# Exporting TinyYoloV2 to ONNX

## Prepare Workspace

### Define Google Colab Flag

In [1]:
GOOGLE_COLAB = True

### Mount Google Drive

In [None]:
if GOOGLE_COLAB:
    import os
    from google.colab import drive

    # Check if Google Drive is already mounted
    if not os.path.exists('/content/drive/My Drive'):
        print("Mounting Google Drive...")
        drive.mount('/content/drive')
    else:
        print("Google Drive is already mounted.")

### Set-up Directories & Install Libraires
Create the directories needed and copy uploaded files into them

In [None]:
!pip uninstall onnxconverter-common -y
if GOOGLE_COLAB:
    !pip install torchinfo
    !pip install torchvision pillow
    !pip install onnx
    !pip install onnxscript
    !pip install onnxruntime-gpu==1.19.0
    !pip install onnxruntime-extensions
    !pip install onnxconverter-common

    !mkdir /content/data

    !cp /content/drive/MyDrive/eml_challenge/data/person_indices.json /content/data
    !cp -r /content/drive/MyDrive/eml_challenge/utils /content
    !cp /content/drive/MyDrive/eml_challenge/tinyyolov2_fused.py /content

### Define Path to Weights and Models

In [None]:
if GOOGLE_COLAB:
    WEIGHTS_PATH = "/content/drive/MyDrive/eml_challenge/weights/"
    MODELS_PATH = "/content/drive/MyDrive/eml_challenge/"
else:
    WEIGHTS_PATH = "./"
    MODELS_PATH = "./models/"

### Append Directory Paths to System Path

In [5]:
import sys
if GOOGLE_COLAB:
    sys.path.append('/content')
    sys.path.append('/content/data')
    sys.path.append('/content/utils')
    sys.path.append(WEIGHTS_PATH)

### Import Libraries

In [6]:
# Pytorch libraries
import torch
import torchinfo
import torch.nn as nn

# Other libraires
import numpy as np
import tqdm
import time

# ONNX libraries
import onnx
import onnxruntime
from onnxconverter_common import float16

# EML libraires
from tinyyolov2_fused import FusedTinyYoloV2
from utils.dataloader_v2 import VOCDataset
from utils.ap import precision_recall_levels, ap, display_roc
from utils.yolo import nms, filter_boxes

## Define ONNX Export Functions

### Define export_to_onnx Function

In [7]:
def export_to_onnx(model, onnx_path):
    model.to(device=torch_device)
    model.eval()

    # Define ONNX model input shape
    dummy_input = torch.empty(TEST_BATCH_SIZE, IMAGE_CHANNELS, IMAGE_LENGTH, IMAGE_WIDTH, device=torch_device)

    # Export model using TorchScript based ONNX exporter
    torch.onnx.export(model,                     # model being run
                    dummy_input,                 # model input (or a tuple for multiple inputs)
                    onnx_path,                   # where to save the model (can be a file or file-like object)
                    export_params=True,          # store the trained parameter weights inside the model file
                    verbose=False,               # print information to stdout
                    opset_version=OPSET_VERSION, # the ONNX version to export the model to
                    input_names = ['input'],     # the model's input names
                    output_names = ['output'],   # the model's output names
                    dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                  'output' : {0 : 'batch_size'}})

    # Load and check if ONNX network was saved correclty
    onnx_net = onnx.load(onnx_path)
    onnx.checker.check_model(onnx_net)

### Define convert_onnx_fp32_to_fp16 Function

In [8]:
def convert_onnx_fp32_to_fp16(onnx_path, onnx_path_fp16):
    model = onnx.load(onnx_path)
    model_fp16 = float16.convert_float_to_float16(model)
    onnx.save(model_fp16, onnx_path_fp16)

    # Load and check if ONNX network was saved correclty
    onnx_net = onnx.load(onnx_path_fp16)
    onnx.checker.check_model(onnx_net)

## Define Precision and Latency Measurement Functions

### Define measure_pytorch_latency Function

In [33]:
def measure_pytorch_latency(model):
    model.to(torch_device)
    t_start = time.time()

    with torch.no_grad():
        for _, (input, _) in tqdm.tqdm(enumerate(test_loader), total=len(test_loader)):
            input = input.to(torch_device)
            model.forward(input)

    t_end  = time.time()
    t_diff = t_end - t_start

    return t_diff

