# Study of the cost of calculation of the autodiff calibration
After the first release of Tunax, we faced a problem about the cost of calculation of the calibration which was way too heavy. The cost of calculation of an experience of calibration comes from the computation of the gradient of the cost function with the autodifferentiation of JAX. Indeed, a part of this cost function is the whole forward model itself. The model itself, taking as a function, is composed of a large amount of operations taking in account all the temporal steps of integration. That's why the gradient is too expensive to compute : the memory of all the operations that it does is too large. In this notebook we will study with a simplified model how to use the JAX autodiff to compute the gradient of a cost function with a reasonable cost of calculation.

In [1]:
import os
import shutil
import jax.numpy as jnp
import equinox as eqx
import matplotlib.pyplot as plt
from matplotlib.pyplot import Figure, Axes
from typing import List, Tuple, TypeAlias
subplot_1D_type: TypeAlias = Tuple[Figure, List[Axes]]
subplot_2D_type: TypeAlias = Tuple[Figure, List[List[Axes]]]
latex_installed = shutil.which("latex") is not None
plt.rcParams.update({
    'text.usetex': latex_installed,
    'figure.figsize': (8, 5),
    'axes.titlesize': 18,
    'figure.titlesize': 18,
    'axes.labelsize': 14,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'legend.fontsize': 12,
    'lines.linewidth': 2,
    'lines.markersize': 6
})

In [2]:
import equinox as eqx
from typing import Tuple
from jax import lax, jit, random
from time import time as tt

### Notes on JAX
##### `scan` vs. `fori_loop`
`scan` est le plus deep

If the trip count is static (meaning known at tracing time, perhaps because lower and upper are Python integer literals) then the fori_loop is implemented in terms of scan() and reverse-mode autodiff is supported; otherwise, a while_loop is used and reverse-mode autodiff is not supported. See those functions’ docstrings for more information.

Si les itérations sont indépendantes, fori_loop est plus efficace. Pour des dépendances complexes, scan est préférable mais plus gourmand.

En fait scan permet la sortie des états intermédiaires, donc utile pour une intégration temporelle par exemple alors que fori_loop c'est plus pour une boucle itérative simple (et elle y est plus efficace).

en fait si on gère nous même l'accumulation des sorties, c'est plus coûteux avec fori_loop parce qu'on le fait manuellement

dans le cas où les bornes sont statiques pour fori_loop, comme c'est implémenté en tant que scan c'est le même coût, par contre si les bornes ne sont pas statiques c'est là où le coût est différent

##### checkpointing
l'idée est dans le calcul du gradient reverse, on ne retienne pas les résultats des calculs intérmédiares dans le sens forward, mais on les recompute lorsqu'on fait la partie backward. Ca fait perdre plus de temps parcre qu'on refait les mêmes calculs mais gagner de la mémoire et potentiellement éviter un goulot d'étranglement.

cons : il faut checkpointer sur les sous-fonctions de notre fonction principale pour skipper les étapes qu'il ne faut pas retenir

il faut checkpointer sur les fonctions "intérieures" de la composition = celles qui sont au début dans l'ordre d'écriture du code

**Policies** : en temps normal on doit modifier notre code pour ajouter des checkpointings mais il existe des "policies" pour dire ce qu'on checkpoint sans modifier le code de la fonction. C'est possible de mettre des "names" c'est comme des flags sur certaines parties des calculs pour les viser ou non dans la policy

**Offload** : possible de "déplacer" certaines calculs du CPU au GPU

**recursive checkpoints** idée pour avoir du $\mathcal O (\log_2 (D))$ au lieu de du $\mathcal O (D)$ en mémoire pour les processus recursifs, l'idée est plus ou moins de faire une dicotomie sur les différentes étapes du proecessus recursif.

##### Chekpointing, jit, scan
En général, le checkpointing est inutile lorsqu'on fait le gradient d'un truc déjà jitté parce que XLA optimise déjà tout pour nous

