In [None]:
%load_ext autoreload
%autoreload 2

## Setup

### Libraries

In [None]:
import torch
import seaborn as sns
from qgsw.logging import getLogger, setup_root_logger
from qgsw.specs import defaults
import gc

torch.backends.cudnn.deterministic = True
torch.set_grad_enabled(False)
sns.set_theme("notebook",style="dark")

specs = defaults.get()

setup_root_logger(1)
logger = getLogger(__name__)

### Parameters

In [None]:
from qgsw.configs.core import Configuration
from qgsw.forcing.wind import WindForcing


config = Configuration.from_toml("../output/g5k/param_optim/_config.toml")

H = config.model.h
g_prime = config.model.g_prime
H1, H2 = H[0], H[1]
g1, g2 = g_prime[0], g_prime[1]
beta_plane = config.physics.beta_plane
bottom_drag_coef = config.physics.bottom_drag_coefficient
slip_coef = config.physics.slip_coef

wind = WindForcing.from_config(
    config.windstress,
    config.space,
    config.physics,
)
tx, ty = wind.compute()
p = 4

### Space

In [None]:
from qgsw.spatial.core.discretization import SpaceDiscretization3D


space = SpaceDiscretization3D.from_config(
    config.space,
    config.model,
)
dx, dy = space.dx, space.dy

In [None]:
from qgsw.solver.boundary_conditions.base import Boundaries


def get_psi_slices(imin:int,imax:int,jmin:int,jmax:int) -> tuple[slice]:
    return  [slice(imin, imax + 1), slice(jmin, jmax + 1)]

def extract_psi_w_(psi: torch.Tensor,imin:int,imax:int,jmin:int,jmax:int) -> torch.Tensor:
    """Extract psi."""
    psi_slices_w = get_psi_slices(imin-p,imax+p,jmin-p,jmax+p)
    return psi[..., *psi_slices_w]


def extract_psi_bc(psi: torch.Tensor) -> Boundaries:
    """Extract psi."""
    return Boundaries.extract(psi, p, -p - 1, p, -p - 1, 2)

### Simulation

In [None]:
dt = 7200
n_steps_per_cyle = 250
comparison_interval = 1
n_cycles = 3

### Outputs

In [None]:
save_videos = False

### RMSE

In [None]:
from qgsw.solver.finite_diff import grad, laplacian
from qgsw.utils.reshaping import crop


def rmse(f: torch.Tensor, f_ref: torch.Tensor) -> torch.Tensor:
    """RMSE."""
    return (f - f_ref).square().mean().sqrt() / f_ref.square().mean().sqrt()

def grad_rmse(f:torch.Tensor, f_ref:torch.Tensor) -> torch.Tensor:
    """Gradient RMSE."""
    u,v = grad(f)
    u/=dx
    v/=dy
    u_ref,v_ref = grad(f_ref)
    u_ref/=dx
    v_ref/=dy

    return ((u-u_ref).square().mean()+(v-v_ref).square().mean()).sqrt() / (u_ref.square().mean()+v_ref.square().mean()).sqrt()

def vorticity_rmse(f:torch.Tensor, f_ref:torch.Tensor) -> torch.Tensor:
    """Vorticity RMSE."""
    omega = crop(laplacian(f,dx,dy),1)
    omega_ref = crop(laplacian(f_ref,dx,dy),1)

    return (omega - omega_ref).square().mean().sqrt() / omega_ref.square().mean().sqrt()

## Initial condition

In [None]:
from qgsw.fields.variables.tuples import UVH
from qgsw.masks import Masks
from qgsw.models.qg.stretching_matrix import compute_A
from qgsw.models.qg.uvh.projectors.core import QGProjector
from qgsw.utils import covphys


P = QGProjector(
    A=compute_A(H=H, g_prime=g_prime),
    H=H.unsqueeze(-1).unsqueeze(-1),
    space=space,
    f0=beta_plane.f0,
    masks=Masks.empty(
        nx=config.space.nx,
        ny=config.space.ny,
    ),
)
uvh0 = UVH.from_file("../output/g5k/param_optim/_data_startup.pt")
psi_start = P.compute_p(covphys.to_cov(uvh0, dx, dy))[0] / beta_plane.f0

### Full-domain model

In [None]:
from qgsw.models.qg.psiq.core import QGPSIQ


model_3l = QGPSIQ(
    space_2d=space.remove_h(),
    H=H,
    beta_plane=config.physics.beta_plane,
    g_prime=g_prime,
)
model_3l.set_wind_forcing(tx, ty)
model_3l.masks = Masks.empty_tensor(
    model_3l.space.nx,
    model_3l.space.ny,
    device=specs["device"],
)
model_3l.bottom_drag_coef = bottom_drag_coef
model_3l.slip_coef = slip_coef
model_3l.dt = dt
y0 = model_3l.y0

## OBC models

### Base

In [None]:
from collections.abc import Callable
from pathlib import Path
from typing import TypeVar

from matplotlib import pyplot as plt
import numpy as np

from qgsw.analysis.qg_model import ModelWrapper, ModelsManager
from qgsw.models.qg.psiq.core import QGPSIQCore
from qgsw.spatial.core.discretization import SpaceDiscretization2D

T = TypeVar("T", bound=QGPSIQCore)

class ModelWrapperOBC(ModelWrapper[T]):
    results_paths = Path("../output/g5k/param_optim")

    losses:dict[str,list[list[torch.Tensor]]] = {}
    
    ijs:tuple[int,int,int,int] = None
    save_states = False
    show=True

    def __init__(self, space_2d: SpaceDiscretization2D) -> None:
        super().__init__(space_2d)
        self.losses = {
            "rmse": [],"grad_rmse": [], "vorticity_rmse": [],
        }
        self.states:dict[str,list[list[torch.Tensor]]] = {
            "psi1": []
        }

    def _set_params(self) -> None:
        space = self.model.space
        self.model.y0 = y0
        self.model.masks = Masks.empty_tensor(
            space.nx,
            space.ny,
            device=specs["device"],
        )
        self.model.bottom_drag_coef = 0
        self.model.wide = True
        self.model.slip_coef = slip_coef
        self.model.dt = dt
        
    def load(self, imin:int,imax:int,jmin:int,jmax:int)-> dict:
        indices = f"_{imin}_{imax}_{jmin}_{jmax}.pt"
        file = self.results_paths.joinpath(self.prefix+indices)
        return torch.load(file)
    def new_cycle(self) -> None:
        super().new_cycle()
        if self.save_states:
            for s in self.states.values():
                s.append([])

        for loss in self.losses.values():
            loss.append([])
            
    def add_loss(self, loss_value:float,loss_name:str) -> None:
        self.losses[loss_name][-1].append(loss_value)
        
    def plot_loss(self,*,loss_name:str,ax:plt.Axes|None=None,cycle:int|None=None) -> None:
        if not self.show:
            return
        if ax is None:
            ax = plt.gca()
        cycles = [cycle] if cycle is not None else list(range(len(self.losses[loss_name])))
        time_offset= 0
        for i,c in enumerate(cycles):
            times = self.model.dt*np.arange(len(self.losses[loss_name][c]))/3600/24 + time_offset
            time_offset = times[-1] + self.model.dt/3600/24
            loss =  np.array(self.losses[loss_name][c])
            kwargs = self.plot_kwargs.copy()
            if i!= 0:
                kwargs.pop("label")
            ax.plot(times, loss, **kwargs)
        ax.set_xlabel("Time [day]")
    def step(self) -> None:
        super().step()
        if self.save_states:
            self.states["psi1"].append(self.model.psi[:,:1])
            
        
M = TypeVar("M", bound=ModelWrapperOBC[QGPSIQCore])

class ModelsManagerOBC(ModelsManager[M]):

    loss_fn: dict[str, Callable[[torch.Tensor,torch.Tensor], torch.Tensor]]= {
        "rmse":rmse,
        "grad_rmse":grad_rmse,
        "vorticity_rmse":vorticity_rmse
    }

    losses = list(loss_fn.keys())

    def __init__(self, *mw:M) -> None:
        super().__init__(*mw)
        self.ijs = self.model_wrappers[0].ijs


    @property
    def ijs(self) -> tuple[int,int,int,int]:
        return self.model_wrappers[0].ijs
    @ijs.setter
    def ijs(self,ijs:tuple[int,int,int,int]) -> None:
        self.loop_over_models(lambda mw: setattr(mw,"ijs",ijs))

    def compute_loss(self, psi_ref:torch.Tensor) -> None:
        for loss_name in self.losses:
            self.loop_over_models(
                lambda mw: mw.add_loss(self.loss_fn[loss_name](mw.model.psi[0,0],psi_ref[0,0]).cpu().item(),loss_name)
            )
        
    
    def plot_loss(self,*,loss_name:str,ax:plt.Axes|None=None,cycle:int|None=None) -> None:
        self.loop_over_models(lambda mw: mw.plot_loss(loss_name=loss_name,ax=ax,cycle=cycle))

