# Verify Exported ONNX Model in FINN



In [None]:
import onnx 
import torch 

# Import model into FINN with ModelWrapper

The quantized model is initialised in a ModelWrapper to test how the model behaves with its new FINN structure.

In [None]:
import os
from qonnx.core.modelwrapper import ModelWrapper

model_dir = ""
ready_model_filename = model_dir + "mlp_ready.onnx"
model_for_sim = ModelWrapper(ready_model_filename)

FINN provides a number of functions to access information about the model. This can be used to verify information like the model inputs/outputs and the model shape are still correct

In [None]:
from qonnx.core.datatype import DataType

finnonnx_in_tensor_name = model_for_sim.graph.input[0].name
finnonnx_out_tensor_name = model_for_sim.graph.output[0].name
print("Input tensor name: %s" % finnonnx_in_tensor_name)
print("Output tensor name: %s" % finnonnx_out_tensor_name)
finnonnx_model_in_shape = model_for_sim.get_tensor_shape(finnonnx_in_tensor_name)
finnonnx_model_out_shape = model_for_sim.get_tensor_shape(finnonnx_out_tensor_name)
print("Input tensor shape: %s" % str(finnonnx_model_in_shape))
print("Output tensor shape: %s" % str(finnonnx_model_out_shape))
finnonnx_model_in_dt = model_for_sim.get_tensor_datatype(finnonnx_in_tensor_name)
finnonnx_model_out_dt = model_for_sim.get_tensor_datatype(finnonnx_out_tensor_name)
print("Input tensor datatype: %s" % str(finnonnx_model_in_dt.name))
print("Output tensor datatype: %s" % str(finnonnx_model_out_dt.name))
print("List of node operator types in the graph: ")
print([x.op_type for x in model_for_sim.graph.node])

Input tensor name: global_in
Output tensor name: global_out
Input tensor shape: [1, 2, 1024, 1]
Output tensor shape: [1, 4]
Input tensor datatype: INT8
Output tensor datatype: FLOAT32
List of node operator types in the graph: 
['Conv', 'Mul', 'Add', 'MultiThreshold', 'Mul', 'Conv', 'Mul', 'Add', 'MultiThreshold', 'Mul', 'Flatten', 'MatMul', 'Mul', 'Add', 'MultiThreshold', 'Mul', 'MatMul', 'Mul', 'Add']


Note that the output tensor is (as of yet) marked as a float32 value, even though we know the output is binary.

# Network preparation: Tidy-up transformations

Before running the verification, we need to prepare our FINN-ONNX model. In particular, all the intermediate tensors need to have statically defined shapes. To do this, we apply some graph transformations to the model like a kind of "tidy-up" to make it easier to process. 

In [None]:
from qonnx.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames, RemoveStaticGraphInputs
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.transformation.infer_datatypes import InferDataTypes
from qonnx.transformation.fold_constants import FoldConstants

model_for_sim = model_for_sim.transform(InferShapes())
model_for_sim = model_for_sim.transform(FoldConstants())
model_for_sim = model_for_sim.transform(GiveUniqueNodeNames())
model_for_sim = model_for_sim.transform(GiveReadableTensorNames())
model_for_sim = model_for_sim.transform(InferDataTypes())
model_for_sim = model_for_sim.transform(RemoveStaticGraphInputs())

verif_model_filename = model_dir + "cnn-verification.onnx"
model_for_sim.save(verif_model_filename)

# Load the Dataset

The dataset is loaded as before

In [None]:
def filter_strings(lst):
    filtered_list = [s for s in lst if not any(digit in s for digit in "3456789")]
    return filtered_list

In [None]:
import numpy as np
import os as os
from sklearn.model_selection import train_test_split
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

folder = "../fullPlutoImport"
files = os.listdir(folder)

filtered_files = filter_strings(files)

factor = 2
noFiles = len(filtered_files)

arr = np.ndarray((int(7800*noFiles/factor),128*factor*2), float)
labels = np.ndarray((int(7800*noFiles/factor),4))

seed = 0

