##  Gridded Model Verification

This script verifies output from a ML-based foundation model versus a
traditional NWP system for the atmospheric system. The defaults set at the top of
this script are tailored to the Alps-Clariden HPC system at CSCS.
- The NWP-model is called COSMO-E and is initialised with the ensemble mean of the analysis. Only surface level data is available in the archive at MeteoSwiss.
- The ML-model is called Neural-LAM and is initialised with the deterministic analysis.
- The Ground Truth is the same deterministic analysis as was used to train the ML-model.
- The boundary data for both models is IFS HRES from ECMWF, where the NWP-model got 6 hourly boundary updates and the ML model 12 hourly.

For more info about the COSMO model see:
- https://www.cosmo-model.org/content/model/cosmo/coreDocumentation/cosmo_io_guide_6.00.pdf
- https://www.research-collection.ethz.ch/handle/20.500.11850/720460

In [1]:
import random
from pathlib import Path

import cartopy.crs as ccrs
import cartopy.feature as cfeature
import dask
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
from dask.distributed import Client, LocalCluster
from pysteps.verification.salscores import sal  # requires scikit-image
from scipy.stats import wasserstein_distance, kurtosis, skew
from scores.continuous import (
    mae,
    mse,
    rmse,
)
from scores.continuous.correlation import pearsonr
from scores.spatial import fss_2d

  _set_context_ca_bundle_path(ca_bundle_path)


Pysteps configuration file found at: /users/sadamov/miniforge3/envs/neural-lam/lib/python3.12/site-packages/pysteps/pystepsrc



**--------> Enter all your user settings in the cell below. <--------**

In [2]:
### DEFAULTS ###
# This config will be applied to the data before any plotting. The data will be
# sliced and indexed according to the values in this config.
PATH_GROUND_TRUTH = "/iopsstor/scratch/cscs/sadamov/pyprojects_data/neural-lam/cosmo.datastore.zarr"
PATH_NWP = "/capstor/store/cscs/swissai/a01/sadamov/cosmo_e_forecast.zarr"
PATH_ML = "/iopsstor/scratch/cscs/sadamov/pyprojects_data/neural-lam/eval_results/preds_7_19_margin_interior_lr_0001_ar_12.zarr"
PATH_BOUNDARY = "/iopsstor/scratch/cscs/sadamov/pyprojects_data/neural-lam/ifs_7_19_margin_interior.datastore.zarr"
# elapsed forecast duration in steps for the forecast - [0] refers to the first forecast step
ELAPSED_FORECAST_DURATION = list(range(0, 120, 1))
# Select specific start_times for the forecast (will be used for all variables and metrics)
START_TIMES = [None, None]
# Select specific plot times for the forecast (will be used to create maps for all variables)
PLOT_TIME = "2020-02-13T00:00:00"
# Selection spatial grid in projection
X = [None, None]
Y = [None, None]
# Map projection settings for plotting
PROJECTION = ccrs.RotatedPole(
    pole_longitude=190,
    pole_latitude=43,
    central_rotated_longitude=10,
)
# Define how variables map between different data sources
VARIABLES_GROUND_TRUTH = {
    # Surface and near-surface variables
    "T_2M": "temperature_2m",
    "U_10M": "wind_u_10m",
    "V_10M": "wind_v_10m",
    "PMSL": "pressure_sea_level",
    "PS": "surface_pressure",
    "TOT_PREC": "precipitation",
    "ASHFL_S": "surface_sensible_heat_flux",
    "ASOB_S": "surface_net_shortwave_radiation",
    "ATHB_S": "surface_net_longwave_radiation",
    # Upper air variables - U component
    "U_lev_6": "wind_u_level_6",
    "U_lev_12": "wind_u_level_12",
    "U_lev_20": "wind_u_level_20",
    "U_lev_27": "wind_u_level_27",
    "U_lev_31": "wind_u_level_31",
    "U_lev_39": "wind_u_level_39",
    "U_lev_45": "wind_u_level_45",
    "U_lev_60": "wind_u_level_60",
    # Upper air variables - V component
    "V_lev_6": "wind_v_level_6",
    "V_lev_12": "wind_v_level_12",
    "V_lev_20": "wind_v_level_20",
    "V_lev_27": "wind_v_level_27",
    "V_lev_31": "wind_v_level_31",
    "V_lev_39": "wind_v_level_39",
    "V_lev_45": "wind_v_level_45",
    "V_lev_60": "wind_v_level_60",
    # Upper air variables - Pressure
    "PP_lev_6": "pressure_level_6",
    "PP_lev_12": "pressure_level_12",
    "PP_lev_20": "pressure_level_20",
    "PP_lev_27": "pressure_level_27",
    "PP_lev_31": "pressure_level_31",
    "PP_lev_39": "pressure_level_39",
    "PP_lev_45": "pressure_level_45",
    "PP_lev_60": "pressure_level_60",
    # Upper air variables - Temperature
    "T_lev_6": "temperature_level_6",
    "T_lev_12": "temperature_level_12",
    "T_lev_20": "temperature_level_20",
    "T_lev_27": "temperature_level_27",
    "T_lev_31": "temperature_level_31",
    "T_lev_39": "temperature_level_39",
    "T_lev_45": "temperature_level_45",
    "T_lev_60": "temperature_level_60",
    # Upper air variables - Relative Humidity
    "RELHUM_lev_6": "relative_humidity_level_6",
    "RELHUM_lev_12": "relative_humidity_level_12",
    "RELHUM_lev_20": "relative_humidity_level_20",
    "RELHUM_lev_27": "relative_humidity_level_27",
    "RELHUM_lev_31": "relative_humidity_level_31",
    "RELHUM_lev_39": "relative_humidity_level_39",
    "RELHUM_lev_45": "relative_humidity_level_45",
    "RELHUM_lev_60": "relative_humidity_level_60",
    # Upper air variables - Vertical velocity
    "W_lev_6": "vertical_velocity_level_6",
    "W_lev_12": "vertical_velocity_level_12",
    "W_lev_20": "vertical_velocity_level_20",
    "W_lev_27": "vertical_velocity_level_27",
    "W_lev_31": "vertical_velocity_level_31",
    "W_lev_39": "vertical_velocity_level_39",
    "W_lev_45": "vertical_velocity_level_45",
    "W_lev_60": "vertical_velocity_level_60",
}
VARIABLES_ML = VARIABLES_GROUND_TRUTH
VARIABLES_NWP = {
    "wind_u_10m": "wind_u_10m",
    "wind_v_10m": "wind_v_10m",
    "precipitation_1hr": "precipitation",
    "pressure_sea_level": "pressure_sea_level",
    "surface_pressure": "surface_pressure",
    "temperature_2m": "temperature_2m",
}
VARIABLES_BOUNDARY = {
    # Surface and near-surface variables
    "mean_sea_level_pressure": "pressure_sea_level",
    "2m_temperature": "temperature_2m",
    "10m_u_component_of_wind": "wind_u_10m",
    "10m_v_component_of_wind": "wind_v_10m",
    "surface_pressure": "surface_pressure",
    # Upper air variables - U component
    "u_component_of_wind100hPa": "wind_u_level_6",
    "u_component_of_wind200hPa": "wind_u_level_12",
    "u_component_of_wind400hPa": "wind_u_level_20",
    "u_component_of_wind600hPa": "wind_u_level_27",
    "u_component_of_wind700hPa": "wind_u_level_31",
    "u_component_of_wind850hPa": "wind_u_level_39",
    "u_component_of_wind925hPa": "wind_u_level_45",
    "u_component_of_wind1000hPa": "wind_u_level_60",
    # Upper air variables - V component
    "v_component_of_wind100hPa": "wind_v_level_6",
    "v_component_of_wind200hPa": "wind_v_level_12",
    "v_component_of_wind400hPa": "wind_v_level_20",
    "v_component_of_wind600hPa": "wind_v_level_27",
    "v_component_of_wind700hPa": "wind_v_level_31",
    "v_component_of_wind850hPa": "wind_v_level_39",
    "v_component_of_wind925hPa": "wind_v_level_45",
    "v_component_of_wind1000hPa": "wind_v_level_60",
    # Upper air variables - Temperature
    "temperature100hPa": "temperature_level_6",
    "temperature200hPa": "temperature_level_12",
    "temperature400hPa": "temperature_level_20",
    "temperature600hPa": "temperature_level_27",
    "temperature700hPa": "temperature_level_31",
    "temperature850hPa": "temperature_level_39",
    "temperature925hPa": "temperature_level_45",
    "temperature1000hPa": "temperature_level_60",
    # Upper air variables - Vertical velocity
    "vertical_velocity100hPa": "vertical_velocity_level_6",
    "vertical_velocity200hPa": "vertical_velocity_level_12",
    "vertical_velocity400hPa": "vertical_velocity_level_20",
    "vertical_velocity600hPa": "vertical_velocity_level_27",
    "vertical_velocity700hPa": "vertical_velocity_level_31",
    "vertical_velocity850hPa": "vertical_velocity_level_39",
    "vertical_velocity925hPa": "vertical_velocity_level_45",
    "vertical_velocity1000hPa": "vertical_velocity_level_60",
}

VARIABLES_3D = [
    "wind_u_level",
    "wind_v_level",
    "pressure_level",
    "temperature_level",
    "relative_humidity_level",
    "vertical_velocity_level",
]

# Add units dictionary after the imports
VARIABLE_UNITS = {
    # Surface and near-surface variables
    "temperature_2m": "K",
    "wind_u_10m": "m/s",
    "wind_v_10m": "m/s",
    "pressure_sea_level": "Pa",
    "surface_pressure": "Pa",
    "precipitation": "mm/h",
    "surface_sensible_heat_flux": "W/m²",
    "surface_net_shortwave_radiation": "W/m²",
    "surface_net_longwave_radiation": "W/m²",
    # Upper air variables
    "wind_u_level": "m/s",
    "wind_v_level": "m/s",
    "pressure_level": "hPa",
    "temperature_level": "K",
    "relative_humidity_level": "%",
    "vertical_velocity_level": "Pa/s",
}

# Add level-specific units based on VARIABLES_GROUND_TRUTH
required_levels = set()
for key in VARIABLES_GROUND_TRUTH.keys():
    if "lev_" in key:
        level = int(key.split("_")[-1])
        required_levels.add(level)

for level in required_levels:
    VARIABLE_UNITS[f"wind_u_level_{level}"] = "m/s"
    VARIABLE_UNITS[f"wind_v_level_{level}"] = "m/s"
    VARIABLE_UNITS[f"pressure_level_{level}"] = "hPa"
    VARIABLE_UNITS[f"temperature_level_{level}"] = "K"
    VARIABLE_UNITS[f"relative_humidity_level_{level}"] = "%"
    VARIABLE_UNITS[f"vertical_velocity_level_{level}"] = "Pa/s"

# For some plots a random time step sample is selected
RANDOM_SEED = 42
DPI = 100
# Subsample the data for faster plotting, 10 refers to every 10th element
SUBSAMPLE_HISTOGRAM = 10
# Subsample the data for FSS threshold calculation, 1e7 refers to the number of elements
SUBSAMPLE_FSS_THRESHOLD = 1e7

CHECK_MISSING = False

In [3]:
# Create directories for plots and tables
Path("plots").mkdir(exist_ok=True)
Path("tables").mkdir(exist_ok=True)

# Colorblind-friendly color palette
COLORS = {
    "ground_truth": "#000000",  # Black
    "ml": "#E69F00",  # Orange
    "nwp": "#56B4E9",  # Light blue
    "error": "#CC79A7",  # Pink
}

# Line styles and markers for accessibility
LINE_STYLES = {
    "ground_truth": ("solid", "o"),
    "ml": ("dashed", "s"),
    "nwp": ("dotted", "^"),
}

# Colorblind-friendly colormap for 2D plots
COLORMAP = "viridis"


def save_plot(fig, name, time=None):
    """Helper function to save plots consistently"""
    if time is not None:
        name = f"{name}_{time.dt.strftime('%Y%m%d_%H').values}"
    fig.savefig(f"plots/{name}.pdf", bbox_inches="tight", dpi=300)
    # fig.savefig(f"plots/{name}.png", bbox_inches="tight", dpi=300)


