In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from __future__ import annotations


import numpy as np
import matplotlib.pyplot as plt
import torch

from qgsw import verbose
from qgsw.configs.core import Configuration
from qgsw.fields.variables.prognostic_tuples import UVH
from qgsw.forcing.wind import WindForcing
from qgsw.models.instantiation import instantiate_model
from qgsw.models.names import ModelName
from qgsw.output import RunOutput
from qgsw.perturbations.core import Perturbation
from qgsw.simulation.steps import Steps
from qgsw.spatial.core.discretization import (
    SpaceDiscretization2D,
)
from qgsw.specs import DEVICE
from qgsw.utils import time_params

torch.backends.cudnn.deterministic = True

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
## Configuration

from qgsw.configs.models import ModelConfig


config_dict = {
    "io": {
        "name": "Assimilation.",
        "output": {
            "save": True,
            "type": "interval",
            "interval_duration": 24*3600,  # seconds
            "directory": "tmp",
        },
    },
    "physics" :{
        "rho": 1000,
        "slip_coef": 1.0,
        "f0": 9.375e-5,           # mean coriolis (s^-1)
        "beta": 1.754e-11,                # coriolis gradient (m^-1 s^-1)
        "bottom_drag_coefficient": 3.60577e-8,
    },
    "simulation": {
        "type": "assimilation",
        "duration":1*24*3600,    # seconds
        "dt": 3600,                   # seconds
        "fork_interval": 1*24*3600,   # seconds
        "startup_file": "../output/g5k/double_gyre_qg_long/results_step_876000.pt",
        "reference": {
            "type": "QG",
            "prefix": "reference_step_",
            "layers": [400,1100,2600],
            "reduced_gravity": [9.81,0.025,0.0125],
        },
    },
    "model":{
        "type": "QGCollinearFilteredSF",
        "prefix": "results_step_",
        "layers": [400, 1100],
        "reduced_gravity": [9.81, 0.025],
        "sigma": 20.35,
        "collinearity_coef": {
            "type": "smooth-non-uniform",
            "initial": [0,0,0,0,0,0,0,0],
            "centers": [[32,32],[32,96],[32,160],[32,224],[96,32],[96,96],[96,160],[96,224]],
            "sigma": 20.35,
            "use_optimal": True,
        },
    },
    "space":{
        "nx": 128,
        "ny": 256,
        "unit": "m",
        "x_min": 0,
        "x_max": 2_560_000,
        "y_min": 0,
        "y_max": 5_120_000,
    },
    "windstress": {
        "type": "cosine",
        "magnitude": 0.08,
        "drag_coefficient": 0.0013,
    },
    "perturbation":{
        "type": "none",
        "perturbation_magnitude": 1e-3,
    }
}
model_1l_config = {
    "type": "QG",
    "prefix": "results_step_",
    "layers": [400],
    "reduced_gravity": [9.81*0.025/(9.81+0.025)],

}
model_2l_config = {
    "type": "QG",
    "prefix": "results_step_",
    "layers": [400, 1100],
    "reduced_gravity": [9.81, 0.025],

}

In [None]:
config = Configuration(**config_dict)
f0 = config.physics.f0
nx = config.space.nx
ny = config.space.ny
dx = config.space.dx
dy = config.space.dy
ds  = config.space.ds

In [None]:
from qgsw.fields.variables.coefficients.core import SmoothNonUniformCoefficient
from qgsw.filters.high_pass import GaussianHighPass2D
from qgsw.models.qg.projected.modified.filtered.core import QGCollinearFilteredSF


config_1l = ModelConfig(**model_1l_config)
config_2l = ModelConfig(**model_2l_config)
## Wind Forcing
wind = WindForcing.from_config(config.windstress, config.space, config.physics)
taux, tauy = wind.compute()
## Rossby
Ro = 0.1
## Vortex
perturbation = Perturbation.from_config(
    perturbation_config=config.perturbation,
)
space_2d = SpaceDiscretization2D.from_config(config.space)

model_ref = instantiate_model(
    config.simulation.reference,
    config.physics.beta_plane,
    space_2d,
    perturbation,
    Ro=0.1,
)

model_ref.slip_coef = config.physics.slip_coef
model_ref.bottom_drag_coef = config.physics.bottom_drag_coefficient
if np.isnan(config.simulation.dt):
    model_ref.dt = time_params.compute_dt(
        model_ref.prognostic.uvh,
        model_ref.space,
        model_ref.g_prime,
        model_ref.H,
    )
