## Export PyTorch Model to ONNX

Zahra needs `.onnx` files to be used in Labview, but PyTorch saves the model weights/checkpoints as `.pth` files. Hence, we need to conver the PyTorch `.pth` weights to `.onnx` weights.

Define function that converts model to `.onnx`. As arguments, it takes the trained model, a `torch.Tensor` of the the same size *(BATCH, CHANNEL, HEIGHT, WIDTH)* as is expected by the model (the values in the tensor aren't important), and the filepath of where to save the model.

In [2]:
%load_ext autoreload
%autoreload 2
import os
from pathlib import Path
import torch
import torch.nn
import torch.onnx
import onnx

def my_export_onnx(model:torch.nn.Module, im:torch.Tensor, filepath:str, cpu:bool = True):
    """
    Export model to `.onnx` file. 
    
    Args:
        model: Model to convert to .onnx file
        im (torch.Tensor): Input tensor of expected size for inference
        filepath (str): Location to save file
        cpu (bool): True to send to cpu before export
    """
    save_dir = Path(filepath).parent.absolute()
    if not os.path.isdir(save_dir):
        raise ValueError(f"Invalid path to save: {filepath}. Parent directory doesn't exist")
    exten = Path(filepath).suffix
    if exten != ".onnx":
        raise ValueError(f"Invalid path to save: {filepath}. Must be `.onnx` file.")
    _shape = im.shape
    if len(_shape) != 4:
        raise ValueError(f"Invalid input tensor shape {_shape}. Must be (?, 1, ?, ?) -> (B, C, H, W).")
    if _shape[1] != 1:
        raise ValueError(f"Invalid input tensor shape {_shape}. Must have 1 channel -> (B, C, H, W).")
    

    print(f"Starting `.onnx.` export to: {filepath}")
    
    # set the model to inference mode 
    model.eval()

    if not cpu:
        raise NotImplementedError("Must use cpu")
    else:
        model = model.cpu()
        im = im.cpu()

    # Export model to .onnx file
    torch.onnx.export(
        model,                          # Model to save
        im,                             # Dummy torch.Tensor of expected size
        filepath,                       # Filepath to save
        export_params = True,           # store the trained parameter weights inside the model file 
        opset_version = 12,             # the ONNX version to export the model to
        do_constant_folding = True,     # whether to execute constant folding for optimization
        input_names = ['images'],       # the model's input names
        output_names = ['outputs'],     # the model's output names
        dynamic_axes = {                # Axes of inputs outputs that can change at runtime (aka diff batch size than im )
            "images": {0: "batch_size"},
            "outputs": {0: "batch_size"},
        }
    )

    # Checks
    model_onnx = onnx.load(filepath)  # load onnx model
    onnx.checker.check_model(model_onnx)  # check onnx model

    print("ONNX file successfully created!")


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


With function defined, now we can load the model from the `.pth` file and save it to `.onnx` file.

In [3]:
# Folder name and filenames for where to load/save the model
models_dir = "../src/models/dl4mia_tissue_unet/results"
dir_name = "20230508_104630"
src_name = "best.pth"
dst_name = f"{dir_name}_{os.path.splitext(src_name)[0]}.onnx"

# Specify filepaths
src_file = f"{models_dir}/{dir_name}/{src_name}"
dst_file = f"{models_dir}/{dir_name}/{dst_name}"

assert os.path.exists(src_file), "Source `.pth` file not a valid path"

In [4]:
# Load the model to save to .onnx
from src.models.dl4mia_tissue_unet import (model as v1, model_v2 as v2)
device = "cpu"
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Open .pth file and send to model
print(f"Loading `.pth` file: {src_file}")
checkpoint = torch.load(src_file, map_location=device)
state_dict = checkpoint["model_state_dict"]
model_dict = checkpoint["model_dict"]
print("Trained model info:")
for key in checkpoint:
    if "state_dict" not in key and "logger_data" not in key:
        print(f"\t{key} = {checkpoint[key]}")
model = v2.UNet(**model_dict["kwargs"])
model.load_state_dict(state_dict, strict=True)
model.to(device)
print("`.pth` file successfully loaded!")

Loading `.pth` file: ../src/models/dl4mia_tissue_unet/results/20230508_104630/best.pth
Trained model info:
	epoch = 98
	val_loss = 0.14062084142978376
	val_ap = 0.9189547437887925
	val_dice = 0.859847577718588
	best_loss = 0
	best_dice = 0.9133516412514907
	train_cuda = True
	model_dict = {'name': 'unet', 'kwargs': {'num_classes': 1, 'depth': 3, 'in_channels': 1, 'batch_norm': True}}
`.pth` file successfully loaded!


In [5]:
# Create "dummy" input tensor of expected size/shape for the model
batch = 1
channel = 1
height = 512
width = 512
im = torch.zeros(batch, channel, height, width)
print(f"Created model input of size {im.shape} (B, C, H, W)")

Created model input of size torch.Size([1, 1, 512, 512]) (B, C, H, W)


In [6]:
out = model(im)
print(im.shape)
print(out.shape)

torch.Size([1, 1, 512, 512])
torch.Size([1, 1, 512, 512])


In [7]:
# Convert the model
my_export_onnx(model=model, im=im, filepath=dst_file)

Starting `.onnx.` export to: ../src/models/dl4mia_tissue_unet/results/20230508_104630/20230508_104630_best.onnx
verbose: False, log level: Level.ERROR

ONNX file successfully created!
