In [None]:
%load_ext autoreload
%autoreload 2

## Setup

### Libraries

In [None]:
import torch
import seaborn as sns
from qgsw.logging import getLogger, setup_root_logger
from qgsw.specs import defaults
import gc

torch.backends.cudnn.deterministic = True
torch.set_grad_enabled(False)
sns.set_theme("notebook",style="dark")

specs = defaults.get()

setup_root_logger(1)
logger = getLogger(__name__)

### Parameters

In [None]:
from qgsw.configs.core import Configuration
from qgsw.forcing.wind import WindForcing


config = Configuration.from_toml("../output/g5k/param_optim/_config.toml")

H = config.model.h
g_prime = config.model.g_prime
H1, H2 = H[0], H[1]
g1, g2 = g_prime[0], g_prime[1]
beta_plane = config.physics.beta_plane
bottom_drag_coef = config.physics.bottom_drag_coefficient
slip_coef = config.physics.slip_coef

wind = WindForcing.from_config(
    config.windstress,
    config.space,
    config.physics,
)
tx, ty = wind.compute()
p = 4

### Space

In [None]:
from qgsw.spatial.core.discretization import SpaceDiscretization3D


space = SpaceDiscretization3D.from_config(
    config.space,
    config.model,
)
dx, dy = space.dx, space.dy

In [None]:
from qgsw.solver.boundary_conditions.base import Boundaries


def get_psi_slices(imin:int,imax:int,jmin:int,jmax:int) -> tuple[slice]:
    return  [slice(imin, imax + 1), slice(jmin, jmax + 1)]

def extract_psi_w_(psi: torch.Tensor,imin:int,imax:int,jmin:int,jmax:int) -> torch.Tensor:
    """Extract psi."""
    psi_slices_w = get_psi_slices(imin-p,imax+p,jmin-p,jmax+p)
    return psi[..., *psi_slices_w]


def extract_psi_bc(psi: torch.Tensor) -> Boundaries:
    """Extract psi."""
    return Boundaries.extract(psi, p, -p - 1, p, -p - 1, 2)

### Simulation

In [None]:
dt = 7200
n_steps_per_cyle = 250
comparison_interval = 1
n_cycles = 3

### RMSE

In [None]:
def rmse(f: torch.Tensor, f_ref: torch.Tensor) -> torch.Tensor:
    """RMSE."""
    return (f - f_ref).square().mean().sqrt() / f_ref.square().mean().sqrt()

## Initial condition

In [None]:
from qgsw.fields.variables.tuples import UVH
from qgsw.masks import Masks
from qgsw.models.qg.stretching_matrix import compute_A
from qgsw.models.qg.uvh.projectors.core import QGProjector
from qgsw.utils import covphys


P = QGProjector(
    A=compute_A(H=H, g_prime=g_prime),
    H=H.unsqueeze(-1).unsqueeze(-1),
    space=space,
    f0=beta_plane.f0,
    masks=Masks.empty(
        nx=config.space.nx,
        ny=config.space.ny,
    ),
)
uvh0 = UVH.from_file("../output/g5k/param_optim/_data_startup.pt")
psi_start = P.compute_p(covphys.to_cov(uvh0, dx, dy))[0] / beta_plane.f0

### Full-domain model

In [None]:
from qgsw.models.qg.psiq.core import QGPSIQ


model_3l = QGPSIQ(
    space_2d=space.remove_z_h(),
    H=H,
    beta_plane=config.physics.beta_plane,
    g_prime=g_prime,
)
model_3l.set_wind_forcing(tx, ty)
model_3l.masks = Masks.empty_tensor(
    model_3l.space.nx,
    model_3l.space.ny,
    device=specs["device"],
)
model_3l.bottom_drag_coef = bottom_drag_coef
model_3l.slip_coef = slip_coef
model_3l.dt = dt
y0 = model_3l.y0

## OBC models

In [None]:
from collections.abc import Callable
from pathlib import Path
from typing import Generic, TypeVar

from matplotlib import pyplot as plt
import numpy as np

from qgsw.models.qg.psiq.core import QGPSIQCore
from qgsw.spatial.core.discretization import SpaceDiscretization2D

T = TypeVar("T", bound=QGPSIQCore)

