In [1]:
import os, random
import numpy as np
import torch

torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

Device: cuda


In [2]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010)),
])

test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2, pin_memory=True)

In [3]:
import sys
from pathlib import Path
import torch
import torch.nn as nn
import tensorrt as trt
import numpy as np

sys.path.insert(0, str(Path.cwd().parents[1]))
from models.resnet32_model import ResNetQAT
import torch

In [4]:
import torch.ao.quantization as tq

model = ResNetQAT(num_classes=10)
model.eval().to(device)

ResNetQAT(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlockQAT(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlockQAT(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1

In [5]:
import re
import torch

def strip_and_remap_qat_state(ckpt: dict) -> dict:
    out = {}

    for k, v in ckpt.items():
        # 1) drop all QAT bookkeeping
        if (
            "activation_post_process" in k
            or "fake_quant" in k
            or "weight_fake_quant" in k
            or k.startswith("activation_post_process_")
        ):
            continue

        nk = k

        # 2) stem BN
        nk = nk.replace("conv1.bn.", "bn1.")

        # 3) residual blocks
        nk = re.sub(r"(layer\d+\.\d+)\.conv1\.bn\.", r"\1.bn1.", nk)
        nk = re.sub(r"(layer\d+\.\d+)\.conv2\.bn\.", r"\1.bn2.", nk)

        # 4) shortcut BN (conv, bn)
        nk = re.sub(r"(layer\d+\.\d+)\.shortcut\.0\.bn\.", r"\1.shortcut.1.", nk)

        out[nk] = v

    return out


# ---- LOAD ----
ckpt = torch.load("../../pth/resnet_qat_preconvert.pth", map_location="cpu")
state_fp32 = strip_and_remap_qat_state(ckpt)

missing, unexpected = model.load_state_dict(state_fp32, strict=False)
print("missing:", len(missing))
print("unexpected:", len(unexpected))
print("example missing:", missing[:20])
print("example unexpected:", unexpected[:20])

model.eval().to(device)

missing: 0
unexpected: 0
example missing: []
example unexpected: []


ResNetQAT(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlockQAT(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlockQAT(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1

In [6]:
@torch.no_grad()
def torch_acc(model, loader, device=device):
    model.eval()
    correct = total = 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        pred = model(x).argmax(1)
        correct += (pred == y).sum().item()
        total += y.size(0)
    return 100.0 * correct / total

print("Torch acc:", torch_acc(model, test_loader))

Torch acc: 88.64


In [11]:
import torch

model.eval().to(device)

dummy_map = {
    1:   torch.randn(1,   3, 32, 32, device=device),
    64:  torch.randn(64,  3, 32, 32, device=device),
    128: torch.randn(128, 3, 32, 32, device=device),
}

for bs, dummy in dummy_map.items():
    out_path = f"resnet_qat_qdq_b{bs}_op18.onnx"

    torch.onnx.export(
        model,
        dummy,
        out_path,
        opset_version=18,          # ✅ REQUIRED for Q/DQ
        do_constant_folding=False, # ✅ MUST be False for QAT
        input_names=["input"],
        output_names=["logits"],
        dynamic_axes=None,         # ✅ STATIC shapes
    )

    print("Exported:", out_path)

[torch.onnx] Obtain model graph for `ResNetQAT([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `ResNetQAT([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Applied 66 of general pattern rewrite rules.
Exported: resnet_qat_qdq_b1_op18.onnx
[torch.onnx] Obtain model graph for `ResNetQAT([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `ResNetQAT([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Applied 66 of general pattern rewrite rules.
Exported: resnet_qat_qdq_b64_op18.onnx
[torch.onnx] Obtain model graph for `ResNetQAT([...]` with `torch.export.export(..., stri

In [12]:
!ls -lh resnet_qat_qdq_b1_op18.onnx
!ls -lh resnet_qat_qdq_b64_op18.onnx
!ls -lh resnet_qat_qdq_b128_op18.onnx

-rw-r--r-- 1 ihsiao ihsiao 84K Dec 14 11:12 resnet_qat_qdq_b1_op18.onnx
-rw-r--r-- 1 ihsiao ihsiao 84K Dec 14 11:12 resnet_qat_qdq_b64_op18.onnx
-rw-r--r-- 1 ihsiao ihsiao 84K Dec 14 11:12 resnet_qat_qdq_b128_op18.onnx


In [13]:
import tensorrt as trt

TRT_LOGGER = trt.Logger(trt.Logger.INFO)

onnx_map = {
    1:   "resnet_qat_qdq_b1_op18.onnx",
    64:  "resnet_qat_qdq_b64_op18.onnx",
    128: "resnet_qat_qdq_b128_op18.onnx",
}

def build_qdq_engine(onnx_path, engine_path, workspace_bytes=(1 << 30)):
    with trt.Builder(TRT_LOGGER) as builder, \
         builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) as network, \
         trt.OnnxParser(network, TRT_LOGGER) as parser:

        with open(onnx_path, "rb") as f:
            if not parser.parse(f.read()):
                for i in range(parser.num_errors):
                    print(parser.get_error(i))
                raise RuntimeError(f"ONNX parse failed for {onnx_path}")

        config = builder.create_builder_config()
        config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace_bytes)

        # Optional but usually helpful fallback:
        # (If some QDQ layers can't run int8, TRT can fall back to FP16)
        config.set_flag(trt.BuilderFlag.FP16)

        serialized = builder.build_serialized_network(network, config)
        if serialized is None:
            raise RuntimeError(f"Engine build failed for {onnx_path}")

        with open(engine_path, "wb") as f:
            f.write(serialized)

    print("Saved:", engine_path)

for bs, onnx_path in onnx_map.items():
    build_qdq_engine(onnx_path, f"resnet_qat_qdq_b{bs}.engine")

[12/14/2025-11:29:16] [TRT] [I] [MemUsageChange] Init CUDA: CPU -2, GPU +0, now: CPU 612, GPU 3047 (MiB)
[12/14/2025-11:29:16] [TRT] [I] ----------------------------------------------------------------
[12/14/2025-11:29:16] [TRT] [I] ONNX IR version:  0.0.10
[12/14/2025-11:29:16] [TRT] [I] Opset version:    18
[12/14/2025-11:29:16] [TRT] [I] Producer name:    pytorch
[12/14/2025-11:29:16] [TRT] [I] Producer version: 2.9.1+cu128
[12/14/2025-11:29:16] [TRT] [I] Domain:           
[12/14/2025-11:29:16] [TRT] [I] Model version:    0
[12/14/2025-11:29:16] [TRT] [I] Doc string:       
[12/14/2025-11:29:16] [TRT] [I] ----------------------------------------------------------------
[12/14/2025-11:29:17] [TRT] [I] [MemUsageChange] Init builder kernel library: CPU +287, GPU +6, now: CPU 1100, GPU 3053 (MiB)
[12/14/2025-11:29:17] [TRT] [I] Local timing cache in use. Profiling results in this builder pass will not be stored.
[12/14/2025-11:29:33] [TRT] [I] Compiler backend is used during engine bu

In [None]:
acc1   = trt_accuracy_static("resnet_qat_qdq_b1.engine",   test_loader_b1)
acc64  = trt_accuracy_static("resnet_qat_qdq_b64.engine",  test_loader_b64)
acc128 = trt_accuracy_static("resnet_qat_qdq_b128.engine", test_loader_b128)

print(f"INT8 TRT Acc b1:   {acc1:.2f}%")
print(f"INT8 TRT Acc b64:  {acc64:.2f}%")
print(f"INT8 TRT Acc b128: {acc128:.2f}%")

In [8]:
import onnx
m = onnx.load(out_onnx)
ops = set(n.op_type for n in m.graph.node)
print("QuantizeLinear:", "QuantizeLinear" in ops)
print("DequantizeLinear:", "DequantizeLinear" in ops)

QuantizeLinear: False
DequantizeLinear: False


In [9]:
import torch
ckpt = torch.load("../../pth/resnet_qat_preconvert.pth", map_location="cpu")
print("num keys:", len(ckpt))
print("first 40 keys:")
for k in list(ckpt.keys())[:40]:
    print(" ", k)

num keys: 802
first 40 keys:
  activation_post_process_0.fake_quant_enabled
  activation_post_process_0.observer_enabled
  activation_post_process_0.scale
  activation_post_process_0.zero_point
  activation_post_process_0.activation_post_process.eps
  activation_post_process_0.activation_post_process.min_val
  activation_post_process_0.activation_post_process.max_val
  conv1.weight
  conv1.bn.weight
  conv1.bn.bias
  conv1.bn.running_mean
  conv1.bn.running_var
  conv1.bn.num_batches_tracked
  conv1.weight_fake_quant.fake_quant_enabled
  conv1.weight_fake_quant.observer_enabled
  conv1.weight_fake_quant.scale
  conv1.weight_fake_quant.zero_point
  conv1.weight_fake_quant.activation_post_process.eps
  conv1.weight_fake_quant.activation_post_process.min_val
  conv1.weight_fake_quant.activation_post_process.max_val
  activation_post_process_1.fake_quant_enabled
  activation_post_process_1.observer_enabled
  activation_post_process_1.scale
  activation_post_process_1.zero_point
  activatio

In [10]:
missing, unexpected = model.load_state_dict(ckpt, strict=False)
print("missing:", len(missing))
print("unexpected:", len(unexpected))
print("example missing:", missing[:20])
print("example unexpected:", unexpected[:20])

missing: 132
unexpected: 767
example missing: ['bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'layer1.0.bn1.weight', 'layer1.0.bn1.bias', 'layer1.0.bn1.running_mean', 'layer1.0.bn1.running_var', 'layer1.0.bn2.weight', 'layer1.0.bn2.bias', 'layer1.0.bn2.running_mean', 'layer1.0.bn2.running_var', 'layer1.1.bn1.weight', 'layer1.1.bn1.bias', 'layer1.1.bn1.running_mean', 'layer1.1.bn1.running_var', 'layer1.1.bn2.weight', 'layer1.1.bn2.bias', 'layer1.1.bn2.running_mean', 'layer1.1.bn2.running_var']
example unexpected: ['activation_post_process_0.fake_quant_enabled', 'activation_post_process_0.observer_enabled', 'activation_post_process_0.scale', 'activation_post_process_0.zero_point', 'activation_post_process_0.activation_post_process.eps', 'activation_post_process_0.activation_post_process.min_val', 'activation_post_process_0.activation_post_process.max_val', 'activation_post_process_1.fake_quant_enabled', 'activation_post_process_1.observer_enabled', 'activation_post_proc