In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import seaborn as sns

sns.set_theme("notebook",style="dark")

## Model setup

In [None]:
from qgsw import plots
import matplotlib.pyplot as plt
import torch
from qgsw.fields.variables.tuples import UVH
from qgsw.forcing.wind import WindForcing
from qgsw.masks import Masks
from qgsw.models.qg.psiq.core import QGPSIQ
from qgsw.models.qg.stretching_matrix import compute_A
from qgsw.models.qg.uvh.projectors.core import QGProjector
from qgsw.output import RunOutput
from qgsw.spatial.core.discretization import SpaceDiscretization2D, SpaceDiscretization3D
from qgsw.spatial.core.grid_conversion import points_to_surfaces
from qgsw.specs import defaults
from qgsw.utils import covphys
from qgsw.filters.gaussian import GaussianFilter2D
from qgsw.solver.boundary_conditions.base import Boundaries
from qgsw.solver.finite_diff import laplacian
from qgsw.utils.interpolation import LinearInterpolation, QuadraticInterpolation

run = RunOutput("../output/g5k/sw_double_gyre_long_hr")

H = run.summary.configuration.model.h
g_prime = run.summary.configuration.model.g_prime
f0 = run.summary.configuration.physics.f0
beta = run.summary.configuration.physics.beta
P = QGProjector(
    A =compute_A(
        H = H,
        g_prime = g_prime
    ),
    H = H.unsqueeze(-1).unsqueeze(-1),
    space=SpaceDiscretization3D.from_config(
        run.summary.configuration.space,
        run.summary.configuration.model
    ),
    f0 = run.summary.configuration.physics.f0,
    masks = Masks.empty(nx=run.summary.configuration.space.nx,ny=run.summary.configuration.space.ny)
)
A = P.A
space=P.space
dx,dy = space.dx,space.dy
nx,ny=space.nx,space.ny

wind = WindForcing.from_config(run.summary.configuration.windstress, run.summary.configuration.space,run.summary.configuration.physics)
tx,ty = wind.compute()

outputs = run.outputs()
uvh0: UVH = next(outputs).read()
sf_init = P.compute_p(covphys.to_cov(uvh0, dx,dy))[0]/f0

model_3l= QGPSIQ(
    space_2d=space.remove_z_h(),
    H = H,
    beta_plane=run.summary.configuration.physics.beta_plane,
    g_prime=g_prime,
)
model_3l.set_wind_forcing(tx,ty)
model_3l.masks = Masks.empty_tensor(model_3l.space.nx,model_3l.space.ny,device=defaults.get_device())
model_3l.bottom_drag_coef = run.summary.configuration.physics.bottom_drag_coefficient
model_3l.slip_coef = run.summary.configuration.physics.slip_coef

time_stepper = "rk3" #"euler" #

dt = 3600 if time_stepper == "rk3" else 360

model_3l.dt = dt
model_3l.time_stepper = time_stepper

### Slice

In [None]:
imins = [32, 32, 112, 112]
imaxs = [i + 64 for i in imins]

jmins = [64, 256, 64, 256]
jmaxs = [j+128 for j in jmins]

In [None]:
plots.imshow(sf_init[0,0])
for imin,imax, jmin, jmax in zip(imins, imaxs,jmins,jmaxs):
    plt.hlines([jmin,jmax],imin,imax)
    plt.vlines([imin,imax],jmin,jmax)
plots.show()

#### Models

In [None]:
from qgsw.fields.variables.coefficients.core import UniformCoefficient
from qgsw.models.qg.psiq.filtered.core import QGPSIQCollinearSF

h1,h2,h3 = H
g1, g2, g3 = g_prime 
Heq = (H[1:2]*H[:1])/(H[1:2]+H[:1])

def compute_slices(imin:int,imax:int,jmin:int,jmax:int) -> tuple[list[slice,slice],list[slice,slice]]:

    psi_slices = [slice(imin,imax+1),slice(jmin,jmax+1)]
    q_slices = [slice(imin,imax),slice(jmin,jmax)]

    return psi_slices, q_slices