class ModelWrapper(Generic[T]):
    model:T
    color:str
    prefix:str|None
    label:str
    no_wind=False
    results_paths = Path("../output/g5k/param_optim")
    losses: list[list[torch.Tensor]]
    ijs:tuple[int,int,int,int] = None
    cycle=-1
    save_params = False

    def __init__(self, space_2d:SpaceDiscretization2D) -> None:
        self.losses = []   
    def _set_params(self) -> None:
        space = self.model.space
        self.model.y0 = y0
        self.model.masks = Masks.empty_tensor(
            space.nx,
            space.ny,
            device=specs["device"],
        )
        self.model.bottom_drag_coef = 0
        self.model.wide = True
        self.model.slip_coef = slip_coef
        self.model.dt = dt
    def compute_q(self,psi:torch.Tensor, beta_effect:torch.Tensor) -> torch.Tensor:...
    def set_wind_forcing(self,tx:torch.Tensor, ty:torch.Tensor) -> None:
        if self.no_wind:
            return
        self.model.set_wind_forcing(tx, ty)
    def load(self, imin:int,imax:int,jmin:int,jmax:int)-> dict:
        indices = f"_{imin}_{imax}_{jmin}_{jmax}.pt"
        file = self.results_paths.joinpath(self.prefix+indices)
        return torch.load(file)
    def new_cycle(self) -> None:
        self.cycle+=1
        self.losses.append([])
    def add_loss(self, loss:float) -> None:
        self.losses[-1].append(loss)
    def setup(self, psis:list[torch.Tensor], times:list[torch.Tensor], beta_effect_w:torch.Tensor) -> None:...
    def step(self)-> None:
        self.model.step()
    def plot_loss(self,*,ax:plt.Axes|None=None,cycle:int|None=None) -> None:
        if ax is None:
            ax = plt.gca()
        loss = np.concatenate(self.losses) if cycle is None else np.array(self.losses[cycle])
        ax.plot(loss, color=self.color,label=self.label)
        

class ModelsManager:
    def __init__(self, *mw:ModelWrapper[QGPSIQCore]) -> None:
        self.model_wrappers = mw

        self.ijs = self.model_wrappers[0].ijs
        self.save_params = self.model_wrappers[0].save_params

    @property
    def ijs(self) -> tuple[int,int,int,int]:
        return self.model_wrappers[0].ijs
    @ijs.setter
    def ijs(self,ijs:tuple[int,int,int,int]) -> None:
        self.loop_over_models(lambda mw: setattr(mw,"ijs",ijs))
    @property
    def save_params(self) -> bool:
        return self.model_wrappers[0].save_params
    @save_params.setter
    def save_params(self,save_params:bool) -> None:
        self.loop_over_models(lambda mw: setattr(mw,"save_params",save_params))

    def loop_over_models(self, func:Callable[[ModelWrapper], None]) -> None:
        for mw in self.model_wrappers:
            func(mw)
    def step(self) -> None:
        self.loop_over_models(lambda mw: mw.step())
    def new_cycle(self) -> None:
        self.loop_over_models(lambda mw: mw.new_cycle())
    def compute_loss(self, psi_ref:torch.Tensor) -> None:
        self.loop_over_models(lambda mw: mw.add_loss(rmse(mw.model.psi[0,0],psi_ref[0,0]).cpu().item()))
    def reset_time(self) -> None:
        self.loop_over_models(lambda mw: mw.model.reset_time())
    def set_wind_forcing(self, tx:torch.Tensor, ty:torch.Tensor)-> None:
        self.loop_over_models(lambda mw: mw.set_wind_forcing(tx,ty))
    def setup(self, psis:list[torch.Tensor], times:list[torch.Tensor], beta_effect_w:torch.Tensor)-> None:
        self.loop_over_models(lambda mw: mw.setup(psis,times,beta_effect_w))
    def plot_loss(self,*,ax:plt.Axes|None=None,cycle:int|None=None) -> None:
        self.loop_over_models(lambda mw: mw.plot_loss(ax=ax,cycle=cycle))

### Reduced gravity

In [None]:
from torch import Tensor
from qgsw.solver.finite_diff import laplacian
from qgsw.spatial.core.discretization import SpaceDiscretization2D
from qgsw.spatial.core.grid_conversion import interpolate
from qgsw.utils.interpolation import QuadraticInterpolation
from qgsw.utils.reshaping import crop