Mais une exeption pour `lax.scan` parce que l'opotimisation de cette fonction n'est pas la même pour la compute en mode forward et pour faire son gradient backward -> dans ce cas ça vaut le coup de mettre ces checkpoints dans la fonction qu'on met dans `lax.scan`.

### Premelinary functions implementation

In [3]:
def tridiag_solve_fori_loop(a: jnp.ndarray, b: jnp.ndarray, c: jnp.ndarray, f: jnp.ndarray) -> jnp.ndarray:
    n, = a.shape
    cff = 1.0 / b[0]
    f = f.at[0].multiply(cff)
    q = jnp.zeros(n)
    q = q.at[0].set(-c[0] * cff)

    def tri_forward_fori(k: int, x: jnp.ndarray):
        f = x[0, :]
        q = x[1, :]
        cff = 1.0 / (b[k] + a[k] * q[k-1])
        q = q.at[k].set(-cff * c[k])
        f = f.at[k].set(cff * (f[k] - a[k] * f[k-1]))
        return jnp.stack([f, q])
    f_q = jnp.stack([f, q])
    f_q = lax.fori_loop(1, n, tri_forward_fori, f_q)
    f = f_q[0, :]
    q = f_q[1, :]

    def tri_reverse_fori(k: int, x: jnp.ndarray):
        return x.at[n-1-k].add(q[n-1-k] * x[n-k])
    x = lax.fori_loop(1, n, tri_reverse_fori, f)

    return x

def tridiag_solve_scan(a: jnp.ndarray, b: jnp.ndarray, c: jnp.ndarray, f: jnp.ndarray) -> jnp.ndarray:
    def tri_forward_scan(carry: Tuple[float, float], x: jnp.ndarray):
        f_im1, q_im1 = carry
        a, b, c, f = x
        cff = 1./(b+a*q_im1)
        f_i = cff*(f-a*f_im1)
        q_i = -cff*c
        carry = f_i, q_i
        return carry, carry
    init = f[0]/b[0], -c[0]/b[0]
    xs = jnp.stack([a, b, c, f])[:, 1:].T
    _, (f, q) = lax.scan(tri_forward_scan, init, xs)
    f = jnp.concat([jnp.array([init[0]]), f])
    q = jnp.concat([jnp.array([init[1]]), q])

    def tri_reverse_scan(carry: float, x: jnp.ndarray):
        q_rev, f_rev = x
        carry = f_rev + q_rev*carry
        return carry, carry
    init = f[-1]
    xs = jnp.stack([q[::-1], f[::-1]])[:, 1:].T
    _, x = lax.scan(tri_reverse_scan, init, xs)
    x = jnp.concat([jnp.array([init]), x])

    return x[::-1]

def tridiag_solve_scan2(a: jnp.ndarray, b: jnp.ndarray, c: jnp.ndarray, f: jnp.ndarray) -> jnp.ndarray:
    n, = a.shape
    cff = 1.0 / b[0]
    f = f.at[0].multiply(cff)
    q = jnp.zeros(n)
    q = q.at[0].set(-c[0] * cff)

    # Forward pass with lax.scan
    def tri_forward_scan(carry, k):
        f, q = carry
        cff = 1.0 / (b[k] + a[k] * q[k-1])
        q = q.at[k].set(-cff * c[k])
        f = f.at[k].set(cff * (f[k] - a[k] * f[k-1]))
        return (f, q), None

    (f, q), _ = lax.scan(tri_forward_scan, (f, q), jnp.arange(1, n))

    # Reverse pass with lax.scan
    def tri_reverse_scan(carry, k):
        x = carry
        x = x.at[n-1-k].add(q[n-1-k] * x[n-k])
        return x, None

    x, _ = lax.scan(tri_reverse_scan, f, jnp.arange(1, n))

    return x

