In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import seaborn as sns

sns.set_theme("notebook",style="dark")

In [None]:
from qgsw import plots
import matplotlib.pyplot as plt
import torch
from qgsw.fields.variables.tuples import UVH
from qgsw.forcing.wind import WindForcing
from qgsw.masks import Masks
from qgsw.models.qg.psiq.core import QGPSIQ
from qgsw.models.qg.stretching_matrix import compute_A
from qgsw.models.qg.uvh.projectors.core import QGProjector
from qgsw.output import RunOutput
from qgsw.spatial.core.discretization import SpaceDiscretization2D, SpaceDiscretization3D
from qgsw.spatial.core.grid_conversion import interpolate
from qgsw.specs import defaults
from qgsw.utils import covphys
from qgsw.filters.gaussian import GaussianFilter2D
from qgsw.solver.boundary_conditions.base import Boundaries
from qgsw.solver.finite_diff import laplacian
from qgsw.utils.interpolation import LinearInterpolation, QuadraticInterpolation

run = RunOutput("../output/g5k/sw_double_gyre_long_hr")

H = run.summary.configuration.model.h
g_prime = run.summary.configuration.model.g_prime
f0 = run.summary.configuration.physics.f0
beta = run.summary.configuration.physics.beta
P = QGProjector(
    A =compute_A(
        H = H,
        g_prime = g_prime
    ),
    H = H.unsqueeze(-1).unsqueeze(-1),
    space=SpaceDiscretization3D.from_config(
        run.summary.configuration.space,
        run.summary.configuration.model
    ),
    f0 = run.summary.configuration.physics.f0,
    masks = Masks.empty(nx=run.summary.configuration.space.nx,ny=run.summary.configuration.space.ny)
)
A = P.A
space=P.space
dx,dy = space.dx,space.dy
nx,ny=space.nx,space.ny

wind = WindForcing.from_config(run.summary.configuration.windstress, run.summary.configuration.space,run.summary.configuration.physics)
tx,ty = wind.compute()

outputs = run.outputs()
uvh0: UVH = next(outputs).read()
sf_init = P.compute_p(covphys.to_cov(uvh0, dx,dy))[0]/f0

model_3l= QGPSIQ(
    space_2d=space.remove_z_h(),
    H = H,
    beta_plane=run.summary.configuration.physics.beta_plane,
    g_prime=g_prime,
)
model_3l.set_wind_forcing(tx,ty)
model_3l.masks = Masks.empty_tensor(model_3l.space.nx,model_3l.space.ny,device=defaults.get_device())
model_3l.bottom_drag_coef = run.summary.configuration.physics.bottom_drag_coefficient
model_3l.slip_coef = run.summary.configuration.physics.slip_coef

time_stepper = "rk3" #"euler" #

dt = 3600 if time_stepper == "rk3" else 360

model_3l.dt = dt
model_3l.time_stepper = time_stepper

In [None]:
imins = [32, 32, 112, 112]
imaxs = [i + 64 for i in imins]

jmins = [64, 256, 64, 256]
jmaxs = [j+128 for j in jmins]

imin,imax,jmin,jmax = imins[0], imaxs[0], jmins[0], jmaxs[0]

In [None]:
from qgsw.fields.variables.coefficients.core import UniformCoefficient
from qgsw.models.qg.psiq.filtered.core import QGPSIQCollinearSF

h1,h2,h3 = H
g1, g2, g3 = g_prime 
Heq = (H[1:2]*H[:1])/(H[1:2]+H[:1])

def compute_slices(imin:int,imax:int,jmin:int,jmax:int) -> tuple[list[slice,slice],list[slice,slice]]:

    psi_slices = [slice(imin,imax+1),slice(jmin,jmax+1)]
    q_slices = [slice(imin,imax),slice(jmin,jmax)]

    return psi_slices, q_slices

