Skip to content

Commit

Permalink
ndarray: inplace modulo operators (#389)
Browse files Browse the repository at this point in the history
```julia
mod_from!(x, y)
mod_from!(x, 2)
rmod_from!(2, x)
```
  • Loading branch information
iblislin authored and pluskid committed Jan 4, 2018
1 parent a1cef7f commit 86ffb5d
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 4 deletions.
10 changes: 9 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
4.0
```

* modulo operator. (#TBD)
* Modulo operator. (#TBD)

```julia
x = NDArray(...)
Expand All @@ -128,6 +128,14 @@
2 .% x
```

* Inplace modulo operator, `mod_from!` and `rmod_from!`. (#TBD)

```julia
mod_from!(x, y)
mod_from!(x, 2)
rmod_from!(2, x)
```

* `cat`, `vcat`, `hcat` is implemented. (#TBD)

E.g. `hcat`
Expand Down
40 changes: 37 additions & 3 deletions src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,24 @@ function broadcast_(::typeof(/), x::NDArray{T}, y::Real) where {T<:Integer}
_div_scalar(x, scalar = y)
end

"""
mod_from!(x::NDArray, y::NDArray)
mod_from!(x::NDArray, y::Real)
Elementwise modulo for `NDArray`.
Inplace updating.
"""
mod_from!(x::NDArray, y::NDArray) = _mod!(x, y)
mod_from!(x::NDArray, y::Real) = _mod_scalar!(x, y)

"""
rmod_from!(y::Real, x::NDArray)
Elementwise modulo for `NDArray`.
Inplace updating.
"""
rmod_from!(y::Real, x::NDArray) = _rmod_scalar!(x, y)

import Base: %

"""
Expand All @@ -761,8 +779,8 @@ Elementwise modulo for `NDArray`.
%(x::NDArray, y::Real) = _mod_scalar(x, scalar = y)

broadcast_(::typeof(%), x::NDArray, y::NDArray) = _mod(x, y)
broadcast_(::typeof(%), x::NDArray, y::Real) = _mod_scalar(x, scalar = y)
broadcast_(::typeof(%), y::Real, x::NDArray) = _rmod_scalar(x, scalar = y)
broadcast_(::typeof(%), x::NDArray, y::Real) = _mod_scalar(x, y)
broadcast_(::typeof(%), y::Real, x::NDArray) = _rmod_scalar(x, y)

import Base: ^

Expand Down Expand Up @@ -1061,8 +1079,13 @@ function _autoimport(name::Symbol, sig::Expr)
end
end

_isinplace(name::Symbol) = endswith(string(name), "!")

_writable(name::Symbol, x) =
_isinplace(name) ? :(@assert $x.writable "this NDArray isn't writable") : :()

function _outexpr(name::Symbol, x #= the first arg of `sig` =#)
if endswith(string(name), "!") # `func!`
if _isinplace(name) # `func!`
Ptr, 1, :([[MX_handle(x.handle)]]), :($x)
else
retexpr = :(NDArray(MX_NDArrayHandle(unsafe_load(hdls_ref[], 1))))
Expand Down Expand Up @@ -1124,7 +1147,10 @@ macro _remap(sig::Expr, imp::Expr)
# handler for `func!` which has side effect on first argument.
T, n_output, hdls_ref, retexpr = _outexpr(fname, _firstarg(sig))

assert_expr = _writable(fname, _firstarg(sig))

func_body = quote
$assert_expr
op_handle = _get_cached_libmx_op_handle($opname)
n_output = Ref(Cint($n_output))
hdls_ref = $hdls_ref
Expand Down Expand Up @@ -1346,6 +1372,12 @@ julia> mx.log_softmax.(x)
@_remap _mod(x::NDArray, y::NDArray) _mod(x, y)
@_remap _mod!(x::NDArray, y::NDArray) _mod(x, y)

@_remap _mod_scalar(x::NDArray, y::Real) _mod_scalar(x; scalar = y)
@_remap _mod_scalar!(x::NDArray, y::Real) _mod_scalar(x; scalar = y)

@_remap _rmod_scalar(x::NDArray, y::Real) _rmod_scalar(x; scalar = y)
@_remap _rmod_scalar!(x::NDArray, y::Real) _rmod_scalar(x; scalar = y)

################################################################################
# NDArray functions dynamically imported from libmxnet
################################################################################
Expand Down Expand Up @@ -1467,6 +1499,8 @@ const _op_import_bl = [ # import black list; do not import these funcs
"_plus",
"_minus",
"_mod",
"_mod_scalar",
"_rmod_scalar",

"dot",
"max",
Expand Down
47 changes: 47 additions & 0 deletions test/unittest/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ function test_mod()
@test copy(z) D
end

info("NDArray::mod::scalar")
let x = NDArray(A)
C = A .% 2
y = x .% 2
Expand All @@ -547,6 +548,52 @@ function test_mod()
y = 11 .% x
@test copy(y) C
end

info("NDArray::mod_from!")
let
x = NDArray(A)
y = NDArray(B)
C = A .% B
mx.mod_from!(x, y)
@test copy(x) C
end

let
x = NDArray(A)
y = NDArray(B)
C = B .% A
mx.mod_from!(y, x)

@test copy(y) C
end

info("NDArray::mod_from!::scalar")
let
x = NDArray(A)
C = A .% 2
mx.mod_from!(x, 2)
@test copy(x) C
end

info("NDArray::rmod_from!")
let
x = NDArray(A)
C = 11 .% A
mx.rmod_from!(11, x)
@test copy(x) C
end

info("NDArray::mod_from!::writable")
let
x = NDArray(A)
y = NDArray(B)
x.writable = false
y.writable = false
@test_throws AssertionError mx.mod_from!(x, y)
@test_throws AssertionError mx.mod_from!(y, x)
@test_throws AssertionError mx.mod_from!(x, 2)
@test_throws AssertionError mx.rmod_from!(2, x)
end
end # function test_mod


Expand Down

0 comments on commit 86ffb5d

Please sign in to comment.