# Export to ONNX

In this notebook, we export our pytorch model to ONNX so that it can later be used for inference.

**1. Import Required Libraries:-** 

In [1]:
from torchvision import transforms
import torch
import torch.optim as optim
from torch.utils.data import random_split, DataLoader
import time
import cv2
import numpy as np
import os
import matplotlib.pyplot as plt
#Local Imports
from dataset import HeadposeDataset
from model import FSANet
import onnx
import onnxruntime


**2. Define Model and Load from Saved Checkpoint:-**

In [2]:
device = torch.device("cuda")
model = FSANet(var=True).to(device)
#Load Model Checkpoint
chkpt_dic = torch.load('checkpoints/fsavar-09082020.chkpt')

model.load_state_dict(chkpt_dic['best_states']['model'])
#set model to inference-ready
model.eval()

FSANet(
  (msms): MultiStreamMultiStage(
    (avgpool): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (s0_conv0): SepConvBlock(
      (conv): SepConv2d(
        (depthwise): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=3)
        (pointwise): Conv2d(3, 16, kernel_size=(1, 1), stride=(1, 1))
      )
      (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU()
    )
    (s0_conv1_0): SepConvBlock(
      (conv): SepConv2d(
        (depthwise): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16)
        (pointwise): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1))
      )
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU()
    )
    (s0_conv1_1): SepConvBlock(
      (conv): SepConv2d(
        (depthwise): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)
        (pointwise): Conv2d(32, 32, kernel_si

**3. Export model to ONNX:-**

In [4]:
#Export to ONNX
x = torch.randn(1,3,64,64).to(device)
model_out = model(x)
save_path = "pretrained/fsanet-var-iter-688590.onnx"

torch.onnx.export(model,               # model being run
                  x,                         # model input (or a tuple for multiple inputs)
                  save_path,   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=9,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output']) # the model's output names

**4. Reload model from ONNX:-**

In [5]:
#Verify ONNX model
model = onnx.load(save_path)

# Check that the IR is well formed
onnx.checker.check_model(model)

# Print a human readable representation of the graph
print(onnx.helper.printable_graph(model.graph))

graph torch-jit-export (
  %input[FLOAT, 1x3x64x64]
) initializers (
  %595[INT64, 1]
  %596[INT64, 1]
  %597[INT64, 1]
  %598[INT64, 1]
  %599[INT64, 1]
  %600[INT64, 1]
  %601[INT64, 1]
  %602[INT64, 1]
  %603[INT64, 1]
  %604[INT64, 1]
  %605[INT64, 1]
  %606[INT64, 1]
  %607[INT64, 1]
  %608[INT64, 1]
  %609[FLOAT, 3x21x64x16]
  %610[FLOAT, scalar]
  %611[FLOAT, scalar]
  %612[INT64, 1]
  %613[INT64, 1]
  %614[INT64, 1]
  %615[INT64, 1]
  %616[INT64, 1]
  %617[INT64, 1]
  %618[INT64, 1]
  %619[INT64, 1]
  %620[INT64, 1]
  %caps_layer.affine_w[FLOAT, 3x21x16x64]
  %esp_s1.pred_fc.bias[FLOAT, 9]
  %esp_s1.pred_fc.weight[FLOAT, 9x8]
  %esp_s1.scale_fc.bias[FLOAT, 3]
  %esp_s1.scale_fc.weight[FLOAT, 3x4]
  %esp_s1.shift_fc.bias[FLOAT, 3]
  %esp_s1.shift_fc.weight[FLOAT, 3x4]
  %esp_s2.pred_fc.bias[FLOAT, 9]
  %esp_s2.pred_fc.weight[FLOAT, 9x8]
  %esp_s2.scale_fc.bias[FLOAT, 3]
  %esp_s2.scale_fc.weight[FLOAT, 3x4]
  %esp_s2.shift_fc.bias[FLOAT, 3]
  %esp_s2.shift_fc.weight[FLOAT, 3x4]


**5. Compare ONNXRuntime and Pytorch Exported Model Output:-**

In [6]:
ort_session = onnxruntime.InferenceSession(save_path)

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)

# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(to_numpy(model_out), ort_outs[0], rtol=1e-03, atol=1e-05)

print("Model Testing was Successful, ONNXRuntime Model Output matches with Pytorch Model Output!")

Model Testing was Successful, ONNXRuntime Model Output matches with Pytorch Model Output!