else:
    model_ref.dt = config.simulation.dt
model_ref.compute_time_derivatives(model_ref.prognostic.uvh)
model_ref.set_wind_forcing(taux, tauy)


ref = RunOutput("../output/local/assimilation_ref")

p0_mean = sum(
    model_ref.P.compute_p(o_ref.read())[1][0, 0] for o_ref in ref.outputs()
)
p0_mean /= sum(1 for _ in ref.outputs())

p1_mean = sum(
    model_ref.P.compute_p(o_ref.read())[1][0, 1] for o_ref in ref.outputs()
)
p1_mean /= sum(1 for _ in ref.outputs())
offset_p0 = p0_mean.unsqueeze(0).unsqueeze(0)
offset_p1 = p1_mean.unsqueeze(0).unsqueeze(0)

model_1l = instantiate_model(
    config_1l,
    config.physics.beta_plane,
    space_2d,
    perturbation,
    Ro=0.1,
)
model_1l.slip_coef = config.physics.slip_coef
model_1l.bottom_drag_coef = config.physics.bottom_drag_coefficient
model_1l.dt=model_ref.dt
model_1l.compute_time_derivatives(model_1l.prognostic.uvh)
model_1l.set_wind_forcing(taux, tauy)

model_2l = instantiate_model(
    config_2l,
    config.physics.beta_plane,
    space_2d,
    perturbation,
    Ro=0.1,
)
model_2l.slip_coef = config.physics.slip_coef
model_2l.bottom_drag_coef = config.physics.bottom_drag_coefficient
model_2l.dt=model_ref.dt
model_2l.compute_time_derivatives(model_2l.prognostic.uvh)
model_2l.set_wind_forcing(taux, tauy)

model = QGCollinearFilteredSF(
    space_2d=SpaceDiscretization2D.from_config(config.space),
    H = config.model.h,
    g_prime=config.model.g_prime,
    beta_plane=config.physics.beta_plane,
    optimize=True
)
model.P.filter.sigma = 20.35
coef = SmoothNonUniformCoefficient(nx=nx,ny=ny)
coef.sigma = 20.35
coef.centers = [[32,32],[32,96],[32,160],[32,224],[96,32],[96,96],[96,160],[96,224]]
model.slip_coef = config.physics.slip_coef
model.bottom_drag_coef = config.physics.bottom_drag_coefficient
model.dt = model_ref.dt
model.set_wind_forcing(taux, tauy)

model_ = QGCollinearFilteredSF(
    space_2d=SpaceDiscretization2D.from_config(config.space),
    H = config.model.h,
    g_prime=config.model.g_prime,
    beta_plane=config.physics.beta_plane,
    optimize=True
)
model_.P.filter.sigma = 40
coef_ = SmoothNonUniformCoefficient(nx=nx,ny=ny)
coef_.sigma = 40
coef_.centers = [[32,32],[32,96],[32,160],[32,224],[96,32],[96,96],[96,160],[96,224]]
model_.slip_coef = config.physics.slip_coef
model_.bottom_drag_coef = config.physics.bottom_drag_coefficient
model_.dt = model_ref.dt
model_.set_wind_forcing(taux, tauy)

verbose.display("\n[Reference Model]", trigger_level=1)
verbose.display(msg=model_ref.__repr__(), trigger_level=1)
verbose.display("\n[Model]", trigger_level=1)
verbose.display(msg=model.__repr__(), trigger_level=1)

nl_ref = model_ref.space.nl
nl = model.space.nl
if model.get_type() == ModelName.QG_SANITY_CHECK:
    nl += 1

nx = model.space.nx
ny = model.space.ny

dtype = torch.float64
device = DEVICE.get()

if (startup_file := config.simulation.startup_file) is None:
    uvh0 = UVH.steady(
        n_ens=1,
        nl=nl_ref,
        nx=config.space.nx,
        ny=config.space.ny,
        dtype=torch.float64,
        device=DEVICE.get(),
    )
else:
    uvh0 = UVH.from_file(startup_file, dtype=dtype, device=device)
    horizontal_shape = uvh0.h.shape[-2:]
    if horizontal_shape != (nx, ny):
        msg = (
            f"Horizontal shape {horizontal_shape} from {startup_file}"
            f" should be ({nx},{ny})."
        )
        raise ValueError(msg)

