In [68]:
import os 
import cv2 
import numpy as np 
import requests 
import torch 
import torch.onnx 
from torch import nn 
from torch.nn.functional import interpolate

In [69]:
class MyInterpolate(torch.autograd.Function):
    
    @staticmethod
    def forward(ctx, input, scales):
        return interpolate(input,
                           scale_factor=scales[-2:].tolist(),
                           mode='bicubic',
                           align_corners=False)

    @staticmethod
    def symbolic(g, input, scales):
        return g.op('Resize', 
                    input, 
                    g.op('Constant', value_t=torch.tensor([], dtype=torch.float32)),
                    scales, 
                    mode_s='cubic',
                    coordinate_transformation_mode_s='half_pixel',
                    cubic_coeff_a_f=-0.75,
                    nearest_mode_s='floor')

In [70]:
class SuperResolutionNet(nn.Module): 
    def __init__(self): 
        super().__init__() 
 
        self.conv1 = nn.Conv2d(3,64,kernel_size=9,padding=4) 
        self.conv2 = nn.Conv2d(64,32,kernel_size=1,padding=0) 
        self.conv3 = nn.Conv2d(32,3,kernel_size=5,padding=2) 
 
        self.relu = nn.ReLU() 
 
    def forward(self, x, upsample_factor): 
        x = MyInterpolate.apply(x, upsample_factor)
        out = self.relu(self.conv1(x)) 
        out = self.relu(self.conv2(out)) 
        out = self.conv3(out) 
        return out 

In [71]:
def init_torch_model(): 
    torch_model = SuperResolutionNet() 
 
    state_dict = torch.load('srcnn.pth')['state_dict'] 
 
    # Adapt the checkpoint 
    for old_key in list(state_dict.keys()): 
        new_key = '.'.join(old_key.split('.')[1:]) 
        state_dict[new_key] = state_dict.pop(old_key) 
 
    torch_model.load_state_dict(state_dict) 
    torch_model.eval() 
    return torch_model 

In [72]:
model = init_torch_model() 
input_img = cv2.imread('face.png').astype(np.float32) 
print(input_img.shape)
 
# HWC to NCHW 
input_img = np.transpose(input_img, [2, 0, 1]) 
input_img = np.expand_dims(input_img, 0) 
 
# Inference 
factor = torch.tensor([1, 1, 3, 3])
torch_output = model(torch.from_numpy(input_img), 
                     factor).detach().numpy() 
 
# NCHW to HWC 
torch_output = np.squeeze(torch_output, 0) 
torch_output = np.clip(torch_output, 0, 255) 
torch_output = np.transpose(torch_output, [1, 2, 0]).astype(np.uint8) 
 
# Show image 
cv2.imwrite("face_torch.png", torch_output)
print(torch_output.shape)

(256, 256, 3)
(768, 768, 3)


In [73]:
x = torch.randn(1, 3, 256, 256)
factor = torch.tensor([1, 1, 3, 3], dtype=torch.float32)
with torch.no_grad():
    torch.onnx.export(model, (x, factor), "srcnn2.onnx", verbose=True, 
                      input_names=["input", "factor"], output_names=["output"]) 
    print("Exported model to ONNX")

  scale_factor=scales[-2:].tolist(),


Exported graph: graph(%input : Float(1, 3, 256, 256, strides=[196608, 65536, 256, 1], requires_grad=0, device=cpu),
      %factor : Float(4, strides=[1], requires_grad=0, device=cpu),
      %conv1.weight : Float(64, 3, 9, 9, strides=[243, 81, 9, 1], requires_grad=1, device=cpu),
      %conv1.bias : Float(64, strides=[1], requires_grad=1, device=cpu),
      %conv2.weight : Float(32, 64, 1, 1, strides=[64, 1, 1, 1], requires_grad=1, device=cpu),
      %conv2.bias : Float(32, strides=[1], requires_grad=1, device=cpu),
      %conv3.weight : Float(3, 32, 5, 5, strides=[800, 25, 5, 1], requires_grad=1, device=cpu),
      %conv3.bias : Float(3, strides=[1], requires_grad=1, device=cpu)):
  %/Constant_output_0 : Float(0, strides=[1], device=cpu) = onnx::Constant[value=[ CPUFloatType{0} ], onnx_name="/Constant"](), scope: __main__.SuperResolutionNet::
  %/Resize_output_0 : Float(*, *, *, *, strides=[1769472, 589824, 768, 1], requires_grad=0, device=cpu) = onnx::Resize[coordinate_transformation_

In [74]:
import onnx

onnx_model = onnx.load("srcnn2.onnx")
try:
    onnx.checker.check_model(onnx_model)
    print("Model is valid")
except onnx.onnx_cpp2py_export.checker.ValidationError as e:
    print("Model is invalid: %s" % e)

Model is valid


In [75]:
import onnxruntime as ort

ort_session = ort.InferenceSession("srcnn2.onnx")
input_factor = np.array([1, 1, 4, 4], dtype=np.float32)
ort_input = {"input": input_img, "factor": input_factor}
ort_output = ort_session.run(["output"], ort_input)[0]

ort_output = np.squeeze(ort_output, 0)
ort_output = np.clip(ort_output, 0, 255)
ort_output = np.transpose(ort_output, [1, 2, 0]).astype(np.uint8)
cv2.imwrite("face_ort4x4.png", ort_output)

True