In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from __future__ import annotations


import numpy as np
import matplotlib.pyplot as plt
import torch

from qgsw import verbose
from qgsw.configs.core import Configuration
from qgsw.fields.variables.prognostic_tuples import UVH
from qgsw.forcing.wind import WindForcing
from qgsw.models.instantiation import instantiate_model
from qgsw.models.names import ModelName
from qgsw.output import RunOutput
from qgsw.perturbations.core import Perturbation
from qgsw.simulation.steps import Steps
from qgsw.spatial.core.discretization import (
    SpaceDiscretization2D,
)
from qgsw.models.synchronization import ModelSync
from qgsw.specs import DEVICE
from qgsw.utils import time_params

torch.backends.cudnn.deterministic = True

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
## Configuration

from qgsw.configs.models import ModelConfig


config_dict = {
    "io": {
        "name": "Assimilation.",
        "output": {
            "save": True,
            "type": "interval",
            "interval_duration": 24*3600,  # seconds
            "directory": "tmp",
        },
    },
    "physics" :{
        "rho": 1000,
        "slip_coef": 1.0,
        "f0": 9.375e-5,           # mean coriolis (s^-1)
        "beta": 1.754e-11,                # coriolis gradient (m^-1 s^-1)
        "bottom_drag_coefficient": 3.60577e-8,
    },
    "simulation": {
        "type": "assimilation",
        "duration":1*24*3600,    # seconds
        "dt": 3000,                   # seconds
        "fork_interval": 1*24*3600,   # seconds
        "startup_file": "../output/g5k/double_gyre_qg_long/results_step_876000.pt",
        "reference": {
            "type": "QG",
            "prefix": "reference_step_",
            "layers": [400,1100,2600],
            "reduced_gravity": [9.81,0.025,0.0125],
        },
    },
    "model":{
        "type": "QGCollinearFilteredSF",
        "prefix": "results_step_",
        "layers": [400, 1100],
        "reduced_gravity": [9.81, 0.025],
        "sigma": 20.35,
        "collinearity_coef": {
            "type": "smooth-non-uniform",
            "initial": [0,0,0,0,0,0,0,0],
            "centers": [[32,32],[32,96],[32,160],[32,224],[96,32],[96,96],[96,160],[96,224]],
            "sigma": 20.35,
            "use_optimal": True,
        },
    },
    "space":{
        "nx": 128,
        "ny": 256,
        "unit": "m",
        "x_min": 0,
        "x_max": 2_560_000,
        "y_min": 0,
        "y_max": 5_120_000,
    },
    "windstress": {
        "type": "cosine",
        "magnitude": 0.08,
        "drag_coefficient": 0.0013,
    },
    "perturbation":{
        "type": "none",
        "perturbation_magnitude": 1e-3,
    }
}
model_1l_config = {
    "type": "QG",
    "prefix": "results_step_",
    "layers": [400],
    "reduced_gravity": [9.81*0.025/(9.81+0.025)],

}
model_2l_config = {
    "type": "QG",
    "prefix": "results_step_",
    "layers": [400, 1100],
    "reduced_gravity": [9.81, 0.025],

}

In [None]:
config = Configuration(**config_dict)
f0 = config.physics.f0
nx = config.space.nx
ny = config.space.ny
dx = config.space.dx
dy = config.space.dy
ds  = config.space.ds

In [None]:
from qgsw.fields.variables.coefficients.core import SmoothNonUniformCoefficient
from qgsw.filters.high_pass import GaussianHighPass2D
from qgsw.models.qg.uvh.modified.filtered.core import QGCollinearFilteredSF


config_1l = ModelConfig(**model_1l_config)
config_2l = ModelConfig(**model_2l_config)
## Wind Forcing
wind = WindForcing.from_config(config.windstress, config.space, config.physics)
taux, tauy = wind.compute()
## Rossby
Ro = 0.1
## Vortex
perturbation = Perturbation.from_config(
    perturbation_config=config.perturbation,
)
space_2d = SpaceDiscretization2D.from_config(config.space)

model_ref = instantiate_model(
    config.simulation.reference,
    config.physics.beta_plane,
    space_2d,
    perturbation,
    Ro=0.1,
)

model_ref.slip_coef = config.physics.slip_coef
model_ref.bottom_drag_coef = config.physics.bottom_drag_coefficient
if np.isnan(config.simulation.dt):
    model_ref.dt = time_params.compute_dt(
        model_ref.prognostic.uvh,
        model_ref.space,
        model_ref.g_prime,
        model_ref.H,
    )
else:
    model_ref.dt = config.simulation.dt
model_ref.compute_time_derivatives(model_ref.prognostic.uvh)
model_ref.set_wind_forcing(taux, tauy)


ref = RunOutput("../output/local/assimilation_ref")

p0_mean = sum(
    model_ref.P.compute_p(o_ref.read())[1][0, 0] for o_ref in ref.outputs()
)
p0_mean /= sum(1 for _ in ref.outputs())

