In the lecture, we saw an example of forward mode auto-differentiation. 

Another way to compute derivatives using autodiff is the so-called backward mode.

In the backward mode, there are two passes - forward and backward.

<br>
<br>


<div>
    <center>
        <img src="figures/reverse_mode_autodiff.png" width="1000"/>
    </center>
</div>

<br>
<br>

Notice how, in just one forward+backward pass, we compute the derivative with respect to `x1` and `x2`.

This is contrary to forward mode, where we need one pass per differentiation wrt `x1` and another pass wrt `x2`.

$\textbf{Exercise}$:

Re-write the differentiation function `fn_prime(x1,x2)` for the funtion `fn(x1, x2)` defined below using reverse-mode differentiation concepts. 

Do this by calculating the adjoints. 

In [117]:
import jax
from jax import numpy as jnp

import numpy as np

# This is the same function as the lecture

def fn(x1, x2):
    
    a = x1/x2      
    b = jnp.exp(x2)     
    c = jnp.sin(a)
    d = a - b
    e = c + d
    g = d * e
    
    return g

In [None]:
# Write down the function

def fn_prime(x1, x2):
    
    

Once you write this down, compare the outputs with the `jax.grad` method.

In [None]:
print("Your output")
print(fn_prime(1.0,1.0))

print("JAX output")
print(jax.grad(fn, argnums=(0,1))(1.0,1.0))

<br>

A common step in doing a statistical analysis is calculating the Hessian matrix for a negative log-likelihood function.

The Hessian matrix of a function, as you might be aware, is a matrix of second derivatives of the function wrt it's parameters:

<br>
<br>


<div>
    <center>
        <img src="figures/hessian.png" width="500"/>
    </center>
</div>

<br>
<br>

The dimensions of a Hessian matrix is thus equal to the number of parameters of the function. 

The more parameters we have, the larger it is. In ML models, these parameters can range from thousands to billions!

This is why in most ML optimizations, a Hessian matrix is not computed by default.

<br>

But we are not intereseted in ML applications in these set of lectures. 

Hessian matrix is also computed for parameter fitting in the profile likelihood fit step of a typical statistical analysis. 

This Hessian matrix is typically $O(100)\times O(100)$

We will explore this in the next lecture, but for now let us try and calculate Hessian matrices for arbitrary functions!

<br>

Let's say we have the following two (toy) function:


In [9]:
# Defining a toy function - we will be working with passing arrays of arbitrary sizes

def fn(tuple_arr):
    
    return jnp.sum(tuple_arr ** 2 - tuple_arr ** 3 - tuple_arr)


Using the `jax.hessian` method, try calculating the Hessian matrix of this function at the point $(x,y) = (1,0)$

Documentation: https://jax.readthedocs.io/en/latest/_autosummary/jax.hessian.html

In [10]:
# Calculate the Hessian using the jax.hessian method

from jax import hessian

hess_mat = hessian(fn)(jnp.array([1.0,0.0]))

print(hess_mat)

[[-4.  0.]
 [ 0.  2.]]


Now compute the hessian for a 100-dimensional array of randomly assigned numberse between 0 and 1

In [13]:
# Calculate the Hessian using the jax.hessian method

import numpy as np

hess_mat = hessian(fn)(np.random.uniform(1.0,1.0,size=100))

print(hess_mat)

[[-4.  0.  0. ...  0.  0.  0.]
 [ 0. -4.  0. ...  0.  0.  0.]
 [ 0.  0. -4. ...  0.  0.  0.]
 ...
 [ 0.  0.  0. ... -4.  0.  0.]
 [ 0.  0.  0. ...  0. -4.  0.]
 [ 0.  0.  0. ...  0.  0. -4.]]


<br>

Well that was easy, right? JAX is a powerful and easy to use tool. But we have just scratched the surface of it's capabilities.

Now let's explore a bit how the Hessian is calculated under the hood (https://jax.readthedocs.io/en/latest/_modules/jax/_src/api.html#hessian):

<br>

```
def hessian(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
            has_aux: bool = False, holomorphic: bool = False) -> Callable:
            
            
    return jacfwd(jacrev(fun, argnums, has_aux=has_aux, holomorphic=holomorphic),
                argnums, has_aux=has_aux, holomorphic=holomorphic)
```

It seems that it is calculating the full Hessian matrix by first performing a reverse mode differentiation, and then performing a forward mode differentiation ont he output.

For the scalar valued function we use, this looks like

$$\nabla f : \mathbb{R}^{100} \rightarrow \mathbb{R}^{100}  \\ \text{(Calculated using Reverse Mode)}$$ 

<br>

$$\nabla(\nabla f) : \mathbb{R}^{100} \rightarrow \mathbb{R}^{100 \times 100} \\ \text{(Calculated using Forward Mode)}$$ 

Is this the most efficient way of computation? Why not use jacfwd for both steps? Or reverse mode? Or forward mode followed by reverse mode?

Let's try it out. Let's calculate the computation times

In [16]:
%%timeit -r10 -n10

#Below write down the simple hessian function calculated for f(1.0,1.0)

jax.hessian(fn)(jnp.array(np.ones(200))).block_until_ready()

18.2 ms ± 213 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)


In [17]:
%%timeit -r10 -n10

#Below compute the hessian using a combination of forward mode differentiations only 

jax.jacfwd(jax.jacfwd(fn))(jnp.array(np.ones(200))).block_until_ready()

85.8 ms ± 3.3 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)


In [18]:
%%timeit -r10 -n10

#Below compute the hessian using a combination of reverse mode differentiations only 

jax.jacrev(jax.jacrev(fn))(jnp.array(np.ones(200))).block_until_ready()

22.5 ms ± 6 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)


In [19]:
%%timeit -r10 -n10

#Below compute the hessian after reversing the order of forward and backward modes in the jax hessian method

jax.jacrev(jax.jacfwd(fn))(jnp.array(np.ones(200))).block_until_ready()

102 ms ± 6.26 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)


In [20]:
%%timeit -r10 -n10

#Below compute the hessian after reversing the order of forward and backward modes in the jax hessian method

jax.jacfwd(jax.jacrev(fn))(jnp.array(np.ones(200))).block_until_ready()

17.6 ms ± 481 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)


<br>

Can you guess a pattern in these results?

Think about it carefully, based on the calculation flow used by forward and backward modes of automatic differentiation. Which do you think would be optimal? And when?

I would recommend you discuss with your colleagues and come up with a satisfactory answer before moving down!

<br>

<br>

<br>

<br>

<br>

## Answer

In [149]:
%%timeit -r1 -n1

jax.jacrev(fn)(jnp.array(np.ones(1000))).block_until_ready()

46 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [150]:
%%timeit -r1 -n1

jax.jacfwd(fn)(jnp.array(np.ones(1000))).block_until_ready()

26.5 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
