# Need for Speed 
This notebook is a short example on how easy it can be to parallelize and speed up computations. It is meant as a supplement to a talk given at the [Insurance Data Science Conference](https://insurancedatascience.org/).

In [237]:
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from functools import partial
from typing import Callable, Tuple, Any
from numpy.typing import NDArray

We use a simplied example where there is a stochastic rate and we value a cashflow of $1$ at time $t$ with a certain rate that we have to find with an optimization algorithm (Newton-Raphson).

In [228]:
N = 100  # Number of parallel paths
T = 120  # Number of timesteps

In [229]:
key = jax.random.PRNGKey(1)
rates = np.cumsum(jax.random.normal(key, shape=(N, T)), axis=1)/100

Define a Newton-Raphson minimizer for an arbitrary function. Note that Jax differentiates the function automatically.

In [230]:
def minimize(f: Callable[[float], float], x0: float) -> float:
    """Minimize a function f with Newton-Raphson, starting at x0"""
    df = jax.grad(f)  # Get the gradient
    
    def step(i: int, x: float):
        """Single optimization step"""
        x = x - 1e-2 * df(x)
        return x
    
    # Apply the loop and return after 100 steps
    return jax.lax.fori_loop(0, 1000, step, x0)

assert np.isclose(1., minimize(lambda x: (x-1)**2, 5.), 0.1)

In [231]:
%timeit  minimize(lambda x: (x-1)**2, 5.).block_until_ready()

31.4 ms ± 957 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [232]:
minimize_specific = jax.jit(partial(minimize, lambda x: (x-1)**2))
minimize_specific(5.)

DeviceArray(1.0000029, dtype=float32, weak_type=True)

In [233]:
%timeit minimize_specific(5.).block_until_ready()

6.74 µs ± 33.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


We keep the logic very simple in a function that calculates the value for a single timestep, and then a function that computes all values for a single scenario.

In [241]:
def value_timestep(t: int, rate: NDArray[float]) -> Tuple[int, float]:
    """A function that values at time t for a given rate. For the output see jax.lax.scan. """
    t+=1  # Increase the iteration counter
    value = jnp.exp(rate * (T-t))  # Discount with the rate for the remaining timesteps.
    return t, value

def value_single_path(rates: NDArray[np.float64]) -> NDArray[float]:
    """Apply a Jax-primitive to scan along the first axis of rates and pass these values one-by-one to the function. """
    _, values = jax.lax.scan(
        value_timestep,
        init=0,
        xs=rates
    )

    return values

Parallelizing is extremely easy, we just apply `vmap` which parallelizes over the first axis (our scenarios 1...N)

In [242]:
value_all_paths = jax.vmap(value_single_path)
values = value_all_paths(rates)

In [243]:
%timeit value_all_paths(rates).block_until_ready()

642 µs ± 6.12 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


Now we compile this function (If a GPU backend is missing then it compiles for the CPU) through LLVM:

In [244]:
value_all_paths_fast = jax.jit(jax.vmap(value_single_path))
values = value_all_paths_fast(rates)

In [245]:
%timeit value_all_paths_fast(rates).block_until_ready()

20.6 µs ± 140 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


For fun run it for 5k scenarios up to 100 year in months into the future

In [246]:
rates_l = np.cumsum(jax.random.normal(key, shape=(5000, 1201)), axis=1)/10000

In [247]:
%timeit value_all_paths_fast(rates_l).block_until_ready()

27.5 ms ± 2.49 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
