In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
import matplotlib.pyplot as plt
import torch

from qgsw import fft, plots
from qgsw.fields.variables.tuples import PSIQT
from qgsw.models.qg.stretching_matrix import compute_A, compute_layers_to_mode_decomposition
from qgsw.output import RunOutput
from qgsw.solver.boundary_conditions.base import Boundaries
from qgsw.solver.boundary_conditions.interpolation import BilinearExtendedBoundary, TimeLinearInterpolation
from qgsw.solver.finite_diff import laplacian
from qgsw.solver.helmholtz import compute_laplace_dstI
from qgsw.solver.pv_inversion import InhomogeneousPVInversion
from qgsw.spatial.core.discretization import SpaceDiscretization3D
from qgsw.spatial.core.grid_conversion import points_to_surfaces
from qgsw.specs import defaults
from qgsw.utils.sorting import sort_files
import torch.nn.functional as F
from qgsw import specs

In [None]:
from qgsw.masks import Masks
from qgsw.models.qg.uvh.projectors.core import QGProjector


run = RunOutput("../output/g5k/sw_double_gyre_long_hr")

H = run.summary.configuration.model.h
g_prime = run.summary.configuration.model.g_prime
f0 = run.summary.configuration.physics.f0
A = compute_A(
    H = H,
    g_prime = g_prime
)
A_1l = compute_A(
    H = H.sum(dim=-1,keepdim=True),
    g_prime = g_prime[:1]
)
P = QGProjector(
    A =compute_A(
        H = H,
        g_prime = g_prime
    ),
    H = H.unsqueeze(-1).unsqueeze(-1),
    space=SpaceDiscretization3D.from_config(
        run.summary.configuration.space,
        run.summary.configuration.model
    ),
    f0 = run.summary.configuration.physics.f0,
    masks = Masks.empty(nx=run.summary.configuration.space.nx,ny=run.summary.configuration.space.ny)
)
space=SpaceDiscretization3D.from_config(
    run.summary.configuration.space,
    run.summary.configuration.model
)
dx = space.dx
dy = space.dy
y = torch.linspace(
    0.5 * space.dy,
    space.ly - 0.5 * space.dy,
    space.ny,
    **defaults.get()
).unsqueeze(0)
y0 = 0.5 * space.ly
beta_effect = run.summary.configuration.physics.beta_plane.beta * (y - y0)


In [None]:
imin, imax = 16,48
jmin,jmax = 250, 314
ref_folder = Path("../output/local/qgpsiq_1L_small")
ts, fs = sort_files(files=ref_folder.glob("results_step_*.pt"),prefix="results_step_",suffix=".pt")

# Inversion of dq

In [None]:
## Load data
psi_big, q, _ = PSIQT.from_file(fs[0])
psi_big_int = psi_big[..., :1 ,1:-1,1:-1]
q_big = laplacian(psi_big,dx,dy) - f0**2 * torch.einsum("lm,...mxy->...lxy",A_1l,psi_big_int)

psi_small = psi_big[...,:1, imin:imax+1,jmin:jmax+1]
psi_small_int = psi_small[...,1:-1,1:-1]
q_small = laplacian(psi_small,dx,dy) - f0**2 * torch.einsum("lm,...mxy->...lxy",A_1l,psi_small_int)
q_small_from_big = q_big[...,imin:imax-1,jmin:jmax-1]

In [None]:
from qgsw.solver.pv_inversion import HomogeneousPVInversion

psi_big_inv = HomogeneousPVInversion(A_1l,f0,dx,dy).compute_stream_function(q_big,ensure_mass_conservation=True)[...,1:-1,1:-1]

psi_big_inv_slice = psi_big_inv[...,imin:imax-1,jmin:jmax-1]

fig,axs = plots.subplots(1,3)
fig.suptitle("Big domain")
plots.imshow(q_big[0,0],ax=axs[0,0], title="PV anomaly")
plots.imshow(psi_big_inv[0,0],ax=axs[0,1], title="ѱ After inversion")
plots.imshow(psi_big_inv_slice[0,0],ax=axs[0,2],title="Slice")
plots.show()
fig,axs = plots.subplots(1,3)
fig.suptitle("Comparison with the orignial ѱ field")
plots.imshow(psi_big_inv[0,0],ax=axs[0,0], title="ѱ slice after inversion in big domain")
plots.imshow(psi_big_inv[0,0]-psi_big_int[0,0],ax=axs[0,1], title="Comparison")
plots.imshow(psi_big_int[0,0],ax=axs[0,2],title="ѱ from big domain")
plots.show()


boundary = Boundaries(top=psi_small[...,:,-1],bottom=psi_small[...,:,0],left=psi_small[...,0,:],right=psi_small[...,-1,:])

