# Automatic Differentiation (AD)

In short, the promise of AD is

```julia
f(x) = 4x + x^2

df(x) = derivative(f, x)
```

such that

```julia
df(3) = 4 + 2*3 = 10
```

### What AD is not

**Symbolic rewriting:**
$$ f(x) = 4x + x^2 \quad \rightarrow \quad df(x) = 4 + 2x $$

**Numerical differentiation:**
$$ \frac{df}{dx} \approx \frac{f(x+h) - f(x)}{\Delta h} $$

## Forward mode AD

Key to AD is the application of the chain rule
$$\dfrac{d}{dx} f(g(x)) = \dfrac{df}{dg} \dfrac{dg}{dx}$$

Consider the function $f(a,b) = \ln(ab + \sin(a))$.

In [5]:
f(a,b) = log(a*b + sin(a))

f (generic function with 1 method)

In [6]:
f_derivative(a,b) = 1/(a*b + sin(a)) * (b + cos(a))

f_derivative (generic function with 1 method)

In [7]:
a = 3.1
b = 2.4
f_derivative(a,b)

0.18724182935843758

Dividing the function into the elementary steps, it corresponds to the following "*computational graph*":

<img src="imgs/comp_graph.svg" width=300px>

In [8]:
function f_graph(a,b)
    c1 = a*b
    c2 = sin(a)
    c3 = c1 + c2
    c4 = log(c3)
end

f_graph (generic function with 1 method)

In [9]:
f(a,b) == f_graph(a,b)

true

To calculate $\frac{\partial f}{\partial a}$ we have to apply the chain rule multiple times.

$\dfrac{\partial f}{\partial a} = \dfrac{\partial f}{\partial c_4} \dfrac{\partial c_4}{\partial a} = \dfrac{\partial f}{\partial c_4} \left( \dfrac{\partial c_4}{\partial c_3} \dfrac{\partial c_3}{\partial a}  \right) = \dfrac{\partial f}{\partial c_4} \left( \dfrac{\partial c_4}{\partial c_3} \left( \dfrac{\partial c_3}{\partial c_2} \dfrac{\partial c_2}{\partial a} + \dfrac{\partial c_3}{\partial c_1} \dfrac{\partial c_1}{\partial a}\right)  \right)$

In [10]:
function f_graph_derivative(a,b)
    c1 = a*b
    c1_ϵ = b
    
    c2 = sin(a)
    c2_ϵ = cos(a)
    
    c3 = c1 + c2
    c3_ϵ = c1_ϵ + c2_ϵ
    
    c4 = log(c3)
    c4_ϵ = 1/c3 * c3_ϵ
    
    c4, c4_ϵ
end

f_graph_derivative (generic function with 1 method)

In [11]:
f_graph_derivative(a,b)[2] == f_derivative(a,b)

true

**How can we automate this?**

In [12]:
# D for "dual number", invented by Clifford in 1873.
struct D <: Number
    x::Float64 # value
    ϵ::Float64 # derivative
end

In [13]:
import Base: +, *, /, -, sin, log, convert, promote_rule

a::D + b::D = D(a.x + b.x, a.ϵ + b.ϵ) # sum rule
a::D - b::D = D(a.x - b.x, a.ϵ - b.ϵ)
a::D * b::D = D(a.x * b.x, a.x * b.ϵ + a.ϵ * b.x) # product rule
a::D / b::D = D(a.x / b.x, (b.x * a.ϵ - a.x * b.ϵ)/b.x^2) # quotient rule

sin(a::D) = D(sin(a.x), cos(a.x) * a.ϵ)
log(a::D) = D(log(a.x), 1/a.x * a.ϵ)

Base.convert(::Type{D}, x::Real) = D(x, zero(x))
Base.promote_rule(::Type{D}, ::Type{<:Number}) = D

In [14]:
f(D(a,1), b)

D(2.0124440881688996, 0.18724182935843758)

Boom! That was easy!

In [16]:
f_derivative(a,b)

0.18724182935843758

In [17]:
f(D(a,1), b).ϵ ≈ f_derivative(a,b)

true

**How does this work?!**

