# Plot pixel-by-pixel CSI & RMSE

In [1]:
import argparse
from pathlib import Path
import xarray as xr

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from pysteps.visualization.spectral import plot_spectrum1d
import geopandas as gpd
from matplotlib.collections import LineCollection
from matplotlib import colors, cm, gridspec, ticker
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from copy import copy
import cmcrameri  # noqa
import string


Pysteps configuration file found at: /home/users/ritvanen/conda/envs/jupyter/lib/python3.10/site-packages/pysteps/pystepsrc



In [2]:
from addict import Dict
import yaml


def load_yaml_config(path: str):
    """
    Load a YAML config file as an attribute-dictionnary.

    Args:
        path (str): Path to the YAML config file.

    Returns:
        Dict: Configuration loaded.
    """
    with open(path, "r") as f:
        config = Dict(yaml.safe_load(f))
    return config


def save_figs(fig, outpath, name, extensions, subfolder=None):
    if subfolder:
        outpath = outpath / subfolder
        outpath.mkdir(parents=True, exist_ok=True)
    for ext in extensions:
        fig.savefig(outpath / f"{name}.{ext}", bbox_inches="tight")
    plt.close(fig)
    del fig

def load_metrics(path, metric_name, timestep=5):
    """Load metrics from netCDF files and return as xr.Dataset.

    Parameters
    ----------
    path : str
        Path to directory containing the netCDF files.
    metric_name : str
        Metric name that is used match the file with glob `*{metric_name}*.nc`.
    timestep : int, optional
        The leadtime timestep in minutes, by default 5. Applied to the `leadtime`
        coordinate.

    Returns
    -------
    xarray.Dataset
        Dataset containing the metrics.
    """
    files = sorted(Path(path).glob(f"*{metric_name}*.nc"))

    try:
        das = [xr.open_dataarray(p) for p in files]
        ds = xr.Dataset(data_vars={arr.name: arr for arr in das})
    except ValueError:
        ds = xr.open_mfdataset(files)

    # Change leadtime to minutes
    ds = ds.assign_coords(leadtime=(ds.leadtime) * timestep)

    return ds

def set_ax(ax, score_conf, leadtime_limits, leadtime_locator_multiples=[15, 5]):
    """Set axis limits and ticks."""
    if score_conf["limits"] is not None:
        ax.set_ylim(*score_conf["limits"])
    else:
        ax.autoscale(enable=True, axis="y", tight=True)
    if score_conf["ticks"] and len(score_conf["ticks"]) == 3:
        ax.set_yticks(np.arange(*score_conf["ticks"]))
    elif score_conf["ticks"] and len(score_conf["ticks"]) == 2:
        ax.yaxis.set_major_locator(plt.MultipleLocator(score_conf["ticks"][0]))
        ax.yaxis.set_minor_locator(plt.MultipleLocator(score_conf["ticks"][1]))

    if score_conf.get("log_scale"):
        if score_conf["limits"] is not None:
            ax.set_ylim([10 ** score_conf["limits"][0], 10 ** score_conf["limits"][1]])
        else:
            ax.autoscale(enable=True, axis="y", tight=True)

        ax.set_yscale("log")
        ax.yaxis.set_major_locator(plt.LogLocator(base=10.0, numticks=15))
        ax.yaxis.set_minor_locator(plt.NullLocator())

    ax.xaxis.set_major_locator(plt.MultipleLocator(leadtime_locator_multiples[0]))
    ax.xaxis.set_minor_locator(plt.MultipleLocator(leadtime_locator_multiples[1]))

    # Add first and last leadtime tick labels
    ax.set_xticks(list(ax.get_xticks()) + leadtime_limits)

    ax.set_xlim(*leadtime_limits)
    ax.set_xlabel("Leadtime [min]")

UNIT_STRINGS = {
    "mmh": r"$\mathrm{mm\,h}^{-1}$",
    "dbz": r"$\mathrm{dBZ}$",
    "meters": r"$\mathrm{m}$",
}

alphabet = string.ascii_lowercase


def nested_list_to_tuple(lst):
    return tuple(nested_list_to_tuple(i) if isinstance(i, list) else i for i in lst)

In [3]:
config_path = "/home/users/ritvanen/koodaus/cell-tracking-article-code/config/swiss-data/plot_metrics.yaml"
config = load_yaml_config(config_path)