In [None]:
H1_,H2_ = 600,900 #H[0],H[1]
H_ = torch.tensor([H1_,H2_],**specs)
g1_, g2_ = g_prime[0],g_prime[1]*0.1
g_prime_ = torch.tensor([g1_,g2_],**specs)


### Reduced gravity

In [None]:
from torch import Tensor
from qgsw.solver.finite_diff import laplacian
from qgsw.spatial.core.discretization import SpaceDiscretization2D
from qgsw.spatial.core.grid_conversion import interpolate
from qgsw.utils.interpolation import QuadraticInterpolation
from qgsw.utils.reshaping import crop


class ReducedGravity(ModelWrapperOBC[QGPSIQ]):
    prefix = None
    color = "blue"
    label="Reduced Gravity"
    def _init_model(self, space_2d:SpaceDiscretization2D) -> None:
        self.model= QGPSIQ(
            space_2d=space_2d,
            H=H[:1],
            beta_plane=beta_plane,
            g_prime=g_prime[:1]*g_prime[1:2]/(g_prime[:1]+g_prime[1:2]),
        )
        self._set_params()
    def compute_q(self,psi: Tensor, beta_effect:torch.Tensor) -> Tensor:
        return interpolate(laplacian(psi,dx,dy) - beta_plane.f0**2 * (1/H1/g1+1/H1/g2)*psi[...,1:-1,1:-1]) + beta_effect
    
    def setup(self, psis: list[torch.Tensor],times:list[torch.Tensor],beta_effect_w:torch.Tensor) -> None:
        psi0 = psis[0]
        psi_bcs = [extract_psi_bc(psi[:,:1]) for psi in psis]
        q_bcs = [
            Boundaries.extract(
                self.compute_q(psi[:, :1],beta_effect_w), 2, -3, 2, -3, 3
            )
            for psi in psis
        ]
        self.model.set_boundary_maps(QuadraticInterpolation(times, psi_bcs), QuadraticInterpolation(times, q_bcs))
        self.model.set_psiq(crop(psi0[:,:1],p), crop(self.compute_q(psi0[:,:1],beta_effect_w),p-1))
        
        if self.save_states:
            self.states["psi1"].append(self.model.psi[:,:1])
class ReducedGravityPerturbed(ModelWrapperOBC[QGPSIQ]):
    prefix = None
    color = "blue"
    label="Reduced Gravity"
    def _init_model(self, space_2d:SpaceDiscretization2D) -> None:
        self.model= QGPSIQ(
            space_2d=space_2d,
            H=H_[:1],
            beta_plane=beta_plane,
            g_prime=g_prime_[:1]*g_prime_[1:2]/(g_prime_[:1]+g_prime_[1:2]),
        )
        self._set_params()
    def compute_q(self,psi: Tensor, beta_effect:torch.Tensor) -> Tensor:
        return interpolate(laplacian(psi,dx,dy) - beta_plane.f0**2 * (1/H1_/g1_+1/H1_/g2_)*psi[...,1:-1,1:-1]) + beta_effect
    
    def setup(self, psis: list[torch.Tensor],times:list[torch.Tensor],beta_effect_w:torch.Tensor) -> None:
        psi0 = psis[0]
        psi_bcs = [extract_psi_bc(psi[:,:1]) for psi in psis]
        q_bcs = [
            Boundaries.extract(
                self.compute_q(psi[:, :1],beta_effect_w), 2, -3, 2, -3, 3
            )
            for psi in psis
        ]
        self.model.set_boundary_maps(QuadraticInterpolation(times, psi_bcs), QuadraticInterpolation(times, q_bcs))
        self.model.set_psiq(crop(psi0[:,:1],p), crop(self.compute_q(psi0[:,:1],beta_effect_w),p-1))
        
        if self.save_states:
            self.states["psi1"].append(self.model.psi[:,:1])

### Mixed

In [None]:
from matplotlib.animation import FuncAnimation
from torch._tensor import Tensor
from qgsw import plots
from qgsw.decomposition.coefficients import DecompositionCoefs
from qgsw.decomposition.core import build_basis_from_params_dict
from qgsw.models.qg.psiq.modified.forced import QGPSIQRGPsi2Transport, QGPSIQRGPsi2TransportDR
from qgsw.models.qg.stretching_matrix import compute_A_tilde
from qgsw.plots.plt_wrapper import default_clim, retrieve_colorbar, retrieve_imshow_data
from qgsw.spatial.core.discretization import SpaceDiscretization2D
from qgsw.utils.tensor_dict import change_specs


class RGPsi2Transport(ModelWrapperOBC[QGPSIQRGPsi2Transport]):
    prefix = "results_mixed_rg_ro_ge"
    color="navy"
    label="GaussBarotropic - RG"
    save_video = False
    def _init_model(self, space_2d:SpaceDiscretization2D) -> None:
        self.model= QGPSIQRGPsi2Transport(
            space_2d=space_2d,
            H=H[:2],
            beta_plane=beta_plane,
            g_prime=g_prime[:2],
        )
        self._set_params()
        self.alphas = {}
        self.coefs = {}
    def compute_q(self,psi: Tensor, beta_effect:torch.Tensor) -> Tensor:
        return interpolate(
            laplacian(psi, dx, dy)
            - beta_plane.f0**2 * (1 / H1 / g1 + 1 / H1 / g2) * psi[..., 1:-1, 1:-1]
        ) + beta_effect
    def setup(self, psis: list[Tensor], times: list[Tensor], beta_effect_w: Tensor) -> None:
        res = self.load(*self.ijs)

        imin,imax,jmin,jmax = self.ijs

        space_slice = space.remove_h().slice(
            imin,imax+1,jmin,jmax+1
        )
        try:
            alpha:torch.Tensor = res[self.cycle]["alpha"]
        except KeyError:
            alpha=torch.tensor(0,**specs)
        coefs = DecompositionCoefs.from_dict(change_specs(res[self.cycle]["coefs"],**specs))

        basis = build_basis_from_params_dict(res[self.cycle]["config"]["basis"])
        try:
            basis.freeze_time_normalization(self.model.dt*torch.tensor([n_steps_per_cyle],**specs))
        except:... 
        basis.set_coefs(coefs)
        self._fpsi2 = basis.localize(
            space_slice.psi.xy.x,space_slice.psi.xy.y
        )

        if self.save_params:
            self.alphas[self.cycle] = alpha
            self.coefs[self.cycle] = coefs
        psi0 = psis[0]
        psi_bcs = [extract_psi_bc(psi[:,:1]) for psi in psis]
        q_bcs = [
                Boundaries.extract(
                    self.compute_q(psi[:, :1],beta_effect_w), 2, -3, 2, -3, 3
                )
                for n,psi in enumerate(psis)
            ]

        self.model.set_psiq(crop(psi0[:,:1],p), crop(self.compute_q(psi0[:,:1],beta_effect_w),p-1))
        self.model.alpha = torch.ones_like(self.model.psi)*alpha
        self.model.basis = basis
        self.model.set_boundary_maps(QuadraticInterpolation(times, psi_bcs), QuadraticInterpolation(times, q_bcs))

class RGPsi2TransportPerturbed(ModelWrapperOBC[QGPSIQRGPsi2Transport]):
    prefix = "results_mixed_rg_ro_ge_perturbed"
    color="navy"
    label="GaussBarotropic - RG - Pert"
    save_video = False
    def _init_model(self, space_2d:SpaceDiscretization2D) -> None:
        self.model= QGPSIQRGPsi2Transport(
            space_2d=space_2d,
            H=H_[:2],
            beta_plane=beta_plane,
            g_prime=g_prime_[:2],
        )
        self._set_params()
        self.alphas = {}
        self.coefs = {}
    def compute_q(self,psi: Tensor, beta_effect:torch.Tensor) -> Tensor:
        return interpolate(
            laplacian(psi, dx, dy)
            - beta_plane.f0**2 * (1 / H1_ / g1_ + 1 / H1_ / g2_) * psi[..., 1:-1, 1:-1]
        ) + beta_effect
    def setup(self, psis: list[Tensor], times: list[Tensor], beta_effect_w: Tensor) -> None:
        res = self.load(*self.ijs)

        imin,imax,jmin,jmax = self.ijs

        space_slice = space.remove_h().slice(
            imin,imax+1,jmin,jmax+1
        )
        try:
            alpha:torch.Tensor = res[self.cycle]["alpha"]
        except KeyError:
            alpha=torch.tensor(0,**specs)
        coefs = DecompositionCoefs.from_dict(change_specs(res[self.cycle]["coefs"],**specs))

        basis = build_basis_from_params_dict(res[self.cycle]["config"]["basis"])
        try:
            basis.freeze_time_normalization(self.model.dt*torch.tensor([n_steps_per_cyle],**specs))
        except:... 
        basis.set_coefs(coefs)
        self._fpsi2 = basis.localize(
            space_slice.psi.xy.x,space_slice.psi.xy.y
        )

        if self.save_params:
            self.alphas[self.cycle] = alpha
            self.coefs[self.cycle] = coefs
        psi0 = psis[0]
        psi_bcs = [extract_psi_bc(psi[:,:1]) for psi in psis]
        q_bcs = [
                Boundaries.extract(
                    self.compute_q(psi[:, :1],beta_effect_w), 2, -3, 2, -3, 3
                )
                for psi in psis
            ]

        self.model.set_psiq(crop(psi0[:,:1],p), crop(self.compute_q(psi0[:,:1],beta_effect_w),p-1))
        self.model.alpha = torch.ones_like(self.model.psi)*alpha
        self.model.basis = basis
        self.model.set_boundary_maps(QuadraticInterpolation(times, psi_bcs), QuadraticInterpolation(times, q_bcs))

