In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import seaborn as sns

sns.set_theme("notebook",style="dark")

In [None]:
import datetime
from pathlib import Path

import torch

from qgsw.cli import ScriptArgs
from qgsw.configs.core import Configuration
from qgsw.fields.variables.tuples import UVH
from qgsw.forcing.wind import WindForcing
from qgsw.logging import getLogger, setup_root_logger
from qgsw.masks import Masks
from qgsw.models.qg.psiq.core import QGPSIQ
from qgsw.models.qg.psiq.filtered.core import (
    QGPSIQCollinearSF,
    QGPSIQFixeddSF2,
)
from qgsw.optim.utils import EarlyStop, RegisterParams
from qgsw.models.qg.stretching_matrix import compute_A
from qgsw.models.qg.uvh.projectors.core import QGProjector
from qgsw.solver.boundary_conditions.base import Boundaries
from qgsw.spatial.core.discretization import (
    SpaceDiscretization2D,
    SpaceDiscretization3D,
)
import matplotlib.pyplot as plt
from qgsw.specs import defaults
from qgsw.utils import covphys
from qgsw.utils.interpolation import QuadraticInterpolation

torch.backends.cudnn.deterministic = True
torch.set_grad_enabled(False)
import gc

specs = defaults.get()

setup_root_logger(1)
logger = getLogger(__name__)

config = Configuration.from_toml("../output/g5k/param_optim/_config.toml")


# Parameters

H = config.model.h
g_prime = config.model.g_prime
H1, H2 = H[0], H[1]
g1, g2 = g_prime[0], g_prime[1]
beta_plane = config.physics.beta_plane
bottom_drag_coef = config.physics.bottom_drag_coefficient
slip_coef = config.physics.slip_coef

space = SpaceDiscretization3D.from_config(
    config.space,
    config.model,
)
P = QGProjector(
    A=compute_A(H=H, g_prime=g_prime),
    H=H.unsqueeze(-1).unsqueeze(-1),
    space=space,
    f0=beta_plane.f0,
    masks=Masks.empty(
        nx=config.space.nx,
        ny=config.space.ny,
    ),
)
dx, dy = space.dx, space.dy
nx, ny = space.nx, space.ny

wind = WindForcing.from_config(
    config.windstress,
    config.space,
    config.physics,
)
tx, ty = wind.compute()

uvh0 = UVH.from_file("../output/g5k/param_optim/_data_startup.pt")
psi_start = P.compute_p(covphys.to_cov(uvh0, dx, dy))[0] / beta_plane.f0

## Areas


def compute_slices(
    imin: int, imax: int, jmin: int, jmax: int
) -> tuple[list[slice, slice], list[slice, slice]]:
    """Compute horizontal slices."""
    psi_slices = [slice(imin, imax + 1), slice(jmin, jmax + 1)]
    q_slices = [slice(imin, imax), slice(jmin, jmax)]

    return psi_slices, q_slices

## Simulation parameters

dt = 7200
optim_max_step = 200
str_optim_len = len(str(optim_max_step))
n_steps_per_cyle = 250
comparison_interval = 1

## Error


def rmse(f: torch.Tensor, f_ref: torch.Tensor) -> float:
    """RMSE."""
    return (f - f_ref).square().mean().sqrt() / f_ref.square().mean().sqrt()


# Models
## Three Layer model

model_3l = QGPSIQ(
    space_2d=space.remove_z_h(),
    H=H,
    beta_plane=config.physics.beta_plane,
    g_prime=g_prime,
)
model_3l.set_wind_forcing(tx, ty)
model_3l.masks = Masks.empty_tensor(
    model_3l.space.nx,
    model_3l.space.ny,
    device=specs["device"],
)
model_3l.bottom_drag_coef = bottom_drag_coef
model_3l.slip_coef = slip_coef
model_3l.dt = dt
y0 = model_3l.y0


## Inhomogeneous models
def set_inhomogeneous_model(
    model: QGPSIQ | QGPSIQCollinearSF | QGPSIQFixeddSF2,
) -> QGPSIQ | QGPSIQCollinearSF | QGPSIQFixeddSF2:
    """Set up inhomogeneous model."""
    space = model.space
    model.y0 = y0
    model.masks = Masks.empty_tensor(
        space.nx,
        space.ny,
        device=specs["device"],
    )
    model.bottom_drag_coef = 0
    model.wide = True
    model.slip_coef = slip_coef
    model.dt = dt
    return model

