In [1]:
# 필요한 import문
import io
import numpy as np

from torch import nn
import torch.utils.model_zoo as model_zoo
import torch.onnx

### pytorch

In [2]:
# PyTorch에서 구현된 초해상도 모델
import torch.nn as nn
import torch.nn.init as init

class SuperResolutionNet(nn.Module):
    def __init__(self, upscale_factor, inplace=False):
        super(SuperResolutionNet, self).__init__()

        self.relu = nn.ReLU(inplace=inplace)
        self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
        self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
        self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

        self._initialize_weights()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.pixel_shuffle(self.conv4(x))
        return x

    def _initialize_weights(self):
        init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv4.weight)


In [3]:
torch_model = SuperResolutionNet(upscale_factor=3)

In [4]:
# 미리 학습된 가중치를 읽어옵니다
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
batch_size = 1    # 임의의 수

In [5]:
# 모델을 미리 학습된 가중치로 초기화합니다
map_location = lambda storage, loc: storage
if torch.cuda.is_available():
    map_location = None
torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))

# 모델을 추론 모드로 전환합니다
torch_model.eval()

SuperResolutionNet(
  (relu): ReLU()
  (conv1): Conv2d(1, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(32, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pixel_shuffle): PixelShuffle(upscale_factor=3)
)

In [6]:
# 모델에 대한 입력값
x = torch.ones(1,1,224,224)
x

tensor([[[[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]]]])

In [7]:
# 모델의 출력
torch_out = torch_model(x)
torch_out

tensor([[[[0.7651, 0.8741, 0.9048,  ..., 0.8559, 0.7837, 0.7129],
          [0.8241, 1.0140, 1.1038,  ..., 1.1016, 1.0106, 0.7954],
          [0.8494, 1.0746, 1.1395,  ..., 1.1076, 1.0816, 0.8434],
          ...,
          [0.9251, 1.0277, 1.0539,  ..., 1.0218, 1.0811, 0.9037],
          [0.8515, 0.9133, 0.9483,  ..., 0.9819, 1.0057, 0.8684],
          [0.7569, 0.7720, 0.7775,  ..., 0.7539, 0.8017, 0.7727]]]],
       grad_fn=<UnsafeViewBackward>)

In [8]:
# 모델 변환
torch.onnx.export(torch_model,               # 실행될 모델
                  x,                         # 모델 입력값 (튜플 또는 여러 입력값들도 가능)
                  "super_resolution.onnx",   # 모델 저장 경로 (파일 또는 파일과 유사한 객체 모두 가능)
                  export_params=True,        # 모델 파일 안에 학습된 모델 가중치를 저장할지의 여부
                  opset_version=10,          # 모델을 변환할 때 사용할 ONNX 버전
                  do_constant_folding=True,  # 최적하시 상수폴딩을 사용할지의 여부
                  input_names = ['input'],   # 모델의 입력값을 가리키는 이름
                  output_names = ['output'], # 모델의 출력값을 가리키는 이름
                  dynamic_axes={'input' : {0 : 'batch_size'},    # 가변적인 길이를 가진 차원
                                'output' : {0 : 'batch_size'}})

### onnx

In [9]:
import onnx

onnx_model = onnx.load("super_resolution.onnx")
onnx.checker.check_model(onnx_model)

In [10]:
import onnxruntime

ort_session = onnxruntime.InferenceSession("super_resolution.onnx")

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

In [11]:
# ONNX 런타임에서 계산된 결과값
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}

ort_outs = ort_session.run(None, ort_inputs)

In [12]:
ort_outs

[array([[[[0.76511073, 0.8740847 , 0.90481126, ..., 0.85588545,
           0.7836736 , 0.7129    ],
          [0.82408434, 1.0140268 , 1.1037575 , ..., 1.1016412 ,
           1.0106119 , 0.79541755],
          [0.84935516, 1.0745931 , 1.1394703 , ..., 1.1076125 ,
           1.0815643 , 0.843422  ],
          ...,
          [0.92509943, 1.0276617 , 1.0539078 , ..., 1.0218285 ,
           1.081129  , 0.9036602 ],
          [0.851483  , 0.91334385, 0.9482872 , ..., 0.9819409 ,
           1.005691  , 0.86839116],
          [0.7569455 , 0.772007  , 0.77752113, ..., 0.75390404,
           0.8017341 , 0.7727227 ]]]], dtype=float32)]

In [13]:
ort_outs[0].shape

(1, 1, 672, 672)

In [14]:
672*672

451584

In [15]:
# ONNX 런타임과 PyTorch에서 연산된 결과값 비교
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)

print("Exported model has been tested with ONNXRuntime, and the result looks good!")

Exported model has been tested with ONNXRuntime, and the result looks good!


In [16]:
torch_out

tensor([[[[0.7651, 0.8741, 0.9048,  ..., 0.8559, 0.7837, 0.7129],
          [0.8241, 1.0140, 1.1038,  ..., 1.1016, 1.0106, 0.7954],
          [0.8494, 1.0746, 1.1395,  ..., 1.1076, 1.0816, 0.8434],
          ...,
          [0.9251, 1.0277, 1.0539,  ..., 1.0218, 1.0811, 0.9037],
          [0.8515, 0.9133, 0.9483,  ..., 0.9819, 1.0057, 0.8684],
          [0.7569, 0.7720, 0.7775,  ..., 0.7539, 0.8017, 0.7727]]]],
       grad_fn=<UnsafeViewBackward>)

In [17]:
ort_outs

[array([[[[0.76511073, 0.8740847 , 0.90481126, ..., 0.85588545,
           0.7836736 , 0.7129    ],
          [0.82408434, 1.0140268 , 1.1037575 , ..., 1.1016412 ,
           1.0106119 , 0.79541755],
          [0.84935516, 1.0745931 , 1.1394703 , ..., 1.1076125 ,
           1.0815643 , 0.843422  ],
          ...,
          [0.92509943, 1.0276617 , 1.0539078 , ..., 1.0218285 ,
           1.081129  , 0.9036602 ],
          [0.851483  , 0.91334385, 0.9482872 , ..., 0.9819409 ,
           1.005691  , 0.86839116],
          [0.7569455 , 0.772007  , 0.77752113, ..., 0.75390404,
           0.8017341 , 0.7727227 ]]]], dtype=float32)]