# Reverse mode AD

In [1]:
mutable struct Node
    val::Float64
    # parent index
    par1::Int64
    par2::Int64
    # adjoint ∂f/∂v
    adj::Float64
    der1::Float64
    der2::Float64
    index::Int64
end

mutable struct Tape
    tape::Vector{Node}
    size::Int64
end

In [2]:
# overloading basic commands
Base.:log(f::Node) = log_AD(f, tape)
Base.:*(f::Node, g::Node) = mult_AD(f,g, tape)
Base.:*(n::Integer, f::Node) = scalar_mult_AD(n,f, tape)
Base.:sin(f::Node) = sin_AD(f, tape)
Base.:+(f::Node, g::Node) = add_AD(f,g, tape)
Base.:-(f::Node,g::Node) = sub_AD(f,g, tape)
Base.:/(f::Node,g::Node) = divide_AD(f,g, tape)

# Define exponential function for node data type
Base.:exp(f::Node) = exp_AD(f,tape)

In [3]:
# here are the actual funtions
function log_AD(f::Node, tape)::Node
    ans = Node(log(f.val), f.index, -1, 0, 1/f.val, 0, tape.size)
    tape.tape[tape.size] = ans
    tape.size += 1
    return ans
end

function mult_AD(f::Node, g::Node, tape)::Node
    ans = Node(f.val * g.val, f.index, g.index, 0, g.val, f.val, tape.size)
    tape.tape[tape.size] = ans
    tape.size += 1
    return ans
end

function scalar_mult_AD(n::Integer, f::Node, tape)::Node
    ans = Node(n*f.val, f.index, -1, 0, n, 0, tape.size)
    tape.tape[tape.size] = ans
    tape.size += 1
    return ans
end

function divide_AD(f::Node, g::Node, tape)::Node
    ans = Node(f.val / g.val, f.index, g.index, 0, 1/g.val, -f.val/(g.val)^2, tape.size)
    tape.tape[tape.size] = ans
    tape.size += 1
    return ans
end

function sin_AD(f::Node, tape)::Node
    ans = Node(sin(f.val), f.index, -1, 0, cos(f.val), 0, tape.size)
    tape.tape[tape.size] = ans
    tape.size += 1
    return ans
end

function add_AD(f::Node, g::Node, tape)
    ans = Node(f.val + g.val, f.index, g.index, 0, 1, 1, tape.size)
    tape.tape[tape.size] = ans
    tape.size += 1
    return ans
end

function sub_AD(f::Node, g::Node, tape)
    ans = Node(f.val - g.val, f.index, g.index, 0, 1, -1, tape.size)
    tape.tape[tape.size] = ans
    tape.size += 1
    return ans
end

function exp_AD(f::Node,tape)::Node
    ans = Node(exp(f.val),f.index,-1,0,exp(f.val),0,tape.size)
    tape.tape[tape.size] = ans
    tape.size += 1
    return ans
end
    

exp_AD (generic function with 1 method)

In [4]:
function rewind(tape::Tape)
    tape.tape[tape.size - 1].adj = 1
    for i in tape.size-1:-1:1
        if (tape.tape[i].par1 != -1)
            tape.tape[tape.tape[i].par1].adj += tape.tape[i].adj * tape.tape[i].der1
        end
        if (tape.tape[i].par2 != -1)
            tape.tape[tape.tape[i].par2].adj += tape.tape[i].adj * tape.tape[i].der2
        end
    end
end

rewind (generic function with 1 method)

In [10]:
# initializing variables
x_1 = Node(1.0, -1.0, -1.0, 0.0, 0.0, 0.0, 1)
x_2 = Node(2.0, -1.0, -1.0, 0.0, 0.0, 0.0, 2)
x_3 = Node(3.0, -1.0, -1.0, 0.0, 0.0, 0.0, 3)

# initialize tape
x = 10
size = 1
tape = Tape(Array{Node,1}(undef, x), size)

# storing numbers in tape
tape.tape[tape.size] = x_1
tape.size += 1
tape.tape[tape.size] = x_2
tape.size += 1
tape.tape[tape.size] = x_3
tape.size += 1

# here is the function
#y=log(x_1)+x_1*x_2-sin(x_2)
y = ( x_1*x_2*sin(x_3) + exp(x_1*x_2) ) / x_3
#
# y = exp(2*x_1)

# rewind tape to find adjoints
rewind(tape)

# here are the partial derivatives
# ∂y/∂x_1 = x2 * sin(x3)/x3 + exp(x1 * x2) * x2/x3
println("from AD")
println(tape.tape[1].adj)
println("true")
println(2*sin(3)/3 + exp(1*2) * 2/ 3)
    
# ∂y/∂x_2  = x1*sin(x3) + exp(x1*x2)*x1/x3
println("from AD")
println(tape.tape[2].adj)
println("true")
println(1*sin(3)/3 + exp(1*2) * 1/ 3)

# ∂y/∂x_3  = x1*x2*(cos(x3)*x3 - sin(X3))/x3 - exp(x1*x2)/x3^2
println(tape.tape[3].adj)

LoadError: syntax: missing comma or ) in argument list

In [6]:
#df/dx1
2*sin(3) + exp(1*2) * 2/ 3

5.208277415406835

In [7]:
#df/dx2
1*sin(3) + exp(1*2) * 1/ 3

2.6041387077034175

In [8]:
#df/dx2
1*2*cos(3) -  exp(1*2)/3^2

-2.8009912264154075

Doesn't seem to match

In [9]:
# initializing variables
x_1 = Node(2.0, -1.0, -1.0, 0.0, 0.0, 0.0, 1)

# initialize tape
x = 15
size = 1
tape2 = Tape(Array{Node,1}(undef, x), size)
y = exp(3*x_1)

# rewind tape to find adjoints
rewind(tape2)

# here are the partial derivatives
# ∂y/∂x_1 = 2*exp(x1)
println("AD computation: ")
println(tape2.tape[1].adj)

print("true derivative")
println(3*exp(2))



LoadError: BoundsError: attempt to access 10-element Vector{Node} at index [11]