In [1]:
import numpy as np
from finn.util.visualization import showInNetron
from finn.core.modelwrapper import ModelWrapper
import onnx


def generate_model(permutation, default_data_layout):
    if permutation == [0, 3, 1, 2]:
        in_shape = [1, 128, 1, 256]
        out_shape = [1, 256, 128, 1]
        data_layout = 'NCHW'
    if permutation == [0, 2, 3, 1]:
        in_shape = [1, 256, 128, 1]
        out_shape = [1, 128, 1, 256]
        data_layout = 'NHWC'   
    
    Transpose1_node = onnx.helper.make_node(
        "Transpose",
        inputs = ['in_transpose1'],
        outputs = ['out_transpose1'],
        perm = permutation
    )

    Transpose2_node = onnx.helper.make_node(
        "Transpose",
        inputs = ['in_transpose2'],
        outputs = ['out_transpose2'],
        perm = permutation
    )

    if default_data_layout is True and data_layout == 'NCHW': # try without setting data layout
        Multithreshold1_node = onnx.helper.make_node(
            "MultiThreshold",
            inputs = ['out_transpose1', 'in2_multithreshold1'],
            outputs = ['out_multithreshold1'],
            domain = 'finn.custom_op.general',
            out_dtype = 'UINT4'
        )

        Multithreshold2_node = onnx.helper.make_node(
            "MultiThreshold",
            inputs = ['out_transpose2', 'in2_multithreshold2'],
            outputs = ['out_multithreshold2'],
            domain = 'finn.custom_op.general',
            out_dtype = 'UINT4'    
        )
    else:
        Multithreshold1_node = onnx.helper.make_node(
            "MultiThreshold",
            inputs = ['out_transpose1', 'in2_multithreshold1'],
            outputs = ['out_multithreshold1'],
            domain = 'finn.custom_op.general',
            out_dtype = 'UINT4',
            data_layout = data_layout
        )

        Multithreshold2_node = onnx.helper.make_node(
            "MultiThreshold",
            inputs = ['out_transpose2', 'in2_multithreshold2'],
            outputs = ['out_multithreshold2'],
            domain = 'finn.custom_op.general',
            out_dtype = 'UINT4',    
            data_layout = data_layout
        )       

    Add1_node = onnx.helper.make_node(
        "Add",
        inputs = ['out_multithreshold1', 'out_multithreshold2'],
        outputs = ['out_add1']
    )
    
    in_transpose1 = onnx.helper.make_tensor_value_info('in_transpose1', onnx.TensorProto.FLOAT, in_shape)
    in_transpose2 = onnx.helper.make_tensor_value_info('in_transpose2', onnx.TensorProto.FLOAT, in_shape)
    out_add1 = onnx.helper.make_tensor_value_info('out_add1', onnx.TensorProto.FLOAT, out_shape)

    out_transpose1 = onnx.helper.make_tensor_value_info('out_transpose1', onnx.TensorProto.FLOAT, out_shape)
    out_transpose2 = onnx.helper.make_tensor_value_info('out_transpose2', onnx.TensorProto.FLOAT, out_shape)
    out_multithreshold1 = onnx.helper.make_tensor_value_info('out_multithreshold1', onnx.TensorProto.FLOAT, out_shape)
    out_multithreshold2 = onnx.helper.make_tensor_value_info('out_multithreshold2', onnx.TensorProto.FLOAT, out_shape)

    in2_multithreshold1 = onnx.helper.make_tensor_value_info('in2_multithreshold1', onnx.TensorProto.FLOAT, [256, 15])
    in2_multithreshold2 = onnx.helper.make_tensor_value_info('in2_multithreshold2', onnx.TensorProto.FLOAT, [256, 15])

    graph = onnx.helper.make_graph(
        nodes = [
            Transpose1_node,
            Transpose2_node,
            Multithreshold1_node,
            Multithreshold2_node,
            Add1_node
        ],
        name = "test_graph",
        inputs = [in_transpose1, in_transpose2],
        outputs = [out_add1],
        value_info = [
            out_transpose1,
            out_transpose2,
            out_multithreshold1,
            out_multithreshold2,
            in2_multithreshold1,
            in2_multithreshold2
        ]
    )

    onnx_model = onnx.helper.make_model(graph, producer_name="test_model")
    model = ModelWrapper(onnx_model)

    mt_weights = np.random.randint(low=-1000, high=1000, size=[256,15])
    mt_weights = np.sort(mt_weights, 1)
    model.set_initializer('in2_multithreshold1', mt_weights)
    model.set_initializer('in2_multithreshold2', mt_weights)

    model.save("/tmp/test_move_transpose_past_mt.onnx")
    showInNetron("/tmp/test_move_transpose_past_mt.onnx")
    
    return model