In [None]:
from qgsw.solver.finite_diff import laplacian
from qgsw.spatial.core.grid_conversion import interpolate


def compute_q_alpha_(psi: torch.Tensor, alpha: torch.Tensor, beta_effect:torch.Tensor) -> torch.Tensor:
    """Compute pv using alpha."""
    return interpolate(
        laplacian(psi, dx, dy)
        - beta_plane.f0**2 * (1 / H1 / g1 + 1 / H1 / g2) * psi[..., 1:-1, 1:-1]
        + beta_plane.f0**2 * (1 / H1 / g2) * alpha * psi[..., 1:-1, 1:-1]
    ) + beta_effect

def compute_q_mixed_(psi: torch.Tensor, alpha: torch.Tensor, psi2:torch.Tensor, beta_effect:torch.Tensor) -> torch.Tensor:
    """Compute pv using alpha."""
    return interpolate(
        laplacian(psi, dx, dy)
        - beta_plane.f0**2 * (1 / H1 / g1 + 1 / H1 / g2) * psi[..., 1:-1, 1:-1]
        + beta_plane.f0**2 * (1 / H1 / g2) * (psi2[...,1:-1,1:-1]+alpha * psi[..., 1:-1, 1:-1])
    ) + beta_effect


def compute_q_psi2_(psi: torch.Tensor, psi2: torch.Tensor, beta_effect:torch.Tensor) -> torch.Tensor:
    """Compute pv using psi2."""
    return interpolate(
        laplacian(psi, dx, dy)
        - beta_plane.f0**2 * (1 / H1 / g1 + 1 / H1 / g2) * psi[..., 1:-1, 1:-1]
        + beta_plane.f0**2 * (1 / H1 / g2) * psi2[..., 1:-1, 1:-1]
    ) + beta_effect

def compute_q_rg_(psi:torch.Tensor, beta_effect:torch.Tensor) -> torch.Tensor:
    return interpolate(laplacian(psi,dx,dy) - beta_plane.f0**2 * (1/H1/g1+1/H1/g2)*psi[...,1:-1,1:-1]) + beta_effect

A_2l = compute_A(
    H[:2],g_prime[:2]
)

def compute_q_2l_(psi:torch.Tensor, beta_effect:torch.Tensor) -> torch.Tensor:
    return interpolate(laplacian(psi,dx,dy) - beta_plane.f0**2 * torch.einsum("lm,...mxy->...lxy",A_2l,psi[...,1:-1,1:-1]))+beta_effect

In [None]:
output_files = [
    # ("../output/g5k/param_optim/results_alpha_32_96_64_192.pt","../output/g5k/param_optim/results_psi2_reg_32_96_64_192.pt"),
    ("../output/g5k/param_optim/results_alpha_32_96_256_384.pt","../output/g5k/param_optim/results_psi2_reg_32_96_256_384.pt"),
    # ("../output/g5k/param_optim/results_alpha_112_176_64_192.pt","../output/g5k/param_optim/results_psi2_reg_112_176_64_192.pt"),
    # ("../output/g5k/param_optim/results_alpha_112_176_256_384.pt","../output/g5k/param_optim/results_psi2_reg_112_176_256_384.pt"),
]

outputs = {
    "indices": [],
    "losses_1l": [],
    "losses_2l": [],
    "losses_alpha": [],
    "losses_dpsi": [],
    "alphas": [],
    "dalphas": [],
    "psi2s": [],
    "dpsi2s": [],
    "dpsis": [],
    "psi2s_ref": [],
}

