To get the ONNX file, run: 'python setup.py test --addopts "-k test_brevitas_quartznet"'

In [1]:
import numpy as np
from finn.util.visualization import showInNetron
from finn.core.modelwrapper import ModelWrapper

file_name = '/tmp/quartznet.onnx'
showInNetron(file_name)

Serving '/tmp/quartznet.onnx' at http://0.0.0.0:8081


In [2]:
### STREAMLINING
from finn.util.visualization import showInNetron
from finn.core.modelwrapper import ModelWrapper

from finn.transformation.streamline import *
from finn.transformation.streamline.reorder import MoveMulPastDWConv, MoveLinearPastEltwiseAdd, MoveMulPastFork
from finn.transformation.change_3d_tensors_to_4d import Change3DTo4DTensors
from finn.transformation.streamline.absorb import AbsorbConsecutiveTransposes # No effect (only on consecutive transpose nodes)
from finn.transformation.streamline.absorb import AbsorbTransposeIntoMultiThreshold
from finn.util.basic import get_by_name
from finn.core.datatype import DataType
from finn.core.modelwrapper import ModelWrapper
from finn.util.visualization import showInNetron
from finn.transformation.general import GiveUniqueNodeNames, GiveRandomTensorNames, GiveReadableTensorNames, GiveUniqueParameterTensors

model = ModelWrapper("/tmp/quartznet.onnx")

######## Tidy-up

model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveRandomTensorNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(GiveUniqueParameterTensors())

# Convert to supported format
model = model.transform(Change3DTo4DTensors())

########

# Collapse BatchNorm to Add and Mul
model = model.transform(BatchNormToAffine())

# Group additions
model = model.transform(MoveAddPastMul())
model = model.transform(MoveAddPastConv())
model = model.transform(MoveAddPastMul())

# Group multiplications
#### Move mul past fork
model = model.transform(MoveMulPastFork())
model = model.transform(MoveScalarMulPastConv())
model = model.transform(MoveMulPastDWConv())

# Move Mul/Add past join node
model = model.transform(MoveLinearPastEltwiseAdd())

# Collapes additions & multiplications
model = model.transform(CollapseRepeatedAdd())
model = model.transform(CollapseRepeatedMul())

# Absorb Add/Mul into multithreshold
model = model.transform(AbsorbAddIntoMultiThreshold())
model = model.transform(FactorOutMulSignMagnitude())
model = model.transform(Absorb1BitMulIntoConv())
model = model.transform(AbsorbMulIntoMultiThreshold())

# Ensure thresholds are integers
## Add quantization annotation to ensure RoundAndClipThresholds works
for n in model.graph.node:
    if n.op_type=="MultiThreshold":
        odtype = get_by_name(n.attribute, "out_dtype", name_field="name").s.decode("utf-8")
        dtype = getattr(DataType, odtype) 
        #model.set_tensor_datatype(n.input[0], dtype)
        model.set_tensor_datatype(n.input[0], DataType.INT32)

#from finn.transformation.infer_datatypes import InferDataTypes
#model = model.transform(InferDataTypes())
        
model = model.transform(RoundAndClipThresholds())

model.save("/tmp/quartznet_streamlined.onnx")
showInNetron("/tmp/quartznet_streamlined.onnx")

Stopping http://0.0.0.0:8081
Serving '/tmp/quartznet_streamlined.onnx' at http://0.0.0.0:8081


In [1]:
## PARTITIONING
from finn.util.visualization import showInNetron
from finn.core.modelwrapper import ModelWrapper

from finn.transformation.create_generic_partitions import PartitionFromDict

model = ModelWrapper("/tmp/quartznet_streamlined.onnx")

#partitionings = {0: range(0, 3), 
#                1: range(3, 27),
#                2: range(27, 51),
#                3: range(51, 75),
#                4: range(75, 99),
#                5: range(99, 123),
#                6: range(123, 147),
#                7: range(147, 171),
#                8: range(171, 195),
#                9: range(195, 219),
#                10: range(219, 243),
#                11: range(243, 267),
#                12: range(267, 291),
#                13: range(291, 315),
#                14: range(315, 339),
#                15: range(339, 363),
#                16: range(363, 376)}
partitionings = {0: range(0, 3), 
                1: range(3, 75),
                2: range(75, 147),
                3: range(147, 219),
                4: range(219, 291),
                5: range(291, 363),
                6: range(363, 376)}

model = model.transform(PartitionFromDict(partitionings))

model.save("/tmp/quartznet_streamlined_partitioned.onnx")
showInNetron("/tmp/quartznet_streamlined_partitioned.onnx")

Serving '/tmp/quartznet_streamlined_partitioned.onnx' at http://0.0.0.0:8081


