# Study of the cost of calculation of the autodiff calibration
After the first release of Tunax, we faced a problem about the cost of calculation of the calibration which was way too heavy. The cost of calculation of an experience of calibration comes from the computation of the gradient of the cost function with the autodifferentiation of JAX. Indeed, a part of this cost function is the whole forward model itself. The model itself, taking as a function, is composed of a large amount of operations taking in account all the temporal steps of integration. That's why the gradient is too expensive to compute : the memory of all the operations that it does is too large. In this notebook we will study with a simplified model how to use the JAX autodiff to compute the gradient of a cost function with a reasonable cost of calculation.

In [1]:
import os
import shutil
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import Figure, Axes
from typing import List, Tuple, TypeAlias
subplot_1D_type: TypeAlias = Tuple[Figure, List[Axes]]
subplot_2D_type: TypeAlias = Tuple[Figure, List[List[Axes]]]
latex_installed = shutil.which("latex") is not None
plt.rcParams.update({
    'text.usetex': latex_installed,
    'figure.figsize': (8, 5),
    'axes.titlesize': 18,
    'figure.titlesize': 18,
    'axes.labelsize': 14,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'legend.fontsize': 12,
    'lines.linewidth': 2,
    'lines.markersize': 6
})

### Simplified model implementation

In [28]:
import equinox as eqx
from typing import Tuple
from jax import lax
from tunax import add_boundaries, tridiag_solve

def diffusion_solver(ak: jnp.ndarray, hz: jnp.ndarray, f: jnp.ndarray, dt: float) -> jnp.ndarray:
# fill the coefficients for the tridiagonal matrix
    a_in = -2.0 * dt * ak[1:-2] / (hz[:-2] + hz[1:-1])
    c_in = -2.0 * dt * ak[2:-1] / (hz[2:] + hz[1:-1])
    b_in = hz[1:-1] - a_in - c_in

    # bottom boundary condition
    c_btm = -2.0 * dt * ak[1] / (hz[1] + hz[0])
    b_btm = hz[0] - c_btm

    # surface boundary condition
    a_sfc = -2.0 * dt * ak[-2] / (hz[-2] + hz[-1])
    b_sfc = hz[-1] - a_sfc

    # concatenations
    a = add_boundaries(0., a_in, a_sfc)
    b = add_boundaries(b_btm, b_in, b_sfc)
    c = add_boundaries(c_btm, c_in, 0.)

    x = tridiag_solve(a, b, c, f)

    return x

class Grid(eqx.Module):
    nz: int
    zr: jnp.ndarray
    hz: jnp.ndarray
    
    def __init__(self, nz: int):
        self.nz = nz
        self.zr = jnp.linspace(-100, 0, nz)
        self.hz = jnp.full(nz, 100/nz)

class State(eqx.Module):
    grid: Grid
    t: jnp.ndarray

class Trajectory(eqx.Module):
    grid: Grid
    time: jnp.ndarray
    t: jnp.ndarray

class CloState(eqx.Module):
    grid: Grid
    diff: jnp.ndarray


    def __init__(self, grid: Grid):
        self.grid = grid
        self.diff = jnp.full(grid.nz+1, 1e-5)

class CloParams(eqx.Module):
    par1: float = 1.
    par2: float = 1.
    par3: float = 1.
    par4: float = 1.
    par5: float = 1.

class Case(eqx.Module):
    forc: float = .01

def clo_step(state: State, clo_state: CloState, dt: float, clo_params: CloParams, case: Case)->CloParams: 
    return clo_state

class Model(eqx.Module):
    nt: int
    dt: float
    n_out: int
    init_state: State
    case: Case

    def step(self, clo_params: CloParams, state: State, clo_state: CloState)-> Tuple[State, CloState]:
        grid = state.grid

        clo_state = clo_step(state, clo_state, self.dt, clo_params, self.case)

        ft = jnp.zeros(grid.nz+1)
        ft = ft.at[-1].set(self.case.forc)
        dft = state.grid.hz*state.t + self.dt*(ft[1:] - ft[:-1])
        t = diffusion_solver(clo_state.diff, grid.hz, dft, self.dt)

        state = eqx.tree_at(lambda tree: tree.t, state, t)

        return state, clo_state
    
    def run_partial(self, clo_params: CloParams, state0: State, clo_state0: CloState, n_steps: int)->Tuple[State, CloState]:
        state = state0
        clo_state = clo_state0
        for _ in range(n_steps):
            state, clo_state = self.step(clo_params, state, clo_state)
        return state, clo_state
    
    def run(self, clo_params: CloParams)->Trajectory:
        state = self.init_state
        clo_state = CloState(state.grid)
        state_list = [state]
        n_out = self.nt//self.n_out
        for _ in range(n_out):
            state, clo_state = self.run_partial(clo_params, state, clo_state, self.n_out)
            state_list.append(state)
        out_dt = self.n_out*self.dt
        time = jnp.arange(0, (n_out+1)*out_dt,out_dt)
        t_list = [s.t for s in state_list]
        return Trajectory(state.grid, time, jnp.vstack(t_list))

In [29]:
nz = 100
grid = Grid(nz)
t0 = jnp.linspace(15, 20, nz)
init_state = State(grid, t0)
case = Case()
nt = 1000
dt = 10
n_out = 10

model = Model(nt, dt, n_out, init_state, case)
clo_params = CloParams()

traj = model.run(clo_params)