# Week 11 - Differentiable programming

Problem: to find a *minimum* of some function $J : \mathbb{R}^M \rightarrow \mathbb{R}$.

In [None]:
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np


def J(x, y):
    return np.sin(x ** 2 - y) * np.exp(-x ** 2 - y ** 2) * x * y


x = np.linspace(-2, 2, 201)
y = np.linspace(-2, 2, 201)
X, Y = np.meshgrid(x, y, indexing="ij")

fig, ax = plt.subplots()
p = ax.pcolormesh(X, Y, J(X, Y), cmap="plasma")
ax.set_xlim(x[0], x[-1])
ax.set_ylim(y[0], y[-1])
ax.set_aspect(1)
fig.colorbar(p)

## Gradient descent

If we have the derivative

$$F(x) = \nabla J (x),$$

with

$$\nabla J = \left( \begin{array}{c}
        \frac{\partial J}{\partial x_0} \\
        \frac{\partial J}{\partial x_1} \\
        \vdots \\
        \frac{\partial J}{\partial x_{M - 1}}
    \end{array} \right),$$

then we can take a step *downhill* (note: change in subscript notation)

$$x_{n + 1} = x_n - \alpha_n \nabla J (x_n),$$

for some $\alpha_n > 0$. More general approach: 

$$x_{n + 1} = x_n - \alpha_n A_n \nabla J (x_n),$$

where $A_n$ is some symmetric positive definite matrix.

Compare with Newton's method

$$x_{n + 1} = x_n - \left[ \nabla F ( x_n ) \right]^{-1} F ( x_n ).$$

**Key question:** How do we find the derivatives?

## Forward mode autodiff

Given some function $G : \mathbb{R}^M \rightarrow \mathbb{R}^N$ compute a 'Jacobian vector product' $(\nabla G) v$,

$$\left[ (\nabla G) v \right]_i = \sum_{j = 0}^{M - 1} \frac{\partial G_i}{\partial x_j} v_j,$$

for some $v \in \mathbb{R}^M$. Result is a vector in $\mathbb{R}^N$.

In [None]:
%matplotlib inline

import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)


def J(x, y):
    return jnp.sin(x ** 2 - y) * jnp.exp(-x ** 2 - y ** 2) * x * y


x = jnp.linspace(-2, 2, 201)
y = jnp.linspace(-2, 2, 201)
X, Y = jnp.meshgrid(x, y, indexing="ij")

fig, ax = plt.subplots()
p = ax.pcolormesh(X, Y, J(X, Y), cmap="plasma")
ax.set_xlim(x[0], x[-1])
ax.set_ylim(y[0], y[-1])
ax.set_aspect(1)
fig.colorbar(p)

In [None]:
x_0, y_0 = -0.5, -1.0
eps = 1e-3

J_val, jvp = jax.linearize(J, x_0, y_0)
print(f"{J_val=:.10g}")

dJdx = jvp(1.0, 0.0)
dJdy = jvp(0.0, 1.0)
print(f"{dJdx=:.10g}")
print(f"{dJdy=:.10g}")

dJdx_fd = (J(x_0 + eps, y_0) - J(x_0 - eps, y_0)) / (2 * eps)
dJdy_fd = (J(x_0, y_0 + eps) - J(x_0, y_0 - eps)) / (2 * eps)
print(f"{dJdx_fd=:.10g}")
print(f"{dJdy_fd=:.10g}")

## Reverse mode autodiff

Given some function $G : \mathbb{R}^M \rightarrow \mathbb{R}^N$ compute a 'vector Jacobian product' $v^T ( \nabla G )$

$$\left[ v^T (\nabla G) \right]_j = \sum_{i = 0}^{N - 1} v_i \frac{\partial G_i}{\partial x_j},$$

for some $v \in \mathbb{R}^N$. Result is a vector in $\mathbb{R}^M$.

In [None]:
J_val, vjp = jax.vjp(J, x_0, y_0)
print(f"{J_val=:.10g}")

dJ = vjp(1.0)
print(f"{dJ[0]=:.10g}")
print(f"{dJ[1]=:.10g}")

## Gradient-based optimization

In [None]:
from minimize import minimize

x0 = jnp.array([-0.5, -1.0], dtype=float)
print(f"{x0=}")
x_i = [x0.copy()]


def callback(x):
    print(f"{x=}")
    x_i.append(x.copy())


result = minimize(lambda x: J(*x), x0,
                  callback=callback,
                  method="L-BFGS-B",
                  options={"ftol": 1.0e-10,
                           "gtol": 1.0e-10,
                           "maxiter": 200})
assert result.success
print(f"{result.x=}")
x_i = jnp.array(x_i, dtype=float)

x = jnp.linspace(-2, 2, 201)
y = jnp.linspace(-2, 2, 201)
X, Y = jnp.meshgrid(x, y, indexing="ij")