for method in config.methods.keys():
    if isinstance(config.methods[method].linestyle, list):
        config.methods[method].linestyle = nested_list_to_tuple(config.methods[method].linestyle)

exp_id = config.exp_id
result_dir = config.path.result_dir.format(id=exp_id)
save_dir = Path(config.path.save_dir.format(id=exp_id))
save_dir.mkdir(parents=True, exist_ok=True)

if config.stylefile is not None:
    plt.style.use(config.stylefile)

ds_cat = load_metrics(result_dir, "CAT")
ds_cont = load_metrics(result_dir, "CONT")

In [15]:
ds_cat

Unnamed: 0,Array,Chunk
Bytes,4.12 kiB,4.12 kiB
Shape,"(4, 11, 12)","(4, 11, 12)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 4.12 kiB 4.12 kiB Shape (4, 11, 12) (4, 11, 12) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",12  11  4,

Unnamed: 0,Array,Chunk
Bytes,4.12 kiB,4.12 kiB
Shape,"(4, 11, 12)","(4, 11, 12)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4.12 kiB,4.12 kiB
Shape,"(4, 11, 12)","(4, 11, 12)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 4.12 kiB 4.12 kiB Shape (4, 11, 12) (4, 11, 12) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",12  11  4,

Unnamed: 0,Array,Chunk
Bytes,4.12 kiB,4.12 kiB
Shape,"(4, 11, 12)","(4, 11, 12)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4.12 kiB,4.12 kiB
Shape,"(4, 11, 12)","(4, 11, 12)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 4.12 kiB 4.12 kiB Shape (4, 11, 12) (4, 11, 12) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",12  11  4,

Unnamed: 0,Array,Chunk
Bytes,4.12 kiB,4.12 kiB
Shape,"(4, 11, 12)","(4, 11, 12)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4.12 kiB,4.12 kiB
Shape,"(4, 11, 12)","(4, 11, 12)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 4.12 kiB 4.12 kiB Shape (4, 11, 12) (4, 11, 12) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",12  11  4,

Unnamed: 0,Array,Chunk
Bytes,4.12 kiB,4.12 kiB
Shape,"(4, 11, 12)","(4, 11, 12)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4.12 kiB,4.12 kiB
Shape,"(4, 11, 12)","(4, 11, 12)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 4.12 kiB 4.12 kiB Shape (4, 11, 12) (4, 11, 12) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",12  11  4,

Unnamed: 0,Array,Chunk
Bytes,4.12 kiB,4.12 kiB
Shape,"(4, 11, 12)","(4, 11, 12)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4.12 kiB,4.12 kiB
Shape,"(4, 11, 12)","(4, 11, 12)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 4.12 kiB 4.12 kiB Shape (4, 11, 12) (4, 11, 12) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",12  11  4,

Unnamed: 0,Array,Chunk
Bytes,4.12 kiB,4.12 kiB
Shape,"(4, 11, 12)","(4, 11, 12)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4.12 kiB,4.12 kiB
Shape,"(4, 11, 12)","(4, 11, 12)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 4.12 kiB 4.12 kiB Shape (4, 11, 12) (4, 11, 12) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",12  11  4,

Unnamed: 0,Array,Chunk
Bytes,4.12 kiB,4.12 kiB
Shape,"(4, 11, 12)","(4, 11, 12)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4.12 kiB,4.12 kiB
Shape,"(4, 11, 12)","(4, 11, 12)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 4.12 kiB 4.12 kiB Shape (4, 11, 12) (4, 11, 12) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",12  11  4,

Unnamed: 0,Array,Chunk
Bytes,4.12 kiB,4.12 kiB
Shape,"(4, 11, 12)","(4, 11, 12)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4.12 kiB,4.12 kiB
Shape,"(4, 11, 12)","(4, 11, 12)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 4.12 kiB 4.12 kiB Shape (4, 11, 12) (4, 11, 12) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",12  11  4,

Unnamed: 0,Array,Chunk
Bytes,4.12 kiB,4.12 kiB
Shape,"(4, 11, 12)","(4, 11, 12)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4.12 kiB,4.12 kiB
Shape,"(4, 11, 12)","(4, 11, 12)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 4.12 kiB 4.12 kiB Shape (4, 11, 12) (4, 11, 12) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",12  11  4,

