In [1]:
import numpy as np
from finn.util.visualization import showInNetron
from finn.core.modelwrapper import ModelWrapper
import onnx


def generate_model(perm, default_data_layout):
    if perm == [0, 3, 1, 2]:
        in_shape = [1, 128, 1, 256]
        out_shape = [1, 256, 128, 1]
        data_layout = 'NCHW'
    if perm == [0, 2, 3, 1]:
        in_shape = [1, 256, 128, 1]
        out_shape = [1, 128, 1, 256]
        data_layout = 'NHWC'   
    
    Transpose1_node = onnx.helper.make_node(
        "Transpose",
        inputs = ['in_transpose1'],
        outputs = ['out_transpose1'],
        perm = perm
    )

    Transpose2_node = onnx.helper.make_node(
        "Transpose",
        inputs = ['in_transpose2'],
        outputs = ['out_transpose2'],
        perm = perm
    )

    if default_data_layout is True and data_layout == 'NCHW': # meaning that we will not set the data_layout attribute
        Multithreshold1_node = onnx.helper.make_node(
            "MultiThreshold",
            inputs = ['out_transpose1', 'in2_multithreshold1'],
            outputs = ['out_multithreshold1'],
            domain = 'finn.custom_op.general',
            out_dtype = 'UINT4'
        )

        Multithreshold2_node = onnx.helper.make_node(
            "MultiThreshold",
            inputs = ['out_transpose2', 'in2_multithreshold2'],
            outputs = ['out_multithreshold2'],
            domain = 'finn.custom_op.general',
            out_dtype = 'UINT4'    
        )
    else: # we set the data_layout attribute
        Multithreshold1_node = onnx.helper.make_node(
            "MultiThreshold",
            inputs = ['out_transpose1', 'in2_multithreshold1'],
            outputs = ['out_multithreshold1'],
            domain = 'finn.custom_op.general',
            out_dtype = 'UINT4',
            data_layout = data_layout
        )

        Multithreshold2_node = onnx.helper.make_node(
            "MultiThreshold",
            inputs = ['out_transpose2', 'in2_multithreshold2'],
            outputs = ['out_multithreshold2'],
            domain = 'finn.custom_op.general',
            out_dtype = 'UINT4',    
            data_layout = data_layout
        )       

    Add1_node = onnx.helper.make_node(
        "Add",
        inputs = ['out_multithreshold1', 'out_multithreshold2'],
        outputs = ['out_add1']
    )
    
    in_transpose1 = onnx.helper.make_tensor_value_info('in_transpose1', onnx.TensorProto.FLOAT, in_shape)
    in_transpose2 = onnx.helper.make_tensor_value_info('in_transpose2', onnx.TensorProto.FLOAT, in_shape)
    out_add1 = onnx.helper.make_tensor_value_info('out_add1', onnx.TensorProto.FLOAT, out_shape)

    out_transpose1 = onnx.helper.make_tensor_value_info('out_transpose1', onnx.TensorProto.FLOAT, out_shape)
    out_transpose2 = onnx.helper.make_tensor_value_info('out_transpose2', onnx.TensorProto.FLOAT, out_shape)
    out_multithreshold1 = onnx.helper.make_tensor_value_info('out_multithreshold1', onnx.TensorProto.FLOAT, out_shape)
    out_multithreshold2 = onnx.helper.make_tensor_value_info('out_multithreshold2', onnx.TensorProto.FLOAT, out_shape)

    in2_multithreshold1 = onnx.helper.make_tensor_value_info('in2_multithreshold1', onnx.TensorProto.FLOAT, [256, 15])
    in2_multithreshold2 = onnx.helper.make_tensor_value_info('in2_multithreshold2', onnx.TensorProto.FLOAT, [256, 15])

    graph = onnx.helper.make_graph(
        nodes = [
            Transpose1_node,
            Transpose2_node,
            Multithreshold1_node,
            Multithreshold2_node,
            Add1_node
        ],
        name = "test_graph",
        inputs = [in_transpose1, in_transpose2],
        outputs = [out_add1],
        value_info = [
            out_transpose1,
            out_transpose2,
            out_multithreshold1,
            out_multithreshold2,
            in2_multithreshold1,
            in2_multithreshold2
        ]
    )

    onnx_model = onnx.helper.make_model(graph, producer_name="test_model")
    model = ModelWrapper(onnx_model)

    mt_weights = np.random.randint(low=-1000, high=1000, size=[256,15])
    mt_weights = np.sort(mt_weights, 1)
    model.set_initializer('in2_multithreshold1', mt_weights)
    model.set_initializer('in2_multithreshold2', mt_weights)

    model.save("/tmp/test_move_transpose_past_mt.onnx")
    showInNetron("/tmp/test_move_transpose_past_mt.onnx")
    
    return model

