In [None]:
%load_ext autoreload
%autoreload 2

## Setup

### Libraries

In [None]:
import torch
import seaborn as sns
from qgsw.logging import getLogger, setup_root_logger
from qgsw.specs import defaults
import gc

torch.backends.cudnn.deterministic = True
torch.set_grad_enabled(False)
sns.set_theme("notebook",style="dark")

specs = defaults.get()

setup_root_logger(1)
logger = getLogger(__name__)

### Parameters

In [None]:
from qgsw.configs.core import Configuration
from qgsw.forcing.wind import WindForcing


config = Configuration.from_toml("../output/g5k/param_optim/_config.toml")

H = config.model.h
g_prime = config.model.g_prime
H1, H2 = H[0], H[1]
g1, g2 = g_prime[0], g_prime[1]
beta_plane = config.physics.beta_plane
bottom_drag_coef = config.physics.bottom_drag_coefficient
slip_coef = config.physics.slip_coef

wind = WindForcing.from_config(
    config.windstress,
    config.space,
    config.physics,
)
tx, ty = wind.compute()
p = 4

### Space

In [None]:
from qgsw.spatial.core.discretization import SpaceDiscretization3D


space = SpaceDiscretization3D.from_config(
    config.space,
    config.model,
)
dx, dy = space.dx, space.dy

In [None]:
from qgsw.solver.boundary_conditions.base import Boundaries


def get_psi_slices(imin:int,imax:int,jmin:int,jmax:int) -> tuple[slice]:
    return  [slice(imin, imax + 1), slice(jmin, jmax + 1)]

def extract_psi_w_(psi: torch.Tensor,imin:int,imax:int,jmin:int,jmax:int) -> torch.Tensor:
    """Extract psi."""
    psi_slices_w = get_psi_slices(imin-p,imax+p,jmin-p,jmax+p)
    return psi[..., *psi_slices_w]


def extract_psi_bc(psi: torch.Tensor) -> Boundaries:
    """Extract psi."""
    return Boundaries.extract(psi, p, -p - 1, p, -p - 1, 2)

### Simulation

In [None]:
dt = 7200
n_steps_per_cyle = 250
comparison_interval = 1
n_cycles = 3

### Outputs

In [None]:
save_videos = False

### RMSE

In [None]:
def rmse(f: torch.Tensor, f_ref: torch.Tensor) -> torch.Tensor:
    """RMSE."""
    return (f - f_ref).square().mean().sqrt() / f_ref.square().mean().sqrt()

## Initial condition

In [None]:
from qgsw.fields.variables.tuples import UVH
from qgsw.masks import Masks
from qgsw.models.qg.stretching_matrix import compute_A
from qgsw.models.qg.uvh.projectors.core import QGProjector
from qgsw.utils import covphys


P = QGProjector(
    A=compute_A(H=H, g_prime=g_prime),
    H=H.unsqueeze(-1).unsqueeze(-1),
    space=space,
    f0=beta_plane.f0,
    masks=Masks.empty(
        nx=config.space.nx,
        ny=config.space.ny,
    ),
)
uvh0 = UVH.from_file("../output/g5k/param_optim/_data_startup.pt")
psi_start = P.compute_p(covphys.to_cov(uvh0, dx, dy))[0] / beta_plane.f0

### Full-domain model

In [None]:
from qgsw.models.qg.psiq.core import QGPSIQ


model_3l = QGPSIQ(
    space_2d=space.remove_z_h(),
    H=H,
    beta_plane=config.physics.beta_plane,
    g_prime=g_prime,
)
model_3l.set_wind_forcing(tx, ty)
model_3l.masks = Masks.empty_tensor(
    model_3l.space.nx,
    model_3l.space.ny,
    device=specs["device"],
)
model_3l.bottom_drag_coef = bottom_drag_coef
model_3l.slip_coef = slip_coef
model_3l.dt = dt
y0 = model_3l.y0

## OBC models

### Base

In [None]:
from pathlib import Path
from typing import TypeVar

from matplotlib import pyplot as plt
import numpy as np

from qgsw.analysis.qg_model import ModelWrapper, ModelsManager
from qgsw.models.qg.psiq.core import QGPSIQCore
from qgsw.spatial.core.discretization import SpaceDiscretization2D

T = TypeVar("T", bound=QGPSIQCore)

class ModelWrapperOBC(ModelWrapper[T]):
    results_paths = Path("../output/g5k/param_optim")
    losses: list[list[torch.Tensor]]
    ijs:tuple[int,int,int,int] = None
    save_video = False
    show=True

    def __init__(self, space_2d: SpaceDiscretization2D) -> None:
        super().__init__(space_2d)
        self.losses = []

    def _set_params(self) -> None:
        space = self.model.space
        self.model.y0 = y0
        self.model.masks = Masks.empty_tensor(
            space.nx,
            space.ny,
            device=specs["device"],
        )
        self.model.bottom_drag_coef = 0
        self.model.wide = True
        self.model.slip_coef = slip_coef
        self.model.dt = dt
        
    def load(self, imin:int,imax:int,jmin:int,jmax:int)-> dict:
        indices = f"_{imin}_{imax}_{jmin}_{jmax}.pt"
        file = self.results_paths.joinpath(self.prefix+indices)
        return torch.load(file)
    def new_cycle(self) -> None:
        super().new_cycle()
        self.losses.append([])
    def add_loss(self, loss:float) -> None:
        self.losses[-1].append(loss)
        
    def plot_loss(self,*,ax:plt.Axes|None=None,cycle:int|None=None) -> None:
        if not self.show:
            return
        if ax is None:
            ax = plt.gca()
        loss = np.concatenate(self.losses) if cycle is None else np.array(self.losses[cycle])
        ax.plot(loss, **self.plot_kwargs)
    def save_param_video(self) -> None:
        if not ( self.save_video and save_videos):
            return
        return
        
M = TypeVar("M", bound=ModelWrapperOBC[QGPSIQCore])

class ModelsManagerOBC(ModelsManager[M]):
    def __init__(self, *mw:M) -> None:
        super().__init__(*mw)
        self.ijs = self.model_wrappers[0].ijs


    @property
    def ijs(self) -> tuple[int,int,int,int]:
        return self.model_wrappers[0].ijs
    @ijs.setter
    def ijs(self,ijs:tuple[int,int,int,int]) -> None:
        self.loop_over_models(lambda mw: setattr(mw,"ijs",ijs))

    def compute_loss(self, psi_ref:torch.Tensor) -> None:
        self.loop_over_models(lambda mw: mw.add_loss(rmse(mw.model.psi[0,0],psi_ref[0,0]).cpu().item()))
        
    
    def plot_loss(self,*,ax:plt.Axes|None=None,cycle:int|None=None) -> None:
        self.loop_over_models(lambda mw: mw.plot_loss(ax=ax,cycle=cycle))
    def save_param_video(self) -> None:
        self.loop_over_models(lambda mw: mw.save_param_video())

### Reduced gravity

In [None]:
from torch import Tensor
from qgsw.solver.finite_diff import laplacian
from qgsw.spatial.core.discretization import SpaceDiscretization2D
from qgsw.spatial.core.grid_conversion import interpolate
from qgsw.utils.interpolation import QuadraticInterpolation
from qgsw.utils.reshaping import crop


class ReducedGravity(ModelWrapperOBC[QGPSIQ]):
    prefix = None
    color = "blue"
    label="Reduced Gravity"
    def _init_model(self, space_2d:SpaceDiscretization2D) -> None:
        self.model= QGPSIQ(
            space_2d=space_2d,
            H=H[:1],
            beta_plane=beta_plane,
            g_prime=g_prime[:1]*g_prime[1:2]/(g_prime[:1]+g_prime[1:2]),
        )
        self._set_params()
    def compute_q(self,psi: Tensor, beta_effect:torch.Tensor) -> Tensor:
        return interpolate(laplacian(psi,dx,dy) - beta_plane.f0**2 * (1/H1/g1+1/H1/g2)*psi[...,1:-1,1:-1]) + beta_effect
    
    def setup(self, psis: list[torch.Tensor],times:list[torch.Tensor],beta_effect_w:torch.Tensor) -> None:
        psi0 = psis[0]
        psi_bcs = [extract_psi_bc(psi[:,:1]) for psi in psis]
        q_bcs = [
            Boundaries.extract(
                self.compute_q(psi[:, :1],beta_effect_w), 2, -3, 2, -3, 3
            )
            for psi in psis
        ]
        self.model.set_boundary_maps(QuadraticInterpolation(times, psi_bcs), QuadraticInterpolation(times, q_bcs))
        self.model.set_psiq(crop(psi0[:,:1],p), crop(self.compute_q(psi0[:,:1],beta_effect_w),p-1))

### Two-layers

In [None]:
from torch._tensor import Tensor

from qgsw.utils.reshaping import crop