class RGPsi2TransportPerturbedDR(ModelWrapperOBC[QGPSIQRGPsi2TransportDR]):
    prefix = "results_mixed_rg_ro_ge_perturbed_dr"
    color="orange"
    label="GaussBarotropic - RG - Pert - DR"
    save_video = False
    def _init_model(self, space_2d:SpaceDiscretization2D) -> None:
        self.model= QGPSIQRGPsi2TransportDR(
            space_2d=space_2d,
            H=H_[:2],
            beta_plane=beta_plane,
            g_prime=g_prime_[:2],
        )
        self._set_params()
        self.alphas = {}
        self.coefs = {}
    def compute_q(self,psi: Tensor, A11:torch.Tensor, beta_effect:torch.Tensor) -> Tensor:
        return interpolate(
            laplacian(psi, dx, dy)
            - beta_plane.f0**2 * A11 * psi[..., 1:-1, 1:-1]
        ) + beta_effect
    def setup(self, psis: list[Tensor], times: list[Tensor], beta_effect_w: Tensor) -> None:
        res = self.load(*self.ijs)

        imin,imax,jmin,jmax = self.ijs

        space_slice = space.remove_h().slice(
            imin,imax+1,jmin,jmax+1
        )
        try:
            alpha:torch.Tensor = res[self.cycle]["alpha"]
        except KeyError:
            alpha=torch.tensor(0,**specs)
        self.A = compute_A_tilde(H_,g_prime_,alpha,**specs)
        coefs = DecompositionCoefs.from_dict(change_specs(res[self.cycle]["coefs"],**specs))

        basis = build_basis_from_params_dict(res[self.cycle]["config"]["basis"])
        try:
            basis.freeze_time_normalization(self.model.dt*torch.tensor([n_steps_per_cyle],**specs))
        except:... 
        basis.set_coefs(coefs)
        self._fpsi2 = basis.localize(
            space_slice.psi.xy.x,space_slice.psi.xy.y
        )

        if self.save_params:
            self.alphas[self.cycle] = alpha
            self.coefs[self.cycle] = coefs
        psi0 = psis[0]
        psi_bcs = [extract_psi_bc(psi[:,:1]) for psi in psis]
        q_bcs = [
                Boundaries.extract(
                    self.compute_q(psi[:, :1],self.A[:1,:1],beta_effect_w), 2, -3, 2, -3, 3
                )
                for psi in psis
            ]

        self.model.set_psiq(crop(psi0[:,:1],p), crop(self.compute_q(psi0[:,:1],self.A[:1,:1],beta_effect_w),p-1))
        self.model.alpha = alpha
        self.model.basis = basis
        self.model.set_boundary_maps(QuadraticInterpolation(times, psi_bcs), QuadraticInterpolation(times, q_bcs))

class RGPsi2TransportPerturbedDR_(ModelWrapperOBC[QGPSIQRGPsi2TransportDR]):
    prefix = "results_mixed_rg_ro_ge_perturbed_dr"
    color="orange"
    label="GaussBarotropic - RG - Pert - DR"
    save_video = False
    def _init_model(self, space_2d:SpaceDiscretization2D) -> None:
        self.model= QGPSIQRGPsi2TransportDR(
            space_2d=space_2d,
            H=H[:2],
            beta_plane=beta_plane,
            g_prime=g_prime[:2],
        )
        self._set_params()
        self.alphas = {}
        self.coefs = {}
    def compute_q(self,psi: Tensor, A11:torch.Tensor, beta_effect:torch.Tensor) -> Tensor:
        return interpolate(
            laplacian(psi, dx, dy)
            - beta_plane.f0**2 * A11 * psi[..., 1:-1, 1:-1]
        ) + beta_effect
    def setup(self, psis: list[Tensor], times: list[Tensor], beta_effect_w: Tensor) -> None:
        res = self.load(*self.ijs)

        imin,imax,jmin,jmax = self.ijs

        space_slice = space.remove_h().slice(
            imin,imax+1,jmin,jmax+1
        )
        try:
            alpha:torch.Tensor = res[self.cycle]["alpha"]
        except KeyError:
            alpha=torch.tensor(0,**specs)
        self.A = compute_A_tilde(H[:2],g_prime[:2],alpha,**specs)
        coefs = DecompositionCoefs.from_dict(change_specs(res[self.cycle]["coefs"],**specs))

        basis = build_basis_from_params_dict(res[self.cycle]["config"]["basis"])
        try:
            basis.freeze_time_normalization(self.model.dt*torch.tensor([n_steps_per_cyle],**specs))
        except:... 
        basis.set_coefs(coefs)
        self._fpsi2 = basis.localize(
            space_slice.psi.xy.x,space_slice.psi.xy.y
        )

        if self.save_params:
            self.alphas[self.cycle] = alpha
            self.coefs[self.cycle] = coefs
        psi0 = psis[0]
        psi_bcs = [extract_psi_bc(psi[:,:1]) for psi in psis]
        q_bcs = [
                Boundaries.extract(
                    self.compute_q(psi[:, :1],self.A[:1,:1],beta_effect_w), 2, -3, 2, -3, 3
                )
                for psi in psis
            ]

        self.model.set_psiq(crop(psi0[:,:1],p), crop(self.compute_q(psi0[:,:1],self.A[:1,:1],beta_effect_w),p-1))
        self.model.alpha = alpha
        self.model.basis = basis
        self.model.set_boundary_maps(QuadraticInterpolation(times, psi_bcs), QuadraticInterpolation(times, q_bcs))

### Forced

In [None]:
from qgsw.decomposition.wavelets import WaveletBasis
from qgsw.models.qg.psiq.modified.forced import QGPSIQForced

Heq = H[:1]*H[1:2]/(H[:1]+H[1:2])
class ForcedRGDR(ModelWrapperOBC[QGPSIQForced]):
    prefix = "results_forced_rg_dr"
    color="brown"
    label="Forced DR"
    save_video = False
    def __init__(self, space_2d: SpaceDiscretization2D) -> None:
        super().__init__(space_2d)
        self.states["forcing"] = []
    def _init_model(self, space_2d:SpaceDiscretization2D) -> None:
        self.model= QGPSIQForced(
            space_2d=space_2d,
            H=Heq,
            beta_plane=beta_plane,
            g_prime=g_prime[1:2],
        )
        self.model.wind_scaling = H[:1].item()
        self._set_params()
        self.coefs = {}
    def compute_q(self,psi: Tensor, beta_effect:torch.Tensor) -> Tensor:
        return interpolate(
            laplacian(psi,dx,dy)
            - beta_plane.f0**2 * (1/Heq/g2)*psi[...,1:-1,1:-1]
        ) + beta_effect

    def setup(self, psis: list[Tensor], times: list[Tensor], beta_effect_w: Tensor) -> None:
        res = self.load(*self.ijs)
        coefs = DecompositionCoefs.from_dict(change_specs(res[self.cycle]["coefs"],**specs))
        self.basis = build_basis_from_params_dict(res[self.cycle]["config"]["basis"])
        self.basis.set_coefs(coefs)
        if self.save_params:
            self.coefs[self.cycle] = coefs
            
        self.wv = self.basis.localize(
            self.model.space.remove_h().q.xy.x,
            self.model.space.remove_h().q.xy.y,
        )
        psi0 = psis[0]
        psi_bcs = [extract_psi_bc(psi[:,:1]) for psi in psis]
        q_bcs = [
                Boundaries.extract(
                    self.compute_q(psi[:, :1],beta_effect_w), 2, -3, 2, -3, 3
                )
                for psi in psis
            ]

        self.model.set_psiq(crop(psi0[:,:1],p), crop(self.compute_q(psi0[:,:1],beta_effect_w),p-1))
        self.model.set_boundary_maps(QuadraticInterpolation(times, psi_bcs), QuadraticInterpolation(times, q_bcs))
        
        if self.save_states:
            self.states["psi1"] = [self.model.psi[:,:1]]
            self.states["forcing"] = [crop(self.wv(self.model.time)[None,None,...],p)]
        
    def step(self) -> None:
        self.model.forcing = self.wv(self.model.time)
        super().step()
        if self.save_states:
            self.states["forcing"].append(crop(self.wv(self.model.time)[None,None,...],p))