In [2]:
perm = [[0, 3, 1, 2], [0, 2, 3, 1]]
model = generate_model(perm[1], False)

Serving '/tmp/test_move_transpose_past_mt.onnx' at http://0.0.0.0:8081


In [6]:
# Move transpose past multithreshold
from finn.util.basic import get_by_name
from finn.custom_op.registry import getCustomOp
from finn.transformation.general import SortGraph
import finn.core.data_layout as DataLayout

model = ModelWrapper("/tmp/test_move_transpose_past_mt.onnx")

graph = model.graph
graph_modified = False
for n in graph.node:
    print(n.op_type)
    if (n.op_type == 'Transpose' and not model.is_fork_node(n)):
        consumer = model.find_consumer(n.output[0])
        if (consumer is not None and consumer.op_type == 'MultiThreshold' and not model.is_fork_node(consumer)):
            perm = get_by_name(n.attribute, 'perm', 'name')
            perm = onnx.helper.get_attribute_value(perm)
            nhwc_to_nchw = [0, 3, 1, 2]
            nchw_to_nhwc = [0, 2, 3, 1]
            mt_inst = getCustomOp(consumer)
            mt_data_layout = mt_inst.get_nodeattr('data_layout')

            # Set new attribute
            if mt_data_layout == 'NCHW':
                if perm == nhwc_to_nchw:
                    # Transpose(NHWC) -> NCHW. As we move the transpose node after the multithreshold node,
                    # the input tensor will be in NHWC format
                    new_data_layout = 'NHWC'
                    new_tensor_data_layout = DataLayout.NHWC
                else:
                    continue

            elif mt_data_layout == 'NHWC':
                if perm == nchw_to_nhwc:
                    # Transpose(NCHW) -> NHWC. As we move the transpose node after the multithreshold node,
                    # the input tensor will be in NCHW format
                    new_data_layout = 'NCHW'
                    new_tensor_data_layout = DataLayout.NCHW
                else:
                    continue

            # Update the node attribute data_layout
            mt_inst.set_nodeattr('data_layout', new_data_layout)

            # Now we will rewire the nodes accordingly
            transpose_in = n.input[0]
            transpose_out = n.output[0]
            multithreshold_out = consumer.output[0]
            transpose_in_shape = model.get_tensor_shape(transpose_in)
            transpose_in_layout = model.get_tensor_layout(transpose_in)

            # First we move the multithreshold node in front of the transpose node
            consumer.input[0] = transpose_in
            consumer.output[0] = transpose_out
            # We must change the shape of transpose_out tensor to match the shape of the input tensor
            model.set_tensor_shape(transpose_out, transpose_in_shape)
            model.set_tensor_layout(transpose_out, new_tensor_data_layout)
            model.set_tensor_layout(transpose_in, new_tensor_data_layout)

            # Finally, we ensure the right tensors are connected to the Transpose node
            n.input[0] = consumer.output[0]
            n.output[0] = multithreshold_out

# We must ensure the graph is sorted, because we reordered the nodes.
model = model.transform(SortGraph())    
    
model.save("/tmp/test_move_transpose_past_mt_modified.onnx")
showInNetron("/tmp/test_move_transpose_past_mt_modified.onnx")                
                
                

Transpose
Transpose
MultiThreshold
MultiThreshold
Add
Stopping http://0.0.0.0:8081
Serving '/tmp/test_move_transpose_past_mt_modified.onnx' at http://0.0.0.0:8081