class TwoLayers(ModelWrapperOBC[QGPSIQ]):
    prefix = None
    color = "black"
    label="Two Layers"
    def _init_model(self, space_2d:SpaceDiscretization2D) -> None:
        self.model= QGPSIQ(
            space_2d=space_2d,
            H=H[:2],
            beta_plane=beta_plane,
            g_prime=g_prime[:2],
        )
        self._set_params()
    def compute_q(self,psi: Tensor, beta_effect:torch.Tensor) -> Tensor:
        return interpolate(laplacian(psi,dx,dy) - beta_plane.f0**2 * torch.einsum("lm,...mxy->...lxy",self.model.A,psi[...,1:-1,1:-1]))+beta_effect
    def setup(self, psis: list[Tensor], times:list[torch.Tensor],beta_effect_w: Tensor) -> None:
        psi0 = psis[0]
        psi_bcs = [extract_psi_bc(psi[:,:2]) for psi in psis]
        q_bcs = [
            Boundaries.extract(
                self.compute_q(psi[:, :2],beta_effect_w), 2, -3, 2, -3, 3
            )
            for psi in psis
        ]
        self.model.set_boundary_maps(QuadraticInterpolation(times, psi_bcs), QuadraticInterpolation(times, q_bcs))
        self.model.set_psiq(crop(psi0[:,:2],p), crop(self.compute_q(psi0[:,:2],beta_effect_w),p-1))

### Alpha

In [None]:
from torch._tensor import Tensor
from qgsw.models.qg.psiq.modified.core import QGPSIQCollinearSF
from qgsw.utils.reshaping import crop


class Alpha(ModelWrapperOBC[QGPSIQCollinearSF]):
    prefix = "results_single_alpha"
    color = "red"
    label="Collinear single"
    def _init_model(self, space_2d:SpaceDiscretization2D) -> None:
        self.model= QGPSIQCollinearSF(
            space_2d=space_2d,
            H=H[:2],
            beta_plane=beta_plane,
            g_prime=g_prime[:2],
        )
        self._set_params()
        self.alphas = {}
    def compute_q(self,psi: Tensor, beta_effect:torch.Tensor, alpha:torch.Tensor) -> Tensor:
        return interpolate(
            laplacian(psi, dx, dy)
            - beta_plane.f0**2 * (1 / H1 / g1 + 1 / H1 / g2) * psi[..., 1:-1, 1:-1]
            + beta_plane.f0**2 * (1 / H1 / g2) * alpha * psi[..., 1:-1, 1:-1]
        ) + beta_effect
    def setup(self, psis: list[Tensor], times: list[Tensor], beta_effect_w: Tensor) -> None:
        res = self.load(*self.ijs)
        alpha= res[self.cycle]["alpha"]
        if self.save_params:
            self.alphas[self.cycle] = alpha
        self.model.alpha = torch.ones_like(self.model.psi)*alpha
        psi0 = psis[0]
        psi_bcs = [extract_psi_bc(psi[:,:1]) for psi in psis]
        q_bcs = [
            Boundaries.extract(
                self.compute_q(psi[:, :1],beta_effect_w,alpha), 2, -3, 2, -3, 3
            )
            for psi in psis
        ]
        self.model.set_boundary_maps(QuadraticInterpolation(times, psi_bcs), QuadraticInterpolation(times, q_bcs))
        self.model.set_psiq(crop(psi0[:,:1],p), crop(self.compute_q(psi0[:,:1],beta_effect_w,alpha),p-1))


### Affine

In [None]:
from torch._tensor import Tensor
from qgsw.models.qg.psiq.modified.core import QGPSIQFixeddSF2
from qgsw.plots.heatmaps import AnimatedHeatmaps
from qgsw.utils.reshaping import crop


class Affine(ModelWrapperOBC[QGPSIQFixeddSF2]):
    prefix = "results_psi2"
    color="orange"
    label="Affine"
    def _init_model(self, space_2d:SpaceDiscretization2D) -> None:
        self.model= QGPSIQFixeddSF2(
            space_2d=space_2d,
            H=H[:2],
            beta_plane=beta_plane,
            g_prime=g_prime[:2],
        )
        self._set_params()
        self.psi2s = {}
        self.dpsi2s = {}
    def compute_q(self,psi: Tensor, beta_effect:torch.Tensor, psi2:torch.Tensor) -> Tensor:
        return interpolate(
            laplacian(psi, dx, dy)
            - beta_plane.f0**2 * (1 / H1 / g1 + 1 / H1 / g2) * psi[..., 1:-1, 1:-1]
            + beta_plane.f0**2 * (1 / H1 / g2) * psi2[..., 1:-1, 1:-1]
        ) + beta_effect
    
    def setup(self, psis: list[Tensor], times: list[Tensor], beta_effect_w: Tensor) -> None:
        res = self.load(*self.ijs)
        psi2 = res[self.cycle]["psi2"].to(**specs)
        dpsi2 = res[self.cycle]["dpsi2"].to(**specs)
        if self.save_params:
            self.dpsi2s[self.cycle] = dpsi2
            self.psi2s[self.cycle] = psi2
        self.model.dpsi2 = crop(dpsi2,p)
        psi0 = psis[0]
        psi_bcs = [extract_psi_bc(psi[:,:1]) for psi in psis]
        q_bcs = [
            Boundaries.extract(
                self.compute_q(psi[:, :1],beta_effect_w,psi2+n*dt*dpsi2), 2, -3, 2, -3, 3
            )
            for n,psi in enumerate(psis)
        ]
        self.model.set_boundary_maps(QuadraticInterpolation(times, psi_bcs), QuadraticInterpolation(times, q_bcs))
        self.model.set_psiq(crop(psi0[:,:1],p), crop(self.compute_q(psi0[:,:1],beta_effect_w,psi2),p-1))

    def save_param_video(self) -> None:
        if not ( self.save_video and save_videos):
            return

### Mixed

In [None]:
from matplotlib.animation import FuncAnimation
from torch._tensor import Tensor
from qgsw import plots
from qgsw.models.qg.psiq.modified.core import QGPSIQMixed
from qgsw.plots.plt_wrapper import default_clim, retrieve_colorbar, retrieve_imshow_data


class Mixed(ModelWrapperOBC[QGPSIQMixed]):
    prefix = "results_mixed_s"
    color="violet"
    label="Mixed Single Alpha"
    save_video = True
    def _init_model(self, space_2d:SpaceDiscretization2D) -> None:
        self.model= QGPSIQMixed(
            space_2d=space_2d,
            H=H[:2],
            beta_plane=beta_plane,
            g_prime=g_prime[:2],
        )
        self._set_params()
        self.alphas = {}
        self.psi2s = {}
        self.dpsi2s = {}
    def compute_q(self,psi: Tensor, beta_effect:torch.Tensor, alpha:torch.Tensor, psi2:torch.Tensor) -> Tensor:
        return interpolate(
            laplacian(psi, dx, dy)
            - beta_plane.f0**2 * (1 / H1 / g1 + 1 / H1 / g2) * psi[..., 1:-1, 1:-1]
            + beta_plane.f0**2 * (1 / H1 / g2) * (psi2[...,1:-1,1:-1]+alpha * psi[..., 1:-1, 1:-1])
        ) + beta_effect
    def setup(self, psis: list[Tensor], times: list[Tensor], beta_effect_w: Tensor) -> None:
        res = self.load(*self.ijs)
        alpha:torch.Tensor = res[self.cycle]["alpha"]
        psi2:torch.Tensor = res[self.cycle]["psi2"].to(**specs)
        dpsi2:torch.Tensor = res[self.cycle]["dpsi2"].to(**specs)
        if self.save_params:
            self.alphas[self.cycle] = alpha
            self.dpsi2s[self.cycle] = dpsi2
            self.psi2s[self.cycle] = psi2
        psi0 = psis[0]
        psi_bcs = [extract_psi_bc(psi[:,:1]) for psi in psis]
        q_bcs = [
                Boundaries.extract(
                    self.compute_q(psi[:, :1],beta_effect_w,alpha,psi2+n*dt*dpsi2), 2, -3, 2, -3, 3
                )
                for n,psi in enumerate(psis)
            ]

        self.model.set_psiq(crop(psi0[:,:1],p), crop(self.compute_q(psi0[:,:1],beta_effect_w,alpha,psi2),p-1))
        self.model.alpha = torch.ones_like(self.model.psi)*alpha
        self.model.psi2 = crop(psi2,p)
        self.model.set_boundary_maps(QuadraticInterpolation(times, psi_bcs), QuadraticInterpolation(times, q_bcs))
        self.model.dpsi2 = crop(dpsi2,p)
        if self.save_video and save_videos:
            # self.psi2_fields = [alpha*crop(psi0[:,:1],p) + self.model.psi2]
            self.psi2_fields = [self.model.psi2]
            self.psi1_fields = [self.model.psi[:,:1]]
    def step(self) -> None:
        super().step()
        if self.save_video and save_videos:
            # self.psi2_fields.append(
            #     self.model.alpha*self.model.psi[:,:1]+ self.model.psi2 + self.model.time.item()*self.model.dpsi2
            # )
            self.psi2_fields.append(
                self.model.psi2 + self.model.time.item()*self.model.dpsi2
            )
            self.psi1_fields.append(self.model.psi[:,:1])
    def save_param_video(self) -> None:
        if not ( self.save_video and save_videos):
            return
        imin,imax,jmin,jmax = self.ijs
        output_folder = Path(f"../output/videos/{self.prefix}")
        if not output_folder.is_dir():
            output_folder.mkdir(parents=True)
            
        fig, axs = plots.subplots(1,2)

        im1=plots.imshow(self.psi1_fields[0][0,0],title="ѱ₁", ax=axs[0,0])
        im2 =plots.imshow(self.psi2_fields[0][0,0],title="ѱ₂",ax=axs[0,1])

        def func(frame: int) -> None:

            data1 = self.psi1_fields[frame][0,0]
            data2 = self.psi2_fields[frame][0,0]

            data_array = retrieve_imshow_data(data1)
            im1.set_array(data_array)
            clim = default_clim(data_array)
            im1.set_clim(*clim)
            retrieve_colorbar(im1, im1.axes).update_normal(im1)

            data_array = retrieve_imshow_data(data2)
            im2.set_array(data_array)
            clim = default_clim(data_array)
            im2.set_clim(*clim)
            retrieve_colorbar(im2, im2.axes).update_normal(im2)

            fig.canvas.draw_idle()

            return im1, im2
            
        FuncAnimation(fig,func, frames=len(self.psi1_fields),blit=True).save(output_folder.joinpath(f"{imin}_{imax}_{jmin}_{jmax}_{self.cycle}.mp4"), fps=25)