class ReducedGravity(ModelWrapper[QGPSIQ]):
    prefix = None
    color = "blue"
    label="Reduced Gravity"
    def __init__(self, space_2d:SpaceDiscretization2D) -> None:
        super().__init__(space_2d)
        self.model= QGPSIQ(
            space_2d=space_2d,
            H=H[:1]*H[1:2]/(H[:1]+H[1:2]),
            beta_plane=beta_plane,
            g_prime=g_prime[1:2],
        )
        self._set_params()
    def compute_q(self,psi: Tensor, beta_effect:torch.Tensor) -> Tensor:
        return interpolate(laplacian(psi,dx,dy) - beta_plane.f0**2 * (1/H1/g1+1/H1/g2)*psi[...,1:-1,1:-1]) + beta_effect
    
    def setup(self, psis: list[torch.Tensor],times:list[torch.Tensor],beta_effect_w:torch.Tensor) -> None:
        psi0 = psis[0]
        psi_bcs = [extract_psi_bc(psi[:,:1]) for psi in psis]
        q_bcs = [
            Boundaries.extract(
                self.compute_q(psi[:, :1],beta_effect_w), 2, -3, 2, -3, 3
            )
            for psi in psis
        ]
        self.model.set_boundary_maps(QuadraticInterpolation(times, psi_bcs), QuadraticInterpolation(times, q_bcs))
        self.model.set_psiq(crop(psi0[:,:1],p), crop(self.compute_q(psi0[:,:1],beta_effect_w),p-1))

### Two-layers

In [None]:
from torch._tensor import Tensor

from qgsw.utils.reshaping import crop


class TwoLayers(ModelWrapper[QGPSIQ]):
    prefix = None
    color = "black"
    label="Two Layers"
    def __init__(self, space_2d:SpaceDiscretization2D) -> None:
        super().__init__(space_2d)
        self.model= QGPSIQ(
            space_2d=space_2d,
            H=H[:2],
            beta_plane=beta_plane,
            g_prime=g_prime[:2],
        )
        self._set_params()
    def compute_q(self,psi: Tensor, beta_effect:torch.Tensor) -> Tensor:
        return interpolate(laplacian(psi,dx,dy) - beta_plane.f0**2 * torch.einsum("lm,...mxy->...lxy",self.model.A,psi[...,1:-1,1:-1]))+beta_effect
    def setup(self, psis: list[Tensor], times:list[torch.Tensor],beta_effect_w: Tensor) -> None:
        psi0 = psis[0]
        psi_bcs = [extract_psi_bc(psi[:,:2]) for psi in psis]
        q_bcs = [
            Boundaries.extract(
                self.compute_q(psi[:, :2],beta_effect_w), 2, -3, 2, -3, 3
            )
            for psi in psis
        ]
        self.model.set_boundary_maps(QuadraticInterpolation(times, psi_bcs), QuadraticInterpolation(times, q_bcs))
        self.model.set_psiq(crop(psi0[:,:2],p), crop(self.compute_q(psi0[:,:2],beta_effect_w),p-1))

### Alpha

In [None]:
from torch._tensor import Tensor
from qgsw.models.qg.psiq.filtered.core import QGPSIQCollinearSF
from qgsw.utils.reshaping import crop