p1_mean = sum(
    model_ref.P.compute_p(o_ref.read())[1][0, 1] for o_ref in ref.outputs()
)
p1_mean /= sum(1 for _ in ref.outputs())
offset_p0 = p0_mean.unsqueeze(0).unsqueeze(0)
offset_p1 = p1_mean.unsqueeze(0).unsqueeze(0)

model_1l = instantiate_model(
    config_1l,
    config.physics.beta_plane,
    space_2d,
    perturbation,
    Ro=0.1,
)
model_1l.slip_coef = config.physics.slip_coef
model_1l.bottom_drag_coef = config.physics.bottom_drag_coefficient
model_1l.dt=model_ref.dt
model_1l.compute_time_derivatives(model_1l.prognostic.uvh)
model_1l.set_wind_forcing(taux, tauy)

model_2l = instantiate_model(
    config_2l,
    config.physics.beta_plane,
    space_2d,
    perturbation,
    Ro=0.1,
)
model_2l.slip_coef = config.physics.slip_coef
model_2l.bottom_drag_coef = config.physics.bottom_drag_coefficient
model_2l.dt=model_ref.dt
model_2l.compute_time_derivatives(model_2l.prognostic.uvh)
model_2l.set_wind_forcing(taux, tauy)

model = QGCollinearFilteredSF(
    space_2d=SpaceDiscretization2D.from_config(config.space),
    H = config.model.h,
    g_prime=config.model.g_prime,
    beta_plane=config.physics.beta_plane,
    optimize=True
)
model.P.filter.sigma = 20.35
coef = SmoothNonUniformCoefficient(nx=nx,ny=ny)
coef.sigma = 20.35
coef.centers = [[32,32],[32,96],[32,160],[32,224],[96,32],[96,96],[96,160],[96,224]]
model.slip_coef = config.physics.slip_coef
model.bottom_drag_coef = config.physics.bottom_drag_coefficient
model.dt = model_ref.dt
model.set_wind_forcing(taux, tauy)

model_ = QGCollinearFilteredSF(
    space_2d=SpaceDiscretization2D.from_config(config.space),
    H = config.model.h,
    g_prime=config.model.g_prime,
    beta_plane=config.physics.beta_plane,
    optimize=True
)
model_.P.filter.sigma = 40
coef_ = SmoothNonUniformCoefficient(nx=nx,ny=ny)
coef_.sigma = 40
coef_.centers = [[32,32],[32,96],[32,160],[32,224],[96,32],[96,96],[96,160],[96,224]]
model_.slip_coef = config.physics.slip_coef
model_.bottom_drag_coef = config.physics.bottom_drag_coefficient
model_.dt = model_ref.dt
model_.set_wind_forcing(taux, tauy)

verbose.display("\n[Reference Model]", trigger_level=1)
verbose.display(msg=model_ref.__repr__(), trigger_level=1)
verbose.display("\n[Model]", trigger_level=1)
verbose.display(msg=model.__repr__(), trigger_level=1)

nl_ref = model_ref.space.nl
nl = model.space.nl
if model.get_type() == ModelName.QG_SANITY_CHECK:
    nl += 1

nx = model.space.nx
ny = model.space.ny

dtype = torch.float64
device = DEVICE.get()

if (startup_file := config.simulation.startup_file) is None:
    uvh0 = UVH.steady(
        n_ens=1,
        nl=nl_ref,
        nx=config.space.nx,
        ny=config.space.ny,
        dtype=torch.float64,
        device=DEVICE.get(),
    )
else:
    uvh0 = UVH.from_file(startup_file, dtype=dtype, device=device)
    horizontal_shape = uvh0.h.shape[-2:]
    if horizontal_shape != (nx, ny):
        msg = (
            f"Horizontal shape {horizontal_shape} from {startup_file}"
            f" should be ({nx},{ny})."
        )
        raise ValueError(msg)

model_ref.set_uvh(
    torch.clone(uvh0.u),
    torch.clone(uvh0.v),
    torch.clone(uvh0.h),
)

dt = model.dt
t_end = config.simulation.duration

steps = Steps(t_end=t_end, dt=dt)
print(steps)

ns = steps.simulation_steps()
forks = steps.steps_from_interval(interval=config.simulation.fork_interval)
saves = config.io.output.get_saving_steps(steps)

t = 0

prefix_ref = config.simulation.reference.prefix
prefix = config.model.prefix
output_dir = config.io.output.directory


In [None]:
from qgsw.fields.errors.point_wise import RMSE
from qgsw.fields.variables.dynamics import PhysicalVorticity, StreamFunctionFromVorticity, Vorticity
from qgsw.models.qg.uvh.modified.filtered.variables import CollinearFilteredPsi2


model_vars = model.get_variable_set(config.space,config.physics,config.model)
model_ref_vars = model_ref.get_variable_set(config.space,config.physics,config.simulation.reference)
model_1l_vars = model_1l.get_variable_set(config.space,config.physics,config_1l)
model_2l_vars = model_2l.get_variable_set(config.space,config.physics,config_2l)