### Forced

In [None]:
from qgsw.decomposition.wavelets import WaveletBasis
from qgsw.models.qg.psiq.modified.forced import QGPSIQForced


class ForcedRG(ModelWrapperOBC[QGPSIQForced]):
    prefix = "results_forced_rg"
    color="brown"
    label="Forced"
    save_video=True
    def _init_model(self, space_2d:SpaceDiscretization2D) -> None:
        self.model= QGPSIQForced(
            space_2d=space_2d,
            H=H[:1],
            beta_plane=beta_plane,
            g_prime=g_prime[:1]*g_prime[1:2]/(g_prime[:1]+g_prime[1:2]),
        )
        self._set_params()
        self.coefs = {}
    def compute_q(self,psi: Tensor, beta_effect:torch.Tensor) -> Tensor:
        return interpolate(
            laplacian(psi,dx,dy)
            - beta_plane.f0**2 * (1/H1/g1+1/H1/g2)*psi[...,1:-1,1:-1]
        ) + beta_effect
    
    def _parse_coefs(self, res:dict) -> dict[int, torch.Tensor]:
        coefs = {}
        for k,v in res.items():
            if k[:6] != "coefs_":
                continue
            o = int(k[6:])
            coefs[o] = v.to(**specs)
        return coefs

    def setup(self, psis: list[Tensor], times: list[Tensor], beta_effect_w: Tensor) -> None:
        res = self.load(*self.ijs)
        wv_space = res[self.cycle]["config"]["wv_space"]
        wv_time = res[self.cycle]["config"]["wv_time"]
        self.basis = WaveletBasis(
            wv_space,wv_time,
        )
        self.basis.n_theta=res[self.cycle]["config"]["n_theta"]
        coefs:dict[int,torch.Tensor] = self._parse_coefs(res[self.cycle])
        if self.save_params:
            self.coefs[self.cycle] = coefs
        self.basis.set_coefs(coefs)
        self.wv = self.basis.localize(
            self.model.space.remove_z_h().q.xy.x,
            self.model.space.remove_z_h().q.xy.y,
        )
        psi0 = psis[0]
        psi_bcs = [extract_psi_bc(psi[:,:1]) for psi in psis]
        q_bcs = [
                Boundaries.extract(
                    self.compute_q(psi[:, :1],beta_effect_w), 2, -3, 2, -3, 3
                )
                for psi in psis
            ]

        self.model.set_psiq(crop(psi0[:,:1],p), crop(self.compute_q(psi0[:,:1],beta_effect_w),p-1))
        self.model.set_boundary_maps(QuadraticInterpolation(times, psi_bcs), QuadraticInterpolation(times, q_bcs))
        if self.save_video and save_videos:
            self.psi1_fields = [self.model.psi[:,:1]]
    def step(self) -> None:
        self.model.forcing = self.wv(self.model.time)
        super().step()
        if self.save_video and save_videos:
            self.psi1_fields.append(self.model.psi[:,:1])
    

    def save_param_video(self) -> None:
        if not ( self.save_video and save_videos):
            return
        imin,imax,jmin,jmax = self.ijs
        psi2s = [
            self.wv(torch.tensor([n*self.model.dt],**specs)) for n in range(n_steps_per_cyle)
        ]
        output_folder = Path(f"../output/videos/{self.prefix}")
        if not output_folder.is_dir():
            output_folder.mkdir(parents=True)
            
        fig, axs = plots.subplots(1,2)

        im1=plots.imshow(self.psi1_fields[0][0,0],title="ѱ₁", ax=axs[0,0])
        im2 =plots.imshow(psi2s[0],title="ε",ax=axs[0,1])

        def func(frame: int) -> None:

            data1 = self.psi1_fields[frame][0,0]
            data2 = psi2s[frame]

            data_array = retrieve_imshow_data(data1)
            im1.set_array(data_array)
            clim = default_clim(data_array)
            im1.set_clim(*clim)
            retrieve_colorbar(im1, im1.axes).update_normal(im1)

            data_array = retrieve_imshow_data(data2)
            im2.set_array(data_array)
            clim = default_clim(data_array)
            im2.set_clim(*clim)
            retrieve_colorbar(im2, im2.axes).update_normal(im2)

            fig.canvas.draw_idle()

            return im1, im2
            
        FuncAnimation(fig,func, frames=len(self.psi1_fields),blit=True).save(output_folder.joinpath(f"{imin}_{imax}_{jmin}_{jmax}_{self.cycle}.mp4"), fps=25)

### Forced MD

In [None]:
from qgsw.decomposition.wavelets import WaveletBasis
from qgsw.models.qg.psiq.modified.forced import QGPSIQForcedMDWV


class ForcedMD(ModelWrapperOBC[QGPSIQForcedMDWV]):
    prefix = "results_forced_md"
    color="green"
    label="Forced MD"
    save_video=True
    def _init_model(self, space_2d:SpaceDiscretization2D) -> None:
        self.model= QGPSIQForcedMDWV(
            space_2d=space_2d,
            H=H[:2],
            beta_plane=beta_plane,
            g_prime=g_prime[:2],
        )
        self._set_params()
        self.coefs = {}
    
    def _parse_coefs(self, res:dict) -> dict[int, torch.Tensor]:
        coefs = {}
        for k,v in res.items():
            if k[:6] != "coefs_":
                continue
            o = int(k[6:])
            coefs[o] = v.to(**specs)
        return coefs

    def setup(self, psis: list[Tensor], times: list[Tensor], beta_effect_w: Tensor) -> None:
        res = self.load(*self.ijs)
        imin,imax,jmin,jmax = self.ijs
        space_ww = space.remove_z_h().slice(imin - p, imax + p + 1, jmin - p, jmax + p + 1)
        wv_space = res[self.cycle]["config"]["wv_space"]
        wv_time = res[self.cycle]["config"]["wv_time"]
        self.basis = WaveletBasis(
            wv_space,wv_time,
        )
        self.basis.n_theta=res[self.cycle]["config"]["n_theta"]
        coefs:dict[int,torch.Tensor] = self._parse_coefs(res[self.cycle])
        if self.save_params:
            self.coefs[self.cycle] = coefs
        self.basis.set_coefs(coefs)
        wv_loc = self.basis.localize(space_ww.psi.xy.x,space_ww.psi.xy.y)
        self.model.wavelets = self.basis
        psi0 = psis[0]
        psi_bcs = [extract_psi_bc(psi[:,:1]) for psi in psis]
        q_bcs = [
                Boundaries.extract(
                    self.compute_q(psi[:, :1],wv_loc(torch.tensor([n*self.model.dt],**specs)),beta_effect_w), 2, -3, 2, -3, 3
                )
                for n,psi in enumerate(psis)
            ]

        self.model.set_psiq(crop(psi0[:,:1],p), crop(self.compute_q(psi0[:,:1],wv_loc(torch.tensor([0],**specs)),beta_effect_w),p-1))
        self.model.set_boundary_maps(QuadraticInterpolation(times, psi_bcs), QuadraticInterpolation(times, q_bcs))
        if self.save_video and save_videos:
            self.psi1_fields = [self.model.psi[:,:1]]
    
    def compute_q(self,psi: Tensor, psi2:torch.Tensor, beta_effect:torch.Tensor) -> Tensor:
        return interpolate(
            laplacian(psi,dx,dy)
            - beta_plane.f0**2 * (1/H1/g1+1/H1/g2)*psi[...,1:-1,1:-1]
            + beta_plane.f0**2 * (1/H1/g2)*psi2[...,1:-1,1:-1]
        ) + beta_effect
    def step(self) -> None:
        super().step()
        if self.save_video and save_videos:
            self.psi1_fields.append(self.model.psi[:,:1])

    def save_param_video(self) -> None:
        if not ( self.save_video and save_videos):
            return
        imin,imax,jmin,jmax = self.ijs
        space_slice = space.remove_z_h().slice(imin - p, imax + p + 1, jmin - p, jmax + p + 1)
        wv = self.model.wavelets.localize(space_slice.psi.xy.x,space_slice.psi.xy.y)
        psi2s = [
            wv(torch.tensor([n*self.model.dt],**specs)) for n in range(n_steps_per_cyle)
        ]
        output_folder = Path(f"../output/videos/{self.prefix}")
        if not output_folder.is_dir():
            output_folder.mkdir(parents=True)
            
        fig, axs = plots.subplots(1,2)

        im1=plots.imshow(self.psi1_fields[0][0,0],title="ѱ₁", ax=axs[0,0])
        im2 =plots.imshow(psi2s[0],title="ѱ₂",ax=axs[0,1])

        def func(frame: int) -> None:

            data1 = self.psi1_fields[frame][0,0]
            data2 = psi2s[frame]

            data_array = retrieve_imshow_data(data1)
            im1.set_array(data_array)
            clim = default_clim(data_array)
            im1.set_clim(*clim)
            retrieve_colorbar(im1, im1.axes).update_normal(im1)

            data_array = retrieve_imshow_data(data2)
            im2.set_array(data_array)
            clim = default_clim(data_array)
            im2.set_clim(*clim)
            retrieve_colorbar(im2, im2.axes).update_normal(im2)

            fig.canvas.draw_idle()

            return im1, im2
            
        FuncAnimation(fig,func, frames=len(self.psi1_fields),blit=True).save(output_folder.joinpath(f"{imin}_{imax}_{jmin}_{jmax}_{self.cycle}.mp4"), fps=25)

