In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
from pathlib import Path
from qgsw.run_summary import RunSummary
from qgsw.utils.sorting import sort_files
import torch
import numpy as np
from qgsw.perturbations.vortex import BaroclinicVortex, BarotropicVortex
from qgsw.spatial.core.grid import Grid3D
from qgsw.utils.units._units import Unit
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from torch.nn import functional as F
from qgsw.filters.low_pass import GaussianFilter2D
from __future__ import annotations

In [4]:
f0 = 9.375e-5

H1 = 200
H2 = 800
g1 = 10
g2 = 0.05

A = torch.tensor(
    [
        [1/H1/g1+1/H1/g2, -1/H1/g2],
        [-1/H2/g2, 1/H2/g2]
    ],
    dtype=torch.float64
)

eigvals, eigvects = (_.real for _ in torch.linalg.eig(A))
Pm2l = eigvects
Pl2m = eigvects.inverse()
D = torch.diag(eigvals)

def plot_3d(
        data: np.ndarray,
        x:np.ndarray,
        y:np.ndarray,
        dtick:int,
        layer_offset: int,
        show_axis: bool,
        show_background:bool,
        show_legend:bool,
        show_scale:bool,
        colorscale:list[list],
        color_field: str,
        zrange: list[float] | None = None) -> go.Figure:

    if len(data.shape) == 2:
        data = data.reshape((1,-2,-1))


    cmax = np.max(np.abs(data))


    fig = go.Figure()
    fig.update_layout(
        autosize=True,
        margin=dict(l=20, r=20, t=20, b=20),
        width= 1200,
        height = 1000,
        font={"size": 20, "color":"black"},
        xaxis={"scaleanchor": "y", "constrain": "domain"},
        yaxis={"scaleanchor": "x", "constrain": "domain"},
    )

    if not show_background :
        fig.update_layout(
            {
                "paper_bgcolor": "rgba(0, 0, 0, 0)",
                "plot_bgcolor": "rgba(0, 0, 0, 0)",
            },
        )

    if not show_axis:
        fig.update_scenes(xaxis_visible=False, yaxis_visible=False,zaxis_visible=False )
    else:
        fig.update_layout(scene = dict(
            xaxis = dict(
                title = "X (km)",
                gridcolor="white",
                showbackground=True,
                zerolinecolor="white",
                tickangle=0,
                title_font = {"size":30,},
                dtick = dtick,
                tick0 = 0,
                ),
            yaxis = dict(
                title = "Y (km)",
                gridcolor="white",
                showbackground=True,
                zerolinecolor="white",
                tickangle=0,
                title_font = {"size":30,},
                dtick = dtick,
                tick0 = 0,
                ),
            zaxis_visible=False,
        ),
            margin=dict(
            r=10, l=10,
            b=10, t=10)
        )
    if zrange is not None:
        fig.update_layout(
            scene = dict( zaxis = dict(range=zrange))
        )

    if not show_legend:
        fig.update_layout(showlegend=False)


    colorbar = go.surface.ColorBar(
        exponentformat="e",
        showexponent="all",
        title= go.surface.colorbar.Title(
            text = color_field,
            side = "right",
            font = go.surface.colorbar.title.Font(
                size = 40,
            ),
        ),
        thickness = 100,
        tickfont=go.surface.colorbar.Tickfont(
            size=40
        )
    )

    for i, layer in enumerate(data):

        fig.add_trace(
            go.Surface(
                x=x,
                y=y,
                z=layer.T - i*layer_offset,
                colorscale=colorscale,
                cmin=-cmax- i*layer_offset,
                cmax=cmax- i*layer_offset,
                colorbar = colorbar,
                showscale = i==0 and show_scale
            ),
        )



    camera = dict(
        up=dict(x=0, y=0, z=1),
        center=dict(x=0, y=0, z=0),
        eye=dict(x=-1.5, y=-1.5, z=1.25)
    )

    fig.update_layout(scene_camera=camera)

    return fig


# Obs

In [6]:
from matplotlib import pyplot as plt

from qgsw.specs import defaults


size = 500

torch.random.manual_seed(0)

filter = GaussianFilter2D(12.5)

surface = filter(torch.rand((size,size),**defaults.get())-0.5)
bottom = filter(torch.rand((size,size),**defaults.get())-0.5)

cmax = max(torch.max(torch.abs(surface)).cpu().item(),torch.max(torch.abs(bottom)).cpu().item())

fig = go.Figure()
fig.update_layout(
    width = 2000,
    height = 2000,
)

