In [9]:
import os 
 
import cv2 
import numpy as np 
import requests 
import torch 
import torch.onnx 
from torch import nn 

In [10]:
class SuperResolutionNet(nn.Module): 
    def __init__(self, upscale_factor): 
        super().__init__() 
        self.upscale_factor = upscale_factor 
        self.img_upsampler = nn.Upsample( 
            scale_factor=self.upscale_factor, 
            mode='bicubic', 
            align_corners=False) 
 
        self.conv1 = nn.Conv2d(3,64,kernel_size=9,padding=4) 
        self.conv2 = nn.Conv2d(64,32,kernel_size=1,padding=0) 
        self.conv3 = nn.Conv2d(32,3,kernel_size=5,padding=2) 
 
        self.relu = nn.ReLU() 
 
    def forward(self, x): 
        x = self.img_upsampler(x) 
        out = self.relu(self.conv1(x)) 
        out = self.relu(self.conv2(out)) 
        out = self.conv3(out) 
        return out 

In [11]:
def init_torch_model(): 
    torch_model = SuperResolutionNet(upscale_factor=3) 
 
    state_dict = torch.load('srcnn.pth')['state_dict'] 
 
    # Adapt the checkpoint 
    for old_key in list(state_dict.keys()): 
        new_key = '.'.join(old_key.split('.')[1:]) 
        state_dict[new_key] = state_dict.pop(old_key) 
 
    torch_model.load_state_dict(state_dict) 
    torch_model.eval() 
    return torch_model 

In [12]:
# Download checkpoint and test image 
urls = ['https://download.openmmlab.com/mmediting/restorers/srcnn/srcnn_x4k915_1x16_1000k_div2k_20200608-4186f232.pth', 
    'https://raw.githubusercontent.com/open-mmlab/mmediting/master/tests/data/face/000001.png'] 
names = ['srcnn.pth', 'face.png'] 
for url, name in zip(urls, names): 
    if not os.path.exists(name): 
        open(name, 'wb').write(requests.get(url).content) 

model = init_torch_model() 
input_img = cv2.imread('face.png').astype(np.float32) 
print(input_img.shape)
 
# HWC to NCHW 
input_img = np.transpose(input_img, [2, 0, 1]) 
input_img = np.expand_dims(input_img, 0) 
 
# Inference 
torch_output = model(torch.from_numpy(input_img)).detach().numpy() 
 
# NCHW to HWC 
torch_output = np.squeeze(torch_output, 0) 
torch_output = np.clip(torch_output, 0, 255) 
torch_output = np.transpose(torch_output, [1, 2, 0]).astype(np.uint8) 
 
# Show image 
cv2.imwrite("face_torch.png", torch_output)
print(torch_output.shape)

(256, 256, 3)
(768, 768, 3)


In [13]:
x = torch.randn(1, 3, 256, 256)
with torch.no_grad():
    torch.onnx.export(model, x, "srcnn.onnx", verbose=True, 
                      input_names=['input'], output_names=['output']) 
    print("Exported model to ONNX")

Exported graph: graph(%input : Float(1, 3, 256, 256, strides=[196608, 65536, 256, 1], requires_grad=0, device=cpu),
      %conv1.weight : Float(64, 3, 9, 9, strides=[243, 81, 9, 1], requires_grad=1, device=cpu),
      %conv1.bias : Float(64, strides=[1], requires_grad=1, device=cpu),
      %conv2.weight : Float(32, 64, 1, 1, strides=[64, 1, 1, 1], requires_grad=1, device=cpu),
      %conv2.bias : Float(32, strides=[1], requires_grad=1, device=cpu),
      %conv3.weight : Float(3, 32, 5, 5, strides=[800, 25, 5, 1], requires_grad=1, device=cpu),
      %conv3.bias : Float(3, strides=[1], requires_grad=1, device=cpu)):
  %/img_upsampler/Constant_output_0 : Float(4, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value= 1  1  3  3 [ CPUFloatType{4} ], onnx_name="/img_upsampler/Constant"](), scope: __main__.SuperResolutionNet::/torch.nn.modules.upsampling.Upsample::img_upsampler # /root/projects/modelDep/venv/lib/python3.10/site-packages/torch/nn/functional.py:4073:0
  %onnx::Resiz

In [14]:
import onnx

onnx_model = onnx.load("srcnn.onnx")
try:
    onnx.checker.check_model(onnx_model)
    print("Model is valid")
except onnx.onnx_cpp2py_export.checker.ValidationError as e:
    print("Model is invalid: %s" % e)

Model is valid


In [15]:
import onnxruntime as ort

ort_session = ort.InferenceSession("srcnn.onnx")
ort_input = {"input": input_img}
ort_output = ort_session.run(["output"], ort_input)[0]

ort_output = np.squeeze(ort_output, 0)
ort_output = np.clip(ort_output, 0, 255)
ort_output = np.transpose(ort_output, [1, 2, 0]).astype(np.uint8)
cv2.imwrite("face_ort.png", ort_output)

True