### Forced Col MD

In [None]:
from qgsw.decomposition.wavelets import WaveletBasis
from qgsw.models.qg.psiq.modified.forced import QGPSIQForcedMDWV
from qgsw.pv import compute_q1_interior


class ForcedColMD(ModelWrapperOBC[QGPSIQForcedMDWV]):
    prefix = "results_forced_colmd"
    color="yellowgreen"
    label="Forced Col MD"
    save_video=True
    def _init_model(self, space_2d:SpaceDiscretization2D) -> None:
        self.model= QGPSIQForcedMDWV(
            space_2d=space_2d,
            H=H[:2],
            beta_plane=beta_plane,
            g_prime=g_prime[:2],
        )
        self._set_params()
        self.coefs = {}
        self.alphas = {}
    
    def _parse_coefs(self, res:dict) -> dict[int, torch.Tensor]:
        coefs = {}
        for k,v in res.items():
            if k[:6] != "coefs_":
                continue
            o = int(k[6:])
            coefs[o] = v.to(**specs)
        return coefs

    def setup(self, psis: list[Tensor], times: list[Tensor], beta_effect_w: Tensor) -> None:
        res = self.load(*self.ijs)
        imin,imax,jmin,jmax = self.ijs
        space_ww = space.remove_z_h().slice(imin - p, imax + p + 1, jmin - p, jmax + p + 1)
        wv_space = res[self.cycle]["config"]["wv_space"]
        wv_time = res[self.cycle]["config"]["wv_time"]
        self.basis = WaveletBasis(
            wv_space,wv_time,
        )
        self.basis.n_theta=res[self.cycle]["config"]["n_theta"]
        coefs:dict[int,torch.Tensor] = self._parse_coefs(res[self.cycle])
        alpha:float = res[self.cycle]["alpha"]
        if self.save_params:
            self.coefs[self.cycle] = coefs
            self.alphas[self.cycle] = alpha
        self.basis.set_coefs(coefs)
        self.wv = self.basis.localize(space_ww.psi.xy.x,space_ww.psi.xy.y)
        self.model.wavelets = self.basis
        self.model.alpha = torch.ones_like(self.model.psi)*alpha
        psi0 = psis[0]
        psi_bcs = [extract_psi_bc(psi[:,:1]) for psi in psis]
        q_bcs = [
                Boundaries.extract(
                    self.compute_q(psi[:, :1],self.wv(torch.tensor([n*self.model.dt],**specs))+alpha*psi[:,:1],beta_effect_w), 2, -3, 2, -3, 3
                )
                for n,psi in enumerate(psis)
            ]

        self.model.set_psiq(crop(psi0[:,:1],p), crop(self.compute_q(psi0[:,:1],self.wv(torch.tensor([0],**specs))+alpha*psi0[:,:1],beta_effect_w),p-1))
        self.model.set_boundary_maps(QuadraticInterpolation(times, psi_bcs), QuadraticInterpolation(times, q_bcs))
        if self.save_video and save_videos:
            # self.psi2_fields = [alpha*crop(psi0[:,:1],p)[0,0] + crop(self.wv(self.model.time),p)]
            self.psi2_fields = [crop(self.wv(self.model.time),p)]
            self.psi1_fields = [self.model.psi[:,:1]]
    
    def compute_q(self,psi: Tensor, psi2:torch.Tensor, beta_effect:torch.Tensor) -> Tensor:
        # return interpolate(
        #     laplacian(psi,dx,dy)
        #     - beta_plane.f0**2 * (1/H1/g1+1/H1/g2)*psi[...,1:-1,1:-1]
        #     + beta_plane.f0**2 * (1/H1/g2)*psi2[...,1:-1,1:-1]
        # ) + beta_effect
        return compute_q1_interior(
            psi,
            psi2,
            H1,
            g1,
            g2,
            dx,
            dy,
            beta_plane.f0,
            beta_effect,
        )
    def step(self) -> None:
        super().step()
        if self.save_video and save_videos:
            # self.psi2_fields.append(
            #     self.model.alpha[0,0]*self.model.psi[0,0] + crop(self.wv(self.model.time),p)
            # )
            self.psi2_fields.append(
                crop(self.wv(self.model.time),p)
            )
            self.psi1_fields.append(self.model.psi[:,:1])
    def save_param_video(self) -> None:
        if not ( self.save_video and save_videos):
            return
        imin,imax,jmin,jmax = self.ijs
        output_folder = Path(f"../output/videos/{self.prefix}")
        if not output_folder.is_dir():
            output_folder.mkdir(parents=True)
            
        fig, axs = plots.subplots(1,2)

        im1=plots.imshow(self.psi1_fields[0][0,0],title="ѱ₁", ax=axs[0,0])
        im2 =plots.imshow(self.psi2_fields[0],title="ѱ₂",ax=axs[0,1])

        def func(frame: int) -> None:

            data1 = self.psi1_fields[frame][0,0]
            data2 = self.psi2_fields[frame]

            data_array = retrieve_imshow_data(data1)
            im1.set_array(data_array)
            clim = default_clim(data_array)
            im1.set_clim(*clim)
            retrieve_colorbar(im1, im1.axes).update_normal(im1)

            data_array = retrieve_imshow_data(data2)
            im2.set_array(data_array)
            clim = default_clim(data_array)
            im2.set_clim(*clim)
            retrieve_colorbar(im2, im2.axes).update_normal(im2)

            fig.canvas.draw_idle()

            return im1, im2
            
        FuncAnimation(fig,func, frames=len(self.psi1_fields),blit=True).save(output_folder.joinpath(f"{imin}_{imax}_{jmin}_{jmax}_{self.cycle}.mp4"), fps=25)

### Forced RG MD

In [None]:
from qgsw.decomposition.wavelets import WaveletBasis
from qgsw.models.qg.psiq.modified.forced import QGPSIQForcedRGMDWV


class ForcedRGMD(ModelWrapperOBC[QGPSIQForcedRGMDWV]):
    prefix = "results_forced_rgmd"
    color="red"
    label="Forced RG MD"
    save_video=True
    def _init_model(self, space_2d:SpaceDiscretization2D) -> None:
        self.model= QGPSIQForcedRGMDWV(
            space_2d=space_2d,
            H=H[:2],
            beta_plane=beta_plane,
            g_prime=g_prime[:2],
        )
        self._set_params()
        self.coefs = {}
    def compute_q(self,psi: Tensor, beta_effect:torch.Tensor) -> Tensor:
        return interpolate(
            laplacian(psi,dx,dy)
            - beta_plane.f0**2 * (1/H1/g1+1/H1/g2)*psi[...,1:-1,1:-1]
        ) + beta_effect
    
    def _parse_coefs(self, res:dict) -> dict[int, torch.Tensor]:
        coefs = {}
        for k,v in res.items():
            if k[:6] != "coefs_":
                continue
            o = int(k[6:])
            coefs[o] = v.to(**specs)
        return coefs

    def setup(self, psis: list[Tensor], times: list[Tensor], beta_effect_w: Tensor) -> None:
        res = self.load(*self.ijs)
        wv_space = res[self.cycle]["config"]["wv_space"]
        wv_time = res[self.cycle]["config"]["wv_time"]
        self.basis = WaveletBasis(
            wv_space,wv_time,
        )
        self.basis.n_theta=res[self.cycle]["config"]["n_theta"]
        coefs:dict[int,torch.Tensor] = self._parse_coefs(res[self.cycle])
        if self.save_params:
            self.coefs[self.cycle] = coefs
        self.basis.set_coefs(coefs)
        self.model.wavelets = self.basis
        psi0 = psis[0]
        psi_bcs = [extract_psi_bc(psi[:,:1]) for psi in psis]
        q_bcs = [
                Boundaries.extract(
                    self.compute_q(psi[:, :1],beta_effect_w), 2, -3, 2, -3, 3
                )
                for psi in psis
            ]

        self.model.set_psiq(crop(psi0[:,:1],p), crop(self.compute_q(psi0[:,:1],beta_effect_w),p-1))
        self.model.set_boundary_maps(QuadraticInterpolation(times, psi_bcs), QuadraticInterpolation(times, q_bcs))
        if self.save_video and save_videos:
            self.psi1_fields = [self.model.psi[:,:1]]
    
    def step(self) -> None:
        super().step()
        if self.save_video and save_videos:
            self.psi1_fields.append(self.model.psi[:,:1])

    def save_param_video(self) -> None:
        if not ( self.save_video and save_videos):
            return
        imin,imax,jmin,jmax = self.ijs
        space_slice = space.remove_z_h().slice(imin - p, imax + p + 1, jmin - p, jmax + p + 1)
        wv = self.model.wavelets.localize(space_slice.psi.xy.x,space_slice.psi.xy.y)
        psi2s = [
            wv(torch.tensor([n*self.model.dt],**specs)) for n in range(n_steps_per_cyle)
        ]
        output_folder = Path(f"../output/videos/{self.prefix}")
        if not output_folder.is_dir():
            output_folder.mkdir(parents=True)
            
        fig, axs = plots.subplots(1,2)

        im1=plots.imshow(self.psi1_fields[0][0,0],title="ѱ₁", ax=axs[0,0])
        im2 =plots.imshow(psi2s[0],title="ѱ₂",ax=axs[0,1])

        def func(frame: int) -> None:

            data1 = self.psi1_fields[frame][0,0]
            data2 = psi2s[frame]

            data_array = retrieve_imshow_data(data1)
            im1.set_array(data_array)
            clim = default_clim(data_array)
            im1.set_clim(*clim)
            retrieve_colorbar(im1, im1.axes).update_normal(im1)

            data_array = retrieve_imshow_data(data2)
            im2.set_array(data_array)
            clim = default_clim(data_array)
            im2.set_clim(*clim)
            retrieve_colorbar(im2, im2.axes).update_normal(im2)

            fig.canvas.draw_idle()

            return im1, im2
            
        FuncAnimation(fig,func, frames=len(self.psi1_fields),blit=True).save(output_folder.joinpath(f"{imin}_{imax}_{jmin}_{jmax}_{self.cycle}.mp4"), fps=25)