error = RMSE(model_vars["psi2"],model_ref_vars["psi2"])
error.slices = [slice(None,None), slice(0,1),...]
error_ = RMSE(CollinearFilteredPsi2(StreamFunctionFromVorticity(PhysicalVorticity(Vorticity(), dx*dy),nx,ny,dx, dy), model_.P.filter, offset_p0/f0, offset_p1/f0),model_ref_vars["psi2"])
error_.slices = [slice(None,None), slice(0,1),...]
error_1l = RMSE(model_1l_vars["psi2"],model_ref_vars["psi2"])
error_1l.slices = [slice(None,None), slice(0,1),...]
error_2l = RMSE(model_2l_vars["psi2"],model_ref_vars["psi2"])
error_2l.slices = [slice(None,None), slice(0,1),...]

In [None]:
from qgsw.models.synchronization import ModelSync


errors = []
errors_ = []
errors_1l = []
errors_2l = []

prognostic = model_ref.prognostic
pressure = model_ref.P.compute_p(prognostic.uvh)[1]

coef.with_optimal_values(model.P.filter(pressure[0,0]), pressure[0,1])
plt.imshow(coef.get()[0,0].cpu().T)
plt.colorbar()
plt.title(r"$\alpha$ - Col Filt & Time avg")
plt.show()
model.alpha = coef.get()
coef_.with_optimal_values(model_.P.filter(pressure[0,0]-offset_p0[0,0]), pressure[0,1]-offset_p1[0,0])
plt.imshow(coef_.get()[0,0].cpu().T)
plt.colorbar()
plt.title(r"$\alpha$ - Col Filt & Time avg")
plt.show()
model_.alpha = coef.get()

ModelSync(model_ref, model)()
ModelSync(model_ref, model_)()
ModelSync(model_ref, model_1l)()
ModelSync(model_ref, model_2l)()

model.set_p(
    model_ref.P.compute_p(prognostic)[0][:, :1],
)
model_.set_p(
    model_ref.P.compute_p(prognostic)[0][:, :1],
)
model_1l.set_p(
    model_ref.P.compute_p(prognostic)[0][:, :1],
)
model_2l.set_p(
    model_ref.P.compute_p(prognostic)[0][:, :2],
)

for n, fork, save in zip(ns, forks, saves):
    errors.append(
        error.compute_ensemble_wise(
            model.prognostic,
            model_ref.prognostic,
        ).cpu().item()
    )
    errors_.append(
        error_.compute_ensemble_wise(
            model_.prognostic,
            model_ref.prognostic,
        ).cpu().item()
    )

    errors_1l.append(
        error_1l.compute_ensemble_wise(
            model_1l.prognostic,
            model_ref.prognostic,
        ).cpu().item()
    )

    errors_2l.append(
        error_2l.compute_ensemble_wise(
            model_2l.prognostic,
            model_ref.prognostic,
        ).cpu().item()
    )

    if save:
        # Save Reference Model
        # model_ref.io.save(output_dir.joinpath(f"{prefix_ref}{n}.pt"))
        # # Save Model
        # model.io.save(output_dir.joinpath(f"{prefix}{n}.pt"))
        ...
        

    model_ref.step()
    model.step()
    model_.step()
    model_1l.step()
    model_2l.step()

errors = np.array(errors)
errors_ = np.array(errors_)
errors_1l = np.array(errors_1l)
errors_2l = np.array(errors_2l)

In [None]:
from qgsw.plots.scatter import ScatterPlot

plot = ScatterPlot([errors_1l/errors_1l, errors_2l/errors_1l, errors/errors_1l, errors_/errors_1l],)
plot.figure.update_layout(template="plotly")
plot.set_xs(
    [k*config.simulation.dt/3600/24 for k in range(len(errors_1l))],
    [k*config.simulation.dt/3600/24 for k in range(len(errors_1l))],
    [k*config.simulation.dt/3600/24 for k in range(len(errors_1l))],
    [k*config.simulation.dt/3600/24 for k in range(len(errors_1l))],
)
plot.set_xaxis_title("Times [day]")
plot.set_traces_name(
    "1L vs 3L", 
    "2L vs 3L", 
    "Col Filt vs 3L", 
    "Col Filt & Time avg vs 3L"
)
plot.show()

In [None]:
import matplotlib.pyplot as plt

sf = StreamFunctionFromVorticity(PhysicalVorticity(Vorticity(),ds),nx, ny, dx, dy)

plt.imshow(sf.compute(model.prognostic)[0,0].cpu().T)
plt.colorbar()
plt.title(r"$\tilde{\psi}_1$")
plt.show()

plt.imshow((model.P.filter(sf.compute(model.prognostic)[0,0])*model.alpha[0,0]).cpu().T)
plt.colorbar()
plt.title(r"$\tilde{\psi}_2 = \alpha K * \tilde{\psi_1}$")
plt.show()

plt.imshow(sf.compute(model_.prognostic)[0,0].cpu().T)
plt.colorbar()
plt.title(r"$\tilde{\psi}_1$")
plt.show()