In [103]:
## LOWERING and ABSORB_TRANSPOSE_INTO_MULTITHRESHOLD
from finn.util.visualization import showInNetron
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.lower_convs_to_matmul import LowerConvsToMatMul
from finn.transformation.streamline.absorb import AbsorbTransposeIntoMultiThreshold
from finn.transformation.streamline.reorder import MoveTransposePastMultiThreshold, MoveTransposePastJoinAdd, MoveTransposeBeforeFork
from finn.util.basic import get_by_name


model = ModelWrapper("/tmp/quartznet_streamlined_partitioned.onnx")

node_ind=0
for n in model.graph.node:
    path_to_partition = get_by_name(n.attribute, "model", "name").s.decode('utf-8')
    print(path_to_partition)
    model_partition = ModelWrapper(path_to_partition)
    
    # Lower
    model_partition = model_partition.transform(LowerConvsToMatMul())
    # Absorb transpose nodes
    model_partition = model_partition.transform(AbsorbTransposeIntoMultiThreshold())
    # Reorder remaining transpose nodes
    model_partition = model_partition.transform(MoveTransposePastMultiThreshold())
    model_partition = model_partition.transform(MoveTransposePastJoinAdd())
    model_partition = model_partition.transform(MoveTransposeBeforeFork())
    
    model_partition.save(path_to_partition)
    
    node_ind+=1

model.save("/tmp/quartznet_streamlined_lowered.onnx")
showInNetron("/tmp/quartznet_streamlined_lowered.onnx")

/tmp/finn_dev_mirza/partitioning_c5rw2io6/partition_0.onnx
/tmp/finn_dev_mirza/partitioning_c5rw2io6/partition_1.onnx
/tmp/finn_dev_mirza/partitioning_c5rw2io6/partition_2.onnx
/tmp/finn_dev_mirza/partitioning_c5rw2io6/partition_3.onnx
/tmp/finn_dev_mirza/partitioning_c5rw2io6/partition_4.onnx
/tmp/finn_dev_mirza/partitioning_c5rw2io6/partition_5.onnx
/tmp/finn_dev_mirza/partitioning_c5rw2io6/partition_6.onnx
Stopping http://0.0.0.0:8081
Serving '/tmp/quartznet_streamlined_lowered.onnx' at http://0.0.0.0:8081


In [104]:
model = ModelWrapper("/tmp/quartznet_streamlined_lowered.onnx")
showInNetron("/tmp/quartznet_streamlined_lowered.onnx")

p = model.graph.node[0]
path = get_by_name(p.attribute, "model", "name").s.decode("utf-8")

showInNetron(path)


Stopping http://0.0.0.0:8081
Serving '/tmp/quartznet_streamlined_lowered.onnx' at http://0.0.0.0:8081
Stopping http://0.0.0.0:8081
Serving '/tmp/finn_dev_mirza/partitioning_c5rw2io6/partition_0.onnx' at http://0.0.0.0:8081


In [105]:
model = ModelWrapper("/tmp/quartznet_streamlined_lowered.onnx")
showInNetron("/tmp/quartznet_streamlined_lowered.onnx")

p = model.graph.node[1]
path = get_by_name(p.attribute, "model", "name").s.decode("utf-8")

showInNetron(path)


Stopping http://0.0.0.0:8081
Serving '/tmp/quartznet_streamlined_lowered.onnx' at http://0.0.0.0:8081
Stopping http://0.0.0.0:8081
Serving '/tmp/finn_dev_mirza/partitioning_c5rw2io6/partition_1.onnx' at http://0.0.0.0:8081


In [107]:
## UNFOLD and ABSORB TRANSPOSE again
from finn.util.visualization import showInNetron
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.unfold_partitions import UnfoldPartitions
from finn.transformation.streamline.absorb import AbsorbTransposeIntoMultiThreshold
from finn.transformation.infer_datatypes import InferDataTypes
from finn.transformation.create_generic_partitions import PartitionFromDict

model = ModelWrapper("/tmp/quartznet_streamlined_lowered.onnx")

new_partitionings = [{0: range(0, 5), 1: range(5, 92)},
                    {2: range(2, 89)},
                    {3: range(3, 90)},
                    {4: range(4, 91)},
                    {5: range(5, 92)},
                    {6: range(6, 21)} 
                    ]

new_partitionings = [{1: range(5, 92)},
                    {2: range(6, 93)},
                    {3: range(7, 94)},
                    {4: range(8, 95)},
                    {5: range(9, 96)},
                    {6: range(11, 21)} 
                    ]