def build_models(imin:int,imax:int,jmin:int,jmax:int) -> tuple[QGPSIQ,...]:
    psi_slices,q_slices = compute_slices(imin,imax,jmin,jmax)
    space_2d = SpaceDiscretization2D.from_tensors(
        x=P.space.remove_z_h().omega.xy.x[imin:imax+1,0],
        y=P.space.remove_z_h().omega.xy.y[0,jmin:jmax+1],
    )

    model_1l = QGPSIQ(
        space_2d=space_2d,
        H = H[:1]+H[1:2],
        beta_plane=run.summary.configuration.physics.beta_plane,
        g_prime=g_prime[:1],
    )
    model_1l.masks = Masks.empty_tensor(model_1l.space.nx,model_1l.space.ny,device=defaults.get_device())
    model_1l.y0 = model_3l.y0
    model_1l.wide = True
    model_1l.bottom_drag_coef = run.summary.configuration.physics.bottom_drag_coefficient*0
    model_1l.slip_coef = run.summary.configuration.physics.slip_coef


    model_1l_alpha = QGPSIQCollinearSF(
        space_2d=space_2d,
        H = H[:2],
        beta_plane=run.summary.configuration.physics.beta_plane,
        g_prime=g_prime[:2],
    )
    model_1l_alpha.masks = Masks.empty_tensor(model_1l_alpha.space.nx,model_1l_alpha.space.ny,device=defaults.get_device())
    model_1l_alpha.y0 = model_3l.y0
    model_1l_alpha.wide = True
    model_1l_alpha.bottom_drag_coef = run.summary.configuration.physics.bottom_drag_coefficient*0
    model_1l_alpha.slip_coef = run.summary.configuration.physics.slip_coef

    model_1l_alpha_mf = QGPSIQCollinearSF(
        space_2d=space_2d,
        H = H[:2],
        beta_plane=run.summary.configuration.physics.beta_plane,
        g_prime=g_prime[:2],
    )
    model_1l_alpha_mf.masks = Masks.empty_tensor(model_1l_alpha_mf.space.nx,model_1l_alpha_mf.space.ny,device=defaults.get_device())
    model_1l_alpha_mf.y0 = model_3l.y0
    model_1l_alpha_mf.wide = True
    model_1l_alpha_mf.bottom_drag_coef = run.summary.configuration.physics.bottom_drag_coefficient*0
    model_1l_alpha_mf.slip_coef = run.summary.configuration.physics.slip_coef

    model_rg= QGPSIQ(
        space_2d=space_2d,
        H = Heq,
        beta_plane=run.summary.configuration.physics.beta_plane,
        g_prime=g_prime[1:2],
    )
    model_rg.masks = Masks.empty_tensor(model_rg.space.nx,model_rg.space.ny,device=defaults.get_device())
    model_rg.y0 = model_3l.y0
    model_rg.wide = True
    model_rg.bottom_drag_coef = run.summary.configuration.physics.bottom_drag_coefficient*0
    model_rg.slip_coef = run.summary.configuration.physics.slip_coef


    model_2l= QGPSIQ(
        space_2d=space_2d,
        H = H[:2],
        beta_plane=run.summary.configuration.physics.beta_plane,
        g_prime=g_prime[:2],
    )
    model_2l.masks = Masks.empty_tensor(model_2l.space.nx,model_2l.space.ny,device=defaults.get_device())
    model_2l.y0 = model_3l.y0
    model_2l.wide = True
    model_2l.bottom_drag_coef = run.summary.configuration.physics.bottom_drag_coefficient
    model_2l.slip_coef = run.summary.configuration.physics.slip_coef


    model_3l_= QGPSIQ(
        space_2d=space_2d,
        H = H[:3],
        beta_plane=run.summary.configuration.physics.beta_plane,
        g_prime=g_prime[:3],
    )
    model_3l_.masks = Masks.empty_tensor(model_3l_.space.nx,model_3l_.space.ny,device=defaults.get_device())
    model_3l_.y0 = model_3l.y0
    model_3l_.wide = True
    model_3l_.bottom_drag_coef = run.summary.configuration.physics.bottom_drag_coefficient
    model_3l_.slip_coef = run.summary.configuration.physics.slip_coef


    model_1l.dt = dt
    model_1l.time_stepper = time_stepper
    model_1l_alpha.dt = dt
    model_1l_alpha.time_stepper = time_stepper
    model_1l_alpha_mf.dt = dt
    model_1l_alpha_mf.time_stepper = time_stepper
    model_rg.dt = dt
    model_rg.time_stepper = time_stepper
    model_2l.dt = dt
    model_2l.time_stepper = time_stepper
    model_3l_.dt = dt
    model_3l_.time_stepper = time_stepper

    return model_rg, model_1l,model_1l_alpha,model_1l_alpha_mf,model_2l,model_3l_

