In [None]:
import os
from pathlib import Path

import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from qgsw.run_summary import RunSummary

ROOT = Path(os.path.abspath('')).parent

In [None]:
plots = [
{
"field": "pv",
"layer": 0,
"input": "one_layer_baroclinic_30km",
"output": "1L_baroclinic_30km",
"steps": [33084, 62492, 91898],
"colorscale": "RdBu_r",
},
{
"field": "pv",
"layer": 0,
"input": "two_layers_baroclinic_30km",
"output": "2L_baroclinic_30km",
"steps": [33084, 62492, 91898],
"colorscale": "RdBu_r",
},
{
"field": "pv",
"layer": 1,
"input": "two_layers_baroclinic_30km",
"output": "2L_baroclinic_30km_bottom",
"steps": [33084, 62492, 91898],
"colorscale": "RdBu_r",
},
{
"field": "pv",
"layer": 0,
"input": "one_layer_baroclinic_100km",
"output": "1L_baroclinic_100km",
"steps": [33084, 62492, 91898],
"colorscale": "RdBu_r",
},
{
"field": "pv",
"layer": 0,
"input": "two_layers_baroclinic_100km",
"output": "2L_baroclinic_100km",
"steps": [33084, 62492, 91898],
"colorscale": "RdBu_r",
},
{
"field": "pv",
"layer": 1,
"input": "two_layers_baroclinic_100km",
"output": "2L_baroclinic_100km_bottom",
"steps": [33084, 62492, 91898],
"colorscale": "RdBu_r",
},
{
"field": "pv",
"layer": 0,
"input": "one_layer_barotropic_100km",
"output": "1L_barotropic_100km",
"steps": [33084, 62492, 91898],
"colorscale": "RdBu_r",
},
{
"field": "pv",
"layer": 0,
"input": "two_layers_barotropic_100km",
"output": "2L_barotropic_100km",
"steps": [33084, 62492, 91898],
"colorscale": "RdBu_r",},
{
"field": "pv",
"layer": 1,
"input": "two_layers_barotropic_100km",
"output": "2L_barotropic_100km_bottom",
"steps": [33084, 62492, 91898],
"colorscale": "RdBu_r",
}
]

In [None]:
import os
from pathlib import Path

import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from qgsw.run_summary import RunSummary
import toml

ROOT = Path(os.path.abspath('')).parent


plots_config = toml.load(ROOT.joinpath("config/save_plots.toml"))

for plot_config in plots:
    field = plot_config["field"]
    layer = plot_config["layer"]

    input_folder = ROOT.joinpath(f"output/g5k/{plot_config['input']}")
    output_folder = ROOT.joinpath(f"output/snapshots/{plot_config['output']}")
    if not output_folder.is_dir():
        output_folder.mkdir(parents=True)

    summary = RunSummary.from_file(input_folder.joinpath("_summary.toml"))
    config = summary.configuration
    x_min, x_max = config.space.box.x_min, config.space.box.x_max
    y_min, y_max = config.space.box.y_min, config.space.box.y_max
    
    offset = 24

    datas = []

    for i,step in enumerate(plot_config["steps"]):
        file = input_folder.joinpath(f"{config.model.prefix}{step}.npz")

        data = np.load(file)[field][0, layer, ...][offset:-offset,offset:-offset]

        datas.append(data)

    zmax = max(np.max(np.abs(data)) for data in datas)
    zmin = -zmax

    for i,data in enumerate(datas):
        colorbar = go.heatmap.ColorBar(
            exponentformat="e",
            showexponent="all",
            title={"text": "Potential Vorticity (s⁻¹)", "side": "right"},
            thickness=100
        )

        x = np.linspace(x_min / 1000, x_max / 1000, config.space.nx)[offset:-offset]
        y = np.linspace(y_min / 1000, y_max / 1000, config.space.ny)[offset:-offset]

        heatmap = go.Heatmap(
            z=data.T,
            x=x,
            y=y,
            colorscale=px.colors.diverging.RdBu_r,
            zmin=zmin,
            zmax=zmax,
            colorbar=colorbar,
            showscale=i == 2,
        )
        fig = go.Figure()
        fig.add_trace(heatmap)#, row=1, col=i+1)
        fig.add_shape(type="rect", xref=f"x", yref=f"y", x0=x[0], x1=x[-1], y0=y[0], y1=y[-1], line=dict(color="black", width=2))

        fig.update_layout(
            # template="simple_white",
            autosize=True,
            width=1000+ (400 * (i==2)) + 100 * (i==0),
            height=1000 ,
            font={"size": 60, "color":"black"},
            xaxis={"scaleanchor": "y", "constrain": "domain"},
            yaxis={"scaleanchor": "x", "constrain": "domain"},
        )
        fig.update_layout(
            margin=dict(l=20, r=20, t=20, b=20),
        )

        fig.update_xaxes(
            title={"text": "X (km)"},
            exponentformat="none",
            dtick = (x_max - x_min) / 1000 / 6 - ((x_max - x_min)/ 1000 / 6 )% 50,
            tick0 = 0,
        )

        fig.update_yaxes(
            title={"text": "Y (km)"},
            exponentformat="none",
            dtick = (y_max - y_min) / 1000 / 6 - ((y_max - y_min)/1000/6) % 50,
            tick0 = 0,
            ticksuffix = "  ",
            visible = i == 0,
        )
        # fig.show()
        fig.write_image(output_folder.joinpath(f"snapshot_{plot_config['steps'][i]}.png"))

