In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
from qgsw.run_summary import RunSummary
from qgsw.utils.sorting import sort_files
import torch
import numpy as np
from qgsw.perturbations.vortex import BaroclinicVortex, BarotropicVortex
from qgsw.spatial.core.grid import Grid3D
from qgsw.spatial.units._units import METERS
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from torch.nn import functional as F
from qgsw.utils.gaussian_filtering import GaussianFilter2D
from __future__ import annotations

In [None]:
f0 = 9.375e-5

H1 = 200
H2 = 800
g1 = 10
g2 = 0.05

A = torch.tensor(
    [
        [1/H1/g1+1/H1/g2, -1/H1/g2],
        [-1/H2/g2, 1/H2/g2]
    ],
    dtype=torch.float64
)

eigvals, eigvects = (_.real for _ in torch.linalg.eig(A))
Pm2l = eigvects
Pl2m = eigvects.inverse()
D = torch.diag(eigvals)

def plot_3d(
        data: np.ndarray,
        x:np.ndarray,
        y:np.ndarray,
        dtick:int,
        layer_offset: int,
        show_axis: bool,
        show_background:bool,
        show_legend:bool,
        show_scale:bool,
        colorscale:list[list],
        color_field: str,
        zrange: list[float] | None = None) -> go.Figure:

    if len(data.shape) == 2:
        data = data.reshape((1,-2,-1))


    cmax = np.max(np.abs(data))


    fig = go.Figure()
    fig.update_layout(
        autosize=True,
        margin=dict(l=20, r=20, t=20, b=20),
        width= 1200,
        height = 1000,
        font={"size": 20, "color":"black"},
        xaxis={"scaleanchor": "y", "constrain": "domain"},
        yaxis={"scaleanchor": "x", "constrain": "domain"},
    )

    if not show_background :
        fig.update_layout(
            {
                "paper_bgcolor": "rgba(0, 0, 0, 0)",
                "plot_bgcolor": "rgba(0, 0, 0, 0)",
            },
        )

    if not show_axis:
        fig.update_scenes(xaxis_visible=False, yaxis_visible=False,zaxis_visible=False )
    else:
        fig.update_layout(scene = dict(
            xaxis = dict(
                title = "X (km)",
                gridcolor="white",
                showbackground=True,
                zerolinecolor="white",
                tickangle=0,
                title_font = {"size":30,},
                dtick = dtick,
                tick0 = 0,
                ),
            yaxis = dict(
                title = "Y (km)",
                gridcolor="white",
                showbackground=True,
                zerolinecolor="white",
                tickangle=0,
                title_font = {"size":30,},
                dtick = dtick,
                tick0 = 0,
                ),
            zaxis_visible=False,
        ),
            margin=dict(
            r=10, l=10,
            b=10, t=10)
        )
    if zrange is not None:
        fig.update_layout(
            scene = dict( zaxis = dict(range=zrange))
        )

    if not show_legend:
        fig.update_layout(showlegend=False)


    colorbar = go.surface.ColorBar(
        exponentformat="e",
        showexponent="all",
        title= go.surface.colorbar.Title(
            text = color_field,
            side = "right",
            font = go.surface.colorbar.title.Font(
                size = 40,
            ),
        ),
        thickness = 100,
        tickfont=go.surface.colorbar.Tickfont(
            size=40
        )
    )

    for i, layer in enumerate(data):

        fig.add_trace(
            go.Surface(
                x=x,
                y=y,
                z=layer.T - i*layer_offset,
                colorscale=colorscale,
                cmin=-cmax- i*layer_offset,
                cmax=cmax- i*layer_offset,
                colorbar = colorbar,
                showscale = i==0 and show_scale
            ),
        )



    camera = dict(
        up=dict(x=0, y=0, z=1),
        center=dict(x=0, y=0, z=0),
        eye=dict(x=-1.5, y=-1.5, z=1.25)
    )

    fig.update_layout(scene_camera=camera)

    return fig


# Obs

In [None]:
size = 500

torch.random.manual_seed(0)

filter = GaussianFilter2D(0.5,25)

surface = filter.smooth(torch.rand((size,size))-0.5)
bottom = filter.smooth(torch.rand((size,size))-0.5)

cmax = max(torch.max(torch.abs(surface)).cpu().item(),torch.max(torch.abs(bottom)).cpu().item())

fig = go.Figure()
fig.update_layout(
    width = 2000,
    height = 2000,
)

fig.update_layout(
    {
        "paper_bgcolor": "rgba(0, 0, 0, 0)",
        "plot_bgcolor": "rgba(0, 0, 0, 0)",
    },
)
fig.update_layout(showlegend=False)