plt.imshow(((model_.P.filter(sf.compute(model_.prognostic)[0,0]-offset_p0[0,0]/f0)+offset_p1[0,0]/f0)*model_.alpha[0,0]).cpu().T)
plt.colorbar()
plt.title(r"$\tilde{\psi}_2 = \alpha K * \left(\tilde{\psi}_1 - \bar{\psi}_1\right) + \bar{\psi}_2$")
plt.show()


plt.imshow(sf.compute(model_2l.prognostic)[0,0].cpu().T)
plt.title(r"$\psi_1^{2L}$")
plt.colorbar()
plt.show()

plt.imshow(sf.compute(model_2l.prognostic)[0,1].cpu().T)
plt.title(r"$\psi_2^{2L}$")
plt.colorbar()
plt.show()

plt.imshow(sf.compute(model_ref.prognostic)[0,0].cpu().T)
plt.title(r"$\psi_1^{3L}$")
plt.colorbar()
plt.show()

plt.imshow(sf.compute(model_ref.prognostic)[0,1].cpu().T)
plt.title(r"$\psi_2^{3L}$")
plt.colorbar()
plt.show()

In [None]:
## Configuration

from qgsw.configs.models import ModelConfig


config_dict = {
    "io": {
        "name": "Assimilation.",
        "output": {
            "save": True,
            "type": "interval",
            "interval_duration": 24*3600,  # seconds
            "directory": "tmp",
        },
    },
    "physics" :{
        "rho": 1000,
        "slip_coef": 1.0,
        "f0": 9.375e-5,           # mean coriolis (s^-1)
        "beta": 1.754e-11,                # coriolis gradient (m^-1 s^-1)
        "bottom_drag_coefficient": 3.60577e-8,
    },
    "simulation": {
        "type": "assimilation",
        "duration":20*24*3600,    # seconds
        "dt": 3000,                   # seconds
        "fork_interval": 20*24*3600,   # seconds
        "startup_file": "../output/g5k/double_gyre_qg_long/results_step_876000.pt",
        "reference": {
            "type": "QG",
            "prefix": "reference_step_",
            "layers": [400,1100,2600],
            "reduced_gravity": [9.81,0.025,0.0125],
        },
    },
    "model":{
        "type": "QGCollinearFilteredSF",
        "prefix": "results_step_",
        "layers": [400, 1100],
        "reduced_gravity": [9.81, 0.025],
        "sigma": 20.35,
        "collinearity_coef": {
            "type": "smooth-non-uniform",
            "initial": [0,0,0,0,0,0,0,0],
            "centers": [[32,32],[32,96],[32,160],[32,224],[96,32],[96,96],[96,160],[96,224]],
            "sigma": 20.35,
            "use_optimal": True,
        },
    },
    "space":{
        "nx": 128,
        "ny": 256,
        "unit": "m",
        "x_min": 0,
        "x_max": 2_560_000,
        "y_min": 0,
        "y_max": 5_120_000,
    },
    "windstress": {
        "type": "cosine",
        "magnitude": 0.08,
        "drag_coefficient": 0.0013,
    },
    "perturbation":{
        "type": "none",
        "perturbation_magnitude": 1e-3,
    }
}
model_1l_config = {
    "type": "QG",
    "prefix": "results_step_",
    "layers": [400],
    "reduced_gravity": [9.81*0.025/(9.81+0.025)],

}
model_2l_config = {
    "type": "QG",
    "prefix": "results_step_",
    "layers": [400, 1100],
    "reduced_gravity": [9.81, 0.025],

}
model_sw_config = {
    "type": "SWFilterBarotropicSpectral",
    "prefix": "reference_step_",
    "layers": [400,1100,2600],
    "reduced_gravity": [9.81,0.025,0.0125],
}
model_qg_config = {
    "type": "QG",
    "prefix": "reference_step_",
    "layers": [400,1100,2600],
    "reduced_gravity": [9.81,0.025,0.0125],
}

In [None]:
config = Configuration(**config_dict)
config_1l = ModelConfig(**model_1l_config)
config_2l = ModelConfig(**model_2l_config)
config_sw = ModelConfig(**model_sw_config)
config_qg = ModelConfig(**model_qg_config)

f0 = config.physics.f0
nx = config.space.nx
ny = config.space.ny
dx = config.space.dx
dy = config.space.dy
ds  = config.space.ds

In [None]:
from qgsw.fields.variables.coefficients.core import SmoothNonUniformCoefficient
from qgsw.filters.high_pass import GaussianHighPass2D
from qgsw.masks import Masks
from qgsw.models.qg.psiq.filtered.core import QGPSIQCollinearFilteredSF
from qgsw.specs import defaults

## Wind Forcing
wind = WindForcing.from_config(config.windstress, config.space, config.physics)
taux, tauy = wind.compute()
## Rossby
Ro = 0.1
## Vortex
perturbation = Perturbation.from_config(
    perturbation_config=config.perturbation,
)
space_2d = SpaceDiscretization2D.from_config(config.space)

model_sw = instantiate_model(
    config_sw,
    config.physics.beta_plane,
    space_2d,
    perturbation,
    Ro=0.1,
)

