## Export PyTorch Model to ONNX

Zahra needs `.onnx` files to be used in Labview, but PyTorch saves the model weights/checkpoints as `.pth` files. Hence, we need to conver the PyTorch `.pth` weights to `.onnx` weights.

Define function that converts model to `.onnx`. As arguments, it takes the trained model, a `torch.Tensor` of the the same size *(BATCH, CHANNEL, HEIGHT, WIDTH)* as is expected by the model (the values in the tensor aren't important), and the filepath of where to save the model.

## Stuff for making Ryan's Model

In [1]:
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import seaborn as sns
import pandas as pd
import numpy as np
import warnings
import pickle
import glob
import math
import time
import cv2
import sys
import os

import tifffile as tiff

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from typing import List, Callable, Union, Any, TypeVar, Tuple

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import sklearn

sns.set_theme(style = "white")

warnings.filterwarnings("ignore")

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
def ConvBlock(in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias = False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace = True)
    )

def FinalBlock(in_channels, out_channels):
    return nn.Sequential(
        #ConvBlock(in_channels, in_channels,  kernel_size = 1),
        ConvBlock(in_channels, in_channels,  kernel_size = 1, stride = 1, padding = 0),
        nn.Conv2d(in_channels, out_channels, kernel_size = 1, stride = 1, padding = 0)
    )

def MiddleBlock(in_channels, out_channels):
    return nn.Sequential(
        ConvBlock(in_channels, out_channels),
        nn.Dropout(p = 0.2),
        ConvBlock(out_channels, out_channels)
    )

class ResidualBlock(nn.Module):
    """ Residual encoder block. """
    def __init__(self, in_channels, feature_maps, stride = 1, downsample = None):
        super(ResidualBlock, self).__init__()

        self.relu = nn.ReLU(inplace = True)
        self.maxpool = nn.MaxPool2d(kernel_size = (2, 2), stride = None)
        self.downsample = nn.Sequential(
            nn.Conv2d(in_channels, feature_maps, kernel_size = (1, 1), stride = stride, bias = False),
            nn.BatchNorm2d(feature_maps)
        )

        self.conv1 = nn.Conv2d(in_channels, feature_maps,  kernel_size = (3, 3), stride = stride, padding = 1, bias = False)
        self.bn1   = nn.BatchNorm2d(feature_maps)

        self.conv2 = nn.Conv2d(feature_maps, feature_maps, kernel_size = (3, 3), stride = 1,      padding = 1, bias = False)
        self.bn2   = nn.BatchNorm2d(feature_maps)

    def forward(self, x):
        identity = x

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)

        if self.downsample is not None:
            identity = self.downsample(identity)

        x = x + identity

        skip_connection = self.relu(x)

        x = self.maxpool(skip_connection)

        return x, skip_connection

class Encoder(nn.Module):
    def __init__(self, in_channels):
        super(Encoder, self).__init__()

        self.Encoder_0 = ResidualBlock(in_channels = in_channels, feature_maps = 32)
        self.Encoder_1 = ResidualBlock(in_channels = 32,          feature_maps = 64)
        self.Encoder_2 = ResidualBlock(in_channels = 64,          feature_maps = 128)

        self.Middle = MiddleBlock(in_channels = 128, out_channels = 256)

    def forward(self, x):
        x, x0 = self.Encoder_0(x)
        x, x1 = self.Encoder_1(x)
        x, x2 = self.Encoder_2(x)

        x3 = self.Middle(x)

        return [x0, x1, x2, x3]
    
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DecoderBlock, self).__init__()

        self.up = nn.Upsample(scale_factor = 2)

        self.conv_block_0 = ConvBlock(in_channels + out_channels, out_channels)
        self.conv_block_1 = ConvBlock(out_channels, out_channels)

    def forward(self, x, skip_connection):
        x = self.up(x)

        x = torch.cat((x, skip_connection), 1)

        x = self.conv_block_0(x)
        x = self.conv_block_1(x)

        return x
    
