In [1]:
import numpy as np
from finn.util.visualization import showInNetron
from finn.core.modelwrapper import ModelWrapper
import onnx
from finn.core.datatype import DataType

Transpose1_node = onnx.helper.make_node(
    "Transpose",
    inputs = ['in_transpose1'],
    outputs = ['out_transpose1'],
    perm = [0, 3, 1, 2]
)

Transpose2_node = onnx.helper.make_node(
    "Transpose",
    inputs = ['in_transpose2'],
    outputs = ['out_transpose2'],
    perm = [0, 3, 1, 2]
)

Join1_node = onnx.helper.make_node(
    "Add",
    inputs = ['out_transpose1', 'out_transpose2'],
    outputs = ['out_join1']
)

in_transpose1 = onnx.helper.make_tensor_value_info('in_transpose1', onnx.TensorProto.FLOAT, [1, 128, 1, 256])
in_transpose2 = onnx.helper.make_tensor_value_info('in_transpose2', onnx.TensorProto.FLOAT, [1, 128, 1, 256])
out_transpose1 = onnx.helper.make_tensor_value_info('out_transpose1', onnx.TensorProto.FLOAT, [1, 256, 128, 1])
out_transpose2 = onnx.helper.make_tensor_value_info('out_transpose2', onnx.TensorProto.FLOAT, [1, 256, 128, 1])
out_join1 = onnx.helper.make_tensor_value_info('out_join1', onnx.TensorProto.FLOAT, [1, 256, 128, 1])

graph = onnx.helper.make_graph(
    nodes = [
        Transpose1_node,
        Transpose2_node,
        Join1_node
    ],
    name = 'test_graph',
    inputs = [in_transpose1, in_transpose2],
    outputs = [out_join1],
    value_info = [
        out_transpose1,
        out_transpose2
    ]
)

onnx_model = onnx.helper.make_model(graph, producer_name='test_model')
model = ModelWrapper(onnx_model)

model.set_tensor_datatype('out_transpose1', DataType.UINT4)

model.save("/tmp/test_move_identical_op_past_join_op.onnx")
showInNetron("/tmp/test_move_identical_op_past_join_op.onnx")



Serving '/tmp/test_move_identical_op_past_join_op.onnx' at http://0.0.0.0:8081


In [2]:
# INSPIRATION: MoveLinearPastEltwiseAdd

import onnx
from finn.transformation.general import SortGraph

# Move identical operations on different branches past the common join node. Transformation should be applied when
# the identical operations are changing the data layout. For linear operations, see the transformation MoveLinearPastEltwiseAdd
# Specifically, this transformation matches and transforms the following patterns:
# f(x) + f(y) -> f(x + y)
# f(x) - f(y) -> f(x - y)
# where f(.) is currently only supporting 'Transpose'

model = ModelWrapper("/tmp/test_move_identical_op_past_join_op.onnx")

identical_op_list = ['Transpose']
join_op_list = ['Add']

graph = model.graph
node_ind = 0
graph_modified = False

def move_node(graph, n, prod0, prod1, node_ind):
    # found! move one of the identical_ops to output, remove the other one
    identical_op0_in0 = prod0.input[0]
    identical_op1_in0 = prod1.input[0]
    add_in0 = n.input[0]
    add_out = n.output[0]
    
    # Rewire
    n.input[0] = identical_op0_in0 # CHECK
    n.input[1] = identical_op1_in0 # CHECK
    
    # Now we create the new output tensor
    # Output tensor of the join node must have the same shape as its input tensor (shape preserving)
    new_shape = model.get_tensor_shape(identical_op0_in0)
    # FINN datatype should be set to the tensor datatype of the output tensor of the add node
    finn_data_type = model.get_tensor_datatype(add_out) 
    # ONNX datatype must be set to the tensor datatype of the output tensor of the add node
    value_info_addout = model.get_tensor_valueinfo(add_out)
    onnx_data_type = value_info_addout.type.tensor_type.elem_type
    
    # Set new tensor shape with appropriate ONNX and FINN datatypes.
    model.set_tensor_shape(
        tensor_name = add_in0,
        tensor_shape = new_shape,
        dtype = onnx_data_type
    )
    model.set_tensor_datatype(add_in0, finn_data_type)
    
    n.output[0] = add_in0 # CHECK
    prod0.input[0] = add_in0 # CHECK
    prod0.output[0] = add_out # CHECK
    
    graph.node.remove(prod1)
    