model_sw.slip_coef = config.physics.slip_coef
model_sw.bottom_drag_coef = config.physics.bottom_drag_coefficient
model_sw.dt = config.simulation.dt
model_sw.set_wind_forcing(taux, tauy)

model_qg = instantiate_model(
    config_qg,
    config.physics.beta_plane,
    space_2d,
    perturbation,
    Ro=0.1,
)
model_qg.slip_coef = config.physics.slip_coef
model_qg.bottom_drag_coef = config.physics.bottom_drag_coefficient
model_qg.dt=model_sw.dt
model_qg.compute_time_derivatives(model_qg.prognostic.uvh)
model_qg.set_wind_forcing(taux, tauy)

model_1l = instantiate_model(
    config_1l,
    config.physics.beta_plane,
    space_2d,
    perturbation,
    Ro=0.1,
)
model_1l.slip_coef = config.physics.slip_coef
model_1l.bottom_drag_coef = config.physics.bottom_drag_coefficient
model_1l.dt=model_sw.dt
model_1l.compute_time_derivatives(model_1l.prognostic.uvh)
model_1l.set_wind_forcing(taux, tauy)

model_2l = instantiate_model(
    config_2l,
    config.physics.beta_plane,
    space_2d,
    perturbation,
    Ro=0.1,
)
model_2l.slip_coef = config.physics.slip_coef
model_2l.bottom_drag_coef = config.physics.bottom_drag_coefficient
model_2l.dt=model_sw.dt
model_2l.compute_time_derivatives(model_2l.prognostic.uvh)
model_2l.set_wind_forcing(taux, tauy)

model = QGCollinearFilteredSF(
    space_2d=SpaceDiscretization2D.from_config(config.space),
    H = config.model.h,
    g_prime=config.model.g_prime,
    beta_plane=config.physics.beta_plane,
    optimize=True
)
model.P.filter.sigma = 20.35
coef = SmoothNonUniformCoefficient(nx=nx,ny=ny)
coef.sigma = 20.35
coef.centers = [[32,32],[32,96],[32,160],[32,224],[96,32],[96,96],[96,160],[96,224]]
model.slip_coef = config.physics.slip_coef
model.bottom_drag_coef = config.physics.bottom_drag_coefficient
model.dt = model_sw.dt
model.set_wind_forcing(taux, tauy)


verbose.display("\n[Reference Model]", trigger_level=1)
verbose.display(msg=model_sw.__repr__(), trigger_level=1)
verbose.display("\n[Model]", trigger_level=1)
verbose.display(msg=model.__repr__(), trigger_level=1)

nl_ref = model_sw.space.nl
nl = model.space.nl
if model.get_type() == ModelName.QG_SANITY_CHECK:
    nl += 1

nx = model.space.nx
ny = model.space.ny

uvh0 = UVH.from_file(config.simulation.startup_file, **defaults.get())

model_sw.set_uvh(
    torch.clone(uvh0.u),
    torch.clone(uvh0.v),
    torch.clone(uvh0.h),
)

dt = model.dt
t_end = config.simulation.duration

steps = Steps(t_end=t_end, dt=dt)
print(steps)

ns = steps.simulation_steps()
forks = steps.steps_from_interval(interval=config.simulation.fork_interval)
saves = config.io.output.get_saving_steps(steps)

t = 0

In [None]:
from torch._tensor import Tensor
from qgsw.fields.variables.dynamics import PhysicalVorticity, Psi2, StreamFunctionFromVorticity
from qgsw.fields.variables.prognostic_tuples import BasePrognosticUVH
from qgsw.models.qg.uvh.projectors.core import QGProjector


class SWPsi(StreamFunctionFromVorticity):
    def __init__(self, vorticity: PhysicalVorticity, nx: int, ny: int, dx: float, dy: float, P:QGProjector) -> None:
        super().__init__(vorticity, nx, ny, dx, dy)
        self._P = P

    def _compute(self, prognostic: BasePrognosticUVH) -> Tensor:
        return super()._compute(self._P.project(prognostic))
    
class SWPsi2(Psi2):
    def __init__(self, psi: SWPsi) -> None:
        super().__init__(psi)
 

In [None]:
from qgsw.fields.errors.point_wise import RMSE
from qgsw.fields.variables.dynamics import PhysicalZonalVelocity, PhysicalZonalVelocity2, PhysicalZonalVelocityFromPsi2, PhysicalMeridionalVelocity, PhysicalMeridionalVelocity2, PhysicalMeridionalVelocityFromPsi2
from qgsw.models.qg.stretching_matrix import compute_A
from qgsw.models.qg.uvh.projectors.core import QGProjector
from qgsw.spatial.core.coordinates import Coordinates1D
from qgsw.utils.units._units import Unit
   
P = QGProjector(
    A = compute_A(config_sw.h, config_sw.g_prime, **defaults.get()),
    H = config_sw.h.unsqueeze(-1).unsqueeze(-1),
    space=space_2d.add_h(Coordinates1D(points=config_sw.h,unit=Unit.M)),
    f0=f0,
    masks=Masks.empty(nx, ny, defaults.get_device())
)

