Skip to content

Commit

Permalink
Merge pull request #272 from iblis17/issue-272
Browse files Browse the repository at this point in the history
reshape on NDArray
  • Loading branch information
vchuravy committed Sep 24, 2017
2 parents bbd0e66 + 3d4adfb commit b68ca2e
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 7 deletions.
9 changes: 9 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
# v0.3.0 (TBD)

## API Changes

* `reshape` of NDArray share the same interface with Base (#272).
* `reshape(NDArray, dim; reverse=false)`
* `reshape(NDArray, dim...; reverse=false)`
* `Reshape` deprecated.

# v0.2.2 (2017.05.14)
* Updated supported version of MXNet to 0.9.4.
* Improved build-system with support for auto-detecting GPU support.
Expand Down
2 changes: 2 additions & 0 deletions src/MXNet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ include("visualize.jl")

include("nn-factory.jl")

include("deprecated.jl")

end # mx

end # module MXNet
3 changes: 3 additions & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# reshape (#272)
@deprecate reshape(arr::NDArray; shape=()) reshape(arr, shape)
@deprecate Reshape(arr::NDArray; shape=()) reshape(arr, shape)
32 changes: 29 additions & 3 deletions src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,31 @@ function save(filename::String, data::Dict{Base.Symbol,NDArray})
filename, length(names), arrays, names)
end

import Base: reshape

"""
reshape(arr::NDArray, dim...; reverse=false)
reshape(arr::NDArray, dim; reverse=false)
"""
reshape{N}(arr::NDArray, dim::NTuple{N, Integer}; reverse::Bool=false) =
_reshape(arr, dim, reverse)
reshape{N}(arr::NDArray, dim::Vararg{Integer, N}; reverse::Bool=false) =
_reshape(arr, dim, reverse)

@inline function _reshape{N}(arr::NDArray, dim::NTuple{N, Integer}, reverse::Bool)
op_handle = _get_cached_libmx_op_handle("reshape")
n_output = Ref(Cint(0))
hdls_ref = Ref{Ptr{MX_handle}}(C_NULL)
@mxcall(:MXImperativeInvoke,
(MX_handle, Cint, Ptr{MX_handle}, Ref{Cint}, Ref{Ptr{MX_handle}},
Cint, char_pp, char_pp),
op_handle, 1, [arr.handle], n_output, hdls_ref,
2, ["shape", "reverse"], [dump_mx_param(dim), dump_mx_param(!reverse)])
# not a typo ^^^^^^^^
@assert n_output[] == 1
NDArray(MX_NDArrayHandle(unsafe_load(hdls_ref[], 1)))
end

################################################################################
# NDArray functions dynamically imported from libmxnet
################################################################################
Expand Down Expand Up @@ -993,7 +1018,6 @@ Upon calling, the output arguments will be automatically initialized with empty
Those functions always return the output arguments. If there is only one output (the typical situation), that
object (`NDArray`) is returned. Otherwise, a tuple containing all the outputs will be returned.
"""

function _get_ndarray_function_def(name :: String)
func_name = Symbol(name)

Expand Down Expand Up @@ -1076,7 +1100,9 @@ function _get_ndarray_function_def(name :: String)
end

macro _import_ndarray_functions()
names = _get_libmx_op_names()
black_list = ["reshape"] # do not import these funcs
names = filter(n -> (lowercase(n), black_list), _get_libmx_op_names())

func_exprs = map(names) do name
op_handle = _get_libmx_op_handle(name)

Expand All @@ -1086,7 +1112,7 @@ macro _import_ndarray_functions()
func_name = Symbol(name)
expr = quote
# TODO the explicit exclusion of take will no longer be necessary when it is removed from Base
$((isdefined(Base, func_name) && func_name :take)? :(import Base.$func_name) : :())
$((isdefined(Base, func_name) && func_name :take) ? :(import Base.$func_name) : :())
$func_def
@doc $desc ->
$func_def2
Expand Down
31 changes: 27 additions & 4 deletions test/unittest/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ function test_plus()
scalar_large = 1e8
@test reldiff(t4 + scalar_small, copy(a4 .+ scalar_small)) < thresh
@test reldiff(t4 + scalar_large, copy(a4 .+ scalar_large)) < thresh

t5 = zeros(Float64, dims)
a5 = copy(t5, mx.cpu())
scalar_small = 1e-8
Expand Down Expand Up @@ -169,7 +169,7 @@ function test_minus()
scalar_large = 1e8
@test reldiff(t4 - scalar_small, copy(a4 .- scalar_small)) < thresh
@test reldiff(t4 - scalar_large, copy(a4 .- scalar_large)) < thresh

t5 = zeros(Float64, dims)
a5 = copy(t5, mx.cpu())
scalar_small = 1e-8
Expand Down Expand Up @@ -213,7 +213,7 @@ function test_mul()
scalar_large = 1e8
@test reldiff(t4 * scalar_small, copy(a4 .* scalar_small)) < thresh
@test reldiff(t4 * scalar_large, copy(a4 .* scalar_large)) < thresh

t5, a5 = rand_tensors(Float64, dims)
scalar_small = 1e-8
scalar_large = 1e8
Expand Down Expand Up @@ -254,7 +254,7 @@ function test_div()
scalar_large = 1e8
@test reldiff(t4 / scalar_small, copy(a4 ./ scalar_small)) < thresh
@test reldiff(t4 / scalar_large, copy(a4 ./ scalar_large)) < thresh

t5, a5 = rand_tensors(Float64, dims)
scalar_small = 1e-8
scalar_large = 1e8
Expand Down Expand Up @@ -391,6 +391,28 @@ function test_eltype()
end
end

function test_reshape()
info("NDArray::reshape")
A = rand(2, 3, 4)

B = reshape(mx.NDArray(A), 4, 3, 2)
@test size(B) == (4, 3, 2)
@test copy(B)[3, 1, 1] == A[1, 2, 1]

C = reshape(mx.NDArray(A), (4, 3, 2))
@test size(C) == (4, 3, 2)
@test copy(C)[3, 1, 1] == A[1, 2, 1]

info("NDArray::reshape::reverse")
A = mx.zeros(10, 5, 4)

B = reshape(A, -1, 0)
@test size(B) == (40, 5)

C = reshape(A, -1, 0, reverse=true)
@test size(C) == (50, 4)
end

function test_kwargs()
info("NDArray::kwargs")
dims1 = (2,3,4)
Expand Down Expand Up @@ -421,6 +443,7 @@ end
test_eltype()
test_nd_as_jl()
test_dot()
test_reshape()
test_kwargs()
end

Expand Down

0 comments on commit b68ca2e

Please sign in to comment.