fig.update_layout(
    {
        "paper_bgcolor": "rgba(0, 0, 0, 0)",
        "plot_bgcolor": "rgba(0, 0, 0, 0)",
    },
)
fig.update_layout(showlegend=False)

fig.add_trace(
    go.Surface(
        z=surface,
        colorscale=px.colors.sequential.Blues,
        cmin=-cmax,
        cmax=cmax,
        showscale = False,
    ),
)
offset = 1
fig.add_trace(
    go.Surface(
        z=bottom - offset,
        colorscale=px.colors.sequential.Blues,
        cmin=-cmax - offset - 0.075,
        cmax=cmax - offset - 0.075,
        showscale=False,
    ),
)

fig.write_image("../output/presentation/two_layers_grid.png")

fig.update_scenes(xaxis_visible=False, yaxis_visible=False,zaxis_visible=False )

fig.write_image("../output/presentation/two_layers.png")

x_r = torch.randint(0,size,(25,))
y_r = torch.randint(0,size,(25,))
z_r = surface[y_r,x_r]

fig.add_trace(
    go.Scatter3d(
        name = "Observations Campaigns data",
        x = x_r,
        y = y_r,
        z = z_r,
        mode="markers",
        marker = {
            "size": 15,
            "color": "#7e0723"
            # "color": "red"
        }
    )
)

fig.add_trace(
    go.Scatter3d(
        name = "Observations Campaigns data",
        x = x_r[::5],
        y = y_r[::5],
        z = z_r[::5] - offset,
        mode="markers",
        marker = {
            "size": 8,
            "color": "#800F20"
            # "color" : "red"
        }
    )
)

fig.write_image("../output/presentation/two_layers_obs.png")