model_vars = model.get_variable_set(config.space,config.physics,config.model)
# model_sw_vars = model_sw.get_variable_set(config.space,config.physics,config_sw)
model_1l_vars = model_1l.get_variable_set(config.space,config.physics,config_1l)
model_2l_vars = model_2l.get_variable_set(config.space,config.physics,config_2l)

sw_psi2 = SWPsi2(
    SWPsi(
        PhysicalVorticity(Vorticity(),ds),
        nx,
        ny,
        dx,
        dy,
        P
    )
)

# error = RMSE(model_vars["psi2"],sw_psi2)
# error.slices = [slice(None,None), slice(0,1),...]
# error_1l = RMSE(model_1l_vars["psi2"],sw_psi2)
# error_1l.slices = [slice(None,None), slice(0,1),...]
# error_2l = RMSE(model_2l_vars["psi2"],sw_psi2)
# error_2l.slices = [slice(None,None), slice(0,1),...]

sw_u2 = PhysicalZonalVelocity2(PhysicalZonalVelocity(dx))
qg_u2 = PhysicalZonalVelocity2(PhysicalZonalVelocity(dx))
model_u2 = PhysicalZonalVelocityFromPsi2(model_vars["psi2"],dy)
model_1l_u2 = PhysicalZonalVelocityFromPsi2(model_1l_vars["psi2"],dy)
model_2l_u2 = PhysicalZonalVelocity2(PhysicalZonalVelocity(dx))


error_u = RMSE(model_u2,sw_u2)
error_u.slices = [slice(None,None), slice(0,1),...]
error_qg_u = RMSE(qg_u2,sw_u2)
error_qg_u.slices = [slice(None,None), slice(0,1),...]
error_1l_u = RMSE(model_1l_u2,sw_u2)
error_1l_u.slices = [slice(None,None), slice(0,1),...]
error_2l_u = RMSE(model_2l_u2,sw_u2)
error_2l_u.slices = [slice(None,None), slice(0,1),...]



sw_v2 = PhysicalMeridionalVelocity2(PhysicalMeridionalVelocity(dy))
qg_v2 = PhysicalMeridionalVelocity2(PhysicalMeridionalVelocity(dy))
model_v2 = PhysicalMeridionalVelocityFromPsi2(model_vars["psi2"],dx)
model_1l_v2 = PhysicalMeridionalVelocityFromPsi2(model_1l_vars["psi2"],dx)
model_2l_v2 = PhysicalMeridionalVelocity2(PhysicalMeridionalVelocity(dy))


error_v = RMSE(model_v2,sw_v2)
error_v.slices = [slice(None,None), slice(0,1),...]
error_qg_v = RMSE(qg_v2,sw_v2)
error_qg_v.slices = [slice(None,None), slice(0,1),...]
error_1l_v = RMSE(model_1l_v2,sw_v2)
error_1l_v.slices = [slice(None,None), slice(0,1),...]
error_2l_v = RMSE(model_2l_v2,sw_v2)
error_2l_v.slices = [slice(None,None), slice(0,1),...]


In [None]:
from qgsw.models.synchronization import ModelSync


errors_u = []
errors_qg_u = []
errors_1l_u = []
errors_2l_u = []

errors_v = []
errors_qg_v = []
errors_1l_v = []
errors_2l_v = []

prognostic = model_sw.prognostic
pressure = P.compute_p(prognostic.uvh)[1]

coef.with_optimal_values(model.P.filter(pressure[0,0]), pressure[0,1])
plt.imshow(coef.get()[0,0].cpu().T)
plt.colorbar()
plt.title(r"$\alpha$ - Col Filt & Time avg")
plt.show()
model.alpha = coef.get()

ModelSync(model_sw, model)()
ModelSync(model_sw, model_1l)()
ModelSync(model_sw, model_2l)()
ModelSync(model_sw, model_qg)()

for n, fork, save in zip(ns, forks, saves):
    errors_u.append(
        error_u.compute_ensemble_wise(
            model.prognostic,
            model_sw.prognostic,
        ).cpu().item()
    )

    errors_qg_u.append(
        error_qg_u.compute_ensemble_wise(
            model_qg.prognostic,
            model_sw.prognostic,
        ).cpu().item()
    )

    errors_1l_u.append(
        error_1l_u.compute_ensemble_wise(
            model_1l.prognostic,
            model_sw.prognostic,
        ).cpu().item()
    )

    errors_2l_u.append(
        error_2l_u.compute_ensemble_wise(
            model_2l.prognostic,
            model_sw.prognostic,
        ).cpu().item()
    )

    errors_v.append(
        error_v.compute_ensemble_wise(
            model.prognostic,
            model_sw.prognostic,
        ).cpu().item()
    )

    errors_qg_v.append(
        error_qg_v.compute_ensemble_wise(
            model_qg.prognostic,
            model_sw.prognostic,
        ).cpu().item()
    )

    errors_1l_v.append(
        error_1l_v.compute_ensemble_wise(
            model_1l.prognostic,
            model_sw.prognostic,
        ).cpu().item()
    )

    errors_2l_v.append(
        error_2l_v.compute_ensemble_wise(
            model_2l.prognostic,
            model_sw.prognostic,
        ).cpu().item()
    )

    if save:
        # Save Reference Model
        # model_ref.io.save(output_dir.joinpath(f"{prefix_ref}{n}.pt"))
        # # Save Model
        # model.io.save(output_dir.joinpath(f"{prefix}{n}.pt"))
        ...
        

    model_sw.step()
    model_qg.step()
    model.step()
    model_1l.step()
    model_2l.step()