def build_models(imin:int,imax:int,jmin:int,jmax:int) -> QGPSIQCollinearSF:
    space_2d = SpaceDiscretization2D.from_tensors(
        x=P.space.remove_z_h().omega.xy.x[imin:imax+1,0],
        y=P.space.remove_z_h().omega.xy.y[0,jmin:jmax+1],
    )
    model_1l_alpha = QGPSIQCollinearSF(
        space_2d=space_2d,
        H = H[:2],
        beta_plane=run.summary.configuration.physics.beta_plane,
        g_prime=g_prime[:2],
    )
    model_1l_alpha.masks = Masks.empty_tensor(model_1l_alpha.space.nx,model_1l_alpha.space.ny,device=defaults.get_device())
    model_1l_alpha.y0 = model_3l.y0
    model_1l_alpha.wide = True
    model_1l_alpha.bottom_drag_coef = run.summary.configuration.physics.bottom_drag_coefficient*0
    model_1l_alpha.slip_coef = run.summary.configuration.physics.slip_coef

    model_1l_alpha.dt = dt
    model_1l_alpha.time_stepper = time_stepper

    return model_1l_alpha

In [None]:
def rmse(psi:torch.Tensor, psi_ref:torch.Tensor) -> float:
    return ((psi-psi_ref).square().mean().sqrt()/psi_ref.square().mean().sqrt())

In [None]:
model_3l.set_psi(sf_init)
model_3l.reset_time()

for _ in range(1,1):
    model_3l.step()
model_3l.reset_time()

filt = GaussianFilter2D(sigma=10)
k = filt.window_radius
p = 4
indices = (imin,imax,jmin,jmax)
print(f"Area: \n\ti: [{indices[0]}, {indices[1]}]\n\tj: [{indices[2]}, {indices[3]}]")
psi_slices, q_slices = compute_slices(*indices)
psi_mean_slice = [slice(s.start-k-p,s.stop+k+p) for s in psi_slices]

sf_0, q_0 = model_3l.prognostic.psiq

times: list[float] = [model_3l.time.item()]

dpsis = []
psis_3l: list[torch.Tensor] = [model_3l.psi[...,*psi_mean_slice]]
psi_bc_1l: list[Boundaries] = [Boundaries.extract(model_3l.psi[:,:1],imin,imax+1,jmin,jmax+1,2)]
qs_3l: list[torch.Tensor] = [model_3l.q[...,*q_slices]]
q_bc_1l: list[Boundaries] = [Boundaries.extract(model_3l.q[:,:1],imin-1,imax+1,jmin-1,jmax+1,3)]

n_steps = 500

for _ in range(1,n_steps):
    model_3l.step()
    times.append(model_3l.time.item())

    dpsis.append(model_3l._dpsi[...,*psi_mean_slice])

    psis_3l.append(model_3l.psi[...,*psi_mean_slice])
    qs_3l.append(model_3l.q[...,*q_slices])
    psi_bc_1l.append(Boundaries.extract(model_3l.psi[:,:1],imin,imax+1,jmin,jmax+1,2))
    q_bc_1l.append(Boundaries.extract(model_3l.q[:,:1],imin-1,imax+1,jmin-1,jmax+1,3))


In [None]:
fig, axs = plots.subplots(3,3)

plots.imshow(dpsis[0][0,0],ax=axs[0,0])
plots.imshow((dpsis[0][0,0]-dpsis[0][0,1]),ax=axs[0,1])
plots.imshow(dpsis[0][0,1],ax=axs[0,2])