for f_alpha, f_psi2 in output_files[:]:

    res_alpha = torch.load(f_alpha)
    res_psi2 = torch.load(f_psi2)
    if res_alpha[0]["coords"] != res_psi2[0]["coords"]:
        msg = "Mismatching indices"
        raise ValueError(msg)
    if len(res_alpha) != len(res_psi2):
        msg = "Mismatching number of cyles indices"
        raise ValueError(msg)

    indices = res_alpha[0]["coords"]
    n_cycles = len(res_alpha)

    logger.info(f"Indices: {indices}")
    
    outputs["indices"].append(indices)

    imin, imax, jmin, jmax = indices
    p = 4
    psi_slices = [slice(imin, imax + 1), slice(jmin, jmax + 1)]
    psi_slices_w = [slice(imin - p, imax + p + 1), slice(jmin - p, jmax + p + 1)]
    
    def extract_psi_w(psi: torch.Tensor) -> torch.Tensor:
        """Extract psi."""
        return psi[..., psi_slices_w[0], psi_slices_w[1]]


    def extract_psi_bc(psi: torch.Tensor) -> Boundaries:
        """Extract psi."""
        return Boundaries.extract(psi, p, -p - 1, p, -p - 1, 2)

    model_3l.reset_time()
    model_3l.set_psi(psi_start)

    psi_slices, _ = compute_slices(imin, imax, jmin, jmax)

    space_slice = SpaceDiscretization2D.from_tensors(
        x=P.space.remove_z_h().omega.xy.x[imin : imax + 1, 0],
        y=P.space.remove_z_h().omega.xy.y[0, jmin : jmax + 1],
    )

    space_slice_w = SpaceDiscretization2D.from_tensors(
        x=P.space.remove_z_h().omega.xy.x[imin - p + 1 : imax + p, 0],
        y=P.space.remove_z_h().omega.xy.y[0, jmin - p + 1 : jmax + p],
    )
    y_w = space_slice_w.q.xy.y[0, :].unsqueeze(0)
    beta_effect_w = beta_plane.beta * (y_w - y0)

    compute_q_rg = lambda psi : compute_q_rg_(psi, beta_effect_w)
    compute_q_2l = lambda psi : compute_q_2l_(psi, beta_effect_w)
    compute_q_alpha = lambda psi, alpha : compute_q_alpha_(psi,alpha, beta_effect_w)
    compute_q_psi2 = lambda psi,psi2 : compute_q_psi2_(psi,psi2, beta_effect_w)

    model_1l = QGPSIQ(
        space_2d=space_slice,
        H=H[:1]*H[1:2]/(H[:1]+H[1:2]),
        beta_plane=beta_plane,
        g_prime=g_prime[1:2],
    )
    model_2l = QGPSIQ(
        space_2d=space_slice,
        H=H[:2],
        beta_plane=beta_plane,
        g_prime=g_prime[:2],
    )
    model_alpha = QGPSIQCollinearSF(
        space_2d=space_slice,
        H=H[:2],
        beta_plane=beta_plane,
        g_prime=g_prime[:2],
    )
    model_dpsi = QGPSIQFixeddSF2(
        space_2d=space_slice,
        H=H[:2],
        beta_plane=beta_plane,
        g_prime=g_prime[:2],
    )
    model_1l: QGPSIQ = set_inhomogeneous_model(model_1l)
    model_2l: QGPSIQ = set_inhomogeneous_model(model_2l)
    model_alpha: QGPSIQCollinearSF = set_inhomogeneous_model(model_alpha)
    model_dpsi: QGPSIQFixeddSF2 = set_inhomogeneous_model(model_dpsi)

    model_1l.set_wind_forcing(
        tx[imin:imax, jmin : jmax + 1], ty[imin : imax + 1, jmin:jmax]
    )
    model_2l.set_wind_forcing(
        tx[imin:imax, jmin : jmax + 1], ty[imin : imax + 1, jmin:jmax]
    )
    model_alpha.set_wind_forcing(
        tx[imin:imax, jmin : jmax + 1], ty[imin : imax + 1, jmin:jmax]
    )
    model_dpsi.set_wind_forcing(
        tx[imin:imax, jmin : jmax + 1], ty[imin : imax + 1, jmin:jmax]
    )


    losses_alpha = []
    losses_dpsi = []
    losses_2l = []
    losses_1l = []
    alphas = []
    dalphas = []
    psi2s = []
    dpsi2s = []
    dpsis = []
    psi2s_ref = []

    for c in range(n_cycles):

        dpsis.append([])

        times = [model_3l.time.item()]

        psi0 = extract_psi_w(model_3l.psi[:,:2])
        psi2s_ref.append(psi0[:,1:2])
        psi_bc = extract_psi_bc(psi0)

        psis = [psi0]
        psi_bcs = [psi_bc]

        for _ in range(1, n_steps_per_cyle):
            model_3l.step()

            times.append(model_3l.time.item())

            psi = extract_psi_w(model_3l.psi[:,:2])
            psi_bc = extract_psi_bc(psi)

            psis.append(psi)
            psi_bcs.append(psi_bc)
            # dpsis[-1].append(extract_psi_w(model_3l._dpsi))
        dpsis[-1] = (psis[-1]-psis[0])/(n_steps_per_cyle-1)/dt #torch.stack(dpsis[-1],dim=0).mean(dim=0)
            

        psi_bc_interp_1l = QuadraticInterpolation(times, [p[:,:1] for p in psi_bcs])
        psi_bc_interp_2l = QuadraticInterpolation(times, psi_bcs)

        model_1l.reset_time()

        q_bcs = [
                Boundaries.extract(
                    compute_q_rg(psi[:, :1]), 2, -3, 2, -3, 3
                )
                for psi in psis
            ]

        model_1l.set_psiq(psi0[:,:1,p:-p,p:-p], compute_q_rg(psi0[:,:1])[...,3:-3,3:-3])
        model_1l.set_boundary_maps(psi_bc_interp_1l, QuadraticInterpolation(times, q_bcs))


        losses_1l.append(rmse(model_1l.psi[0, 0], psis[0][0, 0,p:-p,p:-p]).cpu().item())

        for n in range(1, n_steps_per_cyle):
            model_1l.step()

            if (n + 1) % comparison_interval == 0:
                losses_1l.append(rmse(model_1l.psi[0, 0], psis[n][0, 0,p:-p,p:-p]).cpu().item())

        model_2l.reset_time()

        q_bcs = [
                Boundaries.extract(
                    compute_q_2l(psi[:, :2]), 2, -3, 2, -3, 3
                )
                for psi in psis
            ]

        model_2l.set_psiq(psi0[:,:2,p:-p,p:-p], compute_q_2l(psi0[:,:2])[...,3:-3,3:-3])
        model_2l.set_boundary_maps(psi_bc_interp_1l, QuadraticInterpolation(times, q_bcs))

        losses_2l.append(rmse(model_2l.psi[0, 0], psis[0][0, 0,p:-p,p:-p]).cpu().item())

        for n in range(1, n_steps_per_cyle):
            model_2l.step()

            if (n + 1) % comparison_interval == 0:
                J = rmse(model_2l.psi[0, 0], psis[n][0, 0,p:-p,p:-p]).cpu().item()
                if J > 3:
                    losses_2l.append(torch.nan)
                else:
                    losses_2l.append(J)
        
        alpha:torch.Tensor = res_alpha[c]["alpha"]
        alphas.append(alpha)
        dalpha:torch.Tensor = res_alpha[c]["dalpha"]
        dalphas.append(dalpha)
        
        model_alpha.reset_time()

        q_bcs = [
                Boundaries.extract(
                    compute_q_alpha(psi[:, :1],alpha), 2, -3, 2, -3, 3
                )
                for psi in psis
            ]

        model_alpha.set_psiq(psi0[:,:1,p:-p,p:-p], compute_q_alpha(psi0[:,:1],alpha)[...,3:-3,3:-3])

        model_alpha.alpha = torch.ones_like(psi0[:,:1,p:-p,p:-p]) * dalpha
        model_alpha.set_boundary_maps(psi_bc_interp_1l, QuadraticInterpolation(times, q_bcs))
        
        losses_alpha.append(rmse(model_alpha.psi[0, 0], psis[0][0, 0,p:-p,p:-p]).cpu().item())
        
        for n in range(1, n_steps_per_cyle):
            model_alpha.step()

            if (n + 1) % comparison_interval == 0:
                losses_alpha.append(rmse(model_alpha.psi[0, 0], psis[n][0, 0,p:-p,p:-p]).cpu().item())

        psi2:torch.Tensor = res_psi2[c]["psi2"].to(**specs)
        psi2s.append(psi2)
        dpsi2:torch.Tensor = res_psi2[c]["dpsi2"].to(**specs)
        dpsi2s.append(dpsi2)

        model_dpsi.reset_time()

        q_bcs = [
                Boundaries.extract(
                    compute_q_psi2(psi[:, :1],psi2+n*dt*dpsi2), 2, -3, 2, -3, 3
                )
                for n,psi in enumerate(psis)
            ]

        model_dpsi.set_psiq(psi0[:,:1,p:-p,p:-p], compute_q_psi2(psi0[:,:1],psi2)[...,3:-3,3:-3])
        model_dpsi.set_boundary_maps(psi_bc_interp_1l, QuadraticInterpolation(times, q_bcs))
        model_dpsi.dpsi2 = dpsi2[...,p:-p,p:-p]

        losses_dpsi.append(rmse(model_dpsi.psi[0, 0], psis[0][0, 0,p:-p,p:-p]).cpu().item())
        for n in range(1, n_steps_per_cyle):
            model_dpsi.step()

            if (n + 1) % comparison_interval == 0:
                losses_dpsi.append(rmse(model_dpsi.psi[0, 0], psis[n][0, 0,p:-p,p:-p]).cpu().item())
        
    outputs["alphas"].append(alphas)
    outputs["dalphas"].append(dalphas)
    outputs["psi2s"].append(psi2s)
    outputs["dpsi2s"].append(dpsi2s)
    outputs["dpsis"].append(dpsis)
    outputs["losses_1l"].append(losses_1l)
    outputs["losses_2l"].append(losses_2l)
    outputs["losses_alpha"].append(losses_alpha)
    outputs["losses_dpsi"].append(losses_dpsi)
    outputs["psi2s_ref"].append(psi2s_ref)

    torch.cuda.empty_cache()
    gc.collect()