### Define measure_pytorch_f16_latency() Function

In [36]:
def measure_pytorch_f16_latency(model):
    model.to(device=torch_device, dtype=torch.float16)
    t_start = time.time()

    with torch.no_grad():

        for _, (input, _) in tqdm.tqdm(enumerate(test_loader), total=len(test_loader)):
            input = input.to(torch_device, dtype=torch.float16)
            # with torch.amp.autocast(device_type='cuda:0', enabled=True, dtype=torch.float16):
            model.forward(input)

    t_end  = time.time()
    t_diff = t_end - t_start

    return t_diff

### Define measure_onnx_latency Function

In [25]:
def measure_onnx_latency(onnx_session):
    ort_output = torch.empty([TEST_BATCH_SIZE, OUTPUT_DIMENSON_1, OUTPUT_DIMENSON_2, OUTPUT_DIMENSON_3, OUTPUT_DIMENSON_4]
                             , dtype=torch.float32
                             , device=torch_device)

    binding = onnx_session.io_binding()

    binding.bind_output(
    name='output',
    device_type='cuda',
    device_id=0,
    element_type=np.float32,
    shape=tuple(ort_output.shape),
    buffer_ptr=ort_output.data_ptr(),
    )

    t_start = time.time()

    with torch.no_grad():
        for idx, (input, target) in tqdm.tqdm(enumerate(test_loader), total=len(test_loader)):
            if torch.cuda.is_available():
                input = input.to(torch_device)

                binding.bind_input(
                    name='input',
                    device_type='cuda',
                    device_id=0,
                    element_type=np.float32,
                    shape=tuple(input.shape),
                    buffer_ptr=input.data_ptr(),
                )

                onnx_session.run_with_iobinding(binding)
            else:
                input  = input.to(torch_device)
                ort_input  = {onnx_session.get_inputs()[0].name: to_numpy(input)}
                ort_output = onnx_session.run(None, ort_input)[0]
                ort_output = torch.from_numpy(ort_output)

    t_end  = time.time()
    t_diff = t_end - t_start

    return t_diff

### Define measure_onnx_fp16_latency Function

In [12]:
def measure_onnx_fp16_latency(onnx_session):
    ort_output = torch.empty([TEST_BATCH_SIZE, OUTPUT_DIMENSON_1, OUTPUT_DIMENSON_2, OUTPUT_DIMENSON_3, OUTPUT_DIMENSON_4]
                             , dtype=torch.float16
                             , device=torch_device)

    binding_fp16 = onnx_session.io_binding()

    binding_fp16.bind_output(
    name='output',
    device_type='cuda',
    device_id=0,
    element_type=np.float16,
    shape=tuple(ort_output.shape),
    buffer_ptr=ort_output.data_ptr(),
    )

    t_start = time.time()

    with torch.no_grad():
        for idx, (input, target) in tqdm.tqdm(enumerate(test_loader), total=len(test_loader)):
            if torch.cuda.is_available():
                input = input.to(torch_device, dtype=torch.float16)

                binding_fp16.bind_input(
                    name='input',
                    device_type='cuda',
                    device_id=0,
                    element_type=np.float16,
                    shape=tuple(input.shape),
                    buffer_ptr=input.data_ptr(),
                )

                onnx_session.run_with_iobinding(binding_fp16)
            else:
                input  = input.to(dtype=torch.float16)
                ort_input  = {onnx_session.get_inputs()[0].name: to_numpy(input)}
                ort_output = onnx_session.run(None, ort_input)[0]
                ort_output = torch.from_numpy(ort_output)

    t_end  = time.time()
    t_diff = t_end - t_start

    return t_diff

### Define measure_pytorch_avg_precision Function

In [13]:
def measure_pytorch_avg_precision(model):
    test_precision = []
    test_recall = []

    model.to(torch_device)

    with torch.no_grad():
        for idx, (input, target) in tqdm.tqdm(enumerate(test_loader), total=len(test_loader)):
            input  = input.to(torch_device)
            target = target.to(torch_device)
            output = model(input, yolo=True)
            output = filter_boxes(output, CONFIDENCE_THRESHOLD)
            output = nms(output, NMS_THRESHOLD)
            # Calculate precision and recall for each sample
            for i in range(len(target)):
                precision, recall = precision_recall_levels(target[i], output[i])
                test_precision.append(precision)
                test_recall.append(recall)

    # Calculate average precision with collected samples
    average_precision = ap(test_precision, test_recall)
    # Plot ROC
    display_roc(test_precision, test_recall)

    return average_precision

