In [None]:
# To get onnx file, run in terminal:
# python setup.py test --addopts "-k test_brevitas_quartznet" 
import numpy as np
from finn.util.visualization import showInNetron
from finn.core.modelwrapper import ModelWrapper

file_name = '/tmp/quartznet.onnx'
showInNetron(file_name)

In [None]:
### Partition graph


In [None]:
### STREAMLINING SEQUENCE

from finn.transformation.streamline import *
from finn.transformation.streamline.reorder import MoveMulPastDWConv, MoveLinearPastEltwiseAdd, MoveMulPastFork
from finn.transformation.change_3d_tensors_to_4d import Change3DTo4DTensors
from finn.transformation.lower_convs_to_matmul import LowerConvsToMatMul
from finn.transformation.streamline.absorb import AbsorbConsecutiveTransposes # No effect (only on consecutive transpose nodes)
from finn.transformation.streamline.absorb import AbsorbTransposeIntoMultiThreshold
from finn.util.basic import get_by_name
from finn.core.datatype import DataType
from finn.core.modelwrapper import ModelWrapper
from finn.util.visualization import showInNetron
from finn.transformation.general import GiveRandomTensorNames, GiveReadableTensorNames, GiveUniqueParameterTensors

model = ModelWrapper("/tmp/quartznet_repetitive_nodes_subpart.onnx")

model = model.transform(GiveRandomTensorNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(GiveUniqueParameterTensors())

# Convert to supported format
model = model.transform(Change3DTo4DTensors())

# Collapse BatchNorm to Add and Mul
model = model.transform(BatchNormToAffine())

# Group additions
model = model.transform(MoveAddPastMul())
model = model.transform(MoveAddPastConv())
model = model.transform(MoveAddPastMul())

# Group multiplications
#### Move mul past fork
model = model.transform(MoveMulPastFork())
model = model.transform(MoveScalarMulPastConv())
model = model.transform(MoveMulPastDWConv())

# Move Mul/Add past join node
model = model.transform(MoveLinearPastEltwiseAdd())

# Collapes additions & multiplications
model = model.transform(CollapseRepeatedAdd())
model = model.transform(CollapseRepeatedMul())

# Absorb Add/Mul into multithreshold
model = model.transform(AbsorbAddIntoMultiThreshold())
model = model.transform(FactorOutMulSignMagnitude())
model = model.transform(Absorb1BitMulIntoConv())
model = model.transform(AbsorbMulIntoMultiThreshold())

# Ensure thresholds are integers
## Add quantization annotation to ensure RoundAndClipThresholds works
for n in model.graph.node:
    if n.op_type=="MultiThreshold":
        odtype = get_by_name(n.attribute, "out_dtype", name_field="name").s.decode("utf-8")
        dtype = getattr(DataType, odtype) 
        #model.set_tensor_datatype(n.input[0], dtype)
        model.set_tensor_datatype(n.input[0], DataType.INT32)

#from finn.transformation.infer_datatypes import InferDataTypes
#model = model.transform(InferDataTypes())
        
model = model.transform(RoundAndClipThresholds())

model = model.transform(LowerConvsToMatMul())

model = model.transform(AbsorbTransposeIntoMultiThreshold())


model.save("/tmp/quartznet_repetitive_nodes_subpart_streamlined.onnx")
showInNetron("/tmp/quartznet_repetitive_nodes_subpart_streamlined.onnx")