def add_boundaries(vec_btm: float, vec_in: jnp.ndarray, vec_sfc: float) -> jnp.ndarray:
    return jnp.concat([jnp.array([vec_btm]), vec_in, jnp.array([vec_sfc])])

def diffusion_solver(ak: jnp.ndarray, hz: jnp.ndarray, f: jnp.ndarray, dt: float) -> jnp.ndarray:
    a_in = -2.0 * dt * ak[1:-2] / (hz[:-2] + hz[1:-1])
    c_in = -2.0 * dt * ak[2:-1] / (hz[2:] + hz[1:-1])
    b_in = hz[1:-1] - a_in - c_in

    c_btm = -2.0 * dt * ak[1] / (hz[1] + hz[0])
    b_btm = hz[0] - c_btm

    a_sfc = -2.0 * dt * ak[-2] / (hz[-2] + hz[-1])
    b_sfc = hz[-1] - a_sfc

    a = add_boundaries(0., a_in, a_sfc)
    b = add_boundaries(b_btm, b_in, b_sfc)
    c = add_boundaries(c_btm, c_in, 0.)

    x = tridiag_solve_scan(a, b, c, f)

    return x

##### Comparison of `lax.scan` and `lax.fori_loop` for the tridiagonal inversion

In [4]:
def gen_rand_vectors(key, n):
    key, subkey = random.split(key)
    a = random.uniform(subkey, (n,))
    key, subkey = random.split(key)
    b = random.uniform(subkey, (n,))
    key, subkey = random.split(key)
    c = random.uniform(subkey, (n,))
    key, subkey = random.split(key)
    f = random.uniform(subkey, (n,))
    return a, b, c, f, key

def benchmark(n: int, num_trials: int = 10):
    tri_for_jit = jit(tridiag_solve_fori_loop)
    tri_scan_jit = jit(tridiag_solve_scan)
    tri_scan2_jit = jit(tridiag_solve_scan2)
    key = random.PRNGKey(0)

    print(f'Comparing strategies for tridiagonal inversion for vectors of size {n} and for a mean on {num_trials} trials\n')

    t = tt()
    for _ in range(num_trials):
        a, b, c, f, key = gen_rand_vectors(key, n)
        _ = tridiag_solve_fori_loop(a, b, c, f)
    for_time = (tt() - t) / num_trials
    print(f'lax.fori_loop\n{for_time}s\n')

    t = tt()
    _ = tri_for_jit(a, b, c, f)
    first_for_jit_time = tt() - t
    print(f'lax.fori_loop + JIT (1st compilation) \n{first_for_jit_time}s\n')

    t = tt()
    for _ in range(num_trials):
        a, b, c, f, key = gen_rand_vectors(key, n)
        _ = tri_for_jit(a, b, c, f)
    for_jit_time = (tt() - t) / num_trials
    print(f'lax.fori_loop + JIT\n{for_jit_time}s\n')

    t = tt()
    for _ in range(num_trials):
        a, b, c, f, key = gen_rand_vectors(key, n)
        _ = tridiag_solve_scan(a, b, c, f)
    scan_time = (tt() - t) / num_trials
    print(f'lax.scan\n{scan_time}s\n')

    t = tt()
    _ = tri_scan_jit(a, b, c, f)
    first_scan_jit_time = tt() - t
    print(f'lax.scan + JIT (1st compilation) \n{first_scan_jit_time}s\n')

    t = tt()
    for _ in range(num_trials):
        a, b, c, f, key = gen_rand_vectors(key, n)
        _ = tri_scan_jit(a, b, c, f)
    scan_jit_time = (tt() - t) / num_trials
    print(f'lax.scan + JIT\n{scan_jit_time}s\n')

    t = tt()
    for _ in range(num_trials):
        a, b, c, f, key = gen_rand_vectors(key, n)
        _ = tridiag_solve_scan2(a, b, c, f)
    scan_time = (tt() - t) / num_trials
    print(f'lax.scan2\n{scan_time}s\n')

    t = tt()
    _ = tri_scan2_jit(a, b, c, f)
    first_scan_jit_time = tt() - t
    print(f'lax.scan2 + JIT (1st compilation) \n{first_scan_jit_time}s\n')

    t = tt()
    for _ in range(num_trials):
        a, b, c, f, key = gen_rand_vectors(key, n)
        _ = tri_scan2_jit(a, b, c, f)
    scan_jit_time = (tt() - t) / num_trials
    print(f'lax.scan2 + JIT\n{scan_jit_time}s\n')