### Define measure_pytorch_fp16_avg_precision Function

In [14]:
def measure_pytorch_fp16_avg_precision(model):
    test_precision = []
    test_recall = []

    model.to(device=torch_device, dtype=torch.float16)

    with torch.no_grad():
        for idx, (input, target) in tqdm.tqdm(enumerate(test_loader), total=len(test_loader)):
            input  = input.to(device=torch_device, dtype=torch.float16)
            target = target.to(device=torch_device, dtype=torch.float16)
            output = model(input, yolo=True)
            output = filter_boxes(output, CONFIDENCE_THRESHOLD)
            output = nms(output, NMS_THRESHOLD)
            # Calculate precision and recall for each sample
            for i in range(len(target)):
                precision, recall = precision_recall_levels(target[i], output[i])
                test_precision.append(precision)
                test_recall.append(recall)

    # Calculate average precision with collected samples
    average_precision = ap(test_precision, test_recall)
    # Plot ROC
    display_roc(test_precision, test_recall)

    return average_precision

### Define measure_onnx_avg_precision Function

In [15]:
def measure_onnx_avg_precision(onnx_session):
    test_precision = []
    test_recall = []

    ort_output = torch.empty([TEST_BATCH_SIZE, OUTPUT_DIMENSON_1, OUTPUT_DIMENSON_2, OUTPUT_DIMENSON_3, OUTPUT_DIMENSON_4]
                             , dtype=torch.float32
                             , device=torch_device)

    binding = onnx_session.io_binding()

    binding.bind_output(
    name='output',
    device_type='cuda',
    device_id=0,
    element_type=np.float32,
    shape=tuple(ort_output.shape),
    buffer_ptr=ort_output.data_ptr(),
    )

    with torch.no_grad():
        for idx, (input, target) in tqdm.tqdm(enumerate(test_loader), total=len(test_loader)):
            if torch.cuda.is_available():
                input  = input.to(torch_device)
                target = target.to(torch_device)

                binding.bind_input(
                    name='input',
                    device_type='cuda',
                    device_id=0,
                    element_type=np.float32,
                    shape=tuple(input.shape),
                    buffer_ptr=input.data_ptr(),
                )
                onnx_session.run_with_iobinding(binding)
                output = filter_boxes(ort_output, CONFIDENCE_THRESHOLD)
                output = nms(output, NMS_THRESHOLD)
            else:
                ort_input  = {onnx_session.get_inputs()[0].name: to_numpy(input)}
                ort_output = onnx_session.run(None, ort_input)[0]
                ort_output = torch.from_numpy(ort_output)
                output = filter_boxes(ort_output, CONFIDENCE_THRESHOLD)
                output = nms(ort_output, NMS_THRESHOLD)
            # Calculate precision and recall for each sample
            for i in range(len(target)):
                precision, recall = precision_recall_levels(target[i], output[i])
                test_precision.append(precision)
                test_recall.append(recall)

    # Calculate average precision with collected samples
    average_precision = ap(test_precision, test_recall)
    # Plot ROC
    display_roc(test_precision, test_recall)

    return average_precision

### Define measure_onnx_fp16_avg_precision Function