In [None]:
from matplotlib.axes import Axes

from qgsw.utils.interpolation import ConstantInterpolation


def rmse(psi:torch.Tensor, psi_ref:torch.Tensor) -> float:
    return (torch.sqrt(torch.mean((psi-psi_ref)**2))/psi_ref.abs().mean()).cpu().item()


model_3l.set_psi(sf_init)
model_3l.reset_time()

for _ in range(1,1):
    model_3l.step()

model_3l.reset_time()
sf_0,q_0 = model_3l.prognostic.psiq


times: list[float] = [model_3l.time.item()]

psis_3l: list[torch.Tensor] = [model_3l.psi]
qs_3l: list[torch.Tensor] = [model_3l.q]

n_steps = 500

for _ in range(1,n_steps):
    model_3l.step()
    times.append(model_3l.time.item())

    psis_3l.append(model_3l.psi)
    qs_3l.append(model_3l.q)

res_persistency = []
res_mf = []
res_rg = []
res_1l = []
res_1l_alpha = []
res_1l_alpha_mf = []
res_2l = []
res_3l_ = []


filt = GaussianFilter2D(sigma=10)
k = filt.window_radius
p = 4

mfs = []

for indices in zip(imins,imaxs,jmins,jmaxs):
    model_rg, model_1l, model_1l_alpha, model_1l_alpha_mf, model_2l, model_3l_ = build_models(*indices)
    psi_slices, q_slices = compute_slices(*indices)

    
    psi_mean_slice = [slice(s.start-k-p,s.stop+k+p) for s in psi_slices]

    psi_bar = torch.mean(torch.stack([torch.stack(
        [
            filt(psi[0,0,*psi_mean_slice])[None,...],
            filt(psi[0,1,*psi_mean_slice])[None,...],
            filt(psi[0,2,*psi_mean_slice])[None,...],
        ],dim=1
    ) for psi in psis_3l],dim=0),dim=0)

    q_bar = points_to_surfaces(laplacian(psi_bar,dx,dy) - f0**2*torch.einsum("lm,...mxy->...lxy",model_3l.A,psi_bar[...,1:-1,1:-1]))

    q_bar_bc = Boundaries.extract(q_bar,k+2,-k-3,k+2,-k-3,3)
    psi_bar_bc = Boundaries.extract(psi_bar, k+p,-k-p-1,k+p,-k-p-1,2)
    q_bar = q_bar[...,k+3:-k-3,k+3:-k-3]
    psi_bar = psi_bar[...,k+p:-k-p,k+p:-k-p]

    mfs.append(psi_bar[0,0])

    imin,imax,jmin,jmax = indices

    psis_3l_bc: list[Boundaries] = [Boundaries.extract(psi, imin,imax+1,jmin,jmax+1,width=2)for psi in psis_3l]
    qs_3l_bc : list[Boundaries]= [Boundaries.extract(q, imin-1,imax+1,jmin-1,jmax+1,width=3)for q in qs_3l]

    model_rg.set_wind_forcing(tx[imin:imax,jmin:jmax+1],ty[imin:imax+1,jmin:jmax])
    model_rg.set_boundary_maps(QuadraticInterpolation(times,[bc[:,:1] for bc in psis_3l_bc]),QuadraticInterpolation(times,[bc[:,:1] for bc in qs_3l_bc]))
    model_rg.set_psiq(sf_0[:,:1,*psi_slices],q_0[:,:1,*q_slices])
    model_rg.reset_time()

    model_1l.set_wind_forcing(tx[imin:imax,jmin:jmax+1],ty[imin:imax+1,jmin:jmax])
    model_1l.set_boundary_maps(QuadraticInterpolation(times,[bc[:,:1] for bc in psis_3l_bc]),QuadraticInterpolation(times,[bc[:,:1] for bc in qs_3l_bc]))
    model_1l.set_psiq(sf_0[:,:1,*psi_slices],q_0[:,:1,*q_slices])
    model_1l.reset_time()

    alpha = UniformCoefficient.compute_optimal_values(
        psis_3l[0][0,0,*psi_slices],
        psis_3l[0][0,1,*psi_slices]
    )
    # alpha=UniformCoefficient.compute_optimal_values((psis_3l[1]-psis_3l[0])[0,0,*psi_slices],(psis_3l[1]-psis_3l[0])[0,1,*psi_slices])
    print(alpha)
    model_1l_alpha.alpha =torch.ones_like(model_1l_alpha.psi)*alpha
    model_1l_alpha.set_wind_forcing(tx[imin:imax,jmin:jmax+1],ty[imin:imax+1,jmin:jmax])
    model_1l_alpha.set_boundary_maps(QuadraticInterpolation(times,[bc[:,:1] for bc in psis_3l_bc]),QuadraticInterpolation(times,[bc[:,:1] for bc in qs_3l_bc]))
    model_1l_alpha.set_psiq(sf_0[:,:1,*psi_slices],q_0[:,:1,*q_slices])
    model_1l_alpha.reset_time()

    alpha_mf = UniformCoefficient.compute_optimal_values(
        psis_3l[0][0,0,*psi_slices]-psi_bar[0,0],
        psis_3l[0][0,1,*psi_slices]-psi_bar[0,1]
    )
    
    print(alpha_mf)
    model_1l_alpha_mf.alpha =torch.ones_like(model_1l_alpha_mf.psi)*alpha_mf
    model_1l_alpha_mf.set_wind_forcing(tx[imin:imax,jmin:jmax+1],ty[imin:imax+1,jmin:jmax])
    model_1l_alpha_mf.set_boundary_maps(QuadraticInterpolation(times,[bc[:,:1] for bc in psis_3l_bc]),QuadraticInterpolation(times,[bc[:,:1] for bc in qs_3l_bc]))
    model_1l_alpha_mf.set_mean_flow(
        ConstantInterpolation(psi_bar[:,:1]),
        ConstantInterpolation(q_bar[:,:1]),
        ConstantInterpolation(psi_bar_bc[:,:1]),
        ConstantInterpolation(q_bar_bc[:,:1])
        )
    model_1l_alpha_mf.set_psiq(sf_0[:,:1,*psi_slices],q_0[:,:1,*q_slices])
    model_1l_alpha_mf.reset_time()

    model_2l.set_wind_forcing(tx[imin:imax,jmin:jmax+1],ty[imin:imax+1,jmin:jmax])
    model_2l.set_boundary_maps(QuadraticInterpolation(times,[bc[:,:2] for bc in psis_3l_bc]),QuadraticInterpolation(times,[bc[:,:2] for bc in qs_3l_bc]))
    model_2l.set_psiq(sf_0[:,:2,*psi_slices],q_0[:,:2,*q_slices])
    model_2l.reset_time()

    model_3l_.set_wind_forcing(tx[imin:imax,jmin:jmax+1],ty[imin:imax+1,jmin:jmax])
    model_3l_.set_boundary_maps(QuadraticInterpolation(times,psis_3l_bc),QuadraticInterpolation(times,qs_3l_bc))
    model_3l_.set_psiq(sf_0[:,:3,*psi_slices],q_0[:,:3,*q_slices])
    model_3l_.reset_time()


    errs_persistency :list[float] = [rmse(sf_0[0,0,*psi_slices],psis_3l[0][0,0,*psi_slices])]
    errs_mf :list[float] = [rmse(psi_bar[0,0],psis_3l[0][0,0,*psi_slices])]
    errs_rg:list[float] = [rmse(model_rg.psi[0,0],psis_3l[0][0,0,*psi_slices])]
    errs_1l:list[float] = [rmse(model_1l.psi[0,0],psis_3l[0][0,0,*psi_slices])]
    errs_1l_alpha:list[float] = [rmse(model_1l_alpha.psi[0,0],psis_3l[0][0,0,*psi_slices])]
    errs_1l_alpha_mf:list[float] = [rmse(model_1l_alpha_mf.psi[0,0],psis_3l[0][0,0,*psi_slices])]
    errs_2l:list[float] = [rmse(model_2l.psi[0,0],psis_3l[0][0,0,*psi_slices])]
    errs_3l:list[float] = [rmse(model_3l_.psi[0,0],psis_3l[0][0,0,*psi_slices])]
    err_times:list[float] = [model_rg.time.item()]

    for i in range(1,n_steps):

        model_rg.step()
        model_1l.step()
        model_1l_alpha.step()
        model_1l_alpha_mf.step()
        model_2l.step()
        model_3l_.step()
        
        err_times.append(model_rg.time.item())

        errs_persistency.append(rmse(sf_0[0,0,*psi_slices],psis_3l[i][0,0,*psi_slices]))
        errs_mf.append(rmse(psi_bar[0,0],psis_3l[i][0,0,*psi_slices]))
        errs_rg.append(rmse(model_rg.psi[0,0],psis_3l[i][0,0,*psi_slices]))
        errs_1l.append(rmse(model_1l.psi[0,0],psis_3l[i][0,0,*psi_slices]))
        errs_1l_alpha.append(rmse(model_1l_alpha.psi[0,0],psis_3l[i][0,0,*psi_slices]))
        errs_1l_alpha_mf.append(rmse(model_1l_alpha_mf.psi[0,0],psis_3l[i][0,0,*psi_slices]))
        errs_2l.append(rmse(model_2l.psi[0,0],psis_3l[i][0,0,*psi_slices]))
        errs_3l.append(rmse(model_3l_.psi[0,0],psis_3l[i][0,0,*psi_slices]))
    
    res_persistency.append(errs_persistency)
    res_mf.append(errs_persistency)
    res_1l.append(errs_1l)
    res_1l_alpha.append(errs_1l_alpha)
    res_1l_alpha_mf.append(errs_1l_alpha_mf)
    res_rg.append(errs_rg)
    res_2l.append(errs_2l)
    res_3l_.append(errs_3l)

