In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import numpy as np
from PIL import Image

In [2]:
### 모델 네트워크 정의

class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.cn1 = nn.Conv2d(1, 16, 3, 1)
        self.cn2 = nn.Conv2d(16, 32, 3, 1)

        self.dp1 = nn.Dropout2d(0.10)
        self.dp2 = nn.Dropout2d(0.25)

        self.fc1 = nn.Linear(4608, 64) # 4608 is basically 12 X 12 X 32
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.cn1(x)
        x = F.relu(x)

        x = self.cn2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dp1(x)

        x = torch.flatten(x, 1)

        x = self.fc1(x)
        x = F.relu(x)
        x = self.dp2(x)

        x = self.fc2(x)

        op = F.log_softmax(x, dim = 1)

        return op

In [3]:
model = ConvNet()

In [4]:
### 모델 가중치 업데이트

PATH_TO_MODEL = "./convnet.pth"
model.load_state_dict(torch.load(PATH_TO_MODEL, map_location = "cpu"))

<All keys matched successfully>

In [5]:
### 모델 추적

model.eval() # 평가

for p in model.parameters():
    p.requires_grad_(False)

### **모델 스크립팅**
- 모델에 더미 입력을 제공하지 않아도 바로 토치스크립트 코드로 변환 가능

In [6]:
scripted_model = torch.jit.script(model)

In [7]:
scripted_model.graph

graph(%self : __torch__.ConvNet,
      %x.1 : Tensor):
  %51 : Function = prim::Constant[name="log_softmax"]()
  %49 : int = prim::Constant[value=3]()
  %33 : int = prim::Constant[value=-1]()
  %26 : Function = prim::Constant[name="_max_pool2d"]()
  %20 : int = prim::Constant[value=0]()
  %19 : None = prim::Constant()
  %7 : Function = prim::Constant[name="relu"]()
  %6 : bool = prim::Constant[value=0]()
  %17 : int = prim::Constant[value=2]() # C:\Users\doroc\AppData\Local\Temp\ipykernel_26416\2936341797.py:21:28
  %32 : int = prim::Constant[value=1]() # C:\Users\doroc\AppData\Local\Temp\ipykernel_26416\2936341797.py:24:29
  %2 : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="cn1"](%self)
  %x.3 : Tensor = prim::CallMethod[name="forward"](%2, %x.1) # C:\Users\doroc\AppData\Local\Temp\ipykernel_26416\2936341797.py:16:12
  %x.5 : Tensor = prim::CallFunction(%7, %x.3, %6) # C:\Users\doroc\AppData\Local\Temp\ipykernel_26416\2936341797.py:17:12
  %9 : __torch__.torch.nn.modul

In [8]:
print(scripted_model.code)

def forward(self,
    x: Tensor) -> Tensor:
  _0 = __torch__.torch.nn.functional._max_pool2d
  _1 = __torch__.torch.nn.functional.log_softmax
  x0 = (self.cn1).forward(x, )
  x1 = __torch__.torch.nn.functional.relu(x0, False, )
  x2 = (self.cn2).forward(x1, )
  x3 = __torch__.torch.nn.functional.relu(x2, False, )
  x4 = _0(x3, [2, 2], None, [0, 0], [1, 1], False, False, )
  x5 = (self.dp1).forward(x4, )
  x6 = torch.flatten(x5, 1, -1)
  x7 = (self.fc1).forward(x6, )
  x8 = __torch__.torch.nn.functional.relu(x7, False, )
  x9 = (self.dp2).forward(x8, )
  x10 = (self.fc2).forward(x9, )
  return _1(x10, 1, 3, None, )



In [9]:
torch.jit.save(scripted_model, 'scripted_convnet.pt') # 스크립팅한 모델 내보내기

loaded_scripted_model = torch.jit.load('scripted_convnet.pt') # 모델 재로딩하기

### **스크립팅한 모델을 사용한 추론**

In [10]:
### 샘플 이미지 준비

image = Image.open("./digit_image.jpg")

In [11]:
def image_to_tensor(image):
    gray_image = transforms.functional.to_grayscale(image)
    resized_image = transforms.functional.resize(gray_image, (28, 28))
    input_image_tensor = transforms.functional.to_tensor(resized_image)
    input_image_tensor_norm = transforms.functional.normalize(input_image_tensor, 
                                                              (0.1302,), (0.3069,))
    return input_image_tensor_norm

