Skip to content

Commit

Permalink
ndarray: change internal api of mul/div to help autograd (#366)
Browse files Browse the repository at this point in the history
  • Loading branch information
iblislin authored and pluskid committed Dec 9, 2017
1 parent 233fcfc commit eb819b0
Showing 1 changed file with 13 additions and 15 deletions.
28 changes: 13 additions & 15 deletions src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -650,12 +650,12 @@ import Base: *
Currently only multiplication a scalar with an `NDArray` is implemented.
"""
*(x:: NDArray, y::Real) = x .* y
*(x::Real, y::NDArray) = y .* x
*(x::NDArray, y::Real) = _mul_scalar(x, scalar = y)
*(y::Real, x::NDArray) = _mul_scalar(x, scalar = y)

broadcast_(::typeof(*), x::NDArray, y::NDArrayOrReal) =
mul_to!(copy(x, context(x)), y)
broadcast_(::typeof(*), x::Real, y::NDArray) = y .* x
broadcast_(::typeof(*), x::NDArray, y::Real) = x * y
broadcast_(::typeof(*), y::Real, x::NDArray) = x * y
broadcast_(::typeof(*), x::NDArray, y::NDArray) = _mul(x, y)

"""
*(A::NDArray, B::NDArray)
Expand Down Expand Up @@ -703,25 +703,23 @@ of the same shape.
* Matrix division (solving linear systems) is not implemented yet.
"""
/(x::NDArray, y::Real) = x ./ y
/(x::NDArray, y::Real) = _div_scalar(x, scalar = y)

broadcast_(::typeof(/), x::NDArray, y::NDArrayOrReal) =
div_from!(copy(x, context(x)), y)

broadcast_(::typeof(/), x::Real, y::NDArray) =
rdiv_from!(x, copy(y, context(y)))
broadcast_(::typeof(/), x::NDArray, y::NDArray) = _div(x, y)
broadcast_(::typeof(/), x::NDArray, y::Real) = _div_scalar(x, scalar = y)
broadcast_(::typeof(/), y::Real, x::NDArray) = _rdiv_scalar(x, scalar = y)

import Base: ^

# document of `.^` is merged into SymbolicNode's

broadcast_(::typeof(^), x::NDArray, y::NDArray) = _power(x, y)
broadcast_(::typeof(^), x::NDArray, s::Real) = _power_scalar(x, scalar=s)
broadcast_(::typeof(^), s::Real, x::NDArray) = _rpower_scalar(x, scalar=s)
broadcast_(::typeof(^), x::NDArray, s::Real) = _power_scalar(x, scalar = s)
broadcast_(::typeof(^), s::Real, x::NDArray) = _rpower_scalar(x, scalar = s)

broadcast_(::typeof(^), ::Irrational{:e}, x::NDArray) = exp(x)
broadcast_(::typeof(^), x::NDArray, s::Irrational) = _power_scalar(x, scalar=s)
broadcast_(::typeof(^), s::Irrational, x::NDArray) = _rpower_scalar(x, scalar=s)
broadcast_(::typeof(^), x::NDArray, s::Irrational) = _power_scalar(x, scalar = s)
broadcast_(::typeof(^), s::Irrational, x::NDArray) = _rpower_scalar(x, scalar = s)

"""
fill!(arr::NDArray, x)
Expand Down

0 comments on commit eb819b0

Please sign in to comment.