## Purpose

This notebook exports the trained Brevitas quantized regression CNN into a FINN-compatible QONNX model.

This step is mandatory for FINN and bridges:

1-Model.ipynb ‚Üí FINN toolflow

The output of this notebook is:

model_regression_qonnx.onnx

This file will be consumed by 2-finn.ipynb.

In [1]:
import torch
import torch.nn as nn
import onnx

import brevitas.nn as qnn
import brevitas.quant as quant
from brevitas.export import export_qonnx


No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'


## Import model definitions 

These definitions are taken directly from 1-Model.ipynb.

In [2]:
import torch.nn as nn
import brevitas.nn as qnn
import brevitas.quant as quant


# =========================================================
# FP32 reference model (used only to load weights)
# =========================================================
class EllipseRegressionModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 256, 3, padding=1)
        self.bn4 = nn.BatchNorm2d(256)
        self.pool = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(256, 128)
        self.fc2 = nn.Linear(128, 5)

    def forward(self, x):
        x = self.pool(torch.relu(self.bn1(self.conv1(x))))
        x = self.pool(torch.relu(self.bn2(self.conv2(x))))
        x = self.pool(torch.relu(self.bn3(self.conv3(x))))
        x = self.pool(torch.relu(self.bn4(self.conv4(x))))
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        return self.fc2(x)


# =========================================================
# Quantized QAT model (this is what FINN consumes)
# =========================================================
class QuantEllipseRegressionModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = qnn.QuantConv2d(1, 32, 3, padding=1, weight_quant=quant.Int8WeightPerTensorFixedPoint)
        self.conv2 = qnn.QuantConv2d(32, 64, 3, padding=1, weight_quant=quant.Int8WeightPerTensorFixedPoint)
        self.conv3 = qnn.QuantConv2d(64, 128, 3, padding=1, weight_quant=quant.Int8WeightPerTensorFixedPoint)
        self.conv4 = qnn.QuantConv2d(128, 256, 3, padding=1, weight_quant=quant.Int8WeightPerTensorFixedPoint)
        self.pool = nn.MaxPool2d(2)
        self.relu = qnn.QuantReLU(bit_width=8)
        self.fc1 = qnn.QuantLinear(256, 128, weight_quant=quant.Int8WeightPerTensorFixedPoint)
        self.fc2 = qnn.QuantLinear(128, 5, weight_quant=quant.Int8WeightPerTensorFixedPoint)

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.pool(self.relu(self.conv3(x)))
        x = self.pool(self.relu(self.conv4(x)))
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        return self.fc2(x)