errors_u = np.array(errors_u)
errors_qg_u = np.array(errors_qg_u)
errors_1l_u = np.array(errors_1l_u)
errors_2l_u = np.array(errors_2l_u)

errors_v = np.array(errors_v)
errors_qg_v = np.array(errors_qg_v)
errors_1l_v = np.array(errors_1l_v)
errors_2l_v = np.array(errors_2l_v)

In [None]:
from qgsw.plots.scatter import ScatterPlot

plot = ScatterPlot([errors_1l_u/errors_1l_u,errors_qg_u/errors_1l_u,errors_2l_u/errors_1l_u, errors_u/errors_1l_u],)
plot.figure.update_layout(template="plotly")
plot.set_xs(
    [k*config.simulation.dt/3600/24 for k in range(len(errors_1l_u))],
    [k*config.simulation.dt/3600/24 for k in range(len(errors_1l_u))],
    [k*config.simulation.dt/3600/24 for k in range(len(errors_1l_u))],
    [k*config.simulation.dt/3600/24 for k in range(len(errors_1l_u))],
)
plot.set_xaxis_title("Times [day]")
plot.set_traces_name(
    "1L-QG vs 3L-SW",
    "3L-QG vs 3L-SW",
    "2L-QG vs 3L-SW", 
    "Col Filt vs 3L-SW",
)
plot.figure.update_layout(title_text="U")
plot.show()


plot = ScatterPlot([errors_1l_v/errors_1l_v,errors_qg_v/errors_1l_v,errors_2l_v/errors_1l_v, errors_v/errors_1l_v],)
plot.figure.update_layout(template="plotly")
plot.set_xs(
    [k*config.simulation.dt/3600/24 for k in range(len(errors_1l_v))],
    [k*config.simulation.dt/3600/24 for k in range(len(errors_1l_v))],
    [k*config.simulation.dt/3600/24 for k in range(len(errors_1l_v))],
    [k*config.simulation.dt/3600/24 for k in range(len(errors_1l_v))],
)
plot.set_xaxis_title("Times [day]")
plot.set_traces_name(
    "1L-QG vs 3L-SW",
    "3L-QG vs 3L-SW",
    "2L-QG vs 3L-SW", 
    "Col Filt vs 3L-SW",
)
plot.figure.update_layout(title_text="V")
plot.show()

In [None]:
import matplotlib.pyplot as plt

sf = StreamFunctionFromVorticity(PhysicalVorticity(Vorticity(),ds),nx, ny, dx, dy)

plt.imshow(sf.compute(model.prognostic)[0,0].cpu().T)
plt.colorbar()
plt.title(r"$\tilde{\psi}_1$")
plt.show()

plt.imshow((model.P.filter(sf.compute(model.prognostic)[0,0])*model.alpha[0,0]).cpu().T)
plt.colorbar()
plt.title(r"$\tilde{\psi}_2 = \alpha K * \tilde{\psi_1}$")
plt.show()

plt.imshow(sf.compute(model_2l.prognostic)[0,0].cpu().T)
plt.title(r"$\psi_1^{2L-QG}$")
plt.colorbar()
plt.show()

plt.imshow(sf.compute(model_2l.prognostic)[0,1].cpu().T)
plt.title(r"$\psi_2^{2L-QG}$")
plt.colorbar()
plt.show()


plt.imshow(sf.compute(model_qg.prognostic)[0,0].cpu().T)
plt.title(r"$\psi_1^{3L-QG}$")
plt.colorbar()
plt.show()

plt.imshow(sf.compute(model_qg.prognostic)[0,1].cpu().T)
plt.title(r"$\psi_2^{3L-QG}$")
plt.colorbar()
plt.show()

plt.imshow(sf.compute(model_sw.prognostic)[0,0].cpu().T)
plt.title(r"$\psi_1^{3L-SW}$")
plt.colorbar()
plt.show()

plt.imshow(sf.compute(model_sw.prognostic)[0,1].cpu().T)
plt.title(r"$\psi_2^{3L-SW}$")
plt.colorbar()
plt.show()

In [None]:
import matplotlib.pyplot as plt

from qgsw.spatial.core.grid_conversion import interpolate

sf = StreamFunctionFromVorticity(PhysicalVorticity(Vorticity(),ds),nx, ny, dx, dy)

plt.imshow(-torch.diff(model.P.filter(sf.compute(model.prognostic)[0,0])*model.alpha[0,0], dim=-1).cpu().T/dy.item())
plt.colorbar()
plt.title(r"$\tilde{u}_2 = - \partial_y \left(\alpha K * \tilde{\psi_1}\right)$")
plt.show()

