<center><h1>Export to ONNX</h1></center>

# Create a Model

In [None]:
# Super Resolution is a way of increasing the resolution of images, videos.

import torch.nn as nn
import torch.nn.init as init
import torch

# Super Resolution model definition in PyTorch
class SuperResolutionNet(nn.Module):
    def __init__(self, upscale_factor, inplace=False):
        super(SuperResolutionNet, self).__init__()

        self.relu = nn.ReLU(inplace=inplace)
        self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
        self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
        self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

        self._initialize_weights()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.pixel_shuffle(self.conv4(x))
        return x

    def _initialize_weights(self):
        init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv4.weight)

## Instantiate the Model

In [None]:
# Create the super-resolution model by using the above model definition.
torch_model = SuperResolutionNet(upscale_factor=3)
print(torch_model)

# Train Model

## Load Pretrained Model Weights

In [None]:
# Load pretrained model weights
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
batch_size = 1    # just a random number

## Initialize Model with Pretrained Weights

In [None]:
import torch.utils.model_zoo as model_zoo

# Initialize model with the pretrained weights
map_location = lambda storage, loc: storage
if torch.cuda.is_available():
    map_location = None
torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))
print(torch_model.state_dict())

## Train the model

In [None]:
# set the train mode to false since we will only run the forward pass.
torch_model.train(False)

# Export the Model to ONNX

## Provide Input to the Model

In [None]:
# Input to the model
x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)
print(x)

## Export

In [None]:
import torch.onnx
# Export the model
torch.onnx.export(
    torch_model,             # model being run
    x,                       # model input (or a tuple for multiple inputs)
    "super_resolution.onnx", # where to save the model (can be a file or file-like object)
    export_params=True,      # store the trained parameter weights inside the model file
    verbose=True             # Show the output of the export process
)

# Verify ONNX Model

In [None]:
import onnx

model = onnx.load("super_resolution.onnx")       # load an ONNX model
onnx.checker.check_model(model)                  # check that the model IR is well formed

onnx.helper.printable_graph(model.graph)         # print a human readable presentation of the graph