def export_table(df, name, caption=""):
    """Helper function to export tables consistently"""
    # Export to LaTeX with caption
    latex_str = df.to_latex(
        float_format="%.4f", caption=caption, label=f"tab:{name}"
    )
    with open(f"tables/{name}.tex", "w") as f:
        f.write(latex_str)

    # Export to CSV
    df.to_csv(f"tables/{name}.csv")

In [None]:
ds_ml = xr.open_zarr(PATH_ML)
ds_ml = ds_ml.sel(state_feature=list(VARIABLES_ML.keys()))
ds_ml = ds_ml.sel(y=slice(*Y), x=slice(*X))
ds_ml = ds_ml.sel(start_time=slice(*START_TIMES))
for feature in ds_ml.state_feature.values:
    ds_ml[VARIABLES_ML[feature]] = ds_ml["state"].sel(state_feature=feature)
forecast_times = (
    ds_ml.start_time.values[:, None] + ds_ml.elapsed_forecast_duration.values
)
ds_ml = ds_ml.assign_coords(
    forecast_time=(("start_time", "elapsed_forecast_duration"), forecast_times)
)
ds_ml = ds_ml.drop_vars(["state", "state_feature", "time"])
ds_ml = ds_ml.transpose("start_time", "elapsed_forecast_duration", "x", "y")
ds_ml = ds_ml[
    [
        "start_time",
        "elapsed_forecast_duration",
        "x",
        "y",
        *VARIABLES_ML.values(),
    ]
]
ds_ml = ds_ml.isel(elapsed_forecast_duration=ELAPSED_FORECAST_DURATION)
ds_ml = ds_ml.compute()

ds_ml_first_timestep = (
    ds_ml.isel(elapsed_forecast_duration=0)
    .rename({"forecast_time": "time"})
    .swap_dims({"start_time": "time"})
    .drop_vars(["start_time", "elapsed_forecast_duration"])
).compute()

ds_ml

In [None]:
TIMES = np.unique(ds_ml.forecast_time.values.flatten())
START_TIMES = ds_ml.start_time
STEP_SIZE = ds_ml.elapsed_forecast_duration.diff(
    "elapsed_forecast_duration"
).values[0]

In [None]:
ds_gt = xr.open_zarr(PATH_GROUND_TRUTH)
ds_gt = ds_gt.set_index(grid_index=["y", "x"]).unstack("grid_index")
ds_gt = ds_gt.sel(y=slice(*Y), x=slice(*X))
ds_gt = ds_gt.sel(state_feature=list(VARIABLES_ML.keys()))
ds_gt = ds_gt.sel(split_name="test").drop_dims([
    "forcing_feature",
    "static_feature",
    "split_part",
])
for feature in ds_gt.state_feature.values:
    ds_gt[VARIABLES_ML[feature]] = ds_gt["state"].sel(state_feature=feature)
ds_gt = ds_gt.drop_vars([
    "state",
    "state_feature",
    "state_feature_units",
    "state_feature_long_name",
    "state_feature_source_dataset",
    "state__train__diff_mean",
    "state__train__diff_std",
    "state__train__mean",
    "state__train__std",
])
ds_gt = ds_gt.transpose("time", "x", "y")
ds_gt = ds_gt[
    [
        "time",
        "x",
        "y",
        *VARIABLES_GROUND_TRUTH.values(),
    ]
]
ds_gt_first_timestep = (
    ds_gt.sel(
        time=START_TIMES + np.timedelta64(ELAPSED_FORECAST_DURATION[0] + 1, "h")
    )
    .swap_dims({"start_time": "time"})
    .drop_vars(["start_time"])
)
ds_gt_first_timestep = ds_gt_first_timestep.compute()
ds_gt = ds_gt.sel(time=TIMES)
ds_gt = ds_gt.compute()
ds_gt


In [None]:
ds_nwp = xr.open_zarr(PATH_NWP)
ds_nwp = ds_nwp.sel(y=slice(*Y), x=slice(*X), time=START_TIMES)
# The NWP data starts at lead time 0 = start_time
ds_nwp = ds_nwp.drop_isel(lead_time=0).isel(lead_time=ELAPSED_FORECAST_DURATION)
ds_nwp = ds_nwp[VARIABLES_NWP.keys()].rename(VARIABLES_NWP)
ds_nwp = ds_nwp.rename_dims({"lead_time": "elapsed_forecast_duration"})
ds_nwp = ds_nwp.rename_vars({"lead_time": "elapsed_forecast_duration"})
forecast_times = (
    ds_nwp.start_time.values[:, None] + ds_nwp.elapsed_forecast_duration.values
)
ds_nwp = ds_nwp.assign_coords(
    forecast_time=(("start_time", "elapsed_forecast_duration"), forecast_times)
)
ds_nwp = ds_nwp.drop_vars(["time"])
ds_nwp = ds_nwp.transpose("start_time", "elapsed_forecast_duration", "x", "y")
ds_nwp = ds_nwp[
    [
        "start_time",
        "elapsed_forecast_duration",
        "x",
        "y",
        *VARIABLES_NWP.values(),
    ]
]
ds_nwp = ds_nwp.compute()

ds_nwp_first_timestep = (
    ds_nwp.isel(elapsed_forecast_duration=0)
    .rename({"forecast_time": "time"})
    .swap_dims({"start_time": "time"})
    .drop_vars(["start_time", "elapsed_forecast_duration"])
).compute()

ds_nwp

Check for missing data in any of the variables. If you have missing data, you need to handle it before running the verification.

In [None]:
if CHECK_MISSING:
    with LocalCluster(
        n_workers=16,
        threads_per_worker=1,
        memory_limit="16GB",
    ) as cluster:
        with Client(cluster) as client:
            missing_counts = dask.compute(
                {
                    var: ds_gt[var].isnull().sum().values
                    for var in ds_gt.data_vars
                },
                {
                    var: ds_nwp[var].isnull().sum().values
                    for var in ds_nwp.data_vars
                },
                {
                    var: ds_ml[var].isnull().sum().values
                    for var in ds_ml.data_vars
                },
            )
    # Unpack results
    gt_missing, nwp_missing, ml_missing = missing_counts

    # Print results
    print("Ground Truth")
    for var, count in gt_missing.items():
        print(f"{var}: {count} missing values")

    print("\nNWP Model")
    for var, count in nwp_missing.items():
        print(f"{var}: {count} missing values")

    print("\nML Model")
    for var, count in ml_missing.items():
        print(f"{var}: {count} missing values")

In [None]:
assert ds_gt.sizes["x"] == ds_ml.sizes["x"]
assert ds_gt.sizes["x"] == ds_nwp.sizes["x"]
assert ds_gt.sizes["y"] == ds_ml.sizes["y"]
assert ds_gt.sizes["y"] == ds_nwp.sizes["y"]
assert ds_gt.sizes["time"] == len(
    np.unique(ds_ml.forecast_time.values.flatten())
)

In [None]:
# Get coordinates
if hasattr(ds_gt, "longitude") and hasattr(ds_gt, "latitude"):
    lons = ds_gt.longitude.values
    lats = ds_gt.latitude.values
elif hasattr(ds_gt, "lon") and hasattr(ds_gt, "lat"):
    lons = ds_gt.lon.values
    lats = ds_gt.lat.values

lon_min = lons.min()
lon_max = lons.max()
lat_min = lats.min()
lat_max = lats.max()

# Transform domain bounds to rotated coordinates
transformer = PROJECTION.transform_points(
    ccrs.PlateCarree(),
    np.array([lon_min, lon_max]),
    np.array([lat_min, lat_max]),
)

# Get rotated coordinate bounds
rot_lon_min, rot_lon_max = transformer[:, 0].min(), transformer[:, 0].max()
rot_lat_min, rot_lat_max = transformer[:, 1].min(), transformer[:, 1].max()

### 1. Maps

**Random Time Selection:** A random time step is selected to avoid bias in the comparison, ensuring that the assessment is representative of typical model performance.

**Consistent Color Scales:** By setting the same minimum and maximum values across all datasets for each variable, we ensure that color differences in the plots reflect true discrepancies, not artifacts of scaling.

**Spatial Patterns:** The plots reveal how the ML model and NWP model represent geographical features like weather fronts, high and low-pressure systems, and temperature gradients. Visual comparisons can immediately highlight areas where the models perform well or poorly, guiding further investigation.

**Edge Effects:** Near the boundaries, artifacts may occur as the model does not calculate a loss in the boundary region.

In [None]:
ds_boundary = xr.open_zarr(PATH_BOUNDARY)

temporal_dim = "time" if "time" in ds_boundary.dims else "analysis_time"
forecast_duration_dim = (
    "elapsed_forecast_duration"
    if "elapsed_forecast_duration" in ds_boundary.dims
    else None
)
dims_to_transpose = [
    dim
    for dim in [temporal_dim, forecast_duration_dim, "latitude", "longitude"]
    if dim is not None
]

ds_boundary = ds_boundary.sel(forcing_feature=list(VARIABLES_BOUNDARY.keys()))
ds_boundary = ds_boundary.sel(split_name="test").drop_dims([
    "split_part",
    "static_feature",
])
for feature in ds_boundary.forcing_feature.values:
    ds_boundary[VARIABLES_BOUNDARY[feature]] = ds_boundary["forcing"].sel(
        forcing_feature=feature
    )
ds_boundary = ds_boundary.drop_vars([
    "forcing",
    "forcing_feature",
    "forcing_feature_units",
    "forcing_feature_long_name",
    "forcing_feature_source_dataset",
    "forcing__train__diff_mean",
    "forcing__train__diff_std",
    "forcing__train__mean",
    "forcing__train__std",
])
ds_boundary = ds_boundary.set_index(grid_index=["latitude", "longitude"])
ds_boundary = ds_boundary.unstack("grid_index")
ds_boundary = ds_boundary.transpose(*dims_to_transpose)
longitude_new = np.where(
    ds_boundary["longitude"] > 180,
    ds_boundary["longitude"] - 360,
    ds_boundary["longitude"],
)
ds_boundary = ds_boundary.assign_coords(longitude=longitude_new).sortby([
    "longitude",
    "latitude",
])


lon_mesh, lat_mesh = np.meshgrid(ds_boundary.longitude, ds_boundary.latitude)
ds_boundary