In [2]:
permutation = [[0, 3, 1, 2], [0, 2, 3, 1]]
model = generate_model(permutation[0], False)

Serving '/tmp/test_move_transpose_past_mt.onnx' at http://0.0.0.0:8081


In [3]:
# Move transpose past multithreshold
from finn.util.basic import get_by_name
from finn.custom_op.general.multithreshold import MultiThreshold
from finn.custom_op.registry import getCustomOp
from finn.transformation.general import SortGraph

model = ModelWrapper("/tmp/test_move_transpose_past_mt.onnx")

graph = model.graph
node_ind = 0
graph_modified = False

for n in graph.node:
    node_ind += 1
    
    if n.op_type == 'Transpose' and not model.is_fork_node(n):
        perm = get_by_name(n.attribute, 'perm', 'name')
        perm = onnx.helper.get_attribute_value(perm)
        consumer = model.find_consumer(n.output[0])
        
        if (consumer.op_type == 'MultiThreshold' and not model.is_fork_node(consumer)):           
            # Can we also use 'self' here, because MultiThreshold().apply is called
            mt_inst = getCustomOp(consumer)
            mt_data_layout = mt_inst.get_nodeattr('data_layout')
            
            # Set new attribute
            if mt_data_layout == 'NCHW':
                nhwc_to_nchw = [0, 3, 1, 2]
                nchw_to_nhwc = [0, 2, 3, 1]
                if perm == nhwc_to_nchw: # Transpose(NHWC) -> NCHW, as we move the transpose node after the multithreshold node, the input tensor will be in NHWC format
                    new_data_layout = 'NHWC'
                    new_data_layout_list = [i for i in new_data_layout]
                else:
                    continue
                    
            elif mt_data_layout == 'NHWC':
                if perm == nchw_to_nhwc: # Transpose(NCHW) -> NHWC, as we move the transpose node after the multithreshold node, the input tensor will be in NCHW format
                    new_data_layout = 'NCHW'
                    new_data_layout_list = [i for i in new_data_layout]
                else:
                    continue
       
            # update the node attribute data_layout
            mt_inst.set_nodeattr('data_layout', new_data_layout)
            
            # Rewire the nodes accordingly
            transpose_in = n.input[0]
            transpose_out = n.output[0]
            multithreshold_out = consumer.output[0]                
            transpose_in_shape = model.get_tensor_shape(transpose_in)
            transpose_out_shape = model.get_tensor_shape(transpose_out)
            transpose_in_layout = model.get_tensor_layout(transpose_in)
            transpose_out_layout = model.get_tensor_layout(transpose_out)
            
            # First we move the multithreshold node in front of the transpose node
            # transpose_in -> multithreshold.input[0]
            consumer.input[0] = transpose_in
            
            # multithreshold_out (reshaped) -> multithreshold.output[0]
            consumer.output[0] = transpose_out
            model.set_tensor_shape(transpose_out, transpose_in_shape)
            if transpose_in_layout is not None:
                model.set_tensor_layout(transpose_out, transpose_in_layout)
            
            # multithreshold_out (reshaped) -> transpose.input[0]
            n.input[0] = consumer.output[0]
            n.output[0] = multithreshold_out
        
# Sort graph to obtain the right order of nodes
model = model.transform(SortGraph())
    
    
model.save("/tmp/test_move_transpose_past_mt_modified.onnx")
showInNetron("/tmp/test_move_transpose_past_mt_modified.onnx")                
                
                

Stopping http://0.0.0.0:8081
Serving '/tmp/test_move_transpose_past_mt_modified.onnx' at http://0.0.0.0:8081


In [4]:
from finn.util.basic import gen_finn_dt_tensor
from finn.core.datatype import DataType
import finn.core.onnx_exec as oxe

model = ModelWrapper("/tmp/test_move_transpose_past_mt.onnx")

# Create input data
input0_tensor_name = model.graph.input[0].name
input1_tensor_name = model.graph.input[1].name

# Note: it is assumed that both tensors have the same shape and data type
input_shape = model.get_tensor_shape(input0_tensor_name)
input_dtype = model.get_tensor_datatype(input0_tensor_name)
input_val = gen_finn_dt_tensor(input_dtype, input_shape)
input_dict = {}
input_dict[input0_tensor_name] = input_val
input_dict[input1_tensor_name] = input_val

model = ModelWrapper("/tmp/test_move_transpose_past_mt.onnx")
model_transformed = ModelWrapper("/tmp/test_move_transpose_past_mt_modified.onnx")
is_same = oxe.compare_execution(model, model_transformed, input_dict)

print(is_same)

True