Heq_ = H_[:1]*H_[1:2]/(H_[:1]+H_[1:2])
class ForcedRGDRPerturbed(ModelWrapperOBC[QGPSIQForced]):
    prefix = "results_forced_rg_dr_perturbed"
    color="brown"
    label="Forced DR - Pert"
    save_video = False
    def __init__(self, space_2d: SpaceDiscretization2D) -> None:
        super().__init__(space_2d)
        self.states["forcing"] = []
    def _init_model(self, space_2d:SpaceDiscretization2D) -> None:
        self.model= QGPSIQForced(
            space_2d=space_2d,
            H=Heq_,
            beta_plane=beta_plane,
            g_prime=g_prime_[1:2],
        )
        self.model.wind_scaling = H_[:1].item()
        self._set_params()
        self.coefs = {}
    def compute_q(self,psi: Tensor, beta_effect:torch.Tensor) -> Tensor:
        return interpolate(
            laplacian(psi,dx,dy)
            - beta_plane.f0**2 * (1/Heq_/g2_)*psi[...,1:-1,1:-1]
        ) + beta_effect

    def setup(self, psis: list[Tensor], times: list[Tensor], beta_effect_w: Tensor) -> None:
        res = self.load(*self.ijs)
        coefs = DecompositionCoefs.from_dict(change_specs(res[self.cycle]["coefs"],**specs))
        self.basis = build_basis_from_params_dict(res[self.cycle]["config"]["basis"])
        self.basis.set_coefs(coefs)
        if self.save_params:
            self.coefs[self.cycle] = coefs
            
        self.wv = self.basis.localize(
            self.model.space.remove_h().q.xy.x,
            self.model.space.remove_h().q.xy.y,
        )
        psi0 = psis[0]
        psi_bcs = [extract_psi_bc(psi[:,:1]) for psi in psis]
        q_bcs = [
                Boundaries.extract(
                    self.compute_q(psi[:, :1],beta_effect_w), 2, -3, 2, -3, 3
                )
                for psi in psis
            ]

        self.model.set_psiq(crop(psi0[:,:1],p), crop(self.compute_q(psi0[:,:1],beta_effect_w),p-1))
        self.model.set_boundary_maps(QuadraticInterpolation(times, psi_bcs), QuadraticInterpolation(times, q_bcs))
        
        if self.save_states:
            self.states["psi1"] = [self.model.psi[:,:1]]
            self.states["forcing"] = [crop(self.wv(self.model.time)[None,None,...],p)]
        
    def step(self) -> None:
        self.model.forcing = self.wv(self.model.time)
        super().step()
        if self.save_states:
            self.states["forcing"].append(crop(self.wv(self.model.time)[None,None,...],p))

## [32, 96] x [64, 192]

### Model runs

In [None]:
from matplotlib.animation import FuncAnimation
from qgsw import plots
from qgsw.logging.utils import box, step
from qgsw.plots.plt_wrapper import retrieve_colorbar
from qgsw.utils.reshaping import crop

imin, imax = 32, 96
jmin, jmax = 64, 192
extract_psi_w = lambda psi: extract_psi_w_(psi,imin,imax,jmin,jmax)


model_3l.reset_time()
model_3l.set_psi(psi_start)

space_slice = SpaceDiscretization2D.from_coords(
    x_1d=P.space.remove_h().omega.xy.x[imin : imax + 1, 0],
    y_1d=P.space.remove_h().omega.xy.y[0, jmin : jmax + 1],
)

space_slice_w = SpaceDiscretization2D.from_coords(
    x_1d=P.space.remove_h().omega.xy.x[imin - p + 1 : imax + p, 0],
    y_1d=P.space.remove_h().omega.xy.y[0, jmin - p + 1 : jmax + p],
)
y_w = space_slice_w.q.xy.y[0, :].unsqueeze(0)
beta_effect_w = beta_plane.beta * (y_w - y0)

rg = ReducedGravity(space_slice)
rg.linestyle="dotted"

rg_ = ReducedGravityPerturbed(space_slice)
rg.label = "Reduced Gravity - Pert"

forced = ForcedRGDR(space_slice)
forced.label = "Forced"
forced.linestyle = "dotted"
forced.prefix = "results_forced_rg_dr_gamma1000_obstrack"

forced_ = ForcedRGDRPerturbed(space_slice)
forced_.label = "Forced - Pert"
forced_.prefix = "results_forced_rg_dr_pert_gamma1000_obstrack"

gauss_barotropic_rg = RGPsi2Transport(space_slice)
gauss_barotropic_rg.label = "GaussBarotropic"
gauss_barotropic_rg.color = "navy"
gauss_barotropic_rg.linestyle = "dotted"
gauss_barotropic_rg.prefix = "results_mixed_rg_ro_ge_g5_gamma0_1_obstrack"

gauss_barotropic_rg_noalpha = RGPsi2Transport(space_slice)
gauss_barotropic_rg_noalpha.label = "GaussBarotropic - NoAlpha"
gauss_barotropic_rg_noalpha.color = "lightblue"
gauss_barotropic_rg_noalpha.linestyle = "dotted"
gauss_barotropic_rg_noalpha.prefix = "results_mixed_rg_ro_ge_g5_noalpha_gamma0_1_obstrack"

gauss_barotropic_rg_dr_pert_noalpha_lr = RGPsi2TransportPerturbedDR(space_slice)
gauss_barotropic_rg_dr_pert_noalpha_lr.label = "GaussBarotropic - DR - Pert - NoAlpha"
gauss_barotropic_rg_dr_pert_noalpha_lr.color = "lightblue"
gauss_barotropic_rg_dr_pert_noalpha_lr.prefix = "results_mixed_rg_ro_ge_pert_lr_noalpha_gamma0_1_obstrack"

gauss_barotropic_rg_dr_pert_lr = RGPsi2TransportPerturbedDR(space_slice)
gauss_barotropic_rg_dr_pert_lr.label = "GaussBarotropic - DR - Pert"
gauss_barotropic_rg_dr_pert_lr.color = "navy"
gauss_barotropic_rg_dr_pert_lr.prefix = "results_mixed_rg_ro_ge_pert_lr5_gamma0_1_obstrack"

gauss_barotropic_rg_dr_nopert_noalpha_lr = RGPsi2TransportPerturbedDR_(space_slice)
gauss_barotropic_rg_dr_nopert_noalpha_lr.label = "GaussBarotropic - DR - NoPert - NoAlpha"
gauss_barotropic_rg_dr_nopert_noalpha_lr.color = "lightblue"
gauss_barotropic_rg_dr_nopert_noalpha_lr.linestyle = "dashed"
gauss_barotropic_rg_dr_nopert_noalpha_lr.prefix = "results_mixed_rg_ro_ge_nopert_lr_noalpha_gamma0_1_obstrack"

gauss_barotropic_rg_dr_nopert_lr = RGPsi2TransportPerturbedDR_(space_slice)
gauss_barotropic_rg_dr_nopert_lr.label = "GaussBarotropic - DR - NoPert"
gauss_barotropic_rg_dr_nopert_lr.color = "navy"
gauss_barotropic_rg_dr_nopert_lr.linestyle = "dashed"
gauss_barotropic_rg_dr_nopert_lr.prefix = "results_mixed_rg_ro_ge_nopert_lr5_gamma0_1_obstrack"

gauss_barotropic_psi2 = RGPsi2Transport(space_slice)
gauss_barotropic_psi2.label = "Psi2"
gauss_barotropic_psi2.color = "green"
gauss_barotropic_psi2.prefix = "results_psi2"


models = ModelsManagerOBC(
    rg,
    rg_,
    forced,
    forced_,
    gauss_barotropic_rg,
    gauss_barotropic_rg_noalpha,
    gauss_barotropic_rg_dr_nopert_lr,
    gauss_barotropic_rg_dr_nopert_noalpha_lr,
    gauss_barotropic_rg_dr_pert_lr,
    gauss_barotropic_rg_dr_pert_noalpha_lr,
)
models.ijs = (imin,imax,jmin,jmax)
models.save_params = True

models.set_wind_forcing(tx[imin:imax, jmin : jmax + 1],ty[imin : imax + 1, jmin:jmax])

psi0s = {}
dpsis = {}
gc.collect()
for c in range(n_cycles):
    torch.cuda.reset_peak_memory_stats()
    models.new_cycle()
    times = [model_3l.time.item()]

    psi0 = extract_psi_w(model_3l.psi[:,:2])
    psi0s[c] = psi0

    psis = [psi0]

    for _ in range(1, n_steps_per_cyle):
        model_3l.step()

        times.append(model_3l.time.item())

        psi = extract_psi_w(model_3l.psi[:,:2])

        psis.append(psi)
        
    dpsis[c] = (psis[-1]-psis[0])/(n_steps_per_cyle-1)/dt

    models.reset_time()

    models.setup(psis,times,beta_effect_w)

    models.compute_loss(crop(psis[0],p))
    for n in range(1,n_steps_per_cyle):
        models.step()
        models.compute_loss(crop(psis[n],p))

    torch.cuda.empty_cache()
    gc.collect()

    max_mem = torch.cuda.max_memory_allocated() / 1024 / 1024
    msg_mem = f"Cycle {step(c + 1, n_cycles)} | Max memory allocated: {max_mem:.1f} MB."
    logger.info(box(msg_mem, style="round"))

