In [1]:
import os, random
import numpy as np
import torch

torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

Device: cuda


In [2]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010)),
])

test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2, pin_memory=True)

In [3]:
from torch.utils.data import DataLoader

test_loader_b1   = DataLoader(test_dataset, batch_size=1,   shuffle=False, num_workers=2, pin_memory=True, drop_last=True)
test_loader_b64  = DataLoader(test_dataset, batch_size=64,  shuffle=False, num_workers=2, pin_memory=True, drop_last=True)
test_loader_b128 = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2, pin_memory=True, drop_last=True)

In [4]:
import sys
from pathlib import Path
import torch
import torch.nn as nn
import tensorrt as trt
import numpy as np

sys.path.insert(0, str(Path.cwd().parents[1]))
from models.resnet32_model import ResNetQAT
import torch

In [5]:
import torch
import torch.ao.quantization as tq
from models.resnet32_model import ResNetQAT

device = "cuda" if torch.cuda.is_available() else "cpu"
ckpt = "../../pth/resnet_qat_preconvert.pth"

model = ResNetQAT(num_classes=10).to(device)
state = torch.load(ckpt, map_location="cpu")

model.load_state_dict(state, strict=False)  # strict=False is fine here
model.eval()

# IMPORTANT for Q/DQ export:
model.apply(tq.disable_observer)   # stops moving-average observer updates
model.apply(tq.enable_fake_quant)  # keeps fake-quant so exporter emits Q/DQ