In [35]:
from finn.util.basic import gen_finn_dt_tensor
from finn.core.datatype import DataType
import finn.core.onnx_exec as oxe

model = ModelWrapper("/tmp/test_move_transpose_past_mt.onnx")

# Create input data
input0_tensor_name = model.graph.input[0].name
input1_tensor_name = model.graph.input[1].name

# Note: it is assumed that both tensors have the same shape and data type
input_shape = model.get_tensor_shape(input0_tensor_name)
input_dtype = model.get_tensor_datatype(input0_tensor_name)
input_val = gen_finn_dt_tensor(input_dtype, input_shape)
input_dict = {}
input_dict[input0_tensor_name] = input_val
input_dict[input1_tensor_name] = input_val

model = ModelWrapper("/tmp/test_move_transpose_past_mt.onnx")
model_transformed = ModelWrapper("/tmp/test_move_transpose_past_mt_modified.onnx")
is_same = oxe.compare_execution(model, model_transformed, input_dict)

print(is_same)

True


In [38]:
# Check if order changed
node0_input0_model = model.find_consumer(model.graph.input[0].name).op_type
node1_input1_model = model.find_consumer(model.graph.input[1].name).op_type
node0_input0_model_transformed = model_transformed.find_consumer(model_transformed.graph.input[0].name).op_type
node1_input1_model_transformed = model_transformed.find_consumer(model_transformed.graph.input[1].name).op_type
assert node0_input0_model != node0_input0_model_transformed
assert node1_input1_model != node1_input1_model_transformed
mt0_input = model_transformed.graph.node[0].input[0]
mt1_input = model_transformed.graph.node[1].input[0]
if perm == [0, 3, 1, 2]:
    assert model_transformed.get_tensor_layout(mt0_input) == DataLayout.NHWC
    assert model_transformed.get_tensor_layout(mt1_input) == DataLayout.NHWC
if perm == [0, 2, 3, 1]:
    assert model_transformed.get_tensor_layout(mt0_input) == DataLayout.NCHW
    assert model_transformed.get_tensor_layout(mt1_input) == DataLayout.NCHW

[0, 2, 3, 1]


# Class

In [None]:
class MoveTransposePastMultiThreshold(Transformation):
    """Moves Transpose nodes past MultiThreshold nodes on linear segments
    of the graph."""

    def apply(self, model):
        graph = model.graph
        graph_modified = False
        for n in graph.node:
            if n.op_type == "Transpose" and not model.is_fork_node(n):
                consumer = model.find_consumer(n.output[0])
                if (
                    consumer is not None
                    and consumer.op_type == "MultiThreshold"
                    and not model.is_fork_node(consumer)
                ):
                    perm = get_by_name(n.attribute, "perm", "name")
                    perm = oh.get_attribute_value(perm)
                    nhwc_to_nchw = [0, 3, 1, 2]
                    nchw_to_nhwc = [0, 2, 3, 1]
                    mt_inst = getCustomOp(consumer)
                    mt_data_layout = mt_inst.get_nodeattr("data_layout")

                    # Set new attribute
                    if mt_data_layout == "NCHW":
                        if perm == nhwc_to_nchw:
                            # Transpose(NHWC) -> NCHW. As we move the transpose node
                            # after the multithreshold node, the input tensor
                            # will be in NHWC format
                            new_data_layout = "NHWC"
                            new_tensor_data_layout = DataLayout.NHWC
                            graph_modified = True
                        else:
                            continue

                    elif mt_data_layout == "NHWC":
                        if perm == nchw_to_nhwc:
                            # Transpose(NCHW) -> NHWC. As we move the transpose node
                            # after the multithreshold node, the input tensor
                            # will be in NCHW format
                            new_data_layout = "NCHW"
                            new_tensor_data_layout = DataLayout.NCHW
                            graph_modified = True
                        else:
                            continue

                    # Update the node attribute data_layout
                    mt_inst.set_nodeattr("data_layout", new_data_layout)

                    # Now we will rewire the nodes accordingly
                    transpose_in = n.input[0]
                    transpose_out = n.output[0]
                    multithreshold_out = consumer.output[0]
                    transpose_in_shape = model.get_tensor_shape(transpose_in)

                    # Move the multithreshold node in front of the transpose node
                    consumer.input[0] = transpose_in
                    consumer.output[0] = transpose_out
                    # Change the shape of transpose_out tensor to match the shape
                    # of the input tensor
                    model.set_tensor_shape(transpose_out, transpose_in_shape)
                    model.set_tensor_layout(transpose_out, new_tensor_data_layout)
                    model.set_tensor_layout(transpose_in, new_tensor_data_layout)

                    # Ensure the right tensors are connected to the Transpose node
                    n.input[0] = consumer.output[0]
                    n.output[0] = multithreshold_out

        # We must ensure the graph is sorted, because we reordered the nodes
        if graph_modified:
            model = model.transform(SortGraph(), make_deepcopy=False, cleanup=False)

        return (model, graph_modified)

