## Requirements

`pip install onnx onnxruntime onnxruntime-gpu`

In [59]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, SubsetRandomSampler
import torch.optim as optim

from torchvision.models import resnet18
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms

import time
from tqdm import tqdm
import numpy as np

%load_ext autoreload
%autoreload 2 

INFERENCE_BATCH_SIZE=64
TRAIN_BATCH_SIZE=128

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## 1. Train a resnet18 from scratch

In [60]:
# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load FashionMNIST dataset
train_dataset = CIFAR10(root="./data", train=True, transform=transform, download=True)
test_dataset = CIFAR10(root="./data", train=False, transform=transform, download=True)

# Create data loaders
train_loader = DataLoader(dataset=train_dataset, 
                          batch_size=TRAIN_BATCH_SIZE,
                          shuffle=True, num_workers=1)

test_loader = DataLoader(dataset=test_dataset,
                         batch_size=INFERENCE_BATCH_SIZE,
                         shuffle=False, num_workers=1, drop_last=True)

# Initialize the loss function
criterion = nn.CrossEntropyLoss()

Files already downloaded and verified
Files already downloaded and verified


In [61]:
def train(model, num_epochs, lr, writer, start=0, test_every=5):

    model.train()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs)
    
    print(f"Learning rate: {lr}")
    
    for epoch in range(num_epochs):
        running_loss = 0.0

        for images, labels in train_loader:
            optimizer.zero_grad()
            labels = labels.to(device)
            outputs = model(images.to(device))

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()

            running_loss += loss.item()

        avg_loss = running_loss / len(train_loader)
        
#         writer.add_scalar("Training loss", avg_loss, start+epoch+1)
        print(f'Epoch {start + epoch + 1}/{start + num_epochs}, Loss: {avg_loss:.3f}, LR: {scheduler.get_last_lr()[0]}')

        if (start+epoch+1) % test_every == 0:
            acc, inf_time = evaluate(model)
#             writer.add_scalar("Test acc", acc, start+epoch+1)

In [62]:
def evaluate(model, half_precision=False):
    
    # Evaluate the model on the test set
    model.eval()
    correct = 0
    total = 0
    
    inf_time = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            
            if half_precision:
                images = images.half()

            start = time.time()    
            outputs = model(images)
            inf_time += (time.time() - start)

            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = correct / total * 100
    inf_time = inf_time / (len(test_loader) * INFERENCE_BATCH_SIZE) * 1_000
    
    print(f'Test Accuracy: {accuracy:.2f}%, Inference time per sample: {inf_time:.3f} ms (inference batch size: {INFERENCE_BATCH_SIZE})')

    return accuracy, inf_time

In [63]:
device=torch.device("cuda")

In [64]:
model = resnet18(num_classes=10).to(device)
lr = 1e-3

train(model, num_epochs=10, lr=lr, writer=None)

Learning rate: 0.001
Epoch 1/10, Loss: 1.430, LR: 2.4471741852423235e-05
Epoch 2/10, Loss: 1.053, LR: 0.0009045084971874203
Epoch 3/10, Loss: 0.876, LR: 0.00020610737385379736
Epoch 4/10, Loss: 0.749, LR: 0.0006545084971873971
Epoch 5/10, Loss: 0.643, LR: 0.0005000000000002132
Test Accuracy: 71.73%, Inference time per sample: 0.058 ms (inference batch size: 64)
Epoch 6/10, Loss: 0.757, LR: 0.0003454915028125687
Epoch 7/10, Loss: 0.573, LR: 0.0007938926261462524
Epoch 8/10, Loss: 0.478, LR: 9.549150281258283e-05
Epoch 9/10, Loss: 0.405, LR: 0.0009755282581478662
Epoch 10/10, Loss: 0.348, LR: 0.0
Test Accuracy: 73.08%, Inference time per sample: 0.061 ms (inference batch size: 64)


## 2. Convert PyTorch model to ONNX

In [68]:
dummy_input = torch.rand(1, 3, 32, 32).to(device)

## NEVERMIND the following. The forward function implementation is not simple so
# instead of returning a ScriptModule, trace() is returning a TopLevelTracedModule 
# which indicates that the model was not entirely traceable.

model.eval()

# convert the model to a ScriptModule using a dummy input tensor
model_sm = torch.jit.trace(model, example_inputs=dummy_input)

In [69]:
print(type(model))
print(type(model_sm))

<class 'torchvision.models.resnet.ResNet'>
<class 'torch.jit._trace.TopLevelTracedModule'>


In [70]:
# convert the model to an ONNX model using torch.onnx.export

# make sure the model is in eval mode so that layers that 
# behave differently start doing so before tracing begins
model.eval()

torch.onnx.export(model,
                  args=dummy_input,                 # won't be used since we're passing a ScriptModule
                  f="fp32.onnx",
                  opset_version=17,
                  input_names = ['input'],
                  output_names = ['output'],
                  dynamic_axes={
                      'input' : {0 : 'batch_size'},    # flexibility for batch dim
                      'output' : {0 : 'batch_size'}
                      }
                 )

