In [1]:
abstract type Node end
abstract type Operator end
abstract type LeafNode <: Node end

In [23]:
mutable struct Variable{T} <: LeafNode
    value::T
    grad::T
    name::String
end
Variable(value) = Variable(value, zero(value), "?")
Variable(value, name) = Variable(value, zero(value), name)

Variable

In [3]:
struct Tensor{T} <: LeafNode
    value::Array{Variable{T}}
    name::String
end
Tensor(value, name) = Tensor(Variable.(value), name)

Tensor

In [4]:
import Base: zero, one
zero(::Variable{T}) where T = Variable(zero(T), zero(T), "0")
one(::Variable{T}) where T = Variable(one(T), zero(T), "1")
nothing

In [5]:
struct Method{OT} <: Operator
    f::OT
end

In [6]:
struct Broadcasted{OT} <: Operator
    f::OT
end

In [7]:
OT = Operator
AT = Tuple
KT = NamedTuple
struct ComputableNode <: Node
    op::OT
    args::AT
    kwargs::KT
end
ComputableNode(op::Function, args, kwargs) = ComputableNode(Method(op), args, kwargs)
ComputableNode(op, args)                   = ComputableNode(op, args, NamedTuple())

ComputableNode

In [8]:
mutable struct CachedNode{NT <: Node, OUT} <: Node
    node::NT
    out::OUT
end

function register(op, args...; kwargs...)
    node = ComputableNode(op, args, kwargs.data)
    out  = forward(node)
    CachedNode(node, out)
end

register (generic function with 1 method)

In [9]:
arg(x::ComputableNode, i::Int) = x.args[i]
args(x::ComputableNode) = x.args
kwargs(x::ComputableNode) = x.kwargs
operator(x::ComputableNode) = x.f

arg(x::CachedNode, i::Int) = x.node.args[i]
args(x::CachedNode) = x.node.args
kwargs(x::CachedNode) = x.node.kwargs
operator(x::CachedNode) = x.node.f

operator (generic function with 2 methods)

In [34]:
import Base: show
mime = "text/plain"
show(io::IO, x::Method)         = print(io, "fn ",  x.f);
show(io::IO, x::Broadcasted)    = print(io, "bc ",  x.f);
show(io::IO, x::Operator)       = print(io, "op ",  x.f);
show(io::IO, x::Variable)       = print(io, "var ", x.name, " ", typeof(x.value), " ∇ ");
show(io::IO, x::Tensor)         = print(io, "tsr ", x.name, " ", join(size(value(x)), "×"));
show(io::IO, x::CachedNode)     = begin
    print(io, "cached ", x.node);# show(io, mime, x.out);
end
show(io::IO, x::ComputableNode) = print(io, "[", x.op, "](", join(x.args,","), ")");

In [35]:
forward(cached::CachedNode) = cached.out = forward(cached.node)
forward(node::ComputableNode) = forward(node.op, map(forward, node.args)...; map(forward, node.kwargs)...)
forward(op::Operator, args...; kwargs...) = op.f(args...; kwargs...)
forward(op::Broadcasted, args...; kwargs...) = op.f.(args...)
forward(leaf::LeafNode) = value(leaf)
forward(x) = x
forward(x::NT) where {NT <: Node} = error("forward method is not implemented for node type: $NT")

forward (generic function with 7 methods)