In [None]:
from qgsw import plots

fig,axs = plots.subplots(len(output_files),2, gridspec_kw = {"width_ratios":[1,5]},figsize=(12, 4*len(output_files)))
for i,indices in enumerate(outputs["indices"]):
    imin,imax,jmin,jmax = indices
    print("ɑ: ",outputs["alphas"][i])
    print("dɑ: ",outputs["dalphas"][i])

    plots.imshow(psi_start[0,0],ax=axs[i,0])
    axs[i,0].hlines([jmin,jmax],imin,imax,color="black")
    axs[i,0].vlines([imin,imax],jmin,jmax,color="black")


    axs[i,1].hlines(1,0,len(outputs["losses_1l"][i]),linestyle="dashed",color="brown",alpha=0.75)
    axs[i,1].plot(outputs["losses_1l"][i],color="blue", label = "1 layer")
    axs[i,1].plot(outputs["losses_2l"][i],color="black", label = "2 layers")
    axs[i,1].plot(outputs["losses_alpha"][i],color="red", label = "ɑ")
    axs[i,1].plot(outputs["losses_dpsi"][i][:-1],color="orange", label = "dѱ2")
    axs[i,1].legend(loc="upper left")
    # axs[i,1].set_ylim(-0.1,1.1)

plots.show()