class Alpha(ModelWrapper[QGPSIQCollinearSF]):
    prefix = "results_alpha"
    color = "red"
    label="Collinear"
    def __init__(self, space_2d:SpaceDiscretization2D) -> None:
        super().__init__(space_2d)
        self.model= QGPSIQCollinearSF(
            space_2d=space_2d,
            H=H[:2],
            beta_plane=beta_plane,
            g_prime=g_prime[:2],
        )
        self._set_params()
        self.alphas = {}
        self.dalphas = {}
    def compute_q(self,psi: Tensor, beta_effect:torch.Tensor, alpha:torch.Tensor) -> Tensor:
        return interpolate(
            laplacian(psi, dx, dy)
            - beta_plane.f0**2 * (1 / H1 / g1 + 1 / H1 / g2) * psi[..., 1:-1, 1:-1]
            + beta_plane.f0**2 * (1 / H1 / g2) * alpha * psi[..., 1:-1, 1:-1]
        ) + beta_effect
    def setup(self, psis: list[Tensor], times: list[Tensor], beta_effect_w: Tensor) -> None:
        res = self.load(*self.ijs)
        dalpha = res[self.cycle]["dalpha"]
        alpha= res[self.cycle]["alpha"]
        if self.save_params:
            self.dalphas[self.cycle] = dalpha
            self.alphas[self.cycle] = alpha
        self.model.alpha = torch.ones_like(self.model.psi)*dalpha
        psi0 = psis[0]
        psi_bcs = [extract_psi_bc(psi[:,:1]) for psi in psis]
        q_bcs = [
            Boundaries.extract(
                self.compute_q(psi[:, :1],beta_effect_w,alpha), 2, -3, 2, -3, 3
            )
            for psi in psis
        ]
        self.model.set_boundary_maps(QuadraticInterpolation(times, psi_bcs), QuadraticInterpolation(times, q_bcs))
        self.model.set_psiq(crop(psi0[:,:1],p), crop(self.compute_q(psi0[:,:1],beta_effect_w,alpha),p-1))


### Affine

In [None]:
from torch._tensor import Tensor
from qgsw.models.qg.psiq.filtered.core import QGPSIQFixeddSF2
from qgsw.utils.reshaping import crop


class Affine(ModelWrapper[QGPSIQFixeddSF2]):
    prefix = "results_psi2"
    color="orange"
    label="Affine"
    def __init__(self, space_2d:SpaceDiscretization2D) -> None:
        super().__init__(space_2d)
        self.model= QGPSIQFixeddSF2(
            space_2d=space_2d,
            H=H[:2],
            beta_plane=beta_plane,
            g_prime=g_prime[:2],
        )
        self._set_params()
        self.psi2s = {}
        self.dpsi2s = {}
    def compute_q(self,psi: Tensor, beta_effect:torch.Tensor, psi2:torch.Tensor) -> Tensor:
        return interpolate(
            laplacian(psi, dx, dy)
            - beta_plane.f0**2 * (1 / H1 / g1 + 1 / H1 / g2) * psi[..., 1:-1, 1:-1]
            + beta_plane.f0**2 * (1 / H1 / g2) * psi2[..., 1:-1, 1:-1]
        ) + beta_effect
    
    def setup(self, psis: list[Tensor], times: list[Tensor], beta_effect_w: Tensor) -> None:
        res = self.load(*self.ijs)
        psi2 = res[self.cycle]["psi2"].to(**specs)
        dpsi2 = res[self.cycle]["dpsi2"].to(**specs)
        if self.save_params:
            self.dpsi2s[self.cycle] = dpsi2
            self.psi2s[self.cycle] = psi2
        self.model.dpsi2 = crop(dpsi2,p)
        psi0 = psis[0]
        psi_bcs = [extract_psi_bc(psi[:,:1]) for psi in psis]
        q_bcs = [
            Boundaries.extract(
                self.compute_q(psi[:, :1],beta_effect_w,psi2+n*dt*dpsi2), 2, -3, 2, -3, 3
            )
            for n,psi in enumerate(psis)
        ]
        self.model.set_boundary_maps(QuadraticInterpolation(times, psi_bcs), QuadraticInterpolation(times, q_bcs))
        self.model.set_psiq(crop(psi0[:,:1],p), crop(self.compute_q(psi0[:,:1],beta_effect_w,psi2),p-1))

### Mixed

In [None]:
from torch._tensor import Tensor
from qgsw.models.qg.psiq.filtered.core import QGPSIQMixed


