In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import seaborn as sns

sns.set_theme("notebook",style="dark")

In [None]:
import torch

from qgsw.configs.core import Configuration
from qgsw.forcing.wind import WindForcing
from qgsw.logging.core import getLogger, setup_root_logger
from qgsw.spatial.core.discretization import SpaceDiscretization3D
from qgsw.specs import defaults

specs = defaults.get()

setup_root_logger(2)
logger = getLogger(__name__)


config = Configuration.from_toml("../output/g5k/param_optim/_config.toml")

H = config.model.h
g_prime = config.model.g_prime
H1, H2 = H[0], H[1]
g1, g2 = g_prime[0], g_prime[1]
beta_plane = config.physics.beta_plane
bottom_drag_coef = config.physics.bottom_drag_coefficient
slip_coef = config.physics.slip_coef

wind = WindForcing.from_config(
    config.windstress,
    config.space,
    config.physics,
)
tx, ty = wind.compute()
p = 4

space = SpaceDiscretization3D.from_config(
    config.space,
    config.model,
)
space_2d = space.remove_z_h()
dx,dy = space.dx,space.dy

In [None]:
from qgsw.solver.boundary_conditions.base import Boundaries


def get_psi_slices(imin:int,imax:int,jmin:int,jmax:int) -> tuple[slice]:
    return  [slice(imin, imax + 1), slice(jmin, jmax + 1)]

def extract_psi_w_(psi: torch.Tensor,imin:int,imax:int,jmin:int,jmax:int) -> torch.Tensor:
    """Extract psi."""
    psi_slices_w = get_psi_slices(imin-p,imax+p,jmin-p,jmax+p)
    return psi[..., *psi_slices_w]


def extract_psi_bc(psi: torch.Tensor) -> Boundaries:
    """Extract psi."""
    return Boundaries.extract(psi, p, -p - 1, p, -p - 1, 2)

In [None]:
dt = 7200
n_steps_per_cyle = 250
comparison_interval = 1
n_cycles = 3

In [None]:
save_videos=True

In [None]:
from qgsw.solver.finite_diff import grad


def rmse(f: torch.Tensor, f_ref: torch.Tensor) -> torch.Tensor:
    """RMSE."""
    return (f - f_ref).square().mean().sqrt() / f_ref.square().mean().sqrt()

def compare_fields(f: torch.Tensor, f_ref: torch.Tensor) -> torch.Tensor:
    """RMSE."""
    return (f - f_ref)

def grad_rmse(f:torch.Tensor, f_ref:torch.Tensor) -> torch.Tensor:
    """Gradient RMSE."""
    u,v = grad(f)
    u/=dx
    v/=dy
    u_ref,v_ref = grad(f_ref)
    u_ref/=dx
    v_ref/=dy

    return ((u-u_ref).square().mean()+(v-v_ref).square().mean()).sqrt() / (u_ref.square().mean()+v_ref.square().mean()).sqrt()

In [None]:
from qgsw.fields.variables.tuples import UVH
from qgsw.masks import Masks
from qgsw.models.qg.stretching_matrix import compute_A
from qgsw.models.qg.uvh.projectors.core import QGProjector
from qgsw.utils import covphys


P = QGProjector(
    A=compute_A(H=H, g_prime=g_prime),
    H=H.unsqueeze(-1).unsqueeze(-1),
    space=space,
    f0=beta_plane.f0,
    masks=Masks.empty(
        nx=config.space.nx,
        ny=config.space.ny,
    ),
)
uvh0 = UVH.from_file("../output/g5k/param_optim/_data_startup.pt")
psi_start = P.compute_p(covphys.to_cov(uvh0, dx, dy))[0] / beta_plane.f0

In [None]:
from qgsw.models.qg.psiq.core import QGPSIQ


model_3l = QGPSIQ(
    space_2d=space.remove_z_h(),
    H=H,
    beta_plane=config.physics.beta_plane,
    g_prime=g_prime,
)
model_3l.set_wind_forcing(tx, ty)
model_3l.masks = Masks.empty_tensor(
    model_3l.space.nx,
    model_3l.space.ny,
    device=specs["device"],
)
model_3l.bottom_drag_coef = bottom_drag_coef
model_3l.slip_coef = slip_coef
model_3l.dt = dt
y0 = model_3l.y0

In [None]:
from collections.abc import Callable
from pathlib import Path
from typing import Generic, TypeVar

from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
import numpy as np

from qgsw.analysis.qg_model import ModelWrapper, ModelsManager
from qgsw.models.qg.psiq.core import QGPSIQCore
from qgsw.spatial.core.discretization import SpaceDiscretization2D

T = TypeVar("T", bound=QGPSIQCore)

class ModelWrapperOBC(ModelWrapper[T], Generic[T]):
    results_paths = Path("../output/g5k/param_optim")

    losses:dict[str,list[list[torch.Tensor]]] = {}
    
    ijs:tuple[int,int,int,int] = None
    show=True
    psi1_fields: list[torch.Tensor]
    psi2_fields: list[torch.Tensor]

    def __init__(self, space_2d: SpaceDiscretization2D) -> None:
        super().__init__(space_2d)
        self.losses = {
            "rmse": [],"grad_rmse": [],"rmse-psi2": [],"grad_rmse-psi2": []
        }

    def _set_params(self) -> None:
        space = self.model.space
        self.model.y0 = y0
        self.model.masks = Masks.empty_tensor(
            space.nx,
            space.ny,
            device=specs["device"],
        )
        self.model.bottom_drag_coef = 0
        self.model.wide = True
        self.model.slip_coef = slip_coef
        self.model.dt = dt
        
    def load(self, imin:int,imax:int,jmin:int,jmax:int)-> dict:
        indices = f"_{imin}_{imax}_{jmin}_{jmax}.pt"
        file = self.results_paths.joinpath(self.prefix+indices)
        return torch.load(file)
    def new_cycle(self) -> None:
        super().new_cycle()
        for loss in self.losses.values():
            loss.append([])
            
    def add_loss(self, loss_value:float,loss_name:str) -> None:
        self.losses[loss_name][-1].append(loss_value)
        
    def plot_loss(self,*,loss_name:str,ax:plt.Axes|None=None,cycle:int|None=None) -> Line2D:
        if not self.show:
            return
        if ax is None:
            ax = plt.gca()
        loss = np.concatenate(self.losses[loss_name]) if cycle is None else np.array(self.losses[loss_name][cycle])
        return ax.plot(loss, **self.plot_kwargs)
        
M = TypeVar("M", bound=ModelWrapperOBC[QGPSIQCore])

