## Reproducing configuration stats

This notebook can be used to produce the different configuration statistics (from the DNS) that guide the data generation choice:

* $Ro$: Rossby number
* $t_{L}$: Turnover time
* $\lambda$: Decorrelation error growth rate

We determine the size of the dataset $\mathrm{dim}(\mathbb{D})$ using these quantities.

In [None]:
import sys, os
sys.path.append(os.path.dirname(os.getcwd()))

import tqdm
import matplotlib.pyplot as plt

plt.rcParams.update({
  'mathtext.fontset': 'cm'
})

import numpy as np
import jax
import jax.numpy as jnp
import jax.random as jnr

jax.config.update(
  'jax_enable_x64', True
)

import models.imex_solver as imex
from models.qg_annulus import (
    QgAnnulus, 
    dynamical_solver,
    cartesian_forcing,
    integral,
    reynolds
)
from utils import (
    into_m,
    from_m,
    plot_annulus
)

### Loading from a configuration snapshot

We load a pre-existing snapshot, previously generated using the `snapshot.py` script. 

In [None]:
# Name of the configuration (as -n argument to snapshot.py)
cfg_name = 'i'
eq, time, ps_m, us_m, up_m, om_m = QgAnnulus.load('../data/' + cfg_name + '/snapshot.h5')
print(eq)

fig = plot_annulus(eq, us_m, cmap='BrBG', label=r'$u_s$', vmin=-1e4, vmax=1e4)
fig = plot_annulus(eq, up_m, cmap='RdBu_r', label=r'$u_\varphi$', vmin=-1e4, vmax=1e4)
fig = plot_annulus(eq, om_m, cmap='BrBG_r', label=r'$\omega$', vmin=-7e5, vmax=7e5)

### Instantiating a new dynamical solver

To accumulate statistics, we need to run new iterations at steady state. We create a new dynamical solver with the same configuration parameters (from `snapshot.py` script)

In [None]:
# Cartesian forcing parameters
dx_f = 0.08
radius_f = 0.04
amp_f = 2e10
cf_m = cartesian_forcing(eq, dx_f, radius_f, amp_f)

def source(
    ps_m: jnp.ndarray, 
    us_m: jnp.ndarray, 
    up_m: jnp.ndarray, 
    om_m: jnp.ndarray
) -> jnp.ndarray:
    return cf_m

# Discrete (fixed) time step
dt = 4e-8
solver = jax.jit(dynamical_solver(
    eq,
    imex.BPR353(dt),
    source
))

### Run a new simulation

We now run $1/4$ of the transient time (provided by the final time $T$ of the snapshot) and compute the total kinetic energy $E_{T}(t)$, from which we will derive our statistical metrics.

In [None]:
stats_time = time / 4
iters = int(stats_time / dt)
logs = 2500
logs_freq = int(iters / logs)

# Reynolds number Re(t)
re_t = []