# Test

In [None]:
import pytest

import numpy as np
from onnx import helper as oh
from onnx import TensorProto

from finn.core.modelwrapper import ModelWrapper
import finn.core.data_layout as DataLayout
from finn.transformation.streamline.reorder import MoveTransposePastMultiThreshold
from finn.util.basic import gen_finn_dt_tensor
import finn.core.onnx_exec as oxe

def create_model(permutation, default_data_layout):
    if permutation == [0, 3, 1, 2]:
        in_shape = [1, 128, 1, 256]
        out_shape = [1, 256, 128, 1]
        data_layout = 'NCHW'
    if permutation == [0, 2, 3, 1]:
        in_shape = [1, 256, 128, 1]
        out_shape = [1, 128, 1, 256]
        data_layout = 'NHWC'

    Transpose1_node = oh.make_node(
        "Transpose",
        inputs = ['in_transpose1'],
        outputs = ['out_transpose1'],
        perm = permutation
    )

    Transpose2_node = oh.make_node(
        "Transpose",
        inputs = ['in_transpose2'],
        outputs = ['out_transpose2'],
        perm = permutation
    )

    if default_data_layout is True and data_layout == 'NCHW': # meaning that we will not set the data_layout attribute
        Multithreshold1_node = oh.make_node(
            "MultiThreshold",
            inputs = ['out_transpose1', 'in2_multithreshold1'],
            outputs = ['out_multithreshold1'],
            domain = 'finn.custom_op.general',
            out_dtype = 'UINT4'
        )

        Multithreshold2_node = oh.make_node(
            "MultiThreshold",
            inputs = ['out_transpose2', 'in2_multithreshold2'],
            outputs = ['out_multithreshold2'],
            domain = 'finn.custom_op.general',
            out_dtype = 'UINT4'
        )
    else: # we set the data_layout attribute
        Multithreshold1_node = oh.make_node(
            "MultiThreshold",
            inputs = ['out_transpose1', 'in2_multithreshold1'],
            outputs = ['out_multithreshold1'],
            domain = 'finn.custom_op.general',
            out_dtype = 'UINT4',
            data_layout = data_layout
        )

        Multithreshold2_node = oh.make_node(
            "MultiThreshold",
            inputs = ['out_transpose2', 'in2_multithreshold2'],
            outputs = ['out_multithreshold2'],
            domain = 'finn.custom_op.general',
            out_dtype = 'UINT4',
            data_layout = data_layout
        )

    Add1_node = oh.make_node(
        "Add",
        inputs = ['out_multithreshold1', 'out_multithreshold2'],
        outputs = ['out_add1']
    )

    in_transpose1 = oh.make_tensor_value_info('in_transpose1', TensorProto.FLOAT, in_shape)
    in_transpose2 = oh.make_tensor_value_info('in_transpose2', TensorProto.FLOAT, in_shape)
    out_add1 = oh.make_tensor_value_info('out_add1', TensorProto.FLOAT, out_shape)

    out_transpose1 = oh.make_tensor_value_info('out_transpose1', TensorProto.FLOAT, out_shape)
    out_transpose2 = oh.make_tensor_value_info('out_transpose2', TensorProto.FLOAT, out_shape)
    out_multithreshold1 = oh.make_tensor_value_info('out_multithreshold1', TensorProto.FLOAT, out_shape)
    out_multithreshold2 = oh.make_tensor_value_info('out_multithreshold2', TensorProto.FLOAT, out_shape)

    in2_multithreshold1 = oh.make_tensor_value_info('in2_multithreshold1', TensorProto.FLOAT, [256, 15])
    in2_multithreshold2 = oh.make_tensor_value_info('in2_multithreshold2', TensorProto.FLOAT, [256, 15])

    graph = oh.make_graph(
        nodes = [
            Transpose1_node,
            Transpose2_node,
            Multithreshold1_node,
            Multithreshold2_node,
            Add1_node
        ],
        name = "test_graph",
        inputs = [in_transpose1, in_transpose2],
        outputs = [out_add1],
        value_info = [
            out_transpose1,
            out_transpose2,
            out_multithreshold1,
            out_multithreshold2,
            in2_multithreshold1,
            in2_multithreshold2
        ]
    )

    onnx_model = oh.make_model(graph, producer_name="test_model")
    model = ModelWrapper(onnx_model)

    mt_weights = np.random.randint(low=-1000, high=1000, size=[256,15])
    mt_weights = np.sort(mt_weights, 1)
    model.set_initializer('in2_multithreshold1', mt_weights)
    model.set_initializer('in2_multithreshold2', mt_weights)

    return model