model_ref.set_uvh(
    torch.clone(uvh0.u),
    torch.clone(uvh0.v),
    torch.clone(uvh0.h),
)

dt = model.dt
t_end = config.simulation.duration

steps = Steps(t_end=t_end, dt=dt)
print(steps)

ns = steps.simulation_steps()
forks = steps.steps_from_interval(interval=config.simulation.fork_interval)
saves = config.io.output.get_saving_steps(steps)

t = 0

prefix_ref = config.simulation.reference.prefix
prefix = config.model.prefix
output_dir = config.io.output.directory


In [None]:
from qgsw.fields.errors.point_wise import RMSE
from qgsw.fields.variables.dynamics import PhysicalVorticity, StreamFunctionFromVorticity, Vorticity
from qgsw.models.qg.projected.modified.filtered.variables import CollinearFilteredPsi2


model_vars = model.get_variable_set(config.space,config.physics,config.model)
model_ref_vars = model_ref.get_variable_set(config.space,config.physics,config.simulation.reference)
model_1l_vars = model_1l.get_variable_set(config.space,config.physics,config_1l)
model_2l_vars = model_2l.get_variable_set(config.space,config.physics,config_2l)

error = RMSE(model_vars["psi2"],model_ref_vars["psi2"])
error.slices = [slice(None,None), slice(0,1),...]
error_ = RMSE(CollinearFilteredPsi2(StreamFunctionFromVorticity(PhysicalVorticity(Vorticity(), dx*dy),nx,ny,dx, dy), model_.P.filter, offset_p0/f0, offset_p1/f0),model_ref_vars["psi2"])
error_.slices = [slice(None,None), slice(0,1),...]
error_1l = RMSE(model_1l_vars["psi2"],model_ref_vars["psi2"])
error_1l.slices = [slice(None,None), slice(0,1),...]
error_2l = RMSE(model_2l_vars["psi2"],model_ref_vars["psi2"])
error_2l.slices = [slice(None,None), slice(0,1),...]

In [None]:
errors = []
errors_ = []
errors_1l = []
errors_2l = []

prognostic = model_ref.prognostic
pressure = model_ref.P.compute_p(prognostic.uvh)[1]

coef.with_optimal_values(model.P.filter(pressure[0,0]), pressure[0,1])
plt.imshow(coef.get()[0,0].cpu().T)
plt.colorbar()
plt.title(r"$\alpha$ - Col Filt & Time avg")
plt.show()
model.alpha = coef.get()
coef_.with_optimal_values(model_.P.filter(pressure[0,0]-offset_p0[0,0]), pressure[0,1]-offset_p1[0,0])
plt.imshow(coef_.get()[0,0].cpu().T)
plt.colorbar()
plt.title(r"$\alpha$ - Col Filt & Time avg")
plt.show()
model_.alpha = coef.get()

model.set_p(
    model_ref.P.compute_p(prognostic)[0][:, :1],
)
model_.set_p(
    model_ref.P.compute_p(prognostic)[0][:, :1],
    offset_p0=offset_p0,
    offset_p1=offset_p1,
)
model_1l.set_p(
    model_ref.P.compute_p(prognostic)[0][:, :1],
)
model_2l.set_p(
    model_ref.P.compute_p(prognostic)[0][:, :2],
)

for n, fork, save in zip(ns, forks, saves):
    errors.append(
        error.compute_ensemble_wise(
            model.prognostic,
            model_ref.prognostic,
        ).cpu().item()
    )
    errors_.append(
        error_.compute_ensemble_wise(
            model_.prognostic,
            model_ref.prognostic,
        ).cpu().item()
    )

    errors_1l.append(
        error_1l.compute_ensemble_wise(
            model_1l.prognostic,
            model_ref.prognostic,
        ).cpu().item()
    )

    errors_2l.append(
        error_2l.compute_ensemble_wise(
            model_2l.prognostic,
            model_ref.prognostic,
        ).cpu().item()
    )

    if save:
        # Save Reference Model
        # model_ref.io.save(output_dir.joinpath(f"{prefix_ref}{n}.pt"))
        # # Save Model
        # model.io.save(output_dir.joinpath(f"{prefix}{n}.pt"))
        ...
        

    model_ref.step()
    model.step()
    model_.step()
    model_1l.step()
    model_2l.step()

errors = np.array(errors)
errors_ = np.array(errors_)
errors_1l = np.array(errors_1l)
errors_2l = np.array(errors_2l)

