In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import seaborn as sns

from qgsw import logging

sns.set_theme("notebook",style="dark")

logger = logging.getLogger(__name__)

In [None]:
from qgsw import plots
import matplotlib.pyplot as plt
import torch
from qgsw.fields.variables.tuples import UVH
from qgsw.forcing.wind import WindForcing
from qgsw.masks import Masks
from qgsw.models.qg.psiq.core import QGPSIQ
from qgsw.models.qg.stretching_matrix import compute_A
from qgsw.models.qg.uvh.projectors.core import QGProjector
from qgsw.output import RunOutput
from qgsw.spatial.core.discretization import SpaceDiscretization2D, SpaceDiscretization3D
from qgsw.spatial.core.grid_conversion import interpolate
from qgsw.specs import defaults
from qgsw.utils import covphys
from qgsw.filters.gaussian import GaussianFilter2D
from qgsw.solver.boundary_conditions.base import Boundaries
from qgsw.solver.finite_diff import laplacian
from qgsw.utils.interpolation import LinearInterpolation, QuadraticInterpolation

run = RunOutput("../output/g5k/sw_double_gyre_long_hr")

H = run.summary.configuration.model.h
g_prime = run.summary.configuration.model.g_prime
f0 = run.summary.configuration.physics.f0
beta = run.summary.configuration.physics.beta
P = QGProjector(
    A =compute_A(
        H = H,
        g_prime = g_prime
    ),
    H = H.unsqueeze(-1).unsqueeze(-1),
    space=SpaceDiscretization3D.from_config(
        run.summary.configuration.space,
        run.summary.configuration.model
    ),
    f0 = run.summary.configuration.physics.f0,
    masks = Masks.empty(nx=run.summary.configuration.space.nx,ny=run.summary.configuration.space.ny)
)
A = P.A
space=P.space
dx,dy = space.dx,space.dy
nx,ny=space.nx,space.ny

wind = WindForcing.from_config(run.summary.configuration.windstress, run.summary.configuration.space,run.summary.configuration.physics)
tx,ty = wind.compute()

outputs = run.outputs()
uvh0: UVH = next(outputs).read()
sf_init = P.compute_p(covphys.to_cov(uvh0, dx,dy))[0]/f0

model_3l= QGPSIQ(
    space_2d=space.remove_z_h(),
    H = H,
    beta_plane=run.summary.configuration.physics.beta_plane,
    g_prime=g_prime,
)
model_3l.set_wind_forcing(tx,ty)
model_3l.masks = Masks.empty_tensor(model_3l.space.nx,model_3l.space.ny,device=defaults.get_device())
model_3l.bottom_drag_coef = run.summary.configuration.physics.bottom_drag_coefficient
model_3l.slip_coef = run.summary.configuration.physics.slip_coef

time_stepper = "rk3" #"euler" #

dt = 7200 if time_stepper == "rk3" else 360

model_3l.dt = dt
model_3l.time_stepper = time_stepper

In [None]:
imins = [32, 32, 112, 112][:1]
imaxs = [i + 64 for i in imins]

jmins = [64, 256, 64, 256][:1]
jmaxs = [j+128 for j in jmins]

imin,imax,jmin,jmax = imins[0], imaxs[0], jmins[0], jmaxs[0]

In [None]:
from qgsw.fields.variables.coefficients.core import UniformCoefficient
from qgsw.models.qg.psiq.filtered.core import QGPSIQCollinearSF

h1,h2,h3 = H
g1, g2, g3 = g_prime 
Heq = (H[1:2]*H[:1])/(H[1:2]+H[:1])

def compute_slices(imin:int,imax:int,jmin:int,jmax:int) -> tuple[list[slice,slice],list[slice,slice]]:

    psi_slices = [slice(imin,imax+1),slice(jmin,jmax+1)]
    q_slices = [slice(imin,imax),slice(jmin,jmax)]

    return psi_slices, q_slices

