In [1]:
using Test

# Basic experiments with reverse autodiff

(this is not intended to be a full implementation of autodiff; just an exercise to understand how autodiff can be implemented)

First, construct a context to store the sensitivity dependencies when doing the forward pass

In [2]:
mutable struct AutodiffContext{T}
    numparams :: Int
    updatesources :: Array{Int,1}
    updatetargets :: Array{Int,1}
    updateweights :: Array{T,1}
end

function AutodiffContext{T}() where T
    return AutodiffContext{T}(0, Array{Int}(undef, 0), Array{Int}(undef, 0), Array{T}(undef, 0))
end

Next, construct a type used to propagate values that will have their sensitivity tracked

In [3]:
struct ForwardVal{T}
    value :: T
    index :: Int
    context :: AutodiffContext
end

function ForwardVal(value::T, context::AutodiffContext{T}) where T
    context.numparams += 1
    index = context.numparams
    return ForwardVal(value, index, context)
end
;

Then we define how to propagate the sensitivity through various operations.  We start with addition:

In [4]:
function (Base.:+)(a::ForwardVal{T}, b::T) where T
    # create a new node for this operator
    context = a.context
    result = ForwardVal(a.value + b, context)

    # add instructions to include this operator's sensitivity in a's sensitivity 
    push!(context.updatesources, result.index)
    push!(context.updatetargets, a.index)
    push!(context.updateweights, one(T))
    
    return result
end

function (Base.:+)(a::T, b::ForwardVal{T}) where T
    return b + a
end

function (Base.:+)(a::ForwardVal{T}, b::ForwardVal{T}) where T
    result = a + b.value
    
    # add instructions to include this operator's sensitivity in b's sensitivity 
    push!(context.updatesources, result.index)
    push!(context.updatetargets, b.index)
    push!(context.updateweights, one(T))
    
    return result
end

We next add multiplication:

In [5]:
function (Base.:*)(a::ForwardVal{T}, b::T) where T
    # create a new node for this operator
    context = a.context
    result = ForwardVal(a.value * b, context)

    # add instructions to include this operator's sensitivity in a's sensitivity 
    push!(context.updatesources, result.index)
    push!(context.updatetargets, a.index)
    push!(context.updateweights, b)
    
    return result
end

function (Base.:*)(a::T, b::ForwardVal{T}) where T
    return b * a
end

function (Base.:*)(a::ForwardVal{T}, b::ForwardVal{T}) where T
    result = a * b.value
    
    # add instructions to include this operator's sensitivity in b's sensitivity 
    push!(context.updatesources, result.index)
    push!(context.updatetargets, b.index)
    push!(context.updateweights, a.value)
    
    return result
end

For simplicity, we'll stop here.  It should be obvious how to extend this to additional operations like division and exponentiation and to non-binary operators.  This would be needed for real-world use, but it would just complicate this exercise without really adding educational value.

We next define a function to test this with (limiting ourselves to the operations we defined).

In [6]:
f(x, a, b, c) = ((((3. *a)*x)*x + 2. *b*x) + c) + 1.;

As a reference to test our results against, we'll hand-compute the partial derivatives in symbolic form. 

In [7]:
f_x(x, a, b, c) = 6a*x + 2b
f_a(x, a, b, c) = 3x^2
f_b(x, a, b, c) = 2x
f_c(x, a, b, c) = 1;

We construct forward wrappers for each of the parameters and propagate them through the function

In [8]:
context = AutodiffContext{Float64}()
x₀ = 3.0
a₀ = 2.0
b₀ = 4.0
c₀ = 5.0

xforward = ForwardVal(x₀, context)
aforward = ForwardVal(a₀, context)
bforward = ForwardVal(b₀, context)
cforward = ForwardVal(c₀, context)

fforward = f(xforward, aforward, bforward, cforward)

ForwardVal{Float64}(84.0, 12, AutodiffContext{Float64}(12, [5, 6, 6, 7, 7, 8, 9, 9, 10, 10, 11, 11, 12], [2, 5, 1, 6, 1, 3, 8, 1, 7, 9, 10, 4, 11], [3.0, 3.0, 6.0, 3.0, 18.0, 2.0, 3.0, 8.0, 1.0, 1.0, 1.0, 1.0, 1.0]))

Let's just confirm that we calculated the correct value of the function itself:

In [9]:
@test fforward.value == f(x₀, a₀, b₀, c₀)

[32m[1mTest Passed[22m[39m

Now we back-propagate using the data we gathered during the forward pass

In [10]:
function back(context::AutodiffContext{T}) where T
    result = zeros(T, context.numparams)
    
    # sensitivity of the final result on itself is 1
    result[end] = one(T)
    
    # back-propagate the sensitivities
    for i = length(context.updateweights):-1:1
        result[context.updatetargets[i]] += result[context.updatesources[i]] * context.updateweights[i]
    end
    
    return result
end
;

In [11]:
sensitivities = back(fforward.context)

12-element Array{Float64,1}:
 44.0
 27.0
  6.0
  1.0
  9.0
  3.0
  1.0
  3.0
  1.0
  1.0
  1.0
  1.0

We now pull out the sensitivities for each parameter and confim that they match the partial derivatives of f that we hand calculated

In [12]:
@test sensitivities[xforward.index] == f_x(x₀, a₀, b₀, c₀)
@test sensitivities[aforward.index] == f_a(x₀, a₀, b₀, c₀)
@test sensitivities[bforward.index] == f_b(x₀, a₀, b₀, c₀)
@test sensitivities[cforward.index] == f_c(x₀, a₀, b₀, c₀)

[32m[1mTest Passed[22m[39m

We thus appear to have a working example of reverse autodiff.