pbar = tqdm.tqdm(range(iters), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
for i in pbar:
    c, ps_m, us_m, up_m, om_m = solver(ps_m, us_m, up_m, om_m)
    if i % logs_freq == 0:
        re_t.append(
            reynolds(eq, us_m, up_m)
        )

        pbar.set_postfix(
            cfl=format(dt / c, ".2f")
        )

In [None]:
def get_stats(
    eq: QgAnnulus, 
    re_t: np.ndarray,
    n_steps: int, 
    dt_coarse: float
):
    re = np.mean(re_t)
    print('Re =',re,'±',np.std(re_t))
    print('Ro =',eq.E * re)
    
    turnover_time = 1 / re
    print('Turnover time =',turnover_time)

    samples_turnover = turnover_time / dt_coarse
    print('Samples per turnover =',samples_turnover)
    sub_traj = int(round(samples_turnover / n_steps))
    print('Sub-trajectories per turnover =',sub_traj)
    print('Dataset size (samples) for continuous sub-trajectories =',sub_traj * n_steps)
    return turnover_time

turnover_time = get_stats(
    eq, 
    np.array(re_t),
    n_steps=25, 
    dt_coarse=5 * dt
)

### Computing decorrelation times

To estimate the decorrelation rate $\lambda$, we measure the root mean squared error $\langle \delta \omega^2 \rangle^{1/2}$ between a reference trajectory $\omega(t)$ and a perturbed one $\delta \omega(t)$. This process is repeated for an ensemble. 

*Note*: the perburted trajectories are initialized with a small linear perturbation from a random pump activation of amplitude of $10^{-10} a_{\mathcal{F}}$.

In [None]:
def random_pump(
    eq: QgAnnulus,
    dx_f: float, 
    radius_f: float, 
    amp_f: float,
    key: jnr.PRNGKey
):
    """
    Randomly sample one pump from the cartesian forcing described in
    
    Zonal jets experiments in the gas giants’ zonostrophic regime.
    D. Lemasquerier, B. Favier and M. Le Bars.
    Icarus 390 (2023).
    """
    f_m = np.zeros((eq.n_m, eq.n_s), dtype=np.complex128)
    nx = int(2 * eq.s_o / dx_f + 1)
    ny = nx
    dx = 2 * eq.s_o / (nx - 1)
    dy = dx
    
    x_lins, y_lins = np.meshgrid(np.arange(nx), np.arange(ny), indexing='ij')
    amp_grid = (-1)**(x_lins + 1) * (-1)**(y_lins + 1)
    x_grid, y_grid = np.meshgrid(-eq.s_o + dx * np.arange(nx), -eq.s_o + dy * np.arange(ny), indexing='ij')
    iso_grid = np.sqrt(x_grid*x_grid + y_grid*y_grid)
    
    pump_position = (iso_grid >= eq.s_i + 0.5 * dx) & \
                    (iso_grid <= eq.s_o - 0.5 * dx)
    
    amp = amp_grid[pump_position]
    x = x_grid[pump_position]
    y = y_grid[pump_position]
    
    x_phi_grid = eq.s_grid * np.expand_dims(np.cos(2 * np.pi * np.arange(eq.n_phi) / eq.n_phi), axis=1)
    y_phi_grid = eq.s_grid * np.expand_dims(np.sin(2 * np.pi * np.arange(eq.n_phi) / eq.n_phi), axis=1)

    n_pumps = len(amp)
    if n_pumps % 2 != 0: n_pumps -= 1
        
    newkey, subkey = jnr.split(key)
    pump_i = jnr.randint(subkey, (1,), 0, n_pumps)
    
    f_g = amp_f * amp[pump_i] * np.exp(-(x[pump_i] - x_phi_grid)**2 / radius_f**2) * np.exp(-(y[pump_i] - y_phi_grid)**2 / radius_f**2)
    f_m = into_m(f_g, eq.n_m).at[0].set(0)
    return (
        newkey, 
        f_m
    )

# Number of turnovers (for configs (i), (ii) and (iii), 7 is enough to fully decorrelate)
turnovers = 7
iters = int(np.round(turnovers * turnover_time / dt, -1))
logs_freq = 10
logs = iters // logs_freq

ps_m_ref = np.copy(ps_m)
us_m_ref = np.copy(us_m)
up_m_ref = np.copy(up_m)
om_m_ref = np.copy(om_m)

om_m_trj = np.zeros((logs, eq.n_m, eq.n_s), dtype=np.complex128)
print('Computing reference trajectory...')

pbar = tqdm.tqdm(range(iters), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
for i in pbar:
    c, ps_m_ref, us_m_ref, up_m_ref, om_m_ref = solver(ps_m_ref, us_m_ref, up_m_ref, om_m_ref)
    if i % logs_freq == 0:
        om_m_trj[i // logs_freq] = om_m_ref
        
        pbar.set_postfix(
            cfl=format(dt / c, ".2f"),
        )

# Number of ensemble members
ensemble = 3
ensemble_rmse = np.zeros((ensemble, logs))

key = jnr.key(123)
for e in range(ensemble):
    key, pert_m = random_pump(eq, dx_f, radius_f, amp_f=1e-10 * amp_f, key=key)
    om_m_p = np.copy(om_m) + pert_m
    ps_m_p = np.copy(ps_m)
    us_m_p = np.copy(us_m)
    up_m_p = np.copy(up_m)
    print('Computing perturbed trajectory for ensemble member ' + str(e) + '...')
    
    pbar = tqdm.tqdm(range(iters), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
    for i in pbar:
        c, ps_m_p, us_m_p, up_m_p, om_m_p = solver(ps_m_p, us_m_p, up_m_p, om_m_p)
        if i % logs_freq == 0:
            ensemble_rmse[e, i // logs_freq] = np.sqrt(integral(eq, om_m_p - om_m_trj[i // logs_freq]))

            pbar.set_postfix(
                cfl=format(dt / c, ".2f"),
                rmse=format(ensemble_rmse[e, i // logs_freq], ".2e"),
            )

simu_time = np.linspace(0, iters * dt, logs)
mean_rmse = np.mean(ensemble_rmse, axis=0)

fig, axs = plt.subplots(ncols=1, nrows=1, figsize=(3, 4.5), dpi=120)
axs.semilogy(simu_time, ensemble_rmse.T, alpha=0.3)
axs.semilogy(simu_time, mean_rmse, color='k', marker='s', markevery=200, markersize=6, alpha=0.6)

axs.set_xlabel(r'$t$', fontsize=15)
tax = axs.secondary_xaxis('top', functions=(lambda t: t / turnover_time, lambda t_l: t_l * turnover_time))
tax.set_xlabel(r'$t_L$', fontsize=15)
axs.set_ylabel(r'$\langle \delta \omega^2 \rangle^{1/2}$', fontsize=15)
axs.tick_params(reset=True, axis='y', which='both', direction='in')
fig.tight_layout()
plt.show()