In [37]:
using KnetOnnx

In [38]:
graph = ONNXtoGraph("model_gather.onnx")
PrintGraph(graph)

model inputs: ["0"]
model outputs: ["4"]
(op1) Constant
	output1: 1
(op2) Gather
	input1: 0
	input2: 1
	output1: 2
(op3) Constant
	output1: 3
(op4) Gather
	input1: 2
	input2: 3
	output1: 4


In [39]:
model = KnetModel(graph);

In [40]:
x1 = randn(3,4)

3×4 Array{Float64,2}:
 0.173932  -0.538634  -0.293712  -0.0894061
 0.386333  -0.458431  -0.201749   0.209467 
 1.03815    0.291993   0.536754   0.38918  

In [41]:
model(x1)

2-element Array{Any,1}:
 Any[[0.173932028981112, 0.38633331075208355, 1.038149444151256], [0.173932028981112, 0.38633331075208355, 1.038149444151256]]
 Any[[0.173932028981112, 0.38633331075208355, 1.038149444151256], [0.173932028981112, 0.38633331075208355, 1.038149444151256]]

In [42]:
model.tensors

Dict{Any,Any} with 5 entries:
  "4" => Any[Any[[0.173932, 0.386333, 1.03815], [0.173932, 0.386333, 1.03815]],…
  "1" => Float32[0.0, 0.0]
  "0" => [0.173932 -0.538634 -0.293712 -0.0894061; 0.386333 -0.458431 -0.201749…
  "2" => Any[[0.173932, 0.386333, 1.03815], [0.173932, 0.386333, 1.03815]]
  "3" => Float32[0.0, 0.0]

In [None]:
struct Gather
    axis #increment before adding (julia 0->1)
end

function (g::Gather)(data, indices)
    indices_size = size(indices)

    indices = (x->(x+1)).(indices) # increment for Julia
    indices = (x->(Int32(x))).(indices) # set floats to Int for bug-free indexing

    if length(indices_size) == 1
        return gather_rank1(data, indices)
    end
    if length(indices_size) == 2
        return gather_rank2(data, indices)
    end
    if length(indices_size) > 2
        print("Gather for indices with rank > 2 are not implemented yet.")
    end
end


function gather_rank1(data, indices)
    new_data = []
    axis1 = size(indices)[1]
    for a1 in (1:axis1)
        current_index = indices[a1]
        #get_data = data[:,current_index]
        get_data = data[:,current_index]
        push!(new_data, get_data)
    end
    new_data
end

function gather_rank2(data, indices)
    new_data = []
    axis1, axis2 = size(indices)

    for a1 in (1:axis1)
        mini_list = []
        for a2 in (1:axis2)
            current_index = indices[a1,a2]
            get_data = data[:,current_index]
            push!(mini_list, get_data)
        end
        push!(new_data, mini_list)
    end
    new_data
end

function converter_gather(node, g)
    args = node.input
    outs = node.output
    axis = node.attribute[:axis] +1 #+1 is for Julia
    layer = KL.Gather(axis)
    (args, layer, outs)
end