Skip to content

Commit

Permalink
ndarray: implement rdiv (#292)
Browse files Browse the repository at this point in the history
* ndarray: implement rdiv

e.g.

```julia
1 ./ mx.NDArray(Float32[1 2; 3 4])
```

* typo
  • Loading branch information
iblislin authored and pluskid committed Nov 6, 2017
1 parent 4f182ee commit a3317f1
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 8 deletions.
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
## New API

* `deepcopy` for NDArray (#273)

* `scalar ./ NDArray` is available now. (#292)
* `fill` and `fill!` for NDArray (#TBD)
An API correspond to Python's `mx.nd.full()`

Expand Down
31 changes: 24 additions & 7 deletions src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -682,25 +682,42 @@ function div_from!(dst :: NDArray, arg :: Union{Real, NDArray})
end
end

"""
Elementwise division of NDArray
"""
div(x::NDArray, y::NDArray) = _div(x, y)
div(x::NDArray, s::Real) = _div_scalar(x, scalar=s)
div(s::Real, x::NDArray) = _rdiv_scalar(x, scalar=s)

import Base: /
"""
./(arg0 :: NDArray, arg :: Union{Real, NDArray})
Elementwise dividing an `NDArray` by a scalar or another `NDArray` of the same shape.
"""
@compatdot function Base.broadcast(::typeof(/), arg0 :: NDArray, arg :: Union{Real, NDArray})
ret = copy(arg0, context(arg0))
div_from!(ret, arg)
@compatdot function Base.broadcast(::typeof(/), arg0 :: NDArray,
arg :: Union{Real, NDArray})
div(arg0, arg)
end

@compatdot function Base.broadcast(::typeof(/), arg0 :: Real, arg :: NDArray)
div(arg0, arg)
end

"""
/(arg0 :: NDArray, arg :: Real)
Divide an `NDArray` by a scalar. Matrix division (solving linear systems) is not implemented yet.
Divide an `NDArray` by a scalar.
Matrix division (solving linear systems) is not implemented yet.
"""
function /(arg0 :: NDArray, arg :: Real)
arg0 ./ arg
end
/(arg0 :: NDArray, arg :: Real) = div(arg0, arg)

"""
/(arg0 :: Real, arg :: NDArray)
Elementwise divide a scalar by an `NDArray`.
"""
/(arg0 :: Real, arg :: NDArray) = div(arg0, arg)


"""
Expand Down
40 changes: 40 additions & 0 deletions test/unittest/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,47 @@ function test_div()
t6, a6 = rand_tensors(Float16, dims)
scalar_large = 1e4
@test reldiff(t6 / scalar_large, copy(a6 ./ scalar_large)) < 1e-1

let x = mx.NDArray([1 2; 3 4])
@test eltype(x) == Int
@test copy(x / 2) == [0 1; 1 2]
@test copy(x / 2.5) == [0 1; 1 2]
@test copy(x / 2.9) == [0 1; 1 2]
end
end


function test_rdiv()
info("NDarray::rdiv")

info("NDarray::rdiv::Inf16")
let x = 1 ./ mx.zeros(Float16, 4)
@test copy(x) == [Inf16, Inf16, Inf16, Inf16]
end

info("NDarray::rdiv::Inf32")
let x = 1 ./ mx.zeros(Float32, 4)
@test copy(x) == [Inf32, Inf32, Inf32, Inf32]
end

info("NDarray::rdiv::Inf64")
let x = 1 ./ mx.zeros(Float64, 4)
@test copy(x) == [Inf64, Inf64, Inf64, Inf64]
end

info("NDarray::rdiv::Int")
let x = 1 ./ mx.NDArray([1 2; 3 4])
@test copy(x) == [1 0; 0 0]
end

info("NDarray::rdiv::Float32")
let x = 1 ./ mx.NDArray(Float32[1 2; 3 4])
y = 1 ./ Float32[1 2; 3 4]
@test reldiff(copy(x), y) < 1e8
end
end # function test_rdiv


function test_gd()
dims = rand_dims()
tw, aw = rand_tensors(dims)
Expand Down Expand Up @@ -551,6 +590,7 @@ end
test_minus()
test_mul()
test_div()
test_rdiv()
test_gd()
test_saveload()
test_clip()
Expand Down

0 comments on commit a3317f1

Please sign in to comment.