In [54]:
# Some standard imports
import io
import numpy as np

from torch import nn
import torch.utils.model_zoo as model_zoo
import torch.onnx
import torch.nn.functional as F

In [27]:
# Super Resolution model definition in PyTorch
import torch.nn as nn
import torch.nn.init as init


class SuperResolutionNet(nn.Module):
    def __init__(self, upscale_factor, inplace=False):
        super(SuperResolutionNet, self).__init__()

        self.relu = nn.ReLU(inplace=inplace)
        self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
        self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
        self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

        self._initialize_weights()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.pixel_shuffle(self.conv4(x))
        return x

    def _initialize_weights(self):
        init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv4.weight)
        init.zeros_(self.conv4.bias)  

# Create the super-resolution model by using the above model definition.
model = SuperResolutionNet(upscale_factor=3)

In [28]:
def print_state_dict(state_dict):    
    print(len(state_dict))
    for layer in state_dict:
        print(layer, '\t', state_dict[layer].shape)
    print(state_dict['conv4.bias'])
print_state_dict(model.state_dict())

8
conv1.weight 	 torch.Size([64, 1, 5, 5])
conv1.bias 	 torch.Size([64])
conv2.weight 	 torch.Size([64, 64, 3, 3])
conv2.bias 	 torch.Size([64])
conv3.weight 	 torch.Size([32, 64, 3, 3])
conv3.bias 	 torch.Size([32])
conv4.weight 	 torch.Size([9, 32, 3, 3])
conv4.bias 	 torch.Size([9])
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0.])


In [29]:
# Load pretrained model weights
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'

model.load_state_dict(model_zoo.load_url(model_url))

print_state_dict(model.state_dict())
# set the model to inference mode
model.eval()


8
conv1.weight 	 torch.Size([64, 1, 5, 5])
conv1.bias 	 torch.Size([64])
conv2.weight 	 torch.Size([64, 64, 3, 3])
conv2.bias 	 torch.Size([64])
conv3.weight 	 torch.Size([32, 64, 3, 3])
conv3.bias 	 torch.Size([32])
conv4.weight 	 torch.Size([9, 32, 3, 3])
conv4.bias 	 torch.Size([9])
tensor([-0.0151, -0.0191, -0.0362, -0.0224,  0.0548,  0.0113,  0.0529,  0.0258,
        -0.0180])


SuperResolutionNet(
  (relu): ReLU()
  (conv1): Conv2d(1, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(32, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pixel_shuffle): PixelShuffle(upscale_factor=3)
)

In [44]:
# Input to the model
x = torch.randn(1, 1, 224, 224, requires_grad=True)

# Export the model
torch.onnx.export(model,               # model being run
                  x,                         # model input 
                  "D:\\super_resolution.onnx",   # where to save the model (can be a file or file-like object)                  
                  opset_version=11,          # the ONNX version to export the model to                  
                  input_names = ['input'],   # the model's input names
                  output_names = ['output']  # the model's output names
                  )

In [47]:
input_name = 'input'
output_name = 'output'
torch.onnx.export(model,               # model being run
                  x,                         # model input 
                  "D:\\super_resolution_2.onnx",   # where to save the model (can be a file or file-like object)                  
                  opset_version=11,          # the ONNX version to export the model to                  
                  input_names = [input_name],   # the model's input names
                  output_names = [output_name],  # the model's output names
                  dynamic_axes= {
                        input_name: {0: 'batch_size', 2 : 'in_width', 3: 'int_height'},
                        output_name: {0: 'batch_size', 2: 'out_width', 3:'out_height'}}
                  )

In [55]:
F.interpolate?

[1;31mSignature:[0m
[0mF[0m[1;33m.[0m[0minterpolate[0m[1;33m([0m[1;33m
[0m    [0minput[0m[1;33m:[0m [0mtorch[0m[1;33m.[0m[0mTensor[0m[1;33m,[0m[1;33m
[0m    [0msize[0m[1;33m:[0m [0mOptional[0m[1;33m[[0m[0mint[0m[1;33m][0m [1;33m=[0m [1;32mNone[0m[1;33m,[0m[1;33m
[0m    [0mscale_factor[0m[1;33m:[0m [0mOptional[0m[1;33m[[0m[0mList[0m[1;33m[[0m[0mfloat[0m[1;33m][0m[1;33m][0m [1;33m=[0m [1;32mNone[0m[1;33m,[0m[1;33m
[0m    [0mmode[0m[1;33m:[0m [0mstr[0m [1;33m=[0m [1;34m'nearest'[0m[1;33m,[0m[1;33m
[0m    [0malign_corners[0m[1;33m:[0m [0mOptional[0m[1;33m[[0m[0mbool[0m[1;33m][0m [1;33m=[0m [1;32mNone[0m[1;33m,[0m[1;33m
[0m    [0mrecompute_scale_factor[0m[1;33m:[0m [0mOptional[0m[1;33m[[0m[0mbool[0m[1;33m][0m [1;33m=[0m [1;32mNone[0m[1;33m,[0m[1;33m
[0m[1;33m)[0m [1;33m->[0m [0mtorch[0m[1;33m.[0m[0mTensor[0m[1;33m[0m[1;33m[0m[0m
[1;31mDocstring:[0m
D

In [58]:
class SuperResolutionNet2(nn.Module):
    def __init__(self, upscale_factor, inplace=False):
        super(SuperResolutionNet2, self).__init__()

        self.relu = nn.ReLU(inplace=inplace)
        self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
        self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
        self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

        self._initialize_weights()

    def forward(self, x, scale):
        print(scale)        
        y = F.interpolate(x, scale_factor= 1./float(scale), mode="bilinear")
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.pixel_shuffle(self.conv4(x))
        return x

    def _initialize_weights(self):
        init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv4.weight)
        init.zeros_(self.conv4.bias)  

# Create the super-resolution model by using the above model definition.
model2 = SuperResolutionNet2(upscale_factor=3)

In [59]:
input_name = 'input'
output_name = 'output'
torch.onnx.export(model2,               
                  (x, 2),                         
                  "D:\\super_resolution_3.onnx",   
                  opset_version=11,          
                  input_names = [input_name],  
                  output_names = [output_name],
                  dynamic_axes= {
                        input_name: {0: 'batch_size', 2 : 'in_width', 3: 'int_height'},
                        output_name: {0: 'batch_size', 2: 'out_width', 3:'out_height'}}
                  )

tensor(2)


  y = F.interpolate(x, scale_factor= 1./float(scale), mode="bilinear")