class Mixed(ModelWrapper[QGPSIQMixed]):
    prefix = "results_mixed"
    color="purple"
    label="Mixed"
    def __init__(self, space_2d:SpaceDiscretization2D) -> None:
        super().__init__(space_2d)
        self.model= QGPSIQMixed(
            space_2d=space_2d,
            H=H[:2],
            beta_plane=beta_plane,
            g_prime=g_prime[:2],
        )
        self._set_params()
        self.alphas = {}
        self.dalphas = {}
        self.psi2s = {}
        self.dpsi2s = {}
    def compute_q(self,psi: Tensor, beta_effect:torch.Tensor, alpha:torch.Tensor, psi2:torch.Tensor) -> Tensor:
        return interpolate(
            laplacian(psi, dx, dy)
            - beta_plane.f0**2 * (1 / H1 / g1 + 1 / H1 / g2) * psi[..., 1:-1, 1:-1]
            + beta_plane.f0**2 * (1 / H1 / g2) * (psi2[...,1:-1,1:-1]+alpha * psi[..., 1:-1, 1:-1])
        ) + beta_effect
    def setup(self, psis: list[Tensor], times: list[Tensor], beta_effect_w: Tensor) -> None:
        res = self.load(*self.ijs)
        alpha:torch.Tensor = res[self.cycle]["alpha"]
        dalpha:torch.Tensor = res[self.cycle]["dalpha"]
        psi2:torch.Tensor = res[self.cycle]["psi2"].to(**specs)
        dpsi2:torch.Tensor = res[self.cycle]["dpsi2"].to(**specs)
        if self.save_params:
            self.dalphas[self.cycle] = dalpha
            self.alphas[self.cycle] = alpha
            self.dpsi2s[self.cycle] = dpsi2
            self.psi2s[self.cycle] = psi2
        psi0 = psis[0]
        psi_bcs = [extract_psi_bc(psi[:,:1]) for psi in psis]
        q_bcs = [
                Boundaries.extract(
                    self.compute_q(psi[:, :1],beta_effect_w,alpha,psi2+n*dt*dpsi2), 2, -3, 2, -3, 3
                )
                for n,psi in enumerate(psis)
            ]

        self.model.set_psiq(crop(psi0[:,:1],p), crop(self.compute_q(psi0[:,:1],beta_effect_w,alpha,psi2),p-1))
        self.model.alpha = torch.ones_like(self.model.psi)*dalpha
        self.model.set_boundary_maps(QuadraticInterpolation(times, psi_bcs), QuadraticInterpolation(times, q_bcs))
        self.model.dpsi2 = crop(dpsi2,p)

### Forced

In [None]:
from qgsw.decomposition.sine import STSineBasis
from qgsw.models.qg.psiq.filtered.core import QGPSIQForced


class Forced(ModelWrapper[QGPSIQForced]):
    prefix = "results_forced"
    color="brown"
    label="Forced"
    order=5
    results_paths = Path("../output/local/param_optim")
    def __init__(self, space_2d:SpaceDiscretization2D) -> None:
        super().__init__(space_2d)
        self.model= QGPSIQ(
            space_2d=space_2d,
            H=H[:1]*H[1:2]/(H[:1]+H[1:2]),
            beta_plane=beta_plane,
            g_prime=g_prime[1:2],
        )
        self._set_params()
        self.coefs = {}
    def compute_q(self,psi: Tensor, beta_effect:torch.Tensor, psi2:torch.Tensor) -> Tensor:
        return interpolate(
            laplacian(psi,dx,dy)
            - beta_plane.f0**2 * (1/H1/g1+1/H1/g2)*psi[...,1:-1,1:-1]
            + beta_plane.f0**2 * (1/H1/g2)*psi2[...,1:-1,1:-1]
        ) + beta_effect
    
    def _parse_coefs(self, res:dict) -> dict[int, torch.Tensor]:
        coefs = {}
        for k,v in res.items():
            if k[:6] != "coefs_":
                continue
            o = int(k[6:])
            coefs[o] = v.to(**specs)
        return coefs

    def setup(self, psis: list[Tensor], times: list[Tensor], beta_effect_w: Tensor) -> None:
        res = self.load(*self.ijs)
        self.basis = STSineBasis(
            self.model.space.remove_z_h().q.xy.x - self.model.space.remove_z_h().q.xy.x[:1, :],
            self.model.space.remove_z_h().q.xy.y - self.model.space.remove_z_h().q.xy.y[:, :1],
            torch.tensor(times,**specs)-times[0],
            order=self.order,
        )
        coefs:dict[int,torch.Tensor] = self._parse_coefs(res[self.cycle])
        if self.save_params:
            self.coefs[self.cycle] = coefs
        self.basis.set_coefs(coefs)
        
        
        psi0 = psis[0]
        psi_bcs = [extract_psi_bc(psi[:,:1]) for psi in psis]
        q_bcs = [
                Boundaries.extract(
                    self.compute_q(psi[:, :1],beta_effect_w,psi[:,1:2]), 2, -3, 2, -3, 3
                )
                for psi in psis
            ]

        self.model.set_psiq(crop(psi0[:,:1],p), crop(self.compute_q(psi0[:,:1],beta_effect_w,psi0[:,1:2]),p-1))
        self.model.set_boundary_maps(QuadraticInterpolation(times, psi_bcs), QuadraticInterpolation(times, q_bcs))
    def step(self) -> None:
        self.model.forcing = self.basis.at_time(self.model.time)
        return super().step()