fig.add_trace(
    go.Surface(
        z=surface,
        colorscale=px.colors.sequential.Blues,
        cmin=-cmax,
        cmax=cmax,
        showscale = False,
    ),
)
offset = 1
fig.add_trace(
    go.Surface(
        z=bottom - offset,
        colorscale=px.colors.sequential.Blues,
        cmin=-cmax - offset - 0.1,
        cmax=cmax - offset - 0.1,
        showscale=False,
    ),
)

fig.write_image("../output/presentation/two_layers_grid.png")

fig.update_scenes(xaxis_visible=False, yaxis_visible=False,zaxis_visible=False )

fig.write_image("../output/presentation/two_layers.png")

x_r = torch.randint(0,size,(25,))
y_r = torch.randint(0,size,(25,))
z_r = surface[y_r,x_r]

fig.add_trace(
    go.Scatter3d(
        name = "Observations Campaigns data",
        x = x_r,
        y = y_r,
        z = z_r,
        mode="markers",
        marker = {
            "size": 15,
            # "color": "#7e0723"
            "color": "red"
        }
    )
)

fig.add_trace(
    go.Scatter3d(
        name = "Observations Campaigns data",
        x = x_r[::5],
        y = y_r[::5],
        z = z_r[::5] - offset,
        mode="markers",
        marker = {
            "size": 8,
            # "color": "#7e0723"
            "color" : "red"
        }
    )
)

fig.write_image("../output/presentation/two_layers_obs.png")

