Skip to content

Commit

Permalink
ndarray: elementwise power for irrational (#310)
Browse files Browse the repository at this point in the history
  • Loading branch information
iblislin authored and pluskid committed Nov 9, 2017
1 parent 4919273 commit 1fc03f2
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 5 deletions.
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@
* `x.^y`
* where `x` and `y` are `NDArray`s.

* Elementwise power of irrational and `NDArray` (#TBD)
* `e.^x`
* `x.^e`
* `π.^x`

## API Changes

* `reshape` of NDArray shares the same interface with Base (#272).
Expand Down
11 changes: 6 additions & 5 deletions src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,12 @@ end
#
# TODO: find a better solution in case this cause issues in the future.
################################################################################
dump_mx_param(val::Any) = string(val)
dump_mx_param(val::Float64) = @sprintf("%.16e", val)
dump_mx_param(val::Float32) = @sprintf("%.8e", val)
dump_mx_param(val::Float16) = @sprintf("%.4e", val)
dump_mx_param(shape::NTuple{N, T}) where {N, T<:Integer} =
dump_mx_param(val::Any) = string(val)
dump_mx_param(val::Float64) = @sprintf("%.16e", val)
dump_mx_param(val::Float32) = @sprintf("%.8e", val)
dump_mx_param(val::Float16) = @sprintf("%.4e", val)
dump_mx_param(val::Irrational) = @sprintf("%.16e", val)
dump_mx_param(shape::NTuple{N, <:Integer}) where N =
string(tuple(flipdim([shape...], 1)...))


Expand Down
4 changes: 4 additions & 0 deletions src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,10 @@ broadcast_(::typeof(^), x::NDArray, y::NDArray) = _power(x, y)
broadcast_(::typeof(^), x::NDArray, s::Real) = _power_scalar(x, scalar=s)
broadcast_(::typeof(^), s::Real, x::NDArray) = _rpower_scalar(x, scalar=s)

broadcast_(::typeof(^), ::Irrational{:e}, x::NDArray) = exp(x)
broadcast_(::typeof(^), x::NDArray, s::Irrational) = _power_scalar(x, scalar=s)
broadcast_(::typeof(^), s::Irrational, x::NDArray) = _rpower_scalar(x, scalar=s)

"""
fill!(x, arr::NDArray)
Expand Down
16 changes: 16 additions & 0 deletions test/unittest/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,22 @@ function test_power()
@test copy(y.^x) == Float32[2 4; 8 16]
end

info("NDArray::power::e.^x::x.^e")
let x = mx.zeros(2, 3), A = [1 1 1; 1 1 1]
@test copy(e.^x) A
end

let A = Float32[1 2; 3 4], x = mx.NDArray(A)
@test copy(e.^x) e.^A
@test copy(x.^e) A.^e
end

info("NDArray::power::π.^x::x.^π")
let A = Float32[1 2; 3 4], x = mx.NDArray(A)
@test copy.^x) π.^A
@test copy(x.^π) A.^π
end

# TODO: Float64: wait for https://github.com/apache/incubator-mxnet/pull/8012
end # function test_power

Expand Down

0 comments on commit 1fc03f2

Please sign in to comment.