In [64]:
import numpy as np
from onnx import helper as oh
from onnx import TensorProto

from finn.core.modelwrapper import ModelWrapper
from finn.core.datatype import DataType
from finn.util.basic import gen_finn_dt_tensor
import finn.core.onnx_exec as oxe


MultiThreshold0_node = oh.make_node(
    "MultiThreshold",
    inputs = ['in1_multithreshold0', 'in2_multithreshold0'],
    outputs = ['out_multithreshold0'],
    name = 'MultiThreshold0',
    domain = 'finn.custom_op.general',
    out_dtype = 'UINT4',
)

Conv0_node = oh.make_node(
    "Conv",
    inputs = ['out_multithreshold0', 'in2_conv0'],
    outputs = ['out_conv0'],
    name = 'Conv0',
    dilations = [1, 1],
    group = 1,
    kernel_shape = [1, 1],
    pads = [0, 0, 0, 0],
    strides = [1, 1]
)

Conv1_node = oh.make_node(
    "Conv",
    inputs = ['out_multithreshold0', 'in2_conv1'],
    outputs = ['out_conv1'],
    name = 'Conv1',
    dilations = [1, 1],
    group = 1,
    kernel_shape = [1, 1],
    pads = [0, 0, 0, 0],
    strides = [1, 1]
)

MultiThreshold1_node = oh.make_node(
    "MultiThreshold",
    inputs = ['out_conv0', 'in2_multithreshold1'],
    outputs = ['out_multithreshold1'],
    name = 'MultiThreshold1',
    domain = 'finn.custom_op.general',
    out_dtype = 'UINT4'
)

MultiThreshold2_node = oh.make_node(
    "MultiThreshold",
    inputs = ['out_conv1', 'in2_multithreshold2'],
    outputs = ['out_multithreshold2'],
    name = 'MultiThreshold2',
    domain = 'finn.custom_op.general',
    out_dtype = 'UINT4'
)

Add0_node = oh.make_node(
    "Add",
    inputs = ['out_multithreshold1', 'out_multithreshold2'],
    outputs = ['out_add0'],
    name = 'Add0'
)

MultiThreshold3_node = oh.make_node(
    "MultiThreshold",
    inputs = ['out_add0', 'in2_multithreshold3'],
    outputs = ['out_multithreshold3'],
    name = 'MultiThreshold3',
    domain = 'finn.custom_op.general',
    out_dtype = 'UINT4'
)

Conv2_node = oh.make_node(
    "Conv",
    inputs = ['out_multithreshold3', 'in2_conv2'],
    outputs = ['out_conv2'],
    name = 'Conv2',
    dilations = [1, 1],
    group = 1,
    kernel_shape = [1, 1],
    pads = [0, 0, 0, 0],
    strides = [1, 1]
)

Conv3_node = oh.make_node(
    "Conv",
    inputs = ['out_multithreshold3', 'in2_conv3'],
    outputs = ['out_conv3'],
    name = 'Conv3',
    dilations = [1, 1],
    group = 1,
    kernel_shape = [1, 1],
    pads = [0, 0, 0, 0],
    strides = [1, 1]
)

MultiThreshold4_node = oh.make_node(
    "MultiThreshold",
    inputs = ['out_conv2', 'in2_multithreshold4'],
    outputs = ['out_multithreshold4'],
    name = 'MultiThreshold4',
    domain = 'finn.custom_op.general',
    out_dtype = 'UINT4'
)

MultiThreshold5_node = oh.make_node(
    "MultiThreshold",
    inputs = ['out_conv3', 'in2_multithreshold5'],
    outputs = ['out_multithreshold5'],
    name = 'MultiThreshold5',
    domain = 'finn.custom_op.general',
    out_dtype = 'UINT4'
)

Add1_node = oh.make_node(
    "Add",
    inputs = ['out_multithreshold4', 'out_multithreshold5'],
    outputs = ['out_add1'],
    name = 'Add1'
)

