In [1]:
import os, random
import numpy as np
import torch

torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)
assert device == "cuda", "Need CUDA for TensorRT."

Device: cuda


In [2]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010)),
])

test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=test_transform)

test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2, pin_memory=True)

test_loader_b1   = DataLoader(test_dataset, batch_size=1,   shuffle=False, num_workers=2, pin_memory=True, drop_last=True)
test_loader_b64  = DataLoader(test_dataset, batch_size=64,  shuffle=False, num_workers=2, pin_memory=True, drop_last=True)
test_loader_b128 = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2, pin_memory=True, drop_last=True)

57.7%IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

100.0%


In [3]:
import sys
from pathlib import Path

sys.path.insert(0, str(Path.cwd().parents[1]))

from models.alexnet_model import AlexNetCIFAR10, AlexNetCIFAR10_QAT  # <-- IMPORTANT: FP32 model

In [6]:
import torch

CKPT_PATH = "../../pth/alexnet_qat_preconvert.pth"

def strip_preconvert_state_generic(ckpt: dict) -> dict:
    out = {}
    for k, v in ckpt.items():
        if k.startswith("module."):
            k = k[len("module."):]
        if ("activation_post_process" in k) or ("fake_quant" in k) or ("weight_fake_quant" in k):
            continue
        if k.startswith("activation_post_process_"):
            continue
        out[k] = v
    return out

def remap_by_shape_and_order(ckpt_state: dict, model_state: dict):
    """
    Greedy match ckpt tensors to model tensors by (suffix, shape) in model key order.
    Works when naming differs but layer order/structure is the same.
    """
    def suffix(k):
        # keep BN stats vs weights separate
        if k.endswith("running_mean"): return "running_mean"
        if k.endswith("running_var"):  return "running_var"
        if k.endswith("num_batches_tracked"): return "nbt"
        if k.endswith(".weight"): return "weight"
        if k.endswith(".bias"):   return "bias"
        return "other"

    # group ckpt keys by (suffix, shape)
    buckets = {}
    for k, v in ckpt_state.items():
        if not torch.is_tensor(v):
            continue
        key = (suffix(k), tuple(v.shape), str(v.dtype))
        buckets.setdefault(key, []).append(k)

    used = set()
    new_state = {}

    for mk, mv in model_state.items():
        if not torch.is_tensor(mv):
            continue
        key = (suffix(mk), tuple(mv.shape), str(mv.dtype))
        cands = buckets.get(key, [])
        # pick first unused candidate
        pick = None
        for ck in cands:
            if ck not in used:
                pick = ck
                break
        if pick is not None:
            new_state[mk] = ckpt_state[pick]
            used.add(pick)

    return new_state

# ---- load ckpt ----
raw = torch.load(CKPT_PATH, map_location="cpu")
ckpt_state = strip_preconvert_state_generic(raw)

# IMPORTANT: use your FP32 AlexNet class here
model = AlexNetCIFAR10(num_classes=10)

model_state = model.state_dict()
mapped = remap_by_shape_and_order(ckpt_state, model_state)

missing, unexpected = model.load_state_dict(mapped, strict=False)

bad_missing = [k for k in missing if k.endswith(".weight") or k.endswith(".bias") or k.endswith("running_mean") or k.endswith("running_var")]
print("missing:", len(missing), "unexpected:", len(unexpected))
print("bad_missing (should be empty):", bad_missing[:50])

assert len(bad_missing) == 0, "Still missing real params => model definition doesn't match checkpoint structure."

model.eval().to(device)
print("Loaded FP32 AlexNet from preconvert checkpoint ✅")

missing: 0 unexpected: 0
bad_missing (should be empty): []
Loaded FP32 AlexNet from preconvert checkpoint ✅


In [7]:
import numpy as np
import torch

@torch.no_grad()
def torch_acc(model, loader, device="cuda", max_batches=None):
    model.eval()
    correct = total = 0
    for bi, (x, y) in enumerate(loader):
        if max_batches is not None and bi >= max_batches:
            break
        x, y = x.to(device), y.to(device)
        pred = model(x).argmax(1)
        correct += (pred == y).sum().item()
        total += y.size(0)
    return 100.0 * correct / total

print("Torch acc (full):", torch_acc(model, test_loader, device=device))
print("Torch acc (first 50):", torch_acc(model, test_loader_b128, device=device, max_batches=50))

Torch acc (full): 89.06
Torch acc (first 50): 88.953125


In [8]:
import torch

model.eval().to(device)

dummy_map = {
    1:   torch.randn(1,   3, 32, 32, device=device),
    64:  torch.randn(64,  3, 32, 32, device=device),
    128: torch.randn(128, 3, 32, 32, device=device),
}

