# **1. 라이브러리 import**

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import numpy as np
from PIL import Image

# **2. Modeling**

In [2]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.cn1 = nn.Conv2d(1, 16, 3, 1)
        self.cn2 = nn.Conv2d(16, 32, 3, 1)
        self.dp1 = nn.Dropout2d(0.10)
        self.dp2 = nn.Dropout2d(0.25)
        self.fc1 = nn.Linear(4608, 64) # 4608 is basically 12 X 12 X 32
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.cn1(x)
        x = F.relu(x)

        x = self.cn2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dp1(x)

        x = torch.flatten(x, 1)

        x = self.fc1(x)
        x = F.relu(x)
        x = self.dp2(x)

        x = self.fc2(x)

        op = F.log_softmax(x, dim=1)

        return op

In [3]:
model = ConvNet()

In [4]:
### 모델 가중치 복원

PATH_TO_MODEL = "./convnet.pth"
model.load_state_dict(torch.load(PATH_TO_MODEL, map_location="cpu"))

<All keys matched successfully>

In [5]:
model.eval()

for p in model.parameters():
    p.requires_grad_(False)

# **3. 이미지 전처리**

In [6]:
# 샘플 이미지 로딩
image = Image.open("./digit_image.jpg")

In [7]:
def image_to_tensor(image):
    gray_image = transforms.functional.to_grayscale(image)
    resized_image = transforms.functional.resize(gray_image, (28, 28))
    input_image_tensor = transforms.functional.to_tensor(resized_image)
    input_image_tensor_norm = transforms.functional.normalize(input_image_tensor, 
                                                              (0.1302,), (0.3069,))
    return input_image_tensor_norm

In [8]:
# 전처리 함수 정의

input_tensor = image_to_tensor(image)

# **4. 모델 추적**

In [9]:
demo_input = torch.ones(1, 1, 28, 28)
traced_model = torch.jit.trace(model, demo_input)

In [10]:
traced_model.graph

graph(%self.1 : __torch__.ConvNet,
      %x.1 : Float(1, 1, 28, 28, strides=[784, 784, 28, 1], requires_grad=0, device=cpu)):
  %fc2 : __torch__.torch.nn.modules.linear.___torch_mangle_2.Linear = prim::GetAttr[name="fc2"](%self.1)
  %dp2 : __torch__.torch.nn.modules.dropout.___torch_mangle_1.Dropout2d = prim::GetAttr[name="dp2"](%self.1)
  %fc1 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="fc1"](%self.1)
  %dp1 : __torch__.torch.nn.modules.dropout.Dropout2d = prim::GetAttr[name="dp1"](%self.1)
  %cn2 : __torch__.torch.nn.modules.conv.___torch_mangle_0.Conv2d = prim::GetAttr[name="cn2"](%self.1)
  %cn1 : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="cn1"](%self.1)
  %123 : Tensor = prim::CallMethod[name="forward"](%cn1, %x.1)
  %input.3 : Float(1, 16, 26, 26, strides=[10816, 676, 26, 1], requires_grad=0, device=cpu) = aten::relu(%123) # c:\Users\doroc\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\nn\functional.py:1299:0
  %124 : Tensor 

In [11]:
# 추적한 모델 뒤에 있는 정확한 토치스크립트 코드 살펴보기

print(traced_model.code)

def forward(self,
    x: Tensor) -> Tensor:
  fc2 = self.fc2
  dp2 = self.dp2
  fc1 = self.fc1
  dp1 = self.dp1
  cn2 = self.cn2
  cn1 = self.cn1
  input = torch.relu((cn1).forward(x, ))
  input0 = torch.relu((cn2).forward(input, ))
  input1 = torch.max_pool2d(input0, [2, 2], annotate(List[int], []), [0, 0], [1, 1])
  input2 = torch.flatten((dp1).forward(input1, ), 1)
  input3 = torch.relu((fc1).forward(input2, ))
  _0 = (fc2).forward((dp2).forward(input3, ), )
  return torch.log_softmax(_0, 1)



In [12]:
# 추적 파일 내보내기(저장하기)

torch.jit.save(traced_model, 'traced_convnet.pt')

# **5. 모델 로딩하기**

In [13]:
loaded_traced_model = torch.jit.load('traced_convnet.pt')

In [15]:
model(input_tensor.unsqueeze(0))

tensor([[-9.3505e+00, -1.2089e+01, -2.2391e-03, -8.9248e+00, -9.8197e+00,
         -1.3350e+01, -9.0460e+00, -1.4492e+01, -6.3023e+00, -1.2283e+01]])

In [14]:
loaded_traced_model(input_tensor.unsqueeze(0))

tensor([[-9.3505e+00, -1.2089e+01, -2.2391e-03, -8.9248e+00, -9.8197e+00,
         -1.3350e+01, -9.0460e+00, -1.4492e+01, -6.3023e+00, -1.2283e+01]])

- 추적한 모델이 원래의 모델과 동일하게 제대로 작동하고 있음을 확인할 수 있다.