# permutation of transpose node
@pytest.mark.parametrize("perm", [[0, 3, 1, 2], [0, 2, 3, 1]])
# default data layout variable
@pytest.mark.parametrize("default_data_layout", [True, False])
def test_move_transpose_past_multithreshold(perm, default_data_layout):
    model = create_model(perm, default_data_layout)

    # Create input data
    input0_tensor_name = model.graph.input[0].name
    input1_tensor_name = model.graph.input[1].name

    # Note: it is assumed that both tensors have the same shape and data type
    input_shape = model.get_tensor_shape(input0_tensor_name)
    input_dtype = model.get_tensor_datatype(input0_tensor_name)
    input_val = gen_finn_dt_tensor(input_dtype, input_shape)
    input_dict = {}
    input_dict[input0_tensor_name] = input_val
    input_dict[input1_tensor_name] = input_val

    model_transformed = model.transform(MoveTransposePastMultiThreshold())

    assert oxe.compare_execution(model, model_transformed, input_dict)

    # Check if order changed
    node0_input0_model = model.find_consumer(model.graph.input[0].name).op_type
    node1_input1_model = model.find_consumer(model.graph.input[1].name).op_type
    node0_input0_model_transformed = model_transformed.find_consumer(model_transformed.graph.input[0].name).op_type
    node1_input1_model_transformed = model_transformed.find_consumer(model_transformed.graph.input[1].name).op_type
    assert node0_input0_model != node0_input0_model_transformed
    assert node1_input1_model != node1_input1_model_transformed

    # Check if data_layout is set correctly
    mt0_input = model_transformed.graph.node[0].input[0]
    mt1_input = model_transformed.graph.node[1].input[0]
    mt0_output = model_transformed.graph.node[0].output[0]
    mt1_output = model_transformed.graph.node[1].output[0]
    if perm == [0, 3, 1, 2]:
        assert model_transformed.get_tensor_layout(mt0_input) == DataLayout.NHWC
        assert model_transformed.get_tensor_layout(mt1_input) == DataLayout.NHWC
        assert model_transformed.get_tensor_layout(mt0_output) == DataLayout.NHWC
        assert model_transformed.get_tensor_layout(mt1_output) == DataLayout.NHWC
    if perm == [0, 2, 3, 1]:
        assert model_transformed.get_tensor_layout(mt0_input) == DataLayout.NCHW
        assert model_transformed.get_tensor_layout(mt1_input) == DataLayout.NCHW
        assert model_transformed.get_tensor_layout(mt0_output) == DataLayout.NCHW
        assert model_transformed.get_tensor_layout(mt1_output) == DataLayout.NCHW