In [None]:
def create_comparison_maps(
    ds_gt,
    ds_ml,
    ds_nwp,
    ds_boundary=None,
    var=None,
    elapsed_forecast_duration=None,
    plot_time=None,
    random_seed=42,
):
    """Create comparison maps for model outputs.

    Parameters:
    -----------
    ds_gt : xarray.Dataset
        Ground truth dataset
    ds_ml : xarray.Dataset
        ML model predictions
    ds_nwp : xarray.Dataset
        NWP model predictions (can be None)
    ds_boundary : xarray.Dataset, optional
        Boundary condition dataset
    var : str, optional
        Variable to plot (if None, plots all variables)
    elapsed_forecast_duration : int, optional
        Specific elapsed forecast duration to plot (if None, uses first timestep)
    plot_time : str, optional
        Specific time to plot (format: "YYYY-MM-DD HH:MM:SS")
    random_seed : int, default=42
        Random seed for time selection
    """
    # Handle variable selection
    variables = [var] if var else VARIABLES_GROUND_TRUTH.values()

    if plot_time is None:
        random.seed(random_seed)
        time_index = random.randint(0, len(ds_gt.time) - 1)
        time_selected = ds_ml.time[time_index].values
    else:
        time_selected = plot_time

    if "elapsed_forecast_duration" in ds_ml:
        ds_ml_time = ds_ml.sel(
            start_time=time_selected,
            elapsed_forecast_duration=elapsed_forecast_duration,
        )
        ds_nwp_time = (
            ds_nwp.sel(
                start_time=time_selected,
                elapsed_forecast_duration=elapsed_forecast_duration,
            )
            if ds_nwp is not None
            else None
        )
        ds_gt_time = ds_gt.sel(time=ds_ml_time.forecast_time)
    else:
        ds_ml_time = ds_ml.sel(time=time_selected)
        ds_nwp_time = (
            ds_nwp.sel(time=time_selected) if ds_nwp is not None else None
        )
        ds_gt_time = ds_gt.sel(time=time_selected)

    # Get coordinates
    lons = ds_gt.longitude if hasattr(ds_gt, "longitude") else ds_gt.lon
    lats = ds_gt.latitude if hasattr(ds_gt, "latitude") else ds_gt.lat

    for var in variables:
        # Determine number of subplots based on NWP data availability
        n_plots = 3 if (ds_nwp_time is not None and var in ds_nwp_time) else 2
        fig, axes = plt.subplots(
            1,
            n_plots,
            figsize=(7 * n_plots, 4),
            dpi=DPI,
            subplot_kw={"projection": PROJECTION},
        )
        axes = np.atleast_1d(axes)

        # Select data
        ds_var = ds_gt_time[var]
        ds_ml_var = ds_ml_time[var]

        # Initialize arrays for min/max calculation
        arrays_for_minmax = [ds_var.values, ds_ml_var.values]

        # Add NWP data if available
        if ds_nwp_time is not None and var in ds_nwp_time:
            ds_nwp_var = ds_nwp_time[var]
            arrays_for_minmax.append(ds_nwp_var.values)

        # Add boundary data if available
        if ds_boundary is not None and var in ds_boundary:
            if "elapsed_forecast_duration" in ds_boundary:
                # During evaluation the model has only seen forecasts that were
                # realistically available. The plotted time steps is step t0
                # t-2, t-1 -> t0. Therefore, we need to subtract one step and then
                # select the closest boundary forecast from the past.
                # Exact matches are not allowed, forecast analysis_time must
                # be in the past.
                if (
                    ds_boundary.sel(
                        analysis_time=time_selected - STEP_SIZE,
                        method="pad",
                    ).analysis_time.values
                    == time_selected - STEP_SIZE
                ):
                    steps = 2
                else:
                    steps = 1

                ds_boundary_var = ds_boundary[var].sel(
                    analysis_time=time_selected - steps * STEP_SIZE,
                    method="pad",
                )
                forecast_times = (
                    ds_boundary_var.analysis_time.values
                    + ds_boundary_var.elapsed_forecast_duration.values
                )
                ds_boundary_var = ds_boundary_var.assign_coords(
                    forecast_time=(
                        "elapsed_forecast_duration",
                        forecast_times,
                    )
                ).set_xindex("forecast_time")
                ds_boundary_var = ds_boundary_var.sel(
                    forecast_time=time_selected, method="pad"
                )
                print(
                    "Boundary Start_Time is: ",
                    ds_boundary_var.analysis_time.values,
                )
            else:
                ds_boundary_var = ds_boundary[var].sel(
                    time=time_selected, method="pad"
                )
            arrays_for_minmax.append(ds_boundary_var.values)

        # Calculate global min/max
        combined_array = np.concatenate([
            arr.flatten() for arr in arrays_for_minmax
        ])
        vmin, vmax = np.nanmin(combined_array), np.nanmax(combined_array)

        # Plot boundaries if available
        if ds_boundary is not None and var in ds_boundary:
            for ax in axes:
                ax.contourf(
                    lon_mesh,
                    lat_mesh,
                    ds_boundary_var.values,
                    transform=ccrs.PlateCarree(),
                    cmap="viridis",
                    vmin=vmin,
                    vmax=vmax,
                    alpha=0.5,
                    levels=20,
                )

        # Plot ground truth
        im0 = axes[0].pcolormesh(
            lons,
            lats,
            ds_var.values,
            transform=ccrs.PlateCarree(),
            vmin=vmin,
            vmax=vmax,
            cmap="viridis",
            shading="auto",
        )
        axes[0].set_title("Ground Truth")

        # Plot order depends on NWP availability
        plot_idx = 1
        if ds_nwp_time is not None and var in ds_nwp_time:
            axes[plot_idx].pcolormesh(
                lons,
                lats,
                ds_nwp_var.values,
                transform=ccrs.PlateCarree(),
                vmin=vmin,
                vmax=vmax,
                cmap="viridis",
                shading="auto",
            )
            axes[plot_idx].set_title("NWP Model Prediction")
            plot_idx += 1

        # Plot ML prediction (always last)
        axes[plot_idx].pcolormesh(
            lons,
            lats,
            ds_ml_var.values,
            transform=ccrs.PlateCarree(),
            vmin=vmin,
            vmax=vmax,
            cmap="viridis",
            shading="auto",
        )
        axes[plot_idx].set_title("ML Model Prediction")

        # Add common features to all plots
        for ax in axes:
            ax.coastlines(resolution="50m")
            ax.add_feature(cfeature.BORDERS, linestyle="-", alpha=0.7)
            gl = ax.gridlines(
                draw_labels=True, dms=True, x_inline=False, y_inline=False
            )
            gl.top_labels = False
            gl.right_labels = False

        # Add colorbar
        cbar_ax = fig.add_axes([0.2, -0.05, 0.6, 0.05])
        cbar = fig.colorbar(im0, cax=cbar_ax, orientation="horizontal")
        cbar.set_label(VARIABLE_UNITS[var])

        # Adjust subplot spacing
        plt.subplots_adjust(bottom=0.15, hspace=0.05, wspace=0.05)

        if elapsed_forecast_duration is not None:
            title = f"{var} at {str(time_selected.dt.date.values)} - {time_selected.dt.hour.values:02d} UTC and Elapsed Forecast Duration +{int(elapsed_forecast_duration.values / 1e9 / 3600)}h"
            save_name = f"map_{var}_leadtime_{int(elapsed_forecast_duration.values):02d}"
        else:
            title = f"{var} at {str(time_selected.dt.date.values)} - {time_selected.dt.hour.values:02d} UTC"
            save_name = f"map_{var}"

        plt.suptitle(title)
        save_plot(fig, f"{save_name}", time_selected)
        plt.show()
        plt.close()


In [None]:
if PLOT_TIME is None:
    time_selected = None
else:
    time_selected = ds_gt_first_timestep.sel(
        time=pd.to_datetime(PLOT_TIME)
        + ds_ml.isel(
            elapsed_forecast_duration=ELAPSED_FORECAST_DURATION[0]
        ).elapsed_forecast_duration.values
    ).time

create_comparison_maps(
    ds_gt=ds_gt_first_timestep,
    ds_ml=ds_ml_first_timestep,
    ds_nwp=ds_nwp_first_timestep,
    ds_boundary=ds_boundary,
    plot_time=time_selected,
)


#### Mean Error Plot For The Same Time Step

In [None]:
def create_error_maps(
    ds_gt,
    ds_ml,
    ds_nwp=None,
    var=None,
    elapsed_forecast_duration=None,
    plot_time=None,
    random_seed=42,
):
    """Create error maps for model outputs.

    Parameters:
    -----------
    ds_gt : xarray.Dataset
        Ground truth dataset
    ds_ml : xarray.Dataset
        ML model predictions
    ds_nwp : xarray.Dataset, optional
        NWP model predictions
    var : str, optional
        Variable to plot (if None, plots all variables)
    elapsed_forecast_duration : int, optional
        Specific elapsed forecast duration to plot (if None, uses first timestep)
    plot_time : str, optional
        Specific time to plot (format: "YYYY-MM-DD HH:MM:SS")
    random_seed : int, default=42
        Random seed for time selection
    """
    # Handle variable selection
    variables = [var] if var else VARIABLES_GROUND_TRUTH.values()

    if plot_time is None:
        random.seed(random_seed)
        time_index = random.randint(0, len(ds_gt.time) - 1)
        time_selected = ds_ml.time[time_index].values
    else:
        time_selected = plot_time

    if "elapsed_forecast_duration" in ds_ml:
        ds_ml_time = ds_ml.sel(
            start_time=time_selected,
            elapsed_forecast_duration=elapsed_forecast_duration,
        )
        ds_nwp_time = (
            ds_nwp.sel(
                start_time=time_selected,
                elapsed_forecast_duration=elapsed_forecast_duration,
            )
            if ds_nwp is not None
            else None
        )
        ds_gt_time = ds_gt.sel(time=ds_ml_time.forecast_time)
    else:
        ds_ml_time = ds_ml.sel(time=time_selected)
        ds_nwp_time = (
            ds_nwp.sel(time=time_selected) if ds_nwp is not None else None
        )
        ds_gt_time = ds_gt.sel(time=time_selected)

    # Get coordinates
    lons = ds_gt.longitude if hasattr(ds_gt, "longitude") else ds_gt.lon
    lats = ds_gt.latitude if hasattr(ds_gt, "latitude") else ds_gt.lat

    for var in variables:
        # Determine number of plots based on NWP data availability
        n_plots = 1 if (ds_nwp_time is None or var not in ds_nwp_time) else 2
        fig, axes = plt.subplots(
            1,
            n_plots,
            figsize=(7 * n_plots, 4),
            dpi=DPI,
            subplot_kw={"projection": PROJECTION},
        )
        axes = np.atleast_1d(axes)

        # Calculate errors
        ds_var = ds_gt_time[var]
        ds_ml_var = ds_ml_time[var]
        error_ml = ds_ml_var - ds_var

        # Initialize arrays for min/max calculation
        arrays_for_minmax = [error_ml.values]

        if ds_nwp_time is not None and var in ds_nwp_time:
            ds_nwp_var = ds_nwp_time[var]
            error_nwp = ds_nwp_var - ds_var
            arrays_for_minmax.append(error_nwp.values)

        # Calculate global min/max for symmetric colorbar
        max_abs_error = np.max(np.abs(arrays_for_minmax))
        vmin, vmax = -max_abs_error, max_abs_error

        plot_idx = 0
        if ds_nwp_time is not None and var in ds_nwp_time:
            axes[plot_idx].pcolormesh(
                lons,
                lats,
                error_nwp.values,
                transform=ccrs.PlateCarree(),
                cmap="RdBu",
                vmin=vmin,
                vmax=vmax,
                shading="auto",
            )
            axes[plot_idx].set_title("NWP Model Error")
            plot_idx += 1

        im1 = axes[plot_idx].pcolormesh(
            lons,
            lats,
            error_ml.values,
            transform=ccrs.PlateCarree(),
            cmap="RdBu",
            vmin=vmin,
            vmax=vmax,
            shading="auto",
        )
        axes[plot_idx].set_title("ML Model Error")

        # Add common features to all plots
        for ax in axes:
            ax.coastlines(resolution="50m")
            ax.add_feature(cfeature.BORDERS, linestyle="-", alpha=0.7)
            gl = ax.gridlines(
                draw_labels=True, dms=True, x_inline=False, y_inline=False
            )
            gl.top_labels = False
            gl.right_labels = False

        # Add colorbar
        cbar_ax = fig.add_axes([0.2, -0.05, 0.6, 0.05])
        cbar = fig.colorbar(im1, cax=cbar_ax, orientation="horizontal")
        cbar.set_label(f"Error in {VARIABLE_UNITS[var]}")

        # Adjust subplot spacing
        plt.subplots_adjust(bottom=0.15, hspace=0.05, wspace=0.05)

        if elapsed_forecast_duration is not None:
            title = f"Error in {var} at {str(time_selected.dt.date.values)} - {time_selected.dt.hour.values:02d} UTC and Elapsed Forecast Duration +{int(elapsed_forecast_duration.values / 1e9 / 3600)}h"
            save_name = f"errormap_{var}_leadtime_{int(elapsed_forecast_duration.values):02d}"
        else:
            title = f"Error in {var} at {str(time_selected.dt.date.values)} - {time_selected.dt.hour.values:02d} UTC"
            save_name = f"errormap_{var}"

        plt.suptitle(title)
        save_plot(fig, save_name, time_selected)
        plt.show()
        plt.close()


In [None]:
if PLOT_TIME is None:
    time_selected = None
else:
    time_selected = ds_gt_first_timestep.sel(
        time=pd.to_datetime(PLOT_TIME)
        + ds_ml.isel(
            elapsed_forecast_duration=ELAPSED_FORECAST_DURATION[0]
        ).elapsed_forecast_duration.values
    ).time

# Add this at the end of your script to create error maps for different elapsed forecast durations
create_error_maps(
    ds_gt=ds_gt_first_timestep,
    ds_ml=ds_ml_first_timestep,
    ds_nwp=ds_nwp_first_timestep,
    plot_time=time_selected,
)


### 2. Histograms

By examining these distributions, we can assess whether the ML model and NWP model accurately capture the variability and frequency of different atmospheric states.

**Distribution Shape:** The histograms show whether the models replicate the skewness, kurtosis, and overall shape of the ground truth data distributions.

**Extreme Values:** Identifying how the models handle extreme conditions, such as unusually high or low temperatures, is crucial for weather prediction and risk assessment.

**Normalization Needs:** Differences in scale between variables suggest that normalization may be necessary for accurate comparisons.