In [None]:
non_reg = [
    # "../output/g5k/param_optim/results_psi2_32_96_64_192.pt",
    "../output/g5k/param_optim/results_psi2_32_96_256_384.pt",
    # "../output/g5k/param_optim/results_psi2_112_176_64_192.pt",
    # "../output/g5k/param_optim/results_psi2_112_176_256_384.pt",
]

for i,indices in enumerate(outputs["indices"]):
    fig = plt.figure(figsize=(11,25))
    ax0 = plt.subplot2grid((6,3), (0,0))
    ax1 = plt.subplot2grid((6,3), (0,1), colspan=2)
    ax2 = plt.subplot2grid((6,3),(1,0))
    ax3 = plt.subplot2grid((6,3),(1,1))
    ax4 = plt.subplot2grid((6,3),(1,2))
    ax5 = plt.subplot2grid((6,3),(2,0))
    ax6 = plt.subplot2grid((6,3),(2,1))
    ax7 = plt.subplot2grid((6,3),(2,2))
    ax8 = plt.subplot2grid((6,3),(3,0))
    ax9 = plt.subplot2grid((6,3),(3,1))
    ax10 = plt.subplot2grid((6,3),(3,2))
    ax11 = plt.subplot2grid((6,3),(4,0))
    ax12 = plt.subplot2grid((6,3),(4,1))
    ax13 = plt.subplot2grid((6,3),(4,2))
    ax14 = plt.subplot2grid((6,3),(5,0))
    ax15 = plt.subplot2grid((6,3),(5,1))
    ax16 = plt.subplot2grid((6,3),(5,2))

    imin,imax,jmin,jmax = indices

    plots.imshow(psi_start[0,0],ax=ax0)
    ax0.hlines([jmin,jmax],imin,imax,color="black")
    ax0.vlines([imin,imax],jmin,jmax,color="black")


    ax1.hlines(1,0,len(outputs["losses_1l"][i]),linestyle="dashed",color="brown",alpha=0.75)
    ax1.plot(outputs["losses_1l"][i],color="blue", label = "1 layer")
    ax1.plot(outputs["losses_2l"][i],color="black", label = "2 layers")
    ax1.plot(outputs["losses_alpha"][i],color="red", label = "ɑ")
    ax1.plot(outputs["losses_dpsi"][i][:-1],color="orange", label = "dѱ2")
    ax1.legend(loc="upper left")
    # axs[i,1].set_ylim(-0.1,1.1)

    ax1.set_ylabel("RMSE Evolution")

    out = torch.load(non_reg[i])
    plots.imshow(out[0]["dpsi2"][0,0,4:-4,4:-4],ax=ax2)
    plots.imshow(out[1]["dpsi2"][0,0,4:-4,4:-4],ax=ax3)
    plots.imshow(out[2]["dpsi2"][0,0,4:-4,4:-4],ax=ax4)

    ax2.set_ylabel("dΨ₂ without regularization")

    plots.imshow(outputs["dpsis"][i][0][0,1,4:-4,4:-4],ax=ax5)
    plots.imshow(outputs["dpsis"][i][1][0,1,4:-4,4:-4],ax=ax6)
    plots.imshow(outputs["dpsis"][i][2][0,1,4:-4,4:-4],ax=ax7)

    ax5.set_ylabel("Reference dΨ₂")

    plots.imshow(outputs["dpsi2s"][i][0][0,0,4:-4,4:-4],ax=ax8)
    plots.imshow(outputs["dpsi2s"][i][1][0,0,4:-4,4:-4],ax=ax9)
    plots.imshow(outputs["dpsi2s"][i][2][0,0,4:-4,4:-4],ax=ax10)

    ax8.set_ylabel("dΨ₂ with regularization")

    plots.imshow(outputs["psi2s_ref"][i][0][0,0,4:-4,4:-4],ax=ax11)
    plots.imshow(outputs["psi2s_ref"][i][1][0,0,4:-4,4:-4],ax=ax12)
    plots.imshow(outputs["psi2s_ref"][i][2][0,0,4:-4,4:-4],ax=ax13)
    
    ax11.set_ylabel(" Reference Ψ₂⁰")

    plots.imshow(outputs["psi2s"][i][0][0,0,4:-4,4:-4],ax=ax14)
    plots.imshow(outputs["psi2s"][i][1][0,0,4:-4,4:-4],ax=ax15)
    plots.imshow(outputs["psi2s"][i][2][0,0,4:-4,4:-4],ax=ax16)

    ax14.set_ylabel("Reconstructed Ψ₂⁰")

    plt.tight_layout()
    plots.show()