In [16]:
def measure_onnx_fp16_avg_precision(onnx_session):
    test_precision = []
    test_recall = []

    ort_output = torch.empty([TEST_BATCH_SIZE, OUTPUT_DIMENSON_1, OUTPUT_DIMENSON_2, OUTPUT_DIMENSON_3, OUTPUT_DIMENSON_4]
                             , dtype=torch.float16
                             , device=torch_device)
    binding_fp16 = onnx_session.io_binding()

    binding_fp16.bind_output(
    name='output',
    device_type='cuda',
    device_id=0,
    element_type=np.float16,
    shape=tuple(ort_output.shape),
    buffer_ptr=ort_output.data_ptr(),
    )

    with torch.no_grad():
        for idx, (input, target) in tqdm.tqdm(enumerate(test_loader), total=len(test_loader)):
            if torch.cuda.is_available():
                input  = input.to(device=torch_device, dtype=torch.float16)
                target = target.to(device=torch_device, dtype=torch.float16)

                binding_fp16.bind_input(
                    name='input',
                    device_type='cuda',
                    device_id=0,
                    element_type=np.float16,
                    shape=tuple(input.shape),
                    buffer_ptr=input.data_ptr(),
                )
                onnx_session.run_with_iobinding(binding_fp16)
                output = filter_boxes(ort_output, CONFIDENCE_THRESHOLD)
                output = nms(output, NMS_THRESHOLD)
            else:
                input  = input.to(dtype=torch.float16)
                target = target.to(dtype=torch.float16)
                print(f'Input shape: {input.shape}')

                ort_input  = {onnx_session.get_inputs()[0].name: to_numpy(input)}
                ort_output = onnx_session.run(None, ort_input)[0]
                print(f'ort_output shape: {ort_output.shape}')
                ort_output = torch.from_numpy(ort_output)
                output = filter_boxes(ort_output, CONFIDENCE_THRESHOLD)
                print(f'output shape after filter_boxes: {output[0].shape}')
                output = nms(ort_output, NMS_THRESHOLD)
                print(f'output shape after nms: {output[0].shape}')
            # Calculate precision and recall for each sample
            for i in range(len(target)):
                precision, recall = precision_recall_levels(target[i], output[i])
                test_precision.append(precision)
                test_recall.append(recall)

    # Calculate average precision with collected samples
    average_precision = ap(test_precision, test_recall)
    # Plot ROC
    display_roc(test_precision, test_recall)

    return average_precision

## Define Helper Functions

### Define to_numpy Function

In [17]:
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

## Define Global Variables

### Define Constants

In [18]:
# Image Size
IMAGE_LENGTH         = 320
IMAGE_WIDTH          = 320
IMAGE_CHANNELS       = 3
# Output Size
OUTPUT_DIMENSON_1    = 5
OUTPUT_DIMENSON_2    = 10
OUTPUT_DIMENSON_3    = 10
OUTPUT_DIMENSON_4    = 6

### Define Hyperparameters

In [19]:
# Thresholds
CONFIDENCE_THRESHOLD = 0.0
NMS_THRESHOLD        = 0.5
# Batch size
TEST_BATCH_SIZE      = 1
# ONNX Version
OPSET_VERSION        = 20

### Define Measurement Cases Flags

In [30]:
# Latency Measurements
PYTORCH_LATENCY_MEASUREMENT        = True
PYTORCH_LATENCY_FP16_MEASUREMENT   = True
ONNX_LATENCY_MEASUREMENT           = True
ONNX_LATENCY_FP16_MEASUREMENT      = True

# Precision Measurements
PYTORCH_PRECISION_MEASUREMENT      = False
PYTORCH_FP16_PRECISION_MEASUREMENT = False
ONNX_PRECISION_MEASUREMENT         = False
ONNX_PRECISION_FP16_MEASUREMENT    = False

DISCONNECT_RUN_TIME = False

### Define dataLoader, device & models

In [None]:
if torch.cuda.is_available():
    torch_device = torch.device('cuda:0')
    provider     = ["CUDAExecutionProvider"]
    sd = torch.load(WEIGHTS_PATH + "fused_voc_fine_tuned.pt", weights_only=True)
    print("Using GPU")
else:
    torch_device = torch.device('cpu')
    provider     = ["CPUExecutionProvider"]
    sd = torch.load(WEIGHTS_PATH + "fused_voc_fine_tuned.pt", weights_only=True, map_location=torch_device)
    print("Using CPU")

test_dataset = VOCDataset(root="/content/data", year="2012", image_set='val', only_person=True) # Contains 2232 pictures
test_loader  = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=TEST_BATCH_SIZE, shuffle=False)

net = FusedTinyYoloV2(num_classes=1)
net_fp16 = FusedTinyYoloV2(num_classes=1)

net.load_state_dict(sd)
net_fp16.load_state_dict(sd)

net_fp16.to(device=torch_device, dtype=torch.float16)

onnx_filepath      = MODELS_PATH + "tiny_yolo.onnx"
onnx_filepath_fp16 = MODELS_PATH + "tiny_yolo_fp16.onnx"

export_to_onnx(model=net, onnx_path=onnx_filepath)
convert_onnx_fp32_to_fp16(onnx_path=onnx_filepath, onnx_path_fp16=onnx_filepath_fp16)

### Define ONNX Inference Session