## [32, 96] x [256, 384]

In [None]:
imin, imax = 32, 96
jmin, jmax = 256, 384
extract_psi_w = lambda psi: extract_psi_w_(psi,imin,imax,jmin,jmax)

In [None]:
from qgsw.logging.utils import box, step
from qgsw.utils.reshaping import crop


model_3l.reset_time()
model_3l.set_psi(psi_start)

space_slice = SpaceDiscretization2D.from_tensors(
    x=P.space.remove_z_h().omega.xy.x[imin : imax + 1, 0],
    y=P.space.remove_z_h().omega.xy.y[0, jmin : jmax + 1],
)

space_slice_w = SpaceDiscretization2D.from_tensors(
    x=P.space.remove_z_h().omega.xy.x[imin - p + 1 : imax + p, 0],
    y=P.space.remove_z_h().omega.xy.y[0, jmin - p + 1 : jmax + p],
)
y_w = space_slice_w.q.xy.y[0, :].unsqueeze(0)
beta_effect_w = beta_plane.beta * (y_w - y0)

rg = ReducedGravity(space_slice)
tl = TwoLayers(space_slice)
alpha = Alpha(space_slice)
affine = Affine(space_slice)
mixed = Mixed(space_slice)
forced = Forced(space_slice)

models = ModelsManager(
    rg,
    tl,
    alpha,
    affine,
    mixed,
    forced,
)
models.ijs = (imin,imax,jmin,jmax)
models.save_params = True

models.set_wind_forcing(tx[imin:imax, jmin : jmax + 1],ty[imin : imax + 1, jmin:jmax])

for c in range(n_cycles):
    torch.cuda.reset_peak_memory_stats()
    models.new_cycle()
    times = [model_3l.time.item()]

    psi0 = extract_psi_w(model_3l.psi[:,:2])

    psis = [psi0]

    for _ in range(1, n_steps_per_cyle):
        model_3l.step()

        times.append(model_3l.time.item())

        psi = extract_psi_w(model_3l.psi[:,:2])

        psis.append(psi)

    models.reset_time()

    models.setup(psis,times,beta_effect_w)

    models.compute_loss(crop(psis[0],p))
    for n in range(1,n_steps_per_cyle):
        models.step()
        models.compute_loss(crop(psis[n],p))

    torch.cuda.empty_cache()
    gc.collect()

    max_mem = torch.cuda.max_memory_allocated() / 1024 / 1024
    msg_mem = f"Cycle {step(c + 1, n_cycles)} | Max memory allocated: {max_mem:.1f} MB."
    logger.info(box(msg_mem, style="round"))

In [None]:
from matplotlib import pyplot as plt

from qgsw import plots

fig,axs = plots.subplots(1,2, gridspec_kw = {"width_ratios":[1,5]},figsize=(15, 4))
plots.imshow(psi_start[0,0],ax=axs[0,0])
axs[0,0].hlines([jmin,jmax],imin,imax,color="black")
axs[0,0].vlines([imin,imax],jmin,jmax,color="black")
axs[0,1].hlines(1,0,n_cycles*n_steps_per_cyle,linestyle="dashed",color="grey",alpha=0.75)
models.plot_loss(ax=axs[0,1])
axs[0,1].legend(loc="upper left")

In [None]:
forced.coefs[0]
basis = forced.basis
basis.set_coefs(forced.coefs[0])


In [None]:
from qgsw.plots.heatmaps import AnimatedHeatmaps


datas = [forced.basis.at_time(torch.tensor([i*7200],**specs)).cpu().T for i in range(250)]

plot = AnimatedHeatmaps([datas])
plot.save_video("out.mp4")