In [29]:
push!(LOAD_PATH, ".");
using Knet; using KnetONNX;

In [30]:
g = ONNXtoGraph("branch1.onnx");
PrintGraph(g)

model inputs: ["input"]
model outputs: ["7"]
(op1) Gemm
	input1: input
	input2: linear1.weight
	input3: linear1.bias
	output1: 5
(op2) Gemm
	input1: input
	input2: linear2.weight
	input3: linear2.bias
	output1: 6
(op3) Add
	input1: 5
	input2: 6
	output1: 7


In [31]:
struct ModelLayer
    inputs #list of strings
    layer # a KnetLayer
    outputs #list of strings
end

function ModelLayer(node, g)
    (args, layer, outputs) = KnetONNX.node_to_layer(node, g)
    ModelLayer(args, layer, outputs)
end

function get_ModelLayers(g)
    ModelLayers = []
    for node in g.node; push!(ModelLayers, ModelLayer(node, g)); end
    return ModelLayers
end

get_ModelLayers (generic function with 1 method)

In [38]:
# graph node, graph -> ModelLayer
function node_to_add(node, g)
    args = node.input
    outs = node.output[1]
    #return (args, layer, outs)
    args
end



node3 = g.node[3]
node_to_add(node3, g)

2-element Array{AbstractString,1}:
 "5"
 "6"

In [13]:
mutable struct KnetModel
    tensors
    model_layers
    model_inputs
    model_outputs
end

function KnetModel(g::KnetONNX.Types.Graph)
    tensors = TensorDict(g)
    model_layers = get_ModelLayers(g)
    model_inputs = [i.name for i in g.input]
    model_outputs = [o.name for o in g.output]
    KnetModel(tensors, model_layers, model_inputs, model_outputs)
end

function TensorDict(g::KnetONNX.Types.Graph)
    tensors = Dict()
    for node in g.node
        for input in node.input; tensors[input] = Nothing; end
        for output in node.output; tensors[output] = Nothing; end
    end
    for (node, value) in g.initializer; tensors[node] = value; end
    tensors
end

TensorDict (generic function with 1 method)

In [20]:
# model, layer -> compute forward pass
function forward(km::KnetModel, ml::ModelLayer)
    
    #check if there is a nothing in tensors, no need to copy'em all
    for input in ml.inputs
        if km.tensors[input] == Nothing; return "oops"; end
    end

    # gather inputs in a list
    inputs = (key-> km.tensors[key]).(ml.inputs)
    
    # forward pass (only works if there is just one input)
    # realize that node_to_gemm already initialized the weights 
    out = ml.layer(inputs[1])
    
    #ideally
    #out = ml.layer(inputs)
    
    # output
    km.tensors[ml.outputs[1]] = out
end

function (m::KnetModel)(x)
    # first, insert x into model.tensors
    # might be multiple model inputs
    m.tensors[m.model_inputs[1]] = x
    
    
    # do until all model.tensors are filled 
    # iterate over all layers and call forward on that layer
    while Nothing in values(m.tensors)
        for layer in m.model_layers
            forward(m, layer)
        end
    end
    
    # could be multiple
    if length(m.model_outputs) == 1; return m.tensors[m.model_outputs[1]]; 
        else; outs = []; for out in m.model_outputs; push!(outs, m.tensors[out]); end; return outs; end
end

In [25]:
model = KnetModel(g);

dummy_input = ones(40,5);
model(dummy_input)

2×5 Array{Float64,2}:
  0.558039   0.558039   0.558039   0.558039   0.558039
 -0.149698  -0.149698  -0.149698  -0.149698  -0.149698