class ModelsManagerOBC(ModelsManager[M], Generic[M]):

    loss_fn: dict[str, Callable[[torch.Tensor,torch.Tensor], torch.Tensor]]= {
        "rmse":rmse,
        "grad_rmse":grad_rmse,
    }

    losses = list(loss_fn.keys())

    def __init__(self, *mw:M) -> None:
        super().__init__(*mw)
        self.ijs = self.model_wrappers[0].ijs


    @property
    def ijs(self) -> tuple[int,int,int,int]:
        return self.model_wrappers[0].ijs
    @ijs.setter
    def ijs(self,ijs:tuple[int,int,int,int]) -> None:
        self.loop_over_models(lambda mw: setattr(mw,"ijs",ijs))

    @property
    def n_models(self) -> int:
        return len(self.model_wrappers)


    def compute_loss(self, psi_ref:torch.Tensor) -> None:
        for loss_name in self.losses:
            self.loop_over_models(
                lambda mw: mw.add_loss(self.loss_fn[loss_name](mw.model.psi[0,0],psi_ref[0,0]).cpu().item(),loss_name)
            )
    def compute_loss_psi2(self, psi_ref:torch.Tensor) -> None:
        for loss_name in self.losses:
            self.loop_over_models(
                lambda mw: mw.add_loss(self.loss_fn[loss_name](mw.psi2_fields[-1][0,0],psi_ref[0,1]).cpu().item(),loss_name+"-psi2")
            )
        
    
    def plot_loss(self,*,loss_name:str,ax:plt.Axes|None=None,cycle:int|None=None) -> list[Line2D]:
        res = []
        for mw in self.model_wrappers:
            res.extend(mw.plot_loss(loss_name=loss_name,ax=ax,cycle=cycle))
        return res

In [None]:
from matplotlib.animation import FuncAnimation
from torch._tensor import Tensor
from qgsw import plots
from qgsw.decomposition.coefficients import DecompositionCoefs
from qgsw.decomposition.core import build_basis_from_params_dict
from qgsw.models.qg.psiq.modified.core import QGPSIQMixed
from qgsw.plots.plt_wrapper import default_clim, retrieve_colorbar, retrieve_imshow_data
from qgsw.solver.finite_diff import laplacian
from qgsw.spatial.core.grid_conversion import interpolate
from qgsw.utils.interpolation import QuadraticInterpolation
from qgsw.utils.reshaping import crop
from qgsw.utils.tensor_dict import change_specs


class Mixed(ModelWrapperOBC[QGPSIQMixed]):
    prefix = "results_mixed_s"
    color="violet"
    label="Mixed Single Alpha"
    save_video = False
    def _init_model(self, space_2d:SpaceDiscretization2D) -> None:
        self.model= QGPSIQMixed(
            space_2d=space_2d,
            H=H[:2],
            beta_plane=beta_plane,
            g_prime=g_prime[:2],
        )
        self._set_params()
        self.alphas = {}
        self.coefs = {}
    def compute_q(self,psi: Tensor, beta_effect:torch.Tensor, alpha:torch.Tensor, psi2:torch.Tensor) -> Tensor:
        return interpolate(
            laplacian(psi, dx, dy)
            - beta_plane.f0**2 * (1 / H1 / g1 + 1 / H1 / g2) * psi[..., 1:-1, 1:-1]
            + beta_plane.f0**2 * (1 / H1 / g2) * (psi2[...,1:-1,1:-1]+alpha * psi[..., 1:-1, 1:-1])
        ) + beta_effect
    def setup(self, psis: list[Tensor], times: list[Tensor], beta_effect_w: Tensor) -> None:
        res = self.load(*self.ijs)

        imin,imax,jmin,jmax = self.ijs

        space_slice_w = space.remove_z_h().slice(
            imin-p,imax+p+1,jmin-p,jmax+p+1
        )

        alpha:torch.Tensor = res[self.cycle]["alpha"]
        coefs = DecompositionCoefs.from_dict(change_specs(res[self.cycle]["coefs"],**specs))

        basis = build_basis_from_params_dict(res[self.cycle]["config"]["basis"])
        basis.set_coefs(coefs)
        self._fpsi2 = basis.localize(
            space_slice_w.psi.xy.x,space_slice_w.psi.xy.y
        )

        if self.save_params:
            self.alphas[self.cycle] = alpha
            self.coefs[self.cycle] = coefs
        psi0 = psis[0]
        psi_bcs = [extract_psi_bc(psi[:,:1]) for psi in psis]
        q_bcs = [
                Boundaries.extract(
                    self.compute_q(psi[:, :1],beta_effect_w,alpha,self._fpsi2(self.model.time+n*dt)[None,None,...]), 2, -3, 2, -3, 3
                )
                for n,psi in enumerate(psis)
            ]

        self.model.set_psiq(crop(psi0[:,:1],p), crop(self.compute_q(psi0[:,:1],beta_effect_w,alpha,self._fpsi2(self.model.time)[None,None,...]),p-1))
        self.model.alpha = torch.ones_like(self.model.psi)*alpha
        self.model.basis = basis
        self.model.set_boundary_maps(QuadraticInterpolation(times, psi_bcs), QuadraticInterpolation(times, q_bcs))
        
        if save_videos:
            self.psi2_fields = [self.model.alpha*self.model.psi[:,:1]+crop(self._fpsi2(self.model.time)[None,None,...],p)]
            self.psi1_fields = [self.model.psi[:,:1]]
    def step(self) -> None:
        super().step()
        if save_videos:
            self.psi2_fields.append(
                self.model.alpha*self.model.psi[:,:1]+ crop(self._fpsi2(self.model.time)[None,None,...],p)
            )
            self.psi1_fields.append(self.model.psi[:,:1])

In [None]:
from matplotlib.animation import FuncAnimation
from torch._tensor import Tensor
from qgsw import plots
from qgsw.decomposition.coefficients import DecompositionCoefs
from qgsw.decomposition.core import build_basis_from_params_dict
from qgsw.models.qg.psiq.modified.core import QGPSIQMixed
from qgsw.models.qg.psiq.modified.forced import QGPSIQPsi2TransportDR
from qgsw.plots.plt_wrapper import default_clim, retrieve_colorbar, retrieve_imshow_data
from qgsw.solver.finite_diff import laplacian
from qgsw.spatial.core.grid_conversion import interpolate
from qgsw.utils.interpolation import QuadraticInterpolation
from qgsw.utils.reshaping import crop
from qgsw.utils.tensor_dict import change_specs