x_sat, y_sat = torch.meshgrid(
    torch.arange(0,size,size//20),
    torch.arange(0,size,size//20),
    indexing= 'ij'
)
z_sat = surface[y_sat,x_sat]

fig.add_trace(
    go.Scatter3d(
        name = "Satellite Observations",
        x = x_sat.flatten(),
        y = y_sat.flatten(),
        z = z_sat.flatten(),
        mode="markers",
        marker = {
            "size": 15,
            "color": "#00502e"
        }
    )
)

fig.write_image("../output/presentation/two_layers_obs_sat.png")
# fig.show()

# Streamlines

In [None]:
import plotly.figure_factory as ff

import numpy as np

size = 200

x= torch.linspace(-500,500,size).to(torch.float64)
y = torch.linspace(-500,500,size).to(torch.float64)

grid_3d = Grid3D.from_tensors(
    x_unit=METERS,
    y_unit=METERS,
    zh_unit=METERS,
    x = x,
    y = y,
    h = torch.tensor([200,800]).to(torch.float64),
)

sf = BaroclinicVortex(0.001).compute_stream_function(grid_3d)[0,0]

u = - torch.diff(F.pad(sf, (1,0,0,0)), dim=1) / 9.375e-5 * 10
v = torch.diff(F.pad(sf, (0,0,1,0)), dim=0) / 9.375e-5 * 10

s_s = slice(1,-1,1)
s_q = slice(1,-1,10)

streamline = ff.create_streamline(x[s_s], y[s_s], u[s_s,s_s], v[s_s,s_s], arrow_scale=10)
quiver = ff.create_quiver(grid_3d.xyh.x[0,s_q,s_q], grid_3d.xyh.y[0,s_q,s_q], u[s_q,s_q], v[s_q,s_q], scale=0.00001)
fig = go.Figure()
fig.update_layout(
    width = 800,
    height = 800,
)

fig.update_layout(
    {
        "paper_bgcolor": "rgba(0, 0, 0, 0)",
        "plot_bgcolor": "rgba(0, 0, 0, 0)",
    },
)
fig.update_layout(showlegend=False)
fig.update_scenes(xaxis_visible=False, yaxis_visible=False,zaxis_visible=False )
fig.add_trace(go.Heatmap(x=x, y=y,z=sf, colorscale=px.colors.diverging.RdBu_r, showscale=False))
fig.add_trace(streamline.data[0])
fig.add_trace(quiver.data[0])
fig.show()

In [None]:
import plotly.figure_factory as ff

import numpy as np

size = 200

filter = GaussianFilter2D(0.5,10)

x = np.arange(0, size, 1)
y = np.arange(0, size, 1)

xs,ys = np.meshgrid(x, y)

surface = filter.smooth(torch.rand((size,size))-0.5)

u = - torch.diff(F.pad(surface, (1,0,0,0)), dim=1) / 9.375e-5 * 10
v = torch.diff(F.pad(surface, (0,0,1,0)), dim=0) / 9.375e-5 * 10

s_s = slice(1,-1,1)
s_q = slice(1,-1,10)

streamline = ff.create_streamline(x[s_s], y[s_s], u[s_s,s_s], v[s_s,s_s], arrow_scale=3)
quiver = ff.create_quiver(xs[s_q,s_q], ys[s_q,s_q], u[s_q,s_q], v[s_q,s_q], scale=0.001)
fig = go.Figure()
fig.update_layout(
    width = 800,
    height = 800,
)

fig.update_layout(
    {
        "paper_bgcolor": "rgba(0, 0, 0, 0)",
        "plot_bgcolor": "rgba(0, 0, 0, 0)",
    },
)
fig.update_layout(showlegend=False)
fig.update_scenes(xaxis_visible=False, yaxis_visible=False,zaxis_visible=False )
fig.add_trace(go.Heatmap(x=x,y=y,z=surface, colorscale=px.colors.sequential.Blues, showscale=False))
fig.add_trace(streamline.data[0])
fig.add_trace(quiver.data[0])
fig.show()

# Perturbations

In [None]:
x = torch.linspace(-150_000,150_000,192).to(torch.float64)
y = torch.linspace(-150_000,150_000,192).to(torch.float64)

grid_3d = Grid3D.from_tensors(
    x_unit=METERS,
    y_unit=METERS,
    zh_unit=METERS,
    x = x,
    y = y,
    h = torch.tensor([200,800]).to(torch.float64),
)

pressure = BaroclinicVortex(0.001).compute_initial_pressure(grid_3d,9.375e-5,0.1)

omega = torch.diff(
        torch.diff(F.pad(pressure, (1, 1, 0, 0)), dim=-1),
        dim=-1,
    ) + torch.diff(
        torch.diff(F.pad(pressure, (0, 0, 1, 1)), dim=-2),
        dim=-2,
    ) - f0 * torch.einsum("nk,...kij -> ...nij", A, pressure) / torch.tensor([[[[H1]],[[H2]]]])
cmax = torch.max(torch.abs(omega)).cpu().item()


plot_3d(
    omega[0].cpu().numpy(),
    x = x.cpu().numpy() / 1000,
    y = y.cpu().numpy() / 1000,
    dtick = 50,
    layer_offset = 10,
    show_axis=True,
    show_background=False,
    show_legend=False,
    show_scale = True,
    colorscale=px.colors.sequential.RdBu_r,
    color_field="Vorticité Potentielle (s⁻¹)"
).write_image("../output/presentation/perturbations/baroclinic.png")


plot_3d(
    F.pad(omega[0], (1,1,1,1), value=0).cpu().numpy(),
    x = torch.cat([torch.tensor([-500_000]),x,torch.tensor([500_000])]).cpu().numpy() / 1000,
    y = torch.cat([torch.tensor([-500_000]),y,torch.tensor([500_000])]).cpu().numpy() / 1000,
    dtick = 150,
    layer_offset = 10,
    show_axis=True,
    show_background=False,
    show_legend=False,
    show_scale = True,
    colorscale=px.colors.sequential.RdBu_r,
    color_field="Vorticité Potentielle (s⁻¹)"
).write_image("../output/presentation/perturbations/SDM1.png")


x = torch.linspace(-500_000,500_000,192).to(torch.float64)
y = torch.linspace(-500_000,500_000,192).to(torch.float64)

grid_3d = Grid3D.from_tensors(
    x_unit=METERS,
    y_unit=METERS,
    zh_unit=METERS,
    x = x,
    y = y,
    h = torch.tensor([200,800]).to(torch.float64),
)


pressure = BaroclinicVortex(0.001).compute_initial_pressure(grid_3d,9.375e-5,0.1)

omega = torch.diff(
        torch.diff(F.pad(pressure, (1, 1, 0, 0)), dim=-1),
        dim=-1,
    ) + torch.diff(
        torch.diff(F.pad(pressure, (0, 0, 1, 1)), dim=-2),
        dim=-2,
    ) - f0 * torch.einsum("nk,...kij -> ...nij", A, pressure) / torch.tensor([[[[H1]],[[H2]]]])
cmax = torch.max(torch.abs(omega)).cpu().item()


plot_3d(
    omega[0].cpu().numpy(),
    x = x.cpu().numpy() / 1000,
    y = y.cpu().numpy() / 1000,
    dtick = 50,
    layer_offset = 10,
    show_axis=True,
    show_background=False,
    show_legend=False,
    show_scale = True,
    colorscale=px.colors.sequential.RdBu_r,
    color_field="Vorticité Potentielle (s⁻¹)"
).write_image("../output/presentation/perturbations/SDM1_tmp.png")


pressure = BarotropicVortex(0.001).compute_initial_pressure(grid_3d,9.375e-5,0.1)

omega = torch.diff(
        torch.diff(F.pad(pressure, (1, 1, 0, 0)), dim=-1),
        dim=-1,
    ) + torch.diff(
        torch.diff(F.pad(pressure, (0, 0, 1, 1)), dim=-2),
        dim=-2,
    ) - f0 * torch.einsum("nk,...kij -> ...nij", A, pressure) / torch.tensor([[[[H1]],[[H2]]]])
cmax = torch.max(torch.abs(omega)).cpu().item()


plot_3d(
    omega[0].cpu().numpy(),
    x = x.cpu().numpy() / 1000,
    y = y.cpu().numpy() / 1000,
    dtick = 150,
    layer_offset = 10,
    show_axis=True,
    show_background=False,
    show_legend=False,
    show_scale = True,
    colorscale=px.colors.sequential.RdBu_r,
    color_field="Vorticité Potentielle (s⁻¹)"
).write_image("../output/presentation/perturbations/barotropic.png")

plot_3d(
    omega[0].cpu().numpy(),
    x = x.cpu().numpy() / 1000,
    y = y.cpu().numpy() / 1000,
    dtick = 150,
    layer_offset = 10,
    show_axis=True,
    show_background=False,
    show_legend=False,
    show_scale = True,
    colorscale=px.colors.sequential.RdBu_r,
    color_field="Vorticité Potentielle (s⁻¹)"
).write_image("../output/presentation/perturbations/SDM2.png")

pressure = BaroclinicVortex(0.001).compute_initial_pressure(grid_3d,9.375e-5,0.1)

omega = torch.diff(
        torch.diff(F.pad(pressure, (1, 1, 0, 0)), dim=-1),
        dim=-1,
    ) + torch.diff(
        torch.diff(F.pad(pressure, (0, 0, 1, 1)), dim=-2),
        dim=-2,
    ) - f0 * torch.einsum("nk,...kij -> ...nij", A, pressure) / torch.tensor([[[[H1]],[[H2]]]])
cmax = torch.max(torch.abs(omega)).cpu().item()


plot_3d(
    omega[0].cpu().numpy(),
    x = x.cpu().numpy() / 1000,
    y = y.cpu().numpy() / 1000,
    dtick = 150,
    layer_offset = 10,
    show_axis=True,
    show_background=False,
    show_legend=False,
    show_scale = True,
    colorscale=px.colors.sequential.RdBu_r,
    color_field="Vorticité Potentielle (s⁻¹)"
).write_image("../output/presentation/perturbations/MDM2.png")


# Modified Model

In [None]:
import os
from pathlib import Path

import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from qgsw.run_summary import RunSummary
import toml

ROOT = Path(os.path.abspath('')).parent

file = ROOT.joinpath("archive/qg_1l_SDM1/results_step_90062.npz")
field = "p"


summary = RunSummary.from_file(file.parent.joinpath("_summary.toml"))
config = summary.configuration
x_min, x_max = config.space.box.x_min, config.space.box.x_max
y_min, y_max = config.space.box.y_min, config.space.box.y_max

offset = 24

x = np.linspace(x_min, x_max,192)[offset:-offset]
y = np.linspace(y_min, y_max,192)[offset:-offset]

p = np.load(file)[field][0, 0, ...][offset:-offset,offset:-offset] / f0

alphas = [0,0.5,1]

for alpha in alphas:

    cmax = np.max(np.abs(p))
    cmin = -cmax


    fig = plot_3d(
        np.concatenate([p.reshape((1,*p.shape)), alpha*p.reshape((1,*p.shape))]),
        # p.reshape((1,*p.shape)),
        x=x/1000,
        y=y/1000,
        dtick = 50,
        layer_offset=1000000,
        show_axis=True,
        show_background=False,
        show_legend=False,
        show_scale=True,
        colorscale=px.colors.diverging.RdBu_r,
        color_field="Fonction de Courant (m².s⁻¹)",
        zrange = None

    ).write_image(f"../output/presentation/sf_{alpha}_.png")



# PV Profiles

In [None]:
offset = 24

def plot_pv_profile(file: str, Lx: int, dtick:int, output:str, layer: int, zmax: float | None, field: str, color_field: str, norm: float) -> None:

    data = np.load(file)[field][0,layer] / norm

    x = np.linspace(-Lx // 2,Lx // 2, 192)[offset:-offset]
    y = np.linspace(-Lx // 2,Lx // 2, 192)[offset:-offset]

    if zmax is None:
        zmax = np.max(np.abs(data))

    fig = go.Figure()

    fig.update_layout(
        autosize=True,
        margin=dict(l=20, r=20, t=20, b=20),
        width= 1500,
        height = 1000,
        font={"size": 40, "color":"black"},
        xaxis={"scaleanchor": "y", "constrain": "domain"},
        yaxis={"scaleanchor": "x", "constrain": "domain"},
    )

    fig.add_trace(
        go.Heatmap(
            x=x,
            y=y,
            z = data[offset:-offset, offset:-offset].T,
            zmin=-zmax,
            zmax=zmax,
            colorscale=px.colors.diverging.RdBu_r,
            colorbar = dict(exponentformat="e",
                showexponent="all",
                title={"text": color_field, "side": "right"},
                thickness=100,
                lenmode='fraction', len=0.95, tickwidth=20)
        )
    )

    fig.update_xaxes(
        tick0=0,
        dtick = dtick,
        title = "X (km)",
    )
    fig.update_yaxes(
        tick0=0,
        dtick = dtick,
        title = "Y (km)",
        ticksuffix = "  ",
    )

    fig.add_shape(
        type="rect",
        xref=f"x",
        yref=f"y",
        x0=x[0],
        x1=x[-1],
        y0=y[0],
        y1=y[-1],
        line=dict(color="black", width=2)
    )

    fig.show()

In [None]:
configs = [
    {
        "file" : "../output/g5k/qg_1l_MDM2/results_step_91898.npz",
        "Lx" : 1_000,
        "dtick" : 100,
        "output": "../output/presentation/qg_1l_MDM2_end.png",
        "layer" : 0,
        "zmax" : None,
        "field" : "pv",
        "color_field" : "Vorticité Potentielle (s⁻¹)",
        'norm' : 1
    },
    {
        "file" : "../output/g5k/qg_1l_SDM1/results_step_90062.npz",
        "Lx" : 300,
        "dtick" : 50,
        "output": "../output/presentation/qg_1l_SDM1_end.png",
        "layer" : 0,
        "zmax" : None,
        "field" : "pv",
        "color_field" : "Vorticité Potentielle (s⁻¹)",
        'norm' : 1
    },
    {
        "file" : "../output/g5k/qg_2l_MDM2/results_step_91898.npz",
        "Lx" : 1_000,
        "dtick" : 100,
        "output": "../output/presentation/qg_2l_MDM2_end.png",
        "layer" : 0,
        "zmax" : None,
        "field" : "pv",
        "color_field" : "Vorticité Potentielle (s⁻¹)",
        'norm' : 1
    },
    {
        "file" : "../output/g5k/qg_2l_SDM1/results_step_91898.npz",
        "Lx" : 300,
        "dtick" : 50,
        "output": "../output/presentation/qg_2l_SDM1_end.png",
        "layer" : 0,
        "zmax" : None,
        "field" : "pv",
        "color_field" : "Vorticité Potentielle (s⁻¹)",
        'norm' : 1
    },
    {
        "file" : "../output/g5k/sf_changing_0_5_MDM2/results_step_0.npz",
        "Lx" : 1_000,
        "dtick" : 100,
        "output": "../output/presentation/sf_changing/0.png",
        "layer" : 0,
        "zmax" : None,
        "field" : "pv",
        "color_field" : "Vorticité Potentielle (s⁻¹)",
        'norm' : 1
    },
    {
        "file" : "../output/g5k/sf_changing_0_5_MDM2/results_step_30327.npz",
        "Lx" : 1_000,
        "dtick" : 100,
        "output": "../output/presentation/sf_changing/30327.png",
        "layer" : 0,
        "zmax" : None,
        "field" : "pv",
        "color_field" : "Vorticité Potentielle (s⁻¹)",
        'norm' : 1
    },
    {
        "file" : "../output/g5k/sf_changing_0_5_MDM2/results_step_60654.npz",
        "Lx" : 1_000,
        "dtick" : 100,
        "output": "../output/presentation/sf_changing/60654.png",
        "layer" : 0,
        "zmax" : None,
        "field" : "pv",
        "color_field" : "Vorticité Potentielle (s⁻¹)",
        'norm' : 1
    },
    {
        "file" : "../output/g5k/sf_changing_0_5_MDM2/results_step_91898.npz",
        "Lx" : 1_000,
        "dtick" : 100,
        "output": "../output/presentation/sf_changing/91898.png",
        "layer" : 0,
        "zmax" : None,
        "field" : "pv",
        "color_field" : "Vorticité Potentielle (s⁻¹)",
        'norm' : 1
    },
    {
        "file" : "../output/g5k/qg_2l_MDM2/results_step_0.npz",
        "Lx" : 1_000,
        "dtick" : 100,
        "output": "../output/presentation/qg_2L_MDM2_0_top.png",
        "layer" : 0,
        "zmax" : None,
        "field" : "p",
        "color_field" : "Fonction de Courant (m².s⁻¹)",
        'norm' : f0
    },
    {
        "file" : "../output/g5k/qg_2l_MDM2/results_step_91898.npz",
        "Lx" : 1_000,
        "dtick" : 100,
        "output": "../output/presentation/qg_2L_MDM2_91898_top.png",
        "layer" : 0,
        "zmax" : None,
        "field" : "p",
        "color_field" : "Fonction de Courant (m².s⁻¹)",
        'norm' : f0
    },
    {
        "file" : "../output/g5k/qg_2l_MDM2/results_step_0.npz",
        "Lx" : 1_000,
        "dtick" : 100,
        "output": "../output/presentation/qg_2L_MDM2_0_bot.png",
        "layer" : 1,
        "zmax" : 7000,
        "field" : "p",
        "color_field" : "Fonction de Courant (m².s⁻¹)",
        'norm' : f0
    },
    {
        "file" : "../output/g5k/qg_2l_MDM2/results_step_91898.npz",
        "Lx" : 1_000,
        "dtick" : 100,
        "output": "../output/presentation/qg_2L_MDM2_91898_bot.png",
        "layer" : 1,
        "zmax" : 7000,
        "field" : "p",
        "color_field" : "Fonction de Courant (m².s⁻¹)",
        'norm' : f0
    }
]


for config in configs:
    plot_pv_profile(
        file = config["file"],
        Lx = config["Lx"],
        dtick = config["dtick"],
        output = config["output"],
        layer = config["layer"],
        zmax = config["zmax"],
        field = config["field"],
        color_field = config["color_field"],
        norm = config["norm"]
    )

# 3D Profiles

In [None]:

def plot_pv_profile_3d(file:str, Lx:int, dtick:int, output:str) -> None:

    offset = 24

    data = np.load(file)["p"][0,...,offset:-offset,offset:-offset] / f0

    x = np.linspace(-Lx // 2,Lx // 2, 192)[offset:-offset]
    y = np.linspace(-Lx // 2,Lx // 2, 192)[offset:-offset]

    plot_3d(
        data = data,
        x= x,
        y=y,
        dtick = dtick,
        layer_offset = 10000000,
        show_axis=True,
        show_background=False,
        show_legend=True,
        show_scale=True,
        colorscale=px.colors.diverging.RdBu_r,
        color_field="Fonction de Courant (m².s⁻¹)"
    ).write_image(output)

In [None]:
configs = [
    {
        "file" : "../output/g5k/qg_2l_MDM2/results_step_0.npz",
        "Lx" : 1_000,
        "dtick" : 100,
        "output" : "../output/presentation/qg_2l_MDM2_0_3D.png"
    },
    {
        "file" : "../output/g5k/qg_2l_MDM2/results_step_91898.npz",
        "Lx" : 1_000,
        "dtick" : 100,
        "output" : "../output/presentation/qg_2l_MDM2_91898_3D.png"
    }
]

for config in configs:
    plot_pv_profile_3d(
        file =config["file"],
        Lx = config["Lx"],
        dtick = config["dtick"],
        output = config["output"],
    )

# Modes

In [None]:
def plot_modes(folder: str) -> None:

    path = Path(folder)
    run = RunSummary.from_file(path.joinpath("_summary.toml"))
    steps, files= sort_files(list(path.glob(f"{run.configuration.model.prefix}*.npz")),run.configuration.model.prefix,".npz")


    modes_1 = []
    modes_2 = []
    times=[]
    for i in range(0,len(steps),1):

        psi = torch.tensor(np.load(files[i])["p"], dtype=torch.float64)#[...,offset:-offset,offset:-offset]
        modes = torch.einsum("lm,...mxy->...lxy", Pl2m, psi)
        top = modes[0,0,...]
        bottom = modes[0,1,...]

        mode_1 = torch.mean(torch.square(top))
        mode_2 = torch.mean(torch.square(bottom))

        modes_1.append(mode_1 / (mode_1 + mode_2))
        modes_2.append(mode_2 / (mode_1 + mode_2))

        times.append(steps[i] * run.configuration.simulation.dt / 3600 / 24)

    fig = go.Figure()
    fig.update_layout(
        template="plotly_white",
        autosize=False,
        margin=dict(l=270, r=20, t=20, b=20),
        width=1800,
        height=1000 ,
        font={"size": 60, "color":"black"},
    )
    fig.update_yaxes(
        range=[-0.1,1.1],
        title = {"text": "Intensité du Mode"},
        ticksuffix = "  ",
        tick0=0,
        dtick=0.25,
        visible = True,
    )
    fig.update_xaxes(
        range=[times[0], times[-1]],
        title = {"text": "Temps (jours)"}
    )
    fig.add_shape(
        type="rect", xref=f"x", yref=f"y", x0=times[0], x1=times[-1], y0=-0.1, y1=1.1, line=dict(color="black", width=3)
    )
    fig.add_trace(
        go.Scatter(
            name = "Mode Barocline",
            x = times,
            y = modes_1,
            mode = "lines",
            line= dict(width=15, color="#1d4e89"),
            showlegend=True,
        ),
    )
    fig.add_trace(
        go.Scatter(
            name="Mode Barotrope",
            x = times,
            y = modes_2,
            mode = "lines",
            line= dict(width=15, color = "#a41328"),
            showlegend=True
        ),
    )

    fig.show()

In [None]:
configs = [
    {
        "folder" : "../output/g5k/qg_2l_MDM2"
    }
]

for config in configs:
    plot_modes(
        config["folder"]
    )

# RMSE

In [None]:
px.colors.sequential.Viridis

In [None]:
def loss(x:np.ndarray, y:np.ndarray) -> float:
    return np.sqrt(np.mean(np.square(x - y))) / 9.375e-5

config = {
    "lines" : [
        {
            "folder" : "../output/g5k/sf_0_MDM2",
            "color" : "#440154",
            "name" : "α = 0",
            "dash" : "solid",
        },
        {
            "folder" : "../output/g5k/sf_0_25_MDM2",
            "color" : "#26828e",
            "name" : "α = 0.25",
            "dash" : "solid",
        },
        {
            "folder" : "../output/g5k/sf_0_5_MDM2",
            "color" : "#b5de2b",
            "name" : "α = 0.5",
            "dash" : "solid",
        },
        {
            "folder" : "../output/g5k/qg_1l_MDM2",
            "color" : "#a41328",
            "name" : "Référence",
            "dash" : "dashdot",
        },
    ],
    "baseline" : "../output/g5k/qg_2l_MDM2",
    "layer" : 0,
    "field" : "pv",
}

layer = config["layer"]
field = config["field"]

folder_baseline = Path(config["baseline"])
summary_baseline = RunSummary.from_file(folder_baseline.joinpath("_summary.toml"))
config_baseline = summary_baseline.configuration
steps_baseline, files_baseline = sort_files(list(folder_baseline.glob(f"{config_baseline.model.prefix}*.npz")),config_baseline.model.prefix,".npz")

fig = go.Figure()

for line in config["lines"]:

    folder = Path(line["folder"])
    summary = RunSummary.from_file(folder.joinpath("_summary.toml"))
    steps, files = sort_files(list(folder.glob(f"{summary.configuration.model.prefix}*.npz")),summary.configuration.model.prefix,".npz")
    
    losses = []
    times = []

    for k,file in enumerate(files):

        file_baseline = files_baseline[k]

        data = np.load(file)[field][0, layer, ...]
        data_baseline = np.load(file_baseline)[field][0, layer, ...]

        losses.append(loss(data, data_baseline))
        times.append(steps[k] * summary.configuration.simulation.dt / 3600 / 24)
    
    scatter = go.Scatter(
        x=times, 
        y=losses,
        name = line["name"],
        mode = "lines",
        line= dict(width=11, color = line["color"], dash = line["dash"]),
    )

    fig.add_trace(scatter)

fig.update_layout(
    template="plotly_white",
    autosize=False,
    margin=dict(l=270, r=20, t=20, b=20),
    width=1800,
    height=1000 ,
    font={"size": 60, "color":"black"},
)

fig.update_xaxes(
    title={"text": "Temps (jours)"},
    exponentformat="e",
    mirror=True,
    linewidth=3,
    showgrid=False,
    linecolor="black",
)

fig.update_yaxes(
    title={"text": "RMSE"},
    exponentformat="none",
    ticksuffix = "  ",
    tick0 = 0,
    nticks = 5,
    mirror=True,
    showgrid=True,
    gridwidth=2,
    linewidth=3,
    linecolor="black",

)
fig.show()


In [None]:
def loss(x:np.ndarray, y:np.ndarray) -> float:
    return np.sqrt(np.mean(np.square(x - y))) / 9.375e-5

config = {
    "lines" : [
        {
            "folder" : "../output/g5k/qg_1l_MDM2",
            "color" : "#a41328",
            "name" : "Référence",
            "dash" : "dashdot",
        },
        {
            "folder" : "../output/g5k/sf_0_25_MDM2",
            "color" : "#FECB52",
            "name" : "α = 0.25",
            "dash" : "dashdot",
        },
        {
            "folder" : "../output/g5k/sf_changing_0_25_MDM2",
            "color" : "#440154",
            "name" : "θ = 0.25",
            "dash" : "solid",
        },
        {
            "folder" : "../output/g5k/sf_changing_0_5_MDM2",
            "color" : "#26828e",
            "name" : "θ = 0.5",
            "dash" : "solid",
        },
        {
            "folder" : "../output/g5k/sf_changing_1_5_MDM2",
            "color" : "#b5de2b",
            "name" : "θ = 1.5",
            "dash" : "solid",
        },
    ],
    "baseline" : "../output/g5k/qg_2l_MDM2",
    "layer" : 0,
    "field" : "pv",
}

layer = config["layer"]
field = config["field"]

folder_baseline = Path(config["baseline"])
summary_baseline = RunSummary.from_file(folder_baseline.joinpath("_summary.toml"))
config_baseline = summary_baseline.configuration
steps_baseline, files_baseline = sort_files(list(folder_baseline.glob(f"{config_baseline.model.prefix}*.npz")),config_baseline.model.prefix,".npz")

fig = go.Figure()

for line in config["lines"]:

    folder = Path(line["folder"])
    summary = RunSummary.from_file(folder.joinpath("_summary.toml"))
    steps, files = sort_files(list(folder.glob(f"{summary.configuration.model.prefix}*.npz")),summary.configuration.model.prefix,".npz")
    
    losses = []
    times = []

    for k,file in enumerate(files):

        file_baseline = files_baseline[k]

        data = np.load(file)[field][0, layer, ...]
        data_baseline = np.load(file_baseline)[field][0, layer, ...]

        losses.append(loss(data, data_baseline))
        times.append(steps[k] * summary.configuration.simulation.dt / 3600 / 24)
    
    scatter = go.Scatter(
        x=times, 
        y=losses,
        name = line["name"],
        mode = "lines",
        line= dict(width=11, color = line["color"], dash = line["dash"]),
    )

    fig.add_trace(scatter)

fig.update_layout(
    template="plotly_white",
    autosize=False,
    margin=dict(l=270, r=20, t=20, b=20),
    width=1800,
    height=1000 ,
    font={"size": 60, "color":"black"},
)

fig.update_xaxes(
    title={"text": "Temps (jours)"},
    exponentformat="e",
    mirror=True,
    linewidth=3,
    showgrid=False,
    linecolor="black",
)

fig.update_yaxes(
    title={"text": "RMSE"},
    exponentformat="none",
    ticksuffix = "  ",
    tick0 = 0,
    nticks = 5,
    mirror=True,
    showgrid=True,
    gridwidth=2,
    linewidth=3,
    linecolor="black",

)
fig.show()


# Alpha

In [None]:
from qgsw.utils.gaussian_filtering import GaussianFilter1D


config = {
    "lines" : [
        {
            "file" : "../data/coefficients_0_25_.npz",
            "color" : "#440154",
            "name" : "θ = 0.25",
            "dash" : "solid",
        },
        {
            "file" : "../data/coefficients_0_5.npz",
            "color" : "#26828e",
            "name" : "θ = 0.5",
            "dash" : "solid",
        },
        {
            "file" : "../data/coefficients_1_5.npz",
            "color" : "#b5de2b",
            "name" : "θ = 1.5",
            "dash" : "solid",
        },
    ],
}

filter = GaussianFilter1D(sigma=0.25, radius=30)

fig = go.Figure()

for line in config["lines"]:

    data = np.load(line["file"])
    alphas = filter.smooth(data["alpha"])[:101]
    # alphas = data["alpha"][:101]
    times = data["times"][:101] / 3600 / 24
    
    scatter = go.Scatter(
        x=times, 
        y=alphas,
        name = line["name"],
        mode = "lines",
        line= dict(width=11, color = line["color"], dash = line["dash"]),
    )

    fig.add_trace(scatter)

fig.update_layout(
    template="plotly_white",
    autosize=False,
    margin=dict(l=270, r=20, t=20, b=20),
    width=1800,
    height=1000 ,
    font={"size": 60, "color":"black"},
)

fig.update_xaxes(
    title={"text": "Temps (jours)"},
    exponentformat="e",
    mirror=True,
    linewidth=3,
    showgrid=False,
    linecolor="black",
)

fig.update_yaxes(
    title={"text": "α"},
    exponentformat="none",
    ticksuffix = "  ",
    tick0 = 0,
    nticks = 5,
    mirror=True,
    showgrid=True,
    gridwidth=2,
    linewidth=3,
    linecolor="black",

)
fig.show()