## Some examples on AD

Lets consider the Speelpenning problem $f(x)=\prod_{i=1}^{n} x_i$
We can code this in Julia in a variety of ways:

In [None]:
f(x) = prod(x)
g(x) = reduce(*,x)

Julia Doc says prod should preferred over reduce(*).

Lets give it a try:

In [None]:
x = collect(range(1.0,5.0)) #numbers of 1 to 5 (float)
f(x)

Now lets try to calculate the Gradient.
Choose one of many AD solutions already in the Julia ecosystem.

Lets start with forward (Tangent) mode:

In [None]:
import ForwardDiff
ForwardDiff.gradient(f,x)

What about reverse mode?

In [None]:
import Zygote
Zygote.gradient(f,x)

Maybe also FiniteDifferences just to be sure?

In [None]:
import FiniteDifferences
FiniteDifferences.grad(FiniteDifferences.central_fdm(2, 1), f, x) # 2nd order accurate 1st derivative approx

Interjection: FiniteDifferences can do Richardson extrapolation out of the box:

In [None]:
FiniteDifferences.central_fdm(2,1)(x->sin(x),1.0) - cos(1.0) #regular central FD

In [None]:
FiniteDifferences.extrapolate_fdm(FiniteDifferences.central_fdm(2,1),x->sin(x),1.0)[1] - cos(1.0)

For scalar functions Zygote implements the adjoint operator '

In [None]:
import Zygote
using Plots
plot(x->sin(x))
plot!(x->sin'(x))

Note that the ' operation actually immediately creates machine code for the derivative

In [None]:
s(x)=x^2
@code_llvm debuginfo=:none s'(3.0)

## Limitations of Adjoint AD in Zygote

No mutation within Arrays allowed!
Consider the following contrived example of constructing a matrix.

In [None]:
function mat_sum(p)
    A=zeros(3,3)
    for i in 1:3
        for j in 1:3
            A[i,j] = j*p
        end
    end
    sum(A)
end

Calculating the derivative with Zygote will not work!

In [None]:
mat_sum'(1.0)

If code is written with list comprehensions or other "pythonic" looking things it will usually work with Zygote well. However this prevents some vectorization techniques of Julia (e.g. element wise assignment .=) is also not supported.
If one is not very used to programming this way this becomes unmaintainable really fast.

In [None]:
function mat_sum_no_mutation(p)
    A=mapreduce(i->ones(3)*i*p, hcat, 1:3) # iterate from 1 to 3, build columns of matrix and concats horizontally
    sum(A)
end
mat_sum_no_mutation'(1.0)

## Specifying custom adjoints
If altering the primal is not feasible or if structure can be exploited we can specify custom rules for the reversal.
Lets consider the common problem of the solution of a linear equation system embedded in some other form of computation:

In [1]:
import Random
import LinearAlgebra as LA
import Zygote
import ForwardDiff
import ChainRulesCore

function solve(A,b)
    C = LA.factorize(A)
    x = C \ b
end

solve (generic function with 1 method)

LoadError: ArgumentError: new: too few arguments (expected 4)

Factorize will decide which Factorization to use based of the Type of Matrix A provided (LU for general Matrix, Cholesky for SPD etc).

Lets add a custom rrule for Symmetric or SPD Matrices (we can reuse the existing Factorizations)
Note that the ∂ symbol is only convention here. It could be called anything. Order of incoming and outgoint adjoints is determined by the argument position in the primal routines.

In [17]:
function ChainRulesCore.rrule(::typeof(solve), A::T, b::AbstractVector) where T <: Union{LA.Hermitian,LA.Symmetric}
    @info "Chainrule for symmetric Matrix solve triggered."
    C = LA.factorize(A) # duplication of primal so we can use C later
    x = C \ b
    function solve_pullback(∂x) # incoming adjoints
        ∂b = C \ ∂x    # reuse existing factorization to solve C*∂b=∂x
        ∂A = -x * ∂b'  # calculate outer product
        return ChainRulesCore.NoTangent(), ∂A, ∂b # return adjoints for ∂A, ∂b
    end
    return x, solve_pullback # return primal result x, pullback will be called during reversal
end

And another rule for general matrices (we discard the factorization and build a new one for A^T):

In [29]:
function ChainRulesCore.rrule(::typeof(solve), A, b)
    @info "Chainrule for general matrix solve triggered."
    x = solve(A,b) # reuse primal implementation
    function solve_pullback(∂x)
        ∂b = solve(A',∂x) # reuse primal implementation
        ∂A = -x * ∂b'
        return ChainRulesCore.NoTangent(), ∂A, ∂b
    end
    return x, solve_pullback
end

In [38]:
A=LA.Matrix([1.0 0.5
             0.5 1.0])
b =[1.0,1.0]
typeof(b)
solve(A,b)
#@code_lowered solve(A,b)
Zygote.jacobian(solve,A,b)

┌ Info: Chainrule for general matrix solve triggered.
└ @ Main In[18]:2


([-0.888888888888889 -0.8888888888888888 0.4444444444444445 0.4444444444444444; 0.4444444444444445 0.4444444444444444 -0.888888888888889 -0.8888888888888888], [1.3333333333333333 -0.6666666666666666; -0.6666666666666666 1.3333333333333333])

Construct a function $\mathbb{R}^n \rightarrow \mathbb{R}$ which uses solve()
(we fix the random seed such that gradient is reproducible)

In [20]:
ChainRulesCore.@non_differentiable Random.seed!(::Any...)
function foo(p)
    Random.seed!(0)
    n = length(p)
    A = randn(n,n)
    A = LA.Hermitian(A + A' + LA.Diagonal(p))
    b = randn(n)
    x = solve(A,b)
    #println(x)
    sum(x)
end
p = randn(10)

10-element Vector{Float64}:
 -1.1736201831150836
 -1.3114173281605968
 -0.8892580619082893
 -1.2434775000832616
  0.5752833946674542
  0.09603324973701619
 -1.5140451496538492
  0.4428517672329901
  0.6947751821060821
  0.6495724780590411

In [21]:
Zygote.gradient(foo,p)


┌ Info: Chainrule for symmetric Matrix solve triggered.
└ @ Main In[17]:2


([-0.16224512743328662, -0.5781190985071835, -0.3286065218357177, -0.5162917589606281, 0.015955143927428382, -0.12420307704303024, -0.10617593288054072, -0.002902709816762701, -0.4718317939250771, 0.23973659938718328],)

In [22]:
ForwardDiff.gradient(foo,p)

10-element Vector{Float64}:
 -0.1622451274332865
 -0.5781190985071833
 -0.32860652183571687
 -0.5162917589606288
  0.01595514392742836
 -0.12420307704303027
 -0.10617593288054047
 -0.0029027098167627007
 -0.471831793925077
  0.23973659938718336