# Inputs/outputs (global)
in1_multithreshold0 = oh.make_tensor_value_info('in1_multithreshold0', TensorProto.FLOAT, [1, 256, 128, 1])
out_add1 = oh.make_tensor_value_info('out_add1', TensorProto.FLOAT, [1, 256, 128, 1])

# Initializers
in2_multithreshold0 = oh.make_tensor_value_info('in2_multithreshold0', TensorProto.FLOAT, [256, 15])
in2_conv0 = oh.make_tensor_value_info('in2_conv0', TensorProto.FLOAT, [256, 256, 1, 1])
in2_conv1 = oh.make_tensor_value_info('in2_conv1', TensorProto.FLOAT, [256, 256, 1 ,1])
in2_multithreshold1 = oh.make_tensor_value_info('in2_multithreshold1', TensorProto.FLOAT, [256, 15])
in2_multithreshold2 = oh.make_tensor_value_info('in2_multithreshold2', TensorProto.FLOAT, [256, 15])
in2_multithreshold3 = oh.make_tensor_value_info('in2_multithreshold3', TensorProto.FLOAT, [256, 15])
in2_conv2 = oh.make_tensor_value_info('in2_conv2', TensorProto.FLOAT, [256, 256, 1, 1])
in2_conv3 = oh.make_tensor_value_info('in2_conv3', TensorProto.FLOAT, [256, 256, 1, 1])
in2_multithreshold4 = oh.make_tensor_value_info('in2_multithreshold4', TensorProto.FLOAT, [256, 15])
in2_multithreshold5 = oh.make_tensor_value_info('in2_multithreshold5', TensorProto.FLOAT, [256, 15])

# Value_infos
out_multithreshold0 = oh.make_tensor_value_info('out_multithreshold0', TensorProto.FLOAT, [1, 256, 128, 1])
out_conv0 = oh.make_tensor_value_info('out_conv0', TensorProto.FLOAT, [1, 256, 128, 1])
out_conv1 = oh.make_tensor_value_info('out_conv1', TensorProto.FLOAT, [1, 256, 128, 1])
out_multithreshold1 = oh.make_tensor_value_info('out_multithreshold1', TensorProto.FLOAT, [1, 256, 128, 1])
out_multithreshold2 = oh.make_tensor_value_info('out_multithreshold2', TensorProto.FLOAT, [1, 256, 128, 1])
out_add0 = oh.make_tensor_value_info('out_add0', TensorProto.FLOAT, [1, 256, 128, 1])
out_multithreshold3 = oh.make_tensor_value_info('out_multithreshold3', TensorProto.FLOAT, [1, 256, 128, 1])
out_conv2 = oh.make_tensor_value_info('out_conv2', TensorProto.FLOAT, [1, 256, 128, 1])
out_conv3 = oh.make_tensor_value_info('out_conv3', TensorProto.FLOAT, [1, 256, 128, 1])
out_multithreshold4 = oh.make_tensor_value_info('out_multithreshold4', TensorProto.FLOAT, [1, 256, 128, 1])
out_multithreshold5 = oh.make_tensor_value_info('out_multithreshold5', TensorProto.FLOAT, [1, 256, 128, 1])

graph = oh.make_graph(
    nodes = [
        MultiThreshold0_node, 
        Conv0_node,
        Conv1_node,
        MultiThreshold1_node,
        MultiThreshold2_node,
        Add0_node,
        MultiThreshold3_node,
        Conv2_node,
        Conv3_node,
        MultiThreshold4_node,
        MultiThreshold5_node,
        Add1_node        
    ],
    name = "test_graph",
    inputs = [in1_multithreshold0],
    outputs = [out_add1],
    value_info = [
        in2_multithreshold0,
        in2_conv0,
        in2_conv1,
        in2_multithreshold1,
        in2_multithreshold2,
        in2_multithreshold3,
        in2_conv2,
        in2_conv3,
        in2_multithreshold4,
        in2_multithreshold5,
        out_multithreshold0,
        out_conv0,
        out_conv1,
        out_multithreshold1,
        out_multithreshold2,
        out_add0,
        out_multithreshold3,
        out_conv2,
        out_conv3,
        out_multithreshold4,
        out_multithreshold5
    ]
)

