# 04 - Model Export
# AutonomousVehiclePerception/notebooks/04_model_export.ipynb

Export trained models to ONNX, TorchScript, and apply torch.compile + quantization.

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import time
from pathlib import Path

from src.model.cnn_2d import PerceptionCNN2D
from src.model.cnn_3d_voxel import VoxelBackbone3D
from src.model.fpn_resnet import FPNDetector
from src.model.export import (
    export_onnx, export_torchscript,
    optimize_with_compile, quantize_dynamic
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

EXPORT_DIR = Path('../exports')
EXPORT_DIR.mkdir(parents=True, exist_ok=True)
CHECKPOINT_DIR = Path('../checkpoints')

## Load Trained Models

In [None]:
# 2D CNN
cnn2d = PerceptionCNN2D(num_classes=9)
ckpt_2d = CHECKPOINT_DIR / 'cnn_2d_best.pth'
if ckpt_2d.exists():
    state = torch.load(ckpt_2d, map_location='cpu', weights_only=True)
    cnn2d.load_state_dict(state['model_state_dict'])
    print(f'Loaded 2D CNN checkpoint (epoch {state["epoch"]}, val_loss={state["val_loss"]:.4f})')
else:
    print('No 2D CNN checkpoint found, using random weights')
cnn2d.eval()

# 3D Voxel CNN
cnn3d = VoxelBackbone3D(in_channels=1, num_classes=5)
ckpt_3d = CHECKPOINT_DIR / 'voxel3d_best.pth'
if ckpt_3d.exists():
    state = torch.load(ckpt_3d, map_location='cpu', weights_only=True)
    cnn3d.load_state_dict(state['model_state_dict'])
    print(f'Loaded 3D CNN checkpoint (epoch {state["epoch"]}, val_loss={state["val_loss"]:.4f})')
else:
    print('No 3D CNN checkpoint found, using random weights')
cnn3d.eval()

# FPN-ResNet50
fpn = FPNDetector(num_classes=9, pretrained=False)
fpn.eval()
print('FPN-ResNet50 initialized')

## ONNX Export

Export to ONNX for cross-platform deployment (TensorRT, .NET, Java).

In [None]:
dummy_2d = torch.randn(1, 3, 480, 640)
dummy_3d = torch.randn(1, 1, 20, 128, 128)

print('=== ONNX Export ===')
export_onnx(cnn2d, dummy_2d, EXPORT_DIR / 'cnn_2d.onnx')
export_onnx(cnn3d, dummy_3d, EXPORT_DIR / 'cnn_3d_voxel.onnx')
export_onnx(fpn, dummy_2d, EXPORT_DIR / 'fpn_resnet.onnx')

## TorchScript Export

Export to TorchScript for C++ runtime and TorchServe.

In [None]:
print('=== TorchScript Export ===')
export_torchscript(cnn2d, dummy_2d, EXPORT_DIR / 'cnn_2d.pt')
export_torchscript(cnn3d, dummy_3d, EXPORT_DIR / 'cnn_3d_voxel.pt')

## torch.compile Optimization

Apply graph-mode compilation for GPU inference speedup.

In [None]:
print('=== torch.compile Benchmark ===')

# Benchmark original vs compiled
model = PerceptionCNN2D(num_classes=9).to(device).eval()
test_input = torch.randn(1, 3, 480, 640).to(device)

# Warmup
with torch.no_grad():
    for _ in range(5):
        model(test_input)

# Original timing
times = []
with torch.no_grad():
    for _ in range(20):
        start = time.perf_counter()
        model(test_input)
        times.append((time.perf_counter() - start) * 1000)
print(f'Original:  {sum(times)/len(times):.2f} ms avg')

# Compiled timing
if device.type == 'cuda':
    compiled_model = optimize_with_compile(model, mode='reduce-overhead')
    with torch.no_grad():
        for _ in range(5):
            compiled_model(test_input)

    times_compiled = []
    with torch.no_grad():
        for _ in range(20):
            start = time.perf_counter()
            compiled_model(test_input)
            times_compiled.append((time.perf_counter() - start) * 1000)
    print(f'Compiled:  {sum(times_compiled)/len(times_compiled):.2f} ms avg')
    print(f'Speedup:   {sum(times)/sum(times_compiled):.2f}x')
else:
    print('torch.compile benchmark skipped (CPU only, best results on GPU)')

## Dynamic Quantization (INT8)

Apply INT8 quantization for reduced model size and faster CPU inference.

In [None]:
import os

print('=== INT8 Dynamic Quantization ===')
quantized = quantize_dynamic(cnn2d, EXPORT_DIR / 'cnn_2d_int8.pth')

# Compare sizes
onnx_size = os.path.getsize(EXPORT_DIR / 'cnn_2d.onnx') / (1024 * 1024)
pt_size = os.path.getsize(EXPORT_DIR / 'cnn_2d.pt') / (1024 * 1024)

print(f'\nModel sizes:')
print(f'  ONNX:        {onnx_size:.1f} MB')
print(f'  TorchScript: {pt_size:.1f} MB')

# Verify quantized model output
with torch.no_grad():
    orig_out = cnn2d(dummy_2d)
    quant_out = quantized(dummy_2d)
    print(f'\nOriginal output shape:  {orig_out.shape}')
    print(f'Quantized output shape: {quant_out.shape}')
    print(f'Max output diff: {(orig_out - quant_out).abs().max():.6f}')

## ONNX Validation

In [None]:
try:
    import onnx
    import onnxruntime as ort

    print('=== ONNX Validation ===')
    onnx_model = onnx.load(str(EXPORT_DIR / 'cnn_2d.onnx'))
    onnx.checker.check_model(onnx_model)
    print('ONNX model valid âœ“')

    # Run inference with ONNX Runtime
    session = ort.InferenceSession(str(EXPORT_DIR / 'cnn_2d.onnx'))
    ort_input = {'input': dummy_2d.numpy()}
    ort_output = session.run(None, ort_input)
    print(f'ONNX Runtime output shape: {ort_output[0].shape}')

    # Compare with PyTorch
    import numpy as np
    with torch.no_grad():
        pt_output = cnn2d(dummy_2d).numpy()
    diff = np.abs(pt_output - ort_output[0]).max()
    print(f'Max diff PyTorch vs ONNX Runtime: {diff:.6f}')
except ImportError:
    print('Install onnx and onnxruntime for validation: pip install onnx onnxruntime')

print('\n=== Export Summary ===')
for f in sorted(EXPORT_DIR.glob('*')):
    size_mb = os.path.getsize(f) / (1024 * 1024)
    print(f'  {f.name}: {size_mb:.1f} MB')