Skip to content

Commit

Permalink
Merge pull request #279 from iblis17/sym-reshape
Browse files Browse the repository at this point in the history
sym-node: implement reshape api as Base
  • Loading branch information
vchuravy committed Sep 25, 2017
2 parents df13ddd + 55ccbfa commit 6d1ba53
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 18 deletions.
7 changes: 7 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@
* `reshape(NDArray, dim...; reverse=false)`
* `Reshape` deprecated.

* `reshape` of SymbolicNode share the same interface with Base
and additional keyword argument (#279).

* `reshape(SymbolicNode, dim; reverse=false, name)`
* `reshape(SymbolicNode, dim...; reverse=false, name)`
* `Reshape` deprecated.

# v0.2.2 (2017.05.14)
* Updated supported version of MXNet to 0.9.4.
* Improved build-system with support for auto-detecting GPU support.
Expand Down
6 changes: 5 additions & 1 deletion src/deprecated.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# reshape (#272)
# NDArray reshape (#272)
@deprecate reshape(arr::NDArray; shape=()) reshape(arr, shape)
@deprecate Reshape(arr::NDArray; shape=()) reshape(arr, shape)

# SymbolicNode reshape (#279)
@deprecate reshape(sym::SymbolicNode; shape=()) reshape(sym, shape)
@deprecate Reshape(sym::SymbolicNode; shape=()) reshape(sym, shape)
121 changes: 104 additions & 17 deletions src/symbolic-node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -656,9 +656,107 @@ function save(filename :: AbstractString, node :: SymbolicNode)
@mxcall(:MXSymbolSaveToFile, (MX_handle, char_p), node, filename)
end

import Base: reshape

"""
reshape(sym::SymbolicNode, dim; reverse=false, name)
reshape(sym::SymbolicNode, dim...; reverse=false, name)
Reshape SymbolicNode operator
Some dimensions of the shape can take special values from the set
{0, -1, -2, -3, -4}.
The significance of each is explained below:
- `0` copy this dimension from the input to the output shape.
Example:
- input shape = (2,3,4), shape = (4,0,2), output shape = (4,3,2)
- input shape = (2,3,4), shape = (2,0,0), output shape = (2,3,4)
- `-1` infers the dimension of the output shape by using the remainder of the
input dimensions keeping the size of the new array same as that of the input
array. At most one dimension of shape can be -1.
Example:
- input shape = (2,3,4), shape = (6,1,-1), output shape = (6,1,4)
- input shape = (2,3,4), shape = (3,-1,8), output shape = (3,1,8)
- input shape = (2,3,4), shape=(-1,), output shape = (24,)
- `-2` copy all/remainder of the input dimensions to the output shape.
Example:
- input shape = (2,3,4), shape = (-2,), output shape = (2,3,4)
- input shape = (2,3,4), shape = (2,-2), output shape = (2,3,4)
- input shape = (2,3,4), shape = (-2,1,1), output shape = (2,3,4,1,1)
- `-3` use the product of two consecutive dimensions of the input shape as the
output dimension.
Example:
- input shape = (2,3,4), shape = (-3,4), output shape = (6,4)
- input shape = (2,3,4,5), shape = (-3,-3), output shape = (6,20)
- input shape = (2,3,4), shape = (0,-3), output shape = (2,12)
- input shape = (2,3,4), shape = (-3,-2), output shape = (6,4)
- `-4` split one dimension of the input into two dimensions passed subsequent
to -4 in shape (can contain -1).
Example:
- input shape = (2,3,4), shape = (-4,1,2,-2), output shape = (1,2,3,4)
- input shape = (2,3,4), shape = (2,-4,-1,3,-2), output shape = (2,1,3,4)
If the argument `reverse` is set to `1`, then the special values are inferred
from right to left.
Example:
- with `reverse=false`, for input shape = (10,5,4), shape = (-1,0),
output shape would be (40,5)
- with `reverse=true`, output shape will be (50,4).
"""
reshape{N}(sym::SymbolicNode, dim::NTuple{N, Integer}; kwargs...) =
_reshape(sym, dim; kwargs...)
reshape(sym::SymbolicNode, dim::Integer...; kwargs...) =
_reshape(sym, dim; kwargs...)

@inline function _reshape{N}(sym::SymbolicNode, dim::NTuple{N, Integer};
reverse::Bool=false, name::String="")
op = _get_cached_libmx_op_handle("reshape")
node = _create_atomic_symbol(op.value, ["shape", "reverse"],
[dump_mx_param(dim), dump_mx_param(!reverse)])
name = get!(DEFAULT_NAME_MANAGER, name, "reshape")
_compose!(node, name=name, data=sym)
end

################################################################################
# Atomic SymbolicNode functions dynamically imported from libmxnet
################################################################################
@inline function _create_atomic_symbol(creator::MX_handle, keys::Vector{String},
vals::Vector{String})
ref_sym_hdr = Ref{MX_handle}(C_NULL)
@mxcall(:MXSymbolCreateAtomicSymbol,
(MX_handle, MX_uint, Ptr{char_p}, Ptr{char_p}, Ref{MX_handle}),
creator, length(keys), keys, vals, ref_sym_hdr)
SymbolicNode(MX_SymbolHandle(ref_sym_hdr[]))
end

@inline function _create_atomic_symbol(creator::MX_handle, keys::Vector{String},
vals::Vector{String},
attrs::Dict{Symbol, String})
node = _create_atomic_symbol(creator, keys, vals)
# set attrs
for (k, v) in attrs
set_attr(node, k, v)
end
node
end

function _define_atomic_symbol_creator(name :: String)
handle = _get_libmx_op_handle(name)
f_desc, key_narg = _get_libmx_op_description(name, handle)
Expand Down Expand Up @@ -709,7 +807,7 @@ function _define_atomic_symbol_creator(name :: String)
symbol_kws[k] = v
elseif k == :attrs
if isa(v, Dict)
attrs = convert(Dict{Symbol, AbstractString}, v)
attrs = convert(Dict{Symbol, String}, v)
else
throw(ArgumentError("attrs needs to be a Dictionary"))
end
Expand All @@ -731,24 +829,13 @@ function _define_atomic_symbol_creator(name :: String)
end
end)

