Skip to content

Commit

Permalink
ndarray: more Base-like APIs (#303)
Browse files Browse the repository at this point in the history
* ndarray: make API of `sum` and `mean` be Base-like

- also fix the axis value mapping
- `mean(arr, axis=0)` is not Julian

* ndarray: Base-like `maximum` and `minimum`

- remove `mx.max`, `mx.min`, `mx.max_axis` and `mx.min_axis`

* ndarray: simple doc while remapping with `@_remap`

* ndarray: more test cases for dim as tuple

* ndarray: remap dot, the elegent way

* ndarray: remap `transpose` and add `permutedims`

* ndarray: docs of _remap

* ndarray: remap `prod`

* util: add _sig_checker

for discovering non-Julian APIs

* travis: add _sig_checker after testing
  • Loading branch information
iblislin authored and pluskid committed Nov 6, 2017
1 parent 9fcab40 commit 4f182ee
Show file tree
Hide file tree
Showing 5 changed files with 254 additions and 50 deletions.
3 changes: 3 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ script:
- travis_wait 60 ${TRAVIS_DIR}/run_test.sh

after_success:
# See https://github.com/dmlc/MXNet.jl/pull/303#issuecomment-341171774
- julia -e 'using MXNet; mx._sig_checker()'

- source ${TRAVIS_DIR}/run_coverage.sh
- echo $TRAVIS_JULIA_VERSION
- julia -e 'Pkg.add("Documenter")'
Expand Down
32 changes: 30 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@

## API Changes

* `reshape` of NDArray share the same interface with Base (#272).
* `reshape` of NDArray shares the same interface with Base (#272).
* `reshape(NDArray, dim; reverse=false)`
* `reshape(NDArray, dim...; reverse=false)`
* `Reshape` deprecated.

* `reshape` of SymbolicNode share the same interface with Base
* `reshape` of SymbolicNode shares the same interface with Base
and additional keyword argument (#279).

* `reshape(SymbolicNode, dim; reverse=false, name)`
Expand All @@ -27,6 +27,34 @@

* `srand!` deprecated, please use `srand` (#282)

* `mean` and `sum` of NDArray share the same interface with Base
and fix the `axis` indexing (#TBD).

* This is a breaking change; no deprecated warning.
* Before: `mean(arr, axis=0)`
* After: `mean(arr, 1)`

* `max` and `min` of NDArray renamed to `maximum` and `minimum` and share the
same interface with Base. The `axis` indexing is fixed, also. (#TBD)

* This is a breaking change; no deprecated warning.
* Before: `mx.max(arr, axis=0)` or `mx.max_axis(arr, axis=0)`
* After: `maximum(arr, 1)`

* `mx.transpose` for high dimension NDArray has been renamed to `permutedims`
and shares the same interface with Base. (#TBD)

* This is a breaking changes; no deprecated warning.
* Before: `mx.transpose(A, axis=[2, 1, 3])`
* After: `permutedims(A, [2, 1, 3])`

* `prod` of `NDArray` shares the same interface with Base and fix
the `axis` indexing. (#TBD).

* This is a breaking change; no deprecated warning.
* Before: `prod(arr, axis=-1)`
* After: `prod(arr, 1)`

# v0.2.2 (2017.05.14)
* Updated supported version of MXNet to 0.9.4.
* Improved build-system with support for auto-detecting GPU support.
Expand Down
148 changes: 111 additions & 37 deletions src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ type NDArray
end

function Base.show(io :: IO, arr :: NDArray)
println(io, "$(join(size(arr), "x")) mx.NDArray{$(eltype(arr))} @ $(context(arr)):")
println(io, "$(join(size(arr), "×")) mx.NDArray{$(eltype(arr))} @ $(context(arr)):")
Base.showarray(io, try_get_shared(arr, sync=:read), false, header=false)
end

Expand Down Expand Up @@ -971,31 +971,106 @@ function save(filename::String, data::Dict{Base.Symbol,NDArray})
filename, length(names), arrays, names)
end

import Base: reshape
################################################################################
# Mapping NDArray functions to Base-like API
################################################################################

"""
reshape(arr::NDArray, dim...; reverse=false)
reshape(arr::NDArray, dim; reverse=false)
"""
reshape{N}(arr::NDArray, dim::NTuple{N, Integer}; reverse::Bool=false) =
_reshape(arr, dim, reverse)
reshape{N}(arr::NDArray, dim::Vararg{Integer, N}; reverse::Bool=false) =
_reshape(arr, dim, reverse)
const _mxsig = Dict{Symbol,Expr}()

function _autoimport(name::Symbol)
if isdefined(Base, name)
:(import Base: $name)
else
:()
end
end

macro _remap(sig::Expr, imp::Expr)
fname = sig.args[1]
opname = string(imp.args[1])

@inline function _reshape{N}(arr::NDArray, dim::NTuple{N, Integer}, reverse::Bool)
op_handle = _get_cached_libmx_op_handle("reshape")
n_output = Ref(Cint(0))
hdls_ref = Ref{Ptr{MX_handle}}(C_NULL)
@mxcall(:MXImperativeInvoke,
(MX_handle, Cint, Ptr{MX_handle}, Ref{Cint}, Ref{Ptr{MX_handle}},
Cint, char_pp, char_pp),
op_handle, 1, [arr.handle], n_output, hdls_ref,
2, ["shape", "reverse"], [dump_mx_param(dim), dump_mx_param(!reverse)])
# not a typo ^^^^^^^^
@assert n_output[] == 1
NDArray(MX_NDArrayHandle(unsafe_load(hdls_ref[], 1)))
import_expr = _autoimport(fname)

if isa(imp.args[2], Expr) && imp.args[2].head == :parameters
ndin = imp.args[3:end]
mxargs = imp.args[2].args
else # no keyword arguments
ndin = imp.args[2:end]
mxargs = []
end

mxkeys = map(x -> string(x.args[1]), mxargs)
mxvals = Expr(:vect, map(x -> :(dump_mx_param($(x.args[2]))), mxargs)...)
ndhlds = Expr(:vect, map(x -> :($(x).handle), ndin)...)

func_body = quote
op_handle = _get_cached_libmx_op_handle($opname)
n_output = Ref(Cint(0))
hdls_ref = Ref{Ptr{MX_handle}}(C_NULL)
@mxcall(:MXImperativeInvoke,
(MX_handle,
Cint,
Ptr{MX_handle},
Ref{Cint},
Ref{Ptr{MX_handle}},
Cint,
char_pp,
char_pp),
op_handle,
$(length(ndin)),
$(ndhlds),
n_output,
hdls_ref,
$(length(mxargs)),
$mxkeys,
$mxvals)
NDArray(MX_NDArrayHandle(unsafe_load(hdls_ref[], 1)))
end

docstr = " $sig"
func_def = Expr(:function, sig, func_body)

esc(quote
$import_expr
@doc $docstr ->
$func_def
end)
end

macro _remap(sig::Expr, imp::Symbol)
imp = _mxsig[imp]

esc(quote
@_remap($sig, $imp)
end)
end

_mxsig[:reshape] = :(reshape(arr; shape = dim, reverse = !reverse))
@_remap reshape(arr::NDArray, dim...; reverse = false) reshape
@_remap reshape(arr::NDArray, dim; reverse = false) reshape

@_remap mean(arr::NDArray) mean(arr)
@_remap mean(arr::NDArray, region) mean(arr; axis = 0 .- region, keepdims = true)

@_remap sum(arr::NDArray) sum(arr)
@_remap sum(arr::NDArray, dims) sum(arr; axis = 0 .- dims, keepdims = true)

@_remap maximum(arr::NDArray) max(arr)
@_remap maximum(arr::NDArray, dims) max(arr; axis = 0 .- dims, keepdims = true)

@_remap minimum(arr::NDArray) min(arr)
@_remap minimum(arr::NDArray, dims) min(arr; axis = 0 .- dims, keepdims = true)

# See https://github.com/dmlc/MXNet.jl/issues/55
@_remap dot(x::NDArray, y::NDArray) dot(y, x)

# See https://github.com/dmlc/MXNet.jl/pull/123
@_remap transpose(arr::NDArray) transpose(_only2d(arr))
@_remap permutedims(arr::NDArray, axes) transpose(arr; axes = length(axes) .- tuple(axes...))

@_remap prod(arr::NDArray) prod(arr)
@_remap prod(arr::NDArray, dims) prod(arr; axis = 0 .- dims, keepdims = true)

################################################################################
# NDArray functions dynamically imported from libmxnet
################################################################################
Expand Down Expand Up @@ -1063,19 +1138,6 @@ function _get_ndarray_function_def(name :: String)
args = MX_handle[]
end

# XXX: hacky way of solving the problem that the arguments of `dot` should be swapped
# See https://github.com/dmlc/MXNet.jl/issues/55
if $name == "dot"
args = reverse(args)
end

# XXX: hacky way of solving the semantic difference of the axes parameter in Julia
# and in libmxnet.
# See https://github.com/dmlc/MXNet.jl/pull/123
if $name == "transpose"
kwargs = Any[key != :axes ? (key, arg) : (key, map(i->length(arg)-i, arg)) for (key, arg) in kwargs]
end

if length(output_vars) > 0
output_handles = map((x) -> Base.cconvert(MX_handle, x), output_vars)
# XXX: Julia 0.4 has bug: [Array{MX_handle}] == Array{MX_handle}
Expand Down Expand Up @@ -1123,9 +1185,21 @@ function _get_ndarray_function_def(name :: String)
return func_def, func_def2
end

const _op_import_bl = [ # import black list; do not import these funcs
"mean",
"reshape",
"sum",
"max",
"max_axis",
"min",
"min_axis",
"dot",
"transpose",
"prod",
]

macro _import_ndarray_functions()
black_list = ["reshape"] # do not import these funcs
names = filter(n -> (lowercase(n), black_list), _get_libmx_op_names())
names = filter(n -> (lowercase(n), _op_import_bl), _get_libmx_op_names())

func_exprs = map(names) do name
op_handle = _get_libmx_op_handle(name)
Expand Down
25 changes: 25 additions & 0 deletions src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,28 @@ function _format_signature(narg::Int, arg_names::Ref{char_pp})
return join([unsafe_string(name) for name in arg_names] , ", ")
end

@inline function _only2d(x)
@assert ndims(x) == 2
x
end

"""
libmxnet operators signature checker.
"""
function _sig_checker()
names = filter(n -> (lowercase(n), _op_import_bl), _get_libmx_op_names())
foreach(names) do name
op_handle = _get_libmx_op_handle(name)

desc, key_narg = _get_libmx_op_description(name, op_handle)
_sig = desc |> s -> split(s, '\n') |> first |> strip
_m = match(r"(axis|axes|keepdims|shape)", _sig)

if _m === nothing
return
end

warn(_sig)

end
end
96 changes: 85 additions & 11 deletions test/unittest/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,71 @@ function test_reshape()
@test size(C) == (50, 4)
end

function test_sum()
info("NDArray::sum")

let A = reshape(1.0:8, 2, 2, 2) |> collect, X = mx.NDArray(A)
@test copy(sum(X))[] == sum(A)
@test copy(sum(X, 1)) == sum(A, 1)
@test copy(sum(X, 2)) == sum(A, 2)
@test copy(sum(X, 3)) == sum(A, 3)
@test copy(sum(X, [1, 2])) == sum(A, [1, 2])
@test copy(sum(X, (1, 2))) == sum(A, (1, 2))
end
end

function test_mean()
info("NDArray::mean")

let A = reshape(1.0:8, 2, 2, 2) |> collect, X = mx.NDArray(A)
@test copy(mean(X))[] == mean(A)
@test copy(mean(X, 1)) == mean(A, 1)
@test copy(mean(X, 2)) == mean(A, 2)
@test copy(mean(X, 3)) == mean(A, 3)
@test copy(mean(X, [1, 2])) == mean(A, [1, 2])
@test copy(mean(X, (1, 2))) == mean(A, (1, 2))
end
end

function test_maximum()
info("NDArray::maximum")

let A = reshape(1.0:8, 2, 2, 2) |> collect, X = mx.NDArray(A)
@test copy(maximum(X))[] == maximum(A)
@test copy(maximum(X, 1)) == maximum(A, 1)
@test copy(maximum(X, 2)) == maximum(A, 2)
@test copy(maximum(X, 3)) == maximum(A, 3)
@test copy(maximum(X, [1, 2])) == maximum(A, [1, 2])
@test copy(maximum(X, (1, 2))) == maximum(A, (1, 2))
end
end

function test_minimum()
info("NDArray::minimum")

let A = reshape(1.0:8, 2, 2, 2) |> collect, X = mx.NDArray(A)
@test copy(minimum(X))[] == minimum(A)
@test copy(minimum(X, 1)) == minimum(A, 1)
@test copy(minimum(X, 2)) == minimum(A, 2)
@test copy(minimum(X, 3)) == minimum(A, 3)
@test copy(minimum(X, [1, 2])) == minimum(A, [1, 2])
@test copy(minimum(X, (1, 2))) == minimum(A, (1, 2))
end
end

function test_prod()
info("NDArray::prod")

let A = reshape(1.0:8, 2, 2, 2) |> collect, X = mx.NDArray(A)
@test copy(prod(X))[] == prod(A)
@test copy(prod(X, 1)) == prod(A, 1)
@test copy(prod(X, 2)) == prod(A, 2)
@test copy(prod(X, 3)) == prod(A, 3)
@test copy(prod(X, [1, 2])) == prod(A, [1, 2])
@test copy(prod(X, (1, 2))) == prod(A, (1, 2))
end
end

function test_fill()
info("NDArray::fill")
thresh = 1e8
Expand Down Expand Up @@ -449,21 +514,25 @@ function test_fill()
end
end # function test_fill

function test_kwargs()
info("NDArray::kwargs")
dims1 = (2,3,4)
function test_transpose()
info("NDArray::transpose")
let A = rand(Float32, 2, 3), x = mx.NDArray(A)
@test size(x) == (2, 3)
@test size(x') == (3, 2)
end

A = rand(Float32, dims1)
x = mx.NDArray(A)
tx = mx.transpose(x, axes=(2,1,3))
tA = permutedims(A, [2,1,3])
@test size(tx) == size(tA)
@test all(copy(tx) .== tA)
info("NDArray::permutedims")
let A = collect(Float32, reshape(1.0:24, 2, 3, 4)), x = mx.NDArray(A)
A′ = permutedims(A, [2, 1, 3])
x′ = permutedims(x, [2, 1, 3])
@test size(A′) == size(x′)
@test A′ == copy(x′)
end
end

function test_show()
let str = sprint(show, mx.NDArray([1 2 3 4]))
@test contains(str, "1x4")
@test contains(str, "1×4")
@test contains(str, "mx.NDArray")
@test contains(str, "Int64")
@test contains(str, "CPU")
Expand All @@ -490,8 +559,13 @@ end
test_nd_as_jl()
test_dot()
test_reshape()
test_sum()
test_mean()
test_maximum()
test_minimum()
test_prod()
test_fill()
test_kwargs()
test_transpose()
test_show()
end

Expand Down

0 comments on commit 4f182ee

Please sign in to comment.