## Prepare Models for TensorRT

ResNet:

baseline: 

pruned: 

pruned and quantized: 

Squeezenet:

baseline: squeezenet_data_aug_cifar10_fp32.pth

pruned: squeezenet_p50.pth

pruned and quantized: squeezenet_int8_qat.pth

AlexNet: 

baseline: alexnet_bn_cifar10.pth

pruned: alexnet_p70.pth

pruned and quantized: alexnet_int8_qat.pth

In [None]:
from benchmarks.models.alexnet_model import AlexNetCIFAR10
from benchmarks.models.squeezenet_model import SqueezeNetCIFAR10
from benchmarks.models.resnet32_model import ResNet32

def build_model_from_name(name: str):
    name = name.lower()

    if "resnet" in name:
        model = ResNet32()

    elif "alexnet" in name:
        model = AlexNetCIFAR10()
    elif "squeezenet" in name:
        model = SqueezeNetCIFAR10()

    else:
        raise ValueError(f"Unknown model type in filename: {name}")

    return model

In [None]:
import os
import torch
import onnx

MODELS_DIR = "pth"
ONNX_DIR = "onnx"
OPSET_VERSION = 17

os.makedirs(ONNX_DIR, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dummy_input = torch.randn(1, 3, 32, 32, device=device)

for fname in os.listdir(MODELS_DIR):
    if not fname.endswith(".pth"):
        continue

    model_path = os.path.join(MODELS_DIR, fname)
    onnx_name = fname.replace(".pth", ".onnx")
    onnx_path = os.path.join(ONNX_DIR, onnx_name)

    print(f"\nProcessing {fname}")

    # build model
    model = build_model_from_name(fname)
    sd = torch.load(model_path, map_location=device)
    model.load_state_dict(sd)
    model.to(device)
    model.eval()

    # export to ONNX
    torch.onnx.export(
        model,
        dummy_input,
        onnx_path,
        export_params=True,
        opset_version=OPSET_VERSION,
        do_constant_folding=True,
        input_names=["input"],
        output_names=["output"],
        dynamic_axes={"input": {0: "batch_size"}}
    )

    # Verify ONNX
    onnx_model = onnx.load(onnx_path)
    onnx.checker.check_model(onnx_model)

    print(f"Exported & verified: {onnx_path}")

print("\nAll models exported successfully.")
