Skip to content

Commit

Permalink
ndarray: make _minus type stable (#345)
Browse files Browse the repository at this point in the history
* ndarray: make _minus type stable

The current importer `_import_ndarray_functions`
provide `_minus(x, y; out = x)` and its return value is rely on
keyword argument `out`.
But Julia cannot (or hard to) do type inference on keyword argument
at the moment,
so this commit propose a new method `_minus!(x, y)` which modified the
first argument, instead of provide a keyword argument. The new method
can make type stable.

fix #341

* add test cases
  • Loading branch information
iblislin authored and pluskid committed Nov 27, 2017
1 parent e0f625a commit 935eb35
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 11 deletions.
30 changes: 25 additions & 5 deletions src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ function sub_from!(dst::NDArray, arg::NDArrayOrReal)
if isa(arg, Real)
_minus_scalar(dst, scalar=convert(eltype(dst), arg), out=dst)
else
_minus(dst, arg, out=dst)
_minus!(dst, arg)
end
end

Expand Down Expand Up @@ -1037,6 +1037,15 @@ function _autoimport(name::Symbol)
end
end

function _outexpr(name::Symbol, x #= the first arg of `sig` =#)
if endswith(string(name), "!") # `func!`
Ptr, 1, :([[MX_handle(x.handle)]]), :($x)
else
retexpr = :(NDArray(MX_NDArrayHandle(unsafe_load(hdls_ref[], 1))))
Ref, 0, :(Ref{Ptr{MX_handle}}(C_NULL)), retexpr
end
end

macro _remap(sig::Expr, imp::Expr)
fname = sig.args[1]
opname = string(imp.args[1])
Expand All @@ -1055,16 +1064,19 @@ macro _remap(sig::Expr, imp::Expr)
mxvals = Expr(:vect, map(x -> :(dump_mx_param($(x.args[2]))), mxargs)...)
ndhlds = Expr(:vect, map(x -> :($(x).handle), ndin)...)

# handler for `func!` which has side effect on first argument.
T, n_output, hdls_ref, retexpr = _outexpr(fname, sig.args[2].args[1])

func_body = quote
op_handle = _get_cached_libmx_op_handle($opname)
n_output = Ref(Cint(0))
hdls_ref = Ref{Ptr{MX_handle}}(C_NULL)
n_output = Ref(Cint($n_output))
hdls_ref = $hdls_ref
@mxcall(:MXImperativeInvoke,
(MX_handle,
Cint,
Ptr{MX_handle},
Ref{Cint},
Ref{Ptr{MX_handle}},
$T{Ptr{MX_handle}},
Cint,
char_pp,
char_pp),
Expand All @@ -1076,7 +1088,7 @@ macro _remap(sig::Expr, imp::Expr)
$(length(mxargs)),
$mxkeys,
$mxvals)
NDArray(MX_NDArrayHandle(unsafe_load(hdls_ref[], 1)))
$retexpr
end

docstr = " $sig"
Expand Down Expand Up @@ -1123,6 +1135,13 @@ _mxsig[:reshape] = :(reshape(arr; shape = dim, reverse = !reverse))
@_remap prod(arr::NDArray) prod(arr)
@_remap prod(arr::NDArray, dims) prod(arr; axis = 0 .- dims, keepdims = true)

################################################################################
# remapping to solving type unstablility
################################################################################

@_remap _minus(x::NDArray, y::NDArray) _minus(x, y)
@_remap _minus!(x::NDArray, y::NDArray) _minus(x, y)

################################################################################
# NDArray functions dynamically imported from libmxnet
################################################################################
Expand Down Expand Up @@ -1248,6 +1267,7 @@ const _op_import_bl = [ # import black list; do not import these funcs
"dot",
"transpose",
"prod",
"_minus",
]

macro _import_ndarray_functions()
Expand Down
18 changes: 12 additions & 6 deletions test/unittest/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,12 @@ function test_minus()
scalar_large = Float16(1e4)
@test t6 - scalar_small copy(a6 .- scalar_small)
@test t6 - scalar_large copy(a6 .- scalar_large)

info("NDArray::minus::type stablility")
let x = mx.zeros(dims), y = mx.ones(dims)
@inferred x - y
@inferred x .- y
end
end

function test_mul()
Expand Down Expand Up @@ -361,29 +367,29 @@ end


function test_rdiv()
info("NDarray::rdiv")
info("NDArray::rdiv")

info("NDarray::rdiv::Inf16")
info("NDArray::rdiv::Inf16")
let x = 1 ./ mx.zeros(Float16, 4)
@test copy(x) == [Inf16, Inf16, Inf16, Inf16]
end

info("NDarray::rdiv::Inf32")
info("NDArray::rdiv::Inf32")
let x = 1 ./ mx.zeros(Float32, 4)
@test copy(x) == [Inf32, Inf32, Inf32, Inf32]
end

info("NDarray::rdiv::Inf64")
info("NDArray::rdiv::Inf64")
let x = 1 ./ mx.zeros(Float64, 4)
@test copy(x) == [Inf64, Inf64, Inf64, Inf64]
end

info("NDarray::rdiv::Int")
info("NDArray::rdiv::Int")
let x = 1 ./ mx.NDArray([1 2; 3 4])
@test copy(x) == [1 0; 0 0]
end

info("NDarray::rdiv::Float32")
info("NDArray::rdiv::Float32")
let x = 1 ./ mx.NDArray(Float32[1 2; 3 4])
y = 1 ./ Float32[1 2; 3 4]
@test copy(x) y
Expand Down

0 comments on commit 935eb35

Please sign in to comment.