In [None]:
for variable_name in VARIABLES_GROUND_TRUTH.values():
    fig, ax = plt.subplots(figsize=(10, 6), dpi=DPI)

    # Sample 10% of each dimension
    ds_gt_sampled = ds_gt[variable_name].isel(
        time=slice(None, None, SUBSAMPLE_HISTOGRAM),
        x=slice(None, None, SUBSAMPLE_HISTOGRAM),
        y=slice(None, None, SUBSAMPLE_HISTOGRAM),
    )
    ds_ml_sampled = ds_ml_first_timestep[variable_name].isel(
        time=slice(None, None, SUBSAMPLE_HISTOGRAM),
        x=slice(None, None, SUBSAMPLE_HISTOGRAM),
        y=slice(None, None, SUBSAMPLE_HISTOGRAM),
    )

    # Convert to numpy arrays
    data_gt = ds_gt_sampled.values.flatten()
    data_ml = ds_ml_sampled.values.flatten()

    ax.hist(
        data_gt,
        bins=500,
        density=True,
        color=COLORS["ground_truth"],
        label="Ground Truth",
    )
    # Plot NWP if available
    if variable_name in ds_nwp_first_timestep:
        ds_nwp_sampled = ds_nwp_first_timestep[variable_name].isel(
            time=slice(None, None, SUBSAMPLE_HISTOGRAM),
            x=slice(None, None, SUBSAMPLE_HISTOGRAM),
            y=slice(None, None, SUBSAMPLE_HISTOGRAM),
        )
        data_nwp = ds_nwp_sampled.values.flatten()
        ax.hist(
            data_nwp,
            bins=500,
            alpha=0.8,
            density=True,
            color=COLORS["nwp"],
            label="NWP Model Prediction",
        )

    # Create histograms for ML and ground truth
    ax.hist(
        data_ml,
        bins=500,
        alpha=0.8,
        density=True,
        color=COLORS["ml"],
        label="ML Model Prediction",
    )

    # Add labels and title
    units = VARIABLE_UNITS[variable_name]
    ax.set_title(f"Distribution of {variable_name} ({units})")
    ax.legend()

    # Calculate skewness and kurtosis
    stats_gt = f"Ground Truth:\nSkewness: {skew(data_gt):.2f}\nKurtosis: {kurtosis(data_gt):.2f}"
    stats_ml = f"ML Model:\nSkewness: {skew(data_ml):.2f}\nKurtosis: {kurtosis(data_ml):.2f}"

    # Combine stats
    stats_text = stats_gt + "\n\n" + stats_ml

    if variable_name in ds_nwp_first_timestep:
        stats_nwp = f"NWP Model:\nSkewness: {skew(data_nwp):.2f}\nKurtosis: {kurtosis(data_nwp):.2f}"
        stats_text = stats_text + "\n\n" + stats_nwp

    # Add text box
    ax.text(
        0.95,
        0.55,
        stats_text,
        transform=ax.transAxes,
        bbox=dict(alpha=0.8, facecolor="white", edgecolor="black"),
        color="black",
        verticalalignment="bottom",
        horizontalalignment="right",
        fontsize=10,
    )

    plt.tight_layout()
    save_plot(fig, f"histogram_{variable_name}")
    plt.show()


### 3. Energy Spectra

This chapter examines how energy is distributed across different spatial scales
in the atmosphere by computing and comparing the energy spectra of both models.
This analysis is critical in understanding the models' capabilities to simulate
atmospheric processes ranging from large-scale weather systems to small-scale
turbulence.

**FFT Computation:** The Fast Fourier Transform (FFT) is used to transform spatial data into the frequency domain, revealing how different scales contribute to the overall energy. The energy spectra are averaged over latitudes.

**Scale Representation:** The energy spectra show whether the ML model captures the correct amount of energy at various spatial scales.

**Effective Resolution:** Identifying the effective resolution helps understand the smallest scales that the model can reliably simulate.

**Numerical Artifacts:** Limitations in numerical precision can introduce artifacts in the spectra, especially at the smallest scales.

In [None]:
def calculate_energy_spectra(data):
    """Calculate the energy spectra of the given data using 2D FFT.

    Parameters
    ----------
    data : xarray.DataArray
        The data for which the energy spectra should be calculated.
        Expected dimensions: (x, y, time)

    Returns
    -------
    wavenumber : np.ndarray
        The isotropic wavenumbers.
    power : np.ndarray
        The power spectrum averaged over time and azimuthally.
    effective_resolution : float
        The effective resolution of the model.
    """
    # Get grid spacing in meters
    dx = abs(float(data.x[1] - data.x[0]))
    dy = abs(float(data.y[1] - data.y[0]))

    # Transpose data to (time, y, x) for FFT
    var_data = data.transpose("time", "y", "x")

    # Get dimensions
    _, ny, nx = var_data.shape

    # Compute 2D FFT for each time step
    fft_data = np.fft.rfft2(var_data, axes=(-2, -1))  # (time, ky, kx)

    # Calculate power spectrum
    power_spectrum = np.abs(fft_data) ** 2

    # Get wavenumbers
    kx = np.fft.rfftfreq(nx, d=dx)  # x-direction wavenumbers
    ky = np.fft.fftfreq(ny, d=dy)  # y-direction wavenumbers

    # Create 2D wavenumber grid
    kxx, kyy = np.meshgrid(kx, ky)
    k_mag = np.sqrt(kxx**2 + kyy**2)  # Magnitude of wavenumber vector

    # Create wavenumber bins for azimuthal averaging
    k_bins = np.logspace(
        np.log10(k_mag[k_mag > 0].min()), np.log10(k_mag.max()), num=50
    )

    # Average power spectrum over time
    power_spectrum = power_spectrum.mean(axis=0)

    # Perform azimuthal averaging
    k_averaged = []
    power_averaged = []

    for i in range(len(k_bins) - 1):
        k_mask = (k_mag >= k_bins[i]) & (k_mag < k_bins[i + 1])
        if k_mask.any():
            k_averaged.append(np.mean(k_mag[k_mask]))
            power_averaged.append(np.mean(power_spectrum[k_mask]))

    # Convert to arrays
    k_averaged = np.array(k_averaged)
    power_averaged = np.array(power_averaged)

    # Calculate effective resolution (wavelength corresponding to 4dx)
    effective_resolution = 1 / (4 * dx)

    # Remove first and last two wavenumbers
    return (k_averaged[2:-2], power_averaged[2:-2], effective_resolution)


def plot_energy_spectra(ds_gt, ds_nwp, ds_ml, var, level=None):
    """Plot energy spectra comparison with LSD metric."""
    if level is not None:
        var_data = ds_gt[var].sel(z=level)
        if var in ds_nwp:
            var_data_nwp = ds_nwp[var].sel(z=level)
        var_data_ml = ds_ml[var].sel(z=level)
    else:
        var_data = ds_gt[var]
        if var in ds_nwp:
            var_data_nwp = ds_nwp[var]
        var_data_ml = ds_ml[var]

    # Calculate energy spectra
    wavenumber_gt, spectrum_gt, effective_resolution = calculate_energy_spectra(
        var_data
    )
    if var in ds_nwp:
        wavenumber_nwp, spectrum_nwp, _ = calculate_energy_spectra(var_data_nwp)
    else:
        spectrum_nwp = None
    wavenumber_ml, spectrum_ml, _ = calculate_energy_spectra(var_data_ml)

    # Create plot
    fig, ax = plt.subplots(figsize=(10, 6), dpi=DPI)

    # Plot spectra
    ax.loglog(
        wavenumber_gt,
        spectrum_gt,
        color=COLORS["ground_truth"],
        label="Ground Truth",
        linestyle=LINE_STYLES["ground_truth"][0],
        marker=LINE_STYLES["ground_truth"][1],
        markevery=5,
    )  # Add markers every 5 points for clarity

    if var in ds_nwp:
        ax.loglog(
            wavenumber_nwp,
            spectrum_nwp,
            color=COLORS["nwp"],
            label="NWP Model Prediction",
            linestyle=LINE_STYLES["nwp"][0],
            marker=LINE_STYLES["nwp"][1],
            markevery=3,
        )

    ax.loglog(
        wavenumber_ml,
        spectrum_ml,
        color=COLORS["ml"],
        label="ML Model Prediction",
        linestyle=LINE_STYLES["ml"][0],
        marker=LINE_STYLES["ml"][1],
        markevery=4,
    )

    # Plot effective resolution
    ax.axvline(
        effective_resolution,
        color="salmon",
        linestyle="--",
        label="Effective Model Resolution",
    )

    # Add LSD metric
    add_lsd_to_plot(ax, spectrum_gt, spectrum_nwp, spectrum_ml)

    # Customize plot
    ax.set_xlabel("Wavenumber [1/m]")
    unit = VARIABLE_UNITS.get(var, "")
    ax.set_ylabel(f"Power Spectral Density [{unit}²/m]")
    title = f"Energy Spectra Comparison for {var}"
    if level is not None:
        title += f" at Level {level} hPa"
    ax.set_title(title)
    ax.legend()
    ax.grid(True, which="both", ls="--", alpha=0.5)

    # Save plot
    plot_name = f"energy_spectra_{var}"
    if level is not None:
        plot_name += f"_level_{level}"
    save_plot(fig, plot_name)
    plt.tight_layout()
    plt.show()
    return fig, ax


def calculate_log_spectral_distance(true_spectrum, nwp_spectrum, ml_spectrum):
    """
    Calculate the Log Spectral Distance between three power spectra
    """
    eps = 1e-10
    log_spec1 = np.log10(true_spectrum + eps)
    log_spec3 = np.log10(ml_spectrum + eps)
    lsd_ml = np.sqrt(np.mean((log_spec1 - log_spec3) ** 2))
    if nwp_spectrum is None:
        return None, lsd_ml
    log_spec2 = np.log10(nwp_spectrum + eps)
    lsd_nwp = np.sqrt(np.mean((log_spec1 - log_spec2) ** 2))
    return lsd_nwp, lsd_ml


def add_lsd_to_plot(ax, true_spectrum, nwp_spectrum, ml_spectrum):
    """
    Add LSD metric as text box to spectrum plot
    """
    if nwp_spectrum is None:
        lsd_nwp = None
        _, lsd_ml = calculate_log_spectral_distance(
            true_spectrum, None, ml_spectrum
        )
        textstr = f"LSD ML = {lsd_ml:.4f}"
    else:
        lsd_nwp, lsd_ml = calculate_log_spectral_distance(
            true_spectrum, nwp_spectrum, ml_spectrum
        )
        textstr = f"LSD NWP = {lsd_nwp:.4f}, LSD ML = {lsd_ml:.4f}"
    props = dict(boxstyle="round", facecolor="wheat", alpha=0.5)
    ax.text(
        0.4,
        0.05,
        textstr,
        transform=ax.transAxes,
        verticalalignment="top",
        bbox=props,
    )

To interpret the Log-Spectral Distance (LSD) metric:
The LSD quantifies the difference between two spectra, with lower values indicating better similarity (area between the two spectra). 

Lower values indicate better similarity between spectra
- LSD = 0 means identical spectra

Typical values depend on the specific application, but generally:
- LSD < 1: Good similarity
- 1 < LSD < 2: Moderate differences
- LSD > 2: Significant differences

In [None]:
# For 2D variables
for var in VARIABLES_GROUND_TRUTH.values():
    fig, ax = plot_energy_spectra(
        ds_gt, ds_nwp_first_timestep, ds_ml_first_timestep, var
    )
    plt.show()

In [None]:
def display_and_export_lsd_table(
    ds_gt, ds_ml, ds_nwp, variables, name, caption=""
):
    """
    Display and export a table with LSD metrics for variables as rows and ML/NWP as columns.

    Args:
        ds_gt, ds_ml, ds_nwp: Input datasets
        variables: Variables to analyze
        name: Name for the exported files
        caption: Caption for the LaTeX table
    """
    # Initialize data dictionary with ML and NWP columns
    lsd_data = {var: {"ML": None, "NWP": None} for var in variables}

    for var in variables:
        var_data_gt = ds_gt[var]
        var_data_ml = ds_ml[var]

        # Calculate ML metrics
        _, spectrum_gt, _ = calculate_energy_spectra(var_data_gt)
        _, spectrum_ml, _ = calculate_energy_spectra(var_data_ml)
        _, lsd_ml = calculate_log_spectral_distance(
            spectrum_gt, None, spectrum_ml
        )
        lsd_data[var]["ML"] = lsd_ml

        # Calculate NWP metrics if variable exists in NWP data
        if var in ds_nwp:
            var_data_nwp = ds_nwp[var]
            _, spectrum_nwp, _ = calculate_energy_spectra(var_data_nwp)
            _, lsd_nwp = calculate_log_spectral_distance(
                spectrum_gt, None, spectrum_nwp
            )
            lsd_data[var]["NWP"] = lsd_nwp

    df = pd.DataFrame(lsd_data).T

    # Display styled table
    styled_df = df.style.format(
        lambda x: f"{x:.3f}" if pd.notnull(x) else "-"
    ).map(
        lambda x: f"color: {'green' if x < 1 else 'orange' if x < 2 else 'red'}"
        if pd.notnull(x)
        else ""
    )
    display(styled_df)

    # Export raw dataframe
    export_table(df, name, caption)