def build_models(imin:int,imax:int,jmin:int,jmax:int) -> QGPSIQCollinearSF:
    space_2d = SpaceDiscretization2D.from_tensors(
        x=P.space.remove_z_h().omega.xy.x[imin:imax+1,0],
        y=P.space.remove_z_h().omega.xy.y[0,jmin:jmax+1],
    )
    model_1l_alpha = QGPSIQCollinearSF(
        space_2d=space_2d,
        H = H[:2],
        beta_plane=run.summary.configuration.physics.beta_plane,
        g_prime=g_prime[:2],
    )
    model_1l_alpha.masks = Masks.empty_tensor(model_1l_alpha.space.nx,model_1l_alpha.space.ny,device=defaults.get_device())
    model_1l_alpha.y0 = model_3l.y0
    model_1l_alpha.wide = True
    model_1l_alpha.bottom_drag_coef = run.summary.configuration.physics.bottom_drag_coefficient*0
    model_1l_alpha.slip_coef = run.summary.configuration.physics.slip_coef

    model_1l_alpha.dt = dt
    model_1l_alpha.time_stepper = time_stepper

    return model_1l_alpha

In [None]:
def rmse(psi:torch.Tensor, psi_ref:torch.Tensor) -> float:
    return ((psi-psi_ref).square().mean().sqrt()/psi_ref.square().mean().sqrt())

In [None]:
model_3l.set_psi(sf_init)
model_3l.reset_time()

for _ in range(1,1):
    model_3l.step()

model_3l.reset_time()


indices = (imin,imax,jmin,jmax)
psi_slices, q_slices = compute_slices(*indices)

logger.info(f"Area: \n\ti: [{indices[0]}, {indices[1]}]\n\tj: [{indices[2]}, {indices[3]}]")

sf_0, q_0 = model_3l.prognostic.psiq

times: list[float] = [model_3l.time.item()]

slices = [slice(s.start-4,s.stop+4) for s in psi_slices]

psis_3l: list[torch.Tensor] = [model_3l.psi[...,*slices]]
psi_bc_1l: list[Boundaries] = [Boundaries.extract(model_3l.psi[:,:1],imin,imax+1,jmin,jmax+1,2)]
qs_3l: list[torch.Tensor] = [model_3l.q[...,*q_slices]]
q_bc_1l: list[Boundaries] = [Boundaries.extract(model_3l.q[:,:1],imin-1,imax+1,jmin-1,jmax+1,3)]

n_steps = 250

for _ in range(1,n_steps):
    model_3l.step()

    times.append(model_3l.time.item())

    psis_3l.append(model_3l.psi[...,*slices])
    psi_bc_1l.append(Boundaries.extract(model_3l.psi[:,:1],imin,imax+1,jmin,jmax+1,2))
    qs_3l.append(model_3l.q[...,*q_slices])
    q_bc_1l.append(Boundaries.extract(model_3l.q[:,:1],imin-1,imax+1,jmin-1,jmax+1,3))


In [None]:
psi = psis_3l[0]

interpolate(laplacian(psi,dx,dy) - f0**2 * (1/h1/g1+1/h1/g1)*psi[...,1:-1,1:-1])[...,3:-3,3:-3].shape

In [None]:
A_2l = compute_A(
    H[:2],g_prime[:2]
)

def compute_q_alpha(psi:torch.Tensor, alpha:torch.Tensor) -> torch.Tensor:
    return interpolate(laplacian(psi,dx,dy) - f0**2 * (1/h1/g1+1/h1/g1)*psi[...,1:-1,1:-1] + f0**2 * (1/h1/g2)*alpha*psi[...,1:-1,1:-1])
def compute_q_psi2(psi:torch.Tensor, psi2:torch.Tensor) -> torch.Tensor:
    return interpolate(laplacian(psi,dx,dy) - f0**2 * (1/h1/g1+1/h1/g1)*psi[...,1:-1,1:-1] + f0**2 * (1/h1/g2)*psi2[...,1:-1,1:-1])
def compute_q_rg(psi:torch.Tensor) -> torch.Tensor:
    return interpolate(laplacian(psi,dx,dy) - f0**2 * (1/h1/g1+1/h1/g1)*psi[...,1:-1,1:-1])
def compute_q_2l(psi:torch.Tensor) -> torch.Tensor:
    return interpolate(laplacian(psi,dx,dy) - f0**2 * torch.einsum("lm,...mxy->...lxy",A_2l,psi[...,1:-1,1:-1]))

In [None]:
model_1l_alpha = build_models(*indices)


imin,imax,jmin,jmax = indices

model_1l_alpha.set_wind_forcing(tx[imin:imax,jmin:jmax+1],ty[imin:imax+1,jmin:jmax])

Js = []
alpha0 = 0.5
print("\tɑ: ", alpha0)
alpha = torch.tensor(alpha0,requires_grad=True)
alpha_ = torch.tensor(alpha0,requires_grad=True)

