In [1]:
import tensorrt as trt
import numpy as np

import torch
import torch.utils.data as tud
import torch.nn.functional as F
import torchvision.datasets as tds
import torchvision.transforms.v2 as tv2

import matplotlib.pyplot as plt

import loaders

In [2]:
trt_batch_size = 1
logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)

with open("sparse_resnet.onnx", "rb") as f:
    print("Load ONNX.")
    parser.parse(f.read())


[01/11/2025-20:54:01] [TRT] [I] [MemUsageChange] Init CUDA: CPU +18, GPU +0, now: CPU 158, GPU 2376 (MiB)
Load ONNX.
[01/11/2025-20:54:02] [TRT] [I] [MemUsageChange] Init builder kernel library: CPU +2283, GPU +440, now: CPU 2597, GPU 2816 (MiB)


In [3]:
DIM=160
profile = builder.create_optimization_profile()
profile.set_shape(
    'input',
    (trt_batch_size, 3, DIM, DIM),
    (trt_batch_size, 3, DIM, DIM),
    (trt_batch_size, 3, DIM, DIM)
)

config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30)
config.add_optimization_profile(profile)
config.set_calibration_profile(profile)
config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS)


  config.set_calibration_profile(profile)


In [4]:
class DatasetCalibrator(trt.IInt8Calibrator):
    def __init__(self,
            input, dataset,
            algorithm=trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2):
        super(DatasetCalibrator, self).__init__()
        self.dataset = dataset
        self.algorithm = algorithm
        self.buffer = torch.zeros_like(input).contiguous().cuda() # Must move tensor to GPU here!
        self.count = 0

    def get_batch(self, *args, **kwargs):
        if self.count < len(self.dataset):
            for buffer_idx in range(self.get_batch_size()):
                dataset_idx = self.count % len(self.dataset) # roll around if not multiple of dataset
                image, _ = self.dataset[dataset_idx]
                image = image.to(self.buffer.device)
                self.buffer[buffer_idx].copy_(image)

                self.count += 1
            return [int(self.buffer.data_ptr())]
        else:
            return []

    def get_algorithm(self):
        return self.algorithm

    def get_batch_size(self):
        return int(self.buffer.shape[0])

    def read_calibration_cache(self, *args, **kwargs):
        return None

    def write_calibration_cache(self, cache, *args, **kwargs):
        pass


In [5]:
use_int8=True
batch_size=1

DIM, train_loader, val_loader = loaders.get_loaders(batch_size)

if use_int8:
    data = torch.zeros(batch_size, 3, DIM, DIM)
    config.set_flag(trt.BuilderFlag.INT8)
    
    val_loader = tud.DataLoader(val_loader.dataset, batch_size=trt_batch_size, shuffle=True)
    config.int8_calibrator = DatasetCalibrator(data, val_loader.dataset)
else:
    config.set_flag(trt.BuilderFlag.FP16)

  config.int8_calibrator = DatasetCalibrator(data, val_loader.dataset)


In [6]:
engine = builder.build_serialized_network(network, config)