onnx_map = {}
with torch.no_grad():  # not required, but safe
    for bs, dummy in dummy_map.items():
        out_path = f"alexnet_fp32_b{bs}_op13.onnx"
        torch.onnx.export(
            model, dummy, out_path,
            opset_version=13,
            do_constant_folding=True,
            input_names=["input"],
            output_names=["logits"],
            dynamic_axes=None,
            dynamo=False
        )
        onnx_map[bs] = out_path
        print("Exported:", out_path)

  torch.onnx.export(


Exported: alexnet_fp32_b1_op13.onnx
Exported: alexnet_fp32_b64_op13.onnx
Exported: alexnet_fp32_b128_op13.onnx


In [9]:
import os
import torch
import tensorrt as trt

TRT_LOGGER = trt.Logger(trt.Logger.INFO)

# ----------------------------
# INT8 Entropy Calibrator
# ----------------------------
class EntropyCalibrator(trt.IInt8EntropyCalibrator2):
    def __init__(self, calib_loader, max_batches=200, cache_file="calib.cache"):
        super().__init__()
        self.data_iter = iter(calib_loader)
        self.max_batches = max_batches
        self.batch_count = 0
        self.cache_file = cache_file

        # initial shape (just to allocate something)
        x0, _ = next(iter(calib_loader))
        self.batch_size = x0.shape[0]
        self.device_input = torch.empty_like(x0, device="cuda")

    def get_batch_size(self):
        return self.batch_size

    def get_batch(self, names):
        if self.batch_count >= self.max_batches:
            return None

        try:
            x, _ = next(self.data_iter)
        except StopIteration:
            return None

        self.batch_count += 1
        x = x.to("cuda", non_blocking=True)

        # ✅ IMPORTANT: ensure buffer matches incoming shape
        if self.device_input.numel() != x.numel():
            self.device_input = torch.empty_like(x, device="cuda")
        else:
            self.device_input = self.device_input.view_as(x)

        self.device_input.copy_(x)
        return [int(self.device_input.data_ptr())]

    def read_calibration_cache(self):
        return None  # ✅ force fresh calibration

    def write_calibration_cache(self, cache):
        with open(self.cache_file, "wb") as f:
            f.write(cache)

# ----------------------------
# Build static INT8 engine
# ----------------------------
def build_int8_engine_static(onnx_path, engine_path, calib_loader, max_calib_batches=200):
    # ✅ IMPORTANT: delete old cache so TRT can't reuse stale scales
    cache_path = engine_path.replace(".engine", ".cache")
    if os.path.exists(cache_path):
        os.remove(cache_path)

    with trt.Builder(TRT_LOGGER) as builder, \
         builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) as network, \
         trt.OnnxParser(network, TRT_LOGGER) as parser:

        with open(onnx_path, "rb") as f:
            if not parser.parse(f.read()):
                for i in range(parser.num_errors):
                    print(parser.get_error(i))
                raise RuntimeError(f"ONNX parse failed: {onnx_path}")

        config = builder.create_builder_config()
        config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30)

        config.set_flag(trt.BuilderFlag.INT8)
        config.int8_calibrator = EntropyCalibrator(
            calib_loader,
            max_batches=max_calib_batches,
            cache_file=cache_path
        )

        serialized = builder.build_serialized_network(network, config)
        if serialized is None:
            raise RuntimeError(f"INT8 engine build failed: {onnx_path}")

        with open(engine_path, "wb") as f:
            f.write(serialized)

    print("Saved:", engine_path)

# ----------------------------
# Build engines for each batch size
# ----------------------------
calib_loader_map = {
    1:   test_loader_b1,
    64:  test_loader_b64,
    128: test_loader_b128,
}

engine_map = {}
for bs, onnx_path in onnx_map.items():
    engine_path = f"alexnet_int8_b{bs}.engine"
    build_int8_engine_static(
        onnx_path=onnx_path,
        engine_path=engine_path,
        calib_loader=calib_loader_map[bs],
        max_calib_batches=200
    )
    engine_map[bs] = engine_path

print("INT8 engines built:", engine_map)