class MixedDR(ModelWrapperOBC[QGPSIQPsi2TransportDR]):
    prefix = "results_mixed_ro_ge_g5_dr"
    color="palevioletred"
    linestyle="--"
    label="Mixed - GE - G5 - DR"
    save_video = False
    def _init_model(self, space_2d:SpaceDiscretization2D) -> None:
        self.model= QGPSIQPsi2TransportDR(
            space_2d=space_2d,
            H=H[:2],
            beta_plane=beta_plane,
            g_prime=g_prime[:2],
        )
        self._set_params()
        self.alphas = {}
        self.coefs = {}
        self.kappas = {}
    def compute_q(self,psi: Tensor, beta_effect:torch.Tensor, alpha:torch.Tensor,kappa:torch.Tensor, psi2:torch.Tensor) -> Tensor:
        return interpolate(
            laplacian(psi, dx, dy)
            - beta_plane.f0**2 * (1 / H1 / g1 + (1-kappa) / H1 / g2) * psi[..., 1:-1, 1:-1]
            + beta_plane.f0**2 * ((1-kappa) / H1 / g2) * (psi2[...,1:-1,1:-1]+alpha * psi[..., 1:-1, 1:-1])
        ) + beta_effect
    def setup(self, psis: list[Tensor], times: list[Tensor], beta_effect_w: Tensor) -> None:
        res = self.load(*self.ijs)

        imin,imax,jmin,jmax = self.ijs

        space_slice_w = space.remove_z_h().slice(
            imin-p,imax+p+1,jmin-p,jmax+p+1
        )

        alpha:torch.Tensor = res[self.cycle]["alpha"]
        kappa:torch.Tensor = res[self.cycle]["kappa"]
        coefs = DecompositionCoefs.from_dict(change_specs(res[self.cycle]["coefs"],**specs))

        basis = build_basis_from_params_dict(res[self.cycle]["config"]["basis"])
        basis.set_coefs(coefs)
        self._fpsi2 = basis.localize(
            space_slice_w.psi.xy.x,space_slice_w.psi.xy.y
        )

        if self.save_params:
            self.alphas[self.cycle] = alpha
            self.coefs[self.cycle] = coefs
            self.kappas[self.cycle] = kappa
        psi0 = psis[0]
        psi_bcs = [extract_psi_bc(psi[:,:1]) for psi in psis]
        q_bcs = [
                Boundaries.extract(
                    self.compute_q(psi[:, :1],beta_effect_w,alpha,kappa,self._fpsi2(self.model.time+n*dt)[None,None,...]), 2, -3, 2, -3, 3
                )
                for n,psi in enumerate(psis)
            ]

        self.model.set_psiq(crop(psi0[:,:1],p), crop(self.compute_q(psi0[:,:1],beta_effect_w,alpha,kappa,self._fpsi2(self.model.time)[None,None,...]),p-1))
        self.model.alpha = torch.ones_like(self.model.psi)*alpha
        self.model.kappa=kappa
        self.model.basis = basis
        self.model.set_boundary_maps(QuadraticInterpolation(times, psi_bcs), QuadraticInterpolation(times, q_bcs))
        
        if save_videos:
            self.psi2_fields = [(1-self.model.kappa)*(self.model.alpha*self.model.psi[:,:1]+crop(self._fpsi2(self.model.time)[None,None,...],p))]
            self.psi1_fields = [self.model.psi[:,:1]]
    def step(self) -> None:
        super().step()
        if save_videos:
            self.psi2_fields.append(
                (1-self.model.kappa)*(self.model.alpha*self.model.psi[:,:1]+crop(self._fpsi2(self.model.time)[None,None,...],p))
            )
            self.psi1_fields.append(self.model.psi[:,:1])

In [None]:
from qgsw.decomposition.wavelets import WaveletBasis
from qgsw.models.qg.psiq.modified.forced import QGPSIQForced


class ForcedRG(ModelWrapperOBC[QGPSIQForced]):
    prefix = "results_forced_rg"
    color="brown"
    label="Forced"
    def _init_model(self, space_2d:SpaceDiscretization2D) -> None:
        self.model= QGPSIQForced(
            space_2d=space_2d,
            H=H[:1],
            beta_plane=beta_plane,
            g_prime=g_prime[:1]*g_prime[1:2]/(g_prime[:1]+g_prime[1:2]),
        )
        self._set_params()
        self.coefs = {}
    def compute_q(self,psi: Tensor, beta_effect:torch.Tensor) -> Tensor:
        return interpolate(
            laplacian(psi,dx,dy)
            - beta_plane.f0**2 * (1/H1/g1+1/H1/g2)*psi[...,1:-1,1:-1]
        ) + beta_effect

    def setup(self, psis: list[Tensor], times: list[Tensor], beta_effect_w: Tensor) -> None:
        res = self.load(*self.ijs)
        coefs = DecompositionCoefs.from_dict(change_specs(res[self.cycle]["coefs"],**specs))
        self.basis = build_basis_from_params_dict(res[self.cycle]["config"]["basis"])
        self.basis.set_coefs(coefs)
        if self.save_params:
            self.coefs[self.cycle] = coefs
            
        self.wv = self.basis.localize(
            self.model.space.remove_z_h().q.xy.x,
            self.model.space.remove_z_h().q.xy.y,
        )
        psi0 = psis[0]
        psi_bcs = [extract_psi_bc(psi[:,:1]) for psi in psis]
        q_bcs = [
                Boundaries.extract(
                    self.compute_q(psi[:, :1],beta_effect_w), 2, -3, 2, -3, 3
                )
                for psi in psis
            ]

        self.model.set_psiq(crop(psi0[:,:1],p), crop(self.compute_q(psi0[:,:1],beta_effect_w),p-1))
        self.model.set_boundary_maps(QuadraticInterpolation(times, psi_bcs), QuadraticInterpolation(times, q_bcs))
        if save_videos:
            self.psi1_fields = [self.model.psi[:,:1]]
            self.psi2_fields = [torch.zeros_like(self.model.psi[:,:1])]
    def step(self) -> None:
        self.model.forcing = self.wv(self.model.time)
        super().step()
        if save_videos:
            self.psi1_fields.append(self.model.psi[:,:1])
            self.psi2_fields.append(torch.zeros_like(self.model.psi[:,:1]))


In [None]:
from qgsw.decomposition.wavelets import WaveletBasis
from qgsw.models.qg.psiq.modified.forced import QGPSIQForced


class ForcedRGDR(ModelWrapperOBC[QGPSIQForced]):
    prefix = "results_forced_rg"
    color="brown"
    label="Forced"
    def _init_model(self, space_2d:SpaceDiscretization2D) -> None:
        self.model= QGPSIQForced(
            space_2d=space_2d,
            H=H[:1]*H[1:2]/(H[:1]+H[1:2]),
            beta_plane=beta_plane,
            g_prime=H[1:2],
        )
        self.model.wind_scaling = H[:1]
        self._set_params()
        self.coefs = {}
    def compute_q(self,psi: Tensor, beta_effect:torch.Tensor) -> Tensor:
        return interpolate(
            laplacian(psi,dx,dy)
            - beta_plane.f0**2 * (1/H1/g1+1/H1/g2)*psi[...,1:-1,1:-1]
        ) + beta_effect

    def setup(self, psis: list[Tensor], times: list[Tensor], beta_effect_w: Tensor) -> None:
        res = self.load(*self.ijs)
        coefs = DecompositionCoefs.from_dict(change_specs(res[self.cycle]["coefs"],**specs))
        self.basis = build_basis_from_params_dict(res[self.cycle]["config"]["basis"])
        self.basis.set_coefs(coefs)
        if self.save_params:
            self.coefs[self.cycle] = coefs
            
        self.wv = self.basis.localize(
            self.model.space.remove_z_h().q.xy.x,
            self.model.space.remove_z_h().q.xy.y,
        )
        psi0 = psis[0]
        psi_bcs = [extract_psi_bc(psi[:,:1]) for psi in psis]
        q_bcs = [
                Boundaries.extract(
                    self.compute_q(psi[:, :1],beta_effect_w), 2, -3, 2, -3, 3
                )
                for psi in psis
            ]

        self.model.set_psiq(crop(psi0[:,:1],p), crop(self.compute_q(psi0[:,:1],beta_effect_w),p-1))
        self.model.set_boundary_maps(QuadraticInterpolation(times, psi_bcs), QuadraticInterpolation(times, q_bcs))
        if save_videos:
            self.psi1_fields = [self.model.psi[:,:1]]
            self.psi2_fields = [torch.zeros_like(self.model.psi[:,:1])]
    def step(self) -> None:
        self.model.forcing = self.wv(self.model.time)
        super().step()
        if save_videos:
            self.psi1_fields.append(self.model.psi[:,:1])
            self.psi2_fields.append(torch.zeros_like(self.model.psi[:,:1]))