### Results

In [None]:
from matplotlib import pyplot as plt

from qgsw import plots

# To Show

# Plots

show_grad = True
show_vort = True

fig,axs = plots.subplots(1+show_grad+show_vort,2, gridspec_kw = {"width_ratios":[1,7.5]},figsize=(21,5+5*show_grad+5*show_vort))
plots.set_rowtitles(["RMSE"]+show_grad*[ "Gradient RMSE"] + show_vort*[ "Vorticity RMSE"],axs=axs)
plots.imshow(psi_start[0,0],ax=axs[0,0])
axs[0,0].hlines([jmin,jmax],imin,imax,color="black")
axs[0,0].vlines([imin,imax],jmin,jmax,color="black")
models.plot_loss(loss_name="rmse",ax=axs[0,1])
plots.clamp_ylims(0,1,axs[0,1])
axs[0,1].legend(loc="upper left",prop={'size': 8})
if show_grad:
    plots.imshow(psi_start[0,0],ax=axs[1,0])
    axs[1,0].hlines([jmin,jmax],imin,imax,color="black")
    axs[1,0].vlines([imin,imax],jmin,jmax,color="black")
    models.plot_loss(loss_name="grad_rmse",ax=axs[1,1])
    plots.clamp_ylims(0,1,axs[1,1])
    axs[1,1].legend(loc="upper left",prop={'size': 8})
if show_vort:
    plots.imshow(psi_start[0,0],ax=axs[-1,0])
    axs[-1,0].hlines([jmin,jmax],imin,imax,color="black")
    axs[-1,0].vlines([imin,imax],jmin,jmax,color="black")
    models.plot_loss(loss_name="vorticity_rmse",ax=axs[-1,1])
    plots.clamp_ylims(0,1,axs[-1,1])
    axs[-1,1].legend(loc="upper left",prop={'size': 8})

# To Show
rg.show=False
rg_.show=False
forced.show=False
forced_.show=False
# Plots

show_grad = True
show_vort = True

fig,axs = plots.subplots(1+show_grad+show_vort,2, gridspec_kw = {"width_ratios":[1,7.5]},figsize=(21,5+5*show_grad+5*show_vort))
fig.suptitle("Without RG & Forced")
plots.set_rowtitles(["RMSE"]+show_grad*[ "Gradient RMSE"] + show_vort*[ "Vorticity RMSE"],axs=axs)
plots.imshow(psi_start[0,0],ax=axs[0,0])
axs[0,0].hlines([jmin,jmax],imin,imax,color="black")
axs[0,0].vlines([imin,imax],jmin,jmax,color="black")
models.plot_loss(loss_name="rmse",ax=axs[0,1])
plots.clamp_ylims(0,1,axs[0,1])
axs[0,1].legend(loc="upper left",prop={'size': 8})
if show_grad:
    plots.imshow(psi_start[0,0],ax=axs[1,0])
    axs[1,0].hlines([jmin,jmax],imin,imax,color="black")
    axs[1,0].vlines([imin,imax],jmin,jmax,color="black")
    models.plot_loss(loss_name="grad_rmse",ax=axs[1,1])
    plots.clamp_ylims(0,1,axs[1,1])
    axs[1,1].legend(loc="upper left",prop={'size': 8})
if show_vort:
    plots.imshow(psi_start[0,0],ax=axs[-1,0])
    axs[-1,0].hlines([jmin,jmax],imin,imax,color="black")
    axs[-1,0].vlines([imin,imax],jmin,jmax,color="black")
    models.plot_loss(loss_name="vorticity_rmse",ax=axs[-1,1])
    plots.clamp_ylims(0,1,axs[-1,1])
    axs[-1,1].legend(loc="upper left",prop={'size': 8})

## [32, 96] x [256, 384]

### Model runs

In [None]:
from matplotlib.animation import FuncAnimation
from qgsw import plots
from qgsw.logging.utils import box, step
from qgsw.plots.plt_wrapper import retrieve_colorbar
from qgsw.utils.reshaping import crop

imin, imax = 32, 96
jmin, jmax = 256, 384
extract_psi_w = lambda psi: extract_psi_w_(psi,imin,imax,jmin,jmax)


model_3l.reset_time()
model_3l.set_psi(psi_start)

space_slice = SpaceDiscretization2D.from_coords(
    x_1d=P.space.remove_h().omega.xy.x[imin : imax + 1, 0],
    y_1d=P.space.remove_h().omega.xy.y[0, jmin : jmax + 1],
)

space_slice_w = SpaceDiscretization2D.from_coords(
    x_1d=P.space.remove_h().omega.xy.x[imin - p + 1 : imax + p, 0],
    y_1d=P.space.remove_h().omega.xy.y[0, jmin - p + 1 : jmax + p],
)
y_w = space_slice_w.q.xy.y[0, :].unsqueeze(0)
beta_effect_w = beta_plane.beta * (y_w - y0)

rg = ReducedGravity(space_slice)
rg.linestyle="dotted"

rg_ = ReducedGravityPerturbed(space_slice)
rg.label = "Reduced Gravity - Pert"

forced = ForcedRGDR(space_slice)
forced.label = "Forced"
forced.linestyle = "dotted"
forced.prefix = "results_forced_rg_dr_gamma1000_obstrack"

forced_ = ForcedRGDRPerturbed(space_slice)
forced_.label = "Forced - Pert"
forced_.prefix = "results_forced_rg_dr_pert_gamma1000_obstrack"

gauss_barotropic_rg = RGPsi2Transport(space_slice)
gauss_barotropic_rg.label = "GaussBarotropic"
gauss_barotropic_rg.color = "navy"
gauss_barotropic_rg.linestyle = "dotted"
gauss_barotropic_rg.prefix = "results_mixed_rg_ro_ge_g5_gamma0_1_obstrack"

gauss_barotropic_rg_noalpha = RGPsi2Transport(space_slice)
gauss_barotropic_rg_noalpha.label = "GaussBarotropic - NoAlpha"
gauss_barotropic_rg_noalpha.color = "lightblue"
gauss_barotropic_rg_noalpha.linestyle = "dotted"
gauss_barotropic_rg_noalpha.prefix = "results_mixed_rg_ro_ge_g5_noalpha_gamma0_1_obstrack"

gauss_barotropic_rg_dr_pert_noalpha_lr = RGPsi2TransportPerturbedDR(space_slice)
gauss_barotropic_rg_dr_pert_noalpha_lr.label = "GaussBarotropic - DR - Pert - NoAlpha"
gauss_barotropic_rg_dr_pert_noalpha_lr.color = "lightblue"
gauss_barotropic_rg_dr_pert_noalpha_lr.prefix = "results_mixed_rg_ro_ge_pert_lr_noalpha_gamma0_1_obstrack"

gauss_barotropic_rg_dr_pert_lr = RGPsi2TransportPerturbedDR(space_slice)
gauss_barotropic_rg_dr_pert_lr.label = "GaussBarotropic - DR - Pert"
gauss_barotropic_rg_dr_pert_lr.color = "navy"
gauss_barotropic_rg_dr_pert_lr.prefix = "results_mixed_rg_ro_ge_pert_lr5_gamma0_1_obstrack"

gauss_barotropic_rg_dr_nopert_noalpha_lr = RGPsi2TransportPerturbedDR_(space_slice)
gauss_barotropic_rg_dr_nopert_noalpha_lr.label = "GaussBarotropic - DR - NoPert - NoAlpha"
gauss_barotropic_rg_dr_nopert_noalpha_lr.color = "lightblue"
gauss_barotropic_rg_dr_nopert_noalpha_lr.linestyle = "dashed"
gauss_barotropic_rg_dr_nopert_noalpha_lr.prefix = "results_mixed_rg_ro_ge_nopert_lr_noalpha_gamma0_1_obstrack"

gauss_barotropic_rg_dr_nopert_lr = RGPsi2TransportPerturbedDR_(space_slice)
gauss_barotropic_rg_dr_nopert_lr.label = "GaussBarotropic - DR - NoPert"
gauss_barotropic_rg_dr_nopert_lr.color = "navy"
gauss_barotropic_rg_dr_nopert_lr.linestyle = "dashed"
gauss_barotropic_rg_dr_nopert_lr.prefix = "results_mixed_rg_ro_ge_nopert_lr5_gamma0_1_obstrack"

