Skip to content

Commit

Permalink
Handle broadcasting in load_proto
Browse files Browse the repository at this point in the history
Closes #118
  • Loading branch information
malmaud committed Jan 30, 2017
1 parent 93b8431 commit 94303d4
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -553,11 +553,20 @@ function load_proto(tensor::tensorflow.TensorProto)
val = reinterpret(eltype(val), tensor.tensor_content)
end
end
if length(val) == 0
if length(val) == 0 && length(dim) == 0
zeros(eltype(val),0)
elseif length(dim) == 0
val[1]
else
# https://www.tensorflow.org/api_docs/python/constant_op/constant_value_tensors#constant
if length(val) < prod(dim)
last_val = val[end]
original_length = length(val)
resize!(val, prod(dim))
for i in (original_length+1):length(val)
val[i] = last_val
end
end
reshape(val, dim) |> convert_major_order
end
end
Expand Down
1 change: 1 addition & 0 deletions test/proto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ let
t = get_def(constant(val)).attr["value"].tensor
@test all(load_proto(t) .== map(Vector{UInt8}, val))
end

0 comments on commit 94303d4

Please sign in to comment.