In [None]:
from qgsw.decomposition.wavelets import WaveletBasis
from qgsw.models.qg.psiq.modified.forced import QGPSIQForced

Heq = H[:1]*H[1:2]/(H[:1]+H[1:2])
class ForcedRGDR(ModelWrapperOBC[QGPSIQForced]):
    prefix = "results_forced_rg_dr"
    color="brown"
    linestyle ="dashed"
    label="Forced DR"
    save_video = False
    def _init_model(self, space_2d:SpaceDiscretization2D) -> None:
        self.model= QGPSIQForced(
            space_2d=space_2d,
            H=Heq,
            beta_plane=beta_plane,
            g_prime=g_prime[1:2],
        )
        self.model.wind_scaling = H[:1].item()
        self._set_params()
        self.coefs = {}
    def compute_q(self,psi: Tensor, beta_effect:torch.Tensor) -> Tensor:
        return interpolate(
            laplacian(psi,dx,dy)
            - beta_plane.f0**2 * (1/Heq/g2)*psi[...,1:-1,1:-1]
        ) + beta_effect

    def setup(self, psis: list[Tensor], times: list[Tensor], beta_effect_w: Tensor) -> None:
        res = self.load(*self.ijs)
        coefs = DecompositionCoefs.from_dict(change_specs(res[self.cycle]["coefs"],**specs))
        self.basis = build_basis_from_params_dict(res[self.cycle]["config"]["basis"])
        self.basis.set_coefs(coefs)
        if self.save_params:
            self.coefs[self.cycle] = coefs
            
        self.wv = self.basis.localize(
            self.model.space.remove_z_h().q.xy.x,
            self.model.space.remove_z_h().q.xy.y,
        )
        psi0 = psis[0]
        psi_bcs = [extract_psi_bc(psi[:,:1]) for psi in psis]
        q_bcs = [
                Boundaries.extract(
                    self.compute_q(psi[:, :1],beta_effect_w), 2, -3, 2, -3, 3
                )
                for psi in psis
            ]

        self.model.set_psiq(crop(psi0[:,:1],p), crop(self.compute_q(psi0[:,:1],beta_effect_w),p-1))
        self.model.set_boundary_maps(QuadraticInterpolation(times, psi_bcs), QuadraticInterpolation(times, q_bcs))
        if save_videos:
            self.psi1_fields = [self.model.psi[:,:1]]
            self.psi2_fields = [torch.zeros_like(self.model.psi[:,:1])]
    def step(self) -> None:
        self.model.forcing = self.wv(self.model.time)
        super().step()
        if save_videos:
            self.psi1_fields.append(self.model.psi[:,:1])
            self.psi2_fields.append(torch.zeros_like(self.model.psi[:,:1]))

In [None]:
from matplotlib.artist import Artist
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from mpl_toolkits.axes_grid1 import make_axes_locatable

from qgsw.logging.utils import sec2text


def create_fig_axs(models:ModelsManagerOBC[ModelWrapperOBC[QGPSIQCore]]) -> tuple[Figure, dict[str,Axes]]:

    n = models.n_models
    
    fig, axs = plt.subplot_mosaic(
        [
            [f"ref-{i}"for i in range(n)] + ["cbar-ref"]
        ]+
        [
            [f"compare-{i}"for i in range(n)] + ["cbar-compare"]
        ]+
        [
            [f"psi-{i}"for i in range(n)] + ["cbar-psi"]
        ]+ [["rmse"]*(n+1)]+ [["grad_rmse"]*(n+1)],figsize=(min(12,6*n),20),
        dpi=100,
        gridspec_kw={"width_ratios":[100 for i in range(n)] + [5]}
    )
    fig: Figure
    axs:dict[str, Axes]
    
    return fig, axs

def compute_fields_psi1(psis:list[torch.Tensor], models:ModelsManagerOBC[ModelWrapperOBC[QGPSIQCore]]) -> dict[str, list[torch.Tensor]]:
    
    fields = {}
    
    psis_crop = [crop(psi,p)[0,0] for psi in psis]
    
    for n in range(models.n_models):
        mw = models.model_wrappers[n]

        fields[f"ref-{n}"] = psis_crop
        fields[f"psi-{n}"] = [psi[0,0] for psi in mw.psi1_fields]
        fields[f"compare-{n}"] = [compare_fields(psi, psi_ref) for psi, psi_ref in zip(fields[f"psi-{n}"],fields[f"ref-{n}"])]

    return fields

def compute_fields_psi2(psis:list[torch.Tensor], models:ModelsManagerOBC[ModelWrapperOBC[QGPSIQCore]]) -> dict[str, list[torch.Tensor]]:
    
    fields = {}
    
    psis_crop = [crop(psi,p)[0,1] for psi in psis]
    
    for n in range(models.n_models):
        mw = models.model_wrappers[n]

        fields[f"ref-{n}"] = psis_crop
        fields[f"psi-{n}"] = [psi[0,0] for psi in mw.psi2_fields]
        fields[f"compare-{n}"] = [compare_fields(psi, psi_ref) for psi, psi_ref in zip(fields[f"psi-{n}"],fields[f"ref-{n}"])]

    return fields

def get_comparison_scales(models:ModelsManagerOBC[ModelWrapperOBC[QGPSIQCore]],fields:dict[str, list[torch.Tensor]]) -> dict[str, float]:
    scales = {}
    for n in range(models.n_models):
        scales[f"compare-{n}"] = max(c.abs().max() for c in fields[f"compare-{n}"])

    return scales

def get_clims(models:ModelsManagerOBC[ModelWrapperOBC[QGPSIQCore]],fields:dict[str, list[torch.Tensor]])-> dict[str,dict[str,float]]:
    
    ref_max = max(psi.abs().max() for psi in fields["ref-0"])

    return {
        "ref": {"vmin":-ref_max, "vmax":ref_max},
        "psi": {"vmin":-ref_max, "vmax":ref_max}
    }

def rescale_comparison(fields:dict[str, list[torch.Tensor]], scales:dict[str, float]) -> None:
    for k,v in scales.items():
        cs = fields[k]
        fields[k] = [c/v for c in cs]