for n in model.graph.node:
    node_ind += 1
    if n.op_type in join_op_list and model.is_join_node(n):
        in0 = n.input[0]
        in1 = n.input[1]
        if in0 is None or in1 is None:
            continue
        
        prod0 = model.find_producer(in0)
        prod1 = model.find_producer(in1)
        if prod0 is None or prod1 is None or prod0==prod1: # is this needed?
            continue
        
        identical_op = prod0.op_type == prod1.op_type
        
        if identical_op and prod0.op_type in identical_op_list:
            # Currently, only transpose operation is supported. Adding additional op_types can be done by extending
            # the if-branches below. 
            if prod0.op_type == 'Transpose':
                move_node(graph, n, prod0, prod1, node_ind)
                
model = model.transform(SortGraph())
                
                
model.save("/tmp/test_move_identical_op_past_join_op_modified.onnx")
showInNetron("/tmp/test_move_identical_op_past_join_op_modified.onnx")
        

Stopping http://0.0.0.0:8081
Serving '/tmp/test_move_identical_op_past_join_op_modified.onnx' at http://0.0.0.0:8081


In [2]:
from finn.util.basic import gen_finn_dt_tensor
from finn.core.datatype import DataType
import finn.core.onnx_exec as oxe

model = ModelWrapper("/tmp/test_move_identical_op_past_join_op.onnx")
model_transformed = ModelWrapper("/tmp/test_move_identical_op_past_join_op_modified.onnx")

# Create input data
input0_tensor_name = model.graph.input[0].name
input1_tensor_name = model.graph.input[1].name

# Note: it is assumed that both tensors have the same shape and data type
input_shape = model.get_tensor_shape(input0_tensor_name)
input_dtype = model.get_tensor_datatype(input0_tensor_name)
input_val = gen_finn_dt_tensor(input_dtype, input_shape)
input_dict = {}
input_dict[input0_tensor_name] = input_val
input_dict[input1_tensor_name] = input_val

is_same = oxe.compare_execution(model, model_transformed, input_dict)

#for n in model_transformed.graph.input:
#    print(n)

print(is_same)


Fail: [ONNXRuntimeError] : 1 : FAIL : Node: Output:out_transpose1 [ShapeInferenceError] Can't merge shape info. Both source and target dimension have values but they differ. Source=128 Target=256 Dimension=1

# Class

In [2]:
from finn.transformation.base import Transformation
import onnx
from finn.transformation.general import SortGraph
from finn.core.modelwrapper import ModelWrapper
from finn.util.visualization import showInNetron

In [3]:
class MoveIdenticalOpPastJoinOp(Transformation):
    """
    Move identical operations on different branches past the common join node.
    Transformation should be applied when
    the identical operations are changing the data layout. For linear operations, see the transformation MoveLinearPastEltwiseAdd
    Specifically, this transformation matches and transforms the following patterns:
    f(x) + f(y) -> f(x + y)
    f(x) - f(y) -> f(x - y)
    where f(.) is currently only supporting 'Transpose' 
    """
    
    def __init__(self, identical_op_list, join_node_list):
        super().__init__()
        self.ops_to_move = identical_op_list
        self.join_node_op = join_node_list
    
    def move_node(self, model, n, prod0, prod1):
            # Found! move one of the identical_ops to output, remove the other one
            identical_op0_in0 = prod0.input[0]
            identical_op1_in0 = prod1.input[0]
            add_in0 = n.input[0]
            add_out = n.output[0]

            # Rewire
            n.input[0] = identical_op0_in0 
            n.input[1] = identical_op1_in0 

            # Create the new output tensor
            # Output tensor of the join node must have the same shape
            # as its input tensor (shape preserving)
            new_shape = model.get_tensor_shape(identical_op0_in0)
            # FINN datatype should be set to the tensor datatype of the output tensor of the add node
            #finn_data_type = model.get_tensor_datatype(add_out) 
            # ONNX datatype must be set to the tensor datatype of the output tensor of the add node
            #value_info_addout = model.get_tensor_valueinfo(add_out)
            #onnx_data_type = value_info_addout.type.tensor_type.elem_type

            # Set new tensor shape with appropriate ONNX and FINN datatypes.
            model.set_tensor_shape(
                tensor_name = add_in0,
                tensor_shape = new_shape
            )
            #model.set_tensor_datatype(add_in0, finn_data_type)

            n.output[0] = add_in0
            prod0.input[0] = add_in0 
            prod0.output[0] = add_out

            model.graph.node.remove(prod1)
    
    def apply(self, model):
        graph = model.graph
        graph_modified = False
        for n in graph.node:
            if n.op_type in self.join_node_op and model.is_join_node(n):
                in0 = n.input[0]
                in1 = n.input[1]
                if in0 is None or in1 is None:
                    continue

                prod0 = model.find_producer(in0)
                prod1 = model.find_producer(in1)
                # checks if the join node is preceded by two different, but identical operations
                if prod0==prod1:
                    continue

                identical_op = prod0.op_type == prod1.op_type

                if identical_op and prod0.op_type in self.ops_to_move:
                    self.move_node(model, n, prod0, prod1)
                    graph_modified = True

        if graph_modified:
            model = model.transform(SortGraph(), make_deepcopy=False, cleanup=False)
        
        return (model, graph_modified)
        
