<a href="https://colab.research.google.com/github/cu-applied-math/SciML-Class/blob/main/Demos/AutomaticDifferentiation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Automatic Differentiation

Code resources
#### Python
Some of the big ones are:
- Tensorflow
- PyTorch
- [JAX](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html)
All these Python packages do more than just AD

#### Matlab
- [ADiMat](https://www.sc.informatik.tu-darmstadt.de/res/sw/adimat/general/index.en.jsp) does reverse-mode AD
There are not a lot of high-quality reverse-mode AD packages in Matlab

#### Julia
There are a lot of good choices, and it will only get better as more people use Julia
- See [JuliaDiff](https://juliadiff.org/) for a curated list of packages, and the 2020 answer by ChrisRackauckas in this [forum post](https://discourse.julialang.org/t/state-of-automatic-differentiation-in-julia/43083) has a good discussion of pros and cons (e.g., which can do reverse mode, which can do Hessian-vector products, etc.)

#### Fortran and classic scientific computing languagse
- see [autodiff.org](http://www.autodiff.org/?module=Tools), e.g., [ADIFOR 2.0](https://www.mcs.anl.gov/research/projects/adifor/) for Fortran77


## Demo
This demo is forked from the version at [github.com/cu-numcomp/numcomp-class/blob/master/Differentiation.ipynb](https://github.com/cu-numcomp/numcomp-class/blob/master/Differentiation.ipynb) which was written by Prof. Jed Brown for CU's CSCI 3656 "Numerical Computation" Spring 2020, released under the simplified BSD license. (The original version has demos on finite differences as well, which are not in this version)

Copyright (c) 2017-2020, Jed Brown
All rights reserved.

Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
  list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice, this
  list of conditions and the following disclaimer in the documentation and/or
  other materials provided with the distribution.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('ggplot')

# Symbolic differentiation

We've been differentiating basic mathematical functions, for which there is a formula for the derivative.
Symbolic differentiation is a tool that can compute those expressions (and generate code to evaluate the expressions numerically).

In [None]:
import sympy
from sympy.abc import x

f = sympy.cos(x**sympy.pi) * sympy.log(x)
f

log(x)*cos(x**pi)

In [None]:
sympy.diff(f, x)

-pi*x**pi*log(x)*sin(x**pi)/x + cos(x**pi)/x

#### You can even ask `sympy` to give you a formula you can evaluate in a coding language

In [None]:
sympy.ccode(f, 'y')

'y = log(x)*cos(pow(x, M_PI));'

In [None]:
sympy.fcode(f, 'y')

'      y = log(x)*cos(x**3.1415926535897932d0)'

#### And `sympy` can evaluate the formula itself
In Mathematica, this is like adding a `//N` to the end of an expression

In [None]:
f.evalf(40, subs={x: 1.9})

0.2155134138380419067452319459177557208730

#### A more complicated function
and its derivative

In [None]:
def g(x, m=np):
    y = x
    for i in range(2):
        # a = m.log(y)
        # b = y ** m.pi
        # c = m.cos(b)
        # y = c * a
        y = m.cos(y**m.pi) * m.log(y)
    return y

gexpr = g(x, m=sympy)
gexpr

log(log(x)*cos(x**pi))*cos((log(x)*cos(x**pi))**pi)

In [None]:
sympy.diff(gexpr, x)

-pi*(log(x)*cos(x**pi))**pi*(-pi*x**pi*log(x)*sin(x**pi)/x + cos(x**pi)/x)*log(log(x)*cos(x**pi))*sin((log(x)*cos(x**pi))**pi)/(log(x)*cos(x**pi)) + (-pi*x**pi*log(x)*sin(x**pi)/x + cos(x**pi)/x)*cos((log(x)*cos(x**pi))**pi)/(log(x)*cos(x**pi))

#### Another complicated function
An example of a Speelpenning function (from [Peder Olsen's slides](https://researcher.watson.ibm.com/researcher/files/us-pederao/ADTalk.pdf) )

In [None]:
m = 8
#t = np.random.randn(m)
t = np.arange(1,m+1)

In [None]:
def f(x):
  y = 1
  for i in range(m):
    y *= x - t[i]
  return y

f(x)

(x - 8)*(x - 7)*(x - 6)*(x - 5)*(x - 4)*(x - 3)*(x - 2)*(x - 1)

But $f$ is just a polynomial!  Is `sympy` clever enough to do the right thing?  No.

In [None]:
sympy.diff(f(x), x)

(x - 8)*(x - 7)*(x - 6)*(x - 5)*(x - 4)*(x - 3)*(x - 2) + (x - 8)*(x - 7)*(x - 6)*(x - 5)*(x - 4)*(x - 3)*(x - 1) + (x - 8)*(x - 7)*(x - 6)*(x - 5)*(x - 4)*(x - 2)*(x - 1) + (x - 8)*(x - 7)*(x - 6)*(x - 5)*(x - 3)*(x - 2)*(x - 1) + (x - 8)*(x - 7)*(x - 6)*(x - 4)*(x - 3)*(x - 2)*(x - 1) + (x - 8)*(x - 7)*(x - 5)*(x - 4)*(x - 3)*(x - 2)*(x - 1) + (x - 8)*(x - 6)*(x - 5)*(x - 4)*(x - 3)*(x - 2)*(x - 1) + (x - 7)*(x - 6)*(x - 5)*(x - 4)*(x - 3)*(x - 2)*(x - 1)

`sympy` does the right thing if we give it some help...

In [None]:
sympy.expand( f(x) )

x**8 - 36*x**7 + 546*x**6 - 4536*x**5 + 22449*x**4 - 67284*x**3 + 118124*x**2 - 109584*x + 40320

In [None]:
sympy.diff(sympy.expand( f(x) ), x)

8*x**7 - 252*x**6 + 3276*x**5 - 22680*x**4 + 89796*x**3 - 201852*x**2 + 236248*x - 109584

# Hand-coding derivatives

The size of these expressions grow exponentially in the number of loop iterations, yet one can write efficient code for computing the derivative by hand.  We use the variational notation

$$ \operatorname{d} f = f'(x) \operatorname{d} x $$

which allows us to break a large computation into simple pieces that we can compute incrementally, instead of trying to build up expressions for complicated functions.  That is, we can differentiate a composition $h(g(f(x)))$ as

\begin{align}
  \operatorname{d} h &= h' \operatorname{d} g \\
  \operatorname{d} g &= g' \operatorname{d} f \\
  \operatorname{d} f &= f' \operatorname{d} x.
\end{align}
Consider our example above.

In [None]:
def gprime(x):
    y = x
    dy = 1
    for i in range(2):
        a = np.log(y)
        da = 1/y * dy
        b = y ** np.pi
        db = np.pi * y ** (np.pi - 1) * dy
        c = np.cos(b)
        dc = -np.sin(b) * db
        y = c * a
        dy = dc * a + c * da
    return y, dy

print('by hand', gprime(1.9))

by hand (-1.5346823414986814, -34.03241959914048)


* This code is pretty mechanical to write
* It's hard to maintain as you add new features
* It's hard to debug
  * You can test using finite differencing
  * You can take apart pieces for unit testing and/or debugging
* If you know you'll be writing this sort of code, plan ahead!

### Variational notation is handy (an example)

We'll differentiate the expression

$$ I = A^{-1} A $$
applying the product rule

$$ 0 = A^{-1} (\operatorname dA) + (\operatorname dA^{-1}) A $$
and collect terms

$$ \operatorname dA^{-1} = - A^{-1} (\operatorname dA) A^{-1}. $$

This expression for the derivative $\operatorname d A^{-1}$ in direction $\operatorname d A$ is useful when differentiating algorithmn that involve linear algebra.

## Reverse-mode

What we've done above is called "forward mode", and amounts to placing the parentheses in the chain rule like

$$ \operatorname d h = \frac{dh}{dg} \left(\frac{dg}{df} \left(\frac{df}{dx} \operatorname d x \right) \right) .$$

The expression means the same thing if we rearrange the parentheses,

$$ \operatorname d h = \left( \left( \left( \frac{dh}{dg} \right) \frac{dg}{df} \right) \frac{df}{dx} \right) \operatorname d x .$$

In [None]:
def gprime_rev(x):
    # First compute all the values by going through the iteration forwards
    # I'm unrolling two iterations here for clarity ("static single assignment" form)
    # It is possible to write code that keeps the loop structure.
    a1 = np.log(x)
    b1 = x ** np.pi
    c1 = np.cos(b1)
    y1 = c1 * a1
    a2 = np.log(y1)
    b2 = y1 ** np.pi
    c2 = np.cos(b2)
    y = c2 * a2 # Result
    # Now go backwards computing dy/d_ for each variable
    y_ = 1
    y_c2 = y_ * a2
    y_a2 = c2 * y_
    y_b2 = -y_c2 * np.sin(b2) # dy/db2 = dy/dc2 dc2/db2
    y_y1 = y_b2 * np.pi * y1 ** (np.pi - 1) + y_a2 / y1
    y_c1 = y_y1 * a1
    y_a1 = c1 * y_y1
    y_b1 = -y_c1 * np.sin(b1)
    y_x = y_b1 * np.pi * x ** (np.pi - 1) + y_a1 / x
    return y, y_x

print('forward', gprime(1.9))
print('reverse', gprime_rev(1.9))

forward (-1.5346823414986814, -34.03241959914048)
reverse (-1.5346823414986814, -34.03241959914049)


* This is fairly mechanical, similar to forward-mode
* It is more complicated than forward-mode
* This sort of code is tricky to debug
  * You can test using forward-mode or finite differencing
* We need the results of intermediate computation in reverse order
  * We have to store those values somewhere ("taping" in the literature)
  * Or we have to recompute them (see "hierarchical checkpointing")
* Reverse-mode is also known as the "adjoint" method and "back-propagation".
  
### Why reverse?

If all we had was scalar functions of scalar inputs, we would never use reverse mode.  But let's suppose we are given a dot product with a constant vector.

$$ f(\mathbf x) = \mathbf c^T \mathbf x = \begin{pmatrix} c_0 & c_1 & c_2 & \dotsb \end{pmatrix} \begin{pmatrix} x_0 \\ x_1 \\ x_2 \\ \vdots \end{pmatrix} $$
and wish to compute the gradient
$$ \nabla_{\mathbf x} f = \frac{\partial f}{\partial \mathbf x} = \begin{pmatrix} \frac{\partial f}{\partial x_0} & \frac{\partial f}{\partial x_1} & \frac{\partial f}{\partial x_2} & \dotsb \end{pmatrix} . $$

In [None]:
def dot(c, x):
    n = len(c)
    sum = 0
    for i in range(n):
        sum += c[i] * x[i]
    return sum

n = 20
c = np.random.randn(n)
x = np.random.randn(n)
f = dot(c, x)
f

-0.5295781449140192

If we use forward mode, we can only compute one direction at a time, effectively
$$ \left(\nabla_{\mathbf x} f\right) \cdot \operatorname d x $$
for one value of the vector $\operatorname d x$ at a time.
We can compute the full gradient by choosing $\operatorname d x$ to be each column of the identity.

In [None]:
def dot_x(c, x, dx):
    """Compute derivative in direction dx"""
    n = len(c)
    dsum = 0
    for i in range(n):
        dsum += c[i] * dx[i]
    return dsum

def grad_dot(c, x):
    n = len(c)
    I = np.eye(n)
    grad = np.zeros(n)
    for j in range(n):
        dx = I[:,j]
        grad[j] = dot_x(c, x, dx)
    return grad

grad_dot(c, x)

array([ 1.06724138, -0.37276519, -0.41470903,  0.37528405,  0.06614393,
       -1.50760904,  1.80053804, -0.24942083,  1.67704712,  1.13885786,
       -0.67176614, -1.28040531, -1.50094259, -0.2799402 , -0.33825806,
       -0.04775478,  0.37819716, -0.28330565,  0.37293594,  0.73033729])

We've now traversed the loop with our work as many times as there are components in the vector.  The forward evaluation for `dot` costs $O(n)$ and computing the gradient costs $O(n^2)$ because we have to do $O(n)$ for for each direction and there are $n$ directions.

Compare with reverse-mode

In [None]:
def grad_dot_rev(c, x):
    n = len(c)
    sum_ = np.zeros(n)
    for i in range(n):
        sum_[i] = c[i]
    return sum_

grad_dot_rev(c, x)

array([ 1.06724138, -0.37276519, -0.41470903,  0.37528405,  0.06614393,
       -1.50760904,  1.80053804, -0.24942083,  1.67704712,  1.13885786,
       -0.67176614, -1.28040531, -1.50094259, -0.2799402 , -0.33825806,
       -0.04775478,  0.37819716, -0.28330565,  0.37293594,  0.73033729])

* We get the same values in only $O(n)$ work!
* The astute reader may recall that we already worked out this case,
$$ \frac{\partial \mathbf c^T \mathbf x}{\partial \mathbf x} = \mathbf c^T .$$

## Shape of the gradient (Jacobian)

Suppose we have a vector-valued function of vector-valued input, $\mathbf f(\mathbf x)$ where $\mathbf f$ has length $m$ and $\mathbf x$ has length $n$.
* The gradient (Jacobian) matrix $J = \nabla_{\mathbf x} \mathbf f$ has shape $m\times n$.
* Usually in optimization, $m=1$ because we only have one objective
* If $m\ll n$ then finite differencing and forward-mode differentiation will be much more expensive than reverse-mode differentiation
  * Find a way to use reverse-mode!
* If $m \approx n$ then either is about as efficient, but forward-mode is simpler.
* If $m \gg n$ then forward-mode is the ticket.
* In real computations, there may be expensive stages that have lower dimension inputs or outputs, in which case those can be captured. An example is
$$ \mathbf f(\mathbf x) = \mathbf q \sigma(\mathbf q^T \mathbf x) $$
where $\sigma$ is an expensive nonlinear function.
The Jacobian $J = \nabla_{\mathbf x} \mathbf f$ is a square matrix, but naive forward- and reverse-mode would both require $n$ evaluations of $\sigma$.
Since $\sigma$ is a scalar-valued function of a scalar argument, $\sigma'(\mathbf q^T \mathbf x)$ is just one number, and thus $J = (\sigma') \mathbf q \mathbf q^T$ is readily available (and you know it's rank-1 so don't need to store all $n^2$ entries). Models of this sort show up frequently in physical modeling.

# Algorithmic (automatic) differentiation

Next, we'll consider ways to have libraries/compilers generate by-hand code such as we see above.
We'll use the [JAX](https://jax.readthedocs.io/en/latest/) library, which offers differentiation of NumPy computations (and offload to GPUs, which we won't use now).
Uncomment the line below if you need to install `jax` and `jaxlib`.

In [None]:
# ! pip install jax jaxlib

In [None]:
import jax
import jax.numpy as jnp

# def g_jax(x):
#     """Same function as before, but using jnp in place of np."""
#     y = x
#     for i in range(2):
#         y = jnp.cos(y**jnp.pi) * jnp.log(y)
#     return y

g_jax = lambda x : g(x, m=jnp)

gprime_jax = jax.grad(g_jax)
print(gprime_jax(1.9))
print(gprime(1.9)[1])

-34.03244
-34.03241959914048


#### Example with linear function

In [None]:
n = 20
c = np.random.randn(n)
y = np.random.randn(n)

# SRB changing example a bit
# Let h(y) = dot(y,c), then grad h = c
h = lambda y : jnp.vdot(y,c)

print("Gradient via JAX AD is")
print( jax.grad(h)(y) )
# Alternatively,
# jax.grad(jmp.vdot)(y,c) will pass in both y and c to the function
#     but only differentiation with respect to the first input (in this case, y)

print("Gradient worked out by hand is")
print( c )

Gradient via JAX AD is
[ 0.30825612  0.77228576  0.19304241  0.19611755 -0.30141184 -0.51392907
 -0.5446703  -1.9717118   1.0074332  -1.6879      1.598894    0.343732
 -0.12325148  0.05796708  1.5408208   0.7222661   0.32491362 -0.86020875
  0.8780183  -0.9273103 ]
Gradient worked out by hand is
[ 0.30825612  0.77228577  0.19304241  0.19611755 -0.30141184 -0.51392906
 -0.5446703  -1.97171174  1.00743313 -1.68789991  1.59889398  0.34373199
 -0.12325148  0.05796708  1.54082087  0.7222661   0.32491361 -0.86020875
  0.87801833 -0.92731027]


### another example
Here, jax isn't super fast (probably because it has to unroll the loop) but it's more accurate.  The sympy version uses an unstable formula and doesn't have good accuracy (it returns an answer of 0, which is incorrect)

In [None]:
m = 30
t = np.random.randn(m)

def f(x):
  y = 1
  for i in range(m):
    y *= x - t[i]
  return y

f(x)

x0 = t[0] + 1e-15

In [None]:
sympy.diff(f(x), x).evalf(25,subs={x: x0})

0

In [None]:
jax.grad(f)(x0)

DeviceArray(2.837033e-13, dtype=float32)

In [None]:
%%timeit
sympy.diff(f(x), x).evalf(15,subs={x: x0})

10 loops, best of 5: 70.7 ms per loop


In [None]:
%%timeit
jax.grad(f)(x0)

10 loops, best of 5: 69.6 ms per loop


## Software

* Algorithmic differentiation (AD) software has been around for over 40 years
* There are two classical approaches
  * Source transformation: AD tool emits Fortran (or C, etc.) code, which is compiled by a normal compiler
  * Operator overleading: each basic operation is overloaded to transform objects holding values + derivatives
* Source transformation is usually more efficient, retaining loop structure, etc.
* Implementations tend to have poor ergonomics, odd restrictions on use, poor composition.
* Vectorization has been poor with most classical tools.
* AD *implementations* have come a long way in the past few years (despite the math being old)
* Just-in-time compilation and extensive software engineering
* Exemplars:
  * [JAX](https://jax.readthedocs.io/en/latest/) for Python
  * [Zygote.jl](http://fluxml.ai/Zygote.jl/latest/) for Julia
* AD is great within its domain, but is still intrusive (especially for multi-language projects, languages with poor AD tooling, etc.).  Even in JAX, you'll see [various constraints](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html), such as that you can't in-place update an array.

In [None]:
z = jnp.zeros(3)
z[1] = 1

TypeError: ignored

* If you work in this space, you'll eventually learn to judge when to use AD and when to hand-code a derivative.  This type of decision lies at the intersection of numerical analysis and software engineering.