fig, ax = plt.subplots()
p = ax.pcolormesh(X, Y, J(X, Y), cmap="plasma")
ax.plot(x_i[:, 0], x_i[:, 1], "wx--", markersize=4)
ax.set_xlim(x[0], x[-1])
ax.set_ylim(y[0], y[-1])
ax.set_aspect(1)
fig.colorbar(p)

## Worked example: Brachistochrone problem

Ball rolling on a curve subject to gravity, initially stationary. Falling a height $H$ over a distance $L$.

<div style="text-align: center";><img src="figures/week11/rolling.png" width="600" style="padding-top: 10px; padding-bottom: 30px;" alt="a ball rolling on a curve"/></div>

Break into a series of straight line pieces each covering a horizontal distance $h$.

<div style="text-align: center";><img src="figures/week11/rolling_discrete.png" width="600" style="padding-top: 10px; padding-bottom: 30px;" alt="a ball rolling on a discretized curve consisting of straight line pieces"/></div>

<br/>
Now look at a single straight-line piece making an angle $\theta$ with the vertical, with initial along-curve velocity $u_0$.

<div style="text-align: center";><img src="figures/week11/rolling_piece.png" width="600" style="padding-top: 10px; padding-bottom: 30px;" alt="a ball rolling on a single straight line piece"/></div>

The time taken to roll along the single piece, $t$, satisfies (we use the quadratic formula to compute $1 / t$, rather than $t$, to avoid divide-by-zeros)

$$\frac{1}{t} = \frac{u_0 \sin \theta + \sqrt{u_0^2 \sin^2 \theta + g h \sin 2 \theta}}{2 h}$$

after which the velocity is

$$u_1 = u_0 + g t \cos \theta.$$

We assume the ball remains on the curve, and that the ball's kinetic energy is unchanged when moving from one piece to the next, so that the ball's along-curve velocity is continuous in time.

If we have $N$ pieces then there are $(N - 1)$ angles which define the discretized curve, and we can compute the time take to roll down the curve.

In [None]:
%matplotlib inline

from functools import partial

import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp

from minimize import minimize

jax.config.update("jax_enable_x64", True)


def final_theta(theta, L, H):
    """Compute the angle the final piece of the discretized curve makes with
    the vertical.

    Parameters
    ----------

    theta : jax.numpy.ndarray
        Angles the first :math:`(N - 1)` pieces make with the vertical. Shape
        `(N - 1,`) with floating point data type.
    L : float
        Horizontal distance.
    H : float
        Height drop.

    Returns
    -------

    jax.numpy.ndarray
        The angle the final piece makes with the vertical.
    """

    N, = theta.shape
    N += 1
    h = L / N

    delta_y = (h / jnp.tan(theta)).sum()
    return jnp.arctan2(h, H - delta_y)


def total_time(theta, *, T_inf=1e10, eps=1e-10, L, H, g):
    """Compute the total time taken for the ball to roll down the discretized
    curve.

    Parameters
    ----------

    theta : jax.numpy.ndarray
        Angles the first :math:`(N - 1)` pieces make with the vertical. Shape
        `(N - 1,`) with floating point data type.
    T_inf : float
        Value to use if complex values, or non-positive times, are encountered.
    eps : float
        Non-positive value tolerance.
    L : float
        Horizontal distance.
    H : float
        Height drop.
    g : Magnitude of the gravitational acceleration.

    Returns
    -------

    jax.numpy.ndarray
        The total time.
    """

    N, = theta.shape
    N += 1
    h = L / N

    theta = jnp.concatenate(
        (theta,
         jnp.array((final_theta(theta, L, H),), dtype=theta.dtype)))

    u = 0
    T = 0
    for theta_val in theta:
        # Time taken to travel along the piece
        disc = (u * jnp.sin(theta_val)) ** 2 + g * h * jnp.sin(2 * theta_val)
        if disc < 0:
            return T_inf
        dt_inv = (u * jnp.sin(theta_val) + jnp.sqrt(disc)) / (2 * h)
        if dt_inv <= eps:
            return T_inf

        # Velocity after travelling along the piece
        u += g * (1 / dt_inv) * jnp.cos(theta_val)
        # Total time taken
        T += 1 / dt_inv
    return T


def curve(theta, L, H):
    """
    Return the coordinates of vertices of the discretized curve.

    Parameters
    ----------

    theta : jax.numpy.ndarray
        Angles the first :math:`(N - 1)` pieces make with the vertical. Shape
        `(N - 1,`) with floating point data type.
    L : float
        Horizontal distance.
    H : float
        Height drop.

    Returns
    -------

    jnp.ndarray
        The x-coordinates.
    jnp.ndarray
        The y-coordinates.
    """

    N, = theta.shape
    N += 1
    h = L / N

    x = jnp.linspace(0, L, N + 1, dtype=theta.dtype)
    y = jnp.concatenate((jnp.array((0,), dtype=theta.dtype),
                         jnp.cumsum(-h / jnp.tan(theta)),
                         jnp.array((-H,), dtype=theta.dtype)))

    return x, y


