# Supplementary figures for the manuscript "Cell-based estimation of nowcast model skill for reproducing growth and decay of convective rainfall"


## Imports


In [1]:
import argparse
from pathlib import Path
import xarray as xr

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from pysteps.visualization.spectral import plot_spectrum1d
import geopandas as gpd
from matplotlib.collections import LineCollection
from matplotlib import colors, cm, gridspec, ticker, patches
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from copy import copy
import cmcrameri  # noqa
import palettable  # noqa
import textwrap
import string
import pandas as pd
import xskillscore as xs
import matplotlib.patches as patches
from matplotlib.legend_handler import HandlerTuple
from flox.xarray import xarray_reduce

alphabet = string.ascii_lowercase

import seaborn as sns


Pysteps configuration file found at: /scratch/jritvane/miniconda3/envs/jupyter/lib/python3.11/site-packages/pysteps/pystepsrc



## General definitions for style etc.


In [2]:
MEDIANPROPS = dict(linestyle="-", linewidth=2, color="k")
MEANLINEPROPS = dict(linestyle=(0, (1,0.4)), linewidth=2, color="k")
FLIERPROPS = dict(marker="o", markersize=0.3, markerfacecolor="gray", markeredgecolor="gray", rasterized=True)
ZEROLINE_PROPS = dict(linestyle="--", linewidth=1.5, color="gray")

STATE_GROUP_TITLES = {
    "growth": "Growing cells",
    "decay": "Decaying cells",
    "all": "All cells",
}

HUE_CMAP = "cmc.hawaii_r"

PLOT_EXT = "png"

MAX_RR_LIMITS = (-125, 125)
MEAN_RR_LIMITS = (-20, 20)
SUM_RR_LIMITS = (-10, 20)
LIFETIME_LIMITS = (-12, 12)
AREA_LIMITS = (-600, 1000)
COUNT_LIMITS = (0, 20000)
CENTROID_DISTANCE_LIMITS = (0, 30)

MAX_RR_TICK_MULTIPLE = 25
MEAN_RR_TICK_MULTIPLE = 5
SUM_RR_TICK_MULTIPLE = 5
LIFETIME_TICK_MULTIPLE = 2
AREA_TICK_MULTIPLE = 200
COUNT_TICK_MULTIPLE = 1000

MAX_RR_DIFF_TITLE = "Difference in maximum rainfall rate [mm h$^{-1}$]"
MEAN_RR_DIFF_TITLE = "Difference in mean rainfall rate [mm h$^{-1}$]"
SUM_RR_DIFF_TITLE = "Difference in volume rain rate [10$^6$ m$^3$h$^{-1}$]"
LIFETIME_TITLE = "Difference in lifetime [min]"
AREA_TITLE = "Difference in area [km$^2$]"
COUNT_TITLE = "Cell count [10$^3$]"
CENTROID_DISTANCE_TITLE = "Centroid distance [km]"

METHOD_X_LABEL = "Model"

W_PER_METHOD_LT = 1.2
W_PER_METHOD_S = 0.8

FIG_HEIGHT = 6
FIG_WIDTH = 6

HIST_FIG_H = 2.5
HIST_FIG_W = 3

# Cut away saturated values
MAX_RR_LIMIT = 122

# tolerance for zero difference for volume rain rate
SUM_RR_ZERO_TOL = 0.0


def leadtime_to_minutes(x, pos):
    return f"{x * 5:.0f}"


# Load stylefile
plt.style.use(
    "config/stylefiles/object_figs_article.mplstyle"
)


In [3]:
from addict import Dict
import yaml


def load_yaml_config(path: str):
    """
    Load a YAML config file as an attribute-dictionnary.

    Args:
        path (str): Path to the YAML config file.

    Returns:
        Dict: Configuration loaded.
    """
    with open(path, "r") as f:
        config = Dict(yaml.safe_load(f))
    return config


def save_figs(fig, outpath, name, extensions, subfolder=None):
    if subfolder:
        outpath = outpath / subfolder
        outpath.mkdir(parents=True, exist_ok=True)
    for ext in extensions:
        fig.savefig(outpath / f"{name}.{ext}", bbox_inches="tight")
    plt.close(fig)


def create_fig_leadtime_groups(ngroups, nmethods):
    return plt.subplots(
        ncols=ngroups,
        nrows=1,
        # figsize=(W_PER_METHOD_S * N_METHODS, FIG_HEIGHT * len(groups)),
        figsize=(FIG_WIDTH * ngroups, FIG_HEIGHT),
        constrained_layout=True,
        sharey=True,
        squeeze=True
    )

def create_fig_hist(ngroups):
    return plt.subplots(
        ncols=ngroups,
        nrows=1,
        figsize=(HIST_FIG_W*ngroups, HIST_FIG_H),
        constrained_layout=True,
        sharey=True,
    )


def plot_obs_counts(obs_counts, axs):
    start = axs.containers[0].get_children()[0].xy[0] - 1
    for i, val in enumerate(obs_counts.values):
        axs.bar(
            start, val, width=axs.containers[-1].get_children()[-1].get_width(),
            align="edge",
            edgecolor=axs.containers[i].get_children()[-1].get_edgecolor(),
            linewidth=axs.containers[i].get_children()[-1].get_linewidth(),
            color=axs.containers[i].get_children()[-1].get_facecolor(),
        )
        start += axs.containers[-1].get_children()[-1].get_width()
    xt = axs.get_xticks()
    xt = np.append(xt, -1)

    axs.set_xticks(xt)
    xtl = axs.get_xticklabels()
    xtl[-1] = "Target"
    axs.set_xticklabels(xtl)


def get_labelstr(method, width=10):
    try:
        label = textwrap.fill(conf.methods[method].label, width)
    except:
        label = method
    return label


In [4]:
def set_ax(ax, score_conf, leadtime_limits, leadtime_locator_multiples=[15, 5]):
    """Set axis limits and ticks."""
    if score_conf["limits"] is not None:
        ax.set_ylim(*score_conf["limits"])
    else:
        ax.autoscale(enable=True, axis="y", tight=True)
    if score_conf["ticks"] and len(score_conf["ticks"]) == 3:
        ax.set_yticks(np.arange(*score_conf["ticks"]))
    elif score_conf["ticks"] and len(score_conf["ticks"]) == 2:
        ax.yaxis.set_major_locator(plt.MultipleLocator(score_conf["ticks"][0]))
        ax.yaxis.set_minor_locator(plt.MultipleLocator(score_conf["ticks"][1]))

    if score_conf.get("log_scale"):
        if score_conf["limits"] is not None:
            ax.set_ylim([10 ** score_conf["limits"][0], 10 ** score_conf["limits"][1]])
        else:
            ax.autoscale(enable=True, axis="y", tight=True)

        ax.set_yscale("log")
        ax.yaxis.set_major_locator(plt.LogLocator(base=10.0, numticks=15))
        ax.yaxis.set_minor_locator(plt.NullLocator())

    ax.xaxis.set_major_locator(plt.MultipleLocator(leadtime_locator_multiples[0]))
    ax.xaxis.set_minor_locator(plt.MultipleLocator(leadtime_locator_multiples[1]))
    
    # Add first and last leadtime tick labels
    ax.set_xticks(list(ax.get_xticks()) + leadtime_limits)
    
    ax.set_xlim(*leadtime_limits)
    ax.set_xlabel("Leadtime [min]")

# Read verification configuration file and setup data


## Base data

In [5]:
conf = "config/swiss-data/plot_metrics_objects.yaml"
conf = load_yaml_config(conf)

COLORS_METHODS = {m: conf.methods[m].color for m in conf.methods}

exp_id = conf.exp_id
# exp_id = "split-merge-test"
result_dir = conf.path.result_dir.format(id=exp_id)
OUTPUT_DIR = Path(conf.path.save_dir.format(id=exp_id)) / "figs_article_supplementary"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

metric = "OBJECTS_ALL"
files = sorted(Path(result_dir).glob(f"*{metric}*.nc"))

path = files[0]

DATASET = xr.open_dataset(path)
DATASET = DATASET.drop_duplicates(dim="sample")

obs_no_split_or_merge_at_all_condition = (~((DATASET["obs_merged"].fillna(0) + DATASET["obs_from_split"].fillna(0) + DATASET["obs_will_merge"].fillna(0)) > 0).any(dim="leadtime"))
obs_no_split_or_merge_condition = (((DATASET["obs_merged"].fillna(0) + DATASET["obs_from_split"].fillna(0)) > 0).cumsum(dim="leadtime") == 0)
pred_no_split_merge_condition = (((DATASET["pred_merged"].fillna(0) + DATASET["pred_from_split"].fillna(0)) > 0).cumsum(dim="leadtime") == 0)
prev_no_split_merge_at_all_condition = (~((DATASET["prev_merged"].fillna(0) + DATASET["prev_from_split"].fillna(0) + DATASET["prev_will_merge"].fillna(0)) > 0).any(dim="prev_time"))
prev_no_split_merge_condition = (((DATASET["prev_merged"].fillna(0) + DATASET["prev_from_split"].fillna(0)) > 0).cumsum(dim="prev_time") == 0)