In [None]:
from qgsw.plots.scatter import ScatterPlot

plot = ScatterPlot([errors_1l/errors_1l, errors_2l/errors_1l, errors/errors_1l, errors_/errors_1l],)
plot.figure.update_layout(template="plotly")
plot.set_traces_name(
    "1L vs 3L", 
    "2L vs 3L", 
    "Col Filt vs 3L", 
    "Col Filt & Time avg vs 3L"
)
plot.show()

In [None]:
import matplotlib.pyplot as plt

sf = StreamFunctionFromVorticity(PhysicalVorticity(Vorticity(),ds),nx, ny, dx, dy)

plt.imshow(sf.compute(model.prognostic)[0,0].cpu().T)
plt.colorbar()
plt.title(r"$\tilde{\psi}_1$")
plt.show()

plt.imshow((model.P.filter(sf.compute(model.prognostic)[0,0])*model.alpha[0,0]).cpu().T)
plt.colorbar()
plt.title(r"$\tilde{\psi}_2 = \alpha K * \tilde{\psi_1}$")
plt.show()

plt.imshow(sf.compute(model_.prognostic)[0,0].cpu().T)
plt.colorbar()
plt.title(r"$\tilde{\psi}_1$")
plt.show()

plt.imshow(((model_.P.filter(sf.compute(model_.prognostic)[0,0]-offset_p0[0,0]/f0)+offset_p1[0,0]/f0)*model_.alpha[0,0]).cpu().T)
plt.colorbar()
plt.title(r"$\tilde{\psi}_2 = \alpha K * \left(\tilde{\psi}_1 - \bar{\psi}_1\right) + \bar{\psi}_2$")
plt.show()


plt.imshow(sf.compute(model_2l.prognostic)[0,0].cpu().T)
plt.title(r"$\psi_1^{2L}$")
plt.colorbar()
plt.show()

plt.imshow(sf.compute(model_2l.prognostic)[0,1].cpu().T)
plt.title(r"$\psi_2^{2L}$")
plt.colorbar()
plt.show()

plt.imshow(sf.compute(model_ref.prognostic)[0,0].cpu().T)
plt.title(r"$\psi_1^{3L}$")
plt.colorbar()
plt.show()

plt.imshow(sf.compute(model_ref.prognostic)[0,1].cpu().T)
plt.title(r"$\psi_2^{3L}$")
plt.colorbar()
plt.show()

In [None]:
model_psiq_3l_config = {
    "type": "QGPSIQ",
    "prefix": "results_step_",
    "layers": [400, 1100, 2600],
    "reduced_gravity": [9.81, 0.025, 0.0125],

}
model_psiq_2l_config = {
    "type": "QGPSIQ",
    "prefix": "results_step_",
    "layers": [400, 1100],
    "reduced_gravity": [9.81, 0.025],
}
model_psiq_1l_config = {
    "type": "QGPSIQ",
    "prefix": "results_step_",
    "layers": [400],
    "reduced_gravity": [9.81*0.025/(9.81+0.025)],
}

In [None]:
from qgsw.fields.variables.coefficients.core import SmoothNonUniformCoefficient
from qgsw.filters.high_pass import GaussianHighPass2D
from qgsw.masks import Masks
from qgsw.models.qg.usual.filtered.core import QGPSIQCollinearFilteredSF
from qgsw.specs import defaults

config = Configuration(**config_dict)
config_psiq_3l = ModelConfig(**model_psiq_3l_config)
config_psiq_2l = ModelConfig(**model_psiq_2l_config)
config_psiq_1l = ModelConfig(**model_psiq_1l_config)

ref = RunOutput("../output/local/assimilation_psiq_ref")

p0_psiq_mean:torch.Tensor = sum(o_ref.read().psi[0, 0] for o_ref in ref.outputs())
p0_psiq_mean /= sum(1 for _ in ref.outputs())
p0_psiq_mean *= config.physics.f0

p1_psiq_mean:torch.Tensor = sum(o_ref.read().psi[0, 1] for o_ref in ref.outputs())
p1_psiq_mean /= sum(1 for _ in ref.outputs())
p1_psiq_mean *= config.physics.f0

