In [None]:
import os
from typing import List

import equinox as eqx
import jax.numpy as jnp
from jaxtyping import Float, Array

from tunax import Obs, Database, FittableParameter, FittableParametersSet, Trajectory, Fitter

# create the database
observations = []
weights = []
variables = []
space_step = 4
for timeframe in timeframe_list:
    for forcing in forcing_list:
        filename = os.path.join(
            'k-epsilon_Wagner_LES', 'Wagner_LES',f'{timeframe}_hour_suite',
            f'{space_step}m', f'{forcing}_with_tracer_instantaneous_statistics.jld2')
        obs = Obs.from_jld2(filename, var_names, None, dims, eos_tracers='b', do_pt=True)

        # case corrections
        obs: Obs = eqx.tree_at(lambda t: t.case.ustr_sfc, obs, -obs.metadatas['u_str'])
        obs: Obs = eqx.tree_at(lambda t: t.case.b_forcing, obs, (0., -obs.metadatas['b_str']))
        obs: Obs = eqx.tree_at(lambda t: t.case.grav, obs, 9.80665)

        # passive tracer forcing
        omega_p = 1/obs.metadatas['pt_timescale']
        lambda_c = obs.metadatas['pt_width']
        zc = obs.metadatas['pt_depth']
        l_z = obs.trajectory.grid.hbot
        wrapped_forcing_pt = lambda z: forcing_passive_tracer(z, omega_p, lambda_c, -zc, l_z)
        obs: Obs = eqx.tree_at(lambda t: t.case.pt_forcing, obs, wrapped_forcing_pt)
        obs: Obs = eqx.tree_at(lambda t: t.case.pt_forcing_type, obs, 'constant', is_leaf=lambda x: x is None)

        # sunny forcing
        if forcing == 'strong_wind_and_sunny':
            jb = obs.metadatas['sunny_flux']
            eps1 = 0.6
            lambda1 = 1.
            lambda2 = 16.
            b_str = obs.case.b_forcing[1]
            wrapped_forcing_b_sunny = lambda z: forcing_buoyancy_sunny(z, jb, eps1, lambda1, lambda2)
            obs: Obs = eqx.tree_at(lambda t: t.case.b_forcing, obs, wrapped_forcing_b_sunny)
            obs: Obs = eqx.tree_at(lambda t: t.case.b_forcing_type, obs, 'constant', is_leaf=lambda x: x is None)
        
        # weight of the observation
        if forcing == 'free_convection':
            weights.append(1.)
            variables.append(['b', 'pt'])
        elif forcing in ['strong_wind_no_rotation', 'strong_wind_and_sunny']:
            weights.append(2/3)
            variables.append(['b', 'pt', 'u'])
        else:
            weights.append(1/2)
            variables.append(['b', 'pt', 'u', 'v'])
        observations.append(obs)
metadatas = {'weights': weights, 'variables': variables}
database = Database(observations, metadatas)

# fonction de normalisation d'un array
def norm_array(x: Float[Array, 'nz']):
    return (x-x.mean())/x.std()

# indices de la zone à sélectionner (pour le space step donné)
grid = database.observations[0].trajectory.grid
i_bot = jnp.searchsorted(grid.zr, grid.hbot/3, side='right')
i_top = jnp.searchsorted(grid.zr, -4, side='left')

# fonction coût
def loss(scm_set: List[Trajectory], database: Database):
    cost = 0.
    for i_obs in range(1):#len(scm_set)-34):
        traj_scm = scm_set[i_obs]
        traj_obs = database.observations[i_obs].trajectory
        for var in database.metadatas['variables'][i_obs]:
            var_scm = getattr(traj_scm, var)
            var_obs = getattr(traj_obs, var)
            end_scm = var_scm[-1, i_bot:i_top]
            end_obs = var_obs[-1, i_bot:i_top]
            norm_scm = norm_array(end_scm)
            norm_obs = norm_array(end_obs)
            var_cost = jnp.sum((norm_scm-norm_obs)**2)
            cost += database.metadatas['weights'][i_obs]*var_cost
    return cost

# parameters on the parameters to calibrate
c_eps1_par = FittableParameter(True, val=1.44)
c_eps2_par = FittableParameter(True, val=1.92)
c_eps3m_par = FittableParameter(True, val=.4)
c_eps3p_par = FittableParameter(True, val=1.)
sig_k_par = FittableParameter(True, val=1.)
sig_eps_par = FittableParameter(True, val=1.3)
coef_dico = {
    'c_eps1': c_eps1_par,
    'c_eps2': c_eps2_par,
    'c_eps3m': c_eps3m_par,
    'c_eps3p': c_eps3p_par,
    'sig_k': sig_k_par,
    'sig_eps': sig_eps_par
}
coef_fit_params = FittableParametersSet(coef_dico, 'k-epsilon')