In [1]:
import torch
import tensorrt as trt

print("Torch:", torch.__version__, "CUDA:", torch.cuda.is_available())
print("TensorRT:", trt.__version__)

Torch: 2.9.1+cu128 CUDA: True
TensorRT: 10.14.1.48.post1


In [2]:
import sys
from pathlib import Path
import torch
import torch.nn as nn
import tensorrt as trt
import numpy as np

sys.path.insert(0, str(Path.cwd().parents[1]))

from models.resnet32_model import ResNet, ResNetQAT

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

Device: cuda


In [3]:
model = ResNetQAT(num_classes=10)

state = torch.load("../../pth/resnet_qat_preconvert.pth", map_location="cpu")
model.load_state_dict(state, strict=False)

model.eval()

ResNetQAT(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlockQAT(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlockQAT(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1

In [8]:
import torch.ao.quantization as tq

model.eval()

# ðŸ”‘ CRITICAL STEP
model_int8 = tq.convert(model, inplace=False)

dummy_map = {
    1:   torch.randn(1,   3, 32, 32),
    64:  torch.randn(64,  3, 32, 32),
    128: torch.randn(128, 3, 32, 32),
}

for bs, dummy in dummy_map.items():
    out_path = f"resnet_qat_int8_b{bs}_op18.onnx"
    torch.onnx.export(
        model_int8, dummy, out_path,
        opset_version=18,
        do_constant_folding=False,   # IMPORTANT for QAT
        input_names=["input"],
        output_names=["logits"],
        dynamic_axes=None,
    )
    print("Exported", out_path)

For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  model_int8 = tq.convert(model, inplace=False)


[torch.onnx] Obtain model graph for `ResNetQAT([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `ResNetQAT([...]` with `torch.export.export(..., strict=False)`... âœ…
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... âœ…
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... âœ…
Applied 99 of general pattern rewrite rules.
Exported resnet_qat_int8_b1_op18.onnx
[torch.onnx] Obtain model graph for `ResNetQAT([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `ResNetQAT([...]` with `torch.export.export(..., strict=False)`... âœ…
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... âœ…
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... âœ…
Applied 99 of general pattern rewrite rules.
Exported resnet_qat_int8_b64_op18.onnx
[torch.onnx] Obtain model graph for `ResNetQAT([...]` with `torch.export.expo

In [9]:
import onnx

m = onnx.load("resnet_qat_int8_b1_op18.onnx")
ops = sorted(set(n.op_type for n in m.graph.node))
print("QuantizeLinear:", "QuantizeLinear" in ops)
print("DequantizeLinear:", "DequantizeLinear" in ops)

QuantizeLinear: False
DequantizeLinear: False


In [10]:
!ls -lh resnet_qat_int8_b1_op18.onnx
!ls -lh resnet_qat_int8_b64_op18.onnx
!ls -lh resnet_qat_int8_b128_op18.onnx

-rw-r--r-- 1 ihsiao ihsiao 78K Dec 14 09:47 resnet_qat_int8_b1_op18.onnx
-rw-r--r-- 1 ihsiao ihsiao 78K Dec 14 09:47 resnet_qat_int8_b64_op18.onnx
-rw-r--r-- 1 ihsiao ihsiao 78K Dec 14 09:47 resnet_qat_int8_b128_op18.onnx


In [11]:
import onnx
m = onnx.load("resnet_qat_int8_b1_op13.onnx")
onnx.checker.check_model(m)
print([(op.domain, op.version) for op in m.opset_import])

[('', 18)]


In [12]:
import tensorrt as trt
print(trt.__version__)

10.14.1.48.post1


In [13]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010)),
])

test_dataset = datasets.CIFAR10(
    root="./data",
    train=False,
    download=True,
    transform=test_transform
)

test_loader = DataLoader(
    test_dataset,
    batch_size=128,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

In [14]:
from torch.utils.data import DataLoader

test_loader_b1   = DataLoader(test_dataset, batch_size=1,   shuffle=False, num_workers=2, pin_memory=True, drop_last=True)
test_loader_b64  = DataLoader(test_dataset, batch_size=64,  shuffle=False, num_workers=2, pin_memory=True, drop_last=True)
test_loader_b128 = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2, pin_memory=True, drop_last=True)

In [15]:
import torch
import tensorrt as trt

@torch.no_grad()
def trt_accuracy_static(engine_path, test_loader, num_batches=None):
    TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

    with open(engine_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
        engine = runtime.deserialize_cuda_engine(f.read())
    context = engine.create_execution_context()

    names = [engine.get_tensor_name(i) for i in range(engine.num_io_tensors)]
    inp = [n for n in names if engine.get_tensor_mode(n) == trt.TensorIOMode.INPUT][0]
    out = [n for n in names if engine.get_tensor_mode(n) == trt.TensorIOMode.OUTPUT][0]

    # engine fixed shapes
    in_shape = tuple(engine.get_tensor_shape(inp))
    out_shape = tuple(engine.get_tensor_shape(out))
    fixed_bsz = in_shape[0]  # should be 1 or 64 or 128

    # output dtype
    trt_dtype = engine.get_tensor_dtype(out)
    torch_dtype = {
        trt.DataType.FLOAT: torch.float32,
        trt.DataType.HALF:  torch.float16,
        trt.DataType.INT8:  torch.int8,
        trt.DataType.INT32: torch.int32,
    }[trt_dtype]

    stream = torch.cuda.current_stream()
    correct = 0
    total = 0

    for bi, (x_cpu, y_cpu) in enumerate(test_loader):
        if num_batches is not None and bi >= num_batches:
            break

        x = x_cpu.to("cuda", non_blocking=True)
        y = y_cpu.to("cuda", non_blocking=True)

        if x.shape[0] != fixed_bsz:
            raise RuntimeError(f"Batch mismatch: loader={x.shape[0]} but engine expects {fixed_bsz}")

        yhat = torch.empty(out_shape, device="cuda", dtype=torch_dtype)

        context.set_tensor_address(inp, int(x.data_ptr()))
        context.set_tensor_address(out, int(yhat.data_ptr()))

        ok = context.execute_async_v3(stream_handle=stream.cuda_stream)
        if not ok:
            raise RuntimeError("TRT execute failed")

        pred = yhat.float().argmax(dim=1)
        correct += (pred == y).sum().item()
        total += x.shape[0]

    torch.cuda.synchronize()
    return 100.0 * correct / total

In [18]:
import tensorrt as trt
import torch
import onnx

# ---- QUIET logger (prevents per-batch spam) ----
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

# ---- YOUR 3 STATIC ONNX FILES (fixed batch) ----
onnx_map = {
    1:   "resnet_qat_fp32_b1_op13.onnx",
    64:  "resnet_qat_fp32_b64_op13.onnx",
    128: "resnet_qat_fp32_b128_op13.onnx",
}

# ---- CALIBRATION LOADERS MUST MATCH BATCH SIZE ----
calib_loader_map = {
    1:   test_loader_b1,
    64:  test_loader_b64,
    128: test_loader_b128,
}

CALIB_BATCHES = 200          # 200 is fine; 500 if you want extra
WORKSPACE     = 1 << 28      # 256MB (raise to 1<<29 if build fails)

# ---------------- Calibrator ----------------
class EntropyCalibrator(trt.IInt8EntropyCalibrator2):
    def __init__(self, calib_loader, max_batches=200, cache_file="calib.cache"):
        super().__init__()
        self.cache_file = cache_file
        self.data_iter = iter(calib_loader)
        self.max_batches = max_batches
        self.batch_count = 0

        x0, _ = next(iter(calib_loader))
        self.batch_size = x0.shape[0]
        self.device_input = torch.empty_like(x0, device="cuda")

    def get_batch_size(self):
        return self.batch_size

    def get_batch(self, names):
        if self.batch_count >= self.max_batches:
            return None
        try:
            x, _ = next(self.data_iter)
        except StopIteration:
            return None

        self.batch_count += 1
        x = x.to("cuda", non_blocking=True)
        self.device_input.resize_(x.shape).copy_(x)
        return [int(self.device_input.data_ptr())]

    def read_calibration_cache(self):
        try:
            with open(self.cache_file, "rb") as f:
                return f.read()
        except FileNotFoundError:
            return None

    def write_calibration_cache(self, cache):
        with open(self.cache_file, "wb") as f:
            f.write(cache)

# ---------------- Builder ----------------
def build_int8_engine_static(onnx_path, engine_path, calib_loader, max_calib_batches=200):
    # sanity check: if ONNX already has Q/DQ, you generally should NOT calibrate
    m = onnx.load(onnx_path)
    ops = set(n.op_type for n in m.graph.node)
    has_qdq = ("QuantizeLinear" in ops) or ("DequantizeLinear" in ops)
    if has_qdq:
        print(f"[WARN] {onnx_path} contains Quantize/Dequantize ops. "
              f"PTQ calibration may be wrong for this file.")

    with trt.Builder(TRT_LOGGER) as builder, \
         builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) as network, \
         trt.OnnxParser(network, TRT_LOGGER) as parser:

        with open(onnx_path, "rb") as f:
            if not parser.parse(f.read()):
                for i in range(parser.num_errors):
                    print(parser.get_error(i))
                raise RuntimeError(f"ONNX parse failed for {onnx_path}")

        config = builder.create_builder_config()
        config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, WORKSPACE)

        # INT8 PTQ
        config.set_flag(trt.BuilderFlag.INT8)
        cache_file = engine_path.replace(".engine", ".cache")
        config.int8_calibrator = EntropyCalibrator(
            calib_loader,
            max_batches=max_calib_batches,
            cache_file=cache_file
        )

        # static ONNX => NO optimization profile needed
        serialized = builder.build_serialized_network(network, config)
        if serialized is None:
            raise RuntimeError(f"INT8 engine build failed for {onnx_path}")

        with open(engine_path, "wb") as f:
            f.write(serialized)

    print(f"Saved: {engine_path} | cache: {cache_file} | calib_batches={max_calib_batches}")

# ---------------- Build all 3 ----------------
for bs, onnx_path in onnx_map.items():
    build_int8_engine_static(
        onnx_path=onnx_path,
        engine_path=f"resnet_qat_int8_b{bs}.engine",
        calib_loader=calib_loader_map[bs],
        max_calib_batches=CALIB_BATCHES
    )



  config.int8_calibrator = EntropyCalibrator(


[12/14/2025-09:52:06] [TRT] [W] Missing scale and zero-point for tensor onnx::Conv_391_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[12/14/2025-09:52:06] [TRT] [W] Missing scale and zero-point for tensor onnx::Conv_421, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[12/14/2025-09:52:06] [TRT] [W] Missing scale and zero-point for tensor onnx::Conv_418, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[12/14/2025-09:52:06] [TRT] [W] Missing scale and zero-point for tensor onnx::Conv_415, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[12/14/2025-09:52:06] [TRT] [W] Missing scale and zero-point for tensor onnx::Conv_412, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[12/14/2025-09:52:06] [TRT] [W] Missing scale and zero-point for tensor onnx::Conv_

In [19]:
!ls -lh resnet_qat_int8_b1.engine
!ls -lh resnet_qat_int8_b64.engine
!ls -lh resnet_qat_int8_b128.engine

-rw-r--r-- 1 ihsiao ihsiao 1.2M Dec 14 09:52 resnet_qat_int8_b1.engine
-rw-r--r-- 1 ihsiao ihsiao 856K Dec 14 09:52 resnet_qat_int8_b64.engine
-rw-r--r-- 1 ihsiao ihsiao 877K Dec 14 09:53 resnet_qat_int8_b128.engine


In [20]:
acc1   = trt_accuracy_static("resnet_qat_int8_b1.engine",   test_loader_b1)
acc64  = trt_accuracy_static("resnet_qat_int8_b64.engine",  test_loader_b64)
acc128 = trt_accuracy_static("resnet_qat_int8_b128.engine", test_loader_b128)

print(f"INT8 TRT Acc b1:   {acc1:.2f}%")
print(f"INT8 TRT Acc b64:  {acc64:.2f}%")
print(f"INT8 TRT Acc b128: {acc128:.2f}%")

[12/14/2025-09:53:44] [TRT] [W] Using default stream in enqueueV3() may lead to performance issues due to additional calls to cudaStreamSynchronize() by TensorRT to ensure correct synchronization. Please use non-default stream instead.
[12/14/2025-09:53:51] [TRT] [W] Using default stream in enqueueV3() may lead to performance issues due to additional calls to cudaStreamSynchronize() by TensorRT to ensure correct synchronization. Please use non-default stream instead.
[12/14/2025-09:53:52] [TRT] [W] Using default stream in enqueueV3() may lead to performance issues due to additional calls to cudaStreamSynchronize() by TensorRT to ensure correct synchronization. Please use non-default stream instead.
INT8 TRT Acc b1:   10.03%
INT8 TRT Acc b64:  9.84%
INT8 TRT Acc b128: 9.88%


In [21]:
@torch.no_grad()
def torch_acc(model, loader, device="cuda"):
    model.eval().to(device)
    correct = total = 0
    for x,y in loader:
        x,y = x.to(device), y.to(device)
        pred = model(x).argmax(1)
        correct += (pred==y).sum().item()
        total += y.size(0)
    return 100*correct/total

print("Torch acc:", torch_acc(model, test_loader))

Torch acc: 9.83


In [None]:
print(test_loader.dataset.transform)

In [None]:
import onnx
m = onnx.load("resnet_qat_fp32_b1_op13.onnx")  # <-- change to your qat onnx filename
ops = sorted(set(n.op_type for n in m.graph.node))
print("QuantizeLinear:", "QuantizeLinear" in ops)
print("DequantizeLinear:", "DequantizeLinear" in ops)
print("Num nodes:", len(m.graph.node))

In [None]:
import onnx
m = onnx.load("resnet_qat_fp32_b1_op13.onnx")
print("Inputs:")
for i in m.graph.input:
    print(i.name)
print("Outputs:")
for o in m.graph.output:
    print(o.name)

In [None]:
import tensorrt as trt

TRT_LOGGER = trt.Logger(trt.Logger.ERROR)
with open("resnet_qat_int8_b1.engine","rb") as f, trt.Runtime(TRT_LOGGER) as rt:
    engine = rt.deserialize_cuda_engine(f.read())

print("Engine name:", engine.name)
print("Has refittable:", engine.refittable)
print("Num layers:", engine.num_layers)
print("Has dynamic shapes:", engine.num_optimization_profiles > 0)

In [22]:
import onnx
m = onnx.load("resnet_qat_fp32_b1_op13.onnx")   # <-- whichever file you're using
ops = sorted(set(n.op_type for n in m.graph.node))
print("Has QuantizeLinear?", "QuantizeLinear" in ops)
print("Has DequantizeLinear?", "DequantizeLinear" in ops)

Has QuantizeLinear? False
Has DequantizeLinear? False