# Running this means we are done with PyTorch. The model is now in ONNX 
# so we'll read it in using Onnx and move forward from there.

## 3. Run the ONNX model using ONNX runtime

In [71]:
import onnx
import onnxruntime

In [72]:
## With the release of ONNX runtime, the following isn't 
# needed. We can directly pass the model file to the session 
# creator. 

# model_onnx = onnx.load("fp32.onnx")
# onnx.checker.check_model(model_onnx)

In [73]:
# create an onnx runtime session
options = onnxruntime.SessionOptions()
options.enable_profiling=True
session = onnxruntime.InferenceSession("fp32.onnx",   # alternatively, can pass model_onnx 
                                      providers=[
                                          'CUDAExecutionProvider',
                                          'CPUExecutionProvider',
#                                           'TensorrtExecutionProvider',
                                      ],
                                      sess_options=options)

In [74]:
# the exection providers (EPs) list is traversed in order
# if a given EP is available, it's used, otherwise we move 
# on to the next one. Since CUDA support was available while starting the
# InferenceSession, the following list had returned both the EPs.
# If there was no CUDA support, the list would have only CPUExecutionProvider
# despite passing other EPs in the InferenceSession initialization
print(session.get_providers())

['CUDAExecutionProvider', 'CPUExecutionProvider']


In [75]:
def to_numpy(x):
    return x.detach().cpu().numpy() if x.requires_grad else x.cpu().numpy()

In [76]:
# create input dictionary with {name: tensor} key,value pairs. 
x = torch.randn(8, 3, 32, 32).to(device)
inputs = {session.get_inputs()[0].name: to_numpy(x)}

In [77]:
%%time
# The `output_names` argument is used for output selection.
# Passing None means no selection, return everything
output = session.run(None, inputs)
session.end_profiling()

CPU times: user 65.4 ms, sys: 84.1 ms, total: 150 ms
Wall time: 263 ms


'onnxruntime_profile__2024-03-01_03-59-18.json'

In [78]:
print((output[0]).shape)

(8, 10)


In [79]:
print(onnxruntime.get_device())
onnxruntime.get_available_providers()

GPU


['TensorrtExecutionProvider',
 'CUDAExecutionProvider',
 'AzureExecutionProvider',
 'CPUExecutionProvider']

## 4. Quantize the ONNX model (CPU)

As per `onnxruntime` documentation [here](https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html), by quantization they mean 8-bit linear quantization of an ONNX model.  
There are two things that need to be done before running `quantize_static()`:
1. First, we need to "preprocess" the model using the preprocess script provided by ONNX. 
2. We need to write a CalibrationDataReader that would provide data for the quantization calibration process.

### 4.1 Preprocess the model

This runs shape inference for the model and performs optimizations in the computation graph like fusing operations for efficiency (e.g., conv+bn), reducing redundancy, etc.

In [80]:
!python -m onnxruntime.quantization.preprocess --input fp32.onnx --output fp32_preproc.onnx

### 4.2 Quantize the model (including calibration)

Define a `CalibrationDataReader` object to provide the calibration data needed for static quatization.

In [81]:
num_calib_batches = 64
CALIB_BATCH_SIZE = 8

# which samples to use from the training dataset for calibration
indices = list(np.random.choice(len(test_dataset), size=num_calib_batches*CALIB_BATCH_SIZE, replace=False))

calib_loader = DataLoader(dataset=train_dataset, 
                          batch_size=CALIB_BATCH_SIZE,
                          sampler=SubsetRandomSampler(indices), 
                          num_workers=1)

In [83]:
import onnxruntime.quantization as oqt

# create the CalibrationDataReader for your dataset
class ResNetCalibLoader(oqt.CalibrationDataReader):
    def __init__(self, dataloader, model_path):
        self.dataloader = dataloader
        self.enum_data = None
        
        # create an inference session to find the input name
        session = onnxruntime.InferenceSession(model_path, providers=["CPUExecutionProvider"])
        self.input_name = session.get_inputs()[0].name
        
    def get_next(self):
        if self.enum_data is None:
            self.enum_data = iter(
                [{self.input_name: im.cpu().numpy()} for im, _ in self.dataloader])
        
        # None is the value to return when the iterator is empty
        return next(self.enum_data, None)
        
    def rewind(self,):
        self.enum_data = None

In [84]:
calib_reader = ResNetCalibLoader(calib_loader, "fp32_preproc.onnx")
oqt.quantize_static(
        model_input="fp32_preproc.onnx", 
        model_output="int8.onnx", 
        calibration_data_reader=calib_reader,
        quant_format=oqt.QuantFormat.QDQ,
        weight_type=oqt.QuantType.QInt8,
        per_channel=False
        )