# Optimized Learning

In this notebook, we are going to look at another one of JAX's functional transforms: the `grad` function.

## Autograd to JAX

Before they worked on JAX, the JAX core team worked on another Python package called autograd.
That was where the original idea of building an automatic differentiation system on top of NumPy started.

## Example: Transforming a function into its derivative

Just like `vmap`, `grad` takes in a function and transforms it into another function.
By default, the returned function from `grad`
is the derivative of the function with respect to the first argument.

In [None]:
# Example 1:
from jax import grad


def func(x):
    return 3 * x + 1


df = grad(func)


# Pass in any float value of x, you should get back 3.0 as the _gradient_.
df(4.0)

In [None]:
# Example 2:


def polynomial(x):
    return 3 * x ** 2 + 4 * x - 3


dpolynomial = grad(polynomial)

# pass in any float value of x
# the result will be evaluated at 6x + 4,
# which is the gradient of the polynomial function.
dpolynomial(3.0)

Using grad to solve minimization problems.

In [None]:
# Example: find the minima of the polynomial function.

start = 3.0
for i in range(200):
    start -= dpolynomial(start) * 0.01
start

How close is this to the true analytical value?

$$f(x) = 3x^{2} + 4x -3$$
$$\frac{df}{dx} = 6x + 4$$

At the minima, $\frac{df}{dx}$ is zero. Therefore, $x = -\frac{2}{3}$.

We're pretty darn close.

And that, my friends, is gradient descent!

## maximum likelihood of parameters

In [None]:
from jax import random

key = random.PRNGKey(44)
real_mu = -3.0
real_log_sigma = np.log(2.0)  # the real sigma is 2.0


data = random.normal(key, shape=(1000,)) * np.exp(real_log_sigma) + real_mu

## what is the maximum likelihood value of mu and sigma given the data?

equivalent to minimizing negative log likelihood

In [None]:
from jax.scipy.stats import norm
import jax.numpy as np

def neglogp(mu, log_sigma, data):
    return -np.sum(norm.logpdf(data, loc=mu, scale=np.exp(log_sigma)))

Check that calculation is correct.

In [None]:
mu = -6.0
log_sigma = np.log(2.0)
neglogp(mu, log_sigma, data)

In [None]:
dneglogp = grad(neglogp, argnums=(0, 1))

# condition on data
dneglogp = partial(dneglogp, data=data)
dneglogp(mu, log_sigma)

In [None]:
from functools import partial

# gradient descent
for i in range(300):
    dmu, dlog_sigma = dneglogp(mu, log_sigma)
    mu -= dmu * 0.0001
    log_sigma -= dlog_sigma * 0.0001
mu, np.exp(log_sigma)

## grad with multiple arguments

Where is the gold? It's at the minima!

In [None]:
def func(x, y):
    """All credit to https://www.analyzemath.com/calculus/multivariable/maxima_minima.html for this function."""
    return (2 * x ** 2) - (4 * x * y) + (y ** 4 + 2)

It should be evident from here that there are two minima in the function.
Let's find out where they are.

In [None]:
df = grad(func, argnums=[0, 1])
df(3.0, 4.0)

In [None]:
# Start somewhere
x, y = 0.1, -0.1
for i in range(300):
    dx, dy = df(x, y)
    x -= dx * 0.01
    y -= dy * 0.01
x, y

In [None]:
import matplotlib.pyplot as plt
from matplotlib import cm

fig, ax = plt.subplots(subplot_kw={"projection": "3d"})

# Make data.
X = np.arange(-1.5, 1.5, 0.01)
Y = np.arange(-1.5, 1.5, 0.01)
X, Y = np.meshgrid(X, Y)
Z = func(X, Y)

# Plot the surface.
surf = ax.plot_surface(
    X, Y, Z, cmap=cm.coolwarm, linewidth=0, antialiased=False,
)
ax.view_init(elev=20., azim=20)