In [None]:
plots_compare = [
{
"field": "omega",
"layer": 0,
"name": "1l vs 2l baroclinic",
"input": ["one_layer_baroclinic_30km", "two_layers_baroclinic_30km"],
"output": "compare_baroclinic_30km",
},
{
"field": "omega",
"layer": 0,
"name": "1l vs 2l barotropic",
"input": ["one_layer_barotropic_100km", "two_layers_barotropic_100km"],
"output": "compare_barotropic_100km",
},
{
"field": "omega",
"layer": 0,
"name": "1l vs 2l",
"input": ["one_layer_baroclinic_100km", "two_layers_baroclinic_100km"],
"output": "compare_baroclinic_100km",
},
{
"field": "omega",
"layer": 0,
"name": "SF0 vs 2l baroclinic",
"input": ["sf_alpha_0_30km", "two_layers_baroclinic_30km"],
"output": "compare_baroclinic_100km",
},
{
"field": "omega",
"layer": 0,
"name": "SF1 vs 2l barotropic",
"input": ["sf_alpha_1", "two_layers_barotropic_100km"],
"output": "compare_baroclinic_100km",
},
{
"field": "omega",
"layer": 0,
"name": "SF0.1 vs 2l barotropic",
"input": ["sf_alpha_0_1", "two_layers_barotropic_100km"],
"output": "compare_baroclinic_100km",
},
]

In [None]:
def loss(x:np.ndarray, y:np.ndarray) -> float:
    return np.sqrt(np.mean(np.square(x - y))) / 9.375e-5

In [None]:
import os
from pathlib import Path

import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from qgsw.run_summary import RunSummary
import toml
from qgsw.utils.sorting import sort_files

ROOT = Path(os.path.abspath('')).parent

plots_config = toml.load(ROOT.joinpath("config/save_plots.toml"))
fig = go.Figure()
for plot_config in plots_compare:
    field = plot_config["field"]
    layer = plot_config["layer"]

    input_folder1 = ROOT.joinpath(f"output/g5k/{plot_config['input'][0]}")
    input_folder2 = ROOT.joinpath(f"output/g5k/{plot_config['input'][1]}")
    output_folder = ROOT.joinpath(f"output/snapshots/{plot_config['output']}")

    if not output_folder.is_dir():
        output_folder.mkdir(parents=True)

    summary1 = RunSummary.from_file(input_folder1.joinpath("_summary.toml"))
    config1 = summary1.configuration
    summary2 = RunSummary.from_file(input_folder2.joinpath("_summary.toml"))
    config2 = summary2.configuration

    steps_1, files_1 = sort_files(list(input_folder1.glob(f"{config1.model.prefix}*.npz")),config1.model.prefix,".npz")
    steps_2, files_2 = sort_files(list(input_folder2.glob(f"{config2.model.prefix}*.npz")),config2.model.prefix,".npz")

    x_min, x_max = config1.space.box.x_min, config1.space.box.x_max
    y_min, y_max = config1.space.box.y_min, config1.space.box.y_max
    
    offset = 24

    losses = []
    times = []

    for i,file1 in enumerate(files_1):

        file2 = files_2[i]

        data1 = np.load(file1)[field][0, layer, ...][offset:-offset,offset:-offset]
        data2 = np.load(file2)[field][0, layer, ...][offset:-offset,offset:-offset]

        losses.append(loss(data1, data2))
        times.append(steps_1[i] * config1.simulation.dt)
    
    scatter = go.Scatter(
        x=times, 
        y=losses,
        name = plot_config["name"]
    )


    fig.add_trace(scatter)#, row=1, col=i+1)

fig.update_layout(
    # template="simple_white",
    autosize=True,
    width=1500,
    height=750 ,
    # font={"size": 60, "color":"black"},
)
fig.update_layout(
    margin=dict(l=20, r=20, t=20, b=20),
)

fig.update_xaxes(
    title={"text": "Time (s)"},
    exponentformat="e",
)

fig.update_yaxes(
    title={"text": r"RMSE ($f_0$)"},
    exponentformat="none",
    ticksuffix = "  ",
)
fig.show()