def plot_curve(theta, L, H):
    """Plot the discretized curve.

    Parameters
    ----------

    theta : jax.numpy.ndarray
        Angles the first :math:`(N - 1)` pieces make with the vertical. Shape
        `(N - 1,`) with floating point data type.
    L : float
        Horizontal distance.
    H : float
        Height drop.
    """

    x, y = curve(theta, L, H)

    fig, ax = plt.subplots(figsize=(10, 10))
    ax.axhline(-H, color="#888888")
    ax.plot(x, y, "k-")
    ax.set_xlim(0, L)
    ax.set_aspect(1)


L = 1.0
H = 0.25
g = 10.0
N = 30
total_time_config = partial(total_time, L=L, H=H, g=g)

# Constant slope
theta0 = jnp.full(N - 1, jnp.arctan2(L, H), dtype=float)
print(f"{total_time_config(theta0)=:.10g}")
plot_curve(theta0, L, H)

Now we seek to *minimize* the time taken.

In [None]:
def callback(x):
    print(f"{total_time_config(x)=:.10g}")


result = minimize(total_time_config, theta0,
                  callback=callback,
                  method="L-BFGS-B",
                  options={"ftol": 1e-14,
                           "gtol": 0,
                           "maxiter": 2000})
assert result.success

# Optimized curve
print(f"{total_time_config(result.x)=:.10g}")
plot_curve(result.x, L, H)

The code is surprisingly slow -- but we can make the code *much* faster using Just-In-Time (JIT) compilation.

In [None]:
@partial(jax.jit, static_argnames={"T_inf", "eps", "L", "H", "g"})
def total_time_jit(theta, *, T_inf=1e10, eps=1e-10, L, H, g):
    """As :func:`total_time`, but using JIT compilation.
    """

    final_theta_jit = jax.jit(final_theta, static_argnames={"L", "H"})

    @partial(jax.jit, static_argnames={"T_inf", "eps", "h", "g"})
    def time_piece(theta, i, val, *, T_inf, eps, h, g):
        u, T, inf = val

        # Time taken to travel along the piece
        disc = (u * jnp.sin(theta[i])) ** 2 + g * h * jnp.sin(2 * theta[i])
        inf = jnp.logical_or(inf, disc < 0)
        # If inf then use 0.0 here to avoid potential square root of negative
        disc = jax.lax.select(inf, 0.0, disc)
        dt_inv = (u * jnp.sin(theta[i]) + jnp.sqrt(disc)) / (2 * h)
        inf = jnp.logical_or(inf, dt_inv <= eps)
        # If inf then use 1.0 here to avoid potential division by zero
        dt_inv = jax.lax.select(inf, 1.0, dt_inv)

        # Velocity after travelling along the piece
        u = jax.lax.select(inf, 0.0, u + g * (1 / dt_inv) * jnp.cos(theta[i]))
        # Total time taken
        T = jax.lax.select(inf, T_inf, T + 1 / dt_inv)

        return u, T, inf

    N, = theta.shape
    N += 1
    h = L / N

    theta = jnp.concatenate(
        (theta,
         jnp.array((final_theta_jit(theta, L, H),), dtype=theta.dtype)))

    _, T, _ = jax.lax.fori_loop(
        0, N, partial(time_piece, theta, T_inf=T_inf, eps=eps, h=h, g=g),
        (0.0, 0.0, False))
    return T


total_time_jit_config = partial(total_time_jit, L=L, H=H, g=g)

# Constant slope
theta0 = jnp.full(N - 1, jnp.arctan2(L, H), dtype=float)
print(f"{total_time_jit_config(theta0)=:.10g}")
plot_curve(theta0, L, H)
plt.show()


def callback(x):
    print(f"{total_time_jit_config(x)=:.10g}")


result = minimize(total_time_jit_config, theta0,
                  method="L-BFGS-B",
                  callback=callback,
                  options={"ftol": 1e-14,
                           "gtol": 0,
                           "maxiter": 2000})
assert result.success

# Optimized curve
print(f"{total_time_jit_config(result.x)=:.10g}")
plot_curve(result.x, L, H)

Now we can increase the number of pieces, here using interpolation of the coarse resolution solution to construct an initial guess.

In [None]:
N_high_res = 2000

x, y = curve(result.x, L, H)
x_high_res = jnp.linspace(0, L, N_high_res + 1, dtype=x.dtype)
y_high_res = jnp.interp(x_high_res, x, y)
theta0_high_res = jnp.arctan2(
    x_high_res[1:] - x_high_res[:-1],
    y_high_res[:-1] - y_high_res[1:])[:-1]

result_high_res = minimize(total_time_jit_config, theta0_high_res,
                           callback=callback,
                           method="L-BFGS-B",
                           options={"ftol": 1e-14,
                                    "gtol": 0,
                                    "maxiter": 2000})
assert result_high_res.success

# Optimized curve
print(f"{total_time_jit_config(result_high_res.x)=:.10g}")
plot_curve(result_high_res.x, L, H)