In [1]:
import numpy as np
from finn.util.visualization import showInNetron
from finn.core.modelwrapper import ModelWrapper
import onnx

Transpose1_node = onnx.helper.make_node(
    "Transpose",
    inputs = ['in_transpose1'],
    outputs = ['out_transpose1'],
    perm = [0, 3, 1, 2]
)

Transpose2_node = onnx.helper.make_node(
    "Transpose",
    inputs = ['in_transpose2'],
    outputs = ['out_transpose2'],
    perm = [0, 3, 1, 2]
)

Join1_node = onnx.helper.make_node(
    "Add",
    inputs = ['out_transpose1', 'out_transpose2'],
    outputs = ['out_join1']
)

in_transpose1 = onnx.helper.make_tensor_value_info('in_transpose1', onnx.TensorProto.FLOAT, [1, 128, 1, 256])
in_transpose2 = onnx.helper.make_tensor_value_info('in_transpose2', onnx.TensorProto.FLOAT, [1, 128, 1, 256])
out_transpose1 = onnx.helper.make_tensor_value_info('out_transpose1', onnx.TensorProto.FLOAT, [1, 256, 128, 1])
out_transpose2 = onnx.helper.make_tensor_value_info('out_transpose2', onnx.TensorProto.FLOAT, [1, 256, 128, 1])
out_join1 = onnx.helper.make_tensor_value_info('out_join1', onnx.TensorProto.FLOAT, [1, 256, 128, 1])

graph = onnx.helper.make_graph(
    nodes = [
        Transpose1_node,
        Transpose2_node,
        Join1_node
    ],
    name = 'test_graph',
    inputs = [in_transpose1, in_transpose2],
    outputs = [out_join1],
    value_info = [
        in_transpose1,
        in_transpose2,
        out_transpose1,
        out_transpose2,
        out_join1
    ]
)

onnx_model = onnx.helper.make_model(graph, producer_name='test_model')
model = ModelWrapper(onnx_model)

model.save("/tmp/test_move_identical_op_past_join_op.onnx")
showInNetron("/tmp/test_move_identical_op_past_join_op.onnx")



Serving '/tmp/test_move_identical_op_past_join_op.onnx' at http://0.0.0.0:8081


In [2]:
# INSPIRATION: MoveLinearPastEltwiseAdd

import onnx
from finn.transformation.general import SortGraph

# Move identical operations on different branches past the common join node. Transformation should be applied when
# the identical operations are changing the data layout. For linear operations, see the transformation MoveLinearPastEltwiseAdd
# Specifically, this transformation matches and transforms the following patterns:
# f(x) + f(y) -> f(x + y)
# f(x) - f(y) -> f(x - y)
# where f(.) is currently only supporting 'Transpose'

model = ModelWrapper("/tmp/test_move_identical_op_past_join_op.onnx")

identical_op_list = ['Transpose']
join_op_list = ['Add']

graph = model.graph
node_ind = 0
graph_modified = False

def move_node(graph, n, prod0, prod1, node_ind):
    # found! move one of the identical_ops to output, remove the other one
    identical_op0_in0 = prod0.input[0]
    identical_op1_in0 = prod1.input[0]
    add_in0 = n.input[0]
    add_out = n.output[0]
    
    # Rewire
    n.input[0] = identical_op0_in0 # CHECK
    n.input[1] = identical_op1_in0 # CHECK
    
    # Now we create the new output tensor
    # Output tensor of the join node must have the same shape as its input tensor (shape preserving)
    new_shape = model.get_tensor_shape(identical_op0_in0)
    # FINN datatype should be set to the tensor datatype of the output tensor of the add node
    finn_data_type = model.get_tensor_datatype(add_out) 
    # ONNX datatype must be set to the tensor datatype of the output tensor of the add node
    value_info_addout = model.get_tensor_valueinfo(add_out)
    onnx_data_type = value_info_addout.type.tensor_type.elem_type
    
    # Set new tensor shape with appropriate ONNX and FINN datatypes.
    model.set_tensor_shape(
        tensor_name = add_in0,
        tensor_shape = new_shape,
        dtype = onnx_data_type
    )
    model.set_tensor_datatype(add_in0, finn_data_type)
    
    n.output[0] = add_in0 # CHECK
    prod0.input[0] = add_in0 # CHECK
    prod0.output[0] = add_out # CHECK
    
    graph.node.remove(prod1)
    
for n in model.graph.node:
    node_ind += 1
    if n.op_type in join_op_list and model.is_join_node(n):
        in0 = n.input[0]
        in1 = n.input[1]
        if in0 is None or in1 is None:
            continue
        
        prod0 = model.find_producer(in0)
        prod1 = model.find_producer(in1)
        if prod0 is None or prod1 is None or prod0==prod1: # is this needed?
            continue
        
        identical_op = prod0.op_type == prod1.op_type
        
        if identical_op and prod0.op_type in identical_op_list:
            # Currently, only transpose operation is supported. Adding additional op_types can be done by extending
            # the if-branches below. 
            if prod0.op_type == 'Transpose':
                move_node(graph, n, prod0, prod1, node_ind)
                
model = model.transform(SortGraph())
                
                
model.save("/tmp/test_move_identical_op_past_join_op_modified.onnx")
showInNetron("/tmp/test_move_identical_op_past_join_op_modified.onnx")
        

Stopping http://0.0.0.0:8081
Serving '/tmp/test_move_identical_op_past_join_op_modified.onnx' at http://0.0.0.0:8081


In [7]:
from finn.util.basic import gen_finn_dt_tensor
from finn.core.datatype import DataType
import finn.core.onnx_exec as oxe

# Create input data
input0_tensor_name = model.graph.input[0].name
input1_tensor_name = model.graph.input[1].name

# Note: it is assumed that both tensors have the same shape and data type
input_shape = model.get_tensor_shape(input0_tensor_name)
input_dtype = model.get_tensor_datatype(input0_tensor_name)
input_val = gen_finn_dt_tensor(input_dtype, input_shape)
input_dict = {}
input_dict[input0_tensor_name] = input_val
input_dict[input1_tensor_name] = input_val

model = ModelWrapper("/tmp/test_move_identical_op_past_join_op.onnx")
model_transformed = ModelWrapper("/tmp/test_move_transpose_past_mt_modified.onnx")
is_same = oxe.compare_execution(model, model_transformed, input_dict)

#print(is_same)

Fail: [ONNXRuntimeError] : 1 : FAIL : Error: Duplicate definition-site for (in_transpose1).