**Test results on Macbook Pro M3 Pro 18Go**

In [6]:
benchmark(1000, 100)

Comparing strategies for tridiagonal inversion for vectors of size 1000 and for a mean on 100 trials

lax.fori_loop
0.04748637914657593s

lax.fori_loop + JIT (1st compilation) 
0.04959678649902344s

lax.fori_loop + JIT
0.0007703304290771484s

lax.scan
0.03587670087814331s

lax.scan + JIT (1st compilation) 
0.03472495079040527s

lax.scan + JIT
0.00022287845611572267s

lax.scan2
0.04254220008850098s

lax.scan2 + JIT (1st compilation) 
0.03572678565979004s

lax.scan2 + JIT
0.00021483898162841797s



FAIRE UN GRAPHE SELON N

### Simplified model implementation
We built this simplified model taking in account the same architecture than in tunax. We build the closure calculations to fit approximatelly the cost of calculation of k-epsilon, but the calculations have no sense.

In [2]:
class Grid(eqx.Module):
    nz: int
    zr: jnp.ndarray
    hz: jnp.ndarray
    
    def __init__(self, nz: int):
        self.nz = nz
        self.zr = jnp.linspace(-100, 0, nz)
        self.hz = jnp.full(nz, 100/nz)

class State(eqx.Module):
    grid: Grid
    t: jnp.ndarray

class Trajectory(eqx.Module):
    grid: Grid
    time: jnp.ndarray
    t: jnp.ndarray

class CloState(eqx.Module):
    grid: Grid
    diff: jnp.ndarray


    def __init__(self, grid: Grid):
        self.grid = grid
        self.diff = jnp.full(grid.nz+1, 1e-5)

class CloParams(eqx.Module):
    par1: float = 1.
    par2: float = 1.
    par3: float = 1.
    par4: float = 1.
    par5: float = 1.

class Case(eqx.Module):
    forc: float = .01

def clo_step(state: State, clo_state: CloState, dt: float, clo_params: CloParams, case: Case)->CloParams: 
    t = state.t
    diff = clo_state.diff

    f1 = clo_params.par1 * jnp.sin(t) + clo_params.par2 * jnp.cos(diff[:-1])
    f2 = clo_params.par3 * jnp.log1p(jnp.abs(diff[:-1]))
    f3 = clo_params.par4 * t**2
    convolution = 0.5 * (f3[:-1] + f3[1:])
    f_combined = f1 + f2 + jnp.pad(convolution, (1, 0), mode='constant')
    diff_new = diffusion_solver(diff, state.grid.hz, f_combined, dt)
    diff_new = diffusion_solver(diff, state.grid.hz, diff_new, dt)
    diff_new = diffusion_solver(diff, state.grid.hz, diff_new, dt)
    diff_new = diffusion_solver(diff, state.grid.hz, diff_new, dt)
    diff_new = diffusion_solver(diff, state.grid.hz, diff_new, dt)
    diff_new = diffusion_solver(diff, state.grid.hz, diff_new, dt)
    diff_new = diffusion_solver(diff, state.grid.hz, diff_new, dt)
    diff_new = diffusion_solver(diff, state.grid.hz, diff_new, dt)
    diff_new = diffusion_solver(diff, state.grid.hz, diff_new, dt)
    diff_new = clo_params.par5*jnp.concat([diff_new, jnp.array([0.])])

    clo_state = eqx.tree_at(lambda t: t.diff, clo_state, diff_new)
    return clo_state

