Skip to content

Commit

Permalink
sym: broadcast unfusion (#314)
Browse files Browse the repository at this point in the history
* sym: broadcast unfusion for `add`

* sym: broadcast unfusion for `minus`

* sym: broadcast unfusion for `multiplication`

* sym: broadcast unfusion for `div`

* sym: broadcast unfusion for `power`

* sym: broadcast unfusion for `power` with irrational
  • Loading branch information
iblislin authored and pluskid committed Nov 13, 2017
1 parent 09b9718 commit 9304e6e
Show file tree
Hide file tree
Showing 4 changed files with 338 additions and 84 deletions.
4 changes: 2 additions & 2 deletions src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ import Base: /
"""
./(x::NDArray, y::NDArray)
./(x::NDArray, y::Real)
./(x:: Real, y::NDArray)
./(x::Real, y::NDArray)
* Elementwise dividing an `NDArray` by a scalar or another `NDArray`
of the same shape.
Expand All @@ -746,7 +746,7 @@ import Base: ^
.^(x::NDArray, s::Real)
.^(s::Real, x::NDArray)
Elementwise power of NDArray.
Elementwise power of `NDArray`.
"""
^

Expand Down
173 changes: 92 additions & 81 deletions src/symbolic-node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,15 @@ Make a new node by composing `self` with `args`. Or the arguments
can be specified using keyword arguments.
"""
mutable struct SymbolicNode
handle :: MX_SymbolHandle
handle::MX_SymbolHandle
end
function Base.unsafe_convert(::Type{MX_handle}, obj::SymbolicNode)

const SymbolicNodeOrReal = Union{SymbolicNode, Real}

@unfuse SymbolicNode # for broadcasting

Base.unsafe_convert(::Type{MX_handle}, obj::SymbolicNode) =
Base.unsafe_convert(MX_handle, obj.handle)
end
Base.convert(t::Type{MX_handle}, obj::SymbolicNode) = Base.unsafe_convert(t, obj)
Base.cconvert(t::Type{MX_handle}, obj::SymbolicNode) = Base.unsafe_convert(t, obj)

Expand Down Expand Up @@ -483,105 +487,112 @@ function Base.getindex(self :: SymbolicNode, idx :: Int)
return SymbolicNode(MX_SymbolHandle(ref_hdr[]))
end

import Base.broadcast
import Base: +
function +(self :: SymbolicNode, args :: Union{SymbolicNode,Real}...)
ret = self
for arg in args
if isa(arg, SymbolicNode)
ret = _Plus(ret, arg)

"""
+(args...)
.+(args...)
Elementwise summation of `SymbolicNode`.
"""
function +(x::SymbolicNode, ys::SymbolicNodeOrReal...)
ret = x
for y ys
if y isa SymbolicNode
ret = _plus(ret, y)
else
ret = _PlusScalar(ret, scalar=MX_float(arg))
ret = _plus_scalar(ret, scalar=MX_float(y))
end
end
ret
end
@compatdot function Base.broadcast(::typeof(+), self::SymbolicNode, args::Union{SymbolicNode,Real}...)
+(self, args...)
end
function +(s1 :: Real, self :: SymbolicNode, args :: Union{SymbolicNode,Real}...)
+(self, s1, args...)
end
@compatdot function Base.broadcast(::typeof(+), s1::Real, self::SymbolicNode,
args::Union{SymbolicNode,Real}...)
+(self, s1, args...)
end

+(s::Real, x::SymbolicNode, ys::SymbolicNodeOrReal...) = +(x + s, ys...)

broadcast_(::typeof(+), x::SymbolicNode, ys::SymbolicNodeOrReal...) = +(x, ys...)
broadcast_(::typeof(+), s::Real, x::SymbolicNode, ys::SymbolicNodeOrReal...) = +(x + s, ys...)

import Base: -
function -(self :: SymbolicNode, arg :: SymbolicNode)
_Minus(self, arg)
end
@compatdot function Base.broadcast(::typeof(-), self :: SymbolicNode, arg :: SymbolicNode)
-(self, arg)
end
function -(self :: SymbolicNode, arg :: Real)
_MinusScalar(self, scalar=MX_float(arg))
end
@compatdot function Base.broadcast(::typeof(-), self :: SymbolicNode, arg :: Real)
-(self, arg)
end

function -(arg :: Real, self :: SymbolicNode)
_RMinusScalar(self, scalar=arg)
end
@compatdot function Base.broadcast(::typeof(-), arg :: Real, self :: SymbolicNode)
-(arg, self)
end
"""
-(x, y)
.-(x, y)
function -(self :: SymbolicNode)
-(0, self)
end
Elementwise substraction of `SymbolicNode`.
Operating with `Real` is available.
"""
x::SymbolicNode - y::SymbolicNode = _minus(x, y)
x::SymbolicNode - s::Real = _minus_scalar(x, scalar=MX_float(s))
s::Real - x::SymbolicNode = _rminus_scalar(x, scalar=MX_float(s))

-(x::SymbolicNode) = 0 - x

broadcast_(::typeof(-), x::SymbolicNode, y::SymbolicNodeOrReal) = x - y
broadcast_(::typeof(-), s::Real, x::SymbolicNode) = s - x

import Base: *
@compatdot function Base.broadcast(::typeof(*), self :: SymbolicNode, args :: Union{SymbolicNode,Real}...)
ret = self
for arg in args
if isa(arg, SymbolicNode)
ret = _Mul(ret, arg)

"""
.*(x, y)
Elementwise multiplication of `SymbolicNode`.
"""
x::SymbolicNode * s::Real = _mul_scalar(x, scalar=MX_float(s))
s::Real * x::SymbolicNode = _mul_scalar(x, scalar=MX_float(s))

function broadcast_(::typeof(*), x::SymbolicNode, ys::SymbolicNodeOrReal...)
ret = x
for y in ys
if y isa SymbolicNode
ret = _mul(ret, y)
else
ret = _MulScalar(ret, scalar=MX_float(arg))
ret = _mul_scalar(ret, scalar=MX_float(y))
end
end
ret
end
@compatdot function Base.broadcast(::typeof(*), arg :: Real, self :: SymbolicNode,
args :: Union{SymbolicNode,Real}...)
broadcast(*, self, arg, args...)
end
function *(arg :: Real, self :: SymbolicNode)
_MulScalar(self, scalar=arg)
end
function *(self :: SymbolicNode, arg :: Real)
*(arg, self)
end

broadcast_(::typeof(*), s::Real, x::SymbolicNode, ys::SymbolicNodeOrReal...) =
broadcast_(*, x * s, ys...)

import Base: /
@compatdot function Base.broadcast(::typeof(/), self :: SymbolicNode, arg :: SymbolicNode)
_Div(self, arg)
end
@compatdot function Base.broadcast(::typeof(/), self :: SymbolicNode, arg :: Real)
_DivScalar(self, scalar=MX_float(arg))
end
function /(self :: SymbolicNode, arg :: Real)
self ./ arg
end
function /(arg :: Real, self :: SymbolicNode)
_RDivScalar(self, scalar=arg)
end
@compatdot function Base.broadcast(::typeof(/), arg :: Real, self :: SymbolicNode)
_RDivScalar(self, scalar=arg)
end

"""
./(x, y)
* Elementwise dividing a `SymbolicNode` by a scalar or another `SymbolicNode`
of the same shape.
* Elementwise divide a scalar by an `SymbolicNode`.
* Matrix division (solving linear systems) is not implemented yet.
"""
x::SymbolicNode / s::Real = _DivScalar(x, scalar=MX_float(s))

broadcast_(::typeof(/), x::SymbolicNode, y::SymbolicNode) = _div(x, y)
broadcast_(::typeof(/), x::SymbolicNode, s::Real) = _div_scalar(x, scalar=MX_float(s))
broadcast_(::typeof(/), s::Real, x::SymbolicNode) = _rdiv_scalar(x, scalar=MX_float(s))


import Base: ^
@compatdot function Base.broadcast(::typeof(^), self :: SymbolicNode, pow :: SymbolicNode)
_Power(self, pow)
end
@compatdot function Base.broadcast(::typeof(^), self :: SymbolicNode, pow :: AbstractFloat)
_PowerScalar(self, scalar=pow)
end
function ^(self :: SymbolicNode, pow :: AbstractFloat)
self .^ pow
end

"""
.^(x, y)
Elementwise power of `SymbolicNode`.
Operating with `Real` is available.
"""
^

broadcast_(::typeof(^), x::SymbolicNode, y::SymbolicNode) = _power(x, y)
broadcast_(::typeof(^), x::SymbolicNode, s::Real) = _power_scalar(x, scalar=MX_float(s))
broadcast_(::typeof(^), s::Real, x::SymbolicNode) = _rpower_scalar(x, scalar=MX_float(s))

broadcast_(::typeof(^), ::Irrational{:e}, x::SymbolicNode) = exp(x)
broadcast_(::typeof(^), x::SymbolicNode, s::Irrational) =
_power_scalar(x, scalar=MX_float(s))
broadcast_(::typeof(^), s::Irrational, x::SymbolicNode) =
_rpower_scalar(x, scalar=MX_float(s))

function _compose!(node :: SymbolicNode; kwargs...)
name = char_p(0)
Expand Down
12 changes: 12 additions & 0 deletions test/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,15 @@ function mlpchain()
mx.Activation(act_type=:relu) =>
mx.FullyConnected(name=:fc2, num_hidden=10)
end

"""
execution helper of SymbolicNode
"""
function exec(x::mx.SymbolicNode; feed...)
ks, vs = zip(feed...)
vs′ = mx.NDArray.(vs)

e = mx.bind(x, context = mx.cpu(), args = Dict(zip(ks, vs′)))
mx.forward(e)
e.outputs
end

0 comments on commit 9304e6e

Please sign in to comment.