optimizer = torch.optim.Adam([alpha, alpha_],lr=1e-1)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5,patience=5)

alphas = []
alphas_ = []
Js = []

best_J = float("inf")
previous_J = torch.tensor(0)
best_alpha = None
best_alpha_ = None
stable_since = 0
max_stable_iteration = 10
eps = 1e-8

n_optim = 10

torch.set_grad_enabled(False)

for n in range(n_optim):
    optimizer.zero_grad()
    
    model_1l_alpha.reset_time()
    
    psi_bc = QuadraticInterpolation(times,[Boundaries.extract(psi[:,:1],4,-5,4,-5,2) for psi in psis_3l])

    with torch.enable_grad():
        
        q_0_ = compute_q_alpha(sf_0[:,:1,*slices],alpha)[...,3:-3,3:-3]
        q_bc = QuadraticInterpolation(times,[Boundaries.extract(compute_q_alpha(psi[:,:1],alpha),2,-3,2,-3,3) for psi in psis_3l])

        model_1l_alpha.set_psiq(sf_0[:,:1,*psi_slices],q_0_)

        model_1l_alpha.alpha = torch.ones_like(model_1l_alpha.psi)*alpha_
        model_1l_alpha.set_boundary_maps(psi_bc,q_bc)

        J = torch.tensor(0, **defaults.get())
        
        for i in range(1,n_steps):
            model_1l_alpha.step()

            if (i+1) % 1 == 0:
                J += rmse(model_1l_alpha.psi[0,0],psis_3l[i][0,0,4:-4,4:-4])
    
    if J < best_J:
        best_alpha = alpha.cpu().item()
        best_alpha_ = alpha_.cpu().item()
        best_J = J.cpu().item()
    
    if abs(J - previous_J)/abs(previous_J) < eps:
        stable_since +=1
    else:
        stable_since = 0
        previous_J = J

    alphas.append(alpha.cpu().item())
    alphas_.append(alpha_.cpu().item())
    Js.append(J.cpu().item())
    
    print(f"[{str(n+1).zfill(len(str(n_optim)))}/{str(n_optim).zfill(len(str(n_optim)))}]: ɑ = {alpha.cpu().item():3.5f} - ɑ_ = {alpha_.cpu().item():3.5f} - Loss: {J.cpu().item():3.5f}")

    if stable_since >= max_stable_iteration:
        print(f"Convergence reached after {n+1} iterations.")
        break
    
    J.backward()
    optimizer.step()
    scheduler.step(J)

print(f"ɑ = {best_alpha}")
print(f"Loss: {best_J}")

In [None]:
from qgsw.models.qg.psiq.filtered.core import QGPSIQFixeddSF2

space_2d = SpaceDiscretization2D.from_tensors(
    x=P.space.remove_z_h().omega.xy.x[imin:imax+1,0],
    y=P.space.remove_z_h().omega.xy.y[0,jmin:jmax+1],
)


model_dpsi2 = QGPSIQFixeddSF2(
    space_2d=space_2d,
    H = H[:2],
    beta_plane=run.summary.configuration.physics.beta_plane,
    g_prime=g_prime[:2],
)
model_dpsi2.masks = Masks.empty_tensor(model_dpsi2.space.nx,model_dpsi2.space.ny,device=defaults.get_device())
model_dpsi2.y0 = model_3l.y0
model_dpsi2.wide = True
model_dpsi2.bottom_drag_coef = run.summary.configuration.physics.bottom_drag_coefficient*0
model_dpsi2.slip_coef = run.summary.configuration.physics.slip_coef
model_dpsi2.set_wind_forcing(tx[imin:imax,jmin:jmax+1],ty[imin:imax+1,jmin:jmax])

model_dpsi2.dt = dt
model_dpsi2.time_stepper = time_stepper


Js = []

psi2_ = (torch.nn.functional.pad(torch.ones_like(model_dpsi2.psi),(4,4,4,4),value=1)*torch.mean(model_dpsi2.psi)).requires_grad_()
dpsi2_ = (torch.ones_like(psi2_)*1e-2).requires_grad_()

optimizer = torch.optim.Adam([psi2_, dpsi2_],lr=1e-2)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5,patience=5)

psi2s = []
dpsi2s = []
Js = []

best_J = float("inf")
previous_J = torch.tensor(0)
best_psi2 = None
best_dpsi2 = None
stable_since = 0
max_stable_iteration = 10
eps = 1e-8
n_steps = 250
n_optim = 100