display_and_export_lsd_table(
    ds_gt_first_timestep,
    ds_ml_first_timestep,
    ds_nwp_first_timestep,
    VARIABLES_GROUND_TRUTH.values(),
    name="lsd_metrics",
    caption="Log Spectral Distance (LSD) metrics comparison between ML and NWP models",
)


### 4. Vertical Profiles

In this chapter, the focus is on assessing how the relative error between the ML
model and ground truth data.
Vertical profiles are essential for understanding the atmospheric structure and
processes at different pressure levels. Obviously these plots only work for 3D
variables.

**Relative Error Calculation:** Using percentage differences provides a
normalized measure of error that is comparable across variables and vertical
levels.

**Altitude-Specific Insights:** The plots reveal whether the ML model performs
consistently across different altitudes or if certain layers pose challenges.

**Atmospheric Dynamics:** Accurate representation of vertical profiles is
crucial for modeling phenomena like convection or jet stream anomalies.

**Pressure Level Interpretation:** Lower vertical levels correspond to higher
altitudes. Inverted axes help represent this correctly but can be
counterintuitive.

In [None]:
def plot_vertical_errors_multigrid(ds_gt, ds_ml, variables):
    """
    Plot vertical profiles of relative error for multiple variables in a grid.
    Variables with same prefix are stacked along level dimension.
    """
    # Group variables by their prefix (without level)
    var_groups = {}
    for var in variables.values():
        if any(var.startswith(v) for v in VARIABLES_3D):
            # Extract base name (without level) and level number
            base_name = "_".join(var.split("_")[:-2])  # Remove 'level_XX'
            level = int(var.split("_")[-1])

            if base_name not in var_groups:
                var_groups[base_name] = {"vars": [], "levels": []}
            var_groups[base_name]["vars"].append(var)
            var_groups[base_name]["levels"].append(level)
            var_groups[base_name]["unit"] = VARIABLE_UNITS[var]

    assert len(var_groups) > 0, "No 3D variables found in the dataset"

    # Stack variables for each dataset
    stacked_ds_gt = {}
    stacked_ds_ml = {}

    for base_name, group in var_groups.items():
        # Sort by level to ensure correct ordering
        sorted_idx = np.argsort(group["levels"])
        sorted_vars = [group["vars"][i] for i in sorted_idx]
        sorted_levels = [group["levels"][i] for i in sorted_idx]

        # Stack along new level dimension
        stacked_ds_gt[base_name] = xr.concat(
            [ds_gt[var] for var in sorted_vars],
            dim=pd.Index(sorted_levels, name="level"),
        )
        stacked_ds_ml[base_name] = xr.concat(
            [ds_ml[var] for var in sorted_vars],
            dim=pd.Index(sorted_levels, name="level"),
        )

    # Plot stacked variables
    num_vars = len(var_groups)
    cols = 2
    rows = (num_vars + 1) // cols

    fig, axes = plt.subplots(rows, cols, figsize=(12, 4 * rows), dpi=DPI)
    axes = axes.flatten()

    epsilon = 1e-6

    for i, (base_name, _) in enumerate(var_groups.items()):
        relative_error_ml = (
            abs(stacked_ds_ml[base_name] - stacked_ds_gt[base_name])
            / abs(stacked_ds_gt[base_name] + epsilon)
        ).mean(dim=["time", "y", "x"]) * 100

        # Plot vertical profile
        axes[i].plot(
            relative_error_ml,
            relative_error_ml.level,
            color=COLORS["ml"],
            linestyle=LINE_STYLES["ml"][0],
            marker=LINE_STYLES["ml"][1],
            linewidth=2,
            label="ML Model",
        )
        unit = var_groups[base_name]["unit"][0]
        axes[i].set_title(f"Relative Error for {base_name} [{unit}]", size=12)
        axes[i].set_xlabel("Relative Error (%)", size=10)
        axes[i].set_ylabel("Level", size=10)
        axes[i].grid(True, alpha=0.3)
        axes[i].invert_yaxis()
        axes[i].legend()

    # Hide unused subplots
    for j in range(i + 1, len(axes)):
        fig.delaxes(axes[j])

    # Save plot
    save_plot(fig, f"vertical_profile_{base_name}")
    plt.tight_layout()
    plt.show()

    return (
        stacked_ds_gt,
        stacked_ds_ml,
    )  # Return stacked datasets for further use


stacked_gt, stacked_ml = plot_vertical_errors_multigrid(
    ds_gt_first_timestep, ds_ml_first_timestep, VARIABLES_GROUND_TRUTH
)

### 5. Various Verification Metrics

The final chapter consolidates various statistical metrics to provide a broad
evaluation of the ML model's performance. By considering multiple metrics, we
gain a nuanced understanding of both the strengths and weaknesses of the model.

**Metric Diversity:** Including MAE, RMSE, MSE, Pearson correlation, and the
Fractions Skill Score (FSS) covers different aspects of model performance, from
average errors to spatial pattern accuracy.

**MAE, MSE and RMSE:** Offer insights into the average magnitude of errors, with
RMSE emphasizing larger discrepancies. The colors indicating high errors are
only implemented for these three metrics with standardization.

**Pearson Correlation:** Assesses the linear relationship, indicating whether
the model captures variability even if biases exist.

**FSS:** Evaluates spatial accuracy, which is particularly important for
predicting localized weather events.

**Wasserstein Distance:** Provides a holistic view of distributional similarity
across variables. Same as chapter 3.

**Holistic Assessment:** The combination of metrics provides a comprehensive
performance profile, essential for model validation and comparison. More complex metrics are explained in more detail.

#### Fractional Skill Score
Range: 0 to 1, where:
- 1 = perfect score
- 0 = no skill compared to random chance

**Key Properties:**
- FSS measures the spatial agreement between two fields, accounting for the spatial scale of the features
- It's particularly useful for assessing the spatial distribution of precipitation, cloud cover, or other fields with spatial structure

**Advantages:**
- More meaningful than simple correlation for spatial fields
- Accounts for the spatial scale of features
- Provides a single value for the entire field comparison

