# Autodifferentiation
## Forward mode

In [1]:
struct Var
    x::Real
end

function derivative(ex::Union{Expr,Var},variable::Var)
    #Base case 
    if typeof(ex) == Var
        if ex == variable
            return (ex.x,1) 
        else
            return (ex.x,0)
        end
    elseif typeof(ex) == Expr
        # Add case
        if ex.args[1] == :+     # Expr(:call,:+,a,b,..)
            v,∂v = 0,0
            #= Loops through arguments and recurses down each
            argument until base case is hit.
            =#
            for i in 2:length(ex.args)
                p,∂p = derivative(ex.args[i],variable)
                v += p     # sum of arguments
                ∂v += ∂p   # sum of partials Σ∂
            end
            return (v,∂v)

        # Multiplication case
        elseif ex.args[1] == :*  # Expr(:call,:*,a,b,..)
            v = []
            ∂v = []
            #= Loops through arguments and recurses down each
            argument until base case is hit.
            =#
            for i in 2:length(ex.args)
                p,∂p = derivative(ex.args[i],variable)
                push!(v,p)
                push!(∂v,∂p) 
            end

            a,∂a = 0,0
            a = reduce(*,v)  # product of arguments
            # computes ∂(abc)/∂variable = (bc)∂a/∂variable +
            #                             (ac)∂b/∂variable +
            #                              (ab)∂c/∂variable
            for i in 1:length(∂v)  
                ∂a += a*∂v[i]/v[i]
            end
            return (a,∂a)  
        end
    else
        return "Case Not Valid!"

    end

end



derivative (generic function with 1 method)

In [2]:
x = Var(6)
y = Var(7)
z = :($x*$x+($y*$x*$y))

println("(z,∂z/∂x):",derivative(z,x))
println("(z,∂z/∂y):",derivative(z,y))

(z,∂z/∂x):(330, 61.0)
(z,∂z/∂y):(330, 84.0)


## Reverse mode

In [1]:
mutable struct Var
    x::Real
    partial::Real 
end

Var(x::Real) = Var(x,0)     # initialize partial = 0

mutable struct Node
    ex::Union{Expr,Var}
    val::Real
    child::Vector{Node}
end

Node(ex::Union{Expr,Var}) = Node(ex,0,[]) # initialize value = 0, child nodes = []

#Returns evaluation graph consisting of Nodes
function evaluate(ex::Union{Var,Expr})
    #Base case 
    if typeof(ex) == Var
        return Node(ex,ex.x,[])
    
    elseif typeof(ex) == Expr 
        # Add case
        if ex.args[1] == :+
            cur_n = Node(ex)   #create new Node for this expression
            for i in 2:length(ex.args) 
                n = evaluate(ex.args[i])  #recurses down to base case
                push!(cur_n.child,n)   #appends subnode as this node's child
                cur_n.val += n.val    #add subnode value to current node value
            end
            return cur_n

        # Multiplication case
        elseif ex.args[1] == :*
            cur_n = Node(ex)   #create new Node for this expression
            cur_n.val = 1
            for i in 2:length(ex.args)
                n = evaluate(ex.args[i])  #recurses down to base case
                push!(cur_n.child,n)   #appends subnode as this node's child 
                cur_n.val *= n.val  #multiply subnode value to current node value
            end
            return cur_n
        end
        
    else
        return "Case Not Valid!"
    end

end

function derivative(node::Node,gradient::Real)
    #Base case
    if typeof(node.ex) == Var
        node.ex.partial += gradient   # update gradient for this Var

    elseif typeof(node.ex) == Expr
        # Add case
        if node.ex.args[1] == :+     
            for i in 1:length(node.child)
                derivative(node.child[i],gradient)  #traverses each leaf nodes (Var) and updates gradient
            end

        # Multiplication case
        elseif node.ex.args[1] == :*  
            # current node val has been previously computed as a product of children 
            # e.g. if node has 3 subnodes a,b,c, gradient passed to subnode a = (bc)*gradient,
            # gradient passed to subnode b = (ac)*gradient, gradient passed to subnode c = (ab)*gradient, 
            for i in 1:length(node.child)  
                derivative(node.child[i],gradient * node.val / node.child[i].val) 
            end
        end

    else
        return "Case Not Valid!"
    end

end



derivative (generic function with 1 method)

In [2]:
x = Var(6)
y = Var(7)
z = :($x*($x+$y)+($y*$x*$y))

t = evaluate(z)
derivative(t,1)
println("z: ",t.val)
println("∂z/∂x: ",x.partial)
println("∂z/∂y: ",y.partial)

z: 372
∂z/∂x: 68.0
∂z/∂y: 90.0