[12/14/2025-13:48:31] [TRT] [I] [MemUsageChange] Init CUDA: CPU -2, GPU +0, now: CPU 701, GPU 3714 (MiB)
[12/14/2025-13:48:31] [TRT] [I] ----------------------------------------------------------------
[12/14/2025-13:48:31] [TRT] [I] ONNX IR version:  0.0.7
[12/14/2025-13:48:31] [TRT] [I] Opset version:    13
[12/14/2025-13:48:31] [TRT] [I] Producer name:    pytorch
[12/14/2025-13:48:31] [TRT] [I] Producer version: 2.9.1
[12/14/2025-13:48:31] [TRT] [I] Domain:           
[12/14/2025-13:48:31] [TRT] [I] Model version:    0
[12/14/2025-13:48:31] [TRT] [I] Doc string:       
[12/14/2025-13:48:31] [TRT] [I] ----------------------------------------------------------------


  config.int8_calibrator = EntropyCalibrator(


[12/14/2025-13:48:32] [TRT] [I] [MemUsageChange] Init builder kernel library: CPU +279, GPU +6, now: CPU 1178, GPU 3720 (MiB)
[12/14/2025-13:48:32] [TRT] [I] Perform graph optimization on calibration graph.
[12/14/2025-13:48:32] [TRT] [I] Local timing cache in use. Profiling results in this builder pass will not be stored.
[12/14/2025-13:48:32] [TRT] [I] Compiler backend is used during engine build.
[12/14/2025-13:48:32] [TRT] [I] Detected 1 inputs and 1 output network tensors.
[12/14/2025-13:48:32] [TRT] [I] Total Host Persistent Memory: 38336 bytes
[12/14/2025-13:48:32] [TRT] [I] Total Device Persistent Memory: 0 bytes
[12/14/2025-13:48:32] [TRT] [I] Max Scratch Memory: 4608 bytes
[12/14/2025-13:48:32] [TRT] [I] [BlockAssignment] Started assigning block shifts. This will take 24 steps to complete.
[12/14/2025-13:48:32] [TRT] [I] [BlockAssignment] Algorithm ShiftNTopDown took 0.202061ms to assign 3 blocks to 24 nodes requiring 266752 bytes.
[12/14/2025-13:48:32] [TRT] [I] Total Activa

In [10]:
import tensorrt as trt

def inspect_engine(engine_path):
    TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
    with open(engine_path, "rb") as f, trt.Runtime(TRT_LOGGER) as rt:
        eng = rt.deserialize_cuda_engine(f.read())
    names = [eng.get_tensor_name(i) for i in range(eng.num_io_tensors)]
    inp = [n for n in names if eng.get_tensor_mode(n) == trt.TensorIOMode.INPUT][0]
    out = [n for n in names if eng.get_tensor_mode(n) == trt.TensorIOMode.OUTPUT][0]
    print(engine_path)
    print("  IN :", inp, eng.get_tensor_shape(inp), eng.get_tensor_dtype(inp))
    print("  OUT:", out, eng.get_tensor_shape(out), eng.get_tensor_dtype(out))

inspect_engine("alexnet_int8_b128.engine")

alexnet_int8_b128.engine
  IN : input (128, 3, 32, 32) DataType.FLOAT
  OUT: logits (128, 10) DataType.FLOAT


In [11]:
import time
import numpy as np
import torch
import tensorrt as trt

TRT_LOGGER_EVAL = trt.Logger(trt.Logger.WARNING)

def load_engine(engine_path):
    with open(engine_path, "rb") as f, trt.Runtime(TRT_LOGGER_EVAL) as runtime:
        engine = runtime.deserialize_cuda_engine(f.read())
    if engine is None:
        raise RuntimeError(f"Failed to load engine: {engine_path}")
    return engine, engine.create_execution_context()

def get_io_names(engine):
    names = [engine.get_tensor_name(i) for i in range(engine.num_io_tensors)]
    inp = [n for n in names if engine.get_tensor_mode(n) == trt.TensorIOMode.INPUT][0]
    out = [n for n in names if engine.get_tensor_mode(n) == trt.TensorIOMode.OUTPUT][0]
    return inp, out

def trt_dtype_to_torch(dt):
    return {
        trt.DataType.FLOAT: torch.float32,
        trt.DataType.HALF:  torch.float16,
        trt.DataType.INT8:  torch.int8,
        trt.DataType.INT32: torch.int32,
        trt.DataType.BOOL:  torch.bool,
    }[dt]

@torch.no_grad()
def trt_eval(engine_path, loader, warmup=50, iters=200, acc_batches=50):
    engine, context = load_engine(engine_path)
    inp_name, out_name = get_io_names(engine)

    in_shape  = tuple(engine.get_tensor_shape(inp_name))
    out_shape = tuple(engine.get_tensor_shape(out_name))
    bsz = in_shape[0]

    out_torch_dtype = trt_dtype_to_torch(engine.get_tensor_dtype(out_name))

    # non-default stream (avoids TRT warning + better perf)
    stream = torch.cuda.Stream()

    # --------------- Accuracy (first acc_batches) ---------------
    correct = 0
    total = 0
    for bi, (x_cpu, y_cpu) in enumerate(loader):
        if bi >= acc_batches:
            break
        if x_cpu.shape[0] != bsz:
            raise RuntimeError(f"Batch mismatch: loader={x_cpu.shape[0]} vs engine={bsz}")

        x = x_cpu.to("cuda", non_blocking=True)
        y = y_cpu.to("cuda", non_blocking=True)
        yhat = torch.empty(out_shape, device="cuda", dtype=out_torch_dtype)

        with torch.cuda.stream(stream):
            context.set_tensor_address(inp_name, int(x.data_ptr()))
            context.set_tensor_address(out_name, int(yhat.data_ptr()))
            ok = context.execute_async_v3(stream_handle=stream.cuda_stream)
            if not ok:
                raise RuntimeError("TRT execute failed")
        stream.synchronize()

        pred = yhat.float().argmax(dim=1)
        correct += (pred == y).sum().item()
        total += bsz

    acc = 100.0 * correct / max(1, total)

    # --------------- Latency / Throughput microbench ---------------
    # Use one batch from loader
    x_cpu, _ = next(iter(loader))
    if x_cpu.shape[0] != bsz:
        raise RuntimeError(f"Batch mismatch: loader={x_cpu.shape[0]} vs engine={bsz}")
    x = x_cpu.to("cuda", non_blocking=True)
    yhat = torch.empty(out_shape, device="cuda", dtype=out_torch_dtype)

    # warmup
    for _ in range(warmup):
        with torch.cuda.stream(stream):
            context.set_tensor_address(inp_name, int(x.data_ptr()))
            context.set_tensor_address(out_name, int(yhat.data_ptr()))
            ok = context.execute_async_v3(stream_handle=stream.cuda_stream)
            if not ok:
                raise RuntimeError("TRT execute failed")
    stream.synchronize()

    # timed runs (use CUDA events for accurate GPU timing)
    starter = torch.cuda.Event(enable_timing=True)
    ender   = torch.cuda.Event(enable_timing=True)

    times_ms = []
    for _ in range(iters):
        starter.record(stream)
        with torch.cuda.stream(stream):
            context.set_tensor_address(inp_name, int(x.data_ptr()))
            context.set_tensor_address(out_name, int(yhat.data_ptr()))
            ok = context.execute_async_v3(stream_handle=stream.cuda_stream)
            if not ok:
                raise RuntimeError("TRT execute failed")
        ender.record(stream)
        stream.synchronize()
        times_ms.append(starter.elapsed_time(ender))

    times_ms = np.array(times_ms, dtype=np.float64)
    p50 = float(np.percentile(times_ms, 50))
    p90 = float(np.percentile(times_ms, 90))
    p99 = float(np.percentile(times_ms, 99))
    mean = float(times_ms.mean())

    # throughput (images/sec)
    ips = (1000.0 / mean) * bsz

    return {
        "engine": engine_path,
        "batch": bsz,
        "acc_%": acc,
        "lat_mean_ms": mean,
        "lat_p50_ms": p50,
        "lat_p90_ms": p90,
        "lat_p99_ms": p99,
        "throughput_img_s": ips,
    }

# ---- Run for b1/b64/b128 ----
results = []
results.append(trt_eval(engine_map[1],   test_loader_b1,   warmup=50, iters=200, acc_batches=50))
results.append(trt_eval(engine_map[64],  test_loader_b64,  warmup=50, iters=200, acc_batches=50))
results.append(trt_eval(engine_map[128], test_loader_b128, warmup=50, iters=200, acc_batches=50))

for r in results:
    print(f"\n{r['engine']}")
    print(f"  batch: {r['batch']}")
    print(f"  acc: {r['acc_%']:.2f}% (first 50 batches)")
    print(f"  latency ms: mean={r['lat_mean_ms']:.3f}, p50={r['lat_p50_ms']:.3f}, p90={r['lat_p90_ms']:.3f}, p99={r['lat_p99_ms']:.3f}")
    print(f"  throughput: {r['throughput_img_s']:.1f} img/s")


alexnet_int8_b1.engine
  batch: 1
  acc: 96.00% (first 50 batches)
  latency ms: mean=0.135, p50=0.128, p90=0.136, p99=0.161
  throughput: 7384.2 img/s

alexnet_int8_b64.engine
  batch: 64
  acc: 88.66% (first 50 batches)
  latency ms: mean=0.197, p50=0.197, p90=0.199, p99=0.206
  throughput: 324801.9 img/s

alexnet_int8_b128.engine
  batch: 128
  acc: 86.84% (first 50 batches)
  latency ms: mean=0.243, p50=0.242, p90=0.245, p99=0.252
  throughput: 527815.9 img/s


In [12]:
print("Torch acc (model used for export):", torch_acc(model, test_loader))

Torch acc (model used for export): 89.06