In [None]:
# These helper functions are only used to calculate the FSS threshold
max_spatial_dim = np.maximum(ds_gt.x.size, ds_gt.y.size)
window_size = (max_spatial_dim // 100,) * 2
n_points = int(
    np.minimum(
        SUBSAMPLE_FSS_THRESHOLD,
        ds_ml[list(VARIABLES_GROUND_TRUTH.values())[0]]
        .isel(elapsed_forecast_duration=0)
        .size,
    )
)
print(f"Using window size for FSS: {window_size}")
print(f"Using n_points for FSS: {n_points}")

In [None]:
def calculate_all_metrics(ds_gt, ds_nwp, ds_ml, metrics_to_compute=None):
    """Calculate a set of metrics for each variable in the given datasets.

    Args:
        ds_gt: Ground truth dataset
        ds_nwp: NWP predictions (optional)
        ds_ml: ML model predictions
        metrics_to_compute: List of metrics to compute (default: all metrics)
            Options: ['MAE', 'RMSE', 'MSE', 'RelativeMAE', 'RelativeLRMSE',
                     'PearsonR', 'FSS', 'Wasserstein']
    """
    if metrics_to_compute is None:
        metrics_to_compute = [
            "MAE",
            "RMSE",
            "MSE",
            "RelativeMAE",
            "RelativeRMSE",
            "PearsonR",
            "FSS",
            "Wasserstein",
        ]

    variables = list(ds_gt.data_vars)
    metrics_dict = {}

    def get_base_metrics(
        var,
        y_true,
        y_pred_ml,
        quantile_90,
    ):
        base_dict = {}

        if "MAE" in metrics_to_compute:
            base_dict["MAE ML"] = mae(y_true, y_pred_ml).values
        if "RMSE" in metrics_to_compute:
            base_dict["RMSE ML"] = rmse(y_true, y_pred_ml).values
        if "MSE" in metrics_to_compute:
            base_dict["MSE ML"] = mse(y_true, y_pred_ml).values
        if "RelativeMAE" in metrics_to_compute:
            base_dict["Relative MAE ML"] = np.mean(
                np.abs(y_true.values - y_pred_ml.values)
                / (np.abs(y_true.values) + 1e-6)
            )
        if "RelativeRMSE" in metrics_to_compute:
            base_dict["Relative RMSE ML"] = np.sqrt(
                np.mean(
                    (y_true.values - y_pred_ml.values) ** 2
                    / (y_true.values**2 + 1e-6)
                )
            )
        if "PearsonR" in metrics_to_compute:
            base_dict["Pearson R ML"] = pearsonr(y_true, y_pred_ml).values
        if "FSS" in metrics_to_compute:
            base_dict["FSS ML"] = fss_2d(
                y_pred_ml,
                y_true,
                event_threshold=quantile_90,
                window_size=window_size,
                spatial_dims=["y", "x"],
            ).values
        if "Wasserstein" in metrics_to_compute:
            base_dict["Wasserstein ML"] = wasserstein_distance(
                y_true.values.flatten(), y_pred_ml.values.flatten()
            )

        return base_dict

    for var in variables:
        print(f"Calculating metrics for variable: {var}")
        y_true = ds_gt[var].compute()
        y_pred_ml = ds_ml[var].compute()

        sample = np.random.choice(
            y_true.values.ravel(), n_points, replace=False
        )
        quantile_90 = np.quantile(sample, 0.90)

        metrics_dict[var] = get_base_metrics(
            var, y_true, y_pred_ml, quantile_90
        )

        if ds_nwp is not None and var in ds_nwp:
            y_pred_nwp = ds_nwp[var].compute()
            nwp_metrics = {}

            if "MAE" in metrics_to_compute:
                nwp_metrics["MAE NWP"] = mae(y_true, y_pred_nwp).values
            if "RMSE" in metrics_to_compute:
                nwp_metrics["RMSE NWP"] = rmse(y_true, y_pred_nwp).values
            if "MSE" in metrics_to_compute:
                nwp_metrics["MSE NWP"] = mse(y_true, y_pred_nwp).values
            if "RelativeMAE" in metrics_to_compute:
                nwp_metrics["Relative MAE ML"] = np.mean(
                    np.abs(y_true.values - y_pred_ml.values)
                    / (np.abs(y_true.values) + 1e-6)
                )
            if "RelativeRMSE" in metrics_to_compute:
                nwp_metrics["Relative RMSE ML"] = np.sqrt(
                    np.mean(
                        (y_true.values - y_pred_ml.values) ** 2
                        / (y_true.values**2 + 1e-6)
                    )
                )
            if "PearsonR" in metrics_to_compute:
                nwp_metrics["Pearson R NWP"] = pearsonr(
                    y_true, y_pred_nwp
                ).values
            if "FSS" in metrics_to_compute:
                nwp_metrics["FSS NWP"] = fss_2d(
                    y_pred_nwp,
                    y_true,
                    event_threshold=quantile_90,
                    window_size=window_size,
                    spatial_dims=["y", "x"],
                ).values
            if "Wasserstein" in metrics_to_compute:
                nwp_metrics["Wasserstein NWP"] = wasserstein_distance(
                    y_true.values.flatten(), y_pred_nwp.values.flatten()
                )

            metrics_dict[var].update(nwp_metrics)

    metrics_df = pd.DataFrame.from_dict(metrics_dict, orient="index")

    export_table(
        metrics_df,
        "verification_metrics",
        caption="Performance metrics for ML model predictions",
    )

    return metrics_df


def get_formatters(has_nwp=True):
    base_formatters = {
        "MAE ML": "{:.4f}",
        "RMSE ML": "{:.4f}",
        "MSE ML": "{:.4f}",
        "Relative MAE ML": "{:.4f}",
        "Relative RMSE ML": "{:.3e}",
        "Pearson R ML": "{:.4f}",
        "FSS ML": "{:.4f}",
        "Wasserstein ML": "{:.4f}",
    }

    if has_nwp:
        base_formatters.update({
            "MAE NWP": "{:.4f}",
            "RMSE NWP": "{:.4f}",
            "MSE NWP": "{:.4f}",
            "Relative MAE NWP": "{:.4f}",
            "Relative RMSE NWP": "{:.3e}",
            "Pearson R NWP": "{:.4f}",
            "FSS NWP": "{:.4f}",
            "Wasserstein NWP": "{:.4f}",
        })

    return base_formatters


In [None]:
# Set up and run Dask cluster for NWP-variables
with LocalCluster(
    n_workers=32, threads_per_worker=1, memory_limit="8GB"
) as cluster:
    with Client(cluster) as client:
        print(f"Dask dashboard available at: {client.dashboard_link}")

        # Calculate metrics
        regular_metrics = calculate_all_metrics(
            ds_gt_first_timestep[list(VARIABLES_NWP.values())],
            ds_nwp_first_timestep[list(VARIABLES_NWP.values())],
            ds_ml_first_timestep[list(VARIABLES_NWP.values())],
            metrics_to_compute=[
                "MAE",
                "RMSE",
                # "MSE",
                "RelativeMAE",
                "RelativeRMSE",
                "PearsonR",
                "FSS",
                # "Wasserstein",
            ],
        )

        # Display results
        print("Regular Metrics:")
        display(regular_metrics.style.format(get_formatters()))

In [None]:
# Set up and run Dask cluster for variables not in NWP
with LocalCluster(
    n_workers=32, threads_per_worker=1, memory_limit="8GB"
) as cluster:
    with Client(cluster) as client:
        print(f"Dask dashboard available at: {client.dashboard_link}")

        # Calculate metrics
        regular_metrics = calculate_all_metrics(
            ds_gt_first_timestep[
                [
                    var
                    for var in VARIABLES_ML.values()
                    if var not in VARIABLES_NWP.values()
                ]
            ],
            None,
            ds_ml_first_timestep[
                [
                    var
                    for var in VARIABLES_ML.values()
                    if var not in VARIABLES_NWP.values()
                ]
            ],
            metrics_to_compute=[
                "MAE",
                "RMSE",
                # "MSE",
                "RelativeMAE",
                "RelativeRMSE",
                "PearsonR",
                "FSS",
                # "Wasserstein",
            ],
        )

        # Display results
        print("Regular Metrics:")
        display(regular_metrics.style.format(get_formatters(has_nwp=False)))


#### Equitable Skill Score (MeteoSwiss Modified Version)
Range: 0 to 1, where:
- 1 = perfect score
- 0.5 = no skill compared to random chance
- < 0.5 = worse than random chance

**Key Properties:**
- Modified ETS rescales the traditional ETS using: ETSrescaled = ETS/2 + 0.5
- Measures how well predicted events correspond to observed events, accounting for hits due to random chance
- Particularly useful for rare events (like precipitation above a high threshold)
- More equitable than simple Threat Score by accounting for hits due to random chance

**Advantages:**
- More intuitive scale from 0 to 1 compared to traditional ETS
- Reference point at 0.5 makes interpretation clearer
- Penalizes both misses and false alarms
- Accounts for random chance, making it more robust than basic threat scores
- Maintains original ETS properties while providing more intuitive scaling

#### Frequency Bias Index
Range: 0 to infinity, where:
- 1 = no bias
- < 1 = underforecasting
- > 1 = overforecasting

**Key Properties:**
- FBI measures the ratio of observed to forecasted events, indicating whether the model tends to over- or underforecast
- It's particularly useful for understanding systematic biases in event frequency

**Advantages:**
- Provides a clear indication of over- or underforecasting
- Easy to interpret: 1 indicates no bias, while values above or below 1 show the direction and magnitude of the bias

In [None]:
precip_thresholds = [0.1, 1, 5]  # mm/h
wind_thresholds = [2.5, 5, 10]  # m/s

In [None]:
# Set display options for all float values
pd.set_option("display.float_format", lambda x: "{:.4f}".format(x))


def frequency_bias(obs, pred, threshold):
    """Calculate Frequency Bias Index (FBI) for binary events."""
    count_obs = np.sum(obs >= threshold)
    count_pred = np.sum(pred >= threshold)
    return count_pred / count_obs if count_obs > 0 else np.nan


def mean_error(obs, pred):
    """Calculate mean error (ME) between prediction and observation."""
    return float(np.mean(pred - obs))


def mean_absolute_error(obs, pred):
    """Calculate mean absolute error (MAE) between prediction and observation."""
    return float(np.mean(np.abs(pred - obs)))


def ets_for_threshold(obs, pred, threshold):
    """Calculate ETS with MeteoSwiss rescaling (ETSrescaled = ETS/2 + 0.5)."""
    if isinstance(obs, np.ndarray):
        obs = xr.DataArray(obs)
    if isinstance(pred, np.ndarray):
        pred = xr.DataArray(pred)

    # Convert to binary using threshold
    f_binary = (pred > threshold).astype(int)
    o_binary = (obs > threshold).astype(int)

    try:
        # Compute contingency table
        hits = float(((f_binary == 1) & (o_binary == 1)).sum())
        false_alarms = float(((f_binary == 1) & (o_binary == 0)).sum())
        misses = float(((f_binary == 0) & (o_binary == 1)).sum())
        correct_zeros = float(((f_binary == 0) & (o_binary == 0)).sum())

        total = hits + false_alarms + misses + correct_zeros

        # Calculate hits due to chance
        hits_random = ((hits + false_alarms) * (hits + misses)) / total

        # Calculate original ETS score
        denominator = hits - hits_random + false_alarms + misses
        if denominator == 0:
            ets = 0.0 if total > 0 else 1.0
        else:
            ets = (hits - hits_random) / denominator

        # Apply MeteoSwiss rescaling
        ets_rescaled = ets / 2.0 + 0.5

        return max(0.0, min(1.0, ets_rescaled))
    except Exception as e:
        print(f"Error computing ETS: {e}")
        return 0.0


def stdev_error(obs, pred):
    """Calculate standard deviation of error."""
    return abs(np.std(pred - obs))


if "precipitation" in ds_gt_first_timestep:
    # --- Group 1: Precipitation Metrics ---
    results_precip = {}

    y_true_all = ds_gt_first_timestep["precipitation"].values
    y_ml_all = ds_ml_first_timestep["precipitation"].values
    y_nwp_all = (
        ds_nwp_first_timestep["precipitation"].values
        if "precipitation" in ds_nwp_first_timestep
        else None
    )

    for thr in precip_thresholds:
        # Calculate masks
        mask_ml = (y_true_all >= thr) | (y_ml_all >= thr)
        mask_nwp = (
            (y_true_all >= thr) | (y_nwp_all >= thr)
            if y_nwp_all is not None
            else None
        )

        # Calculate ML metrics
        mae_ml = mean_absolute_error(y_true_all, y_ml_all)
        me_ml = mean_error(y_true_all, y_ml_all)
        fbi_ml = frequency_bias(y_true_all, y_ml_all, thr)
        ets_ml = (
            ets_for_threshold(y_true_all[mask_ml], y_ml_all[mask_ml], thr)
            if mask_ml.any()
            else np.nan
        )

        # Calculate NWP metrics
        if y_nwp_all is not None:
            mae_nwp = mean_absolute_error(y_true_all, y_nwp_all)
            me_nwp = mean_error(y_true_all, y_nwp_all)
            fbi_nwp = frequency_bias(y_true_all, y_nwp_all, thr)
            ets_nwp = (
                ets_for_threshold(
                    y_true_all[mask_nwp], y_nwp_all[mask_nwp], thr
                )
                if mask_nwp.any()
                else np.nan
            )
        else:
            mae_nwp = me_nwp = fbi_nwp = ets_nwp = np.nan

        results_precip[f"{thr} mm/h"] = {
            "MAE ML": mae_ml,
            "ME ML": me_ml,
            "FBI ML": fbi_ml,
            "ETS ML": ets_ml,
            "MAE NWP": mae_nwp,
            "ME NWP": me_nwp,
            "FBI NWP": fbi_nwp,
            "ETS NWP": ets_nwp,
        }
    # Convert DataFrames with explicit formatting
    results_precip_df = pd.DataFrame(results_precip).T.round(4)
    export_table(
        results_precip_df,
        "precipitation_metrics",
        caption="Precipitation verification metrics for different thresholds",
    )
    # Display results
    print("Precipitation Metrics:")
    display(pd.DataFrame(results_precip).T)


if (
    "wind_u_10m" in ds_gt_first_timestep
    and "wind_v_10m" in ds_gt_first_timestep
):
    # --- Group 2: Wind Metrics ---
    results_wind = {}

    for var in ["wind_u_10m", "wind_v_10m"]:
        var_results = {}
        y_true_all = ds_gt_first_timestep[var].values
        y_ml_all = ds_ml_first_timestep[var].values
        y_nwp_all = (
            ds_nwp_first_timestep[var].values
            if var in ds_nwp_first_timestep
            else None
        )

        for thr in wind_thresholds:
            # Calculate masks
            mask_ml = (y_true_all >= thr) | (y_ml_all >= thr)
            mask_nwp = (
                (y_true_all >= thr) | (y_nwp_all >= thr)
                if y_nwp_all is not None
                else None
            )

            # Calculate ML metrics
            mae_ml = mean_absolute_error(y_true_all, y_ml_all)
            me_ml = mean_error(y_true_all, y_ml_all)
            fbi_ml = frequency_bias(y_true_all, y_ml_all, thr)
            ets_ml = (
                ets_for_threshold(y_true_all[mask_ml], y_ml_all[mask_ml], thr)
                if mask_ml.any()
                else np.nan
            )

            # Calculate NWP metrics
            if y_nwp_all is not None:
                mae_nwp = mean_absolute_error(y_true_all, y_nwp_all)
                me_nwp = mean_error(y_true_all, y_nwp_all)
                fbi_nwp = frequency_bias(y_true_all, y_nwp_all, thr)
                ets_nwp = (
                    ets_for_threshold(
                        y_true_all[mask_nwp], y_nwp_all[mask_nwp], thr
                    )
                    if mask_nwp.any()
                    else np.nan
                )
            else:
                mae_nwp = me_nwp = fbi_nwp = ets_nwp = np.nan

            var_results[f"{thr} m/s"] = {
                "MAE ML": mae_ml,
                "ME ML": me_ml,
                "FBI ML": fbi_ml,
                "ETS ML": ets_ml,
                "MAE NWP": mae_nwp,
                "ME NWP": me_nwp,
                "FBI NWP": fbi_nwp,
                "ETS NWP": ets_nwp,
            }
        results_wind[var] = var_results

    results_wind_dfs = {
        var: pd.DataFrame(metrics).T.round(4)
        for var, metrics in results_wind.items()
    }

    for var, metrics in results_wind.items():
        export_table(
            pd.DataFrame(metrics).T,
            f"wind_metrics_{var}",
            caption=f"Wind verification metrics for {var} at different thresholds",
        )
        print(f"\nWind Metrics for {var}:")
        display(pd.DataFrame(metrics).T)

if (
    "surface_net_shortwave_radiation" in ds_gt_first_timestep
    and "temperature_2m" in ds_gt_first_timestep
    and "surface_net_longwave_radiation" in ds_gt_first_timestep
):
    # --- Group 3: Radiation and Temperature Metrics ---
    vars_stdev = [
        "surface_net_shortwave_radiation",
        "surface_net_longwave_radiation",
        "temperature_2m",
    ]
    results_stdev = {}

    for var in vars_stdev:
        y_true = ds_gt_first_timestep[var].values
        y_ml = ds_ml_first_timestep[var].values
        y_nwp = (
            ds_nwp_first_timestep[var].values
            if var in ds_nwp_first_timestep
            else None
        )

        # Calculate ML metrics
        mae_ml = mean_absolute_error(y_true, y_ml)
        me_ml = mean_error(y_true, y_ml)
        stdev_ml = stdev_error(y_true, y_ml)

        # Calculate NWP metrics
        if y_nwp is not None:
            mae_nwp = mean_absolute_error(y_true, y_nwp)
            me_nwp = mean_error(y_true, y_nwp)
            stdev_nwp = stdev_error(y_true, y_nwp)
        else:
            mae_nwp = me_nwp = stdev_nwp = np.nan

        results_stdev[var] = {
            "MAE ML": mae_ml,
            "ME ML": me_ml,
            "STDEV-ERR ML": stdev_ml,
            "MAE NWP": mae_nwp,
            "ME NWP": me_nwp,
            "STDEV-ERR NWP": stdev_nwp,
        }

    results_stdev_df = pd.DataFrame(results_stdev).T.round(4)
    print("\nMAE and STDEV Error Metrics for ASOB, ATHB, t_2m:")
    display(results_stdev_df)

In [None]:
def wind_vector_rmse(u_true, v_true, u_pred, v_pred):
    """Calculate RMSE based on wind vector differences."""
    rmse_u = rmse(u_true, u_pred)
    rmse_v = rmse(v_true, v_pred)

    # Calculate vector RMSE
    rmse_wind = np.sqrt(rmse_u**2 + rmse_v**2)

    return float(rmse_wind)


if (
    "wind_u_10m" in ds_gt_first_timestep
    and "wind_v_10m" in ds_gt_first_timestep
):
    # Get both components for ML and ground truth
    u_true = ds_gt_first_timestep["wind_u_10m"]
    v_true = ds_gt_first_timestep["wind_v_10m"]
    u_ml = ds_ml_first_timestep["wind_u_10m"]
    v_ml = ds_ml_first_timestep["wind_v_10m"]

    # Calculate vector RMSE for ML
    wind_rmse_ml = wind_vector_rmse(u_true, v_true, u_ml, v_ml)

    # Calculate for NWP if available
    if (
        "wind_u_10m" in ds_nwp_first_timestep
        and "wind_v_10m" in ds_nwp_first_timestep
    ):
        u_nwp = ds_nwp_first_timestep["wind_u_10m"]
        v_nwp = ds_nwp_first_timestep["wind_v_10m"]
        wind_rmse_nwp = wind_vector_rmse(u_true, v_true, u_nwp, v_nwp)
    else:
        wind_rmse_nwp = np.nan

    # Add to results
    results_wind["vector_metrics"] = {
        "RMSE ML": wind_rmse_ml,
        "RMSE NWP": wind_rmse_nwp,
    }

    # Create DataFrame for vector metrics
    vector_metrics_df = pd.DataFrame(
        results_wind["vector_metrics"], index=["Value"]
    ).T.round(4)

    # Export table
    export_table(
        vector_metrics_df,
        "wind_vector_metrics",
        caption="Wind vector RMSE metrics comparing ML and NWP predictions",
    )

    # Display results
    print("\nWind Vector RMSE Metrics:")
    display(vector_metrics_df)


Combined SAL = |S| + |A| + |L|
- Range: [0 to 6]
- 0: Perfect forecast
- Higher values indicate worse forecasts

1. Structure (S): [-2 to +2]
- Measures how well the spatial patterns match
- S = 0: Perfect structural agreement
- S > 0: Predicted patterns too large/flat
- S < 0: Predicted patterns too peaked/small

2. Amplitude (A): [-2 to +2]
- Measures the accuracy of domain-averaged values
- A = 0: Perfect amplitude match
- A > 0: Overestimation
- A < 0: Underestimation

3. Location (L): [0 to +2]
- Measures the accuracy of spatial placement
- L = 0: Perfect location match
- L increases with distance between predicted and observed centers of mass

SAL works best for:
- Fields with distinct objects/features
- Variables that can form coherent structures
- Fields with clear boundaries/gradients

In [None]:
def calculate_sal(ds_gt, ds_nwp, ds_ml, thr_factor=0.067, thr_quantile=0.9):
    """Calculate SAL metrics for each variable and level in the given datasets."""
    sal_dict = {}
    var = "precipitation"
    if var not in ds_gt or var not in ds_nwp or var not in ds_ml:
        raise ValueError(f"Variable {var} not found in datasets.")

    structure_scores_nwp = []
    amplitude_scores_nwp = []
    location_scores_nwp = []
    structure_scores_ml = []
    amplitude_scores_ml = []
    location_scores_ml = []

    for t in range(len(ds_gt.time)):
        y_true = ds_gt[var].isel(time=t).values
        y_pred_nwp = ds_nwp[var].isel(time=t).values
        y_pred_ml = ds_ml[var].isel(time=t).values

        try:
            sal_score_nwp = sal(
                y_pred_nwp,
                y_true,
                thr_factor=thr_factor,
                thr_quantile=thr_quantile,
            )
            sal_score_ml = sal(
                y_pred_ml,
                y_true,
                thr_factor=thr_factor,
                thr_quantile=thr_quantile,
            )
            structure_scores_nwp.append(sal_score_nwp[0])
            amplitude_scores_nwp.append(sal_score_nwp[1])
            location_scores_nwp.append(sal_score_nwp[2])
            structure_scores_ml.append(sal_score_ml[0])
            amplitude_scores_ml.append(sal_score_ml[1])
            location_scores_ml.append(sal_score_ml[2])
        except Exception as e:
            print(f"Error calculating SAL for {var} at time {t}: {str(e)}")
            continue

    if structure_scores_nwp and structure_scores_ml:
        structure_mean_nwp = np.nanmean(structure_scores_nwp)
        amplitude_mean_nwp = np.nanmean(amplitude_scores_nwp)
        location_mean_nwp = np.nanmean(location_scores_nwp)
        combined_mean_nwp = (
            np.abs(structure_mean_nwp)
            + np.abs(amplitude_mean_nwp)
            + np.abs(location_mean_nwp)
        )

        structure_mean_ml = np.nanmean(structure_scores_ml)
        amplitude_mean_ml = np.nanmean(amplitude_scores_ml)
        location_mean_ml = np.nanmean(location_scores_ml)
        combined_mean_ml = (
            np.abs(structure_mean_ml)
            + np.abs(amplitude_mean_ml)
            + np.abs(location_mean_ml)
        )

        sal_dict[var] = {
            "Structure": {
                "NWP": structure_mean_nwp,
                "ML": structure_mean_ml,
            },
            "Amplitude": {
                "NWP": amplitude_mean_nwp,
                "ML": amplitude_mean_ml,
            },
            "Location": {"NWP": location_mean_nwp, "ML": location_mean_ml},
            "Combined": {"NWP": combined_mean_nwp, "ML": combined_mean_ml},
        }

    # Create multi-level DataFrame
    df_dict = {}
    for var in sal_dict:
        for metric in sal_dict[var]:
            df_dict[(var, metric, "NWP")] = sal_dict[var][metric]["NWP"]
            df_dict[(var, metric, "ML")] = sal_dict[var][metric]["ML"]

    df = pd.DataFrame(df_dict, index=["Value"]).T
    df = df.unstack(level=2)
    df.columns = df.columns.droplevel(0)  # Remove 'Value' level

    export_table(
        df,
        "sal_metrics",
        caption="SAL metrics for precipitation prediction",
    )
    formatters = {col: "{:.4f}" for col in df.columns}

    return df.style.format(formatters)


if "precipitation" in ds_gt:
    sal_metrics = calculate_sal(
        ds_gt_first_timestep, ds_nwp_first_timestep, ds_ml_first_timestep
    )
    display(sal_metrics)

### 5.2 Time Series Across Elapsed_Forecast_Dimension

This chapter provides a detailed view of how the ML model and NWP model evolve
over time. By comparing time series data, we can identify when and where the
models diverge, offering insights into the underlying causes.



In [None]:
# These helper functions are only used to calculate the FSS threshold
max_spatial_dim = np.maximum(ds_gt.x.size, ds_gt.y.size)
window_size = (max_spatial_dim // 100,) * 2
n_points = int(
    np.minimum(
        SUBSAMPLE_FSS_THRESHOLD,
        ds_ml[list(VARIABLES_GROUND_TRUTH.values())[0]]
        .isel(elapsed_forecast_duration=0)
        .size,
    )
)
print(f"Using window size for FSS: {window_size}")
print(f"Using n_points for FSS: {n_points}")

In [None]:
precip_thresholds = [0.1, 1, 5]  # mm/h
wind_thresholds = [2.5, 5, 10]  # m/s

In [None]:
def calculate_metrics_over_leadtimes(
    ds_gt,
    ds_nwp,
    ds_ml,
    elapsed_forecast_durations,
    variables,
    metrics_to_compute,
):
    """Calculate metrics for each elapsed forecast duration and variable."""
    metrics_by_leadtime = {}

    for efd in elapsed_forecast_durations:
        lt_hours = efd.astype("timedelta64[h]").astype(int)
        print(
            f"Calculating metrics for elapsed forecast duration {lt_hours} hours..."
        )
        # Select data for current elapsed forecast duration
        ds_ml_lead = ds_ml.sel(elapsed_forecast_duration=efd)
        ds_nwp_lead = (
            ds_nwp.sel(elapsed_forecast_duration=efd)
            if ds_nwp is not None
            else None
        )

        # Calculate metrics
        metrics = calculate_all_metrics(
            ds_gt[variables].sel(time=ds_ml_lead.forecast_time),
            ds_nwp_lead,
            ds_ml_lead,
            metrics_to_compute=metrics_to_compute,
        )
        metrics_by_leadtime[efd] = metrics

    return metrics_by_leadtime


def plot_metrics_over_leadtimes(
    metrics_by_leadtime, elapsed_forecast_durations
):
    """Create line plots showing how metrics evolve over elapsed forecast durations."""
    # Get all variables and metrics from the first elapsed forecast duration
    variables = metrics_by_leadtime[elapsed_forecast_durations[0]].index
    metrics = metrics_by_leadtime[elapsed_forecast_durations[0]].columns

    # Create subplots for each metric and variable combination
    plotted_metrics = set()  # Track which metrics have been plotted

    for metric in metrics:
        if metric.endswith("ML") or metric.endswith("NWP"):
            base_metric = metric.rsplit(" ", 1)[0]  # Remove ML/NWP suffix

            # Skip if we've already plotted this base metric
            if base_metric in plotted_metrics:
                continue
            plotted_metrics.add(base_metric)

            for var in variables:
                fig, ax = plt.subplots(figsize=(12, 6), dpi=DPI)

                # Get metric values for all elapsed forecast durations for ML
                ml_values = [
                    metrics_by_leadtime[efd].loc[var, f"{base_metric} ML"]
                    for efd in elapsed_forecast_durations
                ]
                ax.plot(
                    elapsed_forecast_durations,
                    ml_values,
                    color=COLORS["ml"],
                    linestyle=LINE_STYLES["ml"][0],
                    marker=LINE_STYLES["ml"][1],
                    label="ML",
                )

                # Plot NWP if available
                if f"{base_metric} NWP" in metrics:
                    nwp_values = [
                        metrics_by_leadtime[efd].loc[var, f"{base_metric} NWP"]
                        for efd in elapsed_forecast_durations
                    ]
                    ax.plot(
                        elapsed_forecast_durations,
                        nwp_values,
                        color=COLORS["nwp"],
                        linestyle=LINE_STYLES["nwp"][0],
                        marker=LINE_STYLES["nwp"][1],
                        label="NWP",
                    )

                ax.set_title(
                    f"{base_metric} Over Elapsed Forecast Durations for {var}"
                )
                ax.set_xlabel("Elapsed Forecast Duration")
                units = VARIABLE_UNITS.get(var, "")
                ax.set_ylabel(f"{base_metric} [{units}]")
                ax.grid(True)
                ax.legend()
                plt.tight_layout()
                save_plot(fig, f"leadtime_{base_metric}_{var}")
                plt.show()


# Calculate metrics for all elapsed forecast durations
with LocalCluster(
    n_workers=16, threads_per_worker=1, memory_limit="32GB"
) as cluster:
    with Client(cluster) as client:
        print(f"Dask dashboard available at: {client.dashboard_link}")

        # For NWP variables
        metrics_by_leadtime_nwp = calculate_metrics_over_leadtimes(
            ds_gt,
            ds_nwp,
            ds_ml,
            ds_ml.elapsed_forecast_duration.values,
            list(VARIABLES_NWP.values()),
            metrics_to_compute=["RMSE", "FSS"],
        )

        # # For ML-only variables
        # ml_only_vars = [
        #     var
        #     for var in VARIABLES_ML.values()
        #     if var not in VARIABLES_NWP.values()
        # ]
        # metrics_by_leadtime_ml = calculate_metrics_over_leadtimes(
        #     ds_gt,
        #     None,
        #     ds_ml,
        #     ds_ml.elapsed_forecast_duration.values,
        #     ml_only_vars,
        #     metrics_to_compute=["RMSE", "FSS"],
        # )

# Plot metrics evolution over elapsed forecast durations
print("\nPlotting metrics for NWP variables:")
plot_metrics_over_leadtimes(
    metrics_by_leadtime_nwp, ds_ml.elapsed_forecast_duration.values
)

# print("\nPlotting metrics for ML-only variables:")
# plot_metrics_over_leadtimes(
#     metrics_by_leadtime_ml, ds_ml.elapsed_forecast_duration.values
# )

In [None]:
def calculate_meteoswiss_metrics_over_leadtimes(
    ds_gt, ds_ml, ds_nwp, elapsed_forecast_durations
):
    """Calculate MeteoSwiss verification metrics for each elapsed forecast duration.

    Args:
        ds_gt: Ground truth dataset
        ds_ml: ML model predictions
        ds_nwp: NWP predictions (optional)
        elapsed_forecast_durations: Array of forecast durations

    Returns:
        Dictionary containing metrics for each lead time
    """
    metrics_by_leadtime = {}

    for efd in elapsed_forecast_durations:
        metrics_by_leadtime[efd] = {}

        # Select data for current lead time
        ds_ml_lead = ds_ml.sel(elapsed_forecast_duration=efd)
        ds_nwp_lead = (
            ds_nwp.sel(elapsed_forecast_duration=efd)
            if ds_nwp is not None
            else None
        )
        ds_gt_lead = ds_gt.sel(time=ds_ml_lead.forecast_time)

        # --- Group 1: Precipitation Metrics ---
        if "precipitation" in ds_gt_lead:
            metrics_by_leadtime[efd]["precipitation"] = {}
            y_true = ds_gt_lead["precipitation"].values
            y_ml = ds_ml_lead["precipitation"].values
            y_nwp = (
                ds_nwp_lead["precipitation"].values
                if ds_nwp_lead is not None
                else None
            )

            for thr in precip_thresholds:
                metrics_by_leadtime[efd]["precipitation"][f"{thr}mm/h"] = {
                    "ML": {
                        "MAE": mean_absolute_error(y_true, y_ml),
                        "FBI": frequency_bias(y_true, y_ml, thr),
                        "ETS": ets_for_threshold(y_true, y_ml, thr),
                    }
                }
                if y_nwp is not None:
                    metrics_by_leadtime[efd]["precipitation"][f"{thr}mm/h"][
                        "NWP"
                    ] = {
                        "MAE": mean_absolute_error(y_true, y_nwp),
                        "FBI": frequency_bias(y_true, y_nwp, thr),
                        "ETS": ets_for_threshold(y_true, y_nwp, thr),
                    }

        # --- Group 2: Wind Metrics ---
        if "wind_u_10m" in ds_gt_lead and "wind_v_10m" in ds_gt_lead:
            metrics_by_leadtime[efd]["wind"] = {}
            for var in ["wind_u_10m", "wind_v_10m"]:
                metrics_by_leadtime[efd]["wind"][var] = {}
                y_true = ds_gt_lead[var].values
                y_ml = ds_ml_lead[var].values
                y_nwp = (
                    ds_nwp_lead[var].values if ds_nwp_lead is not None else None
                )

                for thr in wind_thresholds:
                    metrics_by_leadtime[efd]["wind"][var][f"{thr}m/s"] = {
                        "ML": {
                            "MAE": mean_absolute_error(y_true, y_ml),
                            "FBI": frequency_bias(y_true, y_ml, thr),
                            "ETS": ets_for_threshold(y_true, y_ml, thr),
                        }
                    }
                    if y_nwp is not None:
                        metrics_by_leadtime[efd]["wind"][var][f"{thr}m/s"][
                            "NWP"
                        ] = {
                            "MAE": mean_absolute_error(y_true, y_nwp),
                            "FBI": frequency_bias(y_true, y_nwp, thr),
                            "ETS": ets_for_threshold(y_true, y_nwp, thr),
                        }

        # --- Group 3: Radiation and Temperature Metrics (ML only) ---
        metrics_by_leadtime[efd]["radiation_temp"] = {}
        for var in vars_stdev:
            if var in ds_gt_lead and var in ds_ml_lead:
                y_true = ds_gt_lead[var].values
                y_ml = ds_ml_lead[var].values

                metrics_by_leadtime[efd]["radiation_temp"][var] = {
                    "ML": {
                        "MAE": mean_absolute_error(y_true, y_ml),
                        "STDEV": stdev_error(y_true, y_ml),
                    }
                }

    return metrics_by_leadtime


def plot_meteoswiss_verification_metrics_over_leadtimes(
    metrics_by_leadtime, elapsed_forecast_durations
):
    """Create separate line plots for MeteoSwiss verification metrics over lead times.

    Args:
        metrics_by_leadtime: Dictionary containing metrics for each lead time
        elapsed_forecast_durations: Array of forecast durations in hours
    """
    # Plot precipitation metrics if data exists
    if "precipitation" in metrics_by_leadtime[elapsed_forecast_durations[0]]:
        for metric in ["MAE", "FBI", "ETS"]:
            fig = plt.figure(figsize=(10, 6))
            ax = fig.add_subplot(111)

            for thr in precip_thresholds:
                # Plot ML metrics
                ml_values = [
                    metrics_by_leadtime[efd]["precipitation"][f"{thr}mm/h"][
                        "ML"
                    ][metric]
                    for efd in elapsed_forecast_durations
                ]
                ax.plot(
                    elapsed_forecast_durations,
                    ml_values,
                    linestyle=LINE_STYLES["ml"][0],
                    marker=LINE_STYLES["ml"][1],
                    label=f"ML {thr}mm/h",
                )

                # Check if NWP data exists for this threshold
                if (
                    f"{thr}mm/h"
                    in metrics_by_leadtime[elapsed_forecast_durations[0]][
                        "precipitation"
                    ]
                    and "NWP"
                    in metrics_by_leadtime[elapsed_forecast_durations[0]][
                        "precipitation"
                    ][f"{thr}mm/h"]
                ):
                    nwp_values = [
                        metrics_by_leadtime[efd]["precipitation"][f"{thr}mm/h"][
                            "NWP"
                        ][metric]
                        for efd in elapsed_forecast_durations
                    ]
                    ax.plot(
                        elapsed_forecast_durations,
                        nwp_values,
                        linestyle=LINE_STYLES["nwp"][0],
                        marker=LINE_STYLES["nwp"][1],
                        label=f"NWP {thr}mm/h",
                    )

            ax.set_title(f"Precipitation {metric}")
            ax.set_xlabel("Elapsed Forecast Duration [h]")
            ax.set_ylabel(metric)
            ax.grid(True)
            ax.legend()
            plt.tight_layout()
            save_plot(fig, f"leadtime_precip_{metric}")
            plt.show()

    # Plot wind metrics if data exists
    if "wind" in metrics_by_leadtime[elapsed_forecast_durations[0]]:
        for var in ["wind_u_10m", "wind_v_10m"]:
            if (
                var
                in metrics_by_leadtime[elapsed_forecast_durations[0]]["wind"]
            ):
                for metric in ["MAE", "FBI", "ETS"]:
                    fig = plt.figure(figsize=(10, 6))
                    ax = fig.add_subplot(111)

                    for thr in wind_thresholds:
                        # Plot ML metrics
                        ml_values = [
                            metrics_by_leadtime[efd]["wind"][var][f"{thr}m/s"][
                                "ML"
                            ][metric]
                            for efd in elapsed_forecast_durations
                        ]
                        ax.plot(
                            elapsed_forecast_durations,
                            ml_values,
                            linestyle=LINE_STYLES["ml"][0],
                            marker=LINE_STYLES["ml"][1],
                            label=f"ML {thr}m/s",
                        )

                        # Check if NWP data exists for this threshold
                        if (
                            f"{thr}m/s"
                            in metrics_by_leadtime[
                                elapsed_forecast_durations[0]
                            ]["wind"][var]
                            and "NWP"
                            in metrics_by_leadtime[
                                elapsed_forecast_durations[0]
                            ]["wind"][var][f"{thr}m/s"]
                        ):
                            nwp_values = [
                                metrics_by_leadtime[efd]["wind"][var][
                                    f"{thr}m/s"
                                ]["NWP"][metric]
                                for efd in elapsed_forecast_durations
                            ]
                            ax.plot(
                                elapsed_forecast_durations,
                                nwp_values,
                                linestyle=LINE_STYLES["nwp"][0],
                                marker=LINE_STYLES["nwp"][1],
                                label=f"NWP {thr}m/s",
                            )

                    ax.set_title(f"{var} {metric}")
                    ax.set_xlabel("Elapsed Forecast Duration [h]")
                    ax.set_ylabel(metric)
                    ax.grid(True)
                    ax.legend()
                    plt.tight_layout()
                    save_plot(fig, f"leadtime_wind_{var}_{metric}")
                    plt.show()

    # Plot radiation and temperature metrics (ML only) if data exists
    if "radiation_temp" in metrics_by_leadtime[elapsed_forecast_durations[0]]:
        for var in vars_stdev:
            if (
                var
                in metrics_by_leadtime[elapsed_forecast_durations[0]][
                    "radiation_temp"
                ]
            ):
                for metric in ["MAE", "STDEV"]:
                    fig = plt.figure(figsize=(10, 6))
                    ax = fig.add_subplot(111)

                    # Plot ML metrics only
                    ml_values = [
                        metrics_by_leadtime[efd]["radiation_temp"][var]["ML"][
                            metric
                        ]
                        for efd in elapsed_forecast_durations
                    ]
                    ax.plot(
                        elapsed_forecast_durations,
                        ml_values,
                        linestyle=LINE_STYLES["ml"][0],
                        marker=LINE_STYLES["ml"][1],
                        label="ML",
                    )

                    ax.set_title(f"{var} {metric}")
                    ax.set_xlabel("Elapsed Forecast Duration [h]")
                    ax.set_ylabel(f"{metric} [{VARIABLE_UNITS.get(var, '')}]")
                    ax.grid(True)
                    ax.legend()
                    plt.tight_layout()
                    save_plot(fig, f"leadtime_{var}_{metric}")
                    plt.show()


# Calculate MeteoSwiss metrics over lead times
meteoswiss_metrics = calculate_meteoswiss_metrics_over_leadtimes(
    ds_gt, ds_ml, ds_nwp, ds_ml.elapsed_forecast_duration.values
)

# Plot the metrics
plot_meteoswiss_verification_metrics_over_leadtimes(
    meteoswiss_metrics, ds_ml.elapsed_forecast_duration.values
)


The following two cells are only run for the first variable. Because otherwise if the user is not careful it will take a long time to run. To remove this limitation, simply remove the list() and [0] from the variable names in the plotting functions.

In [None]:
if PLOT_TIME is None:
    time_selected = None
else:
    time_selected = ds_ml.sel(start_time=pd.to_datetime(PLOT_TIME)).start_time

# 4. elapsed forecast duration plots for specific variable
for elapsed_forecast_duration in ds_ml.elapsed_forecast_duration:
    create_comparison_maps(
        ds_gt=ds_gt,
        ds_ml=ds_ml,
        ds_nwp=ds_nwp,
        ds_boundary=ds_boundary,
        plot_time=time_selected,
        elapsed_forecast_duration=elapsed_forecast_duration,
        # SELECT VARIABLE HERE
        var=list(VARIABLES_GROUND_TRUTH.values())[0],
    )


In [None]:
if PLOT_TIME is None:
    time_selected = None
else:
    time_selected = ds_ml.sel(start_time=pd.to_datetime(PLOT_TIME)).start_time

# 4. elapsed forecast duration plots for specific variable
for elapsed_forecast_duration in ds_ml.elapsed_forecast_duration:
    create_error_maps(
        ds_gt=ds_gt,
        ds_ml=ds_ml,
        ds_nwp=ds_nwp,
        plot_time=time_selected,
        elapsed_forecast_duration=elapsed_forecast_duration,
        # SELECT VARIABLE HERE
        var=list(VARIABLES_GROUND_TRUTH.values())[0],
    )