class Decoder(nn.Module):
    def __init__(self, out_channels):
        super(Decoder, self).__init__()

        self.decoder_0 = DecoderBlock(256, 128) 
        self.decoder_1 = DecoderBlock(128, 64)
        self.decoder_2 = DecoderBlock(64, 32)
        
        self.FinalBlock = FinalBlock(in_channels = 32, out_channels = out_channels)
        
    def forward(self, x0, x1, x2, x3):
        x = self.decoder_0(x3, x2)
        x = self.decoder_1(x,  x1)
        x = self.decoder_2(x,  x0)
        
        x = self.FinalBlock(x)

        return x 

class UNet(nn.Module):
    def __init__(self, in_channels, out_channels = 1):
        super(UNet, self).__init__()

        self.backbone = Encoder(in_channels)
        self.head     = Decoder(out_channels)
        
    def forward(self, x):
        x0, x1, x2, x3 = self.backbone(x)

        x = self.head(x0, x1, x2, x3)

        return x


In [3]:
# This is the preprocessing I do to this image.
def preprocess_image(image, max_value = 65_533):
    """ Normalize values between -1 and 1. """
    return ((image / max_value) - 1) * 2

model = UNet(1, 1)
model.load_state_dict(torch.load("./model_0_cpu.pth")["model_state_dict"])

# Verify output size is as expected.
model(torch.rand(2, 1, 512, 512)).shape

torch.Size([2, 1, 512, 512])

## Code for making ONNX model

In [5]:
%load_ext autoreload
%autoreload 2
import os
from pathlib import Path
import torch
import torch.nn
import torch.onnx
import onnx

def my_export_onnx(model:torch.nn.Module, im:torch.Tensor, filepath:str, cpu:bool = True):
    """
    Export model to `.onnx` file. 
    
    Args:
        model: Model to convert to .onnx file
        im (torch.Tensor): Input tensor of expected size for inference
        filepath (str): Location to save file
        cpu (bool): True to send to cpu before export
    """
    save_dir = Path(filepath).parent.absolute()
    if not os.path.isdir(save_dir):
        raise ValueError(f"Invalid path to save: {filepath}. Parent directory doesn't exist")
    exten = Path(filepath).suffix
    if exten != ".onnx":
        raise ValueError(f"Invalid path to save: {filepath}. Must be `.onnx` file.")
    _shape = im.shape
    if len(_shape) != 4:
        raise ValueError(f"Invalid input tensor shape {_shape}. Must be (?, 1, ?, ?) -> (B, C, H, W).")
    if _shape[1] != 1:
        raise ValueError(f"Invalid input tensor shape {_shape}. Must have 1 channel -> (B, C, H, W).")
    

    print(f"Starting `.onnx.` export to: {filepath}")
    
    # set the model to inference mode 
    model.eval()

    if not cpu:
        raise NotImplementedError("Must use cpu")
    else:
        model = model.cpu()
        im = im.cpu()

    # Export model to .onnx file
    torch.onnx.export(
        model,                          # Model to save
        im,                             # Dummy torch.Tensor of expected size
        filepath,                       # Filepath to save
        export_params = True,           # store the trained parameter weights inside the model file 
        opset_version = 12,             # the ONNX version to export the model to
        do_constant_folding = True,     # whether to execute constant folding for optimization
        input_names = ['images'],       # the model's input names
        output_names = ['outputs'],     # the model's output names
        dynamic_axes = {                # Axes of inputs outputs that can change at runtime (aka diff batch size than im )
            "images": {0: "batch_size"},
            "outputs": {0: "batch_size"},
        }
    )

    # Checks
    model_onnx = onnx.load(filepath)  # load onnx model
    onnx.checker.check_model(model_onnx)  # check onnx model

    print("ONNX file successfully created!")


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [6]:
# Create "dummy" input tensor of expected size/shape for the model
batch = 1
channel = 1
height = 512
width = 512
im = torch.zeros(batch, channel, height, width)
print(f"Created model input of size {im.shape} (B, C, H, W)")

Created model input of size torch.Size([1, 1, 512, 512]) (B, C, H, W)


In [7]:
out = model(im)
print(im.shape)
print(out.shape)

torch.Size([1, 1, 512, 512])
torch.Size([1, 1, 512, 512])


In [8]:
# Convert the model
onnx_name = "model_0_cpu.onnx"
my_export_onnx(model=model, im=im, filepath=onnx_name)

Starting `.onnx.` export to: model_0_cpu.onnx
verbose: False, log level: Level.ERROR

ONNX file successfully created!
