Skip to content

Commit

Permalink
ndarray: change internal api of plus to help autograd (#364)
Browse files Browse the repository at this point in the history
  • Loading branch information
iblislin authored and pluskid committed Dec 9, 2017
1 parent daf787c commit 233fcfc
Showing 1 changed file with 13 additions and 15 deletions.
28 changes: 13 additions & 15 deletions src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -587,8 +587,10 @@ Summation. Multiple arguments of either scalar or `NDArray` could be
added together. Note at least the first or second argument needs to be an
`NDArray` to avoid ambiguity of built-in summation.
"""
+(x::NDArray, ys::NDArrayOrReal...) = add_to!(copy(x, context(x)), ys...)
+(x::Real, y::NDArray, zs::NDArrayOrReal...) = add_to!(copy(y, context(y)), x, zs...)
+(x::NDArray) = x
+(x::NDArray, y::NDArray) = _plus(x, y)
+(x::NDArray, y::Real) = _plus_scalar(x, scalar = y)
+(y::Real, x::NDArray) = _plus_scalar(x, scalar = y)

broadcast_(::typeof(+), x::NDArray, y::NDArrayOrReal) = x + y
broadcast_(::typeof(+), x::Real, y::NDArray) = x + y
Expand Down Expand Up @@ -1205,20 +1207,16 @@ function _get_ndarray_function_def(name :: String)
args = MX_handle[]
end

if length(output_vars) > 0
output_handles = map((x) -> Base.cconvert(MX_handle, x), output_vars)
# XXX: Julia 0.4 has bug: [Array{MX_handle}] == Array{MX_handle}
output_handles_pp = Array{Array{MX_handle}}(1)
output_handles_pp[1] = Base.cconvert(Ptr{MX_handle}, output_handles)
output_handles_pp = if length(output_vars) > 0
[map(x -> x.handle, output_vars)]
else
output_handles_pp = [Base.convert(Ptr{MX_handle}, 0)]
[Ptr{MX_handle}(C_NULL)]
end
num_outputs_p = [convert(Cint, num_outputs)]

kw_keys_str = String[string(x[1]) for x in kwargs]
kw_vals_str = String[dump_mx_param(x[2]) for x in kwargs]

#op_handle = _get_cached_libmx_op_handle($(QuoteNode(name)))
op_handle = _get_cached_libmx_op_handle($(name))
@mxcall(:MXImperativeInvoke,
(MX_handle, Cint, Ptr{MX_handle},
Expand All @@ -1229,13 +1227,13 @@ function _get_ndarray_function_def(name :: String)
length(kwargs), kw_keys_str, kw_vals_str)

if out == nothing
handle_array = unsafe_wrap(Array, output_handles_pp[], num_outputs_p[])
handle_array = [MX_NDArrayHandle(x) for x in handle_array]
arrays = [NDArray(hdr) for hdr in handle_array]
if length(arrays) == 1
return arrays[1]
n = num_outputs_p[]
hdls = unsafe_wrap(Array{MX_handle}, output_handles_pp[], n)
xs = NDArray[NDArray(MX_NDArrayHandle(x)) for x in hdls]
if n == 1
return xs[]
else
return arrays
return xs
end
else
return out
Expand Down

0 comments on commit 233fcfc

Please sign in to comment.