[01/11/2025-20:54:34] [TRT] [I] Perform graph optimization on calibration graph.
[01/11/2025-20:54:34] [TRT] [I] Local timing cache in use. Profiling results in this builder pass will not be stored.
[01/11/2025-20:54:34] [TRT] [I] Compiler backend is used during engine build.
[01/11/2025-20:54:34] [TRT] [I] Detected 1 inputs and 1 output network tensors.
[01/11/2025-20:54:35] [TRT] [I] Total Host Persistent Memory: 83808 bytes
[01/11/2025-20:54:35] [TRT] [I] Total Device Persistent Memory: 38912 bytes
[01/11/2025-20:54:35] [TRT] [I] Max Scratch Memory: 512 bytes
[01/11/2025-20:54:35] [TRT] [I] [BlockAssignment] Started assigning block shifts. This will take 51 steps to complete.
[01/11/2025-20:54:35] [TRT] [I] [BlockAssignment] Algorithm ShiftNTopDown took 0.199519ms to assign 3 blocks to 51 nodes requiring 3686400 bytes.
[01/11/2025-20:54:35] [TRT] [I] Total Activation Memory: 3686400 bytes
[01/11/2025-20:54:35] [TRT] [I] Total Weights Memory: 74073128 bytes
[01/11/2025-20:54:35] [TRT

In [7]:
with open("sparse_resnet.engine", 'wb') as f:
    f.write(engine)

In [8]:
# https://github.com/NVIDIA/TensorRT/tree/main/samples/python/yolov3_onnx
import tensorrt as trt

logger = trt.Logger()
runtime = trt.Runtime(logger)
with open("sparse_resnet.engine", 'rb') as f:
    engine = runtime.deserialize_cuda_engine(f.read())
context = engine.create_execution_context()

In [9]:
# Use "get_tensor_name" instead
inputs = [] 
outputs = []
allocs = []
for i in range(engine.num_io_tensors):
    name = engine.get_tensor_name(i)
    mode = engine.get_tensor_mode(name)

    dtype = engine.get_tensor_dtype(name)
    shape = engine.get_tensor_shape(name)
    print(name, mode, dtype, shape)

    allocation = torch.zeros(list(shape), dtype=torch.float32, device=torch.device('cuda'))
    allocs.append(allocation)
    binding = {
        "index": i,
        "name": name,
        "dtype": np.dtype(trt.nptype(dtype)),
        "shape": [shape],
        "allocation": allocation.data_ptr(),
    }
    if mode == trt.TensorIOMode.INPUT:
        inputs.append(binding)
    elif mode == trt.TensorIOMode.OUTPUT:
        outputs.append(binding)
    else:
        pass

input TensorIOMode.INPUT DataType.FLOAT (1, 3, 160, 160)
output TensorIOMode.OUTPUT DataType.FLOAT (1, 10)


In [10]:
torch.randn(1, 3, 160, 160)

tensor([[[[ 0.0608, -0.6201,  0.3521,  ...,  0.3822, -0.6803,  1.3382],
          [-0.1981,  0.9123, -1.4168,  ..., -0.2005,  0.3921, -1.3323],
          [ 1.3873,  0.1214,  0.5647,  ...,  1.0953,  0.0478,  1.2977],
          ...,
          [ 0.3277, -0.0807, -0.3259,  ...,  0.5205, -0.2888,  0.6869],
          [-0.8896,  1.1568,  0.7104,  ..., -0.3125, -1.1489,  1.8882],
          [ 0.7842, -0.5409,  0.4277,  ...,  0.0184,  0.3026, -0.7409]],

         [[ 1.1056,  1.1562, -0.6393,  ..., -1.0844,  0.3970, -1.5665],
          [-0.6366,  0.3048, -2.1940,  ..., -0.7860,  0.3569,  2.2620],
          [ 0.4650,  1.0774,  1.3685,  ..., -1.0992,  0.3529,  0.2582],
          ...,
          [-0.2267, -1.0669,  0.2567,  ..., -1.0006, -0.9153,  0.2930],
          [-1.5658,  1.2536,  0.0382,  ...,  2.0141, -0.7068, -0.5478],
          [ 0.9830,  0.3725,  1.2499,  ...,  1.2461,  0.9127,  1.9115]],

         [[-1.6640,  1.2523, -2.7089,  ..., -0.0626, -0.1099,  0.6272],
          [ 1.0141,  0.5485,  

In [11]:
allocs[0].copy_(torch.randn(1, 3, 160, 160))
tensors = [x.data_ptr() for x in allocs]
context.execute_v2(tensors)
print(allocs[1])

tensor([[-1.3157, -0.3486, -0.7165,  1.4545,  5.2947, -2.4003, -6.5083, -2.4640,
          5.8920, -0.2899]], device='cuda:0')


In [12]:
losses = []
correct = 0
total = len(val_loader.dataset)

for i, (images, target) in enumerate(val_loader.dataset):
    # print(images, target)
    with torch.no_grad():
        allocs[0].copy_(images)
        tensors = [x.data_ptr() for x in allocs]
        context.execute_v2(tensors)
        outs = allocs[1]

        preds = F.softmax(outs, dim=1)
        cls = outs.argmax()
        lbl = target
        if cls == lbl:
            correct += 1

print("{}/{} correct, {:.2f}%".format(correct, total, 100*correct/total))

3314/3925 correct, 84.43%


In [13]:
from ipywidgets import interact

@interact(index=(0, len(val_loader.dataset) - 1, 1))
def draw_preds(index=0):
    with torch.no_grad():
        image = val_loader.dataset[index][0]
        
        # pred = model(image.float().unsqueeze(0).to(device))

        allocs[0].copy_(image)
        tensors = [x.data_ptr() for x in allocs]
        context.execute_v2(tensors)
        pred = allocs[1]
        
        pred = F.softmax(pred, dim=1)
        clsid = pred.argmax()
        plt.imshow(image.float().cpu().squeeze().permute(1, 2, 0), cmap='gray')
        print(int(clsid))

interactive(children=(IntSlider(value=0, description='index', max=3924), Output()), _dom_classes=('widget-inte…