gauss_barotropic_psi2 = RGPsi2Transport(space_slice)
gauss_barotropic_psi2.label = "Psi2"
gauss_barotropic_psi2.color = "green"
gauss_barotropic_psi2.prefix = "results_psi2"


models = ModelsManagerOBC(
    rg,
    rg_,
    forced,
    forced_,
    gauss_barotropic_rg,
    gauss_barotropic_rg_noalpha,
    gauss_barotropic_rg_dr_nopert_lr,
    gauss_barotropic_rg_dr_nopert_noalpha_lr,
    gauss_barotropic_rg_dr_pert_lr,
    gauss_barotropic_rg_dr_pert_noalpha_lr,
)
models.ijs = (imin,imax,jmin,jmax)
models.save_params = True

models.set_wind_forcing(tx[imin:imax, jmin : jmax + 1],ty[imin : imax + 1, jmin:jmax])

psi0s = {}
dpsis = {}
gc.collect()


for c in range(n_cycles):
    torch.cuda.reset_peak_memory_stats()
    models.new_cycle()
    times = [model_3l.time.item()]

    psi0 = extract_psi_w(model_3l.psi[:,:2])
    psi0s[c] = psi0

    psis = [psi0]

    for _ in range(1, n_steps_per_cyle):
        model_3l.step()

        times.append(model_3l.time.item())

        psi = extract_psi_w(model_3l.psi[:,:2])

        psis.append(psi)
        
    dpsis[c] = (psis[-1]-psis[0])/(n_steps_per_cyle-1)/dt

    models.reset_time()

    models.setup(psis,times,beta_effect_w)

    models.compute_loss(crop(psis[0],p))
    rgs = [rg.model.psi[0,0]]
    rgs_ = [rg_.model.psi[0,0]]
    for n in range(1,n_steps_per_cyle):
        models.step()
        models.compute_loss(crop(psis[n],p))
        rgs.append(rg.model.psi[0,0])
        rgs_.append(rg_.model.psi[0,0])

    torch.cuda.empty_cache()
    gc.collect()

    max_mem = torch.cuda.max_memory_allocated() / 1024 / 1024
    msg_mem = f"Cycle {step(c + 1, n_cycles)} | Max memory allocated: {max_mem:.1f} MB."
    logger.info(box(msg_mem, style="round"))

### Results

In [None]:
from matplotlib import pyplot as plt

from qgsw import plots

# To Show

# Plots

show_grad = True
show_vort = True

fig,axs = plots.subplots(1+show_grad+show_vort,2, gridspec_kw = {"width_ratios":[1,7.5]},figsize=(21,5+5*show_grad+5*show_vort))
plots.set_rowtitles(["RMSE"]+show_grad*[ "Gradient RMSE"] + show_vort*[ "Vorticity RMSE"],axs=axs)
plots.imshow(psi_start[0,0],ax=axs[0,0])
axs[0,0].hlines([jmin,jmax],imin,imax,color="black")
axs[0,0].vlines([imin,imax],jmin,jmax,color="black")
models.plot_loss(loss_name="rmse",ax=axs[0,1])
plots.clamp_ylims(0,1,axs[0,1])
axs[0,1].legend(loc="upper left",prop={'size': 8})
if show_grad:
    plots.imshow(psi_start[0,0],ax=axs[1,0])
    axs[1,0].hlines([jmin,jmax],imin,imax,color="black")
    axs[1,0].vlines([imin,imax],jmin,jmax,color="black")
    models.plot_loss(loss_name="grad_rmse",ax=axs[1,1])
    plots.clamp_ylims(0,1,axs[1,1])
    axs[1,1].legend(loc="upper left",prop={'size': 8})
if show_vort:
    plots.imshow(psi_start[0,0],ax=axs[-1,0])
    axs[-1,0].hlines([jmin,jmax],imin,imax,color="black")
    axs[-1,0].vlines([imin,imax],jmin,jmax,color="black")
    models.plot_loss(loss_name="vorticity_rmse",ax=axs[-1,1])
    plots.clamp_ylims(0,1,axs[-1,1])
    axs[-1,1].legend(loc="upper left",prop={'size': 8})

# To Show
rg.show=False
rg_.show=False
forced.show=False
forced_.show=False
# Plots

show_grad = True
show_vort = True

fig,axs = plots.subplots(1+show_grad+show_vort,2, gridspec_kw = {"width_ratios":[1,7.5]},figsize=(21,5+5*show_grad+5*show_vort))
fig.suptitle("Without RG & Forced")
plots.set_rowtitles(["RMSE"]+show_grad*[ "Gradient RMSE"] + show_vort*[ "Vorticity RMSE"],axs=axs)
plots.imshow(psi_start[0,0],ax=axs[0,0])
axs[0,0].hlines([jmin,jmax],imin,imax,color="black")
axs[0,0].vlines([imin,imax],jmin,jmax,color="black")
models.plot_loss(loss_name="rmse",ax=axs[0,1])
plots.clamp_ylims(0,1,axs[0,1])
axs[0,1].legend(loc="upper left",prop={'size': 8})
if show_grad:
    plots.imshow(psi_start[0,0],ax=axs[1,0])
    axs[1,0].hlines([jmin,jmax],imin,imax,color="black")
    axs[1,0].vlines([imin,imax],jmin,jmax,color="black")
    models.plot_loss(loss_name="grad_rmse",ax=axs[1,1])
    plots.clamp_ylims(0,1,axs[1,1])
    axs[1,1].legend(loc="upper left",prop={'size': 8})
if show_vort:
    plots.imshow(psi_start[0,0],ax=axs[-1,0])
    axs[-1,0].hlines([jmin,jmax],imin,imax,color="black")
    axs[-1,0].vlines([imin,imax],jmin,jmax,color="black")
    models.plot_loss(loss_name="vorticity_rmse",ax=axs[-1,1])
    plots.clamp_ylims(0,1,axs[-1,1])
    axs[-1,1].legend(loc="upper left",prop={'size': 8})

## [112, 176] x [64, 192]

### Model runs

In [None]:
from qgsw.logging.utils import box, step
from qgsw.utils.reshaping import crop

imin, imax = 112, 176
jmin, jmax = 64, 192
extract_psi_w = lambda psi: extract_psi_w_(psi,imin,imax,jmin,jmax)


model_3l.reset_time()
model_3l.set_psi(psi_start)

space_slice = SpaceDiscretization2D.from_coords(
    x_1d=P.space.remove_h().omega.xy.x[imin : imax + 1, 0],
    y_1d=P.space.remove_h().omega.xy.y[0, jmin : jmax + 1],
)

space_slice_w = SpaceDiscretization2D.from_coords(
    x_1d=P.space.remove_h().omega.xy.x[imin - p + 1 : imax + p, 0],
    y_1d=P.space.remove_h().omega.xy.y[0, jmin - p + 1 : jmax + p],
)
y_w = space_slice_w.q.xy.y[0, :].unsqueeze(0)
beta_effect_w = beta_plane.beta * (y_w - y0)

rg = ReducedGravity(space_slice)
rg.linestyle="dotted"

rg_ = ReducedGravityPerturbed(space_slice)
rg.label = "Reduced Gravity - Pert"

forced = ForcedRGDR(space_slice)
forced.label = "Forced"
forced.linestyle = "dotted"
forced.prefix = "results_forced_rg_dr_gamma1000_obstrack"

forced_ = ForcedRGDRPerturbed(space_slice)
forced_.label = "Forced - Pert"
forced_.prefix = "results_forced_rg_dr_pert_gamma1000_obstrack"

gauss_barotropic_rg = RGPsi2Transport(space_slice)
gauss_barotropic_rg.label = "GaussBarotropic"
gauss_barotropic_rg.color = "navy"
gauss_barotropic_rg.linestyle = "dotted"
gauss_barotropic_rg.prefix = "results_mixed_rg_ro_ge_g5_gamma0_1_obstrack"

gauss_barotropic_rg_noalpha = RGPsi2Transport(space_slice)
gauss_barotropic_rg_noalpha.label = "GaussBarotropic - NoAlpha"
gauss_barotropic_rg_noalpha.color = "lightblue"
gauss_barotropic_rg_noalpha.linestyle = "dotted"
gauss_barotropic_rg_noalpha.prefix = "results_mixed_rg_ro_ge_g5_noalpha_gamma0_1_obstrack"

gauss_barotropic_rg_dr_pert_noalpha_lr = RGPsi2TransportPerturbedDR(space_slice)
gauss_barotropic_rg_dr_pert_noalpha_lr.label = "GaussBarotropic - DR - Pert - NoAlpha"
gauss_barotropic_rg_dr_pert_noalpha_lr.color = "lightblue"
gauss_barotropic_rg_dr_pert_noalpha_lr.prefix = "results_mixed_rg_ro_ge_pert_lr_noalpha_gamma0_1_obstrack"