### Forced Col RG MD

In [None]:
from qgsw.decomposition.wavelets import WaveletBasis
from qgsw.models.qg.psiq.modified.forced import QGPSIQForcedRGMDWV


class ForcedColRGMD(ModelWrapperOBC[QGPSIQForcedRGMDWV]):
    prefix = "results_forced_colrgmd"
    color="coral"
    label="Forced RG MD"
    save_video=True
    def _init_model(self, space_2d:SpaceDiscretization2D) -> None:
        self.model= QGPSIQForcedRGMDWV(
            space_2d=space_2d,
            H=H[:2],
            beta_plane=beta_plane,
            g_prime=g_prime[:2],
        )
        self._set_params()
        self.coefs = {}
        self.alphas = {}
    def compute_q(self,psi: Tensor, beta_effect:torch.Tensor) -> Tensor:
        return interpolate(
            laplacian(psi,dx,dy)
            - beta_plane.f0**2 * (1/H1/g1+1/H1/g2)*psi[...,1:-1,1:-1]
        ) + beta_effect
    
    def _parse_coefs(self, res:dict) -> dict[int, torch.Tensor]:
        coefs = {}
        for k,v in res.items():
            if k[:6] != "coefs_":
                continue
            o = int(k[6:])
            coefs[o] = v.to(**specs)
        return coefs

    def setup(self, psis: list[Tensor], times: list[Tensor], beta_effect_w: Tensor) -> None:
        res = self.load(*self.ijs)
        wv_space = res[self.cycle]["config"]["wv_space"]
        wv_time = res[self.cycle]["config"]["wv_time"]
        alpha = res[self.cycle]["alpha"]
        self.basis = WaveletBasis(
            wv_space,wv_time,
        )
        self.basis.n_theta=res[self.cycle]["config"]["n_theta"]
        coefs:dict[int,torch.Tensor] = self._parse_coefs(res[self.cycle])
        if self.save_params:
            self.coefs[self.cycle] = coefs
            self.alphas[self.cycle]=alpha
        self.basis.set_coefs(coefs)
        self.model.wavelets = self.basis
        imin,imax,jmin,jmax = self.ijs
        space_slice = space.remove_z_h().slice(imin - p, imax + p + 1, jmin - p, jmax + p + 1)
        self.wv = self.model.wavelets.localize(space_slice.psi.xy.x,space_slice.psi.xy.y)
        self.model.alpha = torch.ones_like(self.model.psi)*alpha
        psi0 = psis[0]
        psi_bcs = [extract_psi_bc(psi[:,:1]) for psi in psis]
        q_bcs = [
                Boundaries.extract(
                    self.compute_q(psi[:, :1],beta_effect_w), 2, -3, 2, -3, 3
                )
                for psi in psis
            ]

        self.model.set_psiq(crop(psi0[:,:1],p), crop(self.compute_q(psi0[:,:1],beta_effect_w),p-1))
        self.model.set_boundary_maps(QuadraticInterpolation(times, psi_bcs), QuadraticInterpolation(times, q_bcs))
        if self.save_video and save_videos:
            # self.psi2_fields = [alpha*crop(psi0[:,:1],p)[0,0] + crop(self.wv(self.model.time),p)]
            self.psi2_fields = [crop(self.wv(self.model.time),p)]
            self.psi1_fields = [self.model.psi[:,:1]]
    
    def step(self) -> None:
        super().step()
        if self.save_video and save_videos:
            # self.psi2_fields.append(
            #     self.model.alpha[0,0]*self.model.psi[0,0] + crop(self.wv(self.model.time),p)
            # )
            self.psi2_fields.append(
                crop(self.wv(self.model.time),p)
            )
            self.psi1_fields.append(self.model.psi[:,:1])

    def save_param_video(self) -> None:
        if not ( self.save_video and save_videos):
            return
        imin,imax,jmin,jmax = self.ijs
        output_folder = Path(f"../output/videos/{self.prefix}")
        if not output_folder.is_dir():
            output_folder.mkdir(parents=True)
            
        fig, axs = plots.subplots(1,2)

        im1=plots.imshow(self.psi1_fields[0][0,0],title="ѱ₁", ax=axs[0,0])
        im2 =plots.imshow(self.psi2_fields[0],title="ѱ₂",ax=axs[0,1])

        def func(frame: int) -> None:

            data1 = self.psi1_fields[frame][0,0]
            data2 = self.psi2_fields[frame]

            data_array = retrieve_imshow_data(data1)
            im1.set_array(data_array)
            clim = default_clim(data_array)
            im1.set_clim(*clim)
            retrieve_colorbar(im1, im1.axes).update_normal(im1)

            data_array = retrieve_imshow_data(data2)
            im2.set_array(data_array)
            clim = default_clim(data_array)
            im2.set_clim(*clim)
            retrieve_colorbar(im2, im2.axes).update_normal(im2)

            fig.canvas.draw_idle()

            return im1, im2
            
        FuncAnimation(fig,func, frames=len(self.psi1_fields),blit=True).save(output_folder.joinpath(f"{imin}_{imax}_{jmin}_{jmax}_{self.cycle}.mp4"), fps=25)

## [32, 96] x [64, 192]

### Model runs

In [None]:
from matplotlib.animation import FuncAnimation
from qgsw import plots
from qgsw.logging.utils import box, step
from qgsw.plots.plt_wrapper import retrieve_colorbar
from qgsw.utils.reshaping import crop

imin, imax = 32, 96
jmin, jmax = 64, 192
extract_psi_w = lambda psi: extract_psi_w_(psi,imin,imax,jmin,jmax)


model_3l.reset_time()
model_3l.set_psi(psi_start)

space_slice = SpaceDiscretization2D.from_tensors(
    x=P.space.remove_z_h().omega.xy.x[imin : imax + 1, 0],
    y=P.space.remove_z_h().omega.xy.y[0, jmin : jmax + 1],
)

space_slice_w = SpaceDiscretization2D.from_tensors(
    x=P.space.remove_z_h().omega.xy.x[imin - p + 1 : imax + p, 0],
    y=P.space.remove_z_h().omega.xy.y[0, jmin - p + 1 : jmax + p],
)
y_w = space_slice_w.q.xy.y[0, :].unsqueeze(0)
beta_effect_w = beta_plane.beta * (y_w - y0)

rg = ReducedGravity(space_slice)
tl = TwoLayers(space_slice)
alpha = Alpha(space_slice)
affine = Affine(space_slice)
mixed = Mixed(space_slice)
mixed_ro = Mixed(space_slice)
mixed_ro.color = "blueviolet"
mixed_ro.label = "Mixed - Exp field"
mixed_ro.prefix = "results_mixed_ro"
forced_rg = ForcedRG(space_slice)
forced_md = ForcedMD(space_slice)
forced_rgmd = ForcedRGMD(space_slice)
forced_colmd = ForcedColMD(space_slice)
forced_colrgmd = ForcedColRGMD(space_slice)

models = ModelsManagerOBC(
    # rg,
    # tl,
    # alpha,
    # single_alpha,
    # affine,
    mixed_ro,
    mixed,
    forced_rg,
    forced_md,
    forced_rgmd,
    forced_colmd,
    forced_colrgmd,
)
models.ijs = (imin,imax,jmin,jmax)
models.save_params = True

models.set_wind_forcing(tx[imin:imax, jmin : jmax + 1],ty[imin : imax + 1, jmin:jmax])

psi0s = {}
dpsis = {}