Unnamed: 0,Array,Chunk
Bytes,4.12 kiB,4.12 kiB
Shape,"(4, 11, 12)","(4, 11, 12)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4.12 kiB,4.12 kiB
Shape,"(4, 11, 12)","(4, 11, 12)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 4.12 kiB 4.12 kiB Shape (4, 11, 12) (4, 11, 12) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",12  11  4,

Unnamed: 0,Array,Chunk
Bytes,4.12 kiB,4.12 kiB
Shape,"(4, 11, 12)","(4, 11, 12)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4.12 kiB,4.12 kiB
Shape,"(4, 11, 12)","(4, 11, 12)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 4.12 kiB 4.12 kiB Shape (4, 11, 12) (4, 11, 12) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",12  11  4,

Unnamed: 0,Array,Chunk
Bytes,4.12 kiB,4.12 kiB
Shape,"(4, 11, 12)","(4, 11, 12)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [16]:
ds_cont

In [4]:
if config.legend_order is None:
    legend_order = config.methods

legend_label_order = [config.methods[model]["label"] for model in legend_order if model in ds_cat]

In [6]:
# Plot CSI, RMSE
ncols = 2
nrows = 1

thr = 4.6
# thr = [1.0, 5.0, 10.0, 15.0, 20.0, 25.0, 30.0, 50.0]

fig = plt.figure(
    figsize=(config.figures.col_width * ncols, config.figures.row_height * nrows),
    constrained_layout=True,
)
subfigs = fig.subfigures(
    nrows=1,
    ncols=ncols,
    squeeze=True,
    # sharey=False,
    # width_ratios=[*[1 for _ in range(ncols - 1)], 1.1],
)

metric = "CSI"
axs = subfigs[0].subplots(nrows=1, ncols=1, squeeze=True, sharey=False)
for model in config.methods.keys():
    ds_cat[model].sel(cat_metric=metric, threshold=thr).plot.line(
        ax=axs,
        c=config.methods[model]["color"],
        label=config.methods[model]["label"],
        linestyle=config.methods[model]["linestyle"],
    )
set_ax(
    axs,
    config.metric_conf[metric],
    config.leadtime_limits,
    config.leadtime_locator_multiples,
)
axs.set_ylabel(config.metric_conf[metric]["label"])
axs.legend()
axs.grid(which="both", axis="both")

if config.write_panel_labels:
    label = f"({alphabet[0]}) "
else:
    label = ""
if np.isfinite(thr):
    axs.set_title(f"{label} {config.metric_conf[metric]['full_name']} ($\mathrm{{R}}_\mathrm{{thr}} = {thr:.1f}~${UNIT_STRINGS[config.unit]})")
else:
    axs.set_title(f"{label}No threshold")

handles, labels = axs.get_legend_handles_labels()
order = [labels.index(label) for label in legend_label_order]
axs.legend(
    [handles[idx] for idx in order],
    [labels[idx] for idx in order],
    bbox_to_anchor=(0.0, 0.7, 1.0, 0.3),
)


metric = "RMSE"
axs = subfigs[1].subplots(nrows=1, ncols=1, squeeze=True, sharey=False)
for model in config.methods.keys():
    ds_cont[model].sel(cont_metric=metric, threshold=thr).plot.line(
        ax=axs,
        c=config.methods[model]["color"],
        label=config.methods[model]["label"],
        linestyle=config.methods[model]["linestyle"],
    )
set_ax(
    axs,
    config.metric_conf[metric],
    config.leadtime_limits,
    config.leadtime_locator_multiples,
)
axs.set_ylabel(config.metric_conf[metric]["label"])
axs.legend()
axs.grid(which="both", axis="both")

if config.write_panel_labels:
    label = f"({alphabet[1]}) "
else:
    label = ""
if np.isfinite(thr):
    axs.set_title(f"{label} {config.metric_conf[metric]['full_name']} ($\mathrm{{R}}_\mathrm{{thr}} = {thr:.1f}~${UNIT_STRINGS[config.unit]})")
else:
    axs.set_title(f"{label}No threshold")

handles, labels = axs.get_legend_handles_labels()
order = [labels.index(label) for label in legend_label_order]
axs.legend(
    [handles[idx] for idx in order],
    [labels[idx] for idx in order],
    bbox_to_anchor=(0.0, 0.7, 1.0, 0.3),
)

