<img width="150" alt="Logo_ER10" src="https://user-images.githubusercontent.com/3244249/151994514-b584b984-a148-4ade-80ee-0f88b0aefa45.png">

## Pytorch and Pytorch-lightning to ONNX conversion

This notebook shows how to convert your trained Pytorch model or Pytorch-lightning model to ONNX, the generic format supported by DIANNA. <br>
It is based on tutorial at https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.init as init
from torch.utils import model_zoo
import onnx
import onnxruntime as ort

Create an example model using pytorch.

In [2]:
class SuperResolutionNet(nn.Module):
    def __init__(self, upscale_factor, inplace=False):
        super(SuperResolutionNet, self).__init__()

        self.relu = nn.ReLU(inplace=inplace)
        self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
        self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
        self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

        self._initialize_weights()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.pixel_shuffle(self.conv4(x))
        return x

    def _initialize_weights(self):
        init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv4.weight)

# Create the super-resolution model by using the above model definition.
torch_model = SuperResolutionNet(upscale_factor=3)

Load pre-determined weights and applied them to the example model.

In [3]:
# get existing weights
# Load pretrained model weights
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'

# Initialize model with the pretrained weights
torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=torch.device('cpu')))

# set the model to inference mode
torch_model.eval()

SuperResolutionNet(
  (relu): ReLU()
  (conv1): Conv2d(1, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(32, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pixel_shuffle): PixelShuffle(upscale_factor=3)
)

Evaluate model on some random input.

In [4]:
# Input to the model
x = torch.randn(1, 1, 224, 224, requires_grad=True)
# Generate predictions
pred = torch_model(x)

Export the model in ONNX format.

In [5]:
onnx_file = 'pytorch_super_resolution_net.onnx'
# Export the model
torch.onnx.export(torch_model,               # model being run
                  x,                         # model input (or a tuple for multiple inputs)
                  onnx_file,                 # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=10,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                'output' : {0 : 'batch_size'}})

Verify the PyTorch and ONNX predictions match.

In [6]:
# verify the ONNX model is valid
onnx_model = onnx.load(onnx_file)
onnx.checker.check_model(onnx_model)

# get ONNX predictions
sess = ort.InferenceSession(onnx_file)
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name

onnx_input = {input_name: x.detach().numpy().astype(np.float32)}
pred_onnx = sess.run([output_name], onnx_input)[0]

# compare to Pytorch predictions
np.allclose(pred.detach().numpy(), pred_onnx, atol=1e-5)

True

If your model is generated using Pytorch-lightning, the conversion is similar as above. <br>

Below it shows how to convert your trained Pytorch-lightning model to ONNX.<br>
It is based on tutorial at https://pytorch-lightning.readthedocs.io/en/latest/deploy/production_advanced.html

In [7]:
# import pytorch-lightning
import pytorch_lightning as pl

Create an example model using pytorch-lightning.

In [8]:
class SuperResolutionNetLightning(pl.LightningModule):
    def __init__(self, upscale_factor, inplace=False):
        super(SuperResolutionNetLightning, self).__init__()

        self.relu = nn.ReLU(inplace=inplace)
        self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
        self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
        self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

        self._initialize_weights()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.pixel_shuffle(self.conv4(x))
        return x

    def _initialize_weights(self):
        init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv4.weight)

# Create the super-resolution model by using the above model definition.
torch_lightning_model = SuperResolutionNetLightning(upscale_factor=3)

Load pre-determined weights and applied them to the example model.

In [9]:
# Initialize model with the pretrained weights
torch_lightning_model.load_state_dict(model_zoo.load_url(model_url, map_location=torch.device('cpu')))

# set the model to inference mode
torch_lightning_model.eval()

SuperResolutionNetLightning(
  (relu): ReLU()
  (conv1): Conv2d(1, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(32, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pixel_shuffle): PixelShuffle(upscale_factor=3)
)

Evaluate model on some random input.

In [10]:
# Input to the model
x = torch.randn(1, 1, 224, 224, requires_grad=True)
# Generate predictions
pred = torch_lightning_model(x)

Export the model in ONNX format.

In [13]:
# export to onnx format
onnx_file = 'pytorch_lightning_super_resolution_net.onnx'
torch_lightning_model.to_onnx(file_path = onnx_file, input_sample = x, export_params=True)

Verify the PyTorch and ONNX predictions match.

In [14]:
# verify the ONNX model is valid
onnx_model = onnx.load(onnx_file)
onnx.checker.check_model(onnx_model)

# get ONNX predictions
sess = ort.InferenceSession(onnx_file)
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name

onnx_input = {input_name: x.detach().numpy().astype(np.float32)}
pred_onnx = sess.run([output_name], onnx_input)[0]

# compare to Pytorch predictions
np.allclose(pred.detach().numpy(), pred_onnx, atol=1e-5)

True