Skip to content

Commit

Permalink
ndarray: elementwise power (#293)
Browse files Browse the repository at this point in the history
* ndarray: elementwise power

```julia
x.^2

2.^x

x.^y
```

* ndarray: elementwise power unfusion

* Update NEWS
  • Loading branch information
iblislin authored and pluskid committed Nov 8, 2017
1 parent a488d7a commit 57cc677
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 1 deletion.
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@
2.0 4.0
```

* Elementwise power of `NDArray`. (#293)
* `x.^2`
* `2.^x`
* `x.^y`
* where `x` and `y` are `NDArray`s.

## API Changes

* `reshape` of NDArray shares the same interface with Base (#272).
Expand Down
16 changes: 15 additions & 1 deletion src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,21 @@ broadcast_(::typeof(/), x::NDArray, y::NDArrayOrReal) =
broadcast_(::typeof(/), x::Real, y::NDArray) =
rdiv_from!(x, copy(y, context(y)))

import Base: ^

"""
.^(x::NDArray, y::NDArray)
.^(x::NDArray, s::Real)
.^(s::Real, x::NDArray)
Elementwise power of NDArray.
"""
^

broadcast_(::typeof(^), x::NDArray, y::NDArray) = _power(x, y)
broadcast_(::typeof(^), x::NDArray, s::Real) = _power_scalar(x, scalar=s)
broadcast_(::typeof(^), s::Real, x::NDArray) = _rpower_scalar(x, scalar=s)

"""
fill!(x, arr::NDArray)
Expand All @@ -713,7 +728,6 @@ end

fill(x, dims::Integer...) = fill(x, dims)


"""
Manipulating as Julia Arrays
----------------------------
Expand Down
72 changes: 72 additions & 0 deletions test/unittest/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,77 @@ function test_clip()
@test all(clip_down .<= copy(clipped) .<= clip_up)
end

function test_power()
info("NDArray::power")
thresh = 1e8

info("NDArray::power::Int::x.^n")
let x = mx.NDArray([1 2; 3 4])
@test eltype(x) == Int
@test copy(x.^-1) == [1 0; 0 0]
@test copy(x.^0) == [1 1; 1 1]
@test copy(x.^1) == [1 2; 3 4]
@test copy(x.^1.1) == [1 2; 3 4]
@test copy(x.^2) == [1 4; 9 16]
@test copy(x.^2.9) == [1 4; 9 16]
@test copy(x.^3) == [1 8; 27 64]
end

info("NDArray::power::Int::n.^x")
let x = mx.NDArray([1 2; 3 4])
@test eltype(x) == Int
@test copy(0.^x) == [0 0; 0 0]
@test copy(1.^x) == [1 1; 1 1]
@test copy(1.1.^x) == [1 1; 1 1]
@test copy(2.^x) == [2 4; 8 16]
@test copy(2.9.^x) == [2 4; 8 16]
@test copy(3.^x) == [3 9; 27 81]
end

info("NDArray::power::Int::x.^y")
let x = mx.NDArray([1 2; 3 4]), y = mx.NDArray([2 2; 2 2])
@test eltype(x) == Int
@test eltype(y) == Int
@test copy(x.^y) == [1 4; 9 16]
@test copy(y.^x) == [2 4; 8 16]
end

info("NDArray::power::Float32::x.^n")
let x = mx.NDArray(Float32[1 2; 3 4]), A = Float32[1 2; 3 4]
@test eltype(x) == Float32
@test copy(x.^0) == Float32[1 1; 1 1]
@test copy(x.^1) == Float32[1 2; 3 4]
@test copy(x.^2) == Float32[1 4; 9 16]
@test copy(x.^3) == Float32[1 8; 27 64]

@test reldiff(copy(x.^-1), A.^-1) < thresh
@test reldiff(copy(x.^1.1), A.^1.1) < thresh
@test reldiff(copy(x.^2.9), A.^2.9) < thresh
end

info("NDArray::power::Float32::n.^x")
let x = mx.NDArray(Float32[1 2; 3 4]), A = Float32[1 2; 3 4]
@test eltype(x) == Float32
@test copy(0.^x) == Float32[0 0; 0 0]
@test copy(1.^x) == Float32[1 1; 1 1]
@test copy(2.^x) == Float32[2 4; 8 16]
@test copy(3.^x) == Float32[3 9; 27 81]

@test reldiff(copy(1.1.^x), 1.1.^A) < thresh
@test reldiff(copy(2.9.^x), 2.9.^A) < thresh
end

info("NDArray::power::Float32::x.^y")
let x = mx.NDArray(Float32[1 2; 3 4]), y = mx.NDArray(Float32[2 2; 2 2])
@test eltype(x) == Float32
@test eltype(y) == Float32
@test copy(x.^y) == Float32[1 4; 9 16]
@test copy(y.^x) == Float32[2 4; 8 16]
end

# TODO: Float64: wait for https://github.com/apache/incubator-mxnet/pull/8012
end # function test_power

function test_sqrt()
dims = rand_dims()
info("NDArray::sqrt::dims = $dims")
Expand Down Expand Up @@ -599,6 +670,7 @@ end
test_gd()
test_saveload()
test_clip()
test_power()
test_sqrt()
test_eltype()
test_nd_as_jl()
Expand Down

0 comments on commit 57cc677

Please sign in to comment.