offset_p0_psiq = p0_psiq_mean.unsqueeze(0).unsqueeze(0)
offset_p1_psiq = p1_psiq_mean.unsqueeze(0).unsqueeze(0)
## Wind Forcing
wind = WindForcing.from_config(config.windstress, config.space, config.physics)
taux, tauy = wind.compute()
## Rossby
Ro = 0.1
## Vortex
perturbation = Perturbation.from_config(
    perturbation_config=config.perturbation,
)
space_2d = SpaceDiscretization2D.from_config(config.space)

model_psiq_ref = instantiate_model(
    config_psiq_3l,
    config.physics.beta_plane,
    space_2d,
    perturbation,
    Ro=0.1,
)

model_psiq_ref.slip_coef = config.physics.slip_coef
model_psiq_ref.bottom_drag_coef = config.physics.bottom_drag_coefficient
if np.isnan(config.simulation.dt):
    model_psiq_ref.dt = time_params.compute_dt(
        model_psiq_ref.prognostic.uvh,
        model_psiq_ref.space,
        model_psiq_ref.g_prime,
        model_psiq_ref.H,
    )
else:
    model_psiq_ref.dt = config.simulation.dt
model_psiq_ref.set_wind_forcing(taux, tauy)

model_psiq_2l = instantiate_model(
    config_psiq_2l,
    config.physics.beta_plane,
    space_2d,
    perturbation,
    Ro=0.1,
)
model_psiq_2l.slip_coef = config.physics.slip_coef
model_psiq_2l.bottom_drag_coef = config.physics.bottom_drag_coefficient
model_psiq_2l.dt=model_psiq_ref.dt
model_psiq_2l.set_wind_forcing(taux, tauy)

model_psiq_1l = instantiate_model(
    config_psiq_1l,
    config.physics.beta_plane,
    space_2d,
    perturbation,
    Ro=0.1,
)
model_psiq_1l.slip_coef = config.physics.slip_coef
model_psiq_1l.bottom_drag_coef = config.physics.bottom_drag_coefficient
model_psiq_1l.dt=model_psiq_ref.dt
model_psiq_1l.set_wind_forcing(taux, tauy)


model_psiq = QGPSIQCollinearFilteredSF(
    space_2d=SpaceDiscretization2D.from_config(config.space),
    H=config_psiq_2l.h,
    g_prime=config_psiq_2l.g_prime,
    beta_plane=config.physics.beta_plane,
    optimize=True,
)
model_psiq.sigma = 20.35
model_psiq.masks = Masks.empty_tensor(128,256,device=defaults.get_device())
model_psiq.alpha = torch.zeros((1,1,129,257),**defaults.get())
coef_psiq = SmoothNonUniformCoefficient(nx=129,ny=257)
coef_psiq.sigma = 20.35
coef_psiq.centers = [[32,32],[32,96],[32,160],[32,224],[96,32],[96,96],[96,160],[96,224]]
model_psiq.set_p(torch.zeros((1,1,129,257), **defaults.get()))
model_psiq.slip_coef = config.physics.slip_coef
model_psiq.bottom_drag_coef = config.physics.bottom_drag_coefficient
model_psiq.dt=model_psiq_ref.dt
model_psiq.set_wind_forcing(taux, tauy)

model_psiq_ = QGPSIQCollinearFilteredSF(
    space_2d=SpaceDiscretization2D.from_config(config.space),
    H=config_psiq_2l.h,
    g_prime=config_psiq_2l.g_prime,
    beta_plane=config.physics.beta_plane,
    optimize=True,
)
# model_psiq_.offset_psi0 = offset_p0_psiq/config.physics.f0
# model_psiq_.offset_psi1 = offset_p1_psiq/config.physics.f0
model_psiq_.sigma = 40
model_psiq_.masks = Masks.empty_tensor(128,256,device=defaults.get_device())
model_psiq_.alpha = torch.zeros((1,1,129,257),**defaults.get())
coef_psiq_ = SmoothNonUniformCoefficient(nx=129,ny=257)
coef_psiq_.sigma = 40
coef_psiq_.centers = [[32,32],[32,96],[32,160],[32,224],[96,32],[96,96],[96,160],[96,224]]
model_psiq_.set_p(torch.zeros((1,1,129,257), **defaults.get()))
model_psiq_.slip_coef = config.physics.slip_coef
model_psiq_.bottom_drag_coef = config.physics.bottom_drag_coefficient
model_psiq_.dt=model_psiq_ref.dt
model_psiq_.set_wind_forcing(taux, tauy)



