Skip to content

Commit

Permalink
ndarray: remap expand_dims (#381)
Browse files Browse the repository at this point in the history
  • Loading branch information
iblislin authored and pluskid committed Dec 21, 2017
1 parent bfeba81 commit 3787895
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 18 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
* `NDArray`
* `context()`
* `empty()`
* `expand_dims()`

* `SymbolicNode`
* `Variable`
Expand Down
3 changes: 2 additions & 1 deletion src/MXNet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ export SymbolicNode,
# ndarray.jl
export NDArray,
context,
empty
empty,
expand_dims

# executor.jl
export Executor,
Expand Down
39 changes: 36 additions & 3 deletions src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1045,7 +1045,8 @@ end
# Mapping NDArray functions to Base-like API
################################################################################

const _mxsig = Dict{Symbol,Expr}()
const _ndsig = Dict{Symbol,Expr}()
const _nddoc = Dict{Symbol,Any}()

function _autoimport(name::Symbol, sig::Expr)
if name == :broadcast_
Expand Down Expand Up @@ -1074,6 +1075,9 @@ _broadcast_target(sig::Expr) = sig.args[2].args[].args[end]
Generate docstring from function signature
"""
function _docsig(fname::Symbol, sig::Expr)
s = get(_nddoc, fname, "")
!isempty(s) && return s

if fname !== :broadcast_
" $sig"
else
Expand Down Expand Up @@ -1141,14 +1145,14 @@ macro _remap(sig::Expr, imp::Expr)
end

macro _remap(sig::Expr, imp::Symbol)
imp = _mxsig[imp]
imp = _ndsig[imp]

esc(quote
@_remap($sig, $imp)
end)
end

_mxsig[:reshape] = :(reshape(arr; shape = dim, reverse = !reverse))
_ndsig[:reshape] = :(reshape(arr; shape = dim, reverse = !reverse))
@_remap reshape(arr::NDArray, dim...; reverse = false) reshape
@_remap reshape(arr::NDArray, dim; reverse = false) reshape

Expand All @@ -1175,6 +1179,34 @@ _mxsig[:reshape] = :(reshape(arr; shape = dim, reverse = !reverse))
@_remap prod(arr::NDArray) prod(arr)
@_remap prod(arr::NDArray, dims) prod(arr; axis = 0 .- dims, keepdims = true)

_nddoc[:expand_dims] =
"""
expand_dims(x::NDArray, dim)
Insert a new axis into `dim`.
```julia
julia> x
4 mx.NDArray{Float64,1} @ CPU0:
1.0
2.0
3.0
4.0
julia> mx.expand_dims(x, 1)
1×4 mx.NDArray{Float64,2} @ CPU0:
1.0 2.0 3.0 4.0
julia> mx.expand_dims(x, 2)
4×1 mx.NDArray{Float64,2} @ CPU0:
1.0
2.0
3.0
4.0
```
"""
@_remap expand_dims(x::NDArray, dim) expand_dims(x; axis = -dim)

# trigonometric functions, remap to keep consistent with Base
@_remap broadcast_(::typeof(sin), x::NDArray) sin(x)
@_remap broadcast_(::typeof(cos), x::NDArray) cos(x)
Expand Down Expand Up @@ -1318,6 +1350,7 @@ const _op_import_bl = [ # import black list; do not import these funcs
"_full", # we already have `mx.fill`
"_ones", # we already have `mx.ones`
"_zeros", # we already have `mx.zeros`
"expand_dims",

# arithmetic
"_plus",
Expand Down
55 changes: 41 additions & 14 deletions test/unittest/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -779,27 +779,53 @@ function test_eltype()
end

function test_reshape()
info("NDArray::reshape")
A = rand(2, 3, 4)
info("NDArray::reshape")
A = rand(2, 3, 4)

B = reshape(mx.NDArray(A), 4, 3, 2)
@test size(B) == (4, 3, 2)
@test copy(B)[3, 1, 1] == A[1, 2, 1]
B = reshape(NDArray(A), 4, 3, 2)
@test size(B) == (4, 3, 2)
@test copy(B)[3, 1, 1] == A[1, 2, 1]

C = reshape(mx.NDArray(A), (4, 3, 2))
@test size(C) == (4, 3, 2)
@test copy(C)[3, 1, 1] == A[1, 2, 1]
C = reshape(NDArray(A), (4, 3, 2))
@test size(C) == (4, 3, 2)
@test copy(C)[3, 1, 1] == A[1, 2, 1]

info("NDArray::reshape::reverse")
A = mx.zeros(10, 5, 4)
info("NDArray::reshape::reverse")
A = mx.zeros(10, 5, 4)

B = reshape(A, -1, 0)
@test size(B) == (40, 5)
B = reshape(A, -1, 0)
@test size(B) == (40, 5)

C = reshape(A, -1, 0, reverse=true)
@test size(C) == (50, 4)
C = reshape(A, -1, 0, reverse=true)
@test size(C) == (50, 4)
end

function test_expand_dims()
info("NDArray::expand_dims")
let A = [1, 2, 3, 4], x = NDArray(A)
@test size(x) == (4,)

y = expand_dims(x, 1)
@test size(y) == (1, 4)

y = expand_dims(x, 2)
@test size(y) == (4, 1)
end

let A = [1 2; 3 4; 5 6], x = NDArray(A)
@test size(x) == (3, 2)

y = expand_dims(x, 1)
@test size(y) == (1, 3, 2)

y = expand_dims(x, 2)
@test size(y) == (3, 1, 2)

y = expand_dims(x, 3)
@test size(y) == (3, 2, 1)
end
end # test_expand_dims

function test_sum()
info("NDArray::sum")

Expand Down Expand Up @@ -1025,6 +1051,7 @@ end # function test_hyperbolic
test_nd_as_jl()
test_dot()
test_reshape()
test_expand_dims()
test_sum()
test_mean()
test_maximum()
Expand Down

0 comments on commit 3787895

Please sign in to comment.