In [None]:
from qgsw.models.qg.psiq.filtered.core import QGPSIQMixed


output_files_mixed = [
    "../output/local/param_optim/results_psi2_reg_mixed_32_96_256_384.pt"
]

outputs_mixed = {
    "indices": [],
    "losses_1l": [],
    "losses_2l": [],
    "losses_mixed": [],
    "alphas": [],
    "dalphas": [],
    "psi2s": [],
    "dpsi2s": [],
    "dpsis": [],
    "psi0_ref": [],
}

for f_mixed in output_files_mixed[:]:

    res_mixed=torch.load(f_mixed)

    indices = res_mixed[0]["coords"]
    n_cycles = len(res_mixed)

    logger.info(f"Indices: {indices}")
    
    outputs_mixed["indices"].append(indices)

    imin, imax, jmin, jmax = indices
    p = 4
    psi_slices = [slice(imin, imax + 1), slice(jmin, jmax + 1)]
    psi_slices_w = [slice(imin - p, imax + p + 1), slice(jmin - p, jmax + p + 1)]
    
    def extract_psi_w(psi: torch.Tensor) -> torch.Tensor:
        """Extract psi."""
        return psi[..., psi_slices_w[0], psi_slices_w[1]]


    def extract_psi_bc(psi: torch.Tensor) -> Boundaries:
        """Extract psi."""
        return Boundaries.extract(psi, p, -p - 1, p, -p - 1, 2)

    model_3l.reset_time()
    model_3l.set_psi(psi_start)

    psi_slices, _ = compute_slices(imin, imax, jmin, jmax)

    space_slice = SpaceDiscretization2D.from_tensors(
        x=P.space.remove_z_h().omega.xy.x[imin : imax + 1, 0],
        y=P.space.remove_z_h().omega.xy.y[0, jmin : jmax + 1],
    )

    space_slice_w = SpaceDiscretization2D.from_tensors(
        x=P.space.remove_z_h().omega.xy.x[imin - p + 1 : imax + p, 0],
        y=P.space.remove_z_h().omega.xy.y[0, jmin - p + 1 : jmax + p],
    )
    y_w = space_slice_w.q.xy.y[0, :].unsqueeze(0)
    beta_effect_w = beta_plane.beta * (y_w - y0)

    compute_q_rg = lambda psi : compute_q_rg_(psi, beta_effect_w)
    compute_q_2l = lambda psi : compute_q_2l_(psi, beta_effect_w)
    compute_q_mixed = lambda psi, alpha, psi2 : compute_q_mixed_(psi,alpha,psi2, beta_effect_w)

    model_1l = QGPSIQ(
        space_2d=space_slice,
        H=H[:1]*H[1:2]/(H[:1]+H[1:2]),
        beta_plane=beta_plane,
        g_prime=g_prime[1:2],
    )
    model_2l = QGPSIQ(
        space_2d=space_slice,
        H=H[:2],
        beta_plane=beta_plane,
        g_prime=g_prime[:2],
    )
    model_mixed = QGPSIQMixed(
        space_2d=space_slice,
        H=H[:2],
        beta_plane=beta_plane,
        g_prime=g_prime[:2],
    )
    model_1l: QGPSIQ = set_inhomogeneous_model(model_1l)
    model_2l: QGPSIQ = set_inhomogeneous_model(model_2l)
    model_mixed: QGPSIQMixed = set_inhomogeneous_model(model_mixed)

    model_1l.set_wind_forcing(
        tx[imin:imax, jmin : jmax + 1], ty[imin : imax + 1, jmin:jmax]
    )
    model_2l.set_wind_forcing(
        tx[imin:imax, jmin : jmax + 1], ty[imin : imax + 1, jmin:jmax]
    )
    model_mixed.set_wind_forcing(
        tx[imin:imax, jmin : jmax + 1], ty[imin : imax + 1, jmin:jmax]
    )


    losses_mixed = []
    losses_2l = []
    losses_1l = []
    alphas = []
    dalphas = []
    psi2s = []
    dpsi2s = []
    dpsis = []
    psi0_ref = []

    for c in range(n_cycles):

        dpsis.append([])

        times = [model_3l.time.item()]

        psi0 = extract_psi_w(model_3l.psi[:,:2])
        psi0_ref.append(psi0[:,:2])
        psi_bc = extract_psi_bc(psi0)

        psis = [psi0]
        psi_bcs = [psi_bc]

        for _ in range(1, n_steps_per_cyle):
            model_3l.step()

            times.append(model_3l.time.item())

            psi = extract_psi_w(model_3l.psi[:,:2])
            psi_bc = extract_psi_bc(psi)

            psis.append(psi)
            psi_bcs.append(psi_bc)
            
        psi_bc_interp_1l = QuadraticInterpolation(times, [p[:,:1] for p in psi_bcs])
        psi_bc_interp_2l = QuadraticInterpolation(times, psi_bcs)

        model_1l.reset_time()

        q_bcs = [
                Boundaries.extract(
                    compute_q_rg(psi[:, :1]), 2, -3, 2, -3, 3
                )
                for psi in psis
            ]

        model_1l.set_psiq(psi0[:,:1,p:-p,p:-p], compute_q_rg(psi0[:,:1])[...,3:-3,3:-3])
        model_1l.set_boundary_maps(psi_bc_interp_1l, QuadraticInterpolation(times, q_bcs))

        losses_1l.append(rmse(model_1l.psi[0, 0], psis[0][0, 0,p:-p,p:-p]).cpu().item())

        for n in range(1, n_steps_per_cyle):
            model_1l.step()

            if (n + 1) % comparison_interval == 0:
                losses_1l.append(rmse(model_1l.psi[0, 0], psis[n][0, 0,p:-p,p:-p]).cpu().item())

        model_2l.reset_time()

        q_bcs = [
                Boundaries.extract(
                    compute_q_2l(psi[:, :2]), 2, -3, 2, -3, 3
                )
                for psi in psis
            ]

        model_2l.set_psiq(psi0[:,:2,p:-p,p:-p], compute_q_2l(psi0[:,:2])[...,3:-3,3:-3])
        model_2l.set_boundary_maps(psi_bc_interp_1l, QuadraticInterpolation(times, q_bcs))

        losses_2l.append(rmse(model_2l.psi[0, 0], psis[0][0, 0,p:-p,p:-p]).cpu().item())

        for n in range(1, n_steps_per_cyle):
            model_2l.step()

            if (n + 1) % comparison_interval == 0:
                J = rmse(model_2l.psi[0, 0], psis[n][0, 0,p:-p,p:-p]).cpu().item()
                if J > 3:
                    losses_2l.append(torch.nan)
                else:
                    losses_2l.append(J)
    

        alpha:torch.Tensor = res_mixed[c]["alpha"].to(**specs)
        alphas.append(alpha)
        dalpha:torch.Tensor = res_mixed[c]["dalpha"].to(**specs)
        dalphas.append(dalpha)
        psi2:torch.Tensor = res_mixed[c]["psi2"].to(**specs)
        psi2s.append(psi2)
        dpsi2:torch.Tensor = res_mixed[c]["dpsi2"].to(**specs)
        dpsi2s.append(dpsi2)

        model_mixed.reset_time()

        q_bcs = [
                Boundaries.extract(
                    compute_q_mixed(psi[:, :1],alpha,psi2+n*dt*dpsi2), 2, -3, 2, -3, 3
                )
                for n,psi in enumerate(psis)
            ]

        model_mixed.set_psiq(psi0[:,:1,p:-p,p:-p], compute_q_mixed(psi0[:,:1],alpha,psi2)[...,3:-3,3:-3])
        model_mixed.alpha = torch.ones_like(model_mixed.psi)*dalpha
        model_mixed.set_boundary_maps(psi_bc_interp_1l, QuadraticInterpolation(times, q_bcs))
        model_mixed.dpsi2 = dpsi2[...,p:-p,p:-p]

        losses_mixed.append(rmse(model_mixed.psi[0, 0], psis[0][0, 0,p:-p,p:-p]).cpu().item())
        for n in range(1, n_steps_per_cyle):
            model_mixed.step()

            if (n + 1) % comparison_interval == 0:
                losses_mixed.append(rmse(model_mixed.psi[0, 0], psis[n][0, 0,p:-p,p:-p]).cpu().item())
        
    outputs_mixed["alphas"].append(alphas)
    outputs_mixed["dalphas"].append(dalphas)
    outputs_mixed["psi2s"].append(psi2s)
    outputs_mixed["dpsi2s"].append(dpsi2s)
    outputs_mixed["dpsis"].append(dpsis)
    outputs_mixed["losses_1l"].append(losses_1l)
    outputs_mixed["losses_2l"].append(losses_2l)
    outputs_mixed["losses_mixed"].append(losses_mixed)
    outputs_mixed["psi0_ref"].append(psi0_ref)

    torch.cuda.empty_cache()
    gc.collect()

In [None]:
from qgsw import plots

fig,axs = plots.subplots(len(output_files_mixed),2, gridspec_kw = {"width_ratios":[1,5]},figsize=(12, 4*len(output_files_mixed)))
for i,indices in enumerate(outputs_mixed["indices"]):
    imin,imax,jmin,jmax = indices
    print("ɑ: ",outputs_mixed["alphas"][i])
    print("dɑ: ",outputs_mixed["dalphas"][i])

    plots.imshow(psi_start[0,0],ax=axs[i,0])
    axs[i,0].hlines([jmin,jmax],imin,imax,color="black")
    axs[i,0].vlines([imin,imax],jmin,jmax,color="black")


    axs[i,1].hlines(1,0,len(outputs_mixed["losses_1l"][i]),linestyle="dashed",color="brown",alpha=0.75)
    axs[i,1].plot(outputs_mixed["losses_1l"][i],color="blue", label = "1 layer")
    axs[i,1].plot(outputs_mixed["losses_2l"][i],color="black", label = "2 layers")
    axs[i,1].plot(outputs_mixed["losses_mixed"][i],color="purple", label = "Mixed")
    axs[i,1].legend(loc="upper left")
    # axs[i,1].set_ylim(-0.1,1.1)

plots.show()