In [3]:
# =========================================================
# Quantized QAT model - CORRECTED to match checkpoint
# =========================================================
class QuantEllipseRegressionModel(nn.Module):
    def __init__(self):
        super().__init__()
        # Conv layers WITHOUT bias (BatchNorm handles the bias)
        self.conv1 = qnn.QuantConv2d(1, 32, 3, padding=1, weight_quant=quant.Int8WeightPerTensorFixedPoint, bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = qnn.QuantConv2d(32, 64, 3, padding=1, weight_quant=quant.Int8WeightPerTensorFixedPoint, bias=False)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = qnn.QuantConv2d(64, 128, 3, padding=1, weight_quant=quant.Int8WeightPerTensorFixedPoint, bias=False)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = qnn.QuantConv2d(128, 256, 3, padding=1, weight_quant=quant.Int8WeightPerTensorFixedPoint, bias=False)
        self.bn4 = nn.BatchNorm2d(256)
        self.pool = nn.MaxPool2d(2)
        self.act = qnn.QuantReLU(bit_width=8)
        self.fc1 = qnn.QuantLinear(256, 512, weight_quant=quant.Int8WeightPerTensorFixedPoint, bias=False)
        self.fc2 = qnn.QuantLinear(512, 256, weight_quant=quant.Int8WeightPerTensorFixedPoint, bias=False)
        self.fc_out = nn.Linear(256, 5, bias=False)

    def forward(self, x):
        x = self.pool(self.act(self.bn1(self.conv1(x))))
        x = self.pool(self.act(self.bn2(self.conv2(x))))
        x = self.pool(self.act(self.bn3(self.conv3(x))))
        x = self.pool(self.act(self.bn4(self.conv4(x))))
        x = x.view(x.size(0), -1)
        x = self.act(self.fc1(x))
        x = self.fc2(x)
        return self.fc_out(x)

## Load trained QAT weights

These are the same weights you trained in 1-Model.ipynb.

In [4]:
model = QuantEllipseRegressionModel()
model.load_state_dict(torch.load("ellipse_qat_best.pth", map_location="cpu"))
model.eval()

QuantEllipseRegressionModel(
  (conv1): QuantConv2d(
    1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
    (input_quant): ActQuantProxyFromInjector(
      (_zero_hw_sentinel): StatelessBuffer()
    )
    (output_quant): ActQuantProxyFromInjector(
      (_zero_hw_sentinel): StatelessBuffer()
    )
    (weight_quant): WeightQuantProxyFromInjector(
      (_zero_hw_sentinel): StatelessBuffer()
      (tensor_quant): RescalingIntQuant(
        (int_quant): IntQuant(
          (float_to_int_impl): RoundSte()
          (tensor_clamp_impl): TensorClampSte()
          (delay_wrapper): DelayWrapper(
            (delay_impl): _NoDelay()
          )
        )
        (scaling_impl): StatsFromParameterScaling(
          (parameter_list_stats): _ParameterListStats(
            (first_tracked_param): _ViewParameterWrapper(
              (view_shape_impl): OverTensorView()
            )
            (stats): _Stats(
              (stats_impl): AbsMax()
            )
          )
  

In [5]:
class QuantEllipseRegressionModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = qnn.QuantConv2d(1, 32, 3, padding=1, weight_quant=quant.Int8WeightPerTensorFixedPoint, bias=True)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = qnn.QuantConv2d(32, 64, 3, padding=1, weight_quant=quant.Int8WeightPerTensorFixedPoint, bias=True)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = qnn.QuantConv2d(64, 128, 3, padding=1, weight_quant=quant.Int8WeightPerTensorFixedPoint, bias=True)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = qnn.QuantConv2d(128, 256, 3, padding=1, weight_quant=quant.Int8WeightPerTensorFixedPoint, bias=True)
        self.bn4 = nn.BatchNorm2d(256)
        self.pool = nn.MaxPool2d(2)
        self.act = qnn.QuantReLU(bit_width=8)
        self.fc1 = qnn.QuantLinear(256, 512, weight_quant=quant.Int8WeightPerTensorFixedPoint, bias=True)
        self.fc2 = qnn.QuantLinear(512, 256, weight_quant=quant.Int8WeightPerTensorFixedPoint, bias=False)
        self.fc_out = nn.Linear(256, 5)

    def forward(self, x):
        x = self.pool(self.act(self.bn1(self.conv1(x))))
        x = self.pool(self.act(self.bn2(self.conv2(x))))
        x = self.pool(self.act(self.bn3(self.conv3(x))))
        x = self.pool(self.act(self.bn4(self.conv4(x))))
        x = x.view(x.size(0), -1)
        x = self.act(self.fc1(x))
        x = self.fc2(x)
        return self.fc_out(x)

## Define dummy input (static shape required by FINN)

FINN requires fixed spatial dimensions during ONNX export.

Input specification:

Grayscale image

Shape: (N, 1, 20, 20)

In [6]:
dummy_input = torch.randn(1, 1, 20, 20)

## Export to FINN-compatible QONNX

‚ö†Ô∏è Do NOT use torch.onnx.export

Always use Brevitas QONNX exporter so that:

Quantization metadata is preserved

Bit-widths are explicitly encoded

FINN can infer datatypes correctly

In [7]:
export_path = "ellipse_regression_qonnx.onnx"


export_qonnx(
model,
dummy_input,
export_path=export_path
)


print("Exported QONNX model to:", export_path)

verbose: False, log level: Level.ERROR

Exported QONNX model to: ellipse_regression_qonnx.onnx


## Verify ONNX model correctness

This ensures the exported model is structurally valid before FINN processing.

In [8]:
onnx_model = onnx.load(export_path)
onnx.checker.check_model(onnx_model)
print("ONNX model check passed ‚úî")

ONNX model check passed ‚úî


## (Optional but recommended) Quick inference sanity check

This confirms that the exported model produces outputs with the expected shape.

In [None]:
# Cell : (Optional but recommended) Quick inference sanity check
with torch.no_grad():
    out = model(dummy_input)
    print("Output shape:", out.shape)
    print("Expected: torch.Size([1, 5])")
    
    # Verify output is valid
    assert out.shape == (1, 5), f"Expected shape (1, 5), got {out.shape}"
    print("‚úÖ Inference sanity check passed!")

Output shape: torch.Size([1, 5])
Expected: torch.Size([1, 5])
‚úÖ Inference sanity check passed!


In [None]:
# Cell : Verify QONNX Export
import onnx
from collections import Counter

onnx_model = onnx.load(export_path)
onnx.checker.check_model(onnx_model)

print("‚úÖ ONNX model check passed")
print(f"‚úÖ Model name: {onnx_model.graph.name}")
print(f"‚úÖ Opset version: {onnx_model.opset_import[0].version}")
print(f"‚úÖ Producer: {onnx_model.producer_name}")

# Count node types
node_types = [node.op_type for node in onnx_model.graph.node]
print(f"\nüìä Total nodes: {len(node_types)}")
print("\nüîç Node type breakdown:")
for op_type, count in Counter(node_types).most_common():
    print(f"  {op_type}: {count}")

# Check inputs and outputs
print(f"\nüì• Inputs: {len(onnx_model.graph.input)}")
for inp in onnx_model.graph.input:
    shape = [dim.dim_value for dim in inp.type.tensor_type.shape.dim]
    print(f"  - {inp.name}: {shape}")

print(f"\nüì§ Outputs: {len(onnx_model.graph.output)}")
for out in onnx_model.graph.output:
    shape = [dim.dim_value for dim in out.type.tensor_type.shape.dim]
    print(f"  - {out.name}: {shape}")

‚úÖ ONNX model check passed
‚úÖ Model name: torch_jit
‚úÖ Opset version: 14
‚úÖ Producer: pytorch

üìä Total nodes: 34

üîç Node type breakdown:
  Quant: 11
  Relu: 5
  Conv: 4
  BatchNormalization: 4
  MaxPool: 4
  MatMul: 3
  Transpose: 2
  Reshape: 1

üì• Inputs: 18
  - x.7: [1, 1, 20, 20]
  - bn1.weight: [32]
  - bn1.bias: [32]
  - bn1.running_mean: [32]
  - bn1.running_var: [32]
  - bn2.weight: [64]
  - bn2.bias: [64]
  - bn2.running_mean: [64]
  - bn2.running_var: [64]
  - bn3.weight: [128]
  - bn3.bias: [128]
  - bn3.running_mean: [128]
  - bn3.running_var: [128]
  - bn4.weight: [256]
  - bn4.bias: [256]
  - bn4.running_mean: [256]
  - bn4.running_var: [256]
  - onnx::MatMul_86: [256, 5]

üì§ Outputs: 1
  - 82: [1, 5]


In [10]:
# Cell : Verify QONNX Model (Skip ONNX Runtime - use for FINN only)
print("‚ÑπÔ∏è Skipping ONNX Runtime test")
print("   Reason: QONNX models contain Brevitas custom ops (e.g., 'Quant')")
print("   that are not supported by standard ONNX Runtime.")
print("")
print("‚úÖ This model is designed for FINN compilation.")
print("   FINN will handle the custom quantization operators.")

# Instead, verify the model structure
import onnx

onnx_model = onnx.load(export_path)
print(f"\nüìä Model Info:")
print(f"  - Nodes: {len(onnx_model.graph.node)}")
print(f"  - Inputs: {len(onnx_model.graph.input)}")
print(f"  - Outputs: {len(onnx_model.graph.output)}")
print(f"  - Opset: {onnx_model.opset_import[0].version}")

# Check for Brevitas custom ops
custom_ops = set()
for node in onnx_model.graph.node:
    if node.domain == "onnx.brevitas":
        custom_ops.add(node.op_type)

if custom_ops:
    print(f"\nüéØ Brevitas QONNX operators found: {', '.join(custom_ops)}")
    print("   ‚úÖ Model ready for FINN!")
else:
    print("\n‚ö†Ô∏è No Brevitas QONNX operators found")

‚ÑπÔ∏è Skipping ONNX Runtime test
   Reason: QONNX models contain Brevitas custom ops (e.g., 'Quant')
   that are not supported by standard ONNX Runtime.

‚úÖ This model is designed for FINN compilation.
   FINN will handle the custom quantization operators.

üìä Model Info:
  - Nodes: 34
  - Inputs: 18
  - Outputs: 1
  - Opset: 14

üéØ Brevitas QONNX operators found: Quant
   ‚úÖ Model ready for FINN!


In [1]:
# Cell: Launch Netron Visualizer

import netron
import os

qonnx_path = "ellipse_regression_qonnx.onnx"
if os.path.exists(qonnx_path):
    print(f" Launching Netron for: {qonnx_path}")
    netron.start(qonnx_path)
else:
    print(f" File not found: {qonnx_path}")
    print("   Run the QONNX export cell first!")

Serving 'ellipse_regression_qonnx.onnx' at http://localhost:8081


 Launching Netron for: ellipse_regression_qonnx.onnx