for n in range(n_optim):
    optimizer.zero_grad()
    J = torch.tensor(0, **defaults.get())
    model_dpsi2.reset_time()

    psi_bc = QuadraticInterpolation(times,[Boundaries.extract(psi[:,:1],4,-5,4,-5,2) for psi in psis_3l])

    with torch.enable_grad():
        
        q_0_ = compute_q_psi2(sf_0[:,:1,*slices],psi2_)[...,3:-3,3:-3]
        q_bc = QuadraticInterpolation(times,[Boundaries.extract(compute_q_psi2(psi[:,:1],psi2_+n*dt*dpsi2_),2,-3,2,-3,3) for n,psi in enumerate(psis_3l)])
        
        model_dpsi2.set_psiq(sf_0[:,:1,*psi_slices],q_0_)
        model_dpsi2.set_boundary_maps(psi_bc,q_bc)

        model_dpsi2.dpsi2 = dpsi2_[...,4:-4,4:-4]

        for i in range(1,n_steps):
            model_dpsi2.step()

            if (i+1) % 1 == 0:
                J += rmse(model_dpsi2.psi[0,0],psis_3l[i][0,0,4:-4,4:-4])
    
    if torch.isnan(J):
        print(f"Loss diverges at step {n+1}.")
        break

    if J < best_J:
        best_psi2 = psi2_.detach().cpu()
        best_dpsi2 = dpsi2_.detach().cpu()
        best_J = J.cpu().item()
    
    if abs(J - previous_J)/abs(previous_J) < eps:
        stable_since +=1
    else:
        stable_since = 0
        previous_J = J

    dpsi2s.append(dpsi2_.detach().cpu())
    Js.append(J.cpu().item())
    
    print(f"[{str(n+1).zfill(len(str(n_optim)))}/{str(n_optim).zfill(len(str(n_optim)))}]: Loss: {J.cpu().item():3.5f}")

    if stable_since >= max_stable_iteration:
        print(f"Convergence reached after {n+1} iterations.")
        break

    J.backward()
    optimizer.step()
    scheduler.step(J)

print(f"Loss: {best_J}")


In [None]:
for e in {"a":1,"b":3}:
    print(e)

In [None]:
from qgsw import specs

space_2d = SpaceDiscretization2D.from_tensors(
    x=P.space.remove_z_h().omega.xy.x[imin:imax+1,0],
    y=P.space.remove_z_h().omega.xy.y[0,jmin:jmax+1],
)


model_2l = QGPSIQ(
    space_2d=space_2d,
    H = H[:2],
    beta_plane=run.summary.configuration.physics.beta_plane,
    g_prime=g_prime[:2],
)
model_2l.masks = Masks.empty_tensor(model_2l.space.nx,model_2l.space.ny,device=defaults.get_device())
model_2l.y0 = model_3l.y0
model_2l.wide = True
model_2l.bottom_drag_coef = run.summary.configuration.physics.bottom_drag_coefficient*0
model_2l.set_wind_forcing(tx[imin:imax,jmin:jmax+1],ty[imin:imax+1,jmin:jmax])
model_2l.slip_coef = run.summary.configuration.physics.slip_coef

model_2l.dt = dt
model_2l.time_stepper = time_stepper
model_2l.reset_time()

q_0_ = compute_q_2l(sf_0[:,:2,*slices])[...,3:-3,3:-3]
model_2l.set_psiq(sf_0[:,:2,*psi_slices],q_0_)

psi_bc = QuadraticInterpolation(times,[Boundaries.extract(psi[:,:2],4,-5,4,-5,2) for psi in psis_3l])
q_bc = QuadraticInterpolation(times,[Boundaries.extract(compute_q_2l(psi[:,:2]),2,-3,2,-3,3) for psi in psis_3l])

model_2l.set_boundary_maps(psi_bc, q_bc)

losses = [rmse(model_2l.psi[0,0],psis_3l[0][0,0,4:-4,4:-4]).cpu()]

for i in range(1,n_steps):
    model_2l.step()
    losses.append(rmse(model_2l.psi[0,0],psis_3l[i][0,0,4:-4,4:-4]).cpu())

plt.plot(losses, label=f"2 layers", color="pink")

