In [None]:
# !pip install torch torchvision tqdm matplotlib onnx onnxscript
# sudo docker run -it --rm --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 -v$(pwd):/run/host -p 8888:8888 nvcr.io/nvidia/pytorch:24.12-py3


import random
import os

import torch
torch.set_float32_matmul_precision('high')
import torch.utils.data as tud
import torch.nn as nn
import torch.nn.functional as F

import torchvision.transforms as tvt
import torchvision.transforms.v2 as tv2
import torchvision.transforms.functional as tvf
import torchvision.datasets as tds
import torchvision.utils as tu
import torchvision

from tqdm import tqdm
import matplotlib.pyplot as plt

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
cpu_num = os.cpu_count() // 2

In [None]:
dataset_root = './datasets'

DIM=160

train_tfs = tvt.Compose([
    tv2.RandomCrop(DIM, 4),
    tvt.ColorJitter(
        brightness=0.2,
        contrast=0.2,
        saturation=0.2,
        hue=0.1
    ),
    tv2.RandomHorizontalFlip(0.5),
    tv2.RandomVerticalFlip(0.25),
    tv2.ToImage(),
    tv2.ToDtype(torch.float32, scale=True)
])

val_tfs = tvt.Compose([
    tv2.CenterCrop(DIM),
    tv2.ToImage(),
    tv2.ToDtype(torch.float32, scale=True),
])

train = tds.imagenette.Imagenette(
    dataset_root,
    "train",
    "160px",
    download=True,
    transform=train_tfs
)

val = tds.imagenette.Imagenette(
    dataset_root,
    "val",
    "160px",
    download=False,
    transform=val_tfs
)

batchsize = 64

train_loader = tud.DataLoader(train, batch_size=batchsize, num_workers=cpu_num, shuffle=True)
val_loader = tud.DataLoader(val, batch_size=batchsize, shuffle=True)

In [None]:
model = torchvision.models.resnet18(weights=None)
# Replace the head b/c imagenette has only 10 classes.
model.fc = nn.Linear(in_features=512, out_features=10, bias=True)

model = model.to(device).train()

In [None]:
from apex.contrib.sparsity import ASP

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
epochs = 30
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

ASP.prune_trained_model(model, optimizer)

lossfn = nn.CrossEntropyLoss()
loss_plot = []
for epoch in range(epochs):
    model.train()
    for i, (images, target) in enumerate(tqdm(train_loader)):
        optimizer.zero_grad()
        images = images.to(device)
        targets = target.to(device)

        outs = model(images)
        loss = lossfn(outs, targets)
        loss.backward()
        optimizer.step()

    losses = []
    model.eval()
    correct = 0
    total = len(val)
    for i, (images, target) in enumerate(val_loader):
        with torch.no_grad():
            images = images.to(device)
            targets = target.to(device)
            outs = model(images)

            loss = lossfn(outs, targets)
            losses.append(loss)
            for x in range(outs.shape[0]):
                preds = F.softmax(outs, dim=1)
                cls = preds[x].argmax()
                lbl = targets[x]
                if cls == lbl:
                    correct += 1

    epoch_loss = torch.Tensor(losses).mean().item()
    print("Epoch {}: {} ".format(epoch, epoch_loss))
    print("Current LR is {}".format(scheduler.get_last_lr()))
    print("{}/{} correct, {:.2f}%".format(correct, total, 100*correct/total))
    loss_plot.append(epoch_loss)
    scheduler.step()

In [None]:
torch.save(model, "sparse_resnet.pth")

In [None]:
torch_input = torch.randn(1, 3, DIM, DIM).to(device)
onnx_program = torch.onnx.dynamo_export(model, torch_input)
onnx_program.save("sparse_resnet.onnx")

In [None]:
from ipywidgets import interact

@interact(index=(0, len(val) - 1, 1))
def draw_preds(index=0):
    model.eval()
    with torch.no_grad():
        image = val[index][0]
        pred = model(image.float().unsqueeze(0).to(device))
        pred = F.softmax(pred, dim=1)
        clsid = pred.argmax()
        plt.imshow(image.float().cpu().squeeze().permute(1, 2, 0), cmap='gray')
        print(int(clsid))

In [None]:
import tensorrt as trt

trt_batch_size = 1
logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)

with open("sparse_resnet.onnx", "rb") as f:
    print("Load ONNX.")
    parser.parse(f.read())


In [None]:
profile = builder.create_optimization_profile()
profile.set_shape(
    'input',
    (trt_batch_size, 3, DIM, DIM),
    (trt_batch_size, 3, DIM, DIM),
    (trt_batch_size, 3, DIM, DIM)
)

config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30)
config.add_optimization_profile(profile)
config.set_calibration_profile(profile)
config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS)


In [None]:
class DatasetCalibrator(trt.IInt8Calibrator):
    def __init__(self,
            input, dataset,
            algorithm=trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2):
        super(DatasetCalibrator, self).__init__()
        self.dataset = dataset
        self.algorithm = algorithm
        self.buffer = torch.zeros_like(input).contiguous().cuda() # Must move tensor to GPU here!
        self.count = 0

    def get_batch(self, *args, **kwargs):
        if self.count < len(self.dataset):
            for buffer_idx in range(self.get_batch_size()):
                dataset_idx = self.count % len(self.dataset) # roll around if not multiple of dataset
                image, _ = self.dataset[dataset_idx]
                image = image.to(self.buffer.device)
                self.buffer[buffer_idx].copy_(image)

                self.count += 1
            return [int(self.buffer.data_ptr())]
        else:
            return []

    def get_algorithm(self):
        return self.algorithm

    def get_batch_size(self):
        return int(self.buffer.shape[0])

    def read_calibration_cache(self, *args, **kwargs):
        return None

    def write_calibration_cache(self, cache, *args, **kwargs):
        pass


In [None]:
use_int8=True
if use_int8:
    data = torch.zeros(batch_size, 3, DIM, DIM)
    config.set_flag(trt.BuilderFlag.INT8)
    
    val_loader = tud.DataLoader(val, batch_size=trt_batch_size, shuffle=True)
    config.int8_calibrator = DatasetCalibrator(data, val_loader.dataset)
else:
    config.set_flag(trt.BuilderFlag.FP16)

In [None]:
engine = builder.build_serialized_network(network, config)

In [None]:
with open("sparse_resnet.engine", 'wb') as f:
    f.write(engine)

In [None]:
# https://github.com/NVIDIA/TensorRT/tree/main/samples/python/yolov3_onnx
import tensorrt as trt

logger = trt.Logger()
runtime = trt.Runtime(logger)
with open("sparse_resnet.engine", 'rb') as f:
    engine = runtime.deserialize_cuda_engine(f.read())
context = engine.create_execution_context()

Profile the engine:
`trtexec --loadEngine=sparse_resnet.engine`

In [None]:
# Use "get_tensor_name" instead
for i in range(engine.num_io_tensors):
    
input_binding_idx = engine.get_tensor_name(0)
output_binding_idx = engine.get_tensor_name('output')

input_shape = (1, 3, DIM, DIM)
output_shape = (1, 3, DIM, DIM)

# Use "set_input_shape" instead
context.set_binding_shape(
    input_binding_idx,
    input_shape
)

input_buffer = torch.zeros(input_shape, dtype=torch.float32, device=torch.device('cuda'))
output_buffer = torch.zeros(output_shape, dtype=torch.float32, device=torch.device('cuda'))

bindings = [None, None]
bindings[input_binding_idx] = input_buffer.data_ptr()
bindings[output_binding_idx] = output_buffer.data_ptr()