solver = InhomogeneousPVInversion(A_1l,f0,dx,dy)
solver.set_boundaries(boundary)
psi_small_inv = solver.compute_stream_function(q_small)[...,1:-1,1:-1]

fig,axs = plots.subplots(1,4)
fig.suptitle("Small domain")
plots.imshow(q_small[0,0],ax=axs[0,0], title="PV")
plots.imshow(solver.psiq_h.psi[0,0],ax=axs[0,1], title = "ѱ_h from inversion")
plots.imshow(solver.psiq_b.psi[0,0],ax=axs[0,2], title="ѱ_bc from boundary")
plots.imshow(psi_small_inv[0,0],ax=axs[0,3],title="ѱ = ѱ_h + ѱ_bc")
plots.show()

fig,axs = plots.subplots(1,3)
fig.suptitle("Comparison with the original ѱ field")
plots.imshow(psi_big_int[0,0,imin:imax-1,jmin:jmax-1],ax=axs[0,0], title="ѱ slice from big domain")
plots.imshow(psi_big_int[0,0,imin:imax-1,jmin:jmax-1]-psi_small_inv[0,0],ax=axs[0,1], title="Comparison")
plots.imshow(psi_small_inv[0,0],ax=axs[0,2],title="ѱ from small domain")
plots.show()


fig,axs = plots.subplots(1,3)
fig.suptitle("Comparison with the reconstructed ѱ in the big domain")
plots.imshow(psi_big_inv_slice[0,0],ax=axs[0,0], title="ѱ slice after inversion in big domain")
plots.imshow(psi_big_inv_slice[0,0]-psi_small_inv[0,0],ax=axs[0,1], title="Comparison")
plots.imshow(psi_small_inv[0,0],ax=axs[0,2],title="ѱ from small domain")
plots.show()

## Advection

In [None]:
from qgsw.fields.variables.tuples import PSIQ
from qgsw.models.core.flux import div_flux_5pts, div_flux_5pts_only, div_flux_5pts_with_bc
from qgsw.solver.finite_diff import grad_perp



psi, q = PSIQ.from_file(fs[0])

u,v = grad_perp(psi)
u/=dy
v/=dx

div_flux = div_flux_5pts(q,u[...,1:-1,:],v[...,:,1:-1],dx,dy)
div_flux_slice = div_flux[...,imin:imax,jmin:jmax]

psi_slice = psi[...,imin:imax+1,jmin:jmax+1]
q_slice = q[...,imin:imax,jmin:jmax]
q_slice_wide = q[...,imin-3:imax+3,jmin-3:jmax+3]
q_slice_narrow = q[...,imin-1:imax+1,jmin-1:jmax+1]

u_small,v_small = grad_perp(psi_slice)
u_small /= dy
v_small /= dx

div_flux_small_wide = div_flux_5pts_only(q_slice_wide,u_small,v_small,dx,dy)
div_flux_small_narrow = div_flux_5pts_with_bc(q_slice_narrow,u_small,v_small,dx,dy)

div_flux_small = div_flux_5pts(q_slice,u_small[...,1:-1,:],v_small[...,:,1:-1],dx,dy)

div_flux_small_replicate = div_flux_5pts_only(F.pad(q_slice,(3,3,3,3),mode="replicate"),u_small,v_small,dx,dy)

In [None]:
fig,axs = plots.subplots(4,3)

plots.imshow(div_flux_slice[0,0],ax=axs[0,0], title="Ref")
plots.imshow((div_flux_small[0,0] - div_flux_slice[0,0])/div_flux_slice.max(),ax=axs[0,1], title="Small vs Ref")
plots.imshow(div_flux_small[0,0],ax=axs[0,2], title="Small")

plots.imshow(div_flux_slice[0,0],ax=axs[1,0], title="Ref")
plots.imshow((div_flux_small_replicate[0,0] - div_flux_slice[0,0])/div_flux_slice.max(),ax=axs[1,1], title="Wide replicate vs Ref")
plots.imshow(div_flux_small_replicate[0,0],ax=axs[1,2], title="Wide replicate")

plots.imshow(div_flux_slice[0,0],ax=axs[2,0], title="Ref")
plots.imshow((div_flux_small_narrow[0,0] - div_flux_slice[0,0])/div_flux_slice.max(),ax=axs[2,1], title="Narrow vs Ref")
plots.imshow(div_flux_small_narrow[0,0],ax=axs[2,2], title="Narrow")

plots.imshow(div_flux_slice[0,0],ax=axs[3,0], title="Ref")
plots.imshow((div_flux_small_wide[0,0] - div_flux_slice[0,0])/div_flux_slice.max(),ax=axs[3,1], title="Wide vs Ref")
plots.imshow(div_flux_small_wide[0,0],ax=axs[3,2], title="Wide")