verbose.display("\n[Reference Model]", trigger_level=1)
verbose.display(msg=model_psiq_ref.__repr__(), trigger_level=1)
verbose.display("\n[Model]", trigger_level=1)
verbose.display(msg=model_psiq_1l.__repr__(), trigger_level=1)

nl_ref = model_psiq_ref.space.nl
nl = model_psiq_1l.space.nl
if model_psiq_1l.get_type() == ModelName.QG_SANITY_CHECK:
    nl += 1

nx = model_psiq_1l.space.nx
ny = model_psiq_1l.space.ny

dtype = torch.float64
device = DEVICE.get()

if (startup_file := config.simulation.startup_file) is None:
    uvh0 = UVH.steady(
        n_ens=1,
        nl=nl_ref,
        nx=config.space.nx,
        ny=config.space.ny,
        dtype=torch.float64,
        device=DEVICE.get(),
    )
else:
    uvh0 = UVH.from_file(startup_file, dtype=dtype, device=device)
    horizontal_shape = uvh0.h.shape[-2:]
    if horizontal_shape != (nx, ny):
        msg = (
            f"Horizontal shape {horizontal_shape} from {startup_file}"
            f" should be ({nx},{ny})."
        )
        raise ValueError(msg)

model_psiq_ref.set_p(
    model_ref.P.compute_p(uvh0)[0][:, :],
)


dt = model_psiq_1l.dt
t_end = config.simulation.duration

steps = Steps(t_end=t_end, dt=dt)
print(steps)

ns = steps.simulation_steps()
forks = steps.steps_from_interval(interval=config.simulation.fork_interval)
saves = config.io.output.get_saving_steps(steps)

t = 0

prefix_ref = config.simulation.reference.prefix
prefix = config_psiq_1l.prefix
output_dir = config.io.output.directory

In [None]:
from qgsw.models.qg.usual.variables import Psi21L
from qgsw.models.qg.usual.filtered.variables import CollinearFilteredPsi2
from qgsw.fields.errors.point_wise import RMSE


model_psiq_ref_vars = model_psiq_ref.get_variable_set(config.space,config.physics,config.model)
model_psiq_2l_vars = model_psiq_2l.get_variable_set(config.space,config.physics,config.model)


error_psiq_1l = RMSE(Psi21L(),model_psiq_ref_vars["psi2"])
error_psiq_1l.slices = [slice(None,None), slice(0,1),...]
error_psiq_2l = RMSE(model_psiq_2l_vars["psi2"],model_psiq_ref_vars["psi2"])
error_psiq_2l.slices = [slice(None,None), slice(0,1),...]
error_psiq = RMSE(CollinearFilteredPsi2(model_psiq.filter,),model_psiq_ref_vars["psi2"])
error_psiq.slices = [slice(None,None), slice(0,1),...]
error_psiq_ = RMSE(CollinearFilteredPsi2(model_psiq.filter, offset_p0_psiq/config.physics.f0, offset_p1_psiq/config.physics.f0),model_psiq_ref_vars["psi2"])
error_psiq_.slices = [slice(None,None), slice(0,1),...]


In [None]:
errors_psiq = []
errors_psiq_ = []
errors_psiq_1l = []
errors_psiq_2l = []

prognostic = model_psiq_ref.prognostic
psi = model_psiq_ref.psi


coef_psiq.with_optimal_values(
    model_psiq.filter(psi[0,0]), psi[0,1]
)
plt.imshow(coef_psiq.get()[0,0].cpu().T)
plt.colorbar()
plt.title(r"$\alpha$ - Col Filt")
plt.show()
model_psiq.alpha = coef_psiq.get()

coef_psiq_.with_optimal_values(
    model_psiq.filter(psi[0,0]-offset_p0_psiq[0,0]/f0), psi[0,1]-offset_p1_psiq[0,0]/f0
)
plt.imshow(coef_psiq_.get()[0,0].cpu().T)
plt.colorbar()
plt.title(r"$\alpha$ - Col Filt & Time avg")
plt.show()
model_psiq_.alpha = coef_psiq_.get()
model_psiq_2l.set_p(psi[:,:2]*f0)

model_psiq_1l.set_p(psi[:,:1]*f0)

model_psiq.set_p(psi[:,:1]*f0)

model_psiq_.set_p(psi[:,:1]*f0)