plots.imshow(dpsis[len(dpsis)//2][0,0],ax=axs[1,0])
plots.imshow((dpsis[len(dpsis)//2][0,0]-dpsis[len(dpsis)//2][0,1]),ax=axs[1,1])
plots.imshow(dpsis[len(dpsis)//2][0,1],ax=axs[1,2])


plots.imshow(dpsis[-1][0,0],ax=axs[2,0])
plots.imshow((dpsis[-1][0,0]-dpsis[-1][0,1]),ax=axs[2,1])
plots.imshow(dpsis[-1][0,1],ax=axs[2,2])

plots.show()

In [None]:
model_3l.set_psi(sf_init)
model_3l.reset_time()

for _ in range(1,1):
    model_3l.step()
model_3l.reset_time()

filt = GaussianFilter2D(sigma=10)
k = filt.window_radius
p = 4
indices = (imin,imax,jmin,jmax)
print(f"Area: \n\ti: [{indices[0]}, {indices[1]}]\n\tj: [{indices[2]}, {indices[3]}]")
psi_slices, q_slices = compute_slices(*indices)
psi_mean_slice = [slice(s.start-k-p,s.stop+k+p) for s in psi_slices]

sf_0, q_0 = model_3l.prognostic.psiq

times: list[float] = [model_3l.time.item()]

dpsis = []
psis_3l: list[torch.Tensor] = [model_3l.psi[...,*psi_mean_slice]]
psi_bc_1l: list[Boundaries] = [Boundaries.extract(model_3l.psi[:,:1],imin,imax+1,jmin,jmax+1,2)]
qs_3l: list[torch.Tensor] = [model_3l.q[...,*q_slices]]
q_bc_1l: list[Boundaries] = [Boundaries.extract(model_3l.q[:,:1],imin-1,imax+1,jmin-1,jmax+1,3)]

n_steps = 5#00

for _ in range(1,n_steps):
    model_3l.step()
    times.append(model_3l.time.item())

    dpsis.append(model_3l._dpsi[...,*psi_mean_slice])

    psis_3l.append(model_3l.psi[...,*psi_mean_slice])
    qs_3l.append(model_3l.q[...,*q_slices])
    psi_bc_1l.append(Boundaries.extract(model_3l.psi[:,:1],imin,imax+1,jmin,jmax+1,2))
    q_bc_1l.append(Boundaries.extract(model_3l.q[:,:1],imin-1,imax+1,jmin-1,jmax+1,3))


model_1l_alpha = build_models(*indices)


psi_bar = torch.mean(torch.stack([torch.stack(
    [
        filt(psi[0,0])[None,...],
        filt(psi[0,1])[None,...],
        filt(psi[0,2])[None,...],
    ],dim=1
) for psi in psis_3l],dim=0),dim=0)

q_bar = interpolate(laplacian(psi_bar,dx,dy) - f0**2*torch.einsum("lm,...mxy->...lxy",model_3l.A,psi_bar[...,1:-1,1:-1]))

q_bar_bc = Boundaries.extract(q_bar,k+2,-k-3,k+2,-k-3,3)
psi_bar_bc = Boundaries.extract(psi_bar, k+p,-k-p-1,k+p,-k-p-1,2)
q_bar = q_bar[...,k+3:-k-3,k+3:-k-3]
psi_bar = psi_bar[...,k+p:-k-p,k+p:-k-p]

imin,imax,jmin,jmax = indices

alpha0 = UniformCoefficient.compute_optimal_values(
    psis_3l[0][0,0,k+p:-k-p,k+p:-k-p] - psi_bar[0,0],
    psis_3l[0][0,1,k+p:-k-p,k+p:-k-p] - psi_bar[0,1]
)
model_1l_alpha.set_wind_forcing(tx[imin:imax,jmin:jmax+1],ty[imin:imax+1,jmin:jmax])

Js = []
alpha0=0.5
print("\tɑ: ", alpha0)
alpha_ = torch.tensor(alpha0,requires_grad=True)

optimizer = torch.optim.Adam([alpha_],lr=1e-1)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5,patience=5)

alphas = []
Js = []

best_J = float("inf")
previous_J = torch.tensor(0)
best_alpha = None
stable_since = 0
max_stable_iteration = 10
eps = 1e-8

n_optim = 100

torch.set_grad_enabled(False)

for n in range(n_optim):
    optimizer.zero_grad()
    
    
    model_1l_alpha.reset_time()
    model_1l_alpha.set_psiq(sf_0[:,:1,*psi_slices],q_0[:,:1,*q_slices])
    
    
    with torch.enable_grad():

        model_1l_alpha.alpha = torch.ones_like(model_1l_alpha.psi)*alpha_
        model_1l_alpha.set_boundary_maps(QuadraticInterpolation(times,[bc[:,:1] for bc in psi_bc_1l]),QuadraticInterpolation(times,[bc[:,:1] for bc in q_bc_1l]))

        J = torch.tensor(0, **defaults.get())
        
        for i in range(1,n_steps):
            model_1l_alpha.step()

            if (i+1) % 1 == 0:
                J += rmse(model_1l_alpha.psi[0,0],psis_3l[i][0,0,k+p:-k-p,k+p:-k-p])
    
    if J < best_J:
        best_alpha = alpha_.cpu().item()
        best_J = J.cpu().item()
    
    if abs(J - previous_J)/abs(previous_J) < eps:
        stable_since +=1
    else:
        stable_since = 0
        previous_J = J

    alphas.append(alpha_.cpu().item())
    Js.append(J.cpu().item())
    
    print(f"[{str(n+1).zfill(len(str(n_optim)))}/{str(n_optim).zfill(len(str(n_optim)))}]: ɑ = {alpha_.cpu().item():3.5f} - Loss: {J.cpu().item():3.5f}")

    if stable_since >= max_stable_iteration:
        print(f"Convergence reached after {n+1} iterations.")
        break
    
    J.backward()
    optimizer.step()
    scheduler.step(J)

print(f"ɑ = {best_alpha}")
print(f"Loss: {best_J}")

In [None]:
torch.load("../output/local/param_optim/results_area_1.pt")

In [None]:
plt.plot(alphas, Js)
plt.show()

In [None]:
# space_2d = SpaceDiscretization2D.from_tensors(
#     x=P.space.remove_z_h().omega.xy.x[imin:imax+1,0],
#     y=P.space.remove_z_h().omega.xy.y[0,jmin:jmax+1],
# )
# model_1l = QGPSIQ(
#     space_2d=space_2d,
#     H = Heq,
#     beta_plane=run.summary.configuration.physics.beta_plane,
#     g_prime=g_prime[1:2],
# )
# model_1l.masks = Masks.empty_tensor(model_1l.space.nx,model_1l.space.ny,device=defaults.get_device())
# model_1l.y0 = model_3l.y0
# model_1l.wide = True
# model_1l.bottom_drag_coef = run.summary.configuration.physics.bottom_drag_coefficient*0
# model_1l.slip_coef = run.summary.configuration.physics.slip_coef

# model_1l.dt = dt
# model_1l.time_stepper = time_stepper
# model_1l.reset_time()
# model_1l.set_psiq(sf_0[:,:1,*psi_slices],q_0[:,:1,*q_slices])
# model_1l.set_boundary_maps(QuadraticInterpolation(times,[bc[:,:1] for bc in psi_bc_1l]),QuadraticInterpolation(times,[bc[:,:1] for bc in q_bc_1l]))

# losses = [rmse(model_1l.psi[0,0],psis_3l[0][0,0,k+p:-k-p,k+p:-k-p]).cpu()]

# for i in range(1,n_steps):
#     model_1l.step()
#     if i == 1:
#         dpsi = model_1l.dpsi

#     losses.append(rmse(model_1l.psi[0,0],psis_3l[i][0,0,k+p:-k-p,k+p:-k-p]).cpu())

# plt.plot(losses, label=f"Reduced gravity", color="black")

# with torch.no_grad():
#     model_1l_alpha.reset_time()
#     model_1l_alpha.set_psiq(sf_0[:,:1,*psi_slices],q_0[:,:1,*q_slices])
#     model_1l_alpha.alpha =torch.ones_like(model_1l_alpha.psi)*best_alpha
#     model_1l_alpha.set_boundary_maps(QuadraticInterpolation(times,[bc[:,:1] for bc in psi_bc_1l]),QuadraticInterpolation(times,[bc[:,:1] for bc in q_bc_1l]))

#     losses = [rmse(model_1l_alpha.psi[0,0],psis_3l[0][0,0,k+p:-k-p,k+p:-k-p]).cpu()]

#     for i in range(1,n_steps):
#         model_1l_alpha.step()

#         losses.append(rmse(model_1l_alpha.psi[0,0],psis_3l[i][0,0,k+p:-k-p,k+p:-k-p]).cpu())

# plt.plot(losses, label=f"Best ɑ: {best_alpha:.2f}", color="green")
# with torch.no_grad():
#     model_1l_alpha.reset_time()
#     model_1l_alpha.set_psiq(sf_0[:,:1,*psi_slices],q_0[:,:1,*q_slices])
#     model_1l_alpha.alpha =torch.ones_like(model_1l_alpha.psi)*alpha0
#     model_1l_alpha.set_boundary_maps(QuadraticInterpolation(times,[bc[:,:1] for bc in psi_bc_1l]),QuadraticInterpolation(times,[bc[:,:1] for bc in q_bc_1l]))

#     losses = [rmse(model_1l_alpha.psi[0,0],psis_3l[0][0,0,k+p:-k-p,k+p:-k-p]).cpu()]

#     for i in range(1,n_steps):
#         model_1l_alpha.step()

#         losses.append(rmse(model_1l_alpha.psi[0,0],psis_3l[i][0,0,k+p:-k-p,k+p:-k-p]).cpu())

# plt.plot(losses, label = f"Initial ɑ: {alpha0:.2f}", color = "red",linestyle="dashed")
# plt.gca().set_ylim(0,1)
# plt.legend()
# plt.grid()
# plt.show()

In [None]:
from qgsw.models.qg.psiq.filtered.core import QGPSIQFixeddSF2

space_2d = SpaceDiscretization2D.from_tensors(
    x=P.space.remove_z_h().omega.xy.x[imin:imax+1,0],
    y=P.space.remove_z_h().omega.xy.y[0,jmin:jmax+1],
)

model_1l = QGPSIQ(
    space_2d=space_2d,
    H = Heq,
    beta_plane=run.summary.configuration.physics.beta_plane,
    g_prime=g_prime[1:2],
)
model_1l.masks = Masks.empty_tensor(model_1l.space.nx,model_1l.space.ny,device=defaults.get_device())
model_1l.y0 = model_3l.y0
model_1l.wide = True
model_1l.bottom_drag_coef = run.summary.configuration.physics.bottom_drag_coefficient*0
model_1l.slip_coef = run.summary.configuration.physics.slip_coef

model_1l.dt = dt
model_1l.time_stepper = time_stepper
model_1l.reset_time()
model_1l.set_psiq(sf_0[:,:1,*psi_slices],q_0[:,:1,*q_slices])
model_1l.set_boundary_maps(QuadraticInterpolation(times,[bc[:,:1] for bc in psi_bc_1l]),QuadraticInterpolation(times,[bc[:,:1] for bc in q_bc_1l]))

losses = [rmse(model_1l.psi[0,0],psis_3l[0][0,0,k+p:-k-p,k+p:-k-p]).cpu()]

model_1l.step()
dpsi = model_1l._dpsi

model_dpsi2 = QGPSIQFixeddSF2(
    space_2d=space_2d,
    H = H[:2],
    beta_plane=run.summary.configuration.physics.beta_plane,
    g_prime=g_prime[:2],
)
model_dpsi2.masks = Masks.empty_tensor(model_dpsi2.space.nx,model_dpsi2.space.ny,device=defaults.get_device())
model_dpsi2.y0 = model_3l.y0
model_dpsi2.wide = True
model_dpsi2.bottom_drag_coef = run.summary.configuration.physics.bottom_drag_coefficient*0
model_dpsi2.slip_coef = run.summary.configuration.physics.slip_coef

model_dpsi2.dt = dt
model_dpsi2.time_stepper = time_stepper
model_dpsi2.reset_time()
model_dpsi2.set_psiq(sf_0[:,:1,*psi_slices],q_0[:,:1,*q_slices])
model_dpsi2.set_boundary_maps(QuadraticInterpolation(times,[bc[:,:1] for bc in psi_bc_1l]),QuadraticInterpolation(times,[bc[:,:1] for bc in q_bc_1l]))


Js = []

dpsi2_ = torch.clone(dpsi).requires_grad_(True)
optimizer = torch.optim.Adam([dpsi2_],lr=1e-2)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5,patience=5)

dpsi2s = []
Js = []

best_J = float("inf")
previous_J = torch.tensor(0)
best_dpsi2 = None
stable_since = 0
max_stable_iteration = 10
eps = 1e-8
n_steps = 5
n_optim = 100

for n in range(n_optim):
    optimizer.zero_grad()
    J = torch.tensor(0, **defaults.get())
    model_dpsi2.reset_time()
    model_dpsi2.set_psiq(sf_0[:,:1,*psi_slices],q_0[:,:1,*q_slices])
    model_dpsi2.set_boundary_maps(QuadraticInterpolation(times,[bc[:,:1] for bc in psi_bc_1l]),QuadraticInterpolation(times,[bc[:,:1] for bc in q_bc_1l]))

    with torch.enable_grad():

        model_dpsi2.dpsi2 = dpsi2_
        for i in range(1,n_steps):
            model_dpsi2.step()

            if (i+1) % 1 == 0:
                J += rmse(model_dpsi2.psi[0,0],psis_3l[i][0,0,k+p:-k-p,k+p:-k-p])
    
    if torch.isnan(J):
        print(f"Loss diverges at step {n+1}.")
        break

    if J < best_J:
        best_dpsi2 = dpsi2_.detach().cpu()
        best_J = J.cpu().item()
    
    if abs(J - previous_J)/abs(previous_J) < eps:
        stable_since +=1
    else:
        stable_since = 0
        previous_J = J

    dpsi2s.append(dpsi2_.detach().cpu())
    Js.append(J.cpu().item())
    
    print(f"[{str(n+1).zfill(len(str(n_optim)))}/{str(n_optim).zfill(len(str(n_optim)))}]: Loss: {J.cpu().item():3.5f}")

    if stable_since >= max_stable_iteration:
        print(f"Convergence reached after {n+1} iterations.")
        break

    J.backward()
    optimizer.step()
    scheduler.step(J)

print(f"Loss: {best_J}")


In [None]:
from qgsw import specs


space_2d = SpaceDiscretization2D.from_tensors(
    x=P.space.remove_z_h().omega.xy.x[imin:imax+1,0],
    y=P.space.remove_z_h().omega.xy.y[0,jmin:jmax+1],
)
model_1l = QGPSIQ(
    space_2d=space_2d,
    H = Heq,
    beta_plane=run.summary.configuration.physics.beta_plane,
    g_prime=g_prime[1:2],
)
model_1l.masks = Masks.empty_tensor(model_1l.space.nx,model_1l.space.ny,device=defaults.get_device())
model_1l.y0 = model_3l.y0
model_1l.wide = True
model_1l.bottom_drag_coef = run.summary.configuration.physics.bottom_drag_coefficient*0
model_1l.slip_coef = run.summary.configuration.physics.slip_coef

model_1l.dt = dt
model_1l.time_stepper = time_stepper
model_1l.reset_time()
model_1l.set_psiq(sf_0[:,:1,*psi_slices],q_0[:,:1,*q_slices])
model_1l.set_boundary_maps(QuadraticInterpolation(times,[bc[:,:1] for bc in psi_bc_1l]),QuadraticInterpolation(times,[bc[:,:1] for bc in q_bc_1l]))

losses = [rmse(model_1l.psi[0,0],psis_3l[0][0,0,k+p:-k-p,k+p:-k-p]).cpu()]

for i in range(1,n_steps):
    model_1l.step()
    losses.append(rmse(model_1l.psi[0,0],psis_3l[i][0,0,k+p:-k-p,k+p:-k-p]).cpu())

plt.plot(losses, label=f"Reduced gravity", color="black")

with torch.no_grad():
    model_1l_alpha.reset_time()
    model_1l_alpha.set_psiq(sf_0[:,:1,*psi_slices],q_0[:,:1,*q_slices])
    model_1l_alpha.alpha =torch.ones_like(model_1l_alpha.psi)*best_alpha
    model_1l_alpha.set_boundary_maps(QuadraticInterpolation(times,[bc[:,:1] for bc in psi_bc_1l]),QuadraticInterpolation(times,[bc[:,:1] for bc in q_bc_1l]))

    losses = [rmse(model_1l_alpha.psi[0,0],psis_3l[0][0,0,k+p:-k-p,k+p:-k-p]).cpu()]

    for i in range(1,n_steps):
        model_1l_alpha.step()
        losses.append(rmse(model_1l_alpha.psi[0,0],psis_3l[i][0,0,k+p:-k-p,k+p:-k-p]).cpu())

plt.plot(losses, label=f"Best ɑ: {best_alpha:.2f}", color="green")
with torch.no_grad():
    model_1l_alpha.reset_time()
    model_1l_alpha.set_psiq(sf_0[:,:1,*psi_slices],q_0[:,:1,*q_slices])
    model_1l_alpha.alpha =torch.ones_like(model_1l_alpha.psi)*alpha0
    model_1l_alpha.set_boundary_maps(QuadraticInterpolation(times,[bc[:,:1] for bc in psi_bc_1l]),QuadraticInterpolation(times,[bc[:,:1] for bc in q_bc_1l]))

    losses = [rmse(model_1l_alpha.psi[0,0],psis_3l[0][0,0,k+p:-k-p,k+p:-k-p]).cpu()]

    for i in range(1,n_steps):
        model_1l_alpha.step()
        losses.append(rmse(model_1l_alpha.psi[0,0],psis_3l[i][0,0,k+p:-k-p,k+p:-k-p]).cpu())

plt.plot(losses, label = f"Initial ɑ: {alpha0:.2f}", color = "red",linestyle="dashed")

with torch.no_grad():
    model_dpsi2.reset_time()
    model_dpsi2.set_psiq(sf_0[:,:1,*psi_slices],q_0[:,:1,*q_slices])
    model_dpsi2.dpsi2 = best_dpsi2.to(**specs.from_tensor(model_dpsi2.psi))
    model_dpsi2.set_boundary_maps(QuadraticInterpolation(times,[bc[:,:1] for bc in psi_bc_1l]),QuadraticInterpolation(times,[bc[:,:1] for bc in q_bc_1l]))

    losses = [rmse(model_dpsi2.psi[0,0],psis_3l[0][0,0,k+p:-k-p,k+p:-k-p]).cpu()]

    for i in range(1,n_steps):
        model_dpsi2.step()
        losses.append(rmse(model_dpsi2.psi[0,0],psis_3l[i][0,0,k+p:-k-p,k+p:-k-p]).cpu())


plt.plot(losses, label = f"dψ2 optim", color = "orange")

# plt.gca().set_ylim(0,1)
plt.hlines(y= 1,xmin=0,xmax=n_steps,color="grey",linestyle="dotted")
plt.legend()
plt.grid()
plt.show()

In [None]:
torch.load("../output/local/param_optim/results_area_1.pt")

In [None]:
torch.load("tmp.pt")[0]

In [None]:
plots.imshow(torch.stack([psi[0,1] for psi in dpsis],dim=0).mean(dim=0))

In [None]:
model_dpsi2._A12