for c in range(n_cycles):
    torch.cuda.reset_peak_memory_stats()
    models.new_cycle()
    times = [model_3l.time.item()]

    psi0 = extract_psi_w(model_3l.psi[:,:2])
    psi0s[c] = psi0

    psis = [psi0]

    for _ in range(1, n_steps_per_cyle):
        model_3l.step()

        times.append(model_3l.time.item())

        psi = extract_psi_w(model_3l.psi[:,:2])

        psis.append(psi)
    if save_videos:
        output_folder = Path(f"../output/videos/model_3l")
        if not output_folder.is_dir():
            output_folder.mkdir(parents=True)
            
        fig, axs = plots.subplots(1,2)

        im1=plots.imshow(psis[0][0,0],title="ѱ₁", ax=axs[0,0])
        im2 =plots.imshow(psis[0][0,1],title="ѱ₂",ax=axs[0,1])

        def func(frame: int) -> None:

            data1 = psis[frame][0,0]
            data2 = psis[frame][0,1]

            data_array = retrieve_imshow_data(data1)
            im1.set_array(data_array)
            clim = default_clim(data_array)
            im1.set_clim(*clim)
            retrieve_colorbar(im1, im1.axes).update_normal(im1)

            data_array = retrieve_imshow_data(data2)
            im2.set_array(data_array)
            clim = default_clim(data_array)
            im2.set_clim(*clim)
            retrieve_colorbar(im2, im2.axes).update_normal(im2)

            fig.canvas.draw_idle()

            return im1, im2
            
        FuncAnimation(fig,func, frames=len(psis),blit=True).save(output_folder.joinpath(f"{imin}_{imax}_{jmin}_{jmax}_{c}.mp4"), fps=25)
        
    dpsis[c] = (psis[-1]-psis[0])/(n_steps_per_cyle-1)/dt

    models.reset_time()

    models.setup(psis,times,beta_effect_w)

    models.compute_loss(crop(psis[0],p))
    for n in range(1,n_steps_per_cyle):
        models.step()
        models.compute_loss(crop(psis[n],p))
    models.save_param_video()
    torch.cuda.empty_cache()
    gc.collect()

    max_mem = torch.cuda.max_memory_allocated() / 1024 / 1024
    msg_mem = f"Cycle {step(c + 1, n_cycles)} | Max memory allocated: {max_mem:.1f} MB."
    logger.info(box(msg_mem, style="round"))

### Results

In [None]:
from matplotlib import pyplot as plt

from qgsw import plots

fig,axs = plots.subplots(1,2, gridspec_kw = {"width_ratios":[1,5]},figsize=(15, 4))
plots.imshow(psi_start[0,0],ax=axs[0,0])
axs[0,0].hlines([jmin,jmax],imin,imax,color="black")
axs[0,0].vlines([imin,imax],jmin,jmax,color="black")
# axs[0,1].set_ylim(-0.1,1.1)
# axs[0,1].hlines(1,0,n_cycles*n_steps_per_cyle,linestyle="dashed",color="grey",alpha=0.75)
models.plot_loss(ax=axs[0,1])
axs[0,1].legend(loc="upper left")

In [None]:
fig, axs = plots.subplots(n_cycles,1+int(alpha.tracked)+int(mixed.tracked)+int(affine.tracked))
fig.suptitle("dѱ₂")
plots.set_coltitles(
    ["Reference"]+[mw.label for mw in [alpha, affine, mixed] if mw.tracked],
    axs,
)
plots.set_rowtitles(
    [f"Cycle {c+1}" for c in range(n_cycles)],
    axs,
)
for c in range(n_cycles):
    plots.imshow(crop(dpsis[c][0,1],p),ax=axs[c,0])
    if alpha.tracked:
        plots.imshow(crop(alpha.alphas[c]*dpsis[c][0,0],p),ax=axs[c,int(alpha.tracked)])
    if affine.tracked:
        plots.imshow(crop(affine.dpsi2s[c][0,0],p),ax=axs[c,int(alpha.tracked)+int(mixed.tracked)])
    if mixed.tracked:
        plots.imshow(crop(mixed.dpsi2s[c][0,0]+mixed.alphas[c]*dpsis[c][0,0],p+2),ax=axs[c,int(alpha.tracked)+int(mixed.tracked)+int(affine.tracked)])
plots.show()

fig, axs = plots.subplots(n_cycles,1+int(alpha.tracked)+int(mixed.tracked)+int(affine.tracked))
fig.suptitle("ѱ₂")
plots.set_coltitles(
    ["Reference"]+[mw.label for mw in [alpha, affine, mixed] if mw.tracked],
    axs,
)
plots.set_rowtitles(
    [f"Cycle {c+1}" for c in range(n_cycles)],
    axs,
)
for c in range(n_cycles):
    plots.imshow(crop(psi0s[c][0,1],p),ax=axs[c,0])
    if alpha.tracked:
        plots.imshow(crop(alpha.alphas[c]*psi0s[c][0,0],p),ax=axs[c,int(alpha.tracked)])
    if affine.tracked:
        plots.imshow(crop(affine.psi2s[c][0,0],p),ax=axs[c,int(alpha.tracked)+int(mixed.tracked)])
    if mixed.tracked:
        plots.imshow(crop(mixed.psi2s[c][0,0]+mixed.alphas[c]*psi0s[c][0,0],p+2),ax=axs[c,int(alpha.tracked)+int(mixed.tracked)+int(affine.tracked)])
plots.show()

## [32, 96] x [256, 384]

### Model runs

In [None]:
from qgsw.logging.utils import box, step
from qgsw.utils.reshaping import crop

imin, imax = 32, 96
jmin, jmax = 256, 384
extract_psi_w = lambda psi: extract_psi_w_(psi,imin,imax,jmin,jmax)


model_3l.reset_time()
model_3l.set_psi(psi_start)

space_slice = SpaceDiscretization2D.from_tensors(
    x=P.space.remove_z_h().omega.xy.x[imin : imax + 1, 0],
    y=P.space.remove_z_h().omega.xy.y[0, jmin : jmax + 1],
)

space_slice_w = SpaceDiscretization2D.from_tensors(
    x=P.space.remove_z_h().omega.xy.x[imin - p + 1 : imax + p, 0],
    y=P.space.remove_z_h().omega.xy.y[0, jmin - p + 1 : jmax + p],
)
y_w = space_slice_w.q.xy.y[0, :].unsqueeze(0)
beta_effect_w = beta_plane.beta * (y_w - y0)

rg = ReducedGravity(space_slice)
tl = TwoLayers(space_slice)
alpha = Alpha(space_slice)
affine = Affine(space_slice)
mixed = Mixed(space_slice)
mixed_ro = Mixed(space_slice)
mixed_ro.color = "blueviolet"
mixed_ro.label = "Mixed - Exp field"
mixed_ro.prefix = "results_mixed_ro"
forced_rg = ForcedRG(space_slice)
forced_md = ForcedMD(space_slice)
forced_rgmd = ForcedRGMD(space_slice)
forced_colmd = ForcedColMD(space_slice)
forced_colrgmd = ForcedColRGMD(space_slice)

models = ModelsManagerOBC(
    # rg,
    # tl,
    # alpha,
    # single_alpha,
    # affine,
    mixed_ro,
    mixed,
    forced_rg,
    forced_md,
    forced_rgmd,
    forced_colmd,
    forced_colrgmd,
)
models.ijs = (imin,imax,jmin,jmax)
models.save_params = True

models.set_wind_forcing(tx[imin:imax, jmin : jmax + 1],ty[imin : imax + 1, jmin:jmax])

psi0s = {}
dpsis = {}

for c in range(n_cycles):
    torch.cuda.reset_peak_memory_stats()
    models.new_cycle()
    times = [model_3l.time.item()]

    psi0 = extract_psi_w(model_3l.psi[:,:2])
    psi0s[c] = psi0

    psis = [psi0]

    for _ in range(1, n_steps_per_cyle):
        model_3l.step()

        times.append(model_3l.time.item())

        psi = extract_psi_w(model_3l.psi[:,:2])

        psis.append(psi)
    if save_videos:
        output_folder = Path(f"../output/videos/model_3l")
        if not output_folder.is_dir():
            output_folder.mkdir(parents=True)
            
        fig, axs = plots.subplots(1,2)

        im1=plots.imshow(psis[0][0,0],title="ѱ₁", ax=axs[0,0])
        im2 =plots.imshow(psis[0][0,1],title="ѱ₂",ax=axs[0,1])

        def func(frame: int) -> None:

            data1 = psis[frame][0,0]
            data2 = psis[frame][0,1]

            data_array = retrieve_imshow_data(data1)
            im1.set_array(data_array)
            clim = default_clim(data_array)
            im1.set_clim(*clim)
            retrieve_colorbar(im1, im1.axes).update_normal(im1)

            data_array = retrieve_imshow_data(data2)
            im2.set_array(data_array)
            clim = default_clim(data_array)
            im2.set_clim(*clim)
            retrieve_colorbar(im2, im2.axes).update_normal(im2)

            fig.canvas.draw_idle()

            return im1, im2
            
        FuncAnimation(fig,func, frames=len(psis),blit=True).save(output_folder.joinpath(f"{imin}_{imax}_{jmin}_{jmax}_{c}.mp4"), fps=25)

    dpsis[c] = (psis[-1]-psis[0])/(n_steps_per_cyle-1)/dt

    models.reset_time()

    models.setup(psis,times,beta_effect_w)

    models.compute_loss(crop(psis[0],p))
    for n in range(1,n_steps_per_cyle):
        models.step()
        models.compute_loss(crop(psis[n],p))
    models.save_param_video()

    torch.cuda.empty_cache()
    gc.collect()

    max_mem = torch.cuda.max_memory_allocated() / 1024 / 1024
    msg_mem = f"Cycle {step(c + 1, n_cycles)} | Max memory allocated: {max_mem:.1f} MB."
    logger.info(box(msg_mem, style="round"))