plt.imshow(model_2l.prognostic.u[0,1].cpu().T/dx.item())
plt.title(r"$u_2^{2L-QG}$")
plt.colorbar()
plt.show()


plt.imshow(model_qg.prognostic.uvh.u[0,1].cpu().T/dx.item())
plt.title(r"$u_2^{3L-QG}$")
plt.colorbar()
plt.show()

plt.imshow(model_sw.prognostic.uvh.u[0,1].cpu().T/dx.item())
plt.title(r"$u_2^{3L-SW}$")
plt.colorbar()
plt.show()

v_ref = interpolate(model_sw.prognostic.uvh.u)[0,1].cpu().T/dx.item()


v = -torch.diff(model.P.filter(sf.compute(model.prognostic)[0,0])*model.alpha[0,0], dim=-1).cpu().T/dy.item()

print(((( v - v_ref ).square()).mean()).sqrt()/errors_1l_u[-1])

plt.imshow((v-v_ref).abs())
plt.colorbar()
plt.title(r"$\| \tilde{u}_2 -u_2^{3L-SW}\|$")
plt.show()

v = interpolate(model_2l.prognostic.u)[0,1].cpu().T/dx.item()

print(((( v - v_ref ).square()).mean()).sqrt()/errors_1l_u[-1])

plt.imshow((v-v_ref).abs())
plt.title(r"$\|u_2^{2L-QG}-u_2^{3L-SW}\|$")
plt.colorbar()
plt.show()

v = interpolate(model_qg.prognostic.u)[0,1].cpu().T/dx.item()

print(((( v - v_ref ).square()).mean()).sqrt()/errors_1l_u[-1])

plt.imshow((v-v_ref).abs())
plt.title(r"$\|u_2^{3L-QG}-u_2^{3L-SW}\|$")
plt.colorbar()
plt.show()

In [None]:
import matplotlib.pyplot as plt

from qgsw.spatial.core.grid_conversion import interpolate

sf = StreamFunctionFromVorticity(PhysicalVorticity(Vorticity(),ds),nx, ny, dx, dy)

plt.imshow(torch.diff(model.P.filter(sf.compute(model.prognostic)[0,0])*model.alpha[0,0], dim=-2).cpu().T/dx.item())
plt.colorbar()
plt.title(r"$\tilde{v}_2 = - \partial_y \left(\alpha K * \tilde{\psi_1}\right)$")
plt.show()

plt.imshow(model_2l.prognostic.v[0,1].cpu().T/dy.item())
plt.title(r"$v_2^{2L-QG}$")
plt.colorbar()
plt.show()

plt.imshow(model_qg.prognostic.uvh.v[0,1].cpu().T/dy.item())
plt.title(r"$v_2^{3L-QG}$")
plt.colorbar()
plt.show()

plt.imshow(model_sw.prognostic.uvh.v[0,1].cpu().T/dy.item())
plt.title(r"$v_2^{3L-SW}$")
plt.colorbar()
plt.show()

v_ref = interpolate(model_sw.prognostic.uvh.v[0,1]).cpu().T/dy.item()


v = torch.diff(model.P.filter(sf.compute(model.prognostic)[0,0])*model.alpha[0,0], dim=-2).cpu().T/dx.item()

print(((( v - v_ref ).square()).mean()).sqrt()/errors_1l_v[-1])

plt.imshow((v-v_ref).abs())
plt.colorbar()
plt.title(r"$\| \tilde{v}_2 -v_2^{3L-SW}\|$")
plt.show()


v = interpolate(model_2l.prognostic.v[0,1].cpu().T/dy.item())

print(((( v - v_ref ).square()).mean()).sqrt()/errors_1l_v[-1])

plt.imshow((v-v_ref).abs())
plt.title(r"$\|v_2^{2L-QG}-v_2^{3L-SW}\|$")
plt.colorbar()
plt.show()

v = interpolate(model_qg.prognostic.v[0,1].cpu().T/dy.item())

print(((( v - v_ref ).square()).mean()).sqrt()/errors_1l_v[-1])

plt.imshow((v-v_ref).abs())
plt.title(r"$\|u_2^{3L-QG}-u_2^{3L-SW}\|$")
plt.colorbar()
plt.show()

In [None]:
from qgsw.models.synchronization import ModelSync

sync = ModelSync(model_sw,model_1l)

In [None]:
sync

In [None]:
plt.imshow(model_sw.u[0,0].cpu().T)
plt.colorbar()
plt.show()
plt.imshow(model_1l.u[0,0].cpu().T)
plt.colorbar()
plt.show()
plt.imshow((model_1l.u[0,0]-model_sw.u[0,0]).abs().cpu().T)
plt.colorbar()
plt.show()

In [None]:
sync()
plt.imshow(model_sw.u[0,0].cpu().T)
plt.colorbar()
plt.show()
plt.imshow(model_1l.u[0,0].cpu().T)
plt.colorbar()
plt.show()
plt.imshow((model_1l.u[0,0]-model_sw.u[0,0]).abs().cpu().T)
plt.colorbar()
plt.show()