# Remove cells with splits or merges
# obs
DATASET["obs_sum_rr"] = DATASET["obs_sum_rr"].where(obs_no_split_or_merge_at_all_condition  & prev_no_split_merge_at_all_condition)
DATASET["obs_mean_rr"] = DATASET["obs_mean_rr"].where(obs_no_split_or_merge_at_all_condition  & prev_no_split_merge_at_all_condition)
DATASET["obs_area"] = DATASET["obs_area"].where(obs_no_split_or_merge_at_all_condition  & prev_no_split_merge_at_all_condition)
DATASET["obs_max_rr"] = DATASET["obs_max_rr"].where(obs_no_split_or_merge_at_all_condition  & prev_no_split_merge_at_all_condition)

# # pred
DATASET["pred_sum_rr"] = DATASET["pred_sum_rr"].where(obs_no_split_or_merge_at_all_condition & prev_no_split_merge_at_all_condition)
DATASET["pred_mean_rr"] = DATASET["pred_mean_rr"].where(obs_no_split_or_merge_at_all_condition & prev_no_split_merge_at_all_condition)
DATASET["pred_area"] = DATASET["pred_area"].where(obs_no_split_or_merge_at_all_condition & prev_no_split_merge_at_all_condition)
DATASET["pred_max_rr"] = DATASET["pred_max_rr"].where(obs_no_split_or_merge_at_all_condition & prev_no_split_merge_at_all_condition)
# # prev
DATASET["prev_sum_rr"] = DATASET["prev_sum_rr"].where(obs_no_split_or_merge_at_all_condition & prev_no_split_merge_at_all_condition)
DATASET["prev_mean_rr"] = DATASET["prev_mean_rr"].where(obs_no_split_or_merge_at_all_condition & prev_no_split_merge_at_all_condition)
DATASET["prev_merged"] = DATASET["prev_merged"].where(obs_no_split_or_merge_at_all_condition & prev_no_split_merge_at_all_condition)
DATASET["prev_from_split"] = DATASET["prev_from_split"].where(obs_no_split_or_merge_at_all_condition & prev_no_split_merge_at_all_condition)

# Change unit of rr sum to 1e6 m^3/h
DATASET["prev_sum_rr"] = DATASET["prev_sum_rr"] * 1e-3
DATASET["obs_sum_rr"] = DATASET["obs_sum_rr"] * 1e-3
DATASET["pred_sum_rr"] = DATASET["pred_sum_rr"] * 1e-3
# DATASET["sum_rr_diff"] = DATASET["pred_sum_rr"].fillna(0) - DATASET["obs_sum_rr"].fillna(0)
DATASET["sum_rr_diff"] = DATASET["pred_sum_rr"] - DATASET["obs_sum_rr"]

DATASET["cell_match_obs_sum_rr"] = DATASET["cell_match_obs_sum_rr"] * 1e-3
DATASET["cell_match_pred_sum_rr"] = DATASET["cell_match_pred_sum_rr"] * 1e-3

# Calculate differences
DATASET["max_rr_diff"] = DATASET["pred_max_rr"] - DATASET["obs_max_rr"]
DATASET["mean_rr_diff"] = DATASET["pred_mean_rr"] - DATASET["obs_mean_rr"]
DATASET["lifetime_diff"] = DATASET["pred_lifetime"] - DATASET["obs_lifetime"]
DATASET["area_diff"] = DATASET["pred_area"] - DATASET["obs_area"]

# Maximum area in track
DATASET["max_prev_area"] = DATASET["prev_area"].max(dim="prev_time", skipna=True) 
DATASET["max_obs_area"] = DATASET["obs_area"].max(dim="leadtime", skipna=True) 

DATASET["max_area"] = (["sample", "track"], np.nanmax([DATASET["max_prev_area"].values, DATASET["max_obs_area"].values], axis=0))

# Track lifetime
DATASET["lifetime_prev"] = DATASET["prev_mean_rr"].count(dim="prev_time")
DATASET["lifetime_full"] = DATASET["lifetime_prev"] + DATASET["obs_lifetime"]

# Maximum RVR in track
DATASET["track_max_prev_rr"] = DATASET["prev_sum_rr"].max(dim="prev_time", skipna=True)
DATASET["track_max_obs_rr"] = DATASET["obs_sum_rr"].max(dim="leadtime", skipna=True)
DATASET["track_max_pred_rr"] = DATASET["pred_sum_rr"].max(dim="leadtime", skipna=True)
DATASET["track_argmax_obs_rr"] = DATASET["obs_sum_rr"].fillna(-1000).argmax(dim="leadtime", skipna=True).where(
    DATASET["track_max_obs_rr"] > 0)

# Minimum RVR in track
DATASET["track_min_prev_rr"] = DATASET["prev_sum_rr"].min(dim="prev_time", skipna=True)
DATASET["track_min_obs_rr"] = DATASET["obs_sum_rr"].min(dim="leadtime", skipna=True)
DATASET["track_min_pred_rr"] = DATASET["pred_sum_rr"].min(dim="leadtime", skipna=True)

# General variables
DATASET = DATASET.where(DATASET.method.isin(conf.legend_order))
N_METHODS = np.unique(DATASET.method.values).size
METHODS = np.unique(DATASET.method.values)

sorter = np.argsort(np.array(conf.legend_order))

DATASET_BASE = DATASET.copy()
DATASET_BASE

## Cell state

In [6]:
ds_ = DATASET.copy()

# State according to derivative definition
derivative_at_t0 = xr.concat([
    DATASET["prev_sum_rr"].sel(prev_time=[-2, -1, 0]).rename({"prev_time": "leadtime"}), 
    DATASET["obs_sum_rr"].sel(leadtime=[1, 2,])
], dim="leadtime").differentiate("leadtime").sel(leadtime=0)

num_point_in_derivative = xr.concat([
    DATASET["prev_sum_rr"].sel(prev_time=[-2, -1, 0]).rename({"prev_time": "leadtime"}), 
    DATASET["obs_sum_rr"].sel(leadtime=[1, 2,])
], dim="leadtime").count(dim="leadtime")

ds_["obs_derivative_at_t0"] = derivative_at_t0
ds_["num_points_in_obs_derivative"] = num_point_in_derivative
growth_cond = derivative_at_t0 > 0
decay_cond = (
    ((derivative_at_t0) < 0) | 
    ((ds_["track_max_prev_rr"] > 0) & (ds_["prev_sum_rr"].sel(prev_time=0) > 0) & (derivative_at_t0.isnull()))
)
stable_cond = (
    (np.abs(derivative_at_t0) == 0)
    & ((ds_["track_max_prev_rr"] > 0) & (ds_["track_max_obs_rr"] > 0))
)

ds_["state"] = xr.ones_like(ds_["track_max_prev_rr"]) * np.nan
ds_["state"] = ds_["state"].where(~growth_cond, "growth")
ds_["state"] = ds_["state"].where(~decay_cond, "decay")
ds_["state"] = ds_["state"].where(~stable_cond, "stable")

# As integer for confusion matrix
ds_["state_int"] = xr.ones_like(ds_["track_max_prev_rr"]) * np.nan
ds_["state_int"] = ds_["state_int"].where(~growth_cond, 1)
ds_["state_int"] = ds_["state_int"].where(~decay_cond, 2)
ds_["state_int"] = ds_["state_int"].where(~stable_cond, 3)

# Predited state according to derivative
# Predicted state per track
derivative_pred_at_t0 = xr.concat([
    DATASET["prev_sum_rr"].sel(prev_time=[-2, -1, 0]).rename({"prev_time": "leadtime"}), 
    DATASET["pred_sum_rr"].sel(leadtime=[1, 2,])
], dim="leadtime").differentiate("leadtime").sel(leadtime=0)

num_point_in_pred_derivative = xr.concat([
    DATASET["prev_sum_rr"].sel(prev_time=[-2, -1, 0]).rename({"prev_time": "leadtime"}), 
    DATASET["pred_sum_rr"].sel(leadtime=[1, 2,])
], dim="leadtime").count(dim="leadtime")

ds_["pred_derivative_at_t0"] = derivative_pred_at_t0
ds_["num_points_in_pred_derivative"] = num_point_in_pred_derivative
growth_cond_pred = derivative_pred_at_t0 > 0
decay_cond_pred = (
    (derivative_pred_at_t0 < 0) | 
    ((ds_["track_max_prev_rr"] > 0) &  (ds_["prev_sum_rr"].sel(prev_time=0) > 0) & (ds_["track_max_pred_rr"].isnull()))
)
stable_cond_pred = (
    (np.abs(derivative_pred_at_t0) == 0) & 
    ((ds_["track_max_prev_rr"] > 0) & (ds_["track_max_pred_rr"] > 0))
)