### Results

In [None]:
from matplotlib import pyplot as plt

from qgsw import plots

fig,axs = plots.subplots(1,2, gridspec_kw = {"width_ratios":[1,5]},figsize=(15, 4))
plots.imshow(psi_start[0,0],ax=axs[0,0])
axs[0,0].hlines([jmin,jmax],imin,imax,color="black")
axs[0,0].vlines([imin,imax],jmin,jmax,color="black")
axs[0,1].set_ylim(0,0.2)
# axs[0,1].hlines(1,0,n_cycles*n_steps_per_cyle,linestyle="dashed",color="grey",alpha=0.75)
models.plot_loss(ax=axs[0,1])
axs[0,1].legend(loc="upper left")

In [None]:
fig, axs = plots.subplots(n_cycles,1+int(alpha.tracked)+int(mixed.tracked)+int(affine.tracked))
fig.suptitle("dѱ₂")
plots.set_coltitles(
    ["Reference"]+[mw.label for mw in [alpha, affine, mixed] if mw.tracked],
    axs,
)
plots.set_rowtitles(
    [f"Cycle {c+1}" for c in range(n_cycles)],
    axs,
)
for c in range(n_cycles):
    plots.imshow(crop(dpsis[c][0,1],p),ax=axs[c,0])
    if alpha.tracked:
        plots.imshow(crop(alpha.alphas[c]*dpsis[c][0,0],p),ax=axs[c,int(alpha.tracked)])
    if affine.tracked:
        plots.imshow(crop(affine.dpsi2s[c][0,0],p),ax=axs[c,int(alpha.tracked)+int(mixed.tracked)])
    if mixed.tracked:
        plots.imshow(crop(mixed.dpsi2s[c][0,0]+mixed.alphas[c]*dpsis[c][0,0],p+2),ax=axs[c,int(alpha.tracked)+int(mixed.tracked)+int(affine.tracked)])
plots.show()

fig, axs = plots.subplots(n_cycles,1+int(alpha.tracked)+int(mixed.tracked)+int(affine.tracked))
fig.suptitle("ѱ₂")
plots.set_coltitles(
    ["Reference"]+[mw.label for mw in [alpha, affine, mixed] if mw.tracked],
    axs,
)
plots.set_rowtitles(
    [f"Cycle {c+1}" for c in range(n_cycles)],
    axs,
)
for c in range(n_cycles):
    plots.imshow(crop(psi0s[c][0,1],p),ax=axs[c,0])
    if alpha.tracked:
        plots.imshow(crop(alpha.alphas[c]*psi0s[c][0,0],p),ax=axs[c,int(alpha.tracked)])
    if affine.tracked:
        plots.imshow(crop(affine.psi2s[c][0,0],p),ax=axs[c,int(alpha.tracked)+int(mixed.tracked)])
    if mixed.tracked:
        plots.imshow(crop(mixed.psi2s[c][0,0]+mixed.alphas[c]*psi0s[c][0,0],p+2),ax=axs[c,int(alpha.tracked)+int(mixed.tracked)+int(affine.tracked)])
plots.show()

## [112, 176] x [64, 192]

### Model runs

In [None]:
from qgsw.logging.utils import box, step
from qgsw.utils.reshaping import crop

imin, imax = 112, 176
jmin, jmax = 64, 192
extract_psi_w = lambda psi: extract_psi_w_(psi,imin,imax,jmin,jmax)


model_3l.reset_time()
model_3l.set_psi(psi_start)

space_slice = SpaceDiscretization2D.from_tensors(
    x=P.space.remove_z_h().omega.xy.x[imin : imax + 1, 0],
    y=P.space.remove_z_h().omega.xy.y[0, jmin : jmax + 1],
)

space_slice_w = SpaceDiscretization2D.from_tensors(
    x=P.space.remove_z_h().omega.xy.x[imin - p + 1 : imax + p, 0],
    y=P.space.remove_z_h().omega.xy.y[0, jmin - p + 1 : jmax + p],
)
y_w = space_slice_w.q.xy.y[0, :].unsqueeze(0)
beta_effect_w = beta_plane.beta * (y_w - y0)

rg = ReducedGravity(space_slice)
tl = TwoLayers(space_slice)
alpha = Alpha(space_slice)
affine = Affine(space_slice)
mixed = Mixed(space_slice)
mixed_ro = Mixed(space_slice)
mixed_ro.color = "blueviolet"
mixed_ro.label = "Mixed - Exp field"
mixed_ro.prefix = "results_mixed_ro"
forced_rg = ForcedRG(space_slice)
forced_md = ForcedMD(space_slice)
forced_rgmd = ForcedRGMD(space_slice)
forced_colmd = ForcedColMD(space_slice)
forced_colrgmd = ForcedColRGMD(space_slice)

models = ModelsManagerOBC(
    # rg,
    # tl,
    # alpha,
    # single_alpha,
    # affine,
    mixed_ro,
    mixed,
    forced_rg,
    forced_md,
    forced_rgmd,
    forced_colmd,
    forced_colrgmd,
)
models.ijs = (imin,imax,jmin,jmax)
models.save_params = True

models.set_wind_forcing(tx[imin:imax, jmin : jmax + 1],ty[imin : imax + 1, jmin:jmax])

psi0s = {}
dpsis = {}

for c in range(n_cycles):
    torch.cuda.reset_peak_memory_stats()
    models.new_cycle()
    times = [model_3l.time.item()]

    psi0 = extract_psi_w(model_3l.psi[:,:2])
    psi0s[c] = psi0

    psis = [psi0]

    for _ in range(1, n_steps_per_cyle):
        model_3l.step()

        times.append(model_3l.time.item())

        psi = extract_psi_w(model_3l.psi[:,:2])

        psis.append(psi)
    if save_videos:
        output_folder = Path(f"../output/videos/model_3l")
        if not output_folder.is_dir():
            output_folder.mkdir(parents=True)
            
        fig, axs = plots.subplots(1,2)

        im1=plots.imshow(psis[0][0,0],title="ѱ₁", ax=axs[0,0])
        im2 =plots.imshow(psis[0][0,1],title="ѱ₂",ax=axs[0,1])

        def func(frame: int) -> None:

            data1 = psis[frame][0,0]
            data2 = psis[frame][0,1]

            data_array = retrieve_imshow_data(data1)
            im1.set_array(data_array)
            clim = default_clim(data_array)
            im1.set_clim(*clim)
            retrieve_colorbar(im1, im1.axes).update_normal(im1)

            data_array = retrieve_imshow_data(data2)
            im2.set_array(data_array)
            clim = default_clim(data_array)
            im2.set_clim(*clim)
            retrieve_colorbar(im2, im2.axes).update_normal(im2)

            fig.canvas.draw_idle()

            return im1, im2
            
        FuncAnimation(fig,func, frames=len(psis),blit=True).save(output_folder.joinpath(f"{imin}_{imax}_{jmin}_{jmax}_{c}.mp4"), fps=25)

    dpsis[c] = (psis[-1]-psis[0])/(n_steps_per_cyle-1)/dt

    models.reset_time()

    models.setup(psis,times,beta_effect_w)

    models.compute_loss(crop(psis[0],p))
    for n in range(1,n_steps_per_cyle):
        models.step()
        models.compute_loss(crop(psis[n],p))
    models.save_param_video()

    torch.cuda.empty_cache()
    gc.collect()

    max_mem = torch.cuda.max_memory_allocated() / 1024 / 1024
    msg_mem = f"Cycle {step(c + 1, n_cycles)} | Max memory allocated: {max_mem:.1f} MB."
    logger.info(box(msg_mem, style="round"))

### Results

In [None]:
from matplotlib import pyplot as plt

from qgsw import plots

fig,axs = plots.subplots(1,2, gridspec_kw = {"width_ratios":[1,5]},figsize=(15, 4))
plots.imshow(psi_start[0,0],ax=axs[0,0])
axs[0,0].hlines([jmin,jmax],imin,imax,color="black")
axs[0,0].vlines([imin,imax],jmin,jmax,color="black")
# axs[0,1].set_ylim(-0.1,1.1)
# axs[0,1].hlines(1,0,n_cycles*n_steps_per_cyle,linestyle="dashed",color="grey",alpha=0.75)
models.plot_loss(ax=axs[0,1])
axs[0,1].legend(loc="upper left")

In [None]:
fig, axs = plots.subplots(n_cycles,1+int(alpha.tracked)+int(mixed.tracked)+int(affine.tracked))
fig.suptitle("dѱ₂")
plots.set_coltitles(
    ["Reference"]+[mw.label for mw in [alpha, affine, mixed] if mw.tracked],
    axs,
)
plots.set_rowtitles(
    [f"Cycle {c+1}" for c in range(n_cycles)],
    axs,
)
for c in range(n_cycles):
    plots.imshow(crop(dpsis[c][0,1],p),ax=axs[c,0])
    if alpha.tracked:
        plots.imshow(crop(alpha.alphas[c]*dpsis[c][0,0],p),ax=axs[c,int(alpha.tracked)])
    if affine.tracked:
        plots.imshow(crop(affine.dpsi2s[c][0,0],p),ax=axs[c,int(alpha.tracked)+int(mixed.tracked)])
    if mixed.tracked:
        plots.imshow(crop(mixed.dpsi2s[c][0,0]+mixed.alphas[c]*dpsis[c][0,0],p+2),ax=axs[c,int(alpha.tracked)+int(mixed.tracked)+int(affine.tracked)])
