# Automatic differentiation

Goal: given a program or a function in a computer language that computes a function $f(x)$, compute the *exact* derivative $f'(x)$ (up to roundoff errors).

## Types

In [1]:
3

3

In [2]:
typeof(3)

Int64

In [3]:
3.7

3.7

In [4]:
typeof(3.7)

Float64

In [5]:
typeof(3 + 4im)

Complex{Int64}

In [6]:
typeof([3,4,5,7])

Array{Int64,1}

In [8]:
typeof([3,4, "hello"])

Array{Any,1}

In [9]:
supertype(Int64)

Signed

In [11]:
supertype(Signed)

Integer

In [12]:
supertype(Integer)

Real

In [14]:
supertype(Real)

Number

In [15]:
supertype(Number)

Any

In [16]:
function fib(n::Integer)
    if n < 3
        return 1
    else
        return fib(n-1) + fib(n-2)
    end
end

fib (generic function with 1 method)

In [17]:
[fib(n) for n = 1:10]

10-element Array{Int64,1}:
  1
  1
  2
  3
  5
  8
 13
 21
 34
 55

In [18]:
supertype(Float64)

AbstractFloat

In [19]:
supertype(AbstractFloat)

Real

In [20]:
Int64 <: Integer

true

In [25]:
Float64(3)

3.0

## Multiple dispatch

I can define multiple versions of the same function that act on different types of arguments:

In [27]:
f(x) = 3 + x  # will be called on any type of x  -- f(x::Any)

f (generic function with 1 method)

In [28]:
f(x::Integer) = 14

f (generic function with 2 methods)

In [29]:
f(3) # calls f(x::Integer)

14

In [30]:
f(3.0) # calls f(x::Any)

6.0

In [31]:
f(x::Array) = length(x)

f (generic function with 3 methods)

In [32]:
f([3,4,5])

3

When you call `f(x)`, it runs the *most specific* version of `f` for type type of argument `x`:

In [33]:
f(x::Int64) = 17

f (generic function with 4 methods)

In [34]:
f(3)

17

In [35]:
f(Int32(3))

14

In [36]:
g(x,y) = 1
g(x,y::Integer) = 3
g(x::Integer,y) = 4

g (generic function with 3 methods)

In [37]:
g(5.7, 3.2)

1

In [38]:
g(3,6.7)

4

In [39]:
g(3,4)

MethodError: MethodError: g(::Int64, ::Int64) is ambiguous. Candidates:
  g(x::Integer, y) in Main at In[36]:3
  g(x, y::Integer) in Main at In[36]:2
Possible fix, define
  g(::Integer, ::Integer)

In [40]:
g(x::Integer, y::Integer) = 14

g (generic function with 4 methods)

In [41]:
g(3,4)

14

In [42]:
methods(g)

## Dual number:

A dual number is basically a pair of a number and a derivative, and the aritmetic rules correspond to the rules for propagating derivative.

See also:

* [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl): "serious" Julia package for numerical differentiation based on dual numbers.
* [Automatic differentiation in 18.S096](https://github.com/stevengj/18S096/blob/master/lectures/other/Automatic%20differentiation%20and%20applications.ipynb).

We want to define a new *type* `Dual` of object in Julia, and define its rules of arithmetic and so on.

In [1]:
struct Dual <: Number
    val::Real # value
    der::Real # derivative
end

In [2]:
d = Dual(3, 4)

Dual(3, 4)

In [3]:
typeof(d)

Dual

Define some arithmetic rules:

In [4]:
+

+ (generic function with 163 methods)

In [5]:
methods(+)

In [6]:
Base.:+(x::Dual, y::Dual) = Dual(x.val + y.val, x.der + y.der)
Base.:-(x::Dual, y::Dual) = Dual(x.val - y.val, x.der - y.der)

In [7]:
Dual(3,4) + Dual(5,6)

Dual(8, 10)

In [8]:
a = [Dual(3,4), Dual(5,6), Dual(9,10)]

3-element Array{Dual,1}:
  Dual(3, 4)
  Dual(5, 6)
 Dual(9, 10)

In [9]:
sum(a)

Dual(17, 20)

In [10]:
Dual(3,4) - Dual(5,6) 

Dual(-2, -2)

In [11]:
Base.:*(x::Dual, y::Dual) = Dual(x.val * y.val, x.der * y.val + x.val * y.der)

In [12]:
fun(x) = x^3 + x

fun (generic function with 1 method)

Let's try to evaluate $\operatorname{fun}'(4)$:

In [13]:
fun(Dual(4,1))

Dual(68, 49)

The correct derivative is $3x^2 + 1$, which gives:

In [14]:
x = 4
3x^2 + 1

49

Let's compute the derivate of $x^5$ at $x=2$.  The value should be $2^5 = 32$, and the derivative should be $5 \times 2^4 = 80$:

In [15]:
Dual(2,1)^5

Dual(32, 80)

In [16]:
Dual(2,1) + 3

ErrorException: promotion of types Dual and Int64 failed to change any arguments

In [17]:
2 + 3.7

5.7

In [18]:
promote_rule

promote_rule (generic function with 133 methods)

In [19]:
methods(promote_rule)

In [21]:
x = Dual(3,4)

Dual(3, 4)

In [24]:
foo(x::Dual) = 3

foo (generic function with 1 method)

In [28]:
foo(::Integer) = 7
foo(::Float64) = 8

foo (generic function with 3 methods)

In [29]:
foo(3)

7

In [30]:
foo(3.5)

8

In [27]:
Dual isa Type{Dual}

true

In [31]:
Base.promote_rule(::Type{Dual}, ::Type{<:Number}) = Dual

In [47]:
promote_rule(Float64, Int)

Float64

In [48]:
promote_rule(Dual, Int)

Dual

In [51]:
@which promote(Dual(3,4), 4.5)

In [34]:
Dual(3,4) + 5

MethodError: MethodError: no method matching Dual(::Int64)
Closest candidates are:
  Dual(::Real, !Matched::Real) at In[1]:2
  Dual(::Any, !Matched::Any) at In[1]:2
  Dual(::T<:Number) where T<:Number at boot.jl:725
  ...

In [35]:
Dual(x::Number) = Dual(x, 0)

Dual

In [36]:
Dual(3,4) + 5

Dual(8, 4)

In [37]:
f(x) = 3x^2 + 2x + 1

f (generic function with 1 method)

In [38]:
f(3)

34

In [39]:
f(Dual(3,1))

Dual(34, 20)

In [40]:
# manual derivative
f′(x) = 6x + 2

f′ (generic function with 1 method)

In [41]:
f′(3)

20

In [42]:
3 * Dual(2,4)

Dual(6, 12)

In [52]:
Dual(2,4) / 3

ErrorException: / not defined for Dual

In [54]:
Base.:/(x::Dual, y::Dual) = Dual(x.val / y.val, (x.der * y.val - x.val * y.der) / y.val^2)

In [55]:
Dual(3,4) / 2

Dual(1.5, 2.0)

In [56]:
f(x) = 1/x^3

f (generic function with 1 method)

In [59]:
f(2)

0.125

In [60]:
f(Dual(2,1))

Dual(0.125, -0.1875)

In [61]:
-3 / 2^4

-0.1875

In [62]:
sin(Dual(2,1))

MethodError: MethodError: no method matching sin(::Dual)
Closest candidates are:
  sin(!Matched::BigFloat) at mpfr.jl:683
  sin(!Matched::Missing) at math.jl:1056
  sin(!Matched::Complex{Float16}) at math.jl:1005
  ...

In [63]:
mysin(x) = x - x^3/6 + x^5/120

mysin (generic function with 1 method)

In [65]:
mysin(0.1)

0.09983341666666667

In [66]:
sin(0.1)

0.09983341664682815

In [67]:
mysin(Dual(0.1,1))

Dual(0.09983341666666667, 0.9950041666666667)

In [68]:
cos(0.1)

0.9950041652780258

In [69]:
Base.sin(x::Dual) = Dual(sin(x.val), cos(x.val)*x.der)

In [70]:
sin(sin(sin(Dual(1,1))))

Dual(0.6784304773607402, 0.2645082703959581)

In [71]:
sin(sin(sin(1)))

0.6784304773607402

# Performance

The `Dual` type defined above works, but it is slow.   Let's try to understand this.

In [75]:
# install with:  ] add BenchmarkTools
using BenchmarkTools

In [76]:
@btime sin(1)

  7.988 ns (0 allocations: 0 bytes)


0.8414709848078965

In [78]:
@benchmark sin(1)

BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     7.540 ns (0.00% GC)
  median time:      7.789 ns (0.00% GC)
  mean time:        8.010 ns (0.00% GC)
  maximum time:     54.465 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     999

In [82]:
@btime (3.7+2.5im) + (4.5+6.4im)

  1.504 ns (0 allocations: 0 bytes)


8.2 + 8.9im

In [83]:
@btime Dual(3.7,2.5) + Dual(4.5,6.4)

  134.681 ns (7 allocations: 160 bytes)


Dual(8.2, 8.9)

In [84]:
a = rand(10^5)

100000-element Array{Float64,1}:
 0.44400521872407306  
 0.4632049540213794   
 0.18918637498391822  
 0.23664273217516407  
 0.5718319804572034   
 0.3135945358548766   
 0.42709292870253845  
 0.44755220598447587  
 0.17300006457620754  
 0.17253257532796185  
 0.6251881309639795   
 0.8137315380906045   
 0.5065892771699303   
 ⋮                    
 0.6012587887575735   
 0.0007215745390714012
 0.6808662402900936   
 0.8138007077960143   
 0.6106868164429891   
 0.8298862931854845   
 0.834683433827301    
 0.28432798729998576  
 0.43462694058307383  
 0.07734277175037008  
 0.3498099105335042   
 0.09688165144720329  

In [85]:
ac = complex(a)

100000-element Array{Complex{Float64},1}:
   0.44400521872407306 + 0.0im
    0.4632049540213794 + 0.0im
   0.18918637498391822 + 0.0im
   0.23664273217516407 + 0.0im
    0.5718319804572034 + 0.0im
    0.3135945358548766 + 0.0im
   0.42709292870253845 + 0.0im
   0.44755220598447587 + 0.0im
   0.17300006457620754 + 0.0im
   0.17253257532796185 + 0.0im
    0.6251881309639795 + 0.0im
    0.8137315380906045 + 0.0im
    0.5065892771699303 + 0.0im
                       ⋮      
    0.6012587887575735 + 0.0im
 0.0007215745390714012 + 0.0im
    0.6808662402900936 + 0.0im
    0.8138007077960143 + 0.0im
    0.6106868164429891 + 0.0im
    0.8298862931854845 + 0.0im
     0.834683433827301 + 0.0im
   0.28432798729998576 + 0.0im
   0.43462694058307383 + 0.0im
   0.07734277175037008 + 0.0im
    0.3498099105335042 + 0.0im
   0.09688165144720329 + 0.0im

In [86]:
ad = Dual.(a)

100000-element Array{Dual,1}:
   Dual(0.44400521872407306, 0)
    Dual(0.4632049540213794, 0)
   Dual(0.18918637498391822, 0)
   Dual(0.23664273217516407, 0)
    Dual(0.5718319804572034, 0)
    Dual(0.3135945358548766, 0)
   Dual(0.42709292870253845, 0)
   Dual(0.44755220598447587, 0)
   Dual(0.17300006457620754, 0)
   Dual(0.17253257532796185, 0)
    Dual(0.6251881309639795, 0)
    Dual(0.8137315380906045, 0)
    Dual(0.5065892771699303, 0)
                              ⋮
    Dual(0.6012587887575735, 0)
 Dual(0.0007215745390714012, 0)
    Dual(0.6808662402900936, 0)
    Dual(0.8138007077960143, 0)
    Dual(0.6106868164429891, 0)
    Dual(0.8298862931854845, 0)
     Dual(0.834683433827301, 0)
   Dual(0.28432798729998576, 0)
   Dual(0.43462694058307383, 0)
   Dual(0.07734277175037008, 0)
    Dual(0.3498099105335042, 0)
   Dual(0.09688165144720329, 0)

In [88]:
@btime sum($a)

  15.356 μs (0 allocations: 0 bytes)


50046.85362857155

In [89]:
@btime sum($ac)

  112.907 μs (0 allocations: 0 bytes)


50046.85362857155 + 0.0im

In [90]:
@btime sum($ad)

  13.087 ms (299997 allocations: 6.10 MiB)


Dual(50046.85362857155, 0)

In [91]:
typeof(a)

Array{Float64,1}

To make a faster Dual number, we need to make the `val` and `der` fields * concrete* types.  At first glance, it seems like this would lose generality:

In [92]:
struct FastInflexibleDual <: Number
    val::Float64 # value
    der::Float64 # derivative
end 
Base.:+(x::FastInflexibleDual, y::FastInflexibleDual) = FastInflexibleDual(x.val + y.val, x.der + y.der)
FastInflexibleDual(x::Real) = FastInflexibleDual(x, 0)

FastInflexibleDual

In [93]:
afd = FastInflexibleDual.(a)

100000-element Array{FastInflexibleDual,1}:
   FastInflexibleDual(0.44400521872407306, 0.0)
    FastInflexibleDual(0.4632049540213794, 0.0)
   FastInflexibleDual(0.18918637498391822, 0.0)
   FastInflexibleDual(0.23664273217516407, 0.0)
    FastInflexibleDual(0.5718319804572034, 0.0)
    FastInflexibleDual(0.3135945358548766, 0.0)
   FastInflexibleDual(0.42709292870253845, 0.0)
   FastInflexibleDual(0.44755220598447587, 0.0)
   FastInflexibleDual(0.17300006457620754, 0.0)
   FastInflexibleDual(0.17253257532796185, 0.0)
    FastInflexibleDual(0.6251881309639795, 0.0)
    FastInflexibleDual(0.8137315380906045, 0.0)
    FastInflexibleDual(0.5065892771699303, 0.0)
                                              ⋮
    FastInflexibleDual(0.6012587887575735, 0.0)
 FastInflexibleDual(0.0007215745390714012, 0.0)
    FastInflexibleDual(0.6808662402900936, 0.0)
    FastInflexibleDual(0.8138007077960143, 0.0)
    FastInflexibleDual(0.6106868164429891, 0.0)
    FastInflexibleDual(0.8298862931854845, 0

In [94]:
@btime sum($afd)

  106.673 μs (0 allocations: 0 bytes)


FastInflexibleDual(50046.85362857155, 0.0)

In [95]:
FastInflexibleDual(3,4)

FastInflexibleDual(3.0, 4.0)

## Solution: Define a *parameterized* type

Define a whole *family* (a *set*) of types:

In [96]:
struct FastDual{T<:Real} <: Number
    val::T # value
    der::T # derivative
end 
Base.:+(x::FastDual, y::FastDual) = FastDual(x.val + y.val, x.der + y.der)

In [97]:
FastDual(3,4)

FastDual{Int64}(3, 4)

In [98]:
FastDual(3,4.7)

MethodError: MethodError: no method matching FastDual(::Int64, ::Float64)
Closest candidates are:
  FastDual(::T<:Real, !Matched::T<:Real) where T<:Real at In[96]:2
  FastDual(::T<:Number) where T<:Number at boot.jl:725

In [99]:
FastDual(x::Real) = FastDual(x, 0)

FastDual

In [100]:
FastDual(3)

FastDual{Int64}(3, 0)

In [101]:
FastDual(3.7)

MethodError: MethodError: no method matching FastDual(::Float64, ::Int64)
Closest candidates are:
  FastDual(::Real) at In[99]:1
  FastDual(::T<:Real, !Matched::T<:Real) where T<:Real at In[96]:2
  FastDual(::T<:Number) where T<:Number at boot.jl:725

In [102]:
promote(3, 4.7)

(3.0, 4.7)

In [103]:
FastDual(x::Real, y::Real) = FastDual(promote(x,y)...)

FastDual

In [104]:
FastDual(3,4.7)

FastDual{Float64}(3.0, 4.7)

In [105]:
FastDual(3.8)

FastDual{Float64}(3.8, 0.0)

In [106]:
afd = FastDual.(a)

100000-element Array{FastDual{Float64},1}:
   FastDual{Float64}(0.44400521872407306, 0.0)
    FastDual{Float64}(0.4632049540213794, 0.0)
   FastDual{Float64}(0.18918637498391822, 0.0)
   FastDual{Float64}(0.23664273217516407, 0.0)
    FastDual{Float64}(0.5718319804572034, 0.0)
    FastDual{Float64}(0.3135945358548766, 0.0)
   FastDual{Float64}(0.42709292870253845, 0.0)
   FastDual{Float64}(0.44755220598447587, 0.0)
   FastDual{Float64}(0.17300006457620754, 0.0)
   FastDual{Float64}(0.17253257532796185, 0.0)
    FastDual{Float64}(0.6251881309639795, 0.0)
    FastDual{Float64}(0.8137315380906045, 0.0)
    FastDual{Float64}(0.5065892771699303, 0.0)
                                             ⋮
    FastDual{Float64}(0.6012587887575735, 0.0)
 FastDual{Float64}(0.0007215745390714012, 0.0)
    FastDual{Float64}(0.6808662402900936, 0.0)
    FastDual{Float64}(0.8138007077960143, 0.0)
    FastDual{Float64}(0.6106868164429891, 0.0)
    FastDual{Float64}(0.8298862931854845, 0.0)
     FastDual{Flo

In [107]:
@btime sum($afd)

  106.647 μs (0 allocations: 0 bytes)


FastDual{Float64}(50046.85362857155, 0.0)

In [108]:
aslow = [3, "string", [3,4,5], Dual(3,5)]

4-element Array{Any,1}:
          3         
           "string" 
           [3, 4, 5]
 Dual(3, 5)         