for n, fork, save in zip(ns, forks, saves):

    errors_psiq.append(
        error_psiq.compute_ensemble_wise(
            model_psiq.prognostic,
            model_psiq_ref.prognostic,
        ).cpu().item()
    )
    errors_psiq_.append(
        error_psiq_.compute_ensemble_wise(
            model_psiq_.prognostic,
            model_psiq_ref.prognostic,
        ).cpu().item()
    )
    errors_psiq_1l.append(
        error_psiq_1l.compute_ensemble_wise(
            model_psiq_1l.prognostic,
            model_psiq_ref.prognostic,
        ).cpu().item()
    )
    errors_psiq_2l.append(
        error_psiq_2l.compute_ensemble_wise(
            model_psiq_2l.prognostic,
            model_psiq_ref.prognostic,
        ).cpu().item()
    )
    
    if save:
        # Save Reference Model
        # model_ref.io.save(output_dir.joinpath(f"{prefix_ref}{n}.pt"))
        # # Save Model
        # model.io.save(output_dir.joinpath(f"{prefix}{n}.pt"))
        ...

    model_psiq_ref.step()
    model_psiq_2l.step()
    model_psiq_1l.step()
    model_psiq.step()
    model_psiq_.step()

errors_psiq_2l = np.array(errors_psiq_2l)
errors_psiq_1l = np.array(errors_psiq_1l)
errors_psiq = np.array(errors_psiq)
errors_psiq_ = np.array(errors_psiq_)

In [None]:
from qgsw.plots.scatter import ScatterPlot

plot = ScatterPlot([errors_psiq_1l/errors_psiq_1l,errors_psiq_2l/errors_psiq_1l, errors_psiq/errors_psiq_1l, errors_psiq_/errors_psiq_1l],)
plot.figure.update_layout(template="plotly")
plot.set_traces_name(
    "1L vs 3L", 
    "2L vs 3L", 
    "Col Filt vs 3L", 
    "Col Filt & Time avg vs 3L"
)
plot.show()

In [None]:
import matplotlib.pyplot as plt


plt.imshow(model_psiq.psi[0,0].cpu().T)
plt.colorbar()
plt.title(r"$\tilde{\psi}_1$")
plt.show()

plt.imshow((model_psiq.filter(model_psiq.psi[0,0])*model_psiq.alpha[0,0]).cpu().T)
plt.colorbar()
plt.title(r"$\tilde{\psi}_2 = \alpha K * \tilde{\psi_1}$")
plt.show()

plt.imshow(model_psiq_.psi[0,0].cpu().T)
plt.colorbar()
plt.title(r"$\tilde{\psi}_1$")
plt.show()

plt.imshow(((model_psiq_.filter(model_psiq_.psi[0,0]-offset_p0_psiq[0,0]/config.physics.f0)+offset_p1_psiq[0,0]/config.physics.f0)*model_psiq_.alpha[0,0]).cpu().T)
plt.colorbar()
plt.title(r"$\tilde{\psi}_2 = \alpha K * \left(\tilde{\psi}_1 - \bar{\psi}_1\right) + \bar{\psi}_2$")
plt.show()

plt.imshow(model_psiq_2l.psi[0,0].cpu().T)
plt.colorbar()
plt.title(r"$\tilde{\psi}_1^{2L}$")
plt.show()

plt.imshow(model_psiq_2l.psi[0,1].cpu().T)
plt.title(r"$\psi_2^{2L}$")
plt.colorbar()
plt.show()

plt.imshow(model_psiq_ref.psi[0,0].cpu().T)
plt.colorbar()
plt.title(r"$\tilde{\psi}_1^{3L}$")
plt.show()

plt.imshow(model_psiq_ref.psi[0,1].cpu().T)
plt.title(r"$\psi_2^{3L}$")
plt.colorbar()
plt.show()

In [None]:
import matplotlib.pyplot as plt


plt.imshow(model_psiq.q[0,0].cpu().T)
plt.colorbar()
plt.title(r"$\tilde{q}_1$")
plt.show()

plt.imshow(model_psiq_.q[0,0].cpu().T)
plt.colorbar()
plt.title(r"$\tilde{q}_1$")
plt.show()

plt.imshow(model_psiq_2l.q[0,0].cpu().T)
plt.colorbar()
plt.title(r"$q_1^{2L}$")
plt.show()

plt.imshow(model_psiq_ref.q[0,0].cpu().T)
plt.title(r"$q_2^{3L}$")
plt.colorbar()
plt.show()