onnx_model = oh.make_model(graph, producer_name = "test_model")
model = ModelWrapper(onnx_model)

model.set_tensor_datatype('in2_conv0', DataType.INT4)
model.set_tensor_datatype('in2_conv1', DataType.INT4)
model.set_tensor_datatype('in2_conv2', DataType.INT4)
model.set_tensor_datatype('in2_conv3', DataType.INT4)

mt_weights = np.random.randint(low=-1000, high=1000, size=[6, 256, 15])
mt_weights = np.sort(mt_weights, 2)
for i in range(0,6):
    model.set_initializer('in2_multithreshold'+str(i), mt_weights[i])

conv_weights = np.random.randint(low=-8, high=7, size=[4, 256, 256, 1, 1]).astype(np.float32)    
for i in range(0,4):
    model.set_initializer('in2_conv'+str(i), conv_weights[i])
    
model.save("/tmp/test_extend_partition.onnx")
from finn.util.visualization import showInNetron
showInNetron("/tmp/test_extend_partition.onnx")


Stopping http://0.0.0.0:8081
Serving '/tmp/test_extend_partition.onnx' at http://0.0.0.0:8081


In [65]:
## Partition the graph first
from finn.transformation.create_generic_partitions import PartitionFromDict

model = ModelWrapper("/tmp/test_extend_partition.onnx")

partitionings = {0: range(0, 6), 1: range(6, 12)}

model = model.transform(PartitionFromDict(partitionings))

model.save("/tmp/test_extend_partition_partitioned.onnx")

showInNetron("/tmp/test_extend_partition_partitioned.onnx")

Stopping http://0.0.0.0:8081
Serving '/tmp/test_extend_partition_partitioned.onnx' at http://0.0.0.0:8081


# Transformation

In [60]:
from finn.util.basic import get_by_name

model = ModelWrapper("/tmp/test_extend_partition_partitioned.onnx")

partitions_to_expand = [0,1]

# Get information about partitions
node_ind = 0

partition_node_ind = [ind for ind,n in enumerate(model.graph.node) if n.op_type=='GenericPartition']

nodes = [n for n in model.graph.node]

for n in nodes:
    if n.op_type == 'GenericPartition' and node_ind in partitions_to_expand:
        path_to_model = get_by_name(n.attribute, 'model', 'name').s.decode('utf-8')
        model_partition = ModelWrapper(path_to_model)

        # Append nodes
        for partition_node in model_partition.graph.node:
            model.graph.node.append(partition_node)

        # Append value infos
        partition_valueinfos = [x.name for x in model_partition.graph.value_info]
        for vi_name in partition_valueinfos:
            vi = model_partition.get_tensor_valueinfo(vi_name)
            model.graph.value_info.append(vi)

        # Append initializers
        partition_initializers = [x for x in model_partition.graph.initializer]
        for i in partition_initializers:
            model.graph.initializer.append(i)

        # Append tensor annotation
        partition_annotations = [x for x in model_partition.graph.quantization_annotation]
        for a in partition_annotations:
            model.graph.quantization_annotation.append(a)

        model.graph.node.remove(n)
                
    node_ind += 1

from finn.transformation.general import SortGraph
model = model.transform(SortGraph())

model.save("/tmp/test_extend_partition_undo.onnx")

showInNetron("/tmp/test_extend_partition_undo.onnx")

Stopping http://0.0.0.0:8081
Serving '/tmp/test_extend_partition_undo.onnx' at http://0.0.0.0:8081


In [66]:
from finn.util.basic import get_by_name
from finn.transformation.general import SortGraph

model = ModelWrapper("/tmp/test_extend_partition_partitioned.onnx")
graph = model.graph
partitions_to_expand = [0,1]

partition_node_ind = {ind: n for ind,n in enumerate(model.graph.node) if n.op_type=='GenericPartition'}