class MoveTransposePastJoinAdd(MoveIdenticalOpPastJoinOp):
    def __init__(self):
        super().__init__(["Transpose"], ["Add"])

In [4]:
model = ModelWrapper("/tmp/test_move_identical_op_past_join_op.onnx")

model = model.transform(MoveTransposePastJoinAdd())

model.save("/tmp/test_move_identical_op_past_join_op_modified.onnx")
showInNetron("/tmp/test_move_identical_op_past_join_op_modified.onnx")

Stopping http://0.0.0.0:8081
Serving '/tmp/test_move_identical_op_past_join_op_modified.onnx' at http://0.0.0.0:8081


# Test

In [5]:
#import pytest
import numpy as np

from onnx import helper as oh
from onnx import TensorProto

from finn.core.modelwrapper import ModelWrapper
# from finn.transformation.streamline.reorder import MoveTransposePastFork
from finn.util.basic import gen_finn_dt_tensor
import finn.core.onnx_exec as oxe

# perm = [0, 3, 1, 2]
def create_model(perm):
    if perm == [0, 3, 1, 2]:
        in_shape = [1, 128, 1, 256]
        out_shape = [1, 256, 128, 1]
    if perm == [0, 2, 3, 1]:
        in_shape = [1, 256, 128, 1]
        out_shape = [1, 128, 1, 256]
    
    Transpose1_node = onnx.helper.make_node(
        "Transpose",
        inputs = ['in_transpose1'],
        outputs = ['out_transpose1'],
        perm = perm
    )

    Transpose2_node = onnx.helper.make_node(
        "Transpose",
        inputs = ['in_transpose2'],
        outputs = ['out_transpose2'],
        perm = perm
    )

    Join1_node = onnx.helper.make_node(
        "Add",
        inputs = ['out_transpose1', 'out_transpose2'],
        outputs = ['out_join1']
    )

    in_transpose1 = onnx.helper.make_tensor_value_info('in_transpose1', onnx.TensorProto.FLOAT, in_shape)
    in_transpose2 = onnx.helper.make_tensor_value_info('in_transpose2', onnx.TensorProto.FLOAT, in_shape)
    out_transpose1 = onnx.helper.make_tensor_value_info('out_transpose1', onnx.TensorProto.FLOAT, out_shape)
    out_transpose2 = onnx.helper.make_tensor_value_info('out_transpose2', onnx.TensorProto.FLOAT, out_shape)
    out_join1 = onnx.helper.make_tensor_value_info('out_join1', onnx.TensorProto.FLOAT, out_shape)

    graph = onnx.helper.make_graph(
        nodes = [
            Transpose1_node,
            Transpose2_node,
            Join1_node
        ],
        name = 'test_graph',
        inputs = [in_transpose1, in_transpose2],
        outputs = [out_join1],
        value_info = [
            out_transpose1,
            out_transpose2,
        ]
    )

    onnx_model = onnx.helper.make_model(graph, producer_name='test_model')
    model = ModelWrapper(onnx_model)

    return model


# Permutation of transpose node
# @pytest.mark.parametrize("perm", [[0, 3, 1, 2], [0, 2, 3, 1]])
def test_move_identical_op_past_join_op(perm):
    model = create_model(perm)
    
    # Create input data
    input0_tensor_name = model.graph.input[0].name
    input1_tensor_name = model.graph.input[1].name
    
    # Note: it is assumed that both tensors have the same shape and data type
    input_shape = model.get_tensor_shape(input0_tensor_name)
    input_dtype = model.get_tensor_datatype(input0_tensor_name)
    input_val = gen_finn_dt_tensor(input_dtype, input_shape)
    input_dict = {}
    input_dict[input0_tensor_name] = input_val
    input_dict[input1_tensor_name] = input_val
    
    model_transformed = model.transform(MoveTransposePastJoinAdd())
    
    assert oxe.compare_execution(model, model_transformed, input_dict)
    
    # Check if order changed
    node0_input0_model = model.find_consumers(model.graph.input[0].name)[0].op_type
    node1_input1_model = model.find_consumers(model.graph.input[1].name)[0].op_type
    node0_input0_model_transformed = model_transformed.find_consumers(model_transformed.graph.input[0].name)[0].op_type
    node1_input1_model_transformed = model_transformed.find_consumers(model_transformed.graph.input[1].name)[0].op_type
    assert node0_input0_model != node0_input0_model_transformed
    assert node1_input1_model != node1_input1_model_transformed
     

In [7]:
test_move_identical_op_past_join_op([0,3,1,2])
test_move_identical_op_past_join_op([0,2,3,1])