outputname = "pixel_csi_rmse"
save_figs(fig, save_dir, outputname, config.output_formats)

In [19]:
# Plot CSI, RMSE
ncols = 2
nrows = 3

# thr = 4.6
thrs = [1.0, 5.0, 10.0, 20.0, 30.0, 50.0]

fig = plt.figure(
    figsize=(config.figures.col_width * ncols, config.figures.row_height * nrows),
    constrained_layout=True,
)
subfigs = fig.subfigures(
    nrows=nrows,
    ncols=ncols,
    squeeze=False,
    # sharey="rows",
    # sharex="cols",
    # width_ratios=[*[1 for _ in range(ncols - 1)], 1.1],
)

metric = "CSI"
for i, thr in enumerate(thrs):
    axs = subfigs.flatten()[i].subplots(nrows=1, ncols=1, squeeze=True, sharey=False)
    for model in config.methods.keys():
        ds_cat[model].sel(cat_metric=metric, threshold=thr).plot.line(
            ax=axs,
            c=config.methods[model]["color"],
            label=config.methods[model]["label"],
            linestyle=config.methods[model]["linestyle"],
        )
    set_ax(
        axs,
        config.metric_conf[metric],
        config.leadtime_limits,
        config.leadtime_locator_multiples,
    )
    axs.set_ylabel(config.metric_conf[metric]["label"])
    axs.legend()
    axs.grid(which="both", axis="both")

    if config.write_panel_labels:
        label = f"({alphabet[i]}) "
    else:
        label = ""
    if np.isfinite(thr):
        axs.set_title(f"{label} {config.metric_conf[metric]['full_name']} ($\mathrm{{R}}_\mathrm{{thr}} = {thr:.1f}~${UNIT_STRINGS[config.unit]})")
    else:
        axs.set_title(f"{label}No threshold")

    handles, labels = axs.get_legend_handles_labels()
    order = [labels.index(label) for label in legend_label_order]
    axs.legend(
        [handles[idx] for idx in order],
        [labels[idx] for idx in order],
        bbox_to_anchor=(0.0, 0.7, 1.0, 0.3),
    )

outputname = "pixel_csi_supplementary"
save_figs(fig, save_dir, outputname, config.output_formats)

In [20]:
fig = plt.figure(
    figsize=(config.figures.col_width * ncols, config.figures.row_height * nrows),
    constrained_layout=True,
)
subfigs = fig.subfigures(
    nrows=nrows,
    ncols=ncols,
    squeeze=False,
    # sharey=False,
    # width_ratios=[*[1 for _ in range(ncols - 1)], 1.1],
)


metric = "RMSE"
for i, thr in enumerate(thrs):
    axs = subfigs.flatten()[i].subplots(nrows=1, ncols=1, squeeze=True, sharey=False)
    for model in config.methods.keys():
        ds_cont[model].sel(cont_metric=metric, threshold=thr).plot.line(
            ax=axs,
            c=config.methods[model]["color"],
            label=config.methods[model]["label"],
            linestyle=config.methods[model]["linestyle"],
        )
    set_ax(
        axs,
        config.metric_conf[metric],
        config.leadtime_limits,
        config.leadtime_locator_multiples,
    )
    axs.set_ylabel(config.metric_conf[metric]["label"])
    axs.legend()
    axs.grid(which="both", axis="both")
    
    if config.write_panel_labels:
        label = f"({alphabet[i]}) "
    else:
        label = ""
    if np.isfinite(thr):
        axs.set_title(f"{label} {config.metric_conf[metric]['full_name']} ($\mathrm{{R}}_\mathrm{{thr}} = {thr:.1f}~${UNIT_STRINGS[config.unit]})")
    else:
        axs.set_title(f"{label}No threshold")
    
    handles, labels = axs.get_legend_handles_labels()
    order = [labels.index(label) for label in legend_label_order]
    axs.legend(
        [handles[idx] for idx in order],
        [labels[idx] for idx in order],
        bbox_to_anchor=(0.0, 0.7, 1.0, 0.3),
    )

outputname = "pixel_rmse_supplementary"
save_figs(fig, save_dir, outputname, config.output_formats)