In [12]:
input_tensor = image_to_tensor(image)

In [13]:
loaded_scripted_model(input_tensor.unsqueeze(0))

tensor([[-9.3505e+00, -1.2089e+01, -2.2391e-03, -8.9248e+00, -9.8197e+00,
         -1.3350e+01, -9.0460e+00, -1.4492e+01, -6.3023e+00, -1.2283e+01]])

### **ONNX로 파이토치 모델 내보내기**

**모델 ONNX 파일 저장**  
- 내부적으로는 모델 추적에 사용했던 것과 동일한 메커니즘을 사용해 모델 직렬화

In [14]:
demo_input = torch.ones(1,1,28,28)
torch.onnx.export(model, demo_input, "convnet.onnx")

**저장한 onnx 모델 로딩**  
- 직렬화된 tensorflow 모델로 변환
- 이후 모델 아키텍처를 제대로 로딩했는지 확인하고 그래프의 입출력 노드를 식별

In [16]:
import onnx as onnx
from onnx_tf.backend import prepare

model_onnx = onnx.load("./convnet.onnx")
tf_rep = prepare(model_onnx)
tf_rep.export_graph("./convnet.pb")

**모델 그래프 파싱**  
- 모델 아키텍처를 제대로 로딩했는지 확인
- 그래프의 입출력 노드 식별

In [18]:
import tensorflow as tf

with tf.gfile.GFile("./convnet.pb", "rb") as f:
    graph_definition = tf.GraphDef()
    graph_definition.ParseFromString(f.read())
    
with tf.Graph().as_default() as model_graph:
    tf.import_graph_def(graph_definition, name="")
    
for op in model_graph.get_operations():
    print(op.values())



(<tf.Tensor 'Const:0' shape=(16, 1, 3, 3) dtype=float32>,)
(<tf.Tensor 'Const_1:0' shape=(16,) dtype=float32>,)
(<tf.Tensor 'Const_2:0' shape=(32, 16, 3, 3) dtype=float32>,)
(<tf.Tensor 'Const_3:0' shape=(32,) dtype=float32>,)
(<tf.Tensor 'Const_4:0' shape=(64, 4608) dtype=float32>,)
(<tf.Tensor 'Const_5:0' shape=(64,) dtype=float32>,)
(<tf.Tensor 'Const_6:0' shape=(10, 64) dtype=float32>,)
(<tf.Tensor 'Const_7:0' shape=(10,) dtype=float32>,)
(<tf.Tensor 'input.1:0' shape=(1, 1, 28, 28) dtype=float32>,)
(<tf.Tensor 'transpose/perm:0' shape=(4,) dtype=int32>,)
(<tf.Tensor 'transpose:0' shape=(3, 3, 1, 16) dtype=float32>,)
(<tf.Tensor 'Const_8:0' shape=() dtype=int32>,)
(<tf.Tensor 'split/split_dim:0' shape=() dtype=int32>,)
(<tf.Tensor 'split:0' shape=(3, 3, 1, 16) dtype=float32>,)
(<tf.Tensor 'transpose_1/perm:0' shape=(4,) dtype=int32>,)
(<tf.Tensor 'transpose_1:0' shape=(1, 28, 28, 1) dtype=float32>,)
(<tf.Tensor 'Const_9:0' shape=() dtype=int32>,)
(<tf.Tensor 'split_1/split_dim:0'

- 입력과 출력 노드를 식별할 수 있다.

**샘플 이미지에 대한 예측 생성**
- 변수를 신경망 모델의 입력과 출력 노드에 할당하고 tensorflow session을 인스턴스화해서 그래프 실행

In [19]:
model_output = model_graph.get_tensor_by_name('18:0')
model_input = model_graph.get_tensor_by_name('input.1:0')

sess = tf.Session(graph=model_graph)
output = sess.run(model_output, 
                  feed_dict={model_input: input_tensor.unsqueeze(0)})
print(output)


[[-9.35050774e+00 -1.20893326e+01 -2.23910273e-03 -8.92477798e+00
  -9.81972313e+00 -1.33498535e+01 -9.04598618e+00 -1.44924192e+01
  -6.30233145e+00 -1.22827682e+01]]