nodes = [n for n in model.graph.node]
for ind, n in enumerate(nodes):
    if ind == 0:
        node_ind_to_unfold = [ind, ind+1] # unfold current and next node
    else:
        node_ind_to_unfold = [ind+6] # ind+1 is the Transpose node (+5 for initial nodes)
    
    model = model.transform(UnfoldPartitions(node_ind_to_unfold))
    model = model.transform(AbsorbTransposeIntoMultiThreshold())
    
    if ind==0:
        model = model.transform(PartitionFromDict(new_partitionings[0]), "/tmp/finn_dev_mirza/partitioning_0lwlvajs")
    if ind==1:
        model = model.transform(PartitionFromDict(new_partitionings[1], "/tmp/finn_dev_mirza/partitioning_0lwlvajs"))
    if ind==2:
        model = model.transform(PartitionFromDict(new_partitionings[2], "/tmp/finn_dev_mirza/partitioning_35sdx5v_"))
    if ind==3:
        model = model.transform(PartitionFromDict(new_partitionings[3], "/tmp/finn_dev_mirza/partitioning_35sdx5v_"))
    if ind==4:
        model = model.transform(PartitionFromDict(new_partitionings[4], "/tmp/finn_dev_mirza/partitioning_35sdx5v_"))
    if ind==5:
        break
    #    model = model.transform(PartitionFromDict(new_partitionings[5]), "/tmp/finn_dev_mirza/partitioning_35sdx5v_")
    
model.save("/tmp/quartznet_temp_test.onnx")
showInNetron("/tmp/quartznet_temp_test.onnx")

Stopping http://0.0.0.0:8081
Serving '/tmp/quartznet_temp_test.onnx' at http://0.0.0.0:8081


In [109]:
model = ModelWrapper("/tmp/quartznet_temp_test.onnx")
#showInNetron("/tmp/quartznet_temp_test.onnx")

p = model.graph.node[5]
path = get_by_name(p.attribute, "model", "name").s.decode("utf-8")

showInNetron(path)


Stopping http://0.0.0.0:8081
Serving '/tmp/finn_dev_mirza/partitioning_9r84q0d4/partition_1.onnx' at http://0.0.0.0:8081


# Compare 2 models

1. Original QuartzNet
2. Any other


In [2]:
import numpy as np
from finn.core.modelwrapper import ModelWrapper
from finn.util.basic import gen_finn_dt_tensor
import finn.core.onnx_exec as oxe

import time
t1 = time.perf_counter()

################################################################################################
####
#### MODEL 1
####
model_1 = ModelWrapper("/tmp/quartznet.onnx")

#### MODEL 1
# Create input data
input0_tensor_name = model_1.graph.input[0].name

input_shape = model_1.get_tensor_shape(input0_tensor_name)
#input_dtype = model_1.get_tensor_datatype(input0_tensor_name)
#input_val = gen_finn_dt_tensor(input_dtype, input_shape)
input_val = np.random.randint(low=-10000, high=10000, size=input_shape).astype(np.float32)
input_dict = {}
input_dict[input0_tensor_name] = input_val
output0_tensor_name = model_1.graph.output[0].name

expected_m1_dict = oxe.execute_onnx(model_1, input_dict, return_full_exec_context = False)
expected_m1 = expected_m1_dict[output0_tensor_name]
################################################################################################


t2 = time.perf_counter() - t1
print("Elapsed time: {}".format(t2))

Elapsed time: 356.18293706100667


In [3]:
import time
t1 = time.perf_counter()

################################################################################################
####
#### MODEL 2
####
#model_2 = ModelWrapper("/tmp/quartznet_streamlined.onnx") #CORRECT
model_2 = ModelWrapper("/tmp/quartznet_streamlined_partitioned.onnx") #CORRECT
#model_2 = ModelWrapper("/tmp/quartznet_streamlined_lowered.onnx") #CORRECT
#model_2 = ModelWrapper("/tmp/quartznet_temp_test.onnx") #CORRECT?

#### MODEL 2
m1_input_val = input_val

input0_tensor_name = model_2.graph.input[0].name
#input_shape = model_2.get_tensor_shape(input0_tensor_name)
#input_dtype = model_2.get_tensor_datatype(input0_tensor_name)
input_dict = {}
m2_input_val = np.reshape(m1_input_val, np.shape(m1_input_val)+(1,))
input_dict[input0_tensor_name] = m2_input_val
output0_tensor_name = model_2.graph.output[0].name

expected_m2_dict = oxe.execute_onnx(model_2, input_dict, return_full_exec_context = False)
expected_m2 = expected_m2_dict[output0_tensor_name]

expected_m2 = np.reshape(expected_m2, np.shape(expected_m1))
m2_input_val = np.reshape(m2_input_val, np.shape(m1_input_val))


assert(m1_input_val==m2_input_val).all()
assert(expected_m1==expected_m2).all()
################################################################################################


t2 = time.perf_counter() - t1
print("Elapsed time: {}".format(t2))

  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype

  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype

  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype.".format(tensor, dtype)
  "FINN datatype

Elapsed time: 698.7329327710031


In [None]:
start: 16:34