i = 0;
for idx, npz in enumerate(filtered_files):
    
    a = np.load(os.path.join(folder, npz))
    
    start_idx = (idx*int(7800/factor)) if idx <20 else (idx)*int(7800/factor)-1
    end_idx = (1+idx)*int(7800/factor) if idx <20 else (1+idx)*int(7800/factor)-1
           
    reshaped_arr = a["samples"].reshape(int(7800/factor), 128*factor)
    
    float_array = np.ndarray((int(7800/factor), 128*factor*2), float)
    for j in range(reshaped_arr.shape[0]):
        float_array[j] = np.ravel((reshaped_arr[j].real, reshaped_arr[j].imag),'F')
    arr[start_idx:end_idx] = float_array
    labels[start_idx:end_idx] = np.tile(a["active_channels"],  (int(7800/factor), 1))

    i+=1
    if i >= noFiles:
        break
    
normalized_array = 255 * (arr + 2) / (4) - 128

ver_arr = TensorDataset(torch.tensor(normalized_array, dtype=torch.int8), torch.tensor(labels, dtype=torch.int8))


n_verification_inputs = 100
input_tensor = ver_arr.tensors[0][:n_verification_inputs]


Original array min: -2.0, max: 2.0
Normalized array min: -128.0, max: 127.0


# Rebuild MLP

The model is remade in Brevitas using the same weights as before using its state dictionary.

In [None]:
input_bits = 4
a_bits = 4
w_bits = 4
filters_conv = 32
filters_dense = 64

In [None]:


from torch import nn
import brevitas.nn as qnn
from brevitas.quant import IntBias
from brevitas.inject.enum import ScalingImplType
from brevitas.inject.defaults import Int8ActPerTensorFloatMinMaxInit

# Setting seeds for reproducibility
torch.manual_seed(0)
np.random.seed(0)

class InputQuantizer(Int8ActPerTensorFloatMinMaxInit):
    bit_width = input_bits
    min_val = -2.0
    max_val = 2.0
    scaling_impl_type = ScalingImplType.CONST # Fix the quantization range to [min_val, max_val]

model = nn.Sequential(
    # Input quantization layer
    qnn.QuantHardTanh(act_quant=InputQuantizer),

    qnn.QuantConv2d(1, filters_conv, 3, padding=1, weight_bit_width=w_bits, bias=False),
    nn.BatchNorm2d(filters_conv),
    qnn.QuantReLU(bit_width=a_bits),
    #nn.MaxPool2d(2,padding=1),

    qnn.QuantConv2d(filters_conv, 2*filters_conv, 3, padding=1, weight_bit_width=w_bits, bias=False),
    nn.BatchNorm2d(2*filters_conv),
    qnn.QuantReLU(bit_width=a_bits),
    nn.MaxPool2d(2,padding=1),

    qnn.QuantConv2d(2*filters_conv, 2*filters_conv, 3, padding=1, weight_bit_width=w_bits, bias=False),
    nn.BatchNorm2d(2*filters_conv),
    qnn.QuantReLU(bit_width=a_bits),
    #nn.MaxPool2d(2,padding=1),

    qnn.QuantConv2d(2*filters_conv, filters_conv, 3, padding=1, weight_bit_width=w_bits, bias=False),
    nn.BatchNorm2d(filters_conv),
    qnn.QuantReLU(bit_width=a_bits),
    nn.MaxPool2d(2,padding=1),
    
    nn.Flatten(),

    qnn.QuantLinear(filters_conv*65*2, 4, weight_bit_width=w_bits, bias=False),
)