In [None]:
# Define Inference Session Options
session_options = onnxruntime.SessionOptions()
session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL

# Define Inference Session for 32 Bit Float
ort_session = onnxruntime.InferenceSession(onnx_filepath, providers=provider, sess_options=session_options)
# Define Inference Session for 16 Bit Float
ort_session_fp16 = onnxruntime.InferenceSession(onnx_filepath_fp16, providers=provider, sess_options=session_options)

## Execute Workspace

### Execute Measurements

In [37]:
from ast import mod
if PYTORCH_LATENCY_MEASUREMENT:
    print("**************************************************************************************")
    print("Measuring Pytorch Latency...")
    pytorch_latency = measure_pytorch_latency(model=net)
    print(f"Pytorch Latency: {pytorch_latency:.5} seconds")
    print("**************************************************************************************\n")

if PYTORCH_LATENCY_FP16_MEASUREMENT:
    print("**************************************************************************************")
    print("Measuring Pytorch FP16 Latency...")
    pytorch_latency = measure_pytorch_f16_latency(model=net_fp16)
    print(f"Pytorch FP16 Latency: {pytorch_latency:.5} seconds")
    print("**************************************************************************************\n")

if ONNX_LATENCY_MEASUREMENT:
    print("**************************************************************************************")
    print("Measuring ONNX Latency...")
    onnx_latency = measure_onnx_latency(onnx_session=ort_session)
    print(f"ONNX Latency: {onnx_latency:.5} seconds")
    print("**************************************************************************************\n")

if ONNX_LATENCY_FP16_MEASUREMENT:
    print("**************************************************************************************")
    print("Measuring ONNX FP16 Latency...")
    onnx_latency = measure_onnx_fp16_latency(onnx_session=ort_session_fp16)
    print(f"ONNX FP16 Latency: {onnx_latency:.5} seconds")
    print("**************************************************************************************\n")

if PYTORCH_PRECISION_MEASUREMENT:
    print("**************************************************************************************")
    print("Measuring Pytorch Average Precision...")
    pytorch_avg_precision = measure_pytorch_avg_precision(model=net)
    print(f"Pytorch Average Precision: {pytorch_avg_precision:.5}")
    print("**************************************************************************************\n")

if PYTORCH_FP16_PRECISION_MEASUREMENT:
    print("**************************************************************************************")
    print("Measuring Pytorch FP16 Average Precision...")
    pytorch_avg_precision = measure_pytorch_fp16_avg_precision(model=net_fp16)
    print(f"Pytorch FP16 Average Precision: {pytorch_avg_precision:.5}")
    print("**************************************************************************************\n")

if ONNX_PRECISION_MEASUREMENT:
    print("**************************************************************************************")
    print("Measuring ONNX Average Precision...")
    onnx_avg_precision = measure_onnx_avg_precision(onnx_session=ort_session)
    print(f"ONNX Average Precision: {onnx_avg_precision:.5}")
    print("**************************************************************************************\n")

if ONNX_PRECISION_FP16_MEASUREMENT:
    print("**************************************************************************************")
    print("Measuring ONNX FP16 Average Precision...")
    onnx_avg_precision = measure_onnx_fp16_avg_precision(onnx_session=ort_session_fp16)
    print(f"ONNX FP16 Average Precision: {onnx_avg_precision:.5}")
    print("**************************************************************************************\n")

**************************************************************************************
Measuring Pytorch Latency...


100%|██████████| 2232/2232 [00:16<00:00, 137.85it/s]


Pytorch Latency: 16.194 seconds
**************************************************************************************

**************************************************************************************
Measuring Pytorch FP16 Latency...


100%|██████████| 2232/2232 [00:16<00:00, 134.52it/s]


Pytorch FP16 Latency: 16.594 seconds
**************************************************************************************

**************************************************************************************
Measuring ONNX Latency...


100%|██████████| 2232/2232 [00:22<00:00, 99.88it/s]


ONNX Latency: 22.349 seconds
**************************************************************************************

**************************************************************************************
Measuring ONNX FP16 Latency...


100%|██████████| 2232/2232 [00:19<00:00, 111.73it/s]

ONNX FP16 Latency: 19.979 seconds
**************************************************************************************






### Disconnect runtime

In [None]:
if GOOGLE_COLAB and DISCONNECT_RUN_TIME:
    from google.colab import runtime
    runtime.unassign()