In [None]:
%load_ext autoreload
%autoreload 2

## Diagnostics of the mean flow

### 3-layer model

In [None]:
from matplotlib import pyplot as plt
from qgsw import plots
import torch
from qgsw.configs.space import SpaceConfig
from qgsw.fields.variables.physical import PotentialVorticity, QGPressure
from qgsw.fields.variables.tuples import UVH
from qgsw.forcing.wind import WindForcing
from qgsw.masks import Masks
from qgsw.models.qg.psiq.core import QGPSIQ
from qgsw.models.qg.stretching_matrix import compute_A
from qgsw.models.qg.uvh.core import QG
from qgsw.models.qg.uvh.projectors.core import QGProjector
from qgsw.output import RunOutput
from qgsw.physics.coriolis.beta_plane import BetaPlane
from qgsw.solver.finite_diff import laplacian_h
from qgsw.spatial.core.discretization import SpaceDiscretization2D, SpaceDiscretization3D
from qgsw.spatial.core.grid_conversion import points_to_surfaces
from qgsw.specs import defaults
from qgsw.utils import covphys
from qgsw.utils.units._units import Unit
import torch.nn.functional as F

run = RunOutput("../output/g5k/sw_double_gyre_long_hr")

H = run.summary.configuration.model.h
g_prime = run.summary.configuration.model.g_prime
f0 = run.summary.configuration.physics.f0
beta = run.summary.configuration.physics.beta
P = QGProjector(
    A =compute_A(
        H = H,
        g_prime = g_prime
    ),
    H = H.unsqueeze(-1).unsqueeze(-1),
    space=SpaceDiscretization3D.from_config(
        run.summary.configuration.space,
        run.summary.configuration.model
    ),
    f0 = run.summary.configuration.physics.f0,
    masks = Masks.empty(nx=run.summary.configuration.space.nx,ny=run.summary.configuration.space.ny)
)
A = P.A
space=P.space
dx,dy = space.dx,space.dy
nx,ny=space.nx,space.ny

wind = WindForcing.from_config(run.summary.configuration.windstress, run.summary.configuration.space,run.summary.configuration.physics)
tx,ty = wind.compute()

outputs = run.outputs()
uvh0: UVH = next(outputs).read()
sf_init = P.compute_p(covphys.to_cov(uvh0, dx,dy))[0]/f0

beta_effect = beta*(space.q.xyh.y[0,0,:][None,:] - space.ly/2)

model= QGPSIQ(
    space_2d=space.remove_z_h(),
    H = H,
    beta_plane=run.summary.configuration.physics.beta_plane,
    g_prime=g_prime,
)
model.set_wind_forcing(tx,ty)
model.masks = Masks.empty_tensor(model.space.nx,model.space.ny,device=defaults.get_device())
model.dt = 3600
model.bottom_drag_coef = run.summary.configuration.physics.bottom_drag_coefficient
model.slip_coef = run.summary.configuration.physics.slip_coef

In [None]:
imin, imax = 16,48
jmin,jmax = 250, 314

In [None]:
from qgsw.filters.gaussian import GaussianFilter2D
from qgsw.solver.finite_diff import laplacian

filt = GaussianFilter2D(sigma=12)
k = filt.window_radius

model.set_psi(sf_init)

sfs = [model.psi[...,imin-1-k:imax+2+k,jmin-1-k:jmax+2+k]]
qs = [model.q[...,imin:imax,jmin:jmax]]
times = [model.time.item()/3600/24]
for _ in range(500):
    model.step()
    sfs.append(model.psi[...,imin-1-k:imax+2+k,jmin-1-k:jmax+2+k])
    qs.append(model.q[...,imin:imax,jmin:jmax])
    times.append(model.time.item()/3600/24)

sf_bar_wide = torch.stack(
    [
        filt(torch.stack(sfs).mean(dim=0)[0,0]),
        filt(torch.stack(sfs).mean(dim=0)[0,1]),
        filt(torch.stack(sfs).mean(dim=0)[0,2])
    ]
).unsqueeze(0)
sf_bar = sf_bar_wide[...,k:-k,k:-k]
q_bar = points_to_surfaces(laplacian(sf_bar,dx,dy) - f0**2*torch.einsum("lm,...mxy->...lxy",A,sf_bar[...,1:-1,1:-1])) + beta_effect[...,jmin:jmax]
q_mean =  torch.stack(qs).mean(dim=0)

sf_anoms = [sf[...,k+1:-k-1,k+1:-k-1]-sf_bar[...,1:-1,1:-1] for sf in sfs]
q_anoms = [q-q_bar for q in qs]


In [None]:
fig, axs = plots.subplots(3,4)
fig.suptitle("Mean flows")

plots.imshow(sf_bar[0,0],ax=axs[0,0], title="ѱ_bar - top")
plots.imshow(q_bar[0,0],ax=axs[0,1],title ="q_bar - top")
plots.imshow((q_bar-q_mean)[0,0],ax=axs[0,2],title ="q_bar-q_mean - top")
plots.imshow(q_mean[0,0],ax=axs[0,3],title ="q_mean - top")
plots.imshow(sf_bar[0,1],ax=axs[1,0], title="ѱ_bar - middle")
plots.imshow(q_bar[0,1],ax=axs[1,1],title ="q_bar - middle")
plots.imshow((q_bar-q_mean)[0,1],ax=axs[1,2],title ="q_bar-q_mean - middle")
plots.imshow(q_mean[0,1],ax=axs[1,3],title ="q_mean - middle")
plots.imshow(sf_bar[0,2],ax=axs[2,0], title="ѱ_bar - bottom")
plots.imshow(q_bar[0,2],ax=axs[2,1],title ="q_bar - bottom")
plots.imshow((q_bar-q_mean)[0,2],ax=axs[2,2],title ="q_bar-q_mean - bottom")
plots.imshow(q_mean[0,2],ax=axs[2,3],title ="q_mean - bottom")

plots.show()

In [None]:
from pathlib import Path
from qgsw.plots.heatmaps import AnimatedHeatmaps



plot = AnimatedHeatmaps(
    [
        [sf[0,0].T.cpu() for sf in sf_anoms],
        [sf[0,1].T.cpu() for sf in sf_anoms],
        [sf[0,2].T.cpu() for sf in sf_anoms]
    ]
)
plot.set_frame_labels([f"Time: {t:.2f} day{'s' if t>1 else ''}" for t in times])
plot.save_video("../output/videos/sf_anom.mp4", width=1500)


plot = AnimatedHeatmaps(
    [
        [q[0,0].T.cpu() for q in q_anoms],
        [q[0,1].T.cpu() for q in q_anoms],
        [q[0,2].T.cpu() for q in q_anoms]
    ]
)
plot.set_frame_labels([f"Time: {t:.2f} day{'s' if t>1 else ''}" for t in times])
# plot.save_video("../output/videos/q_anom.mp4", width=1500, fps=15)