ds_["state_pred"] = xr.ones_like(ds_["track_max_prev_rr"]) * np.nan
ds_["state_pred"] = ds_["state_pred"].where(~growth_cond_pred, "growth")
ds_["state_pred"] = ds_["state_pred"].where(~decay_cond_pred, "decay")
ds_["state_pred"] = ds_["state_pred"].where(~stable_cond_pred, "stable")

# As integer for confusion matrix
ds_["state_pred_int"] = xr.ones_like(ds_["track_max_prev_rr"]) * np.nan
ds_["state_pred_int"] = ds_["state_pred_int"].where(~growth_cond_pred, 1)
ds_["state_pred_int"] = ds_["state_pred_int"].where(~decay_cond_pred, 2)
ds_["state_pred_int"] = ds_["state_pred_int"].where(~stable_cond_pred, 3)

DATASET_CELL_STATE = ds_.copy()

DATASET_CELL_STATE

# Figures for cell tracking

## Categorical scores of cell existence

In [7]:
plt.style.use("config/stylefiles/article.mplstyle")

In [8]:
track_exists_cond = (DATASET_CELL_STATE["track_max_prev_rr"] > 0) & ((DATASET_CELL_STATE["prev_sum_rr"].sel(prev_time=0) > 0))

cell_exists_nowcast = (DATASET_CELL_STATE["pred_sum_rr"] > 0) 
cell_exists_obs = (DATASET_CELL_STATE["obs_sum_rr"] > 0)

cell_exists_nowcast = cell_exists_nowcast.where(track_exists_cond, 2)
cell_exists_obs = cell_exists_obs.where(track_exists_cond, 2)

contingency_all = xs.Contingency(
    cell_exists_obs, 
    cell_exists_nowcast, 
    np.array([0, 0.5, 1.0]), 
    np.array([0, 0.5, 1.0]), 
    dim=["sample", "track"],
)
contingency_all.table

track_exists_growth_cond = (DATASET_CELL_STATE["track_max_prev_rr"] > 0) & (DATASET_CELL_STATE["state"] == "growth")

cell_exists_nowcast = (DATASET_CELL_STATE["pred_sum_rr"] > 0) 
cell_exists_obs = (DATASET_CELL_STATE["obs_sum_rr"] > 0)

cell_exists_nowcast = cell_exists_nowcast.where(track_exists_growth_cond, 2)
cell_exists_obs = cell_exists_obs.where(track_exists_growth_cond, 2)

contingency_growth = xs.Contingency(
    cell_exists_obs, 
    cell_exists_nowcast, 
    np.array([0, 0.5, 1.0]), 
    np.array([0, 0.5, 1.0]), 
    dim=["sample", "track"],
)
contingency_growth.table

track_exists_decay_cond = (DATASET_CELL_STATE["track_max_prev_rr"] > 0) & (DATASET_CELL_STATE["state"] == "decay")

cell_exists_nowcast = (DATASET_CELL_STATE["pred_sum_rr"] > 0) 
cell_exists_obs = (DATASET_CELL_STATE["obs_sum_rr"] > 0)

cell_exists_nowcast = cell_exists_nowcast.where(track_exists_decay_cond, 2)
cell_exists_obs = cell_exists_obs.where(track_exists_decay_cond, 2)

contingency_decay = xs.Contingency(
    cell_exists_obs, 
    cell_exists_nowcast, 
    np.array([0, 0.5, 1.0]), 
    np.array([0, 0.5, 1.0]), 
    dim=["sample", "track"],
)
contingency_decay.table

In [9]:
# bar plot of fractions of hits, misses, false alarms, and correct negatives
# for each leadtime

contingency_tables = {
    "all": contingency_all, 
    "decay": contingency_decay, 
    "growth": contingency_growth
}

# fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharey=True)

fig, axs = plt.subplots(
    ncols=len(contingency_tables),
    nrows=1,
    # figsize=(W_PER_METHOD_S * N_METHODS, FIG_HEIGHT * len(groups)),
    figsize=((FIG_WIDTH+0.2) * len(contingency_tables), FIG_HEIGHT * 1),
    constrained_layout=True,
    sharey="row",
    squeeze=False
)

sorter = np.argsort(np.array(conf.legend_order))

n_leadtimes = DATASET_BASE.leadtime.values.size

store_dfs = {}

for i, (name, contingency) in enumerate(contingency_tables.items()):
    ax = axs[0, i]
    
    hits = contingency.hits()
    misses = contingency.misses()
    false_alarms = contingency.false_alarms()
    correct_non_alarms = contingency.correct_negatives()

    df = pd.concat([hits.to_dataframe(), misses.to_dataframe(), false_alarms.to_dataframe(), correct_non_alarms.to_dataframe()], axis=1)
    df.columns = ["hits", "misses", "false_alarms", "correct_negatives"]

    df["total"] = df["hits"] + df["misses"] + df["false_alarms"] + df["correct_negatives"]

    df["hits_frac"] = df["hits"]# / df["total"]
    df["misses_frac"] = df["misses"]# / df["total"]
    df["false_alarms_frac"] = df["false_alarms"]# / df["total"]
    df["correct_negatives_frac"] = df["correct_negatives"]# / df["total"]

    df["hits_misses"] = df["hits_frac"] + df["misses_frac"]
    df["hits_misses_falsealarms"] = df["hits_frac"] + df["misses_frac"] + df["false_alarms_frac"]
    df["hits_misses_correctnegatives"] = df["hits_frac"] + df["misses_frac"] + df["false_alarms_frac"] + df["correct_negatives_frac"]
    
    df.sort_values(
        by="method",
        key=lambda x: sorter[
            np.searchsorted(np.array(conf.legend_order), df.index.get_level_values(1), sorter=sorter)
        ],
        inplace=True,
    )
    
    store_dfs[name] = df.copy()
    
    # This shows up as correct negatives
    g_cnegs = sns.barplot(
        ax=ax, 
        data=df, 
        x="method", 
        hue="leadtime", 
        y="hits_misses_correctnegatives", 
        palette=["k"]*n_leadtimes, 
        edgecolor="tab:gray", 
        linewidth=0.5, 
        legend=False,
    )
    # False alarms
    g_falarms = sns.barplot(
        ax=ax, 
        data=df, 
        x="method", 
        hue="leadtime", 
        y="hits_misses_falsealarms", 
        palette=["w"]*n_leadtimes, 
        edgecolor="black", 
        linewidth=0.5, 
        legend=False,
    )
    # Misses
    g_misses = sns.barplot(
        ax=ax, 
        data=df, 
        x="method", 
        hue="leadtime", 
        y="hits_misses", 
        palette=["tab:gray"]*n_leadtimes, 
        edgecolor="black", 
        linewidth=0.5, 
        legend=False,
    )
    # Hits
    g_hits = sns.barplot(
        ax=ax, 
        data=df, 
        x="method", 
        hue="leadtime", 
        y="hits_frac", 
        palette=HUE_CMAP, 
        edgecolor="black", 
        linewidth=0.5, 
        legend="full"
    )

    ax.set_title(f"({alphabet[i]}) {STATE_GROUP_TITLES[name]}")
    ax.set_xticklabels(
        [get_labelstr(l.get_text()) for l in ax.get_xticklabels()]
    )
    g_hits.axes.get_legend().remove()

h, l = axs.flatten()[-1].get_legend_handles_labels()
l1 = fig.legend(
    h,
    [leadtime_to_minutes((int(s)), 0) for s in l],
    title="Leadtime [min]",
    bbox_to_anchor=(0.43, 1.07),
    loc="center left",
    frameon=True,
    bbox_transform=fig.transFigure,
    ncols=6,
)
fig.add_artist(l1)

for ax in axs.flatten():
    ax.set_autoscale_on(False)
    ax.set_ylim(0, 110e3)
    ax.set_ylabel(COUNT_TITLE)
    # ax.set_ylim(LIFETIME_LIMITS)
    # ax.yaxis.set_major_locator(ticker.MultipleLocator(LIFETIME_TICK_MULTIPLE))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(25e3))
    ax.yaxis.set_minor_locator(ticker.MultipleLocator(5e3))
    ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, p: f"{x / 1000:.0f}"))
    ax.grid(which="major", axis="y")
    ax.grid(which="minor", axis="y", alpha=0.2)
    ax.set_autoscale_on(False)
    ax.set_xlabel(METHOD_X_LABEL)
    
# Make legend for hits, misses, false alarms, correct negatives labels
palette_hue = sns.color_palette(HUE_CMAP, n_leadtimes)
hits_patch = [patches.Patch(facecolor=c, edgecolor=c, label="Hits") for c in palette_hue]

misses_patch = patches.Patch(facecolor="tab:gray", edgecolor="tab:gray", label="Misses")
falarms_patch = patches.Patch(facecolor="white", edgecolor="k", label="False alarms")
cnegs_patch = patches.Patch(facecolor="k", edgecolor="tab:gray", label="Correct negatives")

other_patches = [misses_patch, falarms_patch, cnegs_patch]

