In [None]:
%load_ext autoreload
%autoreload 2

## Diagnostics of the mean flow

### 3-layer model

In [None]:
from qgsw import plots
import torch
from qgsw.fields.variables.tuples import UVH
from qgsw.forcing.wind import WindForcing
from qgsw.masks import Masks
from qgsw.models.qg.psiq.core import QGPSIQ
from qgsw.models.qg.stretching_matrix import compute_A
from qgsw.models.qg.uvh.projectors.core import QGProjector
from qgsw.output import RunOutput
from qgsw.spatial.core.discretization import SpaceDiscretization2D, SpaceDiscretization3D
from qgsw.spatial.core.grid_conversion import interpolate
from qgsw.specs import defaults
from qgsw.utils import covphys
from qgsw.filters.gaussian import GaussianFilter2D
from qgsw.solver.boundary_conditions.base import Boundaries
from qgsw.solver.finite_diff import laplacian
from qgsw.utils.interpolation import LinearInterpolation

run = RunOutput("../output/g5k/sw_double_gyre_long_hr")

H = run.summary.configuration.model.h
g_prime = run.summary.configuration.model.g_prime
f0 = run.summary.configuration.physics.f0
beta = run.summary.configuration.physics.beta
P = QGProjector(
    A =compute_A(
        H = H,
        g_prime = g_prime
    ),
    H = H.unsqueeze(-1).unsqueeze(-1),
    space=SpaceDiscretization3D.from_config(
        run.summary.configuration.space,
        run.summary.configuration.model
    ),
    f0 = run.summary.configuration.physics.f0,
    masks = Masks.empty(nx=run.summary.configuration.space.nx,ny=run.summary.configuration.space.ny)
)
A = P.A
space=P.space
dx,dy = space.dx,space.dy
nx,ny=space.nx,space.ny

wind = WindForcing.from_config(run.summary.configuration.windstress, run.summary.configuration.space,run.summary.configuration.physics)
tx,ty = wind.compute()

outputs = run.outputs()
uvh0: UVH = next(outputs).read()
sf_init = P.compute_p(covphys.to_cov(uvh0, dx,dy))[0]/f0

beta_effect = beta*(space.q.xyh.y[0,0,:][None,:] - space.ly/2)

In [None]:
imin, imax = 32,96
jmin,jmax = 256, 384

In [None]:
filt = GaussianFilter2D(sigma=10)
k = filt.window_radius

p = 4 # To be able to compute boundary conditions

model= QGPSIQ(
    space_2d=space.remove_z_h(),
    H = H,
    beta_plane=run.summary.configuration.physics.beta_plane,
    g_prime=g_prime,
)
model.set_wind_forcing(tx,ty)
model.masks = Masks.empty_tensor(model.space.nx,model.space.ny,device=defaults.get_device())
model.dt = 360
model.time_stepper = "euler"
model.bottom_drag_coef = run.summary.configuration.physics.bottom_drag_coefficient
model.slip_coef = run.summary.configuration.physics.slip_coef
model.set_psi(sf_init)


sfs = [model.psi[...,imin-k-p:imax+1+k+p,jmin-k-p:jmax+1+k+p]]
qs = [model.q[...,imin-(p-1):imax+(p-1),jmin-(p-1):jmax+(p-1)]]
times = [model.time.item()]

sf_bcs:list[Boundaries] = [Boundaries.extract(model.psi,imin,imax+1,jmin,jmax+1,width=2)]
pv_bcs :list[Boundaries]= [Boundaries.extract(model.q,imin-1,imax+1,jmin-1,jmax+1,width=3)]

for _ in range(100):
    model.step()
    sfs.append(model.psi[...,imin-k-p:imax+1+k+p,jmin-k-p:jmax+1+k+p])
    qs.append(model.q[...,imin-(p-1):imax+(p-1),jmin-(p-1):jmax+(p-1)])
    sf_bcs.append(Boundaries.extract(model.psi,imin,imax+1,jmin,jmax+1,width=2))
    pv_bcs.append(Boundaries.extract(model.q,imin-1,imax+1,jmin-1,jmax+1,width=3))
    times.append(model.time.item())

sf_bc_interp  = LinearInterpolation(times,sf_bcs)
pv_bc_interp = LinearInterpolation(times, pv_bcs)

sf_bar_wide = torch.stack(
    [
        filt(torch.stack(sfs).mean(dim=0)[0,0]),
        filt(torch.stack(sfs).mean(dim=0)[0,1]),
        filt(torch.stack(sfs).mean(dim=0)[0,2])
    ]
).unsqueeze(0)
sf_bar = sf_bar_wide[...,k:-k,k:-k]
q_bar = interpolate(laplacian(sf_bar,dx,dy) - f0**2*torch.einsum("lm,...mxy->...lxy",A,sf_bar[...,1:-1,1:-1])) + beta_effect[...,jmin-(p-1):jmax+(p-1)]
q_mean =  torch.stack(qs).mean(dim=0)

