In [5]:
push!(LOAD_PATH, ".");
using Knet; using KnetONNX;

In [93]:
g = ONNXtoGraph("branch1.onnx");
PrintGraph(g)

model inputs: ["input"]
model outputs: ["7"]
(op1) Gemm
	input1: input
	input2: linear1.weight
	input3: linear1.bias
	output1: 5
(op2) Gemm
	input1: input
	input2: linear2.weight
	input3: linear2.bias
	output1: 6
(op3) Add
	input1: 5
	input2: 6
	output1: 7


In [94]:
struct ModelLayer
    inputs #list of strings
    layer # a KnetLayer
    outputs #list of strings
end

function ModelLayer(node, g)
    (args, layer, outputs) = node_to_layer(node, g)
    ModelLayer(args, layer, outputs)
end

function get_ModelLayers(g)
    ModelLayers = []
    for node in g.node; push!(ModelLayers, ModelLayer(node, g)); end
    return ModelLayers
end

get_ModelLayers (generic function with 1 method)

In [95]:
# graph node, graph -> ModelLayer
function node_to_layer(node, g)
    if node.op_type == "Gemm"; return node_to_gemm(node, g); end
    if node.op_type == "Add"; return node_to_add(node, g); end
end

node_to_layer (generic function with 1 method)

In [96]:
#returns (names of tensors used for forward pass, KnetLayer, output tensor names)
function node_to_gemm(node, g)
    input1 = node.input[1]
    
    #the layer is a Knet Layer
    layer = KnetONNX.KnetLayers.Linear(input=1,output=1)
    
    # use g.initializer to modify KnetLayer
    w_name = node.input[2]
    b_name = node.input[3]
    w = g.initializer[w_name]
    b = g.initializer[b_name]
    layer.bias = b
    layer.mult.weight = transpose(w)
    
    # return input tensor NAMES, it is called args: [input1, ...]
    # you can take the inputs from model.tensors using these names
    args = [input1]
    outs = [node]
   
    # returns these 3, use these to create ModelLayer
    (args, layer, node.output)
end

node_to_gemm (generic function with 1 method)

In [97]:
# graph node, graph -> ModelLayer
struct AddLayer; end
(a::AddLayer)(x,y) = x+y

function node_to_add(node, g)
    args = node.input
    outs = node.output
    layer = AddLayer()
    return (args, layer, outs)
end

node3 = g.node[3]
ModelLayer(node3, g)

ModelLayer(AbstractString["5", "6"], AddLayer(), AbstractString["7"])

In [98]:
mutable struct KnetModel
    tensors
    model_layers
    model_inputs
    model_outputs
end

function KnetModel(g::KnetONNX.Types.Graph)
    tensors = TensorDict(g)
    model_layers = get_ModelLayers(g)
    model_inputs = [i.name for i in g.input]
    model_outputs = [o.name for o in g.output]
    KnetModel(tensors, model_layers, model_inputs, model_outputs)
end

function TensorDict(g::KnetONNX.Types.Graph)
    tensors = Dict()
    for node in g.node
        for input in node.input; tensors[input] = Nothing; end
        for output in node.output; tensors[output] = Nothing; end
    end
    for (node, value) in g.initializer; tensors[node] = value; end
    tensors
end

TensorDict (generic function with 1 method)

In [110]:
model = KnetModel(g)
model.tensors

Dict{Any,Any} with 8 entries:
  "linear1.bias"   => Float32[0.0216016, 0.0483992, -0.0252835, -0.00927702, 0.…
  "linear2.weight" => Float32[0.0591993 -0.0515833 … 5.5477e-5 -0.089789; -0.03…
  "5"              => Nothing
  "linear2.bias"   => Float32[0.0425705, 0.0738719, -0.00153907, 0.064978, 0.06…
  "6"              => Nothing
  "7"              => Nothing
  "input"          => Nothing
  "linear1.weight" => Float32[-0.0625393 0.00734643 … 0.0502301 -0.0454649; -0.…

In [244]:
# model, layer -> compute forward pass
function forward(km::KnetModel, ml::ModelLayer)
    
        # GATHER INPUTS
    for input in ml.inputs
        if km.tensors[input] == Nothing; return "oops!"; end
    end

        # FORWARD PASS
        # if only one input is requried, pass the first element
        # if more than one input is required, pass all elements
        # simply check the length of requried inputs for the model
    inputs = (key-> km.tensors[key]).(ml.inputs)
    if length(inputs) == 1; out = ml.layer(inputs[1]); 
        else; out = ml.layer(inputs...); end
    
        # SAVE OUTPUTS
        # check if there are multiple outputs (rnn etc.) before saving them to model.tensors
    if length(ml.outputs) == 1; km.tensors[ml.outputs[1]] = out; 
        else; for output in ml.outputs; km.tensors[output] = out; end; end
 end

function (m::KnetModel)(x)
        
        # REGISTER X
    
    #dumb version
    # check if we want multiple inputs (x should be a list) or a single input (x is a single array)
    #if length(m.model_inputs) == 1; m.tensors[m.model_inputs[1]] = x; 
    #    else; for (i,model_input) in enumerate(m.model_inputs); m.tensors[model_input] = x[i]; end; end
    
    m.tensors[m.model_inputs...] = x
    
    #m.tensors[m.model_inputs...] = 100

        # LOOP UNTIL ALL TENSORS ARE CALCULATED
    # do until all model.tensors are filled 
    # iterate over all layers and call forward on that layer
    while Nothing in values(m.tensors)
        for layer in m.model_layers
            forward(m, layer)
        end
    end
    T
        # RETURN MODEL OUTPUTS
    m.tensors[m.model_outputs...]
    #= DUMB VERSION
    # could be multiple
    if length(m.model_outputs) == 1; return m.tensors[m.model_outputs[1]]; 
        else; outs = []; for out in m.model_outputs; push!(outs, m.tensors[out]); end; return outs; end
    =#
    
end

In [246]:
model = KnetModel(g)
model.tensors

Dict{Any,Any} with 8 entries:
  "linear1.bias"   => Float32[0.0216016, 0.0483992, -0.0252835, -0.00927702, 0.…
  "linear2.weight" => Float32[0.0591993 -0.0515833 … 5.5477e-5 -0.089789; -0.03…
  "5"              => Nothing
  "linear2.bias"   => Float32[0.0425705, 0.0738719, -0.00153907, 0.064978, 0.06…
  "6"              => Nothing
  "7"              => Nothing
  "input"          => Nothing
  "linear1.weight" => Float32[-0.0625393 0.00734643 … 0.0502301 -0.0454649; -0.…

In [247]:
dummy_input = ones(100,50);
model(dummy_input)

10×50 Array{Float64,2}:
  0.115668   0.115668   0.115668  …   0.115668   0.115668   0.115668
  0.661428   0.661428   0.661428      0.661428   0.661428   0.661428
  1.26032    1.26032    1.26032       1.26032    1.26032    1.26032 
 -0.799757  -0.799757  -0.799757     -0.799757  -0.799757  -0.799757
 -0.14344   -0.14344   -0.14344      -0.14344   -0.14344   -0.14344 
  0.188061   0.188061   0.188061  …   0.188061   0.188061   0.188061
 -0.815466  -0.815466  -0.815466     -0.815466  -0.815466  -0.815466
  1.73175    1.73175    1.73175       1.73175    1.73175    1.73175 
  0.508684   0.508684   0.508684      0.508684   0.508684   0.508684
 -1.20156   -1.20156   -1.20156      -1.20156   -1.20156   -1.20156 