Skip to content

Commit

Permalink
mnist example
Browse files Browse the repository at this point in the history
  • Loading branch information
gaurav274 committed Jun 5, 2022
1 parent 42a0eb4 commit 6e320a9
Show file tree
Hide file tree
Showing 3 changed files with 424 additions and 0 deletions.
39 changes: 39 additions & 0 deletions tutorials/apps/mnist/eva_mnist_udf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pandas as pd
from torch import Tensor

from eva.udfs.pytorch_abstract_udf import PytorchAbstractUDF
from eva.models.catalog.frame_info import FrameInfo
from eva.models.catalog.properties import ColorSpace
from mnist_raw_script import mnist
from torchvision.transforms import Compose, ToTensor, Normalize


class MnistCNN(PytorchAbstractUDF):

@property
def name(self) -> str:
return 'MnistCNN'

def __init__(self):
super().__init__()
self.model = mnist()
self.mode.eval()

@property
def input_format(self):
return FrameInfo(1, 28, 28, ColorSpace.RGB)

@property
def labels(self):
return list([str(num) for num in range(10)])

def transforms(self) -> Compose:
return Compose([
ToTensor(),
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

def _get_predictions(self, frames: Tensor) -> pd.DataFrame:
outcome = pd.DataFrame()
outcome['label'] = self.model(frames)
return outcome
44 changes: 44 additions & 0 deletions tutorials/apps/mnist/mnist_raw_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch.nn as nn
from collections import OrderedDict
import torch.utils.model_zoo as model_zoo

model_urls = {
'mnist': 'http://ml.cs.tsinghua.edu.cn/~chenxi/pytorch-models/mnist-b07bb66b.pth' # noqa
}


class MLP(nn.Module):
def __init__(self, input_dims, n_hiddens, n_class):
super(MLP, self).__init__()
assert isinstance(input_dims, int), 'Please provide int for input_dims'
self.input_dims = input_dims
current_dims = input_dims
layers = OrderedDict()

if isinstance(n_hiddens, int):
n_hiddens = [n_hiddens]
else:
n_hiddens = list(n_hiddens)
for i, n_hidden in enumerate(n_hiddens):
layers['fc{}'.format(i + 1)] = nn.Linear(current_dims, n_hidden)
layers['relu{}'.format(i + 1)] = nn.ReLU()
layers['drop{}'.format(i + 1)] = nn.Dropout(0.2)
current_dims = n_hidden
layers['out'] = nn.Linear(current_dims, n_class)

self.model = nn.Sequential(layers)

def forward(self, input):
input = input.view(input.size(0), -1)
assert input.size(1) == self.input_dims
return self.model.forward(input)


def mnist(input_dims=784, n_hiddens=[256, 256], n_class=10, pretrained=None):
model = MLP(input_dims, n_hiddens, n_class)
if pretrained is not None:
m = model_zoo.load_url(model_urls['mnist'])
state_dict = m.state_dict() if isinstance(m, nn.Module) else m
assert isinstance(state_dict, (dict, OrderedDict)), type(state_dict)
model.load_state_dict(state_dict)
return model
Loading

0 comments on commit 6e320a9

Please sign in to comment.