In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import seaborn as sns

sns.set_theme("notebook",style="dark")

In [None]:
import datetime
from pathlib import Path

import torch

from qgsw.spatial.core.grid_conversion import interpolate
from qgsw import plots
from qgsw.configs.core import Configuration
from qgsw.fields.variables.tuples import UVH
from qgsw.forcing.wind import WindForcing
from qgsw.logging import getLogger, setup_root_logger
from qgsw.masks import Masks
from qgsw.models.qg.psiq.core import QGPSIQ
from qgsw.models.qg.psiq.filtered.core import (
    QGPSIQCollinearSF,
    QGPSIQFixeddSF2,
)
from qgsw.models.qg.stretching_matrix import compute_A
from qgsw.models.qg.uvh.projectors.core import QGProjector
from qgsw.solver.boundary_conditions.base import Boundaries
from qgsw.spatial.core.discretization import (
    SpaceDiscretization2D,
    SpaceDiscretization3D,
)
from qgsw.specs import defaults
from qgsw.utils import covphys
from qgsw.utils.interpolation import QuadraticInterpolation

torch.backends.cudnn.deterministic = True
torch.set_grad_enabled(False)
import gc

specs = defaults.get()

setup_root_logger(1)
logger = getLogger(__name__)

config = Configuration.from_toml("../output/g5k/param_optim/_config.toml")


# Parameters

H = config.model.h
g_prime = config.model.g_prime
H1, H2 = H[0], H[1]
g1, g2 = g_prime[0], g_prime[1]
beta_plane = config.physics.beta_plane
bottom_drag_coef = config.physics.bottom_drag_coefficient
slip_coef = config.physics.slip_coef

space = SpaceDiscretization3D.from_config(
    config.space,
    config.model,
)
P = QGProjector(
    A=compute_A(H=H, g_prime=g_prime),
    H=H.unsqueeze(-1).unsqueeze(-1),
    space=space,
    f0=beta_plane.f0,
    masks=Masks.empty(
        nx=config.space.nx,
        ny=config.space.ny,
    ),
)
dx, dy = space.dx, space.dy
nx, ny = space.nx, space.ny

wind = WindForcing.from_config(
    config.windstress,
    config.space,
    config.physics,
)
tx, ty = wind.compute()

uvh0 = UVH.from_file("../output/g5k/param_optim/_data_startup.pt")
psi_start = P.compute_p(covphys.to_cov(uvh0, dx, dy))[0] / beta_plane.f0

## Areas


def compute_slices(
    imin: int, imax: int, jmin: int, jmax: int
) -> tuple[list[slice, slice], list[slice, slice]]:
    """Compute horizontal slices."""
    psi_slices = [slice(imin, imax + 1), slice(jmin, jmax + 1)]
    q_slices = [slice(imin, imax), slice(jmin, jmax)]

    return psi_slices, q_slices

## Simulation parameters

dt = 7200
optim_max_step = 200
str_optim_len = len(str(optim_max_step))
n_steps_per_cyle = 2
comparison_interval = 1

## Error


def rmse(f: torch.Tensor, f_ref: torch.Tensor) -> float:
    """RMSE."""
    return (f - f_ref).square().mean().sqrt() / f_ref.square().mean().sqrt()


# Models
## Three Layer model

model_3l = QGPSIQ(
    space_2d=space.remove_z_h(),
    H=H,
    beta_plane=config.physics.beta_plane,
    g_prime=g_prime,
)
model_3l.set_wind_forcing(tx, ty)
model_3l.masks = Masks.empty_tensor(
    model_3l.space.nx,
    model_3l.space.ny,
    device=specs["device"],
)
model_3l.bottom_drag_coef = bottom_drag_coef
model_3l.slip_coef = slip_coef
model_3l.dt = dt
y0 = model_3l.y0


## Inhomogeneous models
def set_inhomogeneous_model(
    model: QGPSIQ | QGPSIQCollinearSF | QGPSIQFixeddSF2,
) -> QGPSIQ | QGPSIQCollinearSF | QGPSIQFixeddSF2:
    """Set up inhomogeneous model."""
    space = model.space
    model.y0 = y0
    model.masks = Masks.empty_tensor(
        space.nx,
        space.ny,
        device=specs["device"],
    )
    model.bottom_drag_coef = 0
    model.wide = True
    model.slip_coef = slip_coef
    model.dt = dt
    # model.time_stepper = "euler"
    return model

In [None]:
from qgsw.pv import compute_q1_interior, compute_q2_2l_interior

compute_q_ = lambda psi1, psi2, beta_effect: compute_q1_interior(
    psi1,
    psi2,
    H1,
    g1,
    g2,
    dx,
    dy,
    beta_plane.f0,
    beta_effect,
)

In [None]:
from qgsw.models.qg.psiq.filtered.core import QGPSIQMixed


output_files = ["../output/g5k/param_optim/results_mixed_32_96_64_192.pt"]

outputs = {
    "indices": []
}