for k, v in partition_node_ind.items():
    if k in partitions_to_expand:
        path_to_model = get_by_name(v.attribute, 'model', 'name').s.decode('utf-8')
        model_partition = ModelWrapper(path_to_model)
        
        # Append nodes
        for partition_node in model_partition.graph.node:
            graph.node.append(partition_node)
            
        # Append value infos
        partition_valueinfos = [x.name for x in model_partition.graph.value_info]
        for vi_name in partition_valueinfos:
            vi = model_partition.get_tensor_valueinfo(vi_name)
            graph.value_info.append(vi)

        # Append initializers
        partition_initializers = [x for x in model_partition.graph.initializer]
        for i in partition_initializers:
            graph.initializer.append(i)

        # Append tensor annotation
        partition_annotations = [x for x in model_partition.graph.quantization_annotation]
        for a in partition_annotations:
            graph.quantization_annotation.append(a)

        graph.node.remove(v)
        graph_modified = True
        
if graph_modified:
    model = model.transform(SortGraph())

model.save("/tmp/test_extend_partition_undo.onnx")

showInNetron("/tmp/test_extend_partition_undo.onnx")

Stopping http://0.0.0.0:8081
Serving '/tmp/test_extend_partition_undo.onnx' at http://0.0.0.0:8081


# Test

In [68]:
import finn.core.onnx_exec as oxe

# Partitioned model
model = ModelWrapper("/tmp/test_extend_partition_partitioned.onnx")

# Create input data
input0_tensor_name = model.graph.input[0].name

input_shape = model.get_tensor_shape(input0_tensor_name)
input_dtype = model.get_tensor_datatype(input0_tensor_name)
input_val = gen_finn_dt_tensor(input_dtype, input_shape)
input_dict = {}
input_dict[input0_tensor_name] = input_val

# Unpacked model
#model_transformed = model.transform(MoveTransposePastMultiThreshold())
model_transformed = ModelWrapper("/tmp/test_extend_partition_undo.onnx")

assert oxe.compare_execution(model, model_transformed, input_dict)

# Check if data_types are retained
for n in model_transformed.graph.node:
    if n.op_type=="Conv":
        assert model_transformed.get_tensor_datatype(n.input[1])==DataType.INT4



In [None]:
# Copyright (c) 2020, Xilinx
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of FINN nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from finn.transformation.base import Transformation
from finn.util.basic import get_by_name
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.general import SortGraph

class UnfoldPartitions(Transformation):
    """Unfolds GenericPartition type nodes by inserting the graph pointed to by
     the model attribute.
     Argument 0: unfolding_index
     * List that contains the node indices of the GenericPartition nodes
     """

    def __init__(self, unfolding_index):
        super().__init__()
        self.unfolding_index = unfolding_index

    def apply(self, model):
        graph = model.graph
        graph_modified = False

        partition_nodes_dict = {ind: n for ind, n in enumerate(graph.node) if n.op_type == 'GenericPartition'}

        for k, v in partition_nodes_dict.items():
            if k in self.unfolding_index:
                path_to_model = get_by_name(v.attribute, 'model', 'name').s.decode('utf-8')
                model_partition = ModelWrapper(path_to_model)

                # Append nodes
                for partition_node in model_partition.graph.node:
                    graph.node.append(partition_node)

                # Append value infos
                partition_valueinfos = [x.name for x in model_partition.graph.value_info]
                for vi_name in partition_valueinfos:
                    vi = model_partition.get_tensor_valueinfo(vi_name)
                    graph.value_info.append(vi)

                # Append initializers
                partition_initializers = [x for x in model_partition.graph.initializer]
                for i in partition_initializers:
                    graph.initializer.append(i)

                # Append tensor annotations, except for the input/output tensors
                # as these will be retained in the 'upper' model
                # after partitioning.
                in_out_names = [x.name for x in model_partition.graph.input]
                in_out_names += [x.name for x in model_partition.graph.output]
                partition_annotations = [x for x in model_partition.graph.quantization_annotation if x.tensor_name not in in_out_names]
                for a in partition_annotations:
                    graph.quantization_annotation.append(a)

                graph.node.remove(v)
                graph_modified = True

        if graph_modified:
            model = model.transform(SortGraph())

        return (model, graph_modified)