Sequential(
  (0): QuantHardTanh(
    (input_quant): ActQuantProxyFromInjector(
      (_zero_hw_sentinel): StatelessBuffer()
    )
    (act_quant): ActQuantProxyFromInjector(
      (_zero_hw_sentinel): StatelessBuffer()
      (fused_activation_quant_proxy): FusedActivationQuantProxy(
        (activation_impl): Identity()
        (tensor_quant): RescalingIntQuant(
          (int_quant): IntQuant(
            (float_to_int_impl): RoundSte()
            (tensor_clamp_impl): TensorClamp()
            (delay_wrapper): DelayWrapper(
              (delay_impl): _NoDelay()
            )
            (input_view_impl): Identity()
          )
          (scaling_impl): ConstScaling(
            (restrict_clamp_scaling): _RestrictClampValue(
              (clamp_min_ste): Identity()
              (restrict_value_impl): FloatRestrictValue()
            )
            (restrict_init_module): Identity()
            (value): StatelessBuffer()
          )
          (int_scaling_impl): IntScaling()
      

In [14]:
# Uncomment the following line if you previously chose to train the network yourself
trained_state_dict = torch.load("state_dict_self-trained.pth")

model.load_state_dict(trained_state_dict, strict=False)

  trained_state_dict = torch.load("state_dict_self-trained.pth")


<All keys matched successfully>

Optional: quantize the output.

In [None]:
from brevitas.nn import QuantIdentity

class BipolarForExport(nn.Module):
    def __init__(self, my_pretrained_model):
        super(BipolarForExport, self).__init__()
        self.pretrained = my_pretrained_model
        self.qnt_output = QuantIdentity(
            quant_type='binary', 
            scaling_impl_type='const',
            bit_width=1, min_val=-1.0, max_val=1.0)
    
    def forward(self, x):
        out_original = self.pretrained(x)
        out_final = self.qnt_output(out_original)   # output as {-1,1}     
        return out_final



model = BipolarForExport(model)

# 4. Compare FINN & Brevitas execution <a id="compare_brevitas"></a>

FINN provides the finn.core.onnx_exec function to simulate what happens in FINN with the given model. By executing on it using this function it can be verified that the model will act in the same manor as the Brevitas model.

In [17]:
import finn.core.onnx_exec as oxe

def inference_with_finn_onnx(current_inp):
    finnonnx_in_tensor_name = model_for_sim.graph.input[0].name
    finnonnx_model_in_shape = model_for_sim.get_tensor_shape(finnonnx_in_tensor_name)
    finnonnx_out_tensor_name = model_for_sim.graph.output[0].name
    # convert input to numpy for FINN
    current_inp = current_inp.detach().numpy()
    # reshape to expected input (add 1 for batch dimension)
    current_inp = current_inp.reshape(finnonnx_model_in_shape)
    # create the input dictionary
    input_dict = {finnonnx_in_tensor_name : current_inp} 
    # run with FINN's execute_onnx
    output_dict = oxe.execute_onnx(model_for_sim, input_dict)
    #get the output tensor
    finn_output = output_dict[finnonnx_out_tensor_name] 
    return finn_output

To get outputs from the brevitas model, simply run as normal.

In [None]:
def inference_with_brevitas(current_inp):
    model.eval() 
    brevitas_output = model(current_inp)

    return brevitas_output

Now the inference helper functions are called for each input and the outputs compared.

In [None]:


import numpy as np
from tqdm import trange

verify_range = trange(n_verification_inputs, desc="FINN execution", position=0, leave=True)
model.eval()

ok = 0
nok = 0

for i in verify_range:
    # run in Brevitas with PyTorch tensor
    current_inp = input_tensor[i].reshape((1, 128*factor*2))
    brevitas_output = inference_with_brevitas(current_inp).detach().numpy().astype(np.int16)
    finn_output = inference_with_finn_onnx(current_inp).astype(np.int16)
    # compare the outputs
    ok += 1 if (finn_output==brevitas_output).all() else 0
    nok += 1 if (finn_output != brevitas_output).any() else 0
    verify_range.set_description("ok %d nok %d" % (ok, nok))
    verify_range.refresh()



ok 5 nok 0:   3%|▎         | 3/100 [00:00<00:04, 21.15it/s]

[[ 1 -1  2  0]]
[[ 1 -1  2  0]]

[[ 2 -1  0  0]]
[[ 2 -1  0  0]]

[[ 3 -1  0  0]]
[[ 3 -1  0  0]]

[[3 0 1 0]]
[[3 0 1 0]]

[[3 0 0 0]]
[[3 0 0 0]]



ok 10 nok 0:   9%|▉         | 9/100 [00:00<00:04, 22.16it/s]

[[ 2 -1  1  0]]
[[ 2 -1  1  0]]

[[ 0 -1  0  2]]
[[ 0 -1  0  2]]

[[2 0 3 0]]
[[2 0 3 0]]

[[ 1 -1  0  2]]
[[ 1 -1  0  2]]

[[ 1 -1  3  0]]
[[ 1 -1  3  0]]



ok 15 nok 0:  15%|█▌        | 15/100 [00:00<00:03, 22.67it/s]

[[ 1 -1  0  2]]
[[ 1 -1  0  2]]

[[ 0 -1  1  1]]
[[ 0 -1  1  1]]

[[ 4 -1  0  1]]
[[ 4 -1  0  1]]

[[3 0 0 1]]
[[3 0 0 1]]

[[ 2 -1  4  3]]
[[ 2 -1  4  3]]



ok 20 nok 0:  18%|█▊        | 18/100 [00:00<00:03, 21.81it/s]

[[2 0 0 0]]
[[2 0 0 0]]

[[ 0 -1  0  2]]
[[ 0 -1  0  2]]

[[ 1 -1  0  1]]
[[ 1 -1  0  1]]

[[ 2 -2  2  0]]
[[ 2 -2  2  0]]

[[ 3 -1  1  0]]
[[ 3 -1  1  0]]



ok 25 nok 0:  24%|██▍       | 24/100 [00:01<00:03, 22.66it/s]

[[ 3 -2  0  2]]
[[ 3 -2  0  2]]

[[ 2 -1  0  0]]
[[ 2 -1  0  0]]

[[ 0 -1  0  1]]
[[ 0 -1  0  1]]

[[ 1 -1  4  2]]
[[ 1 -1  4  2]]

[[1 0 3 0]]
[[1 0 3 0]]



ok 30 nok 0:  30%|███       | 30/100 [00:01<00:03, 22.60it/s]

[[ 0 -1  0  4]]
[[ 0 -1  0  4]]

[[1 0 0 3]]
[[1 0 0 3]]

[[ 3 -1  3  2]]
[[ 3 -1  3  2]]

[[ 0 -1  2  1]]
[[ 0 -1  2  1]]

[[ 0 -1  2  0]]
[[ 0 -1  2  0]]



ok 35 nok 0:  33%|███▎      | 33/100 [00:01<00:02, 22.78it/s]

[[ 1 -1  1  1]]
[[ 1 -1  1  1]]

[[ 0 -1  3  0]]
[[ 0 -1  3  0]]

[[ 1 -2  3  3]]
[[ 1 -2  3  3]]

[[ 2 -1  1  3]]
[[ 2 -1  1  3]]

[[ 0 -1  0  1]]
[[ 0 -1  0  1]]



ok 40 nok 0:  39%|███▉      | 39/100 [00:01<00:02, 22.85it/s]

[[2 0 1 0]]
[[2 0 1 0]]

[[ 1 -2  2  0]]
[[ 1 -2  2  0]]

[[ 0 -1  0  0]]
[[ 0 -1  0  0]]

[[ 1 -1 -1  0]]
[[ 1 -1 -1  0]]

[[ 3 -2  2  0]]
[[ 3 -2  2  0]]



ok 45 nok 0:  45%|████▌     | 45/100 [00:01<00:02, 23.11it/s]

[[ 2 -1  1  3]]
[[ 2 -1  1  3]]

[[ 1 -1  1  0]]
[[ 1 -1  1  0]]

[[ 2 -1  2  0]]
[[ 2 -1  2  0]]

[[ 2 -2  4  2]]
[[ 2 -2  4  2]]

[[ 2 -1  1  0]]
[[ 2 -1  1  0]]



ok 50 nok 0:  48%|████▊     | 48/100 [00:02<00:02, 23.16it/s]

[[ 2 -1  5  0]]
[[ 2 -1  5  0]]

[[ 1  0 -1  2]]
[[ 1  0 -1  2]]

[[ 1 -1  1  2]]
[[ 1 -1  1  2]]

[[ 3 -1  3  0]]
[[ 3 -1  3  0]]

[[1 0 1 0]]
[[1 0 1 0]]



ok 55 nok 0:  54%|█████▍    | 54/100 [00:02<00:01, 23.12it/s]

[[0 0 1 0]]
[[0 0 1 0]]

[[ 1 -1  1  0]]
[[ 1 -1  1  0]]

[[ 0 -1  1  2]]
[[ 0 -1  1  2]]

[[ 1 -1  3  1]]
[[ 1 -1  3  1]]

[[ 0 -1  0  2]]
[[ 0 -1  0  2]]



ok 60 nok 0:  60%|██████    | 60/100 [00:02<00:01, 22.28it/s]

[[ 3 -1  2  1]]
[[ 3 -1  2  1]]

[[ 2 -1  1  1]]
[[ 2 -1  1  1]]

[[ 0 -1  1  2]]
[[ 0 -1  1  2]]

[[ 5 -1  2  0]]
[[ 5 -1  2  0]]

[[4 0 0 0]]
[[4 0 0 0]]



ok 65 nok 0:  63%|██████▎   | 63/100 [00:02<00:01, 22.07it/s]

[[ 1 -1  1  0]]
[[ 1 -1  1  0]]

[[ 0 -1  2  1]]
[[ 0 -1  2  1]]

[[1 0 2 0]]
[[1 0 2 0]]

[[ 1 -1  3  1]]
[[ 1 -1  3  1]]

[[ 0 -1  4  1]]
[[ 0 -1  4  1]]



ok 70 nok 0:  69%|██████▉   | 69/100 [00:03<00:01, 22.42it/s]

[[ 0 -1  4  2]]
[[ 0 -1  4  2]]

[[ 1  0  3 -1]]
[[ 1  0  3 -1]]

[[ 0 -1  2  3]]
[[ 0 -1  2  3]]

[[3 0 0 1]]
[[3 0 0 1]]

[[ 5 -1  1  2]]
[[ 5 -1  1  2]]



ok 75 nok 0:  75%|███████▌  | 75/100 [00:03<00:01, 21.99it/s]

[[ 0 -1  2  1]]
[[ 0 -1  2  1]]

[[ 0 -1  1  1]]
[[ 0 -1  1  1]]

[[1 0 0 0]]
[[1 0 0 0]]

[[3 0 1 2]]
[[3 0 1 2]]

[[ 3 -2  1  0]]
[[ 3 -2  1  0]]



ok 80 nok 0:  78%|███████▊  | 78/100 [00:03<00:01, 21.42it/s]

[[ 3 -1  2  0]]
[[ 3 -1  2  0]]

[[ 1 -1  0  1]]
[[ 1 -1  0  1]]

[[ 0 -2  3  2]]
[[ 0 -2  3  2]]

[[ 0  0 -1  1]]
[[ 0  0 -1  1]]

[[ 1 -1  2  0]]
[[ 1 -1  2  0]]



ok 85 nok 0:  84%|████████▍ | 84/100 [00:03<00:00, 21.42it/s]

[[1 0 3 2]]
[[1 0 3 2]]

[[0 0 0 0]]
[[0 0 0 0]]

[[ 1 -1  2  0]]
[[ 1 -1  2  0]]

[[ 1 -1  3  2]]
[[ 1 -1  3  2]]

[[ 1 -1  0  0]]
[[ 1 -1  0  0]]



ok 90 nok 0:  90%|█████████ | 90/100 [00:04<00:00, 21.03it/s]

[[ 1 -1  2  0]]
[[ 1 -1  2  0]]

[[ 1 -1  3 -1]]
[[ 1 -1  3 -1]]

[[ 1 -1  0  0]]
[[ 1 -1  0  0]]

[[ 1 -1  0  0]]
[[ 1 -1  0  0]]

[[ 3 -1  1  0]]
[[ 3 -1  1  0]]



ok 95 nok 0:  93%|█████████▎| 93/100 [00:04<00:00, 21.08it/s]

[[ 1 -1  0  0]]
[[ 1 -1  0  0]]

[[ 6 -1  2 -1]]
[[ 6 -1  2 -1]]

[[ 0 -1  0  2]]
[[ 0 -1  0  2]]

[[ 0 -1  0  1]]
[[ 0 -1  0  1]]

[[ 2 -1  0  2]]
[[ 2 -1  0  2]]



ok 100 nok 0: 100%|██████████| 100/100 [00:04<00:00, 22.11it/s]

[[ 1 -1  0  1]]
[[ 1 -1  0  1]]

[[ 0 -1  2  2]]
[[ 0 -1  2  2]]

[[ 1 -1  3  0]]
[[ 1 -1  3  0]]

[[ 2 -1  4  0]]
[[ 2 -1  4  0]]

[[ 1 -1  0  1]]
[[ 1 -1  0  1]]






In [17]:
try:
    assert ok == n_verification_inputs
    print("Verification succeeded. Brevitas and FINN-ONNX execution outputs are identical")
except AssertionError:
    assert False, "Verification failed. Brevitas and FINN-ONNX execution outputs are NOT identical"

Verification succeeded. Brevitas and FINN-ONNX execution outputs are identical
