Skip to content

Commit

Permalink
ndarray: broadcasting along dimension on arith operators (#401)
Browse files Browse the repository at this point in the history
* `+`
* `-`
* `*`
* `/`
* `%`
* `^`

```julia
julia> x = NDArray([1 2 3;
                    4 5 6])
2×3 mx.NDArray{Int64,2} @ CPU0:
 1  2  3
 4  5  6

julia> y = NDArray([1;
                    10])
2-element mx.NDArray{Int64,1} @ CPU0:
  1
 10

julia> x .+ y
2×3 mx.NDArray{Int64,2} @ CPU0:
  2   3   4
 14  15  16
```
  • Loading branch information
iblislin committed Jan 7, 2018
1 parent 0ef5966 commit 49399fc
Show file tree
Hide file tree
Showing 4 changed files with 234 additions and 12 deletions.
28 changes: 28 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,34 @@

### `NDArray`

* Broadcasting along dimension supported on following operators (#TBD):

* `+`
* `-`
* `*`
* `/`
* `%`
* `^`

```julia
julia> x = NDArray([1 2 3;
4 5 6])
2×3 mx.NDArray{Int64,2} @ CPU0:
1 2 3
4 5 6

julia> y = NDArray([1;
10])
2-element mx.NDArray{Int64,1} @ CPU0:
1
10

julia> x .+ y
2×3 mx.NDArray{Int64,2} @ CPU0:
2 3 4
14 15 16
```

* Please use dot-call on following trigonometric functions.
Also, the `arc*` has been renamed to keep consistent with `Base`.
(#TBD)
Expand Down
40 changes: 40 additions & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,43 @@ end
@deprecate log_softmax(x::NDArray; axis = ndims(x)) log_softmax.(x, axis)

@deprecate clip(x; a_min = 0, a_max = 0) clip(x, a_min, a_max)

function broadcast_plus(x::NDArray, y::NDArray)
warn("broadcast_plus(x, y) is deprecated, use x .+ y instead.")
x .+ y
end

function broadcast_add(x::NDArray, y::NDArray)
warn("broadcast_add(x, y) is deprecated, use x .+ y instead.")
x .+ y
end

function broadcast_sub(x::NDArray, y::NDArray)
warn("broadcast_sub(x, y) is deprecated, use x .- y instead.")
x .- y
end

function broadcast_minus(x::NDArray, y::NDArray)
warn("broadcast_minus(x, y) is deprecated, use x .- y instead.")
x .- y
end

function broadcast_mul(x::NDArray, y::NDArray)
warn("broadcast_mul(x, y) is deprecated, use x .* y instead.")
x .* y
end

function broadcast_div(x::NDArray, y::NDArray)
warn("broadcast_div(x, y) is deprecated, use x ./ y instead.")
x ./ y
end

function broadcast_mod(x::NDArray, y::NDArray)
warn("broadcast_mod(x, y) is deprecated, use x .% y instead.")
x .% y
end

function broadcast_power(x::NDArray, y::NDArray)
warn("broadcast_power(x, y) is deprecated, use x.^y instead.")
x.^y
end
72 changes: 62 additions & 10 deletions src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -613,8 +613,12 @@ added together. Note at least the first or second argument needs to be an
+(x::NDArray, y::Real) = _plus_scalar(x, scalar = y)
+(y::Real, x::NDArray) = _plus_scalar(x, scalar = y)

broadcast_(::typeof(+), x::NDArray, y::NDArrayOrReal) = x + y
broadcast_(::typeof(+), x::Real, y::NDArray) = x + y
broadcast_(::typeof(+), x::NDArray, y::Real) = x + y
broadcast_(::typeof(+), x::Real, y::NDArray) = x + y

broadcast_(::typeof(+), x::NDArray{T,N}, y::NDArray{T,N}) where {T,N} = x + y
broadcast_(::typeof(+), x::NDArray{T,N}, y::NDArray{T,M}) where {T,N,M} =
_broadcast_add(x, y)

"""
sub_from!(dst::NDArray, args::NDArrayOrReal...)
Expand Down Expand Up @@ -646,8 +650,12 @@ Or create the negative of `x`.
-(x::NDArray, y::Real) = _minus_scalar(x, scalar = y)
-(y::Real, x::NDArray) = _rminus_scalar(x, scalar = y)

broadcast_(::typeof(-), x::NDArray, y::NDArrayOrReal) = x - y
broadcast_(::typeof(-), x::Real, y::NDArray) = x - y
broadcast_(::typeof(-), x::NDArray, y::Real) = x - y
broadcast_(::typeof(-), x::Real, y::NDArray) = x - y

broadcast_(::typeof(-), x::NDArray{T,N}, y::NDArray{T,N}) where {T,N} = x - y
broadcast_(::typeof(-), x::NDArray{T,N}, y::NDArray{T,M}) where {T,N,M} =
_broadcast_minus(x, y)

"""
mul_to!(dst::NDArray, arg::NDArrayOrReal)
Expand Down Expand Up @@ -675,9 +683,13 @@ Elementwise multiplication for `NDArray`.
*(x::NDArray, y::Real) = _mul_scalar(x, scalar = y)
*(y::Real, x::NDArray) = _mul_scalar(x, scalar = y)

broadcast_(::typeof(*), x::NDArray, y::Real) = x * y
broadcast_(::typeof(*), y::Real, x::NDArray) = x * y
broadcast_(::typeof(*), x::NDArray, y::NDArray) = _mul(x, y)
broadcast_(::typeof(*), x::NDArray, y::Real) = x * y
broadcast_(::typeof(*), y::Real, x::NDArray) = x * y

broadcast_(::typeof(*), x::NDArray{T,N}, y::NDArray{T,N}) where {T,N} =
_mul(x, y)
broadcast_(::typeof(*), x::NDArray{T,N}, y::NDArray{T,M}) where {T,N,M} =
_broadcast_mul(x, y)

"""
*(A::NDArray, B::NDArray)
Expand Down Expand Up @@ -735,10 +747,14 @@ of the same shape.
"""
/(x::NDArray, y::Real) = _div_scalar(x, scalar = y)

broadcast_(::typeof(/), x::NDArray, y::NDArray) = _div(x, y)
broadcast_(::typeof(/), x::NDArray, y::Real) = _div_scalar(x, scalar = y)
broadcast_(::typeof(/), y::Real, x::NDArray) = _rdiv_scalar(x, scalar = y)

broadcast_(::typeof(/), x::NDArray{T,N}, y::NDArray{T,N}) where {T,N} =
_div(x, y)
broadcast_(::typeof(/), x::NDArray{T,N}, y::NDArray{T,M}) where {T,N,M} =
_broadcast_div(x, y)

function broadcast_(::typeof(/), x::NDArray{T}, y::Real) where {T<:Integer}
@assert(round(T, y) != zero(T), "Integer divided by zero")
_div_scalar(x, scalar = y)
Expand Down Expand Up @@ -773,22 +789,30 @@ Elementwise modulo for `NDArray`.
"""
%(x::NDArray, y::Real) = _mod_scalar(x, scalar = y)

broadcast_(::typeof(%), x::NDArray, y::NDArray) = _mod(x, y)
broadcast_(::typeof(%), x::NDArray, y::Real) = _mod_scalar(x, y)
broadcast_(::typeof(%), y::Real, x::NDArray) = _rmod_scalar(x, y)

broadcast_(::typeof(%), x::NDArray{T,N}, y::NDArray{T,N}) where {T,N} =
_mod(x, y)
broadcast_(::typeof(%), x::NDArray{T,N}, y::NDArray{T,M}) where {T,N,M} =
_broadcast_mod(x, y)

import Base: ^

# document of `.^` is merged into SymbolicNode's

broadcast_(::typeof(^), x::NDArray, y::NDArray) = _power(x, y)
broadcast_(::typeof(^), x::NDArray, s::Real) = _power_scalar(x, scalar = s)
broadcast_(::typeof(^), s::Real, x::NDArray) = _rpower_scalar(x, scalar = s)

broadcast_(::typeof(^), ::Irrational{:e}, x::NDArray) = exp(x)
broadcast_(::typeof(^), x::NDArray, s::Irrational) = _power_scalar(x, scalar = s)
broadcast_(::typeof(^), s::Irrational, x::NDArray) = _rpower_scalar(x, scalar = s)

broadcast_(::typeof(^), x::NDArray{T,N}, y::NDArray{T,N}) where {T,N} =
_power(x, y)
broadcast_(::typeof(^), x::NDArray{T,N}, y::NDArray{T,M}) where {T,N,M} =
_broadcast_power(x, y)

"""
fill!(arr::NDArray, x)
Expand Down Expand Up @@ -1373,6 +1397,24 @@ julia> mx.log_softmax.(x)
@_remap _rmod_scalar(x::NDArray, y::Real) _rmod_scalar(x; scalar = y)
@_remap _rmod_scalar!(x::NDArray, y::Real) _rmod_scalar(x; scalar = y)

@_remap _broadcast_add(x::NDArray, y::NDArray) broadcast_add(x, y)
@_remap _broadcast_add!(x::NDArray, y::NDArray) broadcast_add(x, y)

@_remap _broadcast_minus(x::NDArray, y::NDArray) broadcast_minus(x, y)
@_remap _broadcast_minus!(x::NDArray, y::NDArray) broadcast_minus(x, y)

@_remap _broadcast_mul(x::NDArray, y::NDArray) broadcast_mul(x, y)
@_remap _broadcast_mul!(x::NDArray, y::NDArray) broadcast_mul(x, y)

@_remap _broadcast_div(x::NDArray, y::NDArray) broadcast_div(x, y)
@_remap _broadcast_div!(x::NDArray, y::NDArray) broadcast_div(x, y)

@_remap _broadcast_mod(x::NDArray, y::NDArray) broadcast_mod(x, y)
@_remap _broadcast_mod!(x::NDArray, y::NDArray) broadcast_mod(x, y)

@_remap _broadcast_power(x::NDArray, y::NDArray) broadcast_power(x, y)
@_remap _broadcast_power!(x::NDArray, y::NDArray) broadcast_power(x, y)

################################################################################
# NDArray functions dynamically imported from libmxnet
################################################################################
Expand Down Expand Up @@ -1529,6 +1571,16 @@ const _op_import_bl = [ # import black list; do not import these funcs
"relu",
"softmax",
"log_softmax",

# broadcast
"broadcast_add",
"broadcast_plus",
"broadcast_minus",
"broadcast_sub",
"broadcast_mul",
"broadcast_div",
"broadcast_mod",
"broadcast_power",
]

macro _import_ndarray_functions()
Expand Down
106 changes: 104 additions & 2 deletions test/unittest/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ using ..Main: rand_dims
################################################################################
# Test Implementations
################################################################################
rand_tensors(dims::NTuple{N, Int}) where {N} = rand_tensors(mx.MX_float, dims)
function rand_tensors(::Type{T}, dims::NTuple{N, Int}) where {N, T}
rand_tensors(dims::NTuple{N,Int}) where {N} = rand_tensors(mx.MX_float, dims)
function rand_tensors(::Type{T}, dims::NTuple{N,Int}) where {N,T}
tensor = rand(T, dims)
array = copy(tensor, mx.cpu())
return (tensor, array)
Expand Down Expand Up @@ -330,6 +330,23 @@ function test_plus()
y = x .+ 2.9
@test copy(y) == [3, 4, 5]
end

info("NDArray::broadcast_add")
let
A = [1 2 3;
4 5 6]
B = [1,
2]
x = NDArray(A)
y = NDArray(B)

z = x .+ y
@test copy(z) == A .+ B

# TODO
# @inplace x .+= y
# @test copy(x) == A .+ B
end
end

function test_minus()
Expand Down Expand Up @@ -386,6 +403,23 @@ function test_minus()
let x = mx.NDArray([1, 2, 3])
@test copy(x .- π) [-2, -1, 0]
end

info("NDArray::broadcast_minus")
let
A = [1 2 3;
4 5 6]
B = [1,
2]
x = NDArray(A)
y = NDArray(B)

z = x .- y
@test copy(z) == A .- B

# TODO
# @inplace x .-= y
# @test copy(x) == A .- B
end
end

function test_mul()
Expand Down Expand Up @@ -445,6 +479,23 @@ function test_mul()
@test eltype(x) == Int
@test copy(y) == [3, 6, 9]
end

info("NDArray::broadcast_mul")
let
A = [1 2 3;
4 5 6]
B = [1,
2]
x = NDArray(A)
y = NDArray(B)

z = x .* y
@test copy(z) == A .* B

# TODO
# @inplace x .*= y
# @test copy(x) == A .* B
end
end

function test_div()
Expand Down Expand Up @@ -499,6 +550,23 @@ function test_div()

@test_throws AssertionError x ./ 0.5
end

info("NDArray::broadcast_div")
let
A = Float32[1 2 3;
4 5 6]
B = Float32[1,
2]
x = NDArray(A)
y = NDArray(B)

z = x ./ y
@test copy(z) == A ./ B

# TODO
# @inplace x ./= y
# @test copy(x) == A ./ B
end
end


Expand Down Expand Up @@ -624,6 +692,23 @@ function test_mod()
@inplace x .%= y
@test copy(x) C
end

info("NDArray::broadcast_mod")
let
A = [1 2 3;
4 5 6]
B = [1,
2]
x = NDArray(A)
y = NDArray(B)

z = x .% y
@test copy(z) == A .% B

# TODO
# @inplace x .%= y
# @test copy(x) == A .% B
end
end # function test_mod


Expand Down Expand Up @@ -788,6 +873,23 @@ function test_power()
end

# TODO: Float64: wait for https://github.com/apache/incubator-mxnet/pull/8012

info("NDArray::broadcast_power")
let
A = [1 2 3;
4 5 6]
B = [1,
2]
x = NDArray(A)
y = NDArray(B)

z = x.^y
@test copy(z) == A.^B

# TODO
# @inplace x .^= y
# @test copy(x) == A.^B
end
end # function test_power

function test_sqrt()
Expand Down

0 comments on commit 49399fc

Please sign in to comment.