local hdr = _get_cached_libmx_op_handle($name)

# create the SymbolicNode
ref_sym_hdr = Ref{MX_handle}()
@mxcall(:MXSymbolCreateAtomicSymbol,
(MX_handle, MX_uint, Ptr{char_p}, Ptr{char_p}, Ref{MX_handle}),
hdr, length(param_keys), param_keys, param_vals, ref_sym_hdr)
sym_hdr = ref_sym_hdr[]
local op = _get_cached_libmx_op_handle($name)
node = _create_atomic_symbol(op.value, param_keys, param_vals, attrs)

node = SymbolicNode(MX_SymbolHandle(sym_hdr))
# generate a new name for the new symbol if user not provided in kwargs
hint = lowercase($name)
name = get!(DEFAULT_NAME_MANAGER, name, hint)

# set attrs
for (k, v) in attrs
set_attr(node, k, v)
end

if length(symbol_kws) == 0
_compose!(node, name, args...)
elseif length(args) == 1
Expand Down Expand Up @@ -778,11 +865,11 @@ macro _import_atomic_symbol_creators()
# XXX: those are operators defined for NDArray, we exclude them here
# because the calling convention for the type signature is not strong
# enough to disambiguate the method for NDArray and SymbolicNode
const ignored_ops = ["_set_value"]
const ignored_ops = ["_set_value", "reshape"] # in lowercase

op_names = _get_libmx_op_names()
func_exprs = map(op_names) do name
if name ignored_ops
if lowercase(name) ignored_ops
expr = _define_atomic_symbol_creator(name)
end
end
Expand Down
93 changes: 93 additions & 0 deletions test/unittest/symbolic-node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,98 @@ function test_functions()
typeof(mx.sum(data)) == mx.SymbolicNode
end

function test_reshape()
info("SymbolicNode::reshape(sym, dim...)")

A = mx.NDArray(collect(1:24))
x = mx.Variable(:x)
y = mx.reshape(x, 2, 3, 4)
e = mx.bind(y, mx.cpu(), Dict(:x => A))
mx.forward(e)
out = e.outputs[1]

@test size(out) == (2, 3, 4)
@test copy(out) == reshape(1:24, 2, 3, 4)

info("SymbolicNode::reshape(sym, dim)")

A = mx.NDArray(collect(1:24))
x = mx.Variable(:x)
y = mx.reshape(x, (2, 3, 4))
e = mx.bind(y, mx.cpu(), Dict(:x => A))
mx.forward(e)
out = e.outputs[1]

@test size(out) == (2, 3, 4)
@test copy(out) == reshape(1:24, 2, 3, 4)

info("SymbolicNode::reshape::reverse")

A = mx.zeros(10, 5, 4)
x = mx.Variable(:x)
y = mx.reshape(x, -1, 0, reverse=true)
e = mx.bind(y, mx.cpu(), Dict(:x => A))
mx.forward(e)
out = e.outputs[1]

@test size(out) == (50, 4)

info("SymbolicNode::reshape::0")

A = mx.zeros(2, 3, 4)
x = mx.Variable(:x)
y = mx.reshape(x, 4, 0, 2)
e = mx.bind(y, mx.cpu(), Dict(:x => A))
mx.forward(e)
out = e.outputs[1]

@test size(out) == (4, 3, 2)

info("SymbolicNode::reshape::-1")

A = mx.zeros(2, 3, 4)
x = mx.Variable(:x)
y = mx.reshape(x, 6, 1, -1)
e = mx.bind(y, mx.cpu(), Dict(:x => A))
mx.forward(e)
out = e.outputs[1]

@test size(out) == (6, 1, 4)

info("SymbolicNode::reshape::-2")

A = mx.zeros(2, 3, 4, 2)
x = mx.Variable(:x)
y = mx.reshape(x, 3, 2, -2)
e = mx.bind(y, mx.cpu(), Dict(:x => A))
mx.forward(e)
out = e.outputs[1]

@test size(out) == (3, 2, 4, 2)

info("SymbolicNode::reshape::-3")

A = mx.zeros(2, 3, 4, 5)
x = mx.Variable(:x)
y = mx.reshape(x, -3, -3)
e = mx.bind(y, mx.cpu(), Dict(:x => A))
mx.forward(e)
out = e.outputs[1]

@test size(out) == (6, 20)

info("SymbolicNode::reshape::-4")

A = mx.zeros(2, 3, 4)
x = mx.Variable(:x)
y = mx.reshape(x, 0, 0, -4, 2, 2)
e = mx.bind(y, mx.cpu(), Dict(:x => A))
mx.forward(e)
out = e.outputs[1]

@test size(out) == (2, 3, 2, 2)
end

function test_dot()
info("SymbolicNode::dot")
x = mx.Variable(:x)
Expand Down Expand Up @@ -164,6 +256,7 @@ end
test_saveload()
test_attrs()
test_functions()
test_reshape()
test_dot()
test_print()
test_misc()
Expand Down

0 comments on commit 6d1ba53

Please sign in to comment.