sf_anoms = [sf[...,k+p:-k-p,k+p:-k-p]-sf_bar[...,p:-p,p:-p] for sf in sfs]
q_anoms = [q-q_bar for q in qs]

In [None]:
fig, axs = plots.subplots(3,4)
fig.suptitle("Mean flows")

plots.imshow(sf_bar[0,0],ax=axs[0,0], title="ѱ_bar - top")
plots.imshow(q_bar[0,0],ax=axs[0,1],title ="q_bar - top")
plots.imshow((q_bar-q_mean)[0,0],ax=axs[0,2],title ="q_bar-q_mean - top")
plots.imshow(q_mean[0,0],ax=axs[0,3],title ="q_mean - top")
plots.imshow(sf_bar[0,1],ax=axs[1,0], title="ѱ_bar - middle")
plots.imshow(q_bar[0,1],ax=axs[1,1],title ="q_bar - middle")
plots.imshow((q_bar-q_mean)[0,1],ax=axs[1,2],title ="q_bar-q_mean - middle")
plots.imshow(q_mean[0,1],ax=axs[1,3],title ="q_mean - middle")
plots.imshow(sf_bar[0,2],ax=axs[2,0], title="ѱ_bar - bottom")
plots.imshow(q_bar[0,2],ax=axs[2,1],title ="q_bar - bottom")
plots.imshow((q_bar-q_mean)[0,2],ax=axs[2,2],title ="q_bar-q_mean - bottom")
plots.imshow(q_mean[0,2],ax=axs[2,3],title ="q_mean - bottom")

plots.show()

In [None]:
sf_bar_interp = LinearInterpolation(times,[sf_bar[...,4:-4,4:-4] for _ in sfs])
pv_bar_interp = LinearInterpolation(times,[interpolate(laplacian(sf_bar,dx,dy)- f0**2*torch.einsum("lm,...mxy->lxy",A,sf_bar[...,1:-1,1:-1]))[...,3:-3,3:-3] + beta_effect[:,jmin:jmax] for _ in sfs])
sf_bar_bc_interp = LinearInterpolation(times,[Boundaries.extract(sf_bar,4,-5,4,-5,2) for _ in sfs])
pv_bar_bc_interp = LinearInterpolation(times,[Boundaries.extract(interpolate(laplacian(sf_bar,dx,dy)- f0**2*torch.einsum("lm,...mxy->lxy",A,sf_bar[...,1:-1,1:-1])) + beta_effect[:,jmin-3:jmax+3],2,-3,2,-3,3) for _ in sfs])

In [None]:
space_2d = SpaceDiscretization2D.from_tensors(
    x=P.space.remove_z_h().omega.xy.x[imin:imax+1,0],
    y=P.space.remove_z_h().omega.xy.y[0,jmin:jmax+1],
)

model_mf= QGPSIQ(
    space_2d=space_2d,
    H = H,
    beta_plane=run.summary.configuration.physics.beta_plane,
    g_prime=g_prime,
)
model_mf.set_wind_forcing(tx[imin:imax,jmin:jmax+1],ty[imin:imax+1,jmin:jmax])
model_mf.masks = Masks.empty_tensor(model_mf.space.nx,model_mf.space.ny,device=defaults.get_device())
model_mf.time_stepper = "euler"
model_mf.y0 = space.ly/2
model_mf.dt = 360
model_mf.bottom_drag_coef = run.summary.configuration.physics.bottom_drag_coefficient
model_mf.slip_coef = run.summary.configuration.physics.slip_coef
model_mf.wide = True
model_mf.set_boundary_maps(sf_bc_interp,pv_bc_interp)
model_mf.set_mean_flow(sf_bar_interp,pv_bar_interp,sf_bar_bc_interp,pv_bar_bc_interp)
model_mf.set_psi(sf_init[...,imin:imax+1,jmin:jmax+1])

In [None]:
for _ in range(1):
    model_mf.step()

In [None]:
plots.imshow((model_mf.psi[0,0]-sfs[1][0,0,(k+p):-(k+p),(k+p):-(k+p)])/sfs[1][0,0,(k+p):-(k+p),(k+p):-(k+p)].max())

In [None]:
plots.imshow(model_mf.perturbation.psi[0,0])

In [None]:
plots.imshow(model_mf.mean_flow.psi[0,0])