# Quantized Model Cleanup
This notebook takes the finn-onnx FACILE model exported by the quant_train notebook and cleans it up. This notebook stops before converting to hls layers.

### Load in FINN and transform

In [1]:
import onnx
from finn.util.test import get_test_model_trained
import brevitas.onnx as bo
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.fold_constants import FoldConstants
from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames, RemoveStaticGraphInputs
from finn.transformation.infer_datatypes import InferDataTypes

In [2]:
#load and tidy up brevitas export
model = ModelWrapper("quant_models/facile.onnx")
model = model.transform(InferShapes())
model = model.transform(FoldConstants())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(InferDataTypes())
model = model.transform(RemoveStaticGraphInputs())
model.save("quant_models/facile_tidy.onnx")

In [3]:
#visualize net in netron
from finn.util.visualization import showInNetron
showInNetron("quant_models/facile_tidy.onnx")

Serving 'quant_models/facile_tidy.onnx' at http://0.0.0.0:8081


In [4]:
#pre processing
from finn.util.pytorch import ToTensor
from finn.transformation.merge_onnx_models import MergeONNXModels
from finn.core.datatype import DataType

model = ModelWrapper("quant_models/facile_tidy.onnx")
global_inp_name = model.graph.input[0].name
ishape = model.get_tensor_shape(global_inp_name)
# preprocessing: torchvision's ToTensor divides uint8 inputs by 255
#totensor_pyt = ToTensor()
#chkpt_preproc_name = "xor_preproc.onnx"
#bo.export_finn_onnx(totensor_pyt, ishape, chkpt_preproc_name)

# join preprocessing and core model
#pre_model = ModelWrapper(chkpt_preproc_name)
#model = model.transform(MergeONNXModels(pre_model))
# add input quantization annotation: UINT8 for all BNN-PYNQ models
global_inp_name = model.graph.input[0].name
print(global_inp_name)
global_oup_name = model.graph.output[0].name
print(global_oup_name)
model.set_tensor_datatype(global_inp_name, DataType.UINT8)
model.set_tensor_datatype(global_oup_name, DataType.UINT8)

model.save("quant_models/facile_with_preproc.onnx")
showInNetron("quant_models/facile_with_preproc.onnx")

global_in
global_out

Stopping http://0.0.0.0:8081
Serving 'quant_models/facile_with_preproc.onnx' at http://0.0.0.0:8081


In [5]:
from finn.transformation.infer_datatypes import InferDataTypes

# postprocessing: insert Top-1 node at the end
#model = model.transform(InsertTopK(k=1))
chkpt_name = "quant_models/facile_postproc.onnx"
# tidy-up again
model = model.transform(InferShapes())
model = model.transform(FoldConstants())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(InferDataTypes())
model = model.transform(RemoveStaticGraphInputs())
model.save(chkpt_name)

showInNetron("quant_models/facile_postproc.onnx")


Stopping http://0.0.0.0:8081
Serving 'quant_models/facile_postproc.onnx' at http://0.0.0.0:8081


### Streamlining

In [6]:
from finn.transformation.streamline import Streamline
from finn.transformation.streamline.reorder import MoveScalarLinearPastInvariants
import finn.transformation.streamline.absorb as absorb

model = ModelWrapper("quant_models/facile_postproc.onnx")
# move initial Mul (from preproc) past the Reshape
model = model.transform(MoveScalarLinearPastInvariants())
# streamline
model = model.transform(Streamline())
model.save("quant_models/facile_streamlined.onnx")
showInNetron("quant_models/facile_streamlined.onnx")


Stopping http://0.0.0.0:8081
Serving 'quant_models/facile_streamlined.onnx' at http://0.0.0.0:8081


In [7]:
from finn.transformation.bipolar_to_xnor import ConvertBipolarMatMulToXnorPopcount
from finn.transformation.streamline.round_thresholds import RoundAndClipThresholds
from finn.transformation.infer_data_layouts import InferDataLayouts
from finn.transformation.general import RemoveUnusedTensors

model = model.transform(ConvertBipolarMatMulToXnorPopcount())
model = model.transform(absorb.AbsorbAddIntoMultiThreshold())
model = model.transform(absorb.AbsorbMulIntoMultiThreshold())
# absorb final add-mul nodes into TopK
#model = model.transform(absorb.AbsorbScalarMulAddIntoTopK())
#model = model.transform(RoundAndClipThresholds())

# bit of tidy-up
model = model.transform(InferDataLayouts())
model = model.transform(RemoveUnusedTensors())

model.save("quant_models/facile_ready_for_hls_conv.onnx")
showInNetron("quant_models/facile_ready_for_hls_conv.onnx")


Stopping http://0.0.0.0:8081
Serving 'quant_models/facile_ready_for_hls_conv.onnx' at http://0.0.0.0:8081


