## Load the model

In [1]:
import torch
import onnxruntime
from torch import nn
import torch.onnx as onnx
import torchvision.models as models
from torchvision import datasets
from torchvision.transforms import ToTensor

In [2]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
            nn.ReLU()
        )
    
    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

In [3]:
# model을 load할 때도 class를 initiate 해줘야 함
model = NeuralNetwork()
model.load_state_dict(torch.load('data/model.pth'))
model.eval() #이 eval()을 해줘야함!! 그럼으로써 dropout 및 batch normalization layer에 대한 처리를 함


NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
    (5): ReLU()
  )
)

## Model Inference

+ ONNX(Open Neural Network Exchange) runtime provide solution to train once and accelerate inference on any hardware or cloud
+ 모델을 여러 언어 등에서 inference 하기 좋게 해준다는 것 같음

In [4]:
# exporting the model to ONNX
input_image = torch.zeros((1, 28, 28))
onnx_model = 'data/model.onnx'
onnx.export(model, input_image, onnx_model) #기존의 pytorch 모델을 onnx 형식의 모델로 export

In [5]:
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

x, y = test_data[0][0], test_data[0][1] #test할 sample data

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [20]:
# create inference session
session = onnxruntime.InferenceSession(onnx_model, None) #모델을 넣어주어 세션을 생성
input_name = session.get_inputs()[0].name 
output_name = session.get_outputs()[0].name

result = session.run([output_name], {input_name:x.numpy()}) #test할 sampledata를 numpy로 바꿔준다
predicted, actual = classes[result[0][0].argmax(0)], classes[y] #numpy의 argmax이므로 argmax(0)은 행에서 가장 큰 값 찾는다는것

print(f'Predicted: "{predicted}", Actual: "{actual}"')


Predicted: "Ankle boot", Actual: "Ankle boot"