model_1l = QGPSIQ(
    space_2d=space_2d,
    H = Heq,
    beta_plane=run.summary.configuration.physics.beta_plane,
    g_prime=g_prime[1:2],
)
model_1l.masks = Masks.empty_tensor(model_1l.space.nx,model_1l.space.ny,device=defaults.get_device())
model_1l.y0 = model_3l.y0
model_1l.wide = True
model_1l.bottom_drag_coef = run.summary.configuration.physics.bottom_drag_coefficient*0
model_1l.set_wind_forcing(tx[imin:imax,jmin:jmax+1],ty[imin:imax+1,jmin:jmax])
model_1l.slip_coef = run.summary.configuration.physics.slip_coef

model_1l.dt = dt
model_1l.time_stepper = time_stepper
model_1l.reset_time()

q_0_ = compute_q_rg(sf_0[:,:1,*slices])[...,3:-3,3:-3]
model_1l.set_psiq(sf_0[:,:1,*psi_slices],q_0_)

psi_bc = QuadraticInterpolation(times,[Boundaries.extract(psi[:,:1],4,-5,4,-5,2) for psi in psis_3l])
q_bc = QuadraticInterpolation(times,[Boundaries.extract(compute_q_rg(psi[:,:1]),2,-3,2,-3,3) for psi in psis_3l])

model_1l.set_boundary_maps(psi_bc, q_bc)

losses = [rmse(model_1l.psi[0,0],psis_3l[0][0,0,4:-4,4:-4]).cpu()]

for i in range(1,n_steps):
    model_1l.step()
    losses.append(rmse(model_1l.psi[0,0],psis_3l[i][0,0,4:-4,4:-4]).cpu())

plt.plot(losses, label=f"Reduced gravity", color="black")

with torch.no_grad():
    model_1l_alpha.reset_time()
    model_1l_alpha.set_psiq(sf_0[:,:1,*psi_slices],q_0[:,:1,*q_slices])

    q_0_ = compute_q_alpha(sf_0[:,:1,*slices],best_alpha)[...,3:-3,3:-3]
    model_1l_alpha.set_psiq(sf_0[:,:1,*psi_slices],q_0_)

    psi_bc = QuadraticInterpolation(times,[Boundaries.extract(psi[:,:1],4,-5,4,-5,2) for psi in psis_3l])
    q_bc = QuadraticInterpolation(times,[Boundaries.extract(compute_q_alpha(psi[:,:1],best_alpha),2,-3,2,-3,3) for psi in psis_3l])

    model_1l_alpha.alpha =torch.ones_like(model_1l_alpha.psi)*best_alpha
    model_1l_alpha.set_boundary_maps(psi_bc,q_bc)

    losses = [rmse(model_1l_alpha.psi[0,0],psis_3l[0][0,0,4:-4,4:-4]).cpu()]

    for i in range(1,n_steps):
        model_1l_alpha.step()
        losses.append(rmse(model_1l_alpha.psi[0,0],psis_3l[i][0,0,4:-4,4:-4]).cpu())

plt.plot(losses, label=f"Best ɑ: {best_alpha:.2f}, ɑ_: {best_alpha_:.2f}", color="green")

with torch.no_grad():
    model_dpsi2.reset_time()
    
    q_0_ = compute_q_psi2(sf_0[:,:1,*slices],best_psi2.to(**specs.from_tensor(model_dpsi2.psi)))[...,3:-3,3:-3]
    model_dpsi2.set_psiq(sf_0[:,:1,*psi_slices],q_0_)

    psi_bc = QuadraticInterpolation(times,[Boundaries.extract(psi[:,:1],4,-5,4,-5,2) for psi in psis_3l])
    q_bc = QuadraticInterpolation(times,[Boundaries.extract(compute_q_psi2(psi[:,:1],best_psi2.to(**specs.from_tensor(model_dpsi2.psi))),2,-3,2,-3,3) for psi in psis_3l])
    
    model_dpsi2.dpsi2 = best_dpsi2.to(**specs.from_tensor(model_dpsi2.psi))
    model_dpsi2.set_boundary_maps(psi_bc,q_bc)

    losses = [rmse(model_dpsi2.psi[0,0],psis_3l[0][0,0,4:-4,4:-4]).cpu()]

    for i in range(1,n_steps):
        model_dpsi2.step()
        losses.append(rmse(model_dpsi2.psi[0,0],psis_3l[i][0,0,4:-4,4:-4]).cpu())


plt.plot(losses, label = f"dψ2 optim", color = "orange")

# plt.gca().set_ylim(0,1)
# plt.hlines(y= 1,xmin=0,xmax=n_steps,color="grey",linestyle="dotted")
plt.legend()
plt.grid()
plt.show()