for f in output_files:
    res_mixed=torch.load(f)

    indices = res_mixed[0]["coords"]
    n_cycles = len(res_mixed)

    logger.info(f"Indices: {indices}")
    
    outputs["indices"].append(indices)

    imin, imax, jmin, jmax = indices
    p = 4
    psi_slices_w = [slice(imin - p, imax + p + 1), slice(jmin - p, jmax + p + 1)]
    
    def extract_psi_w(psi: torch.Tensor) -> torch.Tensor:
        """Extract psi."""
        return psi[..., psi_slices_w[0], psi_slices_w[1]]


    def extract_psi_bc(psi: torch.Tensor) -> Boundaries:
        """Extract psi."""
        return Boundaries.extract(psi, p, -p - 1, p, -p - 1, 2)

    model_3l.reset_time()
    model_3l.set_psi(psi_start)

    psi_slices, q_slices = compute_slices(imin, imax, jmin, jmax)

    space_slice = SpaceDiscretization2D.from_tensors(
        x=P.space.remove_z_h().omega.xy.x[imin : imax + 1, 0],
        y=P.space.remove_z_h().omega.xy.y[0, jmin : jmax + 1],
    )

    space_slice_w = SpaceDiscretization2D.from_tensors(
        x=P.space.remove_z_h().omega.xy.x[imin - p + 1 : imax + p, 0],
        y=P.space.remove_z_h().omega.xy.y[0, jmin - p + 1 : jmax + p],
    )
    y_w = space_slice_w.q.xy.y[0, :].unsqueeze(0)
    beta_effect_w = beta_plane.beta * (y_w - y0)

    compute_q_mixed = lambda psi, psi2 : compute_q_(psi,psi2, beta_effect_w)

    model_mixed = QGPSIQMixed(
        space_2d=space_slice,
        H=H[:2],
        beta_plane=beta_plane,
        g_prime=g_prime[:2],
    )
    
    model_mixed: QGPSIQMixed = set_inhomogeneous_model(model_mixed)
    model_mixed.set_wind_forcing(
        tx[imin:imax, jmin : jmax + 1], ty[imin : imax + 1, jmin:jmax]
    )

    # for c in range(n_cycles):
    for c in range(1):

        times = [model_3l.time.item()]

        psi0_ref = extract_psi_w(model_3l.psi[:,:2])
        q0_ref = model_3l.q[:,:1,*q_slices]
        psi_bc = extract_psi_bc(psi0_ref)

        psis = [psi0_ref]
        psi_bcs = [psi_bc]

        for n in range(1, n_steps_per_cyle):
            model_3l.step()

            times.append(model_3l.time.item())

            psi = extract_psi_w(model_3l.psi[:,:2])
            psi_bc = extract_psi_bc(psi)

            psis.append(psi)
            psi_bcs.append(psi_bc)
            
        psi_bc_interp_1l = QuadraticInterpolation(times, [p[:,:1] for p in psi_bcs])

        alpha:torch.Tensor = res_mixed[c]["alpha"]
        dalpha:torch.Tensor = res_mixed[c]["dalpha"]
        psi2:torch.Tensor = res_mixed[c]["psi2"].to(**specs)
        dpsi2:torch.Tensor = res_mixed[c]["dpsi2"].to(**specs)

        model_mixed.reset_time()

        q_bcs = [
                Boundaries.extract(
                    compute_q_mixed(psi[:, :1],psi2+n*dt*dpsi2 + alpha*psi[:,:1]), 2, -3, 2, -3, 3
                )
                for n,psi in enumerate(psis)
            ]
        psi0 = psi0_ref[:,:1,p:-p,p:-p]
        q0 = compute_q_mixed(psi0_ref[:,:1],psi2+alpha*psi0_ref[:,:1])[...,3:-3,3:-3]
        model_mixed.set_psiq(psi0, q0)
        model_mixed.alpha = torch.ones_like(model_mixed.psi)*dalpha
        model_mixed.set_boundary_maps(psi_bc_interp_1l, QuadraticInterpolation(times, q_bcs))
        model_mixed.dpsi2 = dpsi2[...,p:-p,p:-p]
        n=0
        for n in range(1, n_steps_per_cyle):
            model_mixed.step()
    

    torch.cuda.empty_cache()
    gc.collect()


In [None]:
from qgsw.models.core.flux import div_flux_5pts_no_pad
from qgsw.solver.finite_diff import grad_perp


psi1 = psi_start[:,:1,*psi_slices]
u1,v1 = grad_perp(psi1)
u1/=dy
v1/=dx

psi2_ = interpolate(psi_start[:,1:2,*psi_slices])

j = div_flux_5pts_no_pad(
    psi2_,u1[...,1:-1,:],v1[...,1:-1],dx,dy
)

psi2_tilde_ = interpolate(psi2[...,p:-p,p:-p]+alpha*psi1)
j_ = div_flux_5pts_no_pad(
    psi2_tilde_,u1[...,1:-1,:],v1[...,1:-1],dx,dy
)

In [None]:
print(psi2_.shape,u1.shape, v1.shape)

In [None]:
fig, axs = plots.subplots(1,3)

dtpsi2 = (model_3l.psi-psi_start)[:,1:2,*psi_slices]/model_3l.time.item()
dtpsi1 = (model_mixed.psi-psi_start[:,:1,*psi_slices])/model_mixed.time.item()

plots.imshow((j[0,0]+interpolate(dtpsi2)[0,0,1:-1,1:-1])[3:-3,3:-3],ax=axs[0,0])
plots.imshow((j[0,0]+interpolate(dtpsi2)[0,0,1:-1,1:-1]-j_[0,0]-interpolate(dpsi2[0,0,p:-p,p:-p]+dalpha*dtpsi1[0,0])[1:-1,1:-1])[3:-3,3:-3],ax=axs[0,1])
plots.imshow((j_[0,0]+interpolate(dpsi2[0,0,p:-p,p:-p]+dalpha*dtpsi1[0,0])[1:-1,1:-1])[3:-3,3:-3],ax=axs[0,2])