Skip to content

Commit

Permalink
ndarray: add modulo operator (#373)
Browse files Browse the repository at this point in the history
* ndarray: add modulo operator

* add news
  • Loading branch information
iblislin authored and pluskid committed Dec 15, 2017
1 parent 5908d97 commit a941f3a
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 5 deletions.
11 changes: 11 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,17 @@
4.0
```

* modulo operator. (#TBD)

```julia
x = NDArray(...)
y = NDArray(...)

x .% y
x .% 2
2 .% x
```

* Transposing a column `NDArray` to a row `NDArray` is supported now. (#TBD)

```julia
Expand Down
3 changes: 2 additions & 1 deletion src/broadcast.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using TakingBroadcastSeriously: Broadcasted, unwrap

for f in :[tan, asin, acos, atan,
for f in :[%,
tan, asin, acos, atan,
sinh, cosh, tanh, asinh, acosh, atanh].args
# copy from TakingBroadcastSeriously
@eval Base.$f(a::Broadcasted...) = Broadcasted(broadcast_($f, unwrap.(a)...))
Expand Down
27 changes: 23 additions & 4 deletions src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -709,17 +709,32 @@ broadcast_(::typeof(/), x::NDArray, y::NDArray) = _div(x, y)
broadcast_(::typeof(/), x::NDArray, y::Real) = _div_scalar(x, scalar = y)
broadcast_(::typeof(/), y::Real, x::NDArray) = _rdiv_scalar(x, scalar = y)

import Base: %

"""
.%(x::NDArray, y::NDArray)
.%(x::NDArray, y::Real)
.%(x::Real, y::NDArray)
Elementwise modulo for `NDArray`.
"""
%(x::NDArray, y::Real) = _mod_scalar(x, scalar = y)

broadcast_(::typeof(%), x::NDArray, y::NDArray) = _mod(x, y)
broadcast_(::typeof(%), x::NDArray, y::Real) = _mod_scalar(x, scalar = y)
broadcast_(::typeof(%), y::Real, x::NDArray) = _rmod_scalar(x, scalar = y)

import Base: ^

# document of `.^` is merged into SymbolicNode's

broadcast_(::typeof(^), x::NDArray, y::NDArray) = _power(x, y)
broadcast_(::typeof(^), x::NDArray, s::Real) = _power_scalar(x, scalar = s)
broadcast_(::typeof(^), s::Real, x::NDArray) = _rpower_scalar(x, scalar = s)
broadcast_(::typeof(^), x::NDArray, s::Real) = _power_scalar(x, scalar = s)
broadcast_(::typeof(^), s::Real, x::NDArray) = _rpower_scalar(x, scalar = s)

broadcast_(::typeof(^), ::Irrational{:e}, x::NDArray) = exp(x)
broadcast_(::typeof(^), x::NDArray, s::Irrational) = _power_scalar(x, scalar = s)
broadcast_(::typeof(^), s::Irrational, x::NDArray) = _rpower_scalar(x, scalar = s)
broadcast_(::typeof(^), x::NDArray, s::Irrational) = _power_scalar(x, scalar = s)
broadcast_(::typeof(^), s::Irrational, x::NDArray) = _rpower_scalar(x, scalar = s)

"""
fill!(arr::NDArray, x)
Expand Down Expand Up @@ -1147,6 +1162,9 @@ _mxsig[:reshape] = :(reshape(arr; shape = dim, reverse = !reverse))
@_remap _minus(x::NDArray, y::NDArray) _minus(x, y)
@_remap _minus!(x::NDArray, y::NDArray) _minus(x, y)

@_remap _mod(x::NDArray, y::NDArray) _mod(x, y)
@_remap _mod!(x::NDArray, y::NDArray) _mod(x, y)

################################################################################
# NDArray functions dynamically imported from libmxnet
################################################################################
Expand Down Expand Up @@ -1265,6 +1283,7 @@ const _op_import_bl = [ # import black list; do not import these funcs
# arithmetic
"_plus",
"_minus",
"_mod",

"dot",
"max",
Expand Down
32 changes: 32 additions & 0 deletions test/unittest/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,37 @@ function test_rdiv()
end # function test_rdiv


function test_mod()
info("NDArray::mod")
const A = [1 2; 3 4]
const B = [1 1; 3 3]

let x = NDArray(A), y = NDArray(B)
C = A .% B
D = B .% A

w = x .% y
z = y .% x

@test copy(w) C
@test copy(z) D
end

let x = NDArray(A)
C = A .% 2
y = x .% 2
@test copy(y) C
end

info("NDArray::rmod")
let x = NDArray(A)
C = 11 .% A
y = 11 .% x
@test copy(y) C
end
end # function test_mod


function test_gd()
dims = rand_dims()
tw, aw = rand_tensors(dims)
Expand Down Expand Up @@ -888,6 +919,7 @@ end # function test_hyperbolic
test_mul()
test_div()
test_rdiv()
test_mod()
test_gd()
test_saveload()
test_clip()
Expand Down

0 comments on commit a941f3a

Please sign in to comment.