x_sat, y_sat = torch.meshgrid(
    torch.arange(0,size,size//20),
    torch.arange(0,size,size//20),
    indexing= 'ij'
)
z_sat = surface[y_sat,x_sat]

fig.add_trace(
    go.Scatter3d(
        name = "Satellite Observations",
        x = x_sat.flatten(),
        y = y_sat.flatten(),
        z = z_sat.flatten(),
        mode="markers",
        marker = {
            "size": 15,
            "color": "#00502e"
        }
    )
)

fig.write_image("../output/presentation/two_layers_obs_sat.png")
# fig.show()

TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

In [None]:
from qgsw.fields.variables.prognostic_tuples import UVH
from qgsw.configs.core import Configuration
from qgsw.fields.variables.dynamics import PhysicalVorticity, StreamFunctionFromVorticity, Vorticity
from qgsw.specs import defaults

config=Configuration.from_toml("../output/g5k/double_gyre_qg_long/_config.toml")

uvh = UVH.from_file("../output/g5k/double_gyre_qg_long/results_step_876000.pt", **defaults.get())
sf = StreamFunctionFromVorticity(PhysicalVorticity(Vorticity(), config.space.ds ),config.space.nx,config.space.ny,config.space.nx,config.space.ny)




In [161]:
from qgsw.forcing.wind import WindForcing


X = torch.arange(0,256, dtype=torch.float64)
X -= (X.max()-X.min())/2
Y = torch.cos(X/((X.max()-X.min())/2)*torch.pi)*30

wind = WindForcing.from_config(config.windstress,config.space,config.physics)
taux = wind.compute()[0]


fig = go.Figure()

fig.update_layout(
    width=1500,
    height=2000
)

fig.add_trace(
    go.Heatmap(
        z=(sf.compute(uvh)[0,0]).cpu().T,
        colorscale=px.colors.diverging.RdBu_r,
        zmin=-torch.max(sf.compute(uvh)[0,0].abs()).item(),
        zmax=torch.max(sf.compute(uvh)[0,0].abs()).item(),
        colorbar={"title":{"text":"Stream function [m².s⁻¹]", "side":"right", "font":{"size":75}}, "tickfont":{"size":75}}
    )
)

fig.add_trace(
    go.Scatter(
        x = (taux.T[:,0]-torch.min(taux)).cpu()*100000,
        y = torch.arange(taux.T.shape[0]),
        line={"color":"red", "width":30},
    )
)
fig.write_image("../output/presentation/windstress.png")

In [207]:
from pathlib import Path

from qgsw.fields.errors.error_sets import create_errors_set
from qgsw.models.instantiation import get_model_class
from qgsw.output import RunOutput
import numpy as np
from qgsw.plots.scatter import ScatterPlot


paths = {
    # "../output/local/assimilation_1l": "1L",
    "../output/local/assimilation_1l_400": "1L",
    "../output/local/assimilation_2l": "2L",
    "../output/local/assimilation_colfilt_smooth_optim": "Modified",
}

output = {
    p:None for p in paths.items()
}

variable = "psi_from_omega"
values = []
for path in paths:


    folder=Path(path)
    run = RunOutput(folder)

    config = run.summary.configuration

    errors = create_errors_set()

    vars_dict = get_model_class(config.model).get_variable_set(
        config.space,
        config.physics,
        config.model,
    )

    model_config_ref = run.summary.configuration.simulation.reference
    run_ref = RunOutput(folder, model_config=model_config_ref)

    vars_dict_ref = get_model_class(config.simulation.reference).get_variable_set(
        config.space,
        config.physics,
        config.simulation.reference,
    )
    errors = create_errors_set()
    rmse = errors["rmse"](vars_dict[variable],vars_dict_ref[variable])
    rmse.slices = [slice(None,None),slice(0,1),...]
    rmses = np.array([rmse.compute_ensemble_wise(o.read(),o_ref.read()).cpu().item() for o,o_ref in zip(run.outputs(),run_ref.outputs())])
    output[paths[path]] = rmses

fig = go.Figure()
fig.update_layout(
    width=2000,
    height = 750
)
fig.update_layout(
    xaxis = {"tickfont":{"size":50}, "title":{"text":"Time [days]", "font":{"size":60}}}
)
fig.update_layout(
    yaxis = {"tickfont":{"size":50}, "title":{"text":"Error", "font":{"size":60}}}
)
fig.update_layout(legend={"font":{"size":75}})
fig.add_trace(
    go.Scatter(
        x = [s/3600/24 for s in run.seconds()],
        y=output["1L"],
        line={
            "color":"#800F20",
            "width":6
        },
        name = "1L",
    ),
)
fig.add_trace(
    go.Scatter(
        x = [s/3600/24 for s in run.seconds()],
        y=output["Modified"],
        line={
            "color":"#1D4E89",
            "width":6
        },
        name = "Modified",
    )
)
fig.add_trace(
    go.Scatter(
        y=output["2L"],
        line={
            "color":"#E09F3E",
            "width":6
        },
        name = "2L",
    )
)
fig.show()

# plot = ScatterPlot(values)
# plot.figure.update_layout(template="plotly")
# plot.figure.update_yaxes(range=[0,None])
# plot.set_yaxis_title(f"{rmse.name} on '{vars_dict[variable].name}'")
# plot.set_traces_name(
#     *(" ".join(v.split("_")[1:]) for v in paths)
# )
# plot.show()

In [181]:

sf_top =  sf.compute(uvh)[0,0]
sf_bot =  sf.compute(uvh)[0,1]
zmax = max(sf_top.abs().max().item(), sf_bot.abs().max().item())

fig = go.Figure()

fig.update_layout(
    width=1500,
    height=2000
)

fig.add_trace(
    go.Heatmap(
        z=sf_top.cpu().T,
        colorscale=px.colors.diverging.RdBu_r,
        zmin=-zmax,
        zmax=zmax,
        colorbar={"title":{"text":"Stream function [m².s⁻¹]", "side":"right", "font":{"size":75}}, "tickfont":{"size":75}}
    )
)
fig.show()
fig = go.Figure()

fig.update_layout(
    width=1500,
    height=2000
)

fig.add_trace(
    go.Heatmap(
        z=sf_bot.cpu().T,
        colorscale=px.colors.diverging.RdBu_r,
        zmin=-zmax,
        zmax=zmax,
        colorbar={"title":{"text":"Stream function [m².s⁻¹]", "side":"right", "font":{"size":75}}, "tickfont":{"size":75}}
    )
)
fig.show()

In [185]:
from qgsw.fields.variables.coefficients.core import SmoothNonUniformCoefficient
from qgsw.filters.high_pass import SpectralGaussianHighPass2D
from qgsw.masks import Masks
from qgsw.models.qg.projected.projectors.core import QGProjector
from qgsw.models.qg.projected.projectors.filtered import CollinearFilteredQGProjector
from qgsw.models.qg.stretching_matrix import compute_A
from qgsw.spatial.core.coordinates import Coordinates1D
from qgsw.spatial.core.discretization import SpaceDiscretization2D


nx,ny = config.space.nx, config.space.ny

n_alpha_x = 2
n_alpha_y = 4

ly = ny // n_alpha_y
lx = nx // n_alpha_x


centers = []
for i in range(n_alpha_x):
    for j in range(n_alpha_y):
        centers.append([(i*lx+(i+1)*lx)//2, (j*ly+(j+1)*ly)//2])


P=QGProjector(
    compute_A(config.simulation.reference.h,config.simulation.reference.g_prime, **defaults.get()),
    config.simulation.reference.h.unsqueeze(1).unsqueeze(1),
    SpaceDiscretization2D.from_config(config.space).add_h(Coordinates1D(points=config.simulation.reference.h,unit=Unit.M)),
    config.physics.f0,
    Masks.empty(config.space.nx,config.space.ny,defaults.get_device())
)

p, p_i = P.compute_p(uvh)

coef = SmoothNonUniformCoefficient(nx=128, ny=256)
coef.sigma = 20.35
coef.with_optimal_values(
    CollinearFilteredQGProjector.create_filter(20.35)(p_i[0,0]),p_i[0,1],
    centers=centers
)

In [211]:
fig = go.Figure()

fig.update_layout(
    width=1500,
    height=2000
)

fig.add_trace(
    go.Heatmap(
        z=coef.get()[0,0].cpu().T,
        colorscale=px.colors.sequential.Reds,
        colorbar={"title":{"text":"α [ ]", "side":"right", "font":{"size":75}}, "tickfont":{"size":75}},
        zmin=0,
        zmax=1
    )
)
fig.show()

fig = go.Figure()

fig.update_layout(
    width=1500,
    height=2000
)
fig.add_trace(
    go.Heatmap(
        z=sf_top.cpu().T,
        colorscale=px.colors.diverging.RdBu_r,
        coloraxis='coloraxis',
        zmin=-zmax,
        zmax=zmax,
        colorbar={"title":{"text":"Stream function [m².s⁻¹]", "side":"right", "font":{"size":75}}, "tickfont":{"size":75}},

    )
)

fig.add_trace(
    go.Contour(
        z=coef.get()[0,0].cpu().T,
        colorscale=px.colors.sequential.Reds,
        colorbar={"title":{"text":"α [ ]", "side":"right", "font":{"size":75}}, "tickfont":{"size":75}},
        contours_coloring="lines",
        coloraxis='coloraxis1',
        line={"width":5},
    )
)
fig.show()

In [210]:
from qgsw.fields.variables.prognostic_tuples import UVH
from qgsw.configs.core import Configuration
from qgsw.fields.variables.dynamics import PhysicalVorticity, StreamFunctionFromVorticity, Vorticity
from qgsw.specs import defaults

run = RunOutput("../output/local/assimilation_colfilt_smooth_optim")
config=Configuration.from_toml("../output/local/assimilation_colfilt_smooth_optim/_config.toml")

sf = StreamFunctionFromVorticity(PhysicalVorticity(Vorticity(), config.space.ds ),config.space.nx,config.space.ny,config.space.nx,config.space.ny)

In [287]:
maxs = [sf.compute(o.read()).abs().max().item() for o in run.outputs()] 
zmax = max(maxs)
f=lambda x:"" if x==0 else f((x-1)//26)+chr((x-1)%26+ord("A"))
for i,output in enumerate(run.outputs()):
    uvh = output.read()
    fig = go.Figure()

    fig.update_layout(
        width=750,
        height=1000
    )

    fig.add_trace(
        go.Heatmap(
            z=sf.compute(uvh)[0,0].cpu().T,
            colorscale=px.colors.diverging.RdBu_r,
            colorbar={"title":{"text":"Stream function [m².s⁻¹]", "side":"right", "font":{"size":35}}, "tickfont":{"size":35}},
            zmin=-zmax,
            zmax=zmax
        )
    )
    fig.write_image(f"../output/presentation/gif/{number_to_string(i+1)}.png")

In [278]:
def number_to_string(n, width=10):
    """
    Convert a number to a string with zero-padding to preserve sorting order.
    
    Args:
        n (int): The number to convert.
        width (int): The fixed width for zero-padding (default: 10).
    
    Returns:
        str: A zero-padded string representation of the number.
    """
    return str(n).zfill(width)

In [275]:
num_to_let(28)

'AC'

In [252]:
ord("A")+(1/26)//26

65.0

In [262]:
chr(ord(txt[1])+1-26*1//26+1)

NameError: name 'txt' is not defined

In [276]:
chr(int(ord("A")+27-26*((27/26)//26)))

'\\'

In [277]:
27-26*((27/26)//26)

27.0