def make_video_psi1(psis:list[torch.Tensor], models:ModelsManagerOBC[ModelWrapperOBC[QGPSIQCore]], output_file:Path, cycle:int)-> None:
    fields = compute_fields_psi1(psis,models)
    clims = get_clims(models,fields)
    scales = get_comparison_scales(models,fields)


    N = len(fields["ref-0"])
    fig, axs= create_fig_axs(models)
    rescale_comparison(fields, scales)

    
    defaults_kwargs = {
        "cmap":"RdBu_r",
        "origin":"lower",
    }
    artists = {}


    for n in range(models.n_models):
        axs[f"ref-{n}"].set_title("Ref")
        axs[f"ref-{n}"].xaxis.set_ticks([])
        axs[f"ref-{n}"].yaxis.set_ticks([])
        artists[f"ref-{n}"] = axs[f"ref-{n}"].imshow(retrieve_imshow_data(fields[f"ref-{n}"][0]),**clims["ref"],**defaults_kwargs)
    
        axs[f"compare-{n}"].set_title(f"Diff {scales[f'compare-{n}']:1.1e}")
        axs[f"compare-{n}"].xaxis.set_ticks([])
        axs[f"compare-{n}"].yaxis.set_ticks([])

        artists[f"compare-{n}"] = axs[f"compare-{n}"].imshow(retrieve_imshow_data(fields[f"compare-{n}"][0]),vmin=-1,vmax=1,**defaults_kwargs)

        axs[f"psi-{n}"].set_title(models.model_wrappers[n].label)
        axs[f"psi-{n}"].xaxis.set_ticks([])
        axs[f"psi-{n}"].yaxis.set_ticks([])
        artists[f"psi-{n}"] = axs[f"psi-{n}"].imshow(retrieve_imshow_data(fields[f"psi-{n}"][0]),**clims["psi"],**defaults_kwargs)

    plt.colorbar(artists["ref-0"],cax=axs["cbar-ref"])
    plt.colorbar(artists["compare-0"],cax=axs["cbar-compare"])
    plt.colorbar(artists["psi-0"],cax=axs["cbar-psi"])

    artists["rmse"] = []
    x_loss = np.arange(N)*dt/24/3600
    for i in range(models.n_models):
        artists["rmse"] += axs["rmse"].plot(
            x_loss,
            np.zeros(N),
            **models.model_wrappers[i].plot_kwargs,
        )
    axs["rmse"].set_xlabel("Time [days]")
    axs["rmse"].set_title("RMSE")
    axs["rmse"].set_xlim(0, N*dt/24/3600)
    y_max = 0.1
    axs["rmse"].set_ylim(-(y_max-0)*0.01, y_max)
    axs["rmse"].legend(loc="upper left",prop={'size': 10})

    artists["grad_rmse"] = []
    for i in range(models.n_models):
        artists["grad_rmse"] += axs["grad_rmse"].plot(
            x_loss,
            np.zeros(N),
            **models.model_wrappers[i].plot_kwargs,
        )
    axs["grad_rmse"].set_xlabel("Time [days]")
    axs["grad_rmse"].set_title("Grad RMSE")
    axs["grad_rmse"].set_xlim(0, N*dt/24/3600)
    y_max = 0.5
    axs["grad_rmse"].set_ylim(-(y_max-0)*0.01, y_max)
    axs["grad_rmse"].legend(loc="upper left",prop={'size': 10})

    fig.tight_layout()

    def update(frame:int) -> list[Artist]:
        outs = []
        for n in range(models.n_models):
            artists[f"ref-{n}"].set_array(retrieve_imshow_data(fields[f"ref-{n}"][frame]))
            outs.append(artists[f"ref-{n}"])
            artists[f"compare-{n}"].set_array(retrieve_imshow_data(fields[f"compare-{n}"][frame]))
            outs.append(artists[f"compare-{n}"])
            artists[f"psi-{n}"].set_array(retrieve_imshow_data(fields[f"psi-{n}"][frame]))
            outs.append(artists[f"psi-{n}"])

            y_rmse = models.model_wrappers[n].losses["rmse"][cycle][:frame]
            y = y_rmse + [np.nan] * (N - len(y_rmse))

            artists["rmse"][n].set_data(
                x_loss, y
            )
            outs.append(artists["rmse"][n])

            y_rmse = models.model_wrappers[n].losses["grad_rmse"][cycle][:frame]
            y = y_rmse + [np.nan] * (N - len(y_rmse))

            artists["grad_rmse"][n].set_data(
                x_loss, y
            )
            outs.append(artists["grad_rmse"][n])
        return outs
    anim = FuncAnimation(fig,update, frames=n_steps_per_cyle-1,blit=True)
    anim.save(output_file, fps=20)
    plt.close(fig)


def make_video_psi2(psis:list[torch.Tensor], models:ModelsManagerOBC[ModelWrapperOBC[QGPSIQCore]], output_file:Path, cycle:int)-> None:
    fields = compute_fields_psi2(psis,models)
    clims = get_clims(models,fields)
    scales = get_comparison_scales(models,fields)


    N = len(fields["ref-0"])
    fig, axs= create_fig_axs(models)
    rescale_comparison(fields, scales)

    
    defaults_kwargs = {
        "cmap":"RdBu_r",
        "origin":"lower",
    }
    artists = {}


    for n in range(models.n_models):
        axs[f"ref-{n}"].set_title("Ref")
        axs[f"ref-{n}"].xaxis.set_ticks([])
        axs[f"ref-{n}"].yaxis.set_ticks([])
        artists[f"ref-{n}"] = axs[f"ref-{n}"].imshow(retrieve_imshow_data(fields[f"ref-{n}"][0]),**clims["ref"],**defaults_kwargs)
    
        axs[f"compare-{n}"].set_title(f"Diff {scales[f'compare-{n}']:1.1e}")
        axs[f"compare-{n}"].xaxis.set_ticks([])
        axs[f"compare-{n}"].yaxis.set_ticks([])

        artists[f"compare-{n}"] = axs[f"compare-{n}"].imshow(retrieve_imshow_data(fields[f"compare-{n}"][0]),vmin=-1,vmax=1,**defaults_kwargs)

        axs[f"psi-{n}"].set_title(models.model_wrappers[n].label)
        axs[f"psi-{n}"].xaxis.set_ticks([])
        axs[f"psi-{n}"].yaxis.set_ticks([])
        artists[f"psi-{n}"] = axs[f"psi-{n}"].imshow(retrieve_imshow_data(fields[f"psi-{n}"][0]),**clims["psi"],**defaults_kwargs)

    plt.colorbar(artists["ref-0"],cax=axs["cbar-ref"])
    plt.colorbar(artists["compare-0"],cax=axs["cbar-compare"])
    plt.colorbar(artists["psi-0"],cax=axs["cbar-psi"])

    artists["rmse"] = []
    x_loss = np.arange(N)*dt/24/3600
    for i in range(models.n_models):
        artists["rmse"] += axs["rmse"].plot(
            x_loss,
            np.zeros(N),
            **models.model_wrappers[i].plot_kwargs,
        )
    axs["rmse"].set_xlabel("Time [days]")
    axs["rmse"].set_title("RMSE")
    axs["rmse"].set_xlim(0, N*dt/24/3600)
    y_max = 5
    axs["rmse"].set_ylim(-(y_max-0)*0.01, y_max)
    axs["rmse"].legend(loc="upper left",prop={'size': 10})

    artists["grad_rmse"] = []
    for i in range(models.n_models):
        artists["grad_rmse"] += axs["grad_rmse"].plot(
            x_loss,
            np.zeros(N),
            **models.model_wrappers[i].plot_kwargs,
        )
    axs["grad_rmse"].set_xlabel("Time [days]")
    axs["grad_rmse"].set_title("Grad RMSE")
    axs["grad_rmse"].set_xlim(0, N*dt/24/3600)
    y_max = 5
    axs["grad_rmse"].set_ylim(-(y_max-0)*0.01, y_max)
    axs["grad_rmse"].legend(loc="upper left",prop={'size': 10})

    fig.tight_layout()

    def update(frame:int) -> list[Artist]:
        outs = []
        for n in range(models.n_models):
            artists[f"ref-{n}"].set_array(retrieve_imshow_data(fields[f"ref-{n}"][frame]))
            outs.append(artists[f"ref-{n}"])
            artists[f"compare-{n}"].set_array(retrieve_imshow_data(fields[f"compare-{n}"][frame]))
            outs.append(artists[f"compare-{n}"])
            artists[f"psi-{n}"].set_array(retrieve_imshow_data(fields[f"psi-{n}"][frame]))
            outs.append(artists[f"psi-{n}"])

            y_rmse = models.model_wrappers[n].losses["rmse-psi2"][cycle][:frame]
            y = y_rmse + [np.nan] * (N - len(y_rmse))

            artists["rmse"][n].set_data(
                x_loss, y
            )
            outs.append(artists["rmse"][n])

            y_rmse = models.model_wrappers[n].losses["grad_rmse-psi2"][cycle][:frame]
            y = y_rmse + [np.nan] * (N - len(y_rmse))

            artists["grad_rmse"][n].set_data(
                x_loss, y
            )
            outs.append(artists["grad_rmse"][n])
        return outs
    anim = FuncAnimation(fig,update, frames=n_steps_per_cyle-1,blit=True)
    anim.save(output_file, fps=20)
    plt.close(fig)