The trick of forward mode AD is to let Julia implicitly perform the mapping `f -> f_graph_derivative` for you and then let the compiler optimize the resulting code structure (that's what compilers do!).

In [18]:
@code_typed f(D(a,1), b)

CodeInfo(
[90m1 ─[39m %1  = Base.getfield(a, :x)[36m::Float64[39m
[90m│  [39m %2  = Base.mul_float(%1, b)[36m::Float64[39m
[90m│  [39m %3  = Base.getfield(a, :x)[36m::Float64[39m
[90m│  [39m %4  = Base.mul_float(%3, 0.0)[36m::Float64[39m
[90m│  [39m %5  = Base.getfield(a, :ϵ)[36m::Float64[39m
[90m│  [39m %6  = Base.mul_float(%5, b)[36m::Float64[39m
[90m│  [39m %7  = Base.add_float(%4, %6)[36m::Float64[39m
[90m│  [39m %8  = Base.getfield(a, :x)[36m::Float64[39m
[90m│  [39m %9  = invoke Main.sin(%8::Float64)[36m::Float64[39m
[90m│  [39m %10 = Base.getfield(a, :x)[36m::Float64[39m
[90m│  [39m %11 = invoke Main.cos(%10::Float64)[36m::Float64[39m
[90m│  [39m %12 = Base.getfield(a, :ϵ)[36m::Float64[39m
[90m│  [39m %13 = Base.mul_float(%11, %12)[36m::Float64[39m
[90m│  [39m %14 = Base.add_float(%2, %9)[36m::Float64[39m
[90m│  [39m %15 = Base.add_float(%7, %13)[36m::Float64[39m
[90m│  [39m %16 = invoke Main.log(%14::Float64)[36m::F

While this is somewhat hard to parse, plugging these operations manually into each other we find that this code equals

```julia
D.x = log(a.x*b + sin(a.x))
D.ϵ = 1/(a.x*b + sin(a.x)) * (a.x*0 + (a.ϵ*b) + cos(a.x)*a.ϵ)
```

which, if we drop `a.x*0`, set `a.ϵ = 1`, and rename `a.x` $\rightarrow$ `a`, reads

```julia
D.x = log(a*b + sin(a))
D.ϵ = 1/(a*b + sin(a)) * (b + cos(a)
```

This precisely matches our definitions from above:

```julia
f(a,b) = log(a*b + sin(a))

f_derivative(a,b) = 1/(a*b + sin(a)) * (b + cos(a))
```

Importantly, the compiler sees the entire "rewritten" code and can therefore apply optimizations. In this simple example, we find that the code produced by our simple Forward mode AD is essentially identical to the explicit implementation.

In [19]:
@code_llvm debuginfo=:none f_graph_derivative(a,b)


define void @julia_f_graph_derivative_1561([2 x double]* noalias nocapture sret, double, double) {
top:
  %3 = fmul double %1, %2
  %4 = call double @j_sin_1562(double %1)
  %5 = call double @j_cos_1563(double %1)
  %6 = fadd double %3, %4
  %7 = fadd double %5, %2
  %8 = call double @j_log_1564(double %6)
  %9 = fdiv double 1.000000e+00, %6
  %10 = fmul double %9, %7
  %.sroa.0.0..sroa_idx = getelementptr inbounds [2 x double], [2 x double]* %0, i64 0, i64 0
  store double %8, double* %.sroa.0.0..sroa_idx, align 8
  %.sroa.2.0..sroa_idx1 = getelementptr inbounds [2 x double], [2 x double]* %0, i64 0, i64 1
  store double %10, double* %.sroa.2.0..sroa_idx1, align 8
  ret void
}


In [20]:
@code_llvm debuginfo=:none f(D(a,1), b)


define void @julia_f_1565([2 x double]* noalias nocapture sret, [2 x double]* nocapture nonnull readonly dereferenceable(16), double) {
top:
  %3 = getelementptr inbounds [2 x double], [2 x double]* %1, i64 0, i64 0
  %4 = load double, double* %3, align 8
  %5 = fmul double %4, %2
  %6 = fmul double %4, 0.000000e+00
  %7 = getelementptr inbounds [2 x double], [2 x double]* %1, i64 0, i64 1
  %8 = load double, double* %7, align 8
  %9 = fmul double %8, %2
  %10 = fadd double %6, %9
  %11 = call double @j_sin_1566(double %4)
  %12 = call double @j_cos_1567(double %4)
  %13 = fmul double %12, %8
  %14 = fadd double %5, %11
  %15 = fadd double %10, %13
  %16 = call double @j_log_1568(double %14)
  %17 = fdiv double 1.000000e+00, %14
  %18 = fmul double %17, %15
  %.sroa.0.0..sroa_idx = getelementptr inbounds [2 x double], [2 x double]* %0, i64 0, i64 0
  store double %16, double* %.sroa.0.0..sroa_idx, align 8
  %.sroa.2.0..sroa_idx1 = getelementptr inbounds [2 x double], [2 x double]* %0,

Our AD is alreadly pretty powerful and general. Let's define the promised function `derivative`:

In [21]:
derivative(f::Function, x::Number) = f(D(x, one(x))).ϵ

derivative (generic function with 1 method)

In [26]:
g(x) = x + x^2

g (generic function with 1 method)

In [28]:
derivative(g, 3.0)

7.0

Anonymous function oft come in handy here:

In [30]:
derivative(x->3*x^2+4x+5, 2)

16.0

In [31]:
derivative(x->sin(x)*log(x), 3)

-1.0405779197678489

We can also define the partial derivative $\frac{df(a,b)}{da}$ from above:

In [33]:
df(x) = derivative(a->f(a,b),x)

df (generic function with 1 method)

Here, `b` is "wrapped into a closure".

In [34]:
df(1.23)

0.7020787235973817

## Taking the derivative of *code*

> Repeat $t \leftarrow (t + x/2)/2$ until $t$ converges to $\sqrt{x}$.

In [35]:
@inline function Babylonian(x; N = 10)
    t = (1+x)/2
    for i = 2:N
        t = (t + x/t)/2
    end
    t
end

Babylonian (generic function with 1 method)

In [36]:
Babylonian(2)

1.414213562373095

In [37]:
sqrt(2)

1.4142135623730951

Using our forward mode AD, that is our dual numbers, we can compute the derivative of `Babylonian` **with no rewrite at all**.

In [38]:
Babylonian(D(5, 1))

D(2.23606797749979, 0.22360679774997896)

In [41]:
sqrt(5)

2.23606797749979

In [42]:
1 / (2*sqrt(5))

0.22360679774997896

**It just works and is efficient!**

In [43]:
@code_native debuginfo=:none Babylonian(D(5, 1))

	.section	__TEXT,__text,regular,pure_instructions
	movq	%rdi, %rax
	vmovsd	(%rsi), %xmm1           ## xmm1 = mem[0],zero
	vmovsd	8(%rsi), %xmm2          ## xmm2 = mem[0],zero
	movabsq	$5254196832, %rcx       ## imm = 0x1392CAE60
	vaddsd	(%rcx), %xmm1, %xmm4
	vxorpd	%xmm8, %xmm8, %xmm8
	vaddsd	%xmm8, %xmm2, %xmm5
	movabsq	$5254196840, %rcx       ## imm = 0x1392CAE68
	vmovsd	(%rcx), %xmm3           ## xmm3 = mem[0],zero
	vmulsd	%xmm3, %xmm4, %xmm6
	vaddsd	%xmm5, %xmm5, %xmm5
	vmulsd	%xmm8, %xmm4, %xmm4
	vsubsd	%xmm4, %xmm5, %xmm5
	movabsq	$5254196848, %rcx       ## imm = 0x1392CAE70
	vmovsd	(%rcx), %xmm4           ## xmm4 = mem[0],zero
	vmulsd	%xmm4, %xmm5, %xmm5
	vdivsd	%xmm6, %xmm1, %xmm9
	vmulsd	%xmm2, %xmm6, %xmm0
	vmulsd	%xmm1, %xmm5, %xmm7
	vsubsd	%xmm7, %xmm0, %xmm0
	vmulsd	%xmm6, %xmm6, %xmm7
	vdivsd	%xmm7, %xmm0, %xmm0
	vaddsd	%xmm9, %xmm6, %xmm6
	vaddsd	%xmm0, %xmm5, %xmm0
	vmulsd	%xmm3, %xmm6, %xmm5
	vaddsd	%xmm0, %xmm0, %xmm0
	vmulsd	%xmm8, %xmm6, %xmm6
	vsubsd	%xmm6, %xmm0, 

Recursion? Works as well...

In [47]:
function power(x, n)
    if n <= 0
        return 1
    else
        return x*power(x, n-1)
    end
end

power (generic function with 1 method)

In [61]:
4.0^3

64.0

In [59]:
derivative(x -> power(x,3), 4.0)

48.0

In [60]:
3*4.0^2 # 3*x^2

48.0

Deriving our Vandermonde matrix from yesterday?

In [68]:
function vander_generic(x::AbstractVector{T}) where T
    m = length(x)
    V = Matrix{T}(undef, m, m)
    for j = 1:m
        V[j,1] = one(x[j])
    end
    for i= 2:m
        for j = 1:m
            V[j,i] = x[j] * V[j,i-1]
            end
        end
    return V
end

vander_generic (generic function with 1 method)

\begin{align}V=\begin{bmatrix}1&a&a^{2} &a^3\\1&b&b^{2} &b^3\\1&c&c^{2} &c^3\\1&d&d^{2} &d^3\end{bmatrix}\end{align}

\begin{align}\frac{dV}{da}=\begin{bmatrix}0&1&2a &3a^2\\0&0&0 &0\\0&0&0 &0\\0&0&0 &0\end{bmatrix}\end{align}

In [86]:
a, b, c, d = 2, 3, 4, 5
V = vander_generic([D(a,1), D(b,0), D(c,0), D(d,0)])

4×4 Array{D,2}:
 D(1.0, 0.0)  D(2.0, 1.0)   D(4.0, 4.0)   D(8.0, 12.0)
 D(1.0, 0.0)  D(3.0, 0.0)   D(9.0, 0.0)   D(27.0, 0.0)
 D(1.0, 0.0)  D(4.0, 0.0)  D(16.0, 0.0)   D(64.0, 0.0)
 D(1.0, 0.0)  D(5.0, 0.0)  D(25.0, 0.0)  D(125.0, 0.0)

In [87]:
[V[i,j].ϵ for i in axes(V,1), j in axes(V,2)]

4×4 Array{Float64,2}:
 0.0  1.0  4.0  12.0
 0.0  0.0  0.0   0.0
 0.0  0.0  0.0   0.0
 0.0  0.0  0.0   0.0

## Symbolically (because we can)

The below is mathematically equivalent, **though not exactly what the computation is doing**. Our AD isn't performing symbolic manipulations.

In [88]:
using SymPy

In [89]:
@vars x

(x,)

In [91]:
Babylonian(x; N=1)

x   1
─ + ─
2   2

In [92]:
diff(Babylonian(x; N=1))

1/2

In [94]:
simplify(Babylonian(x; N=5))

 16        15          14           13             12             11          
x   + 496⋅x   + 35960⋅x   + 906192⋅x   + 10518300⋅x   + 64512240⋅x   + 2257928
──────────────────────────────────────────────────────────────────────────────
               ⎛ 15        14         13           12           11            
            32⋅⎝x   + 155⋅x   + 6293⋅x   + 105183⋅x   + 876525⋅x   + 4032015⋅x

    10              9              8              7              6            
40⋅x   + 471435600⋅x  + 601080390⋅x  + 471435600⋅x  + 225792840⋅x  + 64512240⋅
──────────────────────────────────────────────────────────────────────────────
10             9             8             7             6            5       
   + 10855425⋅x  + 17678835⋅x  + 17678835⋅x  + 10855425⋅x  + 4032015⋅x  + 8765

 5             4           3          2            
x  + 10518300⋅x  + 906192⋅x  + 35960⋅x  + 496⋅x + 1
───────────────────────────────────────────────────
    4           3         2            ⎞           


In [93]:
simplify(diff(simplify(Babylonian(x; N=5)), x))

 30        29          28            27              26               25      
x   + 310⋅x   + 59799⋅x   + 4851004⋅x   + 215176549⋅x   + 5809257090⋅x   + 102
──────────────────────────────────────────────────────────────────────────────
                     ⎛ 30        29          28            27             26  
                  32⋅⎝x   + 310⋅x   + 36611⋅x   + 2161196⋅x   + 73961629⋅x   +

           24                  23                   22                   21   
632077611⋅x   + 1246240871640⋅x   + 10776333438765⋅x   + 68124037776390⋅x   + 
──────────────────────────────────────────────────────────────────────────────
             25                24                 23                  22      
 1603620018⋅x   + 23367042639⋅x   + 238538538360⋅x   + 1758637118685⋅x   + 957

                 20                     19                     18             
321156247784955⋅x   + 1146261110726340⋅x   + 3133113888931089⋅x   + 6614351291
──────────────────────────────────────────────────

## Don't reinvent the wheel: ForwardDiff.jl

Now that we have understood how forward AD works, we can use the more feature complete package [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl).

In [96]:
using ForwardDiff

In [97]:
ForwardDiff.derivative(Babylonian, 2)

0.35355339059327373

In [98]:
@edit ForwardDiff.derivative(Babylonian, 2)

(Note: [DiffRules.jl](https://github.com/JuliaDiff/DiffRules.jl))

## If time permits: Reverse mode AD

Forward mode:
$\dfrac{\partial f}{\partial x} = \dfrac{\partial f}{\partial c_4} \dfrac{\partial c_4}{\partial x} = \dfrac{\partial f}{\partial c_4} \left( \dfrac{\partial c_4}{\partial c_3} \dfrac{\partial c_3}{\partial x}  \right) = \dfrac{\partial f}{\partial c_4} \left( \dfrac{\partial c_4}{\partial c_3} \left( \dfrac{\partial c_3}{\partial c_2} \dfrac{\partial c_2}{\partial x} + \dfrac{\partial c_3}{\partial c_1} \dfrac{\partial c_1}{\partial x}\right)  \right)$

Reverse mode:
$\dfrac{\partial f}{\partial x} = \dfrac{\partial f}{\partial c_4} \dfrac{\partial c_4}{\partial x} = \left( \dfrac{\partial f}{\partial c_3}\dfrac{\partial c_3}{\partial c_4}   \right) \dfrac{\partial c_4}{\partial x} = \left( \left( \dfrac{\partial f}{\partial c_2} \dfrac{\partial c_2}{\partial c_3} + \dfrac{\partial f}{\partial c_1} \dfrac{\partial c_1}{\partial c_3} \right) \dfrac{\partial c_3}{\partial c_4} \right) \dfrac{\partial c_4}{\partial x}$

Forward mode AD requires $n$ passes in order to compute an $n$-dimensional
gradient.

Reverse mode AD requires only a single run in order to compute a complete gradient but requires two passes through the graph: a forward pass during which necessary intermediate values are computed and a backward pass which computes the gradient.

*Rule of thumb:*

Forward mode is good for $\mathbb{R} \rightarrow \mathbb{R}^n$ while reverse mode is good for $\mathbb{R}^n \rightarrow \mathbb{R}$.

An efficient source-to-source reverse mode AD is implemented in [Zygote.jl](https://github.com/FluxML/Zygote.jl), the AD underlying [Flux.jl](https://fluxml.ai/) (since version 0.10).

In [99]:
using Zygote

In [100]:
f(x) = 5*x + 3

f (generic function with 2 methods)

In [101]:
gradient(f, 5)

(5,)

In [102]:
@code_llvm debuginfo=:none gradient(f,5)


define [1 x i64] @julia_gradient_5239(i64) {
top:
  ret [1 x i64] [i64 5]
}


In [103]:
@code_llvm debuginfo=:none derivative(f,5)


define double @julia_derivative_5309(i64) {
top:
  %1 = sitofp i64 %0 to double
  %2 = fmul double %1, 0.000000e+00
  %3 = fadd double %2, 5.000000e+00
  %4 = fadd double %3, 0.000000e+00
  ret double %4
}


## Some nice reads

Papers:
* https://www.jmlr.org/papers/volume18/17-468/17-468.pdf

Lectures:


* https://mitmath.github.io/18337/lecture8/automatic_differentiation.html

Blog posts:

* ML in Julia: https://julialang.org/blog/2018/12/ml-language-compiler

* Nice example: https://fluxml.ai/2019/03/05/dp-vs-rl.html

* Nice interactive examples: https://fluxml.ai/experiments/

* Why Julia for ML? https://julialang.org/blog/2017/12/ml&pl

* Neural networks with differential equation layers: https://julialang.org/blog/2019/01/fluxdiffeq

* Implement Your Own Automatic Differentiation with Julia in ONE day : http://blog.rogerluo.me/2018/10/23/write-an-ad-in-one-day/

* Implement Your Own Source To Source AD in ONE day!: http://blog.rogerluo.me/2019/07/27/yassad/

Repositories:

* AD flavors, like forward and reverse mode AD: https://github.com/MikeInnes/diff-zoo (Mike is one of the smartest Julia ML heads)

Talks:

* AD is a compiler problem: https://juliacomputing.com/assets/pdf/CGO_C4ML_talk.pdf