In [None]:
push!(LOAD_PATH, ".");
using Knet; using KnetONNX;

┌ Info: Recompiling stale cache file /Users/egeersu/.julia/compiled/v1.2/KnetONNX.ji for KnetONNX [top-level]
└ @ Base loading.jl:1240


  Updating registry at `~/.julia/registries/General`
  Updating git-repo `https://github.com/JuliaRegistries/General.git`
[2K[?25hFetching: [>                                        ]  0.0 %

In [None]:
g = ONNXtoGraph("MLP.onnx");
PrintGraph(g)

In [None]:
struct ModelLayer
    inputs #list of strings
    layer # a KnetLayer
    outputs #list of strings
end

function ModelLayer(node, g)
    (args, layer, outputs) = node_to_layer(node, g)
    ModelLayer(args, layer, outputs)
end

function get_ModelLayers(g)
    ModelLayers = []
    for node in g.node; push!(ModelLayers, ModelLayer(node, g)); end
    return ModelLayers
end

In [None]:
# graph node, graph -> ModelLayer
function node_to_layer(node, g)
    if node.op_type == "Gemm"; return node_to_gemm(node, g); end
    if node.op_type == "Add"; return node_to_add(node, g); end
end

In [None]:
#returns (names of tensors used for forward pass, KnetLayer, output tensor names)
function node_to_gemm(node, g)
    input1 = node.input[1]
    
    #the layer is a Knet Layer
    layer = KnetONNX.KnetLayers.Linear(input=1,output=1)
    
    # use g.initializer to modify KnetLayer
    w_name = node.input[2]
    b_name = node.input[3]
    w = g.initializer[w_name]
    b = g.initializer[b_name]
    layer.bias = b
    layer.mult.weight = transpose(w)
    
    # return input tensor NAMES, it is called args: [input1, ...]
    # you can take the inputs from model.tensors using these names
    args = [input1]
    outs = [node]
   
    # returns these 3, use these to create ModelLayer
    (args, layer, node.output)
end

In [None]:
# graph node, graph -> ModelLayer
struct AddLayer; end
(a::AddLayer)(x,y) = x+y

function node_to_add(node, g)
    args = node.input
    outs = node.output
    layer = AddLayer()
    return (args, layer, outs)
end

node3 = g.node[3]
ModelLayer(node3, g)

In [None]:
mutable struct KnetModel
    tensors
    model_layers
    model_inputs
    model_outputs
end

function KnetModel(g::KnetONNX.Types.Graph)
    model_layers = get_ModelLayers(g)
    tensors = TensorDict2(model_layers)
    model_inputs = [i.name for i in g.input]
    model_outputs = [o.name for o in g.output]
    KnetModel(tensors, model_layers, model_inputs, model_outputs)
end

function TensorDict2(model_layers)
    tensors = Dict()
    for layer in model_layers
        for input in layer.inputs; tensors[input] = Nothing; end
        for input in layer.outputs; tensors[input] = Nothing; end
    end
    tensors
end

function TensorDict(g::KnetONNX.Types.Graph)
    tensors = Dict()
    for node in g.node
        for input in node.input; tensors[input] = Nothing; end
        for output in node.output; tensors[output] = Nothing; end
    end
    for (node, value) in g.initializer; tensors[node] = value; end
    tensors
end

In [None]:
model = KnetModel(g)
model.tensors

In [None]:
# model, layer -> compute forward pass
function forward(km::KnetModel, ml::ModelLayer)
    
        # GATHER INPUTS
    for input in ml.inputs
        if km.tensors[input] == Nothing; return "oops!"; end
    end

        # FORWARD PASS
        # if only one input is requried, pass the first element
        # if more than one input is required, pass all elements
        # simply check the length of requried inputs for the model
    inputs = (key-> km.tensors[key]).(ml.inputs)
    if length(inputs) == 1; out = ml.layer(inputs[1]); 
        else; out = ml.layer(inputs...); end
    
        # SAVE OUTPUTS
        # check if there are multiple outputs (rnn etc.) before saving them to model.tensors
    if length(ml.outputs) == 1; km.tensors[ml.outputs[1]] = out; 
        else; for output in ml.outputs; km.tensors[output] = out; end; end
 end

function (m::KnetModel)(x)
        
        # REGISTER X
    
    #dumb version
    # check if we want multiple inputs (x should be a list) or a single input (x is a single array)
    #if length(m.model_inputs) == 1; m.tensors[m.model_inputs[1]] = x; 
    #    else; for (i,model_input) in enumerate(m.model_inputs); m.tensors[model_input] = x[i]; end; end
    
    m.tensors[m.model_inputs...] = x
    
    #m.tensors[m.model_inputs...] = 100

        # LOOP UNTIL ALL TENSORS ARE CALCULATED
    # do until all model.tensors are filled 
    # iterate over all layers and call forward on that layer
    while Nothing in values(m.tensors)
        for layer in m.model_layers
            forward(m, layer)
        end
    end
    
        # RETURN MODEL OUTPUTS
    m.tensors[m.model_outputs...]
    #= DUMB VERSION
    # could be multiple
    if length(m.model_outputs) == 1; return m.tensors[m.model_outputs[1]]; 
        else; outs = []; for out in m.model_outputs; push!(outs, m.tensors[out]); end; return outs; end
    =#
    
end

In [None]:
model = KnetModel(g)
model.tensors

In [None]:
dummy_input = ones(100,50);
model(dummy_input)

In [None]:
model.tensors