# Verify Exported ONNX Model in FINN

<font color="red">**Live FINN tutorial:** We recommend clicking **Cell -> Run All** when you start reading this notebook for "latency hiding".</font>

**Important: This notebook depends on the 1-train-mlp-with-brevitas notebook, because we are using the ONNX model that was exported there. So please make sure the needed .onnx file is generated before you run this notebook.**

**Also remember to 'close and halt' any other FINN notebooks, since Netron visualizations use the same port.**

In this notebook we will show how to import the network we trained in Brevitas and verify it in the FINN compiler. 
This verification process can actually be done at various stages in the compiler [as explained in this notebook](../bnn-pynq/tfc_end2end_verification.ipynb) but for this example we'll only consider the first step: verifying the exported high-level FINN-ONNX model.
Another goal of this notebook is to introduce you to the concept of *graph transformations* -- we'll be applying some transformations to the graph to make it executable for verification. 
Once this model is sucessfully verified, we'll generate an FPGA accelerator from it in the next notebook.

In [1]:
import onnx 
import torch 

**This is important -- always import onnx before torch**. This is a workaround for a [known bug](https://github.com/onnx/onnx/issues/2394).

## Outline
-------------
1. [Import model into FINN with ModelWrapper](#brevitas_import_visualization)
2. [Network preparations: Tidy-up transformations](#network_preparations)
3. [Load the dataset and Brevitas model](#load_dataset) 
4. [Compare FINN and Brevitas execution](#compare_brevitas)

# 1. Import model into FINN with ModelWrapper <a id="brevitas_import_visualization"></a>

Now that we have the model in .onnx format, we can work with it using FINN. To import it into FINN, we'll use the [`ModelWrapper`](https://finn.readthedocs.io/en/latest/source_code/finn.core.html#qonnx.core.modelwrapper.ModelWrapper). It is a wrapper around the ONNX model which provides several helper functions to make it easier to work with the model.

In [2]:
import os
from qonnx.core.modelwrapper import ModelWrapper

model_dir = os.environ['FINN_ROOT'] + "/notebooks/FINN_Brevitas"
ready_model_filename = model_dir + "/finn-brevitas-ready.onnx"
model_for_sim = ModelWrapper(ready_model_filename)

Let's have a look at some of the member functions exposed by `ModelWrapper` to see what kind of information we can extract from it.

In [3]:
dir(model_for_sim)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_model_proto',
 'analysis',
 'check_all_tensor_shapes_specified',
 'check_compatibility',
 'cleanup',
 'find_consumer',
 'find_consumers',
 'find_direct_predecessors',
 'find_direct_successors',
 'find_producer',
 'find_upstream',
 'fix_float64',
 'get_all_tensor_names',
 'get_finn_nodes',
 'get_initializer',
 'get_metadata_prop',
 'get_node_from_name',
 'get_node_index',
 'get_nodes_by_op_type',
 'get_non_finn_nodes',
 'get_tensor_datatype',
 'get_tensor_fanout',
 'get_tensor_layout',
 'get_tensor_shape',
 'get_tensor_sparsity',
 'get_tensor_valueinfo',
 'graph',
 'is_fork_node',
 'is_join_node',
 'make_empty_exec_conte

Many of these helper functions relate to extracting information about the structure and properties of the ONNX model. You can find out more about examining and manipulating ONNX models programmatically in [this tutorial](../../basics/0_how_to_work_with_onnx.ipynb), but we'll show a few basic functions here. For instance, we can extract the shape and datatype annotation for various tensors in the graph, as well as information related to the operation types associated with each node.

In [4]:
from qonnx.core.datatype import DataType

finnonnx_in_tensor_name = model_for_sim.graph.input[0].name
finnonnx_out_tensor_name = model_for_sim.graph.output[0].name
print("Input tensor name: %s" % finnonnx_in_tensor_name)
print("Output tensor name: %s" % finnonnx_out_tensor_name)
finnonnx_model_in_shape = model_for_sim.get_tensor_shape(finnonnx_in_tensor_name)
finnonnx_model_out_shape = model_for_sim.get_tensor_shape(finnonnx_out_tensor_name)
print("Input tensor shape: %s" % str(finnonnx_model_in_shape))
print("Output tensor shape: %s" % str(finnonnx_model_out_shape))
finnonnx_model_in_dt = model_for_sim.get_tensor_datatype(finnonnx_in_tensor_name)
finnonnx_model_out_dt = model_for_sim.get_tensor_datatype(finnonnx_out_tensor_name)
print("Input tensor datatype: %s" % str(finnonnx_model_in_dt.name))
print("Output tensor datatype: %s" % str(finnonnx_model_out_dt.name))
print("List of node operator types in the graph: ")
print([x.op_type for x in model_for_sim.graph.node])

Input tensor name: global_in
Output tensor name: global_out
Input tensor shape: [1, 3, 64, 64]
Output tensor shape: [1, 4]
Input tensor datatype: FLOAT32
Output tensor datatype: FLOAT32
List of node operator types in the graph: 
['Reshape', 'MatMul', 'Mul', 'Add', 'BatchNormalization', 'MultiThreshold', 'Mul', 'MatMul', 'Mul', 'Add', 'BatchNormalization', 'MultiThreshold', 'Mul', 'MatMul', 'Mul', 'Add', 'BatchNormalization', 'MultiThreshold', 'Mul', 'MatMul', 'Mul', 'Add']


Note that the output tensor is (as of yet) marked as a float32 value, even though we know the output is binary. This will be automatically inferred by the compiler in the next step when we run the `InferDataTypes` transformation.

# 2. Network preparation: Tidy-up transformations <a id="network_preparations"></a>

Before running the verification, we need to prepare our FINN-ONNX model. In particular, all the intermediate tensors need to have statically defined shapes. To do this, we apply some graph transformations to the model like a kind of "tidy-up" to make it easier to process. 

**Graph transformations in FINN.** The whole FINN compiler is built around the idea of transformations, which gradually transform the model into a synthesizable hardware description. Although FINN offers functionality that automatically calls a standard sequence of transformations (covered in the next notebook), you can also manually call individual transformations (like we do here), as well as adding your own transformations, to create custom flows. You can read more about these transformations in [this notebook](../bnn-pynq/tfc_end2end_example.ipynb).

In [5]:
from qonnx.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames, RemoveStaticGraphInputs
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.transformation.infer_datatypes import InferDataTypes
from qonnx.transformation.fold_constants import FoldConstants

model_for_sim = model_for_sim.transform(InferShapes())
model_for_sim = model_for_sim.transform(FoldConstants())
model_for_sim = model_for_sim.transform(GiveUniqueNodeNames())
model_for_sim = model_for_sim.transform(GiveReadableTensorNames())
model_for_sim = model_for_sim.transform(InferDataTypes())
model_for_sim = model_for_sim.transform(RemoveStaticGraphInputs())

verif_model_filename = model_dir + "/finn-brevitas-verification.onnx"
model_for_sim.save(verif_model_filename)

**Would the FINN compiler still work if we didn't do this?** The compilation step in the next notebook applies these transformations internally and would work fine, but we're going to use FINN's verification capabilities below and these require the tidy-up transformations.

Let's view our ready-to-go model after the transformations. Note that all intermediate tensors now have their shapes specified (indicated by numbers next to the arrows going between layers). Additionally, the datatype inference step has propagated quantization annotations to the outputs of `MultiThreshold` layers (expand by clicking the + next to the name of the tensor to see the quantization annotation) and the final output tensor.

In [6]:
from finn.util.visualization import showInNetron

showInNetron(verif_model_filename)

OSError: [Errno 98] Address already in use

# 3. Load the Dataset and the Brevitas Model <a id="load_dataset"></a>

We'll use some example data from the quantized UNSW-NB15 dataset (from the previous notebook) to use as inputs for the verification. 

In [7]:
"""
import numpy as np
from torch.utils.data import TensorDataset

def get_preqnt_dataset(data_dir: str, train: bool):
    unsw_nb15_data = np.load(data_dir + "/unsw_nb15_binarized.npz")
    if train:
        partition = "train"
    else:
        partition = "test"
    part_data = unsw_nb15_data[partition].astype(np.float32)
    part_data = torch.from_numpy(part_data)
    part_data_in = part_data[:, :-1]
    part_data_out = part_data[:, -1]
    return TensorDataset(part_data_in, part_data_out)

n_verification_inputs = 100
test_quantized_dataset = get_preqnt_dataset(".", False)
input_tensor = test_quantized_dataset.tensors[0][:n_verification_inputs]
input_tensor.shape

train_quantized_dataset = get_preqnt_dataset(".", True)
test_quantized_dataset = get_preqnt_dataset(".", False)

print("Samples in each set: train = %d, test = %s" % (len(train_quantized_dataset), len(test_quantized_dataset))) 
print("Shape of one input sample: " +  str(train_quantized_dataset[0][0].shape))

"""

import numpy as np
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
import pandas as pd
from sklearn.model_selection import train_test_split
from PIL import Image

data_path = "/home/administrateur/finn/notebooks/kidneydataset/CT-KIDNEY-DATASET-Normal-Cyst-Tumor-Stone/CT-KIDNEY-DATASET-Normal-Cyst-Tumor-Stone"

# Prepare dataset
def load_dataset(data_path):
    images = []
    labels = []
    for subfolder in os.listdir(data_path):
        subfolder_path = os.path.join(data_path, subfolder)
        if not os.path.isdir(subfolder_path):
            continue
        for image_filename in os.listdir(subfolder_path):
            image_path = os.path.join(subfolder_path, image_filename)
            images.append(image_path)
            labels.append(subfolder)
    return pd.DataFrame({'image': images, 'label': labels})

# Define Custom Dataset class
class CustomDataset(Dataset):
    def __init__(self, dataframe, transform=None, class_indices=None):
        self.dataframe = dataframe
        self.transform = transform
        self.class_indices = class_indices

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_path = self.dataframe.iloc[idx]['image']
        image = Image.open(img_path).convert('RGB')
        label = self.class_indices[self.dataframe.iloc[idx]['label']]

        if self.transform:
            image = self.transform(image)

        return image, label

input_size = (3, 64, 64)

transform = transforms.Compose([
    transforms.Resize((input_size[1], input_size[2])),
    #transforms.RandomHorizontalFlip(),
    #transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

data = load_dataset(data_path)
train_df, dummy_df = train_test_split(data, train_size=0.01, shuffle=True, stratify=data['label'], random_state=123)

class_indices = {label: idx for idx, label in enumerate(train_df['label'].unique())}
train_quantized_dataset = CustomDataset(train_df, transform=transform, class_indices=class_indices)


Let's also bring up the MLP we trained in Brevitas from the previous notebook. We'll compare its outputs to what is generated by FINN.

In [8]:
"""
input_size = 593      
hidden1 = 64      
hidden2 = 64
hidden3 = 64
weight_bit_width = 2
act_bit_width = 2
num_classes = 1

from brevitas.nn import QuantLinear, QuantReLU
import torch.nn as nn

brevitas_model = nn.Sequential(
      QuantLinear(input_size, hidden1, bias=True, weight_bit_width=weight_bit_width),
      nn.BatchNorm1d(hidden1),
      nn.Dropout(0.5),
      QuantReLU(bit_width=act_bit_width),
      QuantLinear(hidden1, hidden2, bias=True, weight_bit_width=weight_bit_width),
      nn.BatchNorm1d(hidden2),
      nn.Dropout(0.5),
      QuantReLU(bit_width=act_bit_width),
      QuantLinear(hidden2, hidden3, bias=True, weight_bit_width=weight_bit_width),
      nn.BatchNorm1d(hidden3),
      nn.Dropout(0.5),
      QuantReLU(bit_width=act_bit_width),
      QuantLinear(hidden3, num_classes, bias=True, weight_bit_width=weight_bit_width)
)

# replace this with your trained network checkpoint if you're not
# using the pretrained weights
trained_state_dict = torch.load(model_dir + "/state_dict.pth")["models_state_dict"][0]

# Uncomment the following line if you previously chose to train the network yourself
#trained_state_dict = torch.load("state_dict_self-trained.pth")

brevitas_model.load_state_dict(trained_state_dict, strict=False)
"""

from brevitas.nn import QuantLinear, QuantReLU, QuantConv2d
import torch.nn as nn

input_size = (3, 64, 64)

class QuantMobileNetV2Model(nn.Module):
    def __init__(self, num_classes=4):
        super(QuantMobileNetV2Model, self).__init__()

        self.conv1 = QuantConv2d(3, 6, 5, bias=True, weight_bit_width=8, padding=2)
        self.lin2_2 = QuantLinear(int(6*input_size[1]*input_size[2]/4), 32, bias=True, weight_bit_width=8)
        
        self.lin1 = QuantLinear(input_size[0]*input_size[1]*input_size[2], 64, bias=True, weight_bit_width=8)
        self.bnorm1 = nn.BatchNorm1d(64)
        self.drop1 = nn.Dropout(0.5)
        self.relu1 = QuantReLU(bit_width=8)
        self.lin2 = QuantLinear(64, 32, bias=True, weight_bit_width=8)
        self.bnorm2 = nn.BatchNorm1d(32)
        self.drop2 = nn.Dropout(0.5)
        self.relu2 = QuantReLU(bit_width=8)
        self.lin3 = QuantLinear(32, 16, bias=True, weight_bit_width=8)
        self.bnorm3 = nn.BatchNorm1d(16)
        self.drop3 = nn.Dropout(0.5)
        self.relu3 = QuantReLU(bit_width=8)
        self.lin4 = QuantLinear(16, 4, bias=True, weight_bit_width=8)

    def forward(self, x):
        
        out = x.view(-1, input_size[0]*input_size[1]*input_size[2])
        out = self.lin1(out)
        out = self.bnorm1(out)
        out = self.drop1(out)
        out = self.relu1(out)
        out = self.lin2(out)

        #out = self.conv1(x)
        #out = self.relu1(out)
        #out = nn.functional.max_pool2d(out, 2)
        #out = out.view(-1, int(6*input_size[1]*input_size[2]/4))
        #out = self.lin2_2(out)
         
        out = self.bnorm2(out)
        #print(out.shape)
        out = self.drop2(out)
        #print(out.shape)
        out = self.relu2(out)
        out = self.lin3(out)
        out = self.bnorm3(out)
        #print(out.shape)
        out = self.drop3(out)
        #print(out.shape)
        out = self.relu3(out)
        out = self.lin4(out)
        #print(out.shape)

        return out

brevitas_model = QuantMobileNetV2Model(num_classes=4) #.to(device)

#trained_state_dict = torch.load(model_dir + "/state_dict.pth")["models_state_dict"][0]
trained_state_dict = torch.load("state_dict_self-trained.pth")
brevitas_model.load_state_dict(trained_state_dict, strict=False)

<All keys matched successfully>

In [9]:
def inference_with_brevitas(current_inp):
    brevitas_output = brevitas_model.forward(current_inp)
    # apply sigmoid + threshold
    # brevitas_output = torch.sigmoid(brevitas_output)
    
    #brevitas_output = (brevitas_output.detach().numpy() > 0.5) * 1
    brevitas_output = brevitas_output.detach().numpy()
    # convert output to bipolar
    #brevitas_output = 2*brevitas_output - 1
    return brevitas_output

# 4. Compare FINN & Brevitas execution <a id="compare_brevitas"></a>

Let's make helper functions to execute the same input with Brevitas and FINN. For FINN, we'll use the [`finn.core.onnx_exec`](https://finn.readthedocs.io/en/latest/source_code/finn.core.html#finn.core.onnx_exec.execute_onnx) function to execute the exported FINN-ONNX on the inputs. Note that this ONNX execution is for verification only; not for accelerated execution.

Recall that the quantized values from the dataset are 593-bit binary {0, 1} vectors whereas our exported model takes 600-bit bipolar {-1, +1} vectors, so we'll have to preprocess it a bit before we can use it for verifying the ONNX model.

In [10]:
import finn.core.onnx_exec as oxe

def inference_with_finn_onnx(current_inp):
    finnonnx_in_tensor_name = model_for_sim.graph.input[0].name
    finnonnx_model_in_shape = model_for_sim.get_tensor_shape(finnonnx_in_tensor_name)
    finnonnx_out_tensor_name = model_for_sim.graph.output[0].name
    # convert input to numpy for FINN
    current_inp = current_inp.detach().numpy()
    # add padding and re-scale to bipolar
    # current_inp = np.pad(current_inp, [(0, 0), (0, 7)])
    # current_inp = 2*current_inp-1
    # reshape to expected input (add 1 for batch dimension)
    current_inp = current_inp.reshape(finnonnx_model_in_shape)
    # create the input dictionary
    input_dict = {finnonnx_in_tensor_name : current_inp} 
    # run with FINN's execute_onnx
    output_dict = oxe.execute_onnx(model_for_sim, input_dict)
    #get the output tensor
    finn_output = output_dict[finnonnx_out_tensor_name] 
    return finn_output

Now we can call our inference helper functions for each input and compare the outputs.

In [12]:
import numpy as np
from tqdm import trange

verify_range = trange(len(train_quantized_dataset), desc="FINN execution", position=0, leave=True)
brevitas_model.eval()

ok = 0
nok = 0

"""
for i in verify_range:
    # run in Brevitas with PyTorch tensor
    current_inp = input_tensor[i].reshape((1, 3, 224, 224))
    brevitas_output = inference_with_brevitas(current_inp)
    finn_output = inference_with_finn_onnx(current_inp)
    # compare the outputs
    ok += 1 if finn_output == brevitas_output else 0
    nok += 1 if finn_output != brevitas_output else 0
    verify_range.set_description("ok %d nok %d" % (ok, nok))
    verify_range.refresh()
"""


for images, labels in train_quantized_dataset:
    # run in Brevitas with PyTorch tensor
    # print(images.shape)
    current_inp = images.reshape((1, input_size[0], input_size[1], input_size[2]))
    brevitas_output = inference_with_brevitas(current_inp)
    finn_output = inference_with_finn_onnx(current_inp)
    print(brevitas_output)
    print(finn_output)
    # compare the outputs
    ok += 1 if (finn_output == brevitas_output).all() else 0
    nok += 1 if (finn_output != brevitas_output).all() else 0
    verify_range.set_description("ok %d nok %d" % (ok, nok))
    verify_range.refresh()

    with open('test.npy', 'wb') as f:
        np.save(f, images)

ok 0 nok 1:   0%|                                 | 0/124 [02:50<?, ?it/s]
ok 0 nok 1:   0%|                                 | 0/124 [00:02<?, ?it/s]

[[ 0.4677283   0.17800015  0.66173327 -0.2792176 ]]
[[ 0.46772826  0.17800012  0.6617332  -0.27921766]]


ok 0 nok 1:   0%|                                 | 0/124 [00:04<?, ?it/s]

[[-0.32734597  0.41883895 -0.09124881 -0.6060329 ]]
[[-0.3273459   0.41883895 -0.09124884 -0.6060329 ]]


ok 0 nok 1:   0%|                                 | 0/124 [00:06<?, ?it/s]

[[ 0.25232828 -0.06624807  0.18801059 -0.04100133]]
[[ 0.25232828 -0.06624808  0.18801059 -0.04100133]]


KeyboardInterrupt: 

In [None]:
try:
    assert ok == n_verification_inputs
    print("Verification succeeded. Brevitas and FINN-ONNX execution outputs are identical")
except AssertionError:
    assert False, "Verification failed. Brevitas and FINN-ONNX execution outputs are NOT identical"

This concludes our second notebook. In the next one, we'll take the ONNX model we just verified all the way down to FPGA hardware with the FINN compiler.