plots.show()

fig, axs = plots.subplots(n_cycles,1+int(alpha.tracked)+int(mixed.tracked)+int(affine.tracked))
fig.suptitle("ѱ₂")
plots.set_coltitles(
    ["Reference"]+[mw.label for mw in [alpha, affine, mixed] if mw.tracked],
    axs,
)
plots.set_rowtitles(
    [f"Cycle {c+1}" for c in range(n_cycles)],
    axs,
)
for c in range(n_cycles):
    plots.imshow(crop(psi0s[c][0,1],p),ax=axs[c,0])
    if alpha.tracked:
        plots.imshow(crop(alpha.alphas[c]*psi0s[c][0,0],p),ax=axs[c,int(alpha.tracked)])
    if affine.tracked:
        plots.imshow(crop(affine.psi2s[c][0,0],p),ax=axs[c,int(alpha.tracked)+int(mixed.tracked)])
    if mixed.tracked:
        plots.imshow(crop(mixed.psi2s[c][0,0]+mixed.alphas[c]*psi0s[c][0,0],p+2),ax=axs[c,int(alpha.tracked)+int(mixed.tracked)+int(affine.tracked)])
plots.show()

## [112, 176] x [256, 384]

### Model runs

In [None]:
from qgsw.logging.utils import box, step
from qgsw.utils.reshaping import crop

imin, imax = 112, 176
jmin, jmax = 256, 384
extract_psi_w = lambda psi: extract_psi_w_(psi,imin,imax,jmin,jmax)


model_3l.reset_time()
model_3l.set_psi(psi_start)

space_slice = SpaceDiscretization2D.from_tensors(
    x=P.space.remove_z_h().omega.xy.x[imin : imax + 1, 0],
    y=P.space.remove_z_h().omega.xy.y[0, jmin : jmax + 1],
)

space_slice_w = SpaceDiscretization2D.from_tensors(
    x=P.space.remove_z_h().omega.xy.x[imin - p + 1 : imax + p, 0],
    y=P.space.remove_z_h().omega.xy.y[0, jmin - p + 1 : jmax + p],
)
y_w = space_slice_w.q.xy.y[0, :].unsqueeze(0)
beta_effect_w = beta_plane.beta * (y_w - y0)

rg = ReducedGravity(space_slice)
tl = TwoLayers(space_slice)
alpha = Alpha(space_slice)
affine = Affine(space_slice)
mixed = Mixed(space_slice)
mixed_ro = Mixed(space_slice)
mixed_ro.color = "blueviolet"
mixed_ro.label = "Mixed - Exp field"
mixed_ro.prefix = "results_mixed_ro"
forced_rg = ForcedRG(space_slice)
forced_md = ForcedMD(space_slice)
forced_rgmd = ForcedRGMD(space_slice)
forced_colmd = ForcedColMD(space_slice)
forced_colrgmd = ForcedColRGMD(space_slice)

models = ModelsManagerOBC(
    # rg,
    # tl,
    # alpha,
    # single_alpha,
    # affine,
    mixed_ro,
    mixed,
    forced_rg,
    forced_md,
    forced_rgmd,
    forced_colmd,
    forced_colrgmd,
)
models.ijs = (imin,imax,jmin,jmax)
models.save_params = True

models.set_wind_forcing(tx[imin:imax, jmin : jmax + 1],ty[imin : imax + 1, jmin:jmax])

psi0s = {}
dpsis = {}

for c in range(n_cycles):
    torch.cuda.reset_peak_memory_stats()
    models.new_cycle()
    times = [model_3l.time.item()]

    psi0 = extract_psi_w(model_3l.psi[:,:2])
    psi0s[c] = psi0

    psis = [psi0]

    for _ in range(1, n_steps_per_cyle):
        model_3l.step()

        times.append(model_3l.time.item())

        psi = extract_psi_w(model_3l.psi[:,:2])

        psis.append(psi)
    if save_videos:
        output_folder = Path(f"../output/videos/model_3l")
        if not output_folder.is_dir():
            output_folder.mkdir(parents=True)
            
        fig, axs = plots.subplots(1,2)

        im1=plots.imshow(psis[0][0,0],title="ѱ₁", ax=axs[0,0])
        im2 =plots.imshow(psis[0][0,1],title="ѱ₂",ax=axs[0,1])

        def func(frame: int) -> None:

            data1 = psis[frame][0,0]
            data2 = psis[frame][0,1]

            data_array = retrieve_imshow_data(data1)
            im1.set_array(data_array)
            clim = default_clim(data_array)
            im1.set_clim(*clim)
            retrieve_colorbar(im1, im1.axes).update_normal(im1)

            data_array = retrieve_imshow_data(data2)
            im2.set_array(data_array)
            clim = default_clim(data_array)
            im2.set_clim(*clim)
            retrieve_colorbar(im2, im2.axes).update_normal(im2)

            fig.canvas.draw_idle()

            return im1, im2
            
        FuncAnimation(fig,func, frames=len(psis),blit=True).save(output_folder.joinpath(f"{imin}_{imax}_{jmin}_{jmax}_{c}.mp4"), fps=25)

    dpsis[c] = (psis[-1]-psis[0])/(n_steps_per_cyle-1)/dt

    models.reset_time()

    models.setup(psis,times,beta_effect_w)

    models.compute_loss(crop(psis[0],p))
    for n in range(1,n_steps_per_cyle):
        models.step()
        models.compute_loss(crop(psis[n],p))
    models.save_param_video()

    torch.cuda.empty_cache()
    gc.collect()

    max_mem = torch.cuda.max_memory_allocated() / 1024 / 1024
    msg_mem = f"Cycle {step(c + 1, n_cycles)} | Max memory allocated: {max_mem:.1f} MB."
    logger.info(box(msg_mem, style="round"))

### Results

In [None]:
from matplotlib import pyplot as plt

from qgsw import plots

fig,axs = plots.subplots(1,2, gridspec_kw = {"width_ratios":[1,5]},figsize=(15, 4))
plots.imshow(psi_start[0,0],ax=axs[0,0])
axs[0,0].hlines([jmin,jmax],imin,imax,color="black")
axs[0,0].vlines([imin,imax],jmin,jmax,color="black")

# axs[0,1].set_ylim(-0.1,1.1)
# axs[0,1].hlines(1,0,n_cycles*n_steps_per_cyle,linestyle="dashed",color="grey",alpha=0.75)
models.plot_loss(ax=axs[0,1])
axs[0,1].legend(loc="upper left")

In [None]:
fig, axs = plots.subplots(n_cycles,1+int(alpha.tracked)+int(mixed.tracked)+int(affine.tracked))
fig.suptitle("dѱ₂")
plots.set_coltitles(
    ["Reference"]+[mw.label for mw in [alpha, affine, mixed] if mw.tracked],
    axs,
)
plots.set_rowtitles(
    [f"Cycle {c+1}" for c in range(n_cycles)],
    axs,
)
for c in range(n_cycles):
    plots.imshow(crop(dpsis[c][0,1],p),ax=axs[c,0])
    if alpha.tracked:
        plots.imshow(crop(alpha.alphas[c]*dpsis[c][0,0],p),ax=axs[c,int(alpha.tracked)])
    if affine.tracked:
        plots.imshow(crop(affine.dpsi2s[c][0,0],p),ax=axs[c,int(alpha.tracked)+int(mixed.tracked)])
    if mixed.tracked:
        plots.imshow(crop(mixed.dpsi2s[c][0,0]+mixed.alphas[c]*dpsis[c][0,0],p+2),ax=axs[c,int(alpha.tracked)+int(mixed.tracked)+int(affine.tracked)])
plots.show()

fig, axs = plots.subplots(n_cycles,1+int(alpha.tracked)+int(mixed.tracked)+int(affine.tracked))
fig.suptitle("ѱ₂")
plots.set_coltitles(
    ["Reference"]+[mw.label for mw in [alpha, affine, mixed] if mw.tracked],
    axs,
)
plots.set_rowtitles(
    [f"Cycle {c+1}" for c in range(n_cycles)],
    axs,
)
for c in range(n_cycles):
    plots.imshow(crop(psi0s[c][0,1],p),ax=axs[c,0])
    if alpha.tracked:
        plots.imshow(crop(alpha.alphas[c]*psi0s[c][0,0],p),ax=axs[c,int(alpha.tracked)])
    if affine.tracked:
        plots.imshow(crop(affine.psi2s[c][0,0],p),ax=axs[c,int(alpha.tracked)+int(mixed.tracked)])
    if mixed.tracked:
        plots.imshow(crop(mixed.psi2s[c][0,0]+mixed.alphas[c]*psi0s[c][0,0],p+2),ax=axs[c,int(alpha.tracked)+int(mixed.tracked)+int(affine.tracked)])
plots.show()