In [None]:
import gc
from matplotlib.animation import FuncAnimation
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.image import AxesImage
from qgsw import plots
from qgsw.logging.utils import box, sec2text, step
from qgsw.plots.plt_wrapper import retrieve_colorbar
from qgsw.utils.reshaping import crop

imin, imax = 32, 96
jmin, jmax = 64, 192
extract_psi_w = lambda psi: extract_psi_w_(psi,imin,imax,jmin,jmax)


model_3l.reset_time()
model_3l.set_psi(psi_start)

space_slice = SpaceDiscretization2D.from_tensors(
    x=P.space.remove_z_h().omega.xy.x[imin : imax + 1, 0],
    y=P.space.remove_z_h().omega.xy.y[0, jmin : jmax + 1],
)

space_slice_w = SpaceDiscretization2D.from_tensors(
    x=P.space.remove_z_h().omega.xy.x[imin - p + 1 : imax + p, 0],
    y=P.space.remove_z_h().omega.xy.y[0, jmin - p + 1 : jmax + p],
)
y_w = space_slice_w.q.xy.y[0, :].unsqueeze(0)
beta_effect_w = beta_plane.beta * (y_w - y0)

forced = ForcedRG(space_slice)
forced_dr = ForcedRGDR(space_slice)
mixed = Mixed(space_slice)
mixed_ro = Mixed(space_slice)
mixed_ro.color = "blueviolet"
mixed_ro.label = "Mixed - TE"
mixed_ro.prefix = "results_mixed_ro"
mixed_ro_o3 = Mixed(space_slice)
mixed_ro_o3.color = "indigo"
mixed_ro_o3.label = "Mixed - TE - O3"
mixed_ro_o3.prefix = "results_mixed_ro_o3"
mixed_ro_ge = Mixed(space_slice)
mixed_ro_ge.color = "magenta"
mixed_ro_ge.label = "Mixed - GE"
mixed_ro_ge.prefix = "results_mixed_ro_ge"
mixed_ro_ge_g5 = Mixed(space_slice)
mixed_ro_ge_g5.color = "palevioletred"
mixed_ro_ge_g5.label = "Mixed - GE - O5"
mixed_ro_ge_g5.prefix = "results_mixed_ro_ge_g5"
mixed_ro_ge_g5_dr = MixedDR(space_slice)
mixed_ro_ge_g5_ = Mixed(space_slice)
mixed_ro_ge_g5_.color = "red"
mixed_ro_ge_g5_.label = "Mixed - GE - O5 - ɑ=0"
mixed_ro_ge_g5_.prefix = "results_forced_md_ge_g5"
mixed_ro_ge_g5_dr_ = MixedDR(space_slice)
mixed_ro_ge_g5_dr_.color="red"
mixed_ro_ge_g5_dr_.label = "Mixed - GE - O5 - DR - ɑ=0"
mixed_ro_ge_g5_dr_.prefix = "results_forced_md_ge_g5_dr"

models = ModelsManagerOBC[ModelWrapperOBC[QGPSIQCore]](
    forced,
    # forced_dr,
    # mixed_ro,
    # mixed,
    # mixed_ro_o3,
    # mixed_ro_ge,
    mixed_ro_ge_g5,
    mixed_ro_ge_g5_dr,
    mixed_ro_ge_g5_,
    mixed_ro_ge_g5_dr_
)
models.ijs = (imin,imax,jmin,jmax)
models.save_params = True

models.set_wind_forcing(tx[imin:imax, jmin : jmax + 1],ty[imin : imax + 1, jmin:jmax])

psi0s = {}
dpsis = {}

for c in range(n_cycles):
    torch.cuda.reset_peak_memory_stats()
    models.new_cycle()
    times = [model_3l.time.item()]

    psi0 = extract_psi_w(model_3l.psi[:,:2])
    psi0s[c] = psi0

    psis = [psi0]

    for _ in range(1, n_steps_per_cyle):
        model_3l.step()

        times.append(model_3l.time.item())

        psi = extract_psi_w(model_3l.psi[:,:2])

        psis.append(psi)

    models.reset_time()

    models.setup(psis,times,beta_effect_w)

    models.compute_loss(crop(psis[0],p))
    models.compute_loss_psi2(crop(psis[0],p))
    for n in range(1,n_steps_per_cyle):
        models.step()
        models.compute_loss(crop(psis[n],p))
        models.compute_loss_psi2(crop(psis[n],p))

        
    output_folder = Path(f"../output/videos/comparison")
    if not output_folder.is_dir():
        output_folder.mkdir(parents=True)
        
    
    # make_video_psi1(psis,models,output_folder.joinpath(f"psi1_{imin}_{imax}_{jmin}_{jmax}_{c}.mp4"),c)
    make_video_psi2(psis,models,output_folder.joinpath(f"psi2_test_{imin}_{imax}_{jmin}_{jmax}_{c}.mp4"),c)
    torch.cuda.empty_cache()
    gc.collect()

    max_mem = torch.cuda.max_memory_allocated() / 1024 / 1024
    msg_mem = f"Cycle {step(c + 1, n_cycles)} | Max memory allocated: {max_mem:.1f} MB."
    logger.info(box(msg_mem, style="round"))

In [None]:
import gc
from matplotlib.animation import FuncAnimation
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.image import AxesImage
from qgsw import plots
from qgsw.logging.utils import box, step
from qgsw.plots.plt_wrapper import retrieve_colorbar
from qgsw.utils.reshaping import crop

imin, imax = 32, 96
jmin, jmax = 256, 384
extract_psi_w = lambda psi: extract_psi_w_(psi,imin,imax,jmin,jmax)


model_3l.reset_time()
model_3l.set_psi(psi_start)

space_slice = SpaceDiscretization2D.from_tensors(
    x=P.space.remove_z_h().omega.xy.x[imin : imax + 1, 0],
    y=P.space.remove_z_h().omega.xy.y[0, jmin : jmax + 1],
)