ResNetQAT(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlockQAT(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlockQAT(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1

In [6]:
import re
import torch

def strip_and_remap_qat_state(ckpt: dict) -> dict:
    out = {}

    for k, v in ckpt.items():
        # 1) drop all QAT bookkeeping
        if (
            "activation_post_process" in k
            or "fake_quant" in k
            or "weight_fake_quant" in k
            or k.startswith("activation_post_process_")
        ):
            continue

        nk = k

        # 2) stem BN
        nk = nk.replace("conv1.bn.", "bn1.")

        # 3) residual blocks
        nk = re.sub(r"(layer\d+\.\d+)\.conv1\.bn\.", r"\1.bn1.", nk)
        nk = re.sub(r"(layer\d+\.\d+)\.conv2\.bn\.", r"\1.bn2.", nk)

        # 4) shortcut BN (conv, bn)
        nk = re.sub(r"(layer\d+\.\d+)\.shortcut\.0\.bn\.", r"\1.shortcut.1.", nk)

        out[nk] = v

    return out


# ---- LOAD ----
ckpt = torch.load("../../pth/resnet_qat_preconvert.pth", map_location="cpu")
state_fp32 = strip_and_remap_qat_state(ckpt)

missing, unexpected = model.load_state_dict(state_fp32, strict=False)
print("missing:", len(missing))
print("unexpected:", len(unexpected))
print("example missing:", missing[:20])
print("example unexpected:", unexpected[:20])

model.eval().to(device)

missing: 0
unexpected: 0
example missing: []
example unexpected: []


ResNetQAT(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlockQAT(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlockQAT(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1

In [7]:
@torch.no_grad()
def torch_acc(model, loader, device=device):
    model.eval()
    correct = total = 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        pred = model(x).argmax(1)
        correct += (pred == y).sum().item()
        total += y.size(0)
    return 100.0 * correct / total

print("Torch acc:", torch_acc(model, test_loader))

Torch acc: 88.64


In [8]:
import torch
import torch.ao.quantization as tq

model.eval()

# Freeze QAT bookkeeping so it doesn't try to update moving averages during export
model.apply(tq.disable_observer)     # ✅ stops moving_avg observer updates
model.apply(tq.enable_fake_quant)    # ✅ keeps fake-quant so Q/DQ gets emitted

ResNetQAT(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlockQAT(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlockQAT(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1

In [9]:
import torch

model.eval().to(device)

dummy_map = {
    1:   torch.randn(1,   3, 32, 32, device=device),
    64:  torch.randn(64,  3, 32, 32, device=device),
    128: torch.randn(128, 3, 32, 32, device=device),
}

for bs, dummy in dummy_map.items():
    out_path = f"resnet_qat_qdq_b{bs}_op18.onnx"

    torch.onnx.export(
        model,
        dummy,
        out_path,
        opset_version=18,          # ✅ REQUIRED for Q/DQ
        do_constant_folding=False, # ✅ MUST be False for QAT
        input_names=["input"],
        output_names=["logits"],
        dynamic_axes=None,         # ✅ STATIC shapes
    )

    print("Exported:", out_path)

[torch.onnx] Obtain model graph for `ResNetQAT([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `ResNetQAT([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Applied 66 of general pattern rewrite rules.
Exported: resnet_qat_qdq_b1_op18.onnx
[torch.onnx] Obtain model graph for `ResNetQAT([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `ResNetQAT([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Applied 66 of general pattern rewrite rules.
Exported: resnet_qat_qdq_b64_op18.onnx
[torch.onnx] Obtain model graph for `ResNetQAT([...]` with `torch.export.export(..., stri

In [10]:
!ls -lh resnet_qat_qdq_b1_op18.onnx
!ls -lh resnet_qat_qdq_b64_op18.onnx
!ls -lh resnet_qat_qdq_b128_op18.onnx

-rw-r--r-- 1 ihsiao ihsiao 84K Dec 14 11:53 resnet_qat_qdq_b1_op18.onnx
-rw-r--r-- 1 ihsiao ihsiao 84K Dec 14 11:53 resnet_qat_qdq_b64_op18.onnx
-rw-r--r-- 1 ihsiao ihsiao 84K Dec 14 11:53 resnet_qat_qdq_b128_op18.onnx


In [18]:
import os
import tensorrt as trt
import torch

TRT_LOGGER = trt.Logger(trt.Logger.INFO)

onnx_map = {
    1:   "resnet_int8_b1_op18.onnx",
    64:  "resnet_int8_b64_op18.onnx",
    128: "resnet_int8_b128_op18.onnx",
}

calib_loader_map = {
    1:   test_loader_b1,
    64:  test_loader_b64,
    128: test_loader_b128,
}

class EntropyCalibrator(trt.IInt8EntropyCalibrator2):
    def __init__(self, calib_loader, max_batches=200, cache_file="calib.cache"):
        super().__init__()
        self.cache_file = cache_file
        self.data_iter = iter(calib_loader)
        self.max_batches = max_batches
        self.batch_count = 0
        x0, _ = next(iter(calib_loader))
        self.batch_size = x0.shape[0]
        self.device_input = torch.empty_like(x0, device="cuda")

    def get_batch_size(self):
        return self.batch_size

    def get_batch(self, names):
        if self.batch_count >= self.max_batches:
            return None
        try:
            x, _ = next(self.data_iter)
        except StopIteration:
            return None
        self.batch_count += 1
        x = x.to("cuda", non_blocking=True)
        self.device_input.resize_(x.shape).copy_(x)
        return [int(self.device_input.data_ptr())]

    def read_calibration_cache(self):
        if os.path.exists(self.cache_file):
            with open(self.cache_file, "rb") as f:
                return f.read()
        return None

    def write_calibration_cache(self, cache):
        with open(self.cache_file, "wb") as f:
            f.write(cache)

def build_int8_engine_static(onnx_path, engine_path, calib_loader, max_calib_batches=200):
    cache_path = engine_path.replace(".engine", ".cache")

    # ✅ CRITICAL: delete stale cache (common cause of ~10–25% accuracy)
    if os.path.exists(cache_path):
        os.remove(cache_path)
        print("Deleted stale cache:", cache_path)

    with trt.Builder(TRT_LOGGER) as builder, \
         builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) as network, \
         trt.OnnxParser(network, TRT_LOGGER) as parser:

        with open(onnx_path, "rb") as f:
            if not parser.parse(f.read()):
                for i in range(parser.num_errors):
                    print(parser.get_error(i))
                raise RuntimeError(f"ONNX parse failed: {onnx_path}")

        # ✅ Force network output to FP32 (prevents INT8 logits / clipping issues)
        for i in range(network.num_outputs):
            network.get_output(i).dtype = trt.float32

        config = builder.create_builder_config()
        config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30)

        config.set_flag(trt.BuilderFlag.INT8)
        config.set_flag(trt.BuilderFlag.FP16)  # safe + helps performance/fallback

        # Helps TRT respect dtype constraints (esp. keeping outputs FP32)
        if hasattr(trt.BuilderFlag, "OBEY_PRECISION_CONSTRAINTS"):
            config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)

        config.int8_calibrator = EntropyCalibrator(
            calib_loader,
            max_batches=max_calib_batches,
            cache_file=cache_path
        )

        serialized = builder.build_serialized_network(network, config)
        if serialized is None:
            raise RuntimeError(f"Engine build failed: {onnx_path}")

        with open(engine_path, "wb") as f:
            f.write(serialized)

    print(f"Saved: {engine_path}")

# build all three
for bs, onnx_path in onnx_map.items():
    build_int8_engine_static(
        onnx_path=onnx_path,
        engine_path=f"resnet_ptq_int8_b{bs}.engine",
        calib_loader=calib_loader_map[bs],
        max_calib_batches=200,
    )

[12/14/2025-12:16:02] [TRT] [I] ----------------------------------------------------------------
[12/14/2025-12:16:02] [TRT] [I] ONNX IR version:  0.0.10
[12/14/2025-12:16:02] [TRT] [I] Opset version:    18
[12/14/2025-12:16:02] [TRT] [I] Producer name:    pytorch
[12/14/2025-12:16:02] [TRT] [I] Producer version: 2.9.1+cu128
[12/14/2025-12:16:02] [TRT] [I] Domain:           
[12/14/2025-12:16:02] [TRT] [I] Model version:    0
[12/14/2025-12:16:02] [TRT] [I] Doc string:       
[12/14/2025-12:16:02] [TRT] [I] ----------------------------------------------------------------


  network.get_output(i).dtype = trt.float32
  config.int8_calibrator = EntropyCalibrator(


[12/14/2025-12:16:03] [TRT] [I] Perform graph optimization on calibration graph.
[12/14/2025-12:16:03] [TRT] [I] Local timing cache in use. Profiling results in this builder pass will not be stored.
[12/14/2025-12:16:03] [TRT] [I] Compiler backend is used during engine build.
[12/14/2025-12:16:04] [TRT] [I] Detected 1 inputs and 1 output network tensors.
[12/14/2025-12:16:05] [TRT] [I] Total Host Persistent Memory: 148544 bytes
[12/14/2025-12:16:05] [TRT] [I] Total Device Persistent Memory: 0 bytes
[12/14/2025-12:16:05] [TRT] [I] Max Scratch Memory: 4608 bytes
[12/14/2025-12:16:05] [TRT] [I] [BlockAssignment] Started assigning block shifts. This will take 102 steps to complete.
[12/14/2025-12:16:05] [TRT] [I] [BlockAssignment] Algorithm ShiftNTopDown took 0.596822ms to assign 4 blocks to 102 nodes requiring 201216 bytes.
[12/14/2025-12:16:05] [TRT] [I] Total Activation Memory: 201216 bytes
[12/14/2025-12:16:05] [TRT] [I] Total Weights Memory: 3251752 bytes
[12/14/2025-12:16:05] [TRT] [

In [19]:
import torch
import tensorrt as trt

@torch.no_grad()
def trt_accuracy_static(engine_path, test_loader, num_batches=None):
    TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

    with open(engine_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
        engine = runtime.deserialize_cuda_engine(f.read())
    context = engine.create_execution_context()

    names = [engine.get_tensor_name(i) for i in range(engine.num_io_tensors)]
    inp = [n for n in names if engine.get_tensor_mode(n) == trt.TensorIOMode.INPUT][0]
    out = [n for n in names if engine.get_tensor_mode(n) == trt.TensorIOMode.OUTPUT][0]

    # engine fixed shapes
    in_shape = tuple(engine.get_tensor_shape(inp))
    out_shape = tuple(engine.get_tensor_shape(out))
    fixed_bsz = in_shape[0]  # should be 1 or 64 or 128

    # output dtype
    trt_dtype = engine.get_tensor_dtype(out)
    torch_dtype = {
        trt.DataType.FLOAT: torch.float32,
        trt.DataType.HALF:  torch.float16,
        trt.DataType.INT8:  torch.int8,
        trt.DataType.INT32: torch.int32,
    }[trt_dtype]

    stream = torch.cuda.current_stream()
    correct = 0
    total = 0

    for bi, (x_cpu, y_cpu) in enumerate(test_loader):
        if num_batches is not None and bi >= num_batches:
            break

        x = x_cpu.to("cuda", non_blocking=True)
        y = y_cpu.to("cuda", non_blocking=True)

        if x.shape[0] != fixed_bsz:
            raise RuntimeError(f"Batch mismatch: loader={x.shape[0]} but engine expects {fixed_bsz}")

        yhat = torch.empty(out_shape, device="cuda", dtype=torch_dtype)

        context.set_tensor_address(inp, int(x.data_ptr()))
        context.set_tensor_address(out, int(yhat.data_ptr()))

        ok = context.execute_async_v3(stream_handle=stream.cuda_stream)
        if not ok:
            raise RuntimeError("TRT execute failed")

        pred = yhat.float().argmax(dim=1)
        correct += (pred == y).sum().item()
        total += x.shape[0]

    torch.cuda.synchronize()
    return 100.0 * correct / total

In [21]:
acc1   = trt_accuracy_static("resnet_ptq_int8_b1.engine",   test_loader_b1)
acc64  = trt_accuracy_static("resnet_ptq_int8_b64.engine",  test_loader_b64)
acc128 = trt_accuracy_static("resnet_ptq_int8_b128.engine", test_loader_b128)

print(f"INT8 TRT Acc b1:   {acc1:.2f}%")
print(f"INT8 TRT Acc b64:  {acc64:.2f}%")
print(f"INT8 TRT Acc b128: {acc128:.2f}%")

[12/14/2025-12:19:33] [TRT] [W] Using default stream in enqueueV3() may lead to performance issues due to additional calls to cudaStreamSynchronize() by TensorRT to ensure correct synchronization. Please use non-default stream instead.
[12/14/2025-12:19:41] [TRT] [W] Using default stream in enqueueV3() may lead to performance issues due to additional calls to cudaStreamSynchronize() by TensorRT to ensure correct synchronization. Please use non-default stream instead.
[12/14/2025-12:19:42] [TRT] [W] Using default stream in enqueueV3() may lead to performance issues due to additional calls to cudaStreamSynchronize() by TensorRT to ensure correct synchronization. Please use non-default stream instead.
INT8 TRT Acc b1:   22.36%
INT8 TRT Acc b64:  22.25%
INT8 TRT Acc b128: 22.30%


In [None]:
import onnx
m = onnx.load(out_onnx)
ops = set(n.op_type for n in m.graph.node)
print("QuantizeLinear:", "QuantizeLinear" in ops)
print("DequantizeLinear:", "DequantizeLinear" in ops)

In [None]:
import torch
ckpt = torch.load("../../pth/resnet_qat_preconvert.pth", map_location="cpu")
print("num keys:", len(ckpt))
print("first 40 keys:")
for k in list(ckpt.keys())[:40]:
    print(" ", k)

In [None]:
missing, unexpected = model.load_state_dict(ckpt, strict=False)
print("missing:", len(missing))
print("unexpected:", len(unexpected))
print("example missing:", missing[:20])
print("example unexpected:", unexpected[:20])

In [17]:
import onnx

m = onnx.load("resnet_int8_b1_op18.onnx")
ops = sorted({n.op_type for n in m.graph.node})

print("Ops:", ops)
print("Has QuantizeLinear:", "QuantizeLinear" in ops)
print("Has DequantizeLinear:", "DequantizeLinear" in ops)

Ops: ['Add', 'Conv', 'Gemm', 'ReduceMean', 'Relu', 'Reshape']
Has QuantizeLinear: False
Has DequantizeLinear: False