class Model(eqx.Module):
    nt: int
    dt: float
    n_out: int
    init_state: State
    case: Case

    def step(self, clo_params: CloParams, state: State, clo_state: CloState) -> Tuple[State, CloState]:
        # Extraction explicite des attributs nécessaires
        grid = state.grid
        hz = grid.hz
        t = state.t
        diff = clo_state.diff

        # Mise à jour de `clo_state`
        clo_state = clo_step(state, clo_state, self.dt, clo_params, self.case)

        # Calcul de `dft` et de la nouvelle température
        ft = jnp.zeros(state.t.shape[0] + 1)
        ft = ft.at[-1].set(self.case.forc)
        dft = hz * t + self.dt * (ft[1:] - ft[:-1])
        new_t = diffusion_solver(diff, hz, dft, self.dt)

        # Mise à jour de l'état avec `eqx.tree_at`
        state = eqx.tree_at(lambda s: s.t, state, new_t)

        return state, clo_state

    # Compilation JAX-friendly
    compiled_step = jit(step)
    
    def run_partial(self, clo_params: CloParams, state0: State, clo_state0: CloState, n_steps: int)->Tuple[State, CloState]:
        state = state0
        clo_state = clo_state0
        for _ in range(n_steps):
            state, clo_state = self.compiled_step(clo_params, state, clo_state)
        return state, clo_state
    
    def run(self, clo_params: CloParams)->Trajectory:
        state = self.init_state
        clo_state = CloState(state.grid)
        state_list = [state]
        n_out = self.nt//self.n_out
        for _ in range(n_out):
            state, clo_state = self.run_partial(clo_params, state, clo_state, self.n_out)
            state_list.append(state)
        out_dt = self.n_out*self.dt
        time = jnp.arange(0, (n_out+1)*out_dt,out_dt)
        t_list = [s.t for s in state_list]
        return Trajectory(state.grid, time, jnp.vstack(t_list))

In [3]:
nz = 100
grid = Grid(nz)
t0 = jnp.linspace(15, 20, nz)
init_state = State(grid, t0)
case = Case()
nt = 3000
dt = 100
n_out = 10

model = Model(nt, dt, n_out, init_state, case)
clo_params = CloParams()

tic = tt()
traj = model.run(clo_params)
print(tt()-tic)

1.5344898700714111
0.7596253871917724


### Study of the calculation of the gradient

In [7]:
from typing import List
class Observation(eqx.Module):
    traj: Trajectory
    case: Case
    dt: float

def loss(database: List[Observation], clo_params: CloParams):
    s = 0
    for obs in database:
        traj = obs.traj
        t0 = traj.t[0, :]
        init_state = State(traj.grid, t0)
        nt = int(float(traj.time[-1])/obs.dt)
        n_out = int(float(traj.time[1]-traj.time[0])/obs.dt)
        model = Model(nt, obs.dt, n_out, init_state, obs.case)

        traj_model = model.run(clo_params)

        s += jnp.sum((traj.t[-1, :] - traj_model.t[-1, :])**2)

    return s

In [8]:
nz = 100
nt = 300
dt = 100
case = Case()
grid = Grid(nz)
traj = Trajectory(grid, jnp.linspace(0, (nt+1)*dt, nt+1), jnp.full((nt+1, nz), 15.))
database = [Observation(traj, case, dt)]

def loss_wrapped(x: jnp.ndarray):
    return loss(database, CloParams(x[0], x[1], x[2], x[3], x[4]))

loss_wrapped(jnp.array([0, 0, 0, 0, 0]))

# from jax import grad

# grad_loss = grad(loss_wrapped)


# grad_loss(jnp.array([0., 0., 0., 0., 0.]))

Array(3.5743142e-10, dtype=float32)