space_slice_w = SpaceDiscretization2D.from_tensors(
    x=P.space.remove_z_h().omega.xy.x[imin - p + 1 : imax + p, 0],
    y=P.space.remove_z_h().omega.xy.y[0, jmin - p + 1 : jmax + p],
)
y_w = space_slice_w.q.xy.y[0, :].unsqueeze(0)
beta_effect_w = beta_plane.beta * (y_w - y0)

forced = ForcedRG(space_slice)
forced_dr = ForcedRGDR(space_slice)
mixed = Mixed(space_slice)
mixed_ro = Mixed(space_slice)
mixed_ro.color = "blueviolet"
mixed_ro.label = "Mixed - TE"
mixed_ro.prefix = "results_mixed_ro"
mixed_ro_o3 = Mixed(space_slice)
mixed_ro_o3.color = "indigo"
mixed_ro_o3.label = "Mixed - TE - O3"
mixed_ro_o3.prefix = "results_mixed_ro_o3"
mixed_ro_ge = Mixed(space_slice)
mixed_ro_ge.color = "magenta"
mixed_ro_ge.label = "Mixed - GE"
mixed_ro_ge.prefix = "results_mixed_ro_ge"
mixed_ro_ge_g5 = Mixed(space_slice)
mixed_ro_ge_g5.color = "palevioletred"
mixed_ro_ge_g5.label = "Mixed - GE - O5"
mixed_ro_ge_g5.prefix = "results_mixed_ro_ge_g5"
mixed_ro_ge_g5_dr = MixedDR(space_slice)
mixed_ro_ge_g5_ = Mixed(space_slice)
mixed_ro_ge_g5_.color = "red"
mixed_ro_ge_g5_.label = "Mixed - GE - O5 - ɑ=0"
mixed_ro_ge_g5_.prefix = "results_forced_md_ge_g5"
mixed_ro_ge_g5_dr_ = MixedDR(space_slice)
mixed_ro_ge_g5_dr_.color="red"
mixed_ro_ge_g5_dr_.label = "Mixed - GE - O5 - DR - ɑ=0"
mixed_ro_ge_g5_dr_.prefix = "results_forced_md_ge_g5_dr"

models = ModelsManagerOBC[ModelWrapperOBC[QGPSIQCore]](
    forced,
    # forced_dr,
    # mixed_ro,
    # mixed,
    # mixed_ro_o3,
    # mixed_ro_ge,
    mixed_ro_ge_g5,
    mixed_ro_ge_g5_dr,
    mixed_ro_ge_g5_,
    mixed_ro_ge_g5_dr_
)
models.ijs = (imin,imax,jmin,jmax)
models.save_params = True

models.set_wind_forcing(tx[imin:imax, jmin : jmax + 1],ty[imin : imax + 1, jmin:jmax])

psi0s = {}
dpsis = {}

for c in range(n_cycles):
    torch.cuda.reset_peak_memory_stats()
    models.new_cycle()
    times = [model_3l.time.item()]

    psi0 = extract_psi_w(model_3l.psi[:,:2])
    psi0s[c] = psi0

    psis = [psi0]

    for _ in range(1, n_steps_per_cyle):
        model_3l.step()

        times.append(model_3l.time.item())

        psi = extract_psi_w(model_3l.psi[:,:2])

        psis.append(psi)

    models.reset_time()

    models.setup(psis,times,beta_effect_w)

    models.compute_loss(crop(psis[0],p))
    models.compute_loss_psi2(crop(psis[0],p))
    for n in range(1,n_steps_per_cyle):
        models.step()
        models.compute_loss(crop(psis[n],p))
        models.compute_loss_psi2(crop(psis[n],p))

        
    output_folder = Path(f"../output/videos/comparison")
    if not output_folder.is_dir():
        output_folder.mkdir(parents=True)
        
    
    # make_video_psi1(psis,models,output_folder.joinpath(f"psi1_{imin}_{imax}_{jmin}_{jmax}_{c}.mp4"),c)
    make_video_psi2(psis,models,output_folder.joinpath(f"psi2_test_{imin}_{imax}_{jmin}_{jmax}_{c}.mp4"),c)
    torch.cuda.empty_cache()
    gc.collect()

    max_mem = torch.cuda.max_memory_allocated() / 1024 / 1024
    msg_mem = f"Cycle {step(c + 1, n_cycles)} | Max memory allocated: {max_mem:.1f} MB."
    logger.info(box(msg_mem, style="round"))

In [None]:
import gc
from matplotlib.animation import FuncAnimation
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.image import AxesImage
from qgsw import plots
from qgsw.logging.utils import box, step
from qgsw.plots.plt_wrapper import retrieve_colorbar
from qgsw.utils.reshaping import crop

imin, imax = 112, 176
jmin, jmax = 64, 192
extract_psi_w = lambda psi: extract_psi_w_(psi,imin,imax,jmin,jmax)


model_3l.reset_time()
model_3l.set_psi(psi_start)

space_slice = SpaceDiscretization2D.from_tensors(
    x=P.space.remove_z_h().omega.xy.x[imin : imax + 1, 0],
    y=P.space.remove_z_h().omega.xy.y[0, jmin : jmax + 1],
)

space_slice_w = SpaceDiscretization2D.from_tensors(
    x=P.space.remove_z_h().omega.xy.x[imin - p + 1 : imax + p, 0],
    y=P.space.remove_z_h().omega.xy.y[0, jmin - p + 1 : jmax + p],
)
y_w = space_slice_w.q.xy.y[0, :].unsqueeze(0)
beta_effect_w = beta_plane.beta * (y_w - y0)

forced = ForcedRG(space_slice)
forced_dr = ForcedRGDR(space_slice)
mixed = Mixed(space_slice)
mixed_ro = Mixed(space_slice)
mixed_ro.color = "blueviolet"
mixed_ro.label = "Mixed - TE"
mixed_ro.prefix = "results_mixed_ro"
mixed_ro_o3 = Mixed(space_slice)
mixed_ro_o3.color = "indigo"
mixed_ro_o3.label = "Mixed - TE - O3"
mixed_ro_o3.prefix = "results_mixed_ro_o3"
mixed_ro_ge = Mixed(space_slice)
mixed_ro_ge.color = "magenta"
mixed_ro_ge.label = "Mixed - GE"
mixed_ro_ge.prefix = "results_mixed_ro_ge"
mixed_ro_ge_g5 = Mixed(space_slice)
mixed_ro_ge_g5.color = "palevioletred"
mixed_ro_ge_g5.label = "Mixed - GE - O5"
mixed_ro_ge_g5.prefix = "results_mixed_ro_ge_g5"
mixed_ro_ge_g5_dr = MixedDR(space_slice)
mixed_ro_ge_g5_ = Mixed(space_slice)
mixed_ro_ge_g5_.color = "red"
mixed_ro_ge_g5_.label = "Mixed - GE - O5 - ɑ=0"
mixed_ro_ge_g5_.prefix = "results_forced_md_ge_g5"
mixed_ro_ge_g5_dr_ = MixedDR(space_slice)
mixed_ro_ge_g5_dr_.color="red"
mixed_ro_ge_g5_dr_.label = "Mixed - GE - O5 - DR - ɑ=0"
mixed_ro_ge_g5_dr_.prefix = "results_forced_md_ge_g5_dr"