leg = fig.legend(
    handles=[hits_patch, *other_patches], 
    labels=["Hits", *[p.get_label() for p in other_patches]], 
    handler_map={list: HandlerTuple(ndivide=None, pad=0)},
    ncols=2,
    bbox_to_anchor=(0.23, 1.07),
    loc="center left",
    frameon=True,
    bbox_transform=fig.transFigure,
)

outputname = "cell_existence_counts"
save_figs(fig, OUTPUT_DIR, outputname, conf.output_formats)

  ax.set_xticklabels(
  ax.set_xticklabels(
  ax.set_xticklabels(


In [10]:
# Save to csv
store_df = pd.concat(store_dfs.values(), keys=store_dfs.keys(), axis=0).reset_index()
store_df.rename(columns={"level_0": "state"}, inplace=True)

outputname = "cell_existence_counts"
store_df.to_csv(OUTPUT_DIR / f"{outputname}.csv")

In [11]:
# metrics = ["CSI", "ETS", "POD", "FAR"]
metrics = ["CSI", "BIAS", "POD", "FAR"]

conf.metric_conf["BIAS"]["limits"] = [0.5, 6.0]
conf.metric_conf["BIAS"]["ticks"] = [0.5, 0.25]
                         
contingency_tables = {
    "all": contingency_all, 
    "decay": contingency_decay, 
    "growth": contingency_growth
}

fig, axs = plt.subplots(
    figsize=(len(contingency_tables.keys())*5, len(metrics)*5), 
    nrows=len(metrics), 
    ncols=len(contingency_tables.keys()), 
    layout="compressed",
    sharey="row",
    sharex=True,
)

store_dfs = {}

for j, (name, contingency) in enumerate(contingency_tables.items()):

    dfs = {
        "CSI": contingency.threat_score(),
        # "ETS": contingency.equit_threat_score(),
        "POD": contingency.hit_rate(),
        "FAR": contingency.false_alarm_ratio(),
        "BIAS": contingency.bias_score(),
    }
    store_dfs[name] = dfs.copy()

    for i, metric in enumerate(metrics):
        ax = axs[i, j]
        df_ = dfs[metric]
        # Change leadtime to minutes
        df_["leadtime"] = df_["leadtime"] * 5
        for model in conf.legend_order:
            df_.sel(dict(method=model)).plot.line(
                ax=ax,
                x="leadtime",
                color=conf.methods[model]["color"],
                label=conf.methods[model]["label"],
                linestyle=conf.methods[model]["linestyle"],
            )
        set_ax(ax, conf.metric_conf[metric], [5, 60], conf.leadtime_locator_multiples)
        ax.set_ylabel(conf.metric_conf[metric]["label"])
        ax.set_title(f'({alphabet[i*3 + j]}) {STATE_GROUP_TITLES[name]}: {conf.metric_conf[metric]["full_name"]}', color=plt.rcParams["axes.titlecolor"])
        ax.grid(which="both", axis="both")
        ax.legend()
        ax.label_outer()

    # fig.suptitle(
    #     "Cell Existence", 
    #     color=plt.rcParams["axes.titlecolor"],
    #     fontsize=14
    # )

outputname = "cell_existence_growth_decay"
save_figs(fig, OUTPUT_DIR, outputname, conf.output_formats)

In [12]:
# Save to csv
store_df = pd.concat([xr.merge([v.rename(k) for k, v in store_dfs[state].items()]).to_dataframe() for state in store_dfs.keys()], keys=store_dfs.keys(), axis=0).reset_index()

store_df.rename(columns={"level_0": "state"}, inplace=True)
outputname = "cell_existence_growth_decay"
store_df.to_csv(OUTPUT_DIR / f"{outputname}.csv")

## Contingency scores of cell track decay/growth classification

In [13]:
contingency = xs.Contingency(
    DATASET_CELL_STATE["state_int"], 
    DATASET_CELL_STATE["state_pred_int"], 
    np.array([1, 2, 2.5]), 
    np.array([1, 2, 2.5]), 
    dim=["sample", "track"],
)
# contingency.table.to_dataframe()


scores = {
    "csi_growth": contingency.threat_score(yes_category=1),
    "csi_decay": contingency.threat_score(yes_category=2),
    "pod_growth": contingency.hit_rate(yes_category=1),
    "pod_decay": contingency.hit_rate(yes_category=2),
    
    "ets": contingency.equit_threat_score(yes_category=1),
    "gerrity": contingency.gerrity_score(),
    # "peirce": contingency.peirce_score(),
    # "heidke": contingency.heidke_score(),
    
    "bias_growth": contingency.bias_score(yes_category=1),
    "bias_decay": contingency.bias_score(yes_category=2),
    
    "far_growth": contingency.false_alarm_ratio(yes_category=1),
    "far_decay": contingency.false_alarm_ratio(yes_category=2),
}

score_names = {
    "csi_growth": "CSI for growth",
    "csi_decay": "CSI for decay",
    "pod_growth": "POD for growth",
    "pod_decay": "POD for decay",
    "far_growth": "FAR for growth",
    "far_decay": "FAR for decay",
    "bias_growth": "BIAS for growth",
    "bias_decay": "BIAS for decay",
    "ets": "ETS",
    "gerrity": "Gerrity score",
    # "peirce": "Peirce's skill score",
    # "heidke": "Heidke skill score",
}

print_df = pd.DataFrame(columns=conf.legend_order, index=list(scores.keys()))
for score, df_ in scores.items():
    for model in conf.legend_order:
        print_df.loc[score, model] = df_.loc[model].item()
        
df_ = print_df.melt(ignore_index=False).reset_index()
# df_

yes_values = {
    "decay": 2, 
    "growth": 1,
}

sorter = np.argsort(np.array(conf.legend_order))
n_leadtimes = DATASET_BASE.leadtime.values.size

In [14]:
fig = plt.figure(layout="constrained", figsize=(12, 11))

axs = fig.subplot_mosaic([
        ["decay", "decay", "growth","growth"],
        [".", "scores", "scores", "."],
    ], 
    width_ratios=[0.01, 0.49, 0.49, 0.01],
    height_ratios=[0.4, 0.6],
    # gridspec_kw={'wspace': 0.0, "w_pad": 0}
)

store_dfs = {}

for i, (name, yes) in enumerate(yes_values.items()):
    ax = axs[name]
    
    hits = contingency.hits(yes_category=yes)
    misses = contingency.misses(yes_category=yes)
    false_alarms = contingency.false_alarms(yes_category=yes)
    correct_non_alarms = contingency.correct_negatives(yes_category=yes)

    df = pd.concat([hits.to_dataframe(), misses.to_dataframe(), false_alarms.to_dataframe(), correct_non_alarms.to_dataframe()], axis=1)
    df.columns = ["hits", "misses", "false_alarms", "correct_negatives"]

    df["total"] = df["hits"] + df["misses"] + df["false_alarms"] + df["correct_negatives"]

    df["hits_frac"] = df["hits"]# / df["total"]
    df["misses_frac"] = df["misses"]# / df["total"]
    df["false_alarms_frac"] = df["false_alarms"]# / df["total"]
    df["correct_negatives_frac"] = df["correct_negatives"]# / df["total"]

    df["hits_misses"] = df["hits_frac"] + df["misses_frac"]
    df["hits_misses_falsealarms"] = df["hits_frac"] + df["misses_frac"] + df["false_alarms_frac"]
    df["hits_misses_correctnegatives"] = df["hits_frac"] + df["misses_frac"] + df["false_alarms_frac"] + df["correct_negatives_frac"]
    
    df.sort_values(
        by="method",
        key=lambda x: sorter[
            np.searchsorted(np.array(conf.legend_order), df.index.get_level_values(0), sorter=sorter)
        ],
        inplace=True,
    )
    store_dfs[name] = df.copy()
    
    # This shows up as correct negatives
    g_cnegs = sns.barplot(
        ax=ax, 
        data=df, 
        x="method", 
        hue="method",
        y="hits_misses_correctnegatives", 
        palette=["k"]*n_leadtimes, 
        edgecolor="tab:gray", 
        linewidth=0.5, 
        legend=False,
    )
    # False alarms
    g_falarms = sns.barplot(
        ax=ax, 
        data=df, 
        x="method", 
        hue="method",
        y="hits_misses_falsealarms", 
        palette=["w"]*n_leadtimes, 
        edgecolor="black", 
        linewidth=0.5, 
        legend=False,
    )
    # Misses
    g_misses = sns.barplot(
        ax=ax, 
        data=df, 
        x="method", 
        hue="method",  
        y="hits_misses", 
        palette=["tab:gray"]*n_leadtimes, 
        edgecolor="black", 
        linewidth=0.5, 
        legend=False,
    )
    # Hits
    g_hits = sns.barplot(
        ax=ax, 
        data=df, 
        x="method", 
        hue="method", 
        y="hits_frac", 
        palette=COLORS_METHODS, 
        edgecolor="black", 
        linewidth=0.5, 
        legend="full"
    )
    ax.set_title(f"({alphabet[i]}) {STATE_GROUP_TITLES[name]} classification track counts")
    ax.set_xticklabels(
        [get_labelstr(l.get_text()) for l in ax.get_xticklabels()]
    )
    g_hits.axes.get_legend().remove()

for ax in [axs["decay"], axs["growth"]]:
    ax.set_autoscale_on(False)
    ax.set_ylim(0, 115e3)
    # ax.set_ylim(LIFETIME_LIMITS)
    # ax.yaxis.set_major_locator(ticker.MultipleLocator(LIFETIME_TICK_MULTIPLE))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(25e3))
    ax.yaxis.set_minor_locator(ticker.MultipleLocator(5e3))
    ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, p: f"{x / 1000:.0f}"))
    ax.grid(which="major", axis="y")
    ax.grid(which="minor", axis="y", alpha=0.2)
    ax.set_autoscale_on(False)
    ax.set_xlabel(METHOD_X_LABEL)
    ax.set_ylabel(COUNT_TITLE)
    # ax.label_outer()
    
# Make legend for hits, misses, false alarms, correct negatives labels
palette_hue = sns.color_palette([COLORS_METHODS[n] for n in METHODS])
hits_patch = [patches.Patch(facecolor=c, edgecolor=c, label="Hits") for c in palette_hue]

misses_patch = patches.Patch(facecolor="tab:gray", edgecolor="tab:gray", label="Misses")
falarms_patch = patches.Patch(facecolor="white", edgecolor="k", label="False alarms")
cnegs_patch = patches.Patch(facecolor="k", edgecolor="tab:gray", label="Correct negatives")

other_patches = [misses_patch, falarms_patch, cnegs_patch]

leg = fig.legend(
    handles=[hits_patch, *other_patches], 
    labels=["Hits", *[p.get_label() for p in other_patches]], 
    handler_map={list: HandlerTuple(ndivide=None, pad=0)},
    ncols=2,
    bbox_to_anchor=(0.51, 1.02),
    loc="center",
    frameon=True,
    bbox_transform=fig.transFigure,
)
axs["decay"].sharey(axs["growth"])
axs["growth"].sharey(axs["decay"])

# Third panel for scores
g = sns.barplot(
    data=df_,
    x="value",
    y="index",
    hue="variable",
    palette=COLORS_METHODS,
    ax=axs["scores"],
    width=0.7,
    edgecolor="k",
    linewidth=0.5,
    gap=0.05,
)

offsets = [-0.225, -0.03, 0.15, 0.34]

# Label bars so that the highest value is bold
for ind, row in df_.iterrows():
    val = row["value"]
    model = row["variable"]
    model_ind = conf.legend_order.index(model)
    score = row["index"]
    score_ind = list(scores.keys()).index(score)
    fontweight = "normal"
    if "bias" in score:
        df_["diff_from_one"] = np.abs(df_["value"] - 1).astype(float)
        if val == df_.loc[df_.groupby("index")["diff_from_one"].idxmin(axis=0).loc[score], "value"]:
            fontweight = "bold"
    elif "far" not in score:
        if val == df_.groupby("index").max().loc[score, "value"]:
            fontweight = "bold"
    else:
        if val == df_.groupby("index").min().loc[score, "value"]:
            fontweight = "bold"
        
    t = axs["scores"].text(val + 0.01, score_ind + offsets[model_ind], f"{val:.3f}", fontweight=fontweight, fontsize="x-small")
    t.set_linespacing(1.0)
    # t.set_bbox(dict(facecolor='white', alpha=0.8, edgecolor='white'))
    # t.set_backgroundcolor("white")
    
axs["scores"].set_xlabel("Value")
axs["scores"].set_ylabel("")
axs["scores"].set_xlim(0, 1.3)
axs["scores"].xaxis.set_major_locator(plt.MultipleLocator(0.1))
axs["scores"].xaxis.set_minor_locator(plt.MultipleLocator(0.05))
axs["scores"].grid(which="both", axis="x")
axs["scores"].axvline(1, color="gray", linewidth=1.0, linestyle="--")

# horizontal line before bias scores
axs["scores"].axhline(
    y=list(scores.keys()).index("bias_growth") - 0.5, 
    color="gray", 
    linewidth=1.0, linestyle="--"
)

# horizontal line before far scores
axs["scores"].axhline(
    y=list(scores.keys()).index("far_growth") - 0.5, 
    color="gray", 
    linewidth=1.0, linestyle="--"
)

# Set y axis labels
axs["scores"].set_yticklabels([score_names[l_.get_text()] for l_ in axs["scores"].get_yticklabels()])

# Set title
axs["scores"].set_title("(c) Cell track classification metric values")

h, l = axs["scores"].get_legend_handles_labels()
l = [get_labelstr(l_) for l_ in l]
axs["scores"].legend(h, l, loc="upper right")

outputname = "growth_decay_track_classification"
save_figs(fig, OUTPUT_DIR, outputname, conf.output_formats)


  g_cnegs = sns.barplot(
  g_falarms = sns.barplot(
  g_misses = sns.barplot(
  ax.set_xticklabels(
  g_cnegs = sns.barplot(
  g_falarms = sns.barplot(
  g_misses = sns.barplot(
  ax.set_xticklabels(
  axs["scores"].set_yticklabels([score_names[l_.get_text()] for l_ in axs["scores"].get_yticklabels()])


In [15]:
# Save to csv

# Cell counts
store_df = pd.concat(store_dfs.values(), keys=store_dfs.keys(), axis=0)[["hits", "misses", "false_alarms", "correct_negatives"]].reset_index()
store_df.rename(columns={"level_0": "state"}, inplace=True)

outputname = "growth_decay_track_classification_ab_counts"
store_df.to_csv(OUTPUT_DIR / f"{outputname}.csv")

# Scores
outputname = "growth_decay_track_classification_c_scores"
df_.to_csv(OUTPUT_DIR / f"{outputname}.csv")

## Area difference for all cells and by cell state

In [16]:
ds_cs = DATASET_CELL_STATE[["pred_area", "obs_area", "area_diff", "state", "max_area"]].where((DATASET_CELL_STATE["pred_area"] > 0) & (DATASET_CELL_STATE["obs_area"] > 0))
ds_cs = ds_cs.to_dataframe().reset_index()
ds_cs = ds_cs.drop_duplicates(subset=["sample", "track", "leadtime", "method"])
groups = ds_cs.groupby("state")

fig, axs = plt.subplots(
    ncols=len(groups) + 1,
    nrows=1,
    # figsize=(W_PER_METHOD_S * N_METHODS, FIG_HEIGHT * len(groups)),
    figsize=((FIG_WIDTH+0.2) * len(groups) + 1, FIG_HEIGHT * 1),
    constrained_layout=True,
    sharey="row",
    squeeze=False
)

ds_ = DATASET_BASE[["pred_area", "obs_area", "area_diff", "max_area"]].where((DATASET_BASE["pred_area"] > 0) & (DATASET_BASE["obs_area"] > 0))
ds_ = ds_.to_dataframe().reset_index()
ds_ = ds_.drop_duplicates(subset=["sample", "track", "leadtime", "method"])

g = sns.boxplot(
    x="method",
    y="area_diff",
    order=conf.legend_order,
    hue="leadtime",
    data=ds_,
    ax=axs[0,0],
    whis=[5, 95],
    showfliers=True,
    showmeans=True,
    meanline=True,
    medianprops=MEDIANPROPS,
    meanprops=MEANLINEPROPS,
    flierprops=FLIERPROPS,
    legend="full",
    palette=HUE_CMAP,
)
axs[0,0].set_xticklabels([get_labelstr(l.get_text()) for l in axs[0,0].get_xticklabels()])
g.axes.get_legend().remove()
axs[0,0].set_title("(a) All cells")

for i, (name, group) in enumerate(groups):
    ax = axs[0, 1 + i]
    g = sns.boxplot(
        x="method",
        y="area_diff",
        hue="leadtime",
        order=conf.legend_order,
        data=group,
        ax=ax,
        whis=[5, 95],
        showfliers=True,
        showmeans=True,
        meanline=True,
        medianprops=MEDIANPROPS,
        meanprops=MEANLINEPROPS,
        flierprops=FLIERPROPS,
        legend="full",
        palette=HUE_CMAP,
    )
    ax.set_title(f"({alphabet[i+1]}) {STATE_GROUP_TITLES[name]}")
    ax.set_xticklabels(
        [get_labelstr(l.get_text()) for l in ax.get_xticklabels()]
    )
    g.axes.get_legend().remove()

for ax in axs.flatten():
    ax.set_autoscale_on(False)

    # Add legend for mean and median lines

    ax.set_ylabel(AREA_TITLE)
    ax.set_ylim(AREA_LIMITS)
    ax.yaxis.set_major_locator(ticker.MultipleLocator(AREA_TICK_MULTIPLE))
    ax.grid(axis="y")
    ax.axhline(0, **ZEROLINE_PROPS)
    ax.set_xlabel(METHOD_X_LABEL)

medline = axs[0, 0].plot([], [], **MEDIANPROPS, label="Median")
meanline = axs[0, 0].plot([], [], **MEANLINEPROPS, label="Mean")
h, l = axs[0,0].get_legend_handles_labels()
l1 = fig.legend(
    h[:-2],
    [leadtime_to_minutes((int(s)), 0) for s in l[:-2]],
    title="Leadtime [min]",
    bbox_to_anchor=(0.38, 1.07, 0, 0),
    loc="center left",
    frameon=True,
    bbox_transform=fig.transFigure,
    ncols=6,
    # bbox_to_anchor=(0.7, 0.85),
    # loc="upper left",
    # frameon=True,
    # fancybox=True,
    labelspacing=0.2,
    # bbox_transform=fig.transFigure,
)
l2 = fig.legend(
    h[-2:],
    l[-2:],
    bbox_to_anchor=(0.36, 1.07, 0, 0),
    loc="center right",
    frameon=True,
    bbox_transform=fig.transFigure,
    ncols=1,
)
fig.add_artist(l1)

# axs[0, 1].remove()

outputname = "area_diff_article"

save_figs(fig, OUTPUT_DIR, outputname, conf.output_formats)


  axs[0,0].set_xticklabels([get_labelstr(l.get_text()) for l in axs[0,0].get_xticklabels()])
  ax.set_xticklabels(
  ax.set_xticklabels(


In [17]:
# Save to csv
store_df = pd.concat([ds_.groupby(["leadtime", "method"]).describe(percentiles=[0.05, 0.25, 0.5, 0.75, 0.95])["area_diff"], *[g.groupby(["leadtime", "method"]).describe(percentiles=[0.05, 0.25, 0.5, 0.75, 0.95])["area_diff"] for k, g in groups]], keys=["all"] + [k for k, g in groups], axis=0).reset_index()
store_df.rename(columns={"level_0": "state"}, inplace=True)

outputname = "area_diff_article"
store_df.to_csv(OUTPUT_DIR / f"{outputname}.csv")

## Mean rain rate difference for all cells and by cell state

In [18]:
ds_cs = DATASET_CELL_STATE[["pred_mean_rr", "obs_mean_rr", "mean_rr_diff", "state"]].where((DATASET_CELL_STATE["pred_mean_rr"] > 0) & (DATASET_CELL_STATE["obs_mean_rr"] > 0))
ds_cs = ds_cs.to_dataframe().reset_index()
ds_cs = ds_cs.drop_duplicates(subset=["sample", "track", "leadtime", "method"])
groups = ds_cs.groupby("state")

fig, axs = plt.subplots(
    ncols=len(groups) + 1,
    nrows=1,
    # figsize=(W_PER_METHOD_S * N_METHODS, FIG_HEIGHT * len(groups)),
    figsize=((FIG_WIDTH+0.2) * len(groups) + 1, FIG_HEIGHT * 1),
    constrained_layout=True,
    sharey="row",
    squeeze=False
)

ds_ = DATASET_BASE[["pred_mean_rr", "obs_mean_rr", "mean_rr_diff"]].where((DATASET_BASE["pred_mean_rr"] > 0) & (DATASET_BASE["obs_mean_rr"] > 0))
ds_ = ds_.to_dataframe().reset_index()
ds_ = ds_[(ds_["pred_mean_rr"] > 0) & (ds_["obs_mean_rr"] > 0)]
ds_ = ds_.drop_duplicates(subset=["sample", "track", "leadtime", "method"])

g = sns.boxplot(
    x="method",
    y="mean_rr_diff",
    hue="leadtime",
    order=conf.legend_order,
    data=ds_,
    ax=axs[0, 0],
    whis=[5, 95],
    showfliers=True,
    showmeans=True,
    meanline=True,
    medianprops=MEDIANPROPS,
    meanprops=MEANLINEPROPS,
    flierprops=FLIERPROPS,
    legend="full",
    palette=HUE_CMAP,
)
axs[0, 0].set_xticklabels([get_labelstr(l.get_text()) for l in axs[0, 0].get_xticklabels()])
g.axes.get_legend().remove()

axs[0, 0].set_autoscale_on(False)
axs[0, 0].set_title("(a) All cells")


for i, (name, group) in enumerate(groups):
    ax = axs[0, 1 + i]
    g = sns.boxplot(
        x="method",
        y="mean_rr_diff",
        hue="leadtime",
        order=conf.legend_order,
        data=group,
        ax=ax,
        whis=[5, 95],
        showfliers=True,
        showmeans=True,
        meanline=True,
        medianprops=MEDIANPROPS,
        meanprops=MEANLINEPROPS,
        flierprops=FLIERPROPS,
        legend="full",
        palette=HUE_CMAP,
    )
    ax.set_xticklabels(
        [get_labelstr(l.get_text()) for l in ax.get_xticklabels()]
    )
    g.axes.get_legend().remove()
    ax.set_title(f"({alphabet[i+1]}) {STATE_GROUP_TITLES[name]}")
    
for ax in axs.flatten():
    ax.set_autoscale_on(False)
    ax.set_ylabel(MEAN_RR_DIFF_TITLE)
    ax.set_ylim(MEAN_RR_LIMITS)
    ax.yaxis.set_major_locator(ticker.MultipleLocator(MEAN_RR_TICK_MULTIPLE))
    ax.grid(axis="y")
    ax.set_autoscale_on(False)
    ax.axhline(0, **ZEROLINE_PROPS)
    ax.set_xlabel(METHOD_X_LABEL)


# Add legend for mean and median lines
medline = axs[0,0].plot([], [], **MEDIANPROPS, label="Median")
meanline = axs[0,0].plot([], [], **MEANLINEPROPS, label="Mean")

h, l = axs[0,0].get_legend_handles_labels()
l1 = fig.legend(
    h[:-2],
    [leadtime_to_minutes((int(s)), 0) for s in l[:-2]],
    title="Leadtime [min]",
    bbox_to_anchor=(0.38, 1.07, 0, 0),
    loc="center left",
    frameon=True,
    bbox_transform=fig.transFigure,
    ncols=6,
    # bbox_to_anchor=(0.7, 0.85),
    # loc="upper left",
    # frameon=True,
    # fancybox=True,
    labelspacing=0.2,
    # bbox_transform=fig.transFigure,
)
l2 = fig.legend(
    h[-2:],
    l[-2:],
    bbox_to_anchor=(0.36, 1.07, 0, 0),
    loc="center right",
    frameon=True,
    bbox_transform=fig.transFigure,
    ncols=1,
)
fig.add_artist(l1)
                        
# axs[0,1].remove()

outputname = "mean_rr_diff_article"
save_figs(fig, OUTPUT_DIR, outputname, conf.output_formats)

  axs[0, 0].set_xticklabels([get_labelstr(l.get_text()) for l in axs[0, 0].get_xticklabels()])
  ax.set_xticklabels(
  ax.set_xticklabels(


In [19]:
# Save to csv
store_df = pd.concat([ds_.groupby(["leadtime", "method"]).describe(percentiles=[0.05, 0.25, 0.5, 0.75, 0.95])["mean_rr_diff"], *[g.groupby(["leadtime", "method"]).describe(percentiles=[0.05, 0.25, 0.5, 0.75, 0.95])["mean_rr_diff"] for k, g in groups]], keys=["all"] + [k for k, g in groups], axis=0).reset_index()
store_df.rename(columns={"level_0": "state"}, inplace=True)

outputname = "mean_rr_diff_article"
store_df.to_csv(OUTPUT_DIR / f"{outputname}.csv")

## RVR difference for all cells and by cell state

In [20]:
ds_cs = DATASET_CELL_STATE[["pred_sum_rr", "obs_sum_rr", "sum_rr_diff", "state"]].where((DATASET_CELL_STATE["pred_sum_rr"] > 0) | (DATASET_CELL_STATE["obs_sum_rr"] > 0))
ds_cs = ds_cs.to_dataframe().reset_index()
ds_cs = ds_cs.drop_duplicates(subset=["sample", "track", "leadtime", "method"])
groups = ds_cs.groupby("state")

fig, axs = plt.subplots(
    ncols=len(groups) + 1,
    nrows=1,
    # figsize=(W_PER_METHOD_S * N_METHODS, FIG_HEIGHT * len(groups)),
    figsize=((FIG_WIDTH+0.2) * len(groups) + 1, FIG_HEIGHT * 1),
    constrained_layout=True,
    sharey="row",
    squeeze=False
)

ds_ = DATASET_BASE[["pred_sum_rr", "obs_sum_rr", "sum_rr_diff"]].where((DATASET_BASE["pred_sum_rr"] > 0) | (DATASET_BASE["obs_sum_rr"] > 0))
ds_ = ds_.to_dataframe().reset_index()
ds_ = ds_[(ds_["pred_sum_rr"] > 0) & (ds_["obs_sum_rr"] > 0)]
ds_ = ds_.drop_duplicates(subset=["sample", "track", "leadtime", "method"])

g = sns.boxplot(
    x="method",
    y="sum_rr_diff",
    hue="leadtime",
    order=conf.legend_order,
    data=ds_,
    ax=axs[0, 0],
    whis=[5, 95],
    showfliers=True,
    showmeans=True,
    meanline=True,
    medianprops=MEDIANPROPS,
    meanprops=MEANLINEPROPS,
    flierprops=FLIERPROPS,
    legend="full",
    palette=HUE_CMAP,
)
axs[0, 0].set_xticklabels([get_labelstr(l.get_text()) for l in axs[0, 0].get_xticklabels()])
g.axes.get_legend().remove()

axs[0, 0].set_autoscale_on(False)
axs[0, 0].set_title("(a) All cells")

for i, (name, group) in enumerate(groups):
    ax = axs[0, 1 + i]
    g = sns.boxplot(
        x="method",
        y="sum_rr_diff",
        hue="leadtime",
        order=conf.legend_order,
        data=group,
        ax=ax,
        whis=[5, 95],
        showfliers=True,
        showmeans=True,
        meanline=True,
        medianprops=MEDIANPROPS,
        meanprops=MEANLINEPROPS,
        flierprops=FLIERPROPS,
        legend="full",
        palette=HUE_CMAP,
    )
    ax.set_xticklabels(
        [get_labelstr(l.get_text()) for l in ax.get_xticklabels()]
    )
    g.axes.get_legend().remove()
    ax.set_title(f"({alphabet[i+1]}) {STATE_GROUP_TITLES[name]}")
    
for ax in axs.flatten():
    ax.set_autoscale_on(False)
    ax.set_ylabel(SUM_RR_DIFF_TITLE)
    ax.set_ylim(SUM_RR_LIMITS)
    ax.yaxis.set_major_locator(ticker.MultipleLocator(SUM_RR_TICK_MULTIPLE))
    ax.grid(axis="y")
    ax.set_autoscale_on(False)
    ax.axhline(0, **ZEROLINE_PROPS)
    ax.set_xlabel(METHOD_X_LABEL)

# Add legend for mean and median lines
medline = axs[0,0].plot([], [], **MEDIANPROPS, label="Median")
meanline = axs[0,0].plot([], [], **MEANLINEPROPS, label="Mean")

h, l = axs[0,0].get_legend_handles_labels()
l1 = fig.legend(
    h[:-2],
    [leadtime_to_minutes((int(s)), 0) for s in l[:-2]],
    title="Leadtime [min]",
    bbox_to_anchor=(0.38, 1.07, 0, 0),
    loc="center left",
    frameon=True,
    bbox_transform=fig.transFigure,
    ncols=6,
    # bbox_to_anchor=(0.7, 0.85),
    # loc="upper left",
    # frameon=True,
    # fancybox=True,
    labelspacing=0.2,
    # bbox_transform=fig.transFigure,
)
l2 = fig.legend(
    h[-2:],
    l[-2:],
    bbox_to_anchor=(0.36, 1.07, 0, 0),
    loc="center right",
    frameon=True,
    bbox_transform=fig.transFigure,
    ncols=1,
)
fig.add_artist(l1)
                        
# axs[0,1].remove()

outputname = "sum_rr_diff_article"
save_figs(fig, OUTPUT_DIR, outputname, conf.output_formats)

  axs[0, 0].set_xticklabels([get_labelstr(l.get_text()) for l in axs[0, 0].get_xticklabels()])
  ax.set_xticklabels(
  ax.set_xticklabels(


In [21]:
# Save to csv
store_df = pd.concat([ds_.groupby(["leadtime", "method"]).describe(percentiles=[0.05, 0.25, 0.5, 0.75, 0.95])["sum_rr_diff"], *[g.groupby(["leadtime", "method"]).describe(percentiles=[0.05, 0.25, 0.5, 0.75, 0.95])["sum_rr_diff"] for k, g in groups]], keys=["all"] + [k for k, g in groups], axis=0).reset_index()
store_df.rename(columns={"level_0": "state"}, inplace=True)

outputname = "sum_rr_diff_article"
store_df.to_csv(OUTPUT_DIR / f"{outputname}.csv")

## Distance error

In [22]:
ds_cs = DATASET_CELL_STATE[["pred_sum_rr", "obs_sum_rr", "pred_dist", "state"]].where((DATASET_CELL_STATE["pred_sum_rr"] > 0) | (DATASET_CELL_STATE["obs_sum_rr"] > 0))
ds_cs = ds_cs.to_dataframe().reset_index()
ds_cs = ds_cs.drop_duplicates(subset=["sample", "track", "leadtime", "method"])
groups = ds_cs.groupby("state")

fig, axs = plt.subplots(
    ncols=len(groups) + 1,
    nrows=1,
    # figsize=(W_PER_METHOD_S * N_METHODS, FIG_HEIGHT * len(groups)),
    figsize=((FIG_WIDTH+0.2) * len(groups) + 1, FIG_HEIGHT * 1),
    constrained_layout=True,
    sharey="row",
    squeeze=False
)

ds_ = DATASET_BASE[["pred_sum_rr", "obs_sum_rr", "pred_dist"]].where((DATASET_BASE["pred_sum_rr"] > 0) | (DATASET_BASE["obs_sum_rr"] > 0))
ds_ = ds_.to_dataframe().reset_index()
ds_ = ds_[(ds_["pred_sum_rr"] > 0) & (ds_["obs_sum_rr"] > 0)]
ds_ = ds_.drop_duplicates(subset=["sample", "track", "leadtime", "method"])

g = sns.boxplot(
    x="method",
    y="pred_dist",
    hue="leadtime",
    order=conf.legend_order,
    data=ds_,
    ax=axs[0, 0],
    whis=[5, 95],
    showfliers=True,
    showmeans=True,
    meanline=True,
    medianprops=MEDIANPROPS,
    meanprops=MEANLINEPROPS,
    flierprops=FLIERPROPS,
    legend="full",
    palette=HUE_CMAP,
)
axs[0, 0].set_xticklabels([get_labelstr(l.get_text()) for l in axs[0, 0].get_xticklabels()])
g.axes.get_legend().remove()

axs[0, 0].set_autoscale_on(False)
axs[0, 0].set_title("(a) All cells")


for i, (name, group) in enumerate(groups):
    ax = axs[0, 1 + i]
    g = sns.boxplot(
        x="method",
        y="pred_dist",
        hue="leadtime",
        order=conf.legend_order,
        data=group,
        ax=ax,
        whis=[5, 95],
        showfliers=True,
        showmeans=True,
        meanline=True,
        medianprops=MEDIANPROPS,
        meanprops=MEANLINEPROPS,
        flierprops=FLIERPROPS,
        legend="full",
        palette=HUE_CMAP,
    )
    ax.set_xticklabels(
        [get_labelstr(l.get_text()) for l in ax.get_xticklabels()]
    )
    g.axes.get_legend().remove()
    ax.set_title(f"({alphabet[i+1]}) {STATE_GROUP_TITLES[name]}")
    
for ax in axs.flatten():
    ax.set_autoscale_on(False)
    ax.set_ylabel(CENTROID_DISTANCE_TITLE)
    ax.set_ylim(CENTROID_DISTANCE_LIMITS)
    ax.yaxis.set_major_locator(ticker.MultipleLocator(5))
    ax.grid(axis="y")
    ax.set_autoscale_on(False)
    ax.axhline(0, **ZEROLINE_PROPS)
    ax.set_xlabel(METHOD_X_LABEL)


# Add legend for mean and median lines
medline = axs[0,0].plot([], [], **MEDIANPROPS, label="Median")
meanline = axs[0,0].plot([], [], **MEANLINEPROPS, label="Mean")

h, l = axs[0,0].get_legend_handles_labels()
l1 = fig.legend(
    h[:-2],
    [leadtime_to_minutes((int(s)), 0) for s in l[:-2]],
    title="Leadtime [min]",
    bbox_to_anchor=(0.38, 1.07, 0, 0),
    loc="center left",
    frameon=True,
    bbox_transform=fig.transFigure,
    ncols=6,
    # bbox_to_anchor=(0.7, 0.85),
    # loc="upper left",
    # frameon=True,
    # fancybox=True,
    labelspacing=0.2,
    # bbox_transform=fig.transFigure,
)
l2 = fig.legend(
    h[-2:],
    l[-2:],
    bbox_to_anchor=(0.36, 1.07, 0, 0),
    loc="center right",
    frameon=True,
    bbox_transform=fig.transFigure,
    ncols=1,
)
fig.add_artist(l1)
                        
# axs[0,1].remove()

outputname = "centroid_distance_article"
save_figs(fig, OUTPUT_DIR, outputname, conf.output_formats)


  axs[0, 0].set_xticklabels([get_labelstr(l.get_text()) for l in axs[0, 0].get_xticklabels()])
  ax.set_xticklabels(
  ax.set_xticklabels(


In [23]:
# Save to csv
store_df = pd.concat([ds_.groupby(["leadtime", "method"]).describe(percentiles=[0.05, 0.25, 0.5, 0.75, 0.95])["pred_dist"], *[g.groupby(["leadtime", "method"]).describe(percentiles=[0.05, 0.25, 0.5, 0.75, 0.95])["pred_dist"] for k, g in groups]], keys=["all"] + [k for k, g in groups], axis=0).reset_index()
store_df.rename(columns={"level_0": "state"}, inplace=True)

outputname = "centroid_distance_article"
store_df.to_csv(OUTPUT_DIR / f"{outputname}.csv")

In [24]:
del ds_cs, ds_, groups

## Distributions of variables for all cells and by cell state

In [25]:
ds_cs = DATASET_CELL_STATE[["track_max_prev_rr", "track_max_obs_rr", "track_max_prev_rr", "sum_rr_diff", "prev_max_rr", "state", "max_obs_area", "max_prev_area", "max_area", "prev_sum_rr", "prev_mean_rr", "prev_area", "prev_mean_rr", "lifetime_full"]]
ds_cs = ds_cs.to_dataframe().reset_index()
ds_cs = ds_cs.drop_duplicates(subset=["sample", "track", "prev_time"])

ds_cs = ds_cs[ds_cs["prev_time"] == 0]

groups = ds_cs.groupby("state")

ds_ = DATASET_BASE[["track_max_prev_rr", "track_max_obs_rr",  "track_max_prev_rr", "sum_rr_diff", "prev_max_rr","max_obs_area", "max_prev_area", "max_area", "prev_sum_rr", "prev_mean_rr", "prev_area", "prev_mean_rr", "lifetime_full"]]
ds_ = ds_.to_dataframe().reset_index()
ds_ = ds_.drop_duplicates(subset=["sample", "track", "prev_time"])
ds_ = ds_[(ds_["prev_time"] == 0) & (ds_["track_max_prev_rr"] > 0) & (ds_["prev_sum_rr"] > 0)]

variables = [
    "prev_sum_rr", 
    "prev_area", 
    "max_area", 
    "lifetime_full",
]

fig, axs = plt.subplots(
    ncols=len(groups)+1,
    nrows=len(variables),
    # figsize=(W_PER_METHOD_S * N_METHODS, FIG_HEIGHT * len(groups)),
    figsize=(HIST_FIG_W * (len(groups)+1), HIST_FIG_H * len(variables) + 0.3),
    layout='compressed',
    sharey="row",
    sharex="row",
    squeeze=True
)


bin_ranges = {
    "prev_mean_rr": (0, 50),
    "prev_max_rr": (0, 125),
    "prev_area": (0, 1600),
    "lifetime_full": (0, 15),
    "max_area": (0, 1600),
    "prev_sum_rr": (0, 200),
    "sum_rr_diff": (-200, 400),
    "track_max_obs_rr": (0, 200),
    "track_max_prev_rr": (0, 200),
}
nbins = {
    "prev_mean_rr": 50,
    "prev_max_rr": 50,
    "prev_area": 160,
    "lifetime_full": 15,
    "max_area": 160,
    "prev_sum_rr": 400,
    "sum_rr_diff": 600,
    "track_max_obs_rr": 400,
    "track_max_prev_rr": 400,
}

titles = {
    "prev_mean_rr": "$R_\mathrm{avg}$($t_0$)",
    "prev_max_rr": "$R_\mathrm{max}$($t_0$)",
    "prev_area": "$A(t_0)$",
    "lifetime_full": "lifetime $L$",
    "max_area": "$A_\mathrm{max}$",
    "prev_sum_rr": "RVR($t_0$)",
    "sum_rr_diff": "",
    "track_max_obs_rr": "RVR$_\mathrm{max, target}$",
    "track_max_prev_rr": "RVR$_\mathrm{max, input}$",
}

discrete = {
    "prev_mean_rr": False,
    "prev_max_rr": False,
    "prev_area": False,
    "lifetime_full": True,
    "max_area": False,
    "prev_sum_rr": False,
    "track_max_obs_rr": False,
    "track_max_prev_rr": False,
}

count_loc = (0.95, 0.95)
props = dict(facecolor='white', alpha=0.5)

histograms = {}

for row, var in enumerate(variables):
    g = sns.histplot(
        data=ds_,
        x=var,
        ax=axs[row, 0],
        stat="percent",
        color="k",
        bins=nbins[var],
        binrange=bin_ranges[var],
        discrete=discrete[var],
    )
    histograms[var] = {}
    # Save histogram
    hist, bins = np.histogram(ds_[var], bins=nbins[var], range=bin_ranges[var])
    histograms[var]["all"] = pd.DataFrame(hist,  index=bins[:-1], columns=["count"])

    num_obs = ds_.count()[var]
    # Add label with number of observations
    axs[row, 0].text(
        *count_loc,
        f"N = {num_obs:,d}",
        horizontalalignment="right",
        verticalalignment="top",
        transform=axs[row, 0].transAxes,
        fontsize="medium",
        bbox=props,
    )
    axs[row, 0].set_title(f"({alphabet[3*row]}) All cells: {titles[var]}", fontsize="medium")
    
    for i, (name, group) in enumerate(groups):
        # Plot distribution
        g = sns.histplot(
            data=group,
            x=var,
            ax=axs[row, 1+i],
            stat="percent",
            color="k",
            bins=nbins[var],
            binrange=bin_ranges[var],
            discrete=discrete[var],
        )
        hist, bins = np.histogram(group[var], bins=nbins[var], range=bin_ranges[var])
        histograms[var][name] = pd.DataFrame(hist,  index=bins[:-1], columns=["count"])
        
        num_obs = group.count()[var]
        # Add label with number of observations
        axs[row, 1+i].text(
            *count_loc,
            f"N = {num_obs:,d}",
            horizontalalignment="right",
            verticalalignment="top",
            transform=axs[row, 1+i].transAxes,
            fontsize="medium",
            bbox=props,
        )
        axs[row, 1+i].set_title(f"({alphabet[3*row+i+1]}) {STATE_GROUP_TITLES[name]}: {titles[var]}", fontsize="medium")

# axis for rain rate sum
for ax in axs[variables.index("prev_sum_rr"), :].flatten():
    ax.set_ylabel("Proportion [%]")
    ax.set_ylim([0, 35])
    ax.yaxis.set_major_locator(ticker.MultipleLocator(5))
    ax.xaxis.set_major_locator(ticker.MultipleLocator(5))
    ax.xaxis.set_minor_locator(ticker.MultipleLocator(1))
    ax.grid(axis="y")
    ax.set_xlim((0, 20))
    ax.set_xlabel("Volume rain rate [10$^6$ m$^3$h$^{-1}$]")
    
# axis for area
for ax in axs[variables.index("prev_area"), :].flatten():
    ax.set_ylabel("Proportion [%]")
    ax.set_ylim([0, 16])
    ax.yaxis.set_major_locator(ticker.MultipleLocator(2))
    ax.xaxis.set_major_locator(ticker.MultipleLocator(250))
    ax.xaxis.set_minor_locator(ticker.MultipleLocator(50))
    ax.grid(axis="y")
    ax.set_xlim((0, 1600))
    # Vertical line at 25
    ax.axvline(25, linestyle="--", color="k", linewidth=1.5)
    ax.set_xlabel("Cell area [km$^2$]")
    
# axis for area
for ax in axs[variables.index("max_area"), :].flatten():
    ax.set_ylabel("Proportion [%]")
    ax.set_ylim([0, 10])
    ax.yaxis.set_major_locator(ticker.MultipleLocator(2))
    ax.xaxis.set_major_locator(ticker.MultipleLocator(250))
    ax.xaxis.set_minor_locator(ticker.MultipleLocator(50))
    ax.grid(axis="y")
    ax.set_xlim((0, 1600))
    # Vertical line at 25
    ax.axvline(25, linestyle="--", color="k", linewidth=1.5)
    ax.set_xlabel("Cell area [km$^2$]")
    
# axis for lifetime
for ax in axs[variables.index("lifetime_full"), :].flatten():
    ax.set_ylabel("Proportion [%]")
    ax.set_ylim([0, 35])
    ax.yaxis.set_major_locator(ticker.MultipleLocator(5))
    ax.xaxis.set_major_locator(ticker.MultipleLocator(2))
    ax.xaxis.set_minor_locator(ticker.MultipleLocator(1))
    ax.xaxis.set_major_formatter(ticker.FuncFormatter(leadtime_to_minutes))
    ax.grid(axis="y")
    ax.set_xlim((0.5, 15.5))
    ax.set_xlabel("Track lifetime [min]")
    
# Add some space between rows
# fig.get_layout_engine().set(hspace=0.05)

outputname = "histograms_article"
save_figs(fig, OUTPUT_DIR, outputname, conf.output_formats)

# Save histograms
for var in variables:
    store_df = pd.concat(histograms[var].values(), keys=histograms[var].keys(), axis=0)
    store_df = store_df.reset_index().rename(columns={"level_0": "state", "level_1": "value"})
    outputname = f"histograms_article_{var}"
    store_df.to_csv(OUTPUT_DIR / f"{outputname}.csv")

# Free up memory
# del ds_cs