In [12]:
value(x::CachedNode) = value(x.out)
value(x::Variable) = x.value
value(x::Tensor) = x.value .|> value 
value(x) = x
value(x::NT) where {NT <: Node} = error("Expected value in this node $x of type $T
 check if you defined a non-cached node
 or overload value function for your node.")
grad(x::Variable) = x.grad
grad(x::Tensor) = x.value .|> grad

grad (generic function with 2 methods)

In [38]:
function backward(x::Variable, grad)
    x.grad += grad
    nothing
end

function backward(x::Tensor, grad)
    backward.(x.value, grad)
    nothing
end

function backward(cached::CachedNode, f::Function, grad)
    println("@", cached)
    grad_inputs = gradient(cached, grad)
    for (each, each_grad) in zip(args(cached), grad_inputs)
        backward(each, each_grad)
    end
    nothing
end

backward(cached::CachedNode) = backward(cached, 1.0)
backward(cached::CachedNode, grad) = backward(cached, cached.node.op, grad)
backward(cached::CachedNode, op::Method, grad) = backward(cached, op.f, grad)
backward(cached::CachedNode, op::Broadcasted, grad) = backward(cached, op.f, grad)

backward (generic function with 7 methods)

In [39]:
gradient(x::CachedNode, grad) = gradient(x.node.op, grad, x.out, map(value, x.node.args)...; map(value, x.node.kwargs)...)
gradient(x::Broadcasted, grad, out, arg) = begin
    println("gradient grad ", grad)
    println("gradient out  ", out[:])
    println("gradient arg  ", arg[:])
    ∇(a,b) = gradient(x.f, grad, a, b)[1]
    ∇.(out, arg)
end
gradient(x::Method, grad, out, args...; kwargs...) = gradient(x.f, grad, out, args...; kwargs...)
gradient(op, grad, out, args...; kwargs...) = error("gradient of operator $op is not defined\n
 Possible Fix:\n
 define one of the following:\n
 1. gradient(::typeof($op), grad, out, args...; kwargs...)\n
 2. gradient(op::Method{typeof($op)}, grad, out, args...; kwargs...)\n")

gradient (generic function with 21 methods)

In [15]:
import Base: +, -, *, /
+(x::Node) = register(+, x)
-(x::Node) = register(-, x)
gradient(::typeof(+), grad, output, x) = (grad * 1, )
gradient(::typeof(-), grad, output, x) = (grad *-1, )
+(x::Node, y::Node) = register(+, x, y)
-(x::Node, y::Node) = register(-, x, y)
*(x::Node, y::Node) = register(*, x, y)
/(x::Node, y::Node) = register(/, x, y)
gradient(::typeof(+), grad, output, x, y) = (grad * one(x),   grad * one(y))
gradient(::typeof(-), grad, output, x, y) = (grad * one(x),   grad *-one(y))
gradient(::typeof(*), grad, output, x, y) = (grad * y,        grad * x)
gradient(::typeof(/), grad, output, x, y) = (grad * one(x)/y, grad *-x/y/y)

gradient (generic function with 10 methods)

In [16]:
import Base: abs, sin, cos, tan, exp, sqrt, max, min
abs(x::Node)  = register(abs, x)
sin(x::Node)  = register(sin, x)
cos(x::Node)  = register(cos, x)
tan(x::Node)  = register(tan, x)
exp(x::Node)  = register(exp, x)
sqrt(x::Node) = register(sqrt, x)
max(x::Node, y::Node) = register(max, isless(value(x), value(y)) ? y : x)
min(x::Node, y::Node) = register(min, isless(value(x), value(y)) ? x : y)
gradient(::typeof(sqrt), grad, output, x) = (grad * 0.5/sqrt(x), )
gradient(::typeof(abs), grad, output, x)  = (grad * sign(x), )
gradient(::typeof(sin), grad, output, x)  = (grad * cos(x), )
gradient(::typeof(cos), grad, output, x)  = (grad *-sin(x), )
gradient(::typeof(tan), grad, output, x)  = (grad *(tan(x)^2 + 1), )
gradient(::typeof(exp), grad, output, x)  = (grad * exp(x), )
gradient(::typeof(max), grad, output, x)  = (grad * one(x),)
gradient(::typeof(min), grad, output, x)  = (grad * one(x),)
gradient(::typeof(min), grad, output, x, y) = (isless(value(x), value(y)) ? grad * one(x) : grad * zero(x),
                                               isless(value(x), value(y)) ? grad * zero(y) : grad * one(y))

gradient (generic function with 19 methods)

In [47]:
import Base: maximum, broadcasted
broadcasted(f::Function, x::Node) = register(Broadcasted(f), x)
maximum(x::Node; kwargs...) = register(Method(maximum), x; kwargs...)
gradient(::typeof(maximum), grad, output, x; kwargs...) = begin
    res = zeros(size(x))
    
    for k in axes(x, 3)
        for j in axes(x, 2)
            for i in axes(x, 1)
                if x[i,j,k] == output[1,1,k]
                    res[i,j,k] = grad * one(eltype(output))
                end
            end
        end
    end
    (res, )
end

gradient (generic function with 21 methods)

In [48]:
maxpool(x::Node) = register(maxpool, x)
gradient(::typeof(maxpool), grad, output, input) = begin
    res = similar(input)

    for i in axes(input, 1)
        for j in axes(input, 2)
            for k in axes(input, 3)
                I, J = ceil.(Integer, (i, j) ./ 2)
                if input[i,j,k] == output[I,J,k]
                    res[i,j,k] = grad * one(eltype(output))
                end
            end
        end
    end
    (res, )
end

gradient (generic function with 21 methods)

In [49]:
r = rand(8,8,3)
v = Tensor(r, "v")

tsr v 8×8×3

In [50]:
function maxpool(input)
    n, m, k = size(input)
    N, M = floor.(Integer, (n, m) ./ 2)
    output = similar(input, N, M, k)
    for i=1:N
        for j=1:M
            region = input[ 2(i)-1:2(i)-0, 2(j)-1:2(j)-0, :]
            output[i, j, :] = maximum(region; dims=(1,2))
        end
    end
    output
end
c = maxpool(maxpool(v))

cached [fn maxpool](cached [fn maxpool](tsr v 8×8×3))

In [51]:
backward(c)
grad(v)

@cached [fn maxpool](cached [fn maxpool](tsr v 8×8×3))
@cached [fn maxpool](tsr v 8×8×3)


MethodError: MethodError: Cannot `convert` an object of type Array{Float64,3} to an object of type Float64
Closest candidates are:
  convert(::Type{T}, !Matched::T) where T<:Number at number.jl:6
  convert(::Type{T}, !Matched::Number) where T<:Number at number.jl:7
  convert(::Type{T}, !Matched::Base.TwicePrecision) where T<:Number at twiceprecision.jl:250
  ...