In [None]:
import numpy as np

from onnx import helper as oh
from onnx import TensorProto

from finn.core.modelwrapper import ModelWrapper
from finn.util.datatype import DataType
import finn.core.onnx_exec as oxe

def create_model():
    
    Mul1_node = oh.make_node(
        "Mul",
        inputs=['in1_mul1', 'in2_mul1'],
        outputs=['out_mul1']
    )
    
    in1_mul1 = oh.make_tensor_value_info('in1_mul1', TensorProto.FLOAT, [1, 29, 128, 1])
    in2_mul1 = oh.make_tensor_value_info('in2_mul1', TensorProto.FLOAT, [1])
    out_mul1 = oh.make_tensor_value_info('out_mul1', TensorProto.FLOAT, [1, 29, 128, 1])
    
    graph = oh.make_graph(
        nodes=[Mul1_node],
        name="test_graph",
        inputs=[in1_mul1],
        outputs=[out_mul1],
        value_info=[in2_mul1]
    )
    
    onnx_model = oh.make_model(graph, producer_name = "test_model")
    model = ModelWrapper(onnx_model)
    
    #weight_dtype = DataType.FLOAT32
    #mul_weight = np.random.randint(low = idt.min(), high = idt.max(), size = [1])
    mul_weight = np.random.randint(low = np.finfo(np.float32).min, high = np.finfo(np.float32).max, size = [1])
    
    model.set_initializer('in2_mul1', mul_weight)
    
    return model


model = create_model()


###
model.save("/tmp/test_scalar_mul_to_channelwise_mul.onnx")
from finn.util.visualization import showInNetron
showInNetron("/tmp/test_scalar_mul_to_channelwise_mul.onnx")
###



In [None]:
from finn.transformation.base import Transformation
from finn.core.modelwrapper import ModelWrapper
import warnings

class ScalarMulToChannelwiseMul(Transformation):
    def __init__(self):
        super().__init__()
        
    def apply(self, model):
        graph = model.graph
        graph_modified = False
        
        for n in graph.node:
            if n.op_type == "Mul":
                mul_weight_name = n.input[1]
                mul_weight = model.get_initializer(mul_weight_name)
                try:
                    is_scalar = all(x==1 for x in mul_weight.name)
                except TypeError or ValueError:
                    warnings.warn("Mul param is not constant, skipping")
                    continue
                    
                if is_scalar: # is_initializer
                    # How to get the channel dimension? NCHW, NHWC
                    # 1) Infer data layout? 
                    # 2) Trace consumers until you find a MultiThreshold node, which will tell the data_layout.
                    #    Also keep track of Transpose nodes then
        
        
        return (model, graph_modified)
    