models = ModelsManagerOBC[ModelWrapperOBC[QGPSIQCore]](
    forced,
    # forced_dr,
    # mixed_ro,
    # mixed,
    # mixed_ro_o3,
    # mixed_ro_ge,
    mixed_ro_ge_g5,
    mixed_ro_ge_g5_dr,
    mixed_ro_ge_g5_,
    mixed_ro_ge_g5_dr_
)
models.ijs = (imin,imax,jmin,jmax)
models.save_params = True

models.set_wind_forcing(tx[imin:imax, jmin : jmax + 1],ty[imin : imax + 1, jmin:jmax])

psi0s = {}
dpsis = {}

for c in range(n_cycles):
    torch.cuda.reset_peak_memory_stats()
    models.new_cycle()
    times = [model_3l.time.item()]

    psi0 = extract_psi_w(model_3l.psi[:,:2])
    psi0s[c] = psi0

    psis = [psi0]

    for _ in range(1, n_steps_per_cyle):
        model_3l.step()

        times.append(model_3l.time.item())

        psi = extract_psi_w(model_3l.psi[:,:2])

        psis.append(psi)

    models.reset_time()

    models.setup(psis,times,beta_effect_w)

    models.compute_loss(crop(psis[0],p))
    models.compute_loss_psi2(crop(psis[0],p))
    for n in range(1,n_steps_per_cyle):
        models.step()
        models.compute_loss(crop(psis[n],p))
        models.compute_loss_psi2(crop(psis[n],p))

        
    output_folder = Path(f"../output/videos/comparison")
    if not output_folder.is_dir():
        output_folder.mkdir(parents=True)
        
    
    # make_video_psi1(psis,models,output_folder.joinpath(f"psi1_{imin}_{imax}_{jmin}_{jmax}_{c}.mp4"),c)
    make_video_psi2(psis,models,output_folder.joinpath(f"psi2_test_{imin}_{imax}_{jmin}_{jmax}_{c}.mp4"),c)
    torch.cuda.empty_cache()
    gc.collect()

    max_mem = torch.cuda.max_memory_allocated() / 1024 / 1024
    msg_mem = f"Cycle {step(c + 1, n_cycles)} | Max memory allocated: {max_mem:.1f} MB."
    logger.info(box(msg_mem, style="round"))

In [None]:
import gc
from matplotlib.animation import FuncAnimation
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.image import AxesImage
from qgsw import plots
from qgsw.logging.utils import box, step
from qgsw.plots.plt_wrapper import retrieve_colorbar
from qgsw.utils.reshaping import crop

imin, imax = 112, 176
jmin, jmax = 256, 384
extract_psi_w = lambda psi: extract_psi_w_(psi,imin,imax,jmin,jmax)


model_3l.reset_time()
model_3l.set_psi(psi_start)

space_slice = SpaceDiscretization2D.from_tensors(
    x=P.space.remove_z_h().omega.xy.x[imin : imax + 1, 0],
    y=P.space.remove_z_h().omega.xy.y[0, jmin : jmax + 1],
)

space_slice_w = SpaceDiscretization2D.from_tensors(
    x=P.space.remove_z_h().omega.xy.x[imin - p + 1 : imax + p, 0],
    y=P.space.remove_z_h().omega.xy.y[0, jmin - p + 1 : jmax + p],
)
y_w = space_slice_w.q.xy.y[0, :].unsqueeze(0)
beta_effect_w = beta_plane.beta * (y_w - y0)

forced = ForcedRG(space_slice)
forced_dr = ForcedRGDR(space_slice)
mixed = Mixed(space_slice)
mixed_ro = Mixed(space_slice)
mixed_ro.color = "blueviolet"
mixed_ro.label = "Mixed - TE"
mixed_ro.prefix = "results_mixed_ro"
mixed_ro_o3 = Mixed(space_slice)
mixed_ro_o3.color = "indigo"
mixed_ro_o3.label = "Mixed - TE - O3"
mixed_ro_o3.prefix = "results_mixed_ro_o3"
mixed_ro_ge = Mixed(space_slice)
mixed_ro_ge.color = "magenta"
mixed_ro_ge.label = "Mixed - GE"
mixed_ro_ge.prefix = "results_mixed_ro_ge"
mixed_ro_ge_g5 = Mixed(space_slice)
mixed_ro_ge_g5.color = "palevioletred"
mixed_ro_ge_g5.label = "Mixed - GE - O5"
mixed_ro_ge_g5.prefix = "results_mixed_ro_ge_g5"
mixed_ro_ge_g5_dr = MixedDR(space_slice)
mixed_ro_ge_g5_ = Mixed(space_slice)
mixed_ro_ge_g5_.color = "red"
mixed_ro_ge_g5_.label = "Mixed - GE - O5 - ɑ=0"
mixed_ro_ge_g5_.prefix = "results_forced_md_ge_g5"
mixed_ro_ge_g5_dr_ = MixedDR(space_slice)
mixed_ro_ge_g5_dr_.color="red"
mixed_ro_ge_g5_dr_.label = "Mixed - GE - O5 - DR - ɑ=0"
mixed_ro_ge_g5_dr_.prefix = "results_forced_md_ge_g5_dr"

models = ModelsManagerOBC[ModelWrapperOBC[QGPSIQCore]](
    forced,
    # forced_dr,
    # mixed_ro,
    # mixed,
    # mixed_ro_o3,
    # mixed_ro_ge,
    mixed_ro_ge_g5,
    mixed_ro_ge_g5_dr,
    mixed_ro_ge_g5_,
    mixed_ro_ge_g5_dr_
)
models.ijs = (imin,imax,jmin,jmax)
models.save_params = True

models.set_wind_forcing(tx[imin:imax, jmin : jmax + 1],ty[imin : imax + 1, jmin:jmax])

psi0s = {}
dpsis = {}

for c in range(n_cycles):
    torch.cuda.reset_peak_memory_stats()
    models.new_cycle()
    times = [model_3l.time.item()]

    psi0 = extract_psi_w(model_3l.psi[:,:2])
    psi0s[c] = psi0

    psis = [psi0]

    for _ in range(1, n_steps_per_cyle):
        model_3l.step()

        times.append(model_3l.time.item())

        psi = extract_psi_w(model_3l.psi[:,:2])

        psis.append(psi)

    models.reset_time()

    models.setup(psis,times,beta_effect_w)

    models.compute_loss(crop(psis[0],p))
    models.compute_loss_psi2(crop(psis[0],p))
    for n in range(1,n_steps_per_cyle):
        models.step()
        models.compute_loss(crop(psis[n],p))
        models.compute_loss_psi2(crop(psis[n],p))

        
    output_folder = Path(f"../output/videos/comparison")
    if not output_folder.is_dir():
        output_folder.mkdir(parents=True)
        
    
    # make_video_psi1(psis,models,output_folder.joinpath(f"psi1_{imin}_{imax}_{jmin}_{jmax}_{c}.mp4"),c)
    make_video_psi2(psis,models,output_folder.joinpath(f"psi2_test_{imin}_{imax}_{jmin}_{jmax}_{c}.mp4"),c)
    torch.cuda.empty_cache()
    gc.collect()

    max_mem = torch.cuda.max_memory_allocated() / 1024 / 1024
    msg_mem = f"Cycle {step(c + 1, n_cycles)} | Max memory allocated: {max_mem:.1f} MB."
    logger.info(box(msg_mem, style="round"))