gauss_barotropic_rg_dr_pert_lr = RGPsi2TransportPerturbedDR(space_slice)
gauss_barotropic_rg_dr_pert_lr.label = "GaussBarotropic - DR - Pert"
gauss_barotropic_rg_dr_pert_lr.color = "navy"
gauss_barotropic_rg_dr_pert_lr.prefix = "results_mixed_rg_ro_ge_pert_lr5_gamma0_1_obstrack"

gauss_barotropic_rg_dr_nopert_noalpha_lr = RGPsi2TransportPerturbedDR_(space_slice)
gauss_barotropic_rg_dr_nopert_noalpha_lr.label = "GaussBarotropic - DR - NoPert - NoAlpha"
gauss_barotropic_rg_dr_nopert_noalpha_lr.color = "lightblue"
gauss_barotropic_rg_dr_nopert_noalpha_lr.linestyle = "dashed"
gauss_barotropic_rg_dr_nopert_noalpha_lr.prefix = "results_mixed_rg_ro_ge_nopert_lr_noalpha_gamma0_1_obstrack"

gauss_barotropic_rg_dr_nopert_lr = RGPsi2TransportPerturbedDR_(space_slice)
gauss_barotropic_rg_dr_nopert_lr.label = "GaussBarotropic - DR - NoPert"
gauss_barotropic_rg_dr_nopert_lr.color = "navy"
gauss_barotropic_rg_dr_nopert_lr.linestyle = "dashed"
gauss_barotropic_rg_dr_nopert_lr.prefix = "results_mixed_rg_ro_ge_nopert_lr5_gamma0_1_obstrack"

gauss_barotropic_psi2 = RGPsi2Transport(space_slice)
gauss_barotropic_psi2.label = "Psi2"
gauss_barotropic_psi2.color = "green"
gauss_barotropic_psi2.prefix = "results_psi2"


models = ModelsManagerOBC(
    rg,
    rg_,
    forced,
    forced_,
    gauss_barotropic_rg,
    gauss_barotropic_rg_noalpha,
    gauss_barotropic_rg_dr_nopert_lr,
    gauss_barotropic_rg_dr_nopert_noalpha_lr,
    gauss_barotropic_rg_dr_pert_lr,
    gauss_barotropic_rg_dr_pert_noalpha_lr,
)
models.ijs = (imin,imax,jmin,jmax)
models.save_params = True

models.set_wind_forcing(tx[imin:imax, jmin : jmax + 1],ty[imin : imax + 1, jmin:jmax])

psi0s = {}
dpsis = {}
gc.collect()
for c in range(n_cycles):
    torch.cuda.reset_peak_memory_stats()
    models.new_cycle()
    times = [model_3l.time.item()]

    psi0 = extract_psi_w(model_3l.psi[:,:2])
    psi0s[c] = psi0

    psis = [psi0]

    for _ in range(1, n_steps_per_cyle):
        model_3l.step()

        times.append(model_3l.time.item())

        psi = extract_psi_w(model_3l.psi[:,:2])

        psis.append(psi)
    dpsis[c] = (psis[-1]-psis[0])/(n_steps_per_cyle-1)/dt

    models.reset_time()

    models.setup(psis,times,beta_effect_w)

    models.compute_loss(crop(psis[0],p))
    for n in range(1,n_steps_per_cyle):
        models.step()
        models.compute_loss(crop(psis[n],p))
    

    torch.cuda.empty_cache()
    gc.collect()

    max_mem = torch.cuda.max_memory_allocated() / 1024 / 1024
    msg_mem = f"Cycle {step(c + 1, n_cycles)} | Max memory allocated: {max_mem:.1f} MB."
    logger.info(box(msg_mem, style="round"))

### Results

In [None]:
from matplotlib import pyplot as plt

from qgsw import plots

# To Show

# Plots

show_grad = True
show_vort = True

fig,axs = plots.subplots(1+show_grad+show_vort,2, gridspec_kw = {"width_ratios":[1,7.5]},figsize=(21,5+5*show_grad+5*show_vort))
plots.set_rowtitles(["RMSE"]+show_grad*[ "Gradient RMSE"] + show_vort*[ "Vorticity RMSE"],axs=axs)
plots.imshow(psi_start[0,0],ax=axs[0,0])
axs[0,0].hlines([jmin,jmax],imin,imax,color="black")
axs[0,0].vlines([imin,imax],jmin,jmax,color="black")
models.plot_loss(loss_name="rmse",ax=axs[0,1])
plots.clamp_ylims(0,1,axs[0,1])
axs[0,1].legend(loc="upper left",prop={'size': 8})
if show_grad:
    plots.imshow(psi_start[0,0],ax=axs[1,0])
    axs[1,0].hlines([jmin,jmax],imin,imax,color="black")
    axs[1,0].vlines([imin,imax],jmin,jmax,color="black")
    models.plot_loss(loss_name="grad_rmse",ax=axs[1,1])
    plots.clamp_ylims(0,1,axs[1,1])
    axs[1,1].legend(loc="upper left",prop={'size': 8})
if show_vort:
    plots.imshow(psi_start[0,0],ax=axs[-1,0])
    axs[-1,0].hlines([jmin,jmax],imin,imax,color="black")
    axs[-1,0].vlines([imin,imax],jmin,jmax,color="black")
    models.plot_loss(loss_name="vorticity_rmse",ax=axs[-1,1])
    plots.clamp_ylims(0,1,axs[-1,1])
    axs[-1,1].legend(loc="upper left",prop={'size': 8})

# To Show
rg.show=False
rg_.show=False
forced.show=False
forced_.show=False
# Plots

show_grad = True
show_vort = True

fig,axs = plots.subplots(1+show_grad+show_vort,2, gridspec_kw = {"width_ratios":[1,7.5]},figsize=(21,5+5*show_grad+5*show_vort))
fig.suptitle("Without RG & Forced")
plots.set_rowtitles(["RMSE"]+show_grad*[ "Gradient RMSE"] + show_vort*[ "Vorticity RMSE"],axs=axs)
plots.imshow(psi_start[0,0],ax=axs[0,0])
axs[0,0].hlines([jmin,jmax],imin,imax,color="black")
axs[0,0].vlines([imin,imax],jmin,jmax,color="black")
models.plot_loss(loss_name="rmse",ax=axs[0,1])
plots.clamp_ylims(0,1,axs[0,1])
axs[0,1].legend(loc="upper left",prop={'size': 8})
if show_grad:
    plots.imshow(psi_start[0,0],ax=axs[1,0])
    axs[1,0].hlines([jmin,jmax],imin,imax,color="black")
    axs[1,0].vlines([imin,imax],jmin,jmax,color="black")
    models.plot_loss(loss_name="grad_rmse",ax=axs[1,1])
    plots.clamp_ylims(0,1,axs[1,1])
    axs[1,1].legend(loc="upper left",prop={'size': 8})
if show_vort:
    plots.imshow(psi_start[0,0],ax=axs[-1,0])
    axs[-1,0].hlines([jmin,jmax],imin,imax,color="black")
    axs[-1,0].vlines([imin,imax],jmin,jmax,color="black")
    models.plot_loss(loss_name="vorticity_rmse",ax=axs[-1,1])
    plots.clamp_ylims(0,1,axs[-1,1])
    axs[-1,1].legend(loc="upper left",prop={'size': 8})

## [112, 176] x [256, 384]

### Model runs

In [None]:
from qgsw.logging.utils import box, step
from qgsw.utils.reshaping import crop

imin, imax = 112, 176
jmin, jmax = 256, 384
extract_psi_w = lambda psi: extract_psi_w_(psi,imin,imax,jmin,jmax)


model_3l.reset_time()
model_3l.set_psi(psi_start)

space_slice = SpaceDiscretization2D.from_coords(
    x_1d=P.space.remove_h().omega.xy.x[imin : imax + 1, 0],
    y_1d=P.space.remove_h().omega.xy.y[0, jmin : jmax + 1],
)

space_slice_w = SpaceDiscretization2D.from_coords(
    x_1d=P.space.remove_h().omega.xy.x[imin - p + 1 : imax + p, 0],
    y_1d=P.space.remove_h().omega.xy.y[0, jmin - p + 1 : jmax + p],
)
y_w = space_slice_w.q.xy.y[0, :].unsqueeze(0)
beta_effect_w = beta_plane.beta * (y_w - y0)

rg = ReducedGravity(space_slice)
rg.linestyle="dotted"

rg_ = ReducedGravityPerturbed(space_slice)
rg.label = "Reduced Gravity - Pert"

forced = ForcedRGDR(space_slice)
forced.label = "Forced"
forced.linestyle = "dotted"
forced.prefix = "results_forced_rg_dr_gamma1000_obstrack"

forced_ = ForcedRGDRPerturbed(space_slice)
forced_.label = "Forced - Pert"
forced_.prefix = "results_forced_rg_dr_pert_gamma1000_obstrack"