In [None]:
fig, axs = plots.subplots(len(imins),3, figsize=(20,15))

for i, indices in enumerate(zip(imins,imaxs,jmins,jmaxs)):

    imin,imax,jmin,jmax =indices

    plots.imshow(sf_init[0,0], ax=axs[i,0])
    axs[i,0].hlines([jmin,jmax],imin,imax)
    axs[i,0].vlines([imin,imax],jmin,jmax)

    axs[i,1].plot([t/3600/24 for t in err_times], res_rg[i],label="Reduced gravity")
    axs[i,1].plot([t/3600/24 for t in err_times], res_1l[i],label="One layer")
    axs[i,1].plot([t/3600/24 for t in err_times], res_1l_alpha[i],label="ɑ")
    axs[i,1].plot([t/3600/24 for t in err_times], res_1l_alpha_mf[i],label="ɑ Mean flow")
    axs[i,1].plot([t/3600/24 for t in err_times], res_2l[i],label="Two layers")
    axs[i,1].plot([t/3600/24 for t in err_times], res_3l_[i],label="Three layers")
    axs[i,1].plot([t/3600/24 for t in err_times], res_persistency[i],label="Persistency")
    axs[i,1].plot([t/3600/24 for t in err_times], res_mf[i],label="Persistency mf")
    axs[i,1].legend()

    plots.imshow(mfs[i],ax=axs[i,2])

plt.show()

In [None]:
UniformCoefficient.compute_optimal_values((psis_3l[1]-psis_3l[0])[0,0,*psi_slices],(psis_3l[1]-psis_3l[0])[0,1,*psi_slices])