gauss_barotropic_rg = RGPsi2Transport(space_slice)
gauss_barotropic_rg.label = "GaussBarotropic"
gauss_barotropic_rg.color = "navy"
gauss_barotropic_rg.linestyle = "dotted"
gauss_barotropic_rg.prefix = "results_mixed_rg_ro_ge_g5_gamma0_1_obstrack"

gauss_barotropic_rg_noalpha = RGPsi2Transport(space_slice)
gauss_barotropic_rg_noalpha.label = "GaussBarotropic - NoAlpha"
gauss_barotropic_rg_noalpha.color = "lightblue"
gauss_barotropic_rg_noalpha.linestyle = "dotted"
gauss_barotropic_rg_noalpha.prefix = "results_mixed_rg_ro_ge_g5_noalpha_gamma0_1_obstrack"

gauss_barotropic_rg_dr_pert_noalpha_lr = RGPsi2TransportPerturbedDR(space_slice)
gauss_barotropic_rg_dr_pert_noalpha_lr.label = "GaussBarotropic - DR - Pert - NoAlpha"
gauss_barotropic_rg_dr_pert_noalpha_lr.color = "lightblue"
gauss_barotropic_rg_dr_pert_noalpha_lr.prefix = "results_mixed_rg_ro_ge_pert_lr_noalpha_gamma0_1_obstrack"

gauss_barotropic_rg_dr_pert_lr = RGPsi2TransportPerturbedDR(space_slice)
gauss_barotropic_rg_dr_pert_lr.label = "GaussBarotropic - DR - Pert"
gauss_barotropic_rg_dr_pert_lr.color = "navy"
gauss_barotropic_rg_dr_pert_lr.prefix = "results_mixed_rg_ro_ge_pert_lr5_gamma0_1_obstrack"

gauss_barotropic_rg_dr_nopert_noalpha_lr = RGPsi2TransportPerturbedDR_(space_slice)
gauss_barotropic_rg_dr_nopert_noalpha_lr.label = "GaussBarotropic - DR - NoPert - NoAlpha"
gauss_barotropic_rg_dr_nopert_noalpha_lr.color = "lightblue"
gauss_barotropic_rg_dr_nopert_noalpha_lr.linestyle = "dashed"
gauss_barotropic_rg_dr_nopert_noalpha_lr.prefix = "results_mixed_rg_ro_ge_nopert_lr_noalpha_gamma0_1_obstrack"

gauss_barotropic_rg_dr_nopert_lr = RGPsi2TransportPerturbedDR_(space_slice)
gauss_barotropic_rg_dr_nopert_lr.label = "GaussBarotropic - DR - NoPert"
gauss_barotropic_rg_dr_nopert_lr.color = "navy"
gauss_barotropic_rg_dr_nopert_lr.linestyle = "dashed"
gauss_barotropic_rg_dr_nopert_lr.prefix = "results_mixed_rg_ro_ge_nopert_lr5_gamma0_1_obstrack"

gauss_barotropic_psi2 = RGPsi2Transport(space_slice)
gauss_barotropic_psi2.label = "Psi2"
gauss_barotropic_psi2.color = "green"
gauss_barotropic_psi2.prefix = "results_psi2"


models = ModelsManagerOBC(
    rg,
    rg_,
    forced,
    forced_,
    gauss_barotropic_rg,
    gauss_barotropic_rg_noalpha,
    gauss_barotropic_rg_dr_nopert_lr,
    gauss_barotropic_rg_dr_nopert_noalpha_lr,
    gauss_barotropic_rg_dr_pert_lr,
    gauss_barotropic_rg_dr_pert_noalpha_lr,
)
models.ijs = (imin,imax,jmin,jmax)
models.save_params = True

models.set_wind_forcing(tx[imin:imax, jmin : jmax + 1],ty[imin : imax + 1, jmin:jmax])

psi0s = {}
dpsis = {}
gc.collect()
for c in range(n_cycles):
    torch.cuda.reset_peak_memory_stats()
    models.new_cycle()
    times = [model_3l.time.item()]

    psi0 = extract_psi_w(model_3l.psi[:,:2])
    psi0s[c] = psi0

    psis = [psi0]

    for _ in range(1, n_steps_per_cyle):
        model_3l.step()

        times.append(model_3l.time.item())

        psi = extract_psi_w(model_3l.psi[:,:2])

        psis.append(psi)
    dpsis[c] = (psis[-1]-psis[0])/(n_steps_per_cyle-1)/dt

    models.reset_time()

    models.setup(psis,times,beta_effect_w)

    models.compute_loss(crop(psis[0],p))
    for n in range(1,n_steps_per_cyle):
        models.step()
        models.compute_loss(crop(psis[n],p))
    

    torch.cuda.empty_cache()
    gc.collect()

    max_mem = torch.cuda.max_memory_allocated() / 1024 / 1024
    msg_mem = f"Cycle {step(c + 1, n_cycles)} | Max memory allocated: {max_mem:.1f} MB."
    logger.info(box(msg_mem, style="round"))

In [None]:
from matplotlib import pyplot as plt

from qgsw import plots

# To Show

# Plots

show_grad = True
show_vort = True

fig,axs = plots.subplots(1+show_grad+show_vort,2, gridspec_kw = {"width_ratios":[1,7.5]},figsize=(21,5+5*show_grad+5*show_vort))
plots.set_rowtitles(["RMSE"]+show_grad*[ "Gradient RMSE"] + show_vort*[ "Vorticity RMSE"],axs=axs)
plots.imshow(psi_start[0,0],ax=axs[0,0])
axs[0,0].hlines([jmin,jmax],imin,imax,color="black")
axs[0,0].vlines([imin,imax],jmin,jmax,color="black")
models.plot_loss(loss_name="rmse",ax=axs[0,1])
plots.clamp_ylims(0,1,axs[0,1])
axs[0,1].legend(loc="upper left",prop={'size': 8})
if show_grad:
    plots.imshow(psi_start[0,0],ax=axs[1,0])
    axs[1,0].hlines([jmin,jmax],imin,imax,color="black")
    axs[1,0].vlines([imin,imax],jmin,jmax,color="black")
    models.plot_loss(loss_name="grad_rmse",ax=axs[1,1])
    plots.clamp_ylims(0,1,axs[1,1])
    axs[1,1].legend(loc="upper left",prop={'size': 8})
if show_vort:
    plots.imshow(psi_start[0,0],ax=axs[-1,0])
    axs[-1,0].hlines([jmin,jmax],imin,imax,color="black")
    axs[-1,0].vlines([imin,imax],jmin,jmax,color="black")
    models.plot_loss(loss_name="vorticity_rmse",ax=axs[-1,1])
    plots.clamp_ylims(0,1,axs[-1,1])
    axs[-1,1].legend(loc="upper left",prop={'size': 8})

# To Show
rg.show=False
rg_.show=False
forced.show=False
forced_.show=False
# Plots

show_grad = True
show_vort = True

fig,axs = plots.subplots(1+show_grad+show_vort,2, gridspec_kw = {"width_ratios":[1,7.5]},figsize=(21,5+5*show_grad+5*show_vort))
fig.suptitle("Without RG & Forced")
plots.set_rowtitles(["RMSE"]+show_grad*[ "Gradient RMSE"] + show_vort*[ "Vorticity RMSE"],axs=axs)
plots.imshow(psi_start[0,0],ax=axs[0,0])
axs[0,0].hlines([jmin,jmax],imin,imax,color="black")
axs[0,0].vlines([imin,imax],jmin,jmax,color="black")
models.plot_loss(loss_name="rmse",ax=axs[0,1])
plots.clamp_ylims(0,1,axs[0,1])
axs[0,1].legend(loc="upper left",prop={'size': 8})
if show_grad:
    plots.imshow(psi_start[0,0],ax=axs[1,0])
    axs[1,0].hlines([jmin,jmax],imin,imax,color="black")
    axs[1,0].vlines([imin,imax],jmin,jmax,color="black")
    models.plot_loss(loss_name="grad_rmse",ax=axs[1,1])
    plots.clamp_ylims(0,1,axs[1,1])
    axs[1,1].legend(loc="upper left",prop={'size': 8})
if show_vort:
    plots.imshow(psi_start[0,0],ax=axs[-1,0])
    axs[-1,0].hlines([jmin,jmax],imin,imax,color="black")
    axs[-1,0].vlines([imin,imax],jmin,jmax,color="black")
    models.plot_loss(loss_name="vorticity_rmse",ax=axs[-1,1])
    plots.clamp_ylims(0,1,axs[-1,1])
    axs[-1,1].legend(loc="upper left",prop={'size': 8})

In [None]:
from qgsw.models.qg.stretching_matrix import compute_deformation_radii


compute_deformation_radii(gauss_barotropic_rg_dr_pert_lr.model.A,beta_plane.f0)/1000

In [None]:
compute_deformation_radii(gauss_barotropic_rg_dr_nopert_lr.model.A,beta_plane.f0)/1000

In [None]:
compute_deformation_radii(model_3l.A,beta_plane.f0)/1000