## Imports

In [None]:
import math
import matplotlib
import os

import ccc
import utils
import utils_plots
import utils_shap

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from matplotlib.colors import ListedColormap
from matplotlib.patches import Patch
from mpl_toolkits.basemap import Basemap
from scipy import stats

## Settings

In [None]:
font = {'size' : 12}
matplotlib.rc('font', **font)

## Preparations

In [None]:
model_root_tm_path = os.path.join(ccc.MODEL_ROOT_PATH, 'targetmode_1', '2022_02_21__11-11__ALDIS_paper')

In [None]:
df_path = os.path.join(model_root_tm_path, "_model_00018_test_df.pickle")

dd = pd.read_pickle(df_path)

In [None]:
PLOT_COLS = [f"{c}_relative_shap" for c in (ccc.LVL_TRAIN_COLS + ["hour", "dayofyear"])]

In [None]:
dd

In [None]:
shap_path = os.path.join(model_root_tm_path, "_model_00018_shap_parquet_bg_by_lon_lat_no_flash")
dshap = pd.read_parquet(shap_path)

In [None]:
dshap.rename(columns={utils_shap.colname_meta_infix(col) : col for col in ccc.INDEX_COLS}, inplace=True)
dshap

In [None]:
sns.histplot(dshap["shap_base_value"])

In [None]:
used_threshold = 0.8708981871604919
vc_threshold = utils.getVeryConfidentThreshold(used_threshold)

df_joined = utils.joinDataframes(dshap, dd)

df_joined.loc[:, "pred_class"] = np.where(df_joined["output"] > used_threshold, "pred_flash", "pred_no_flash")
df_joined.loc[:, "real_class"] = np.where(df_joined["target"] > 0.5, "real_flash", "real_no_flash")  # target col only contains 0s and 1s.

df_joined.loc[:, 'cat'] = np.select(
    [
        (df_joined['pred_class'] == "pred_flash") & (df_joined['real_class'] == "real_flash") & (df_joined["output"] < vc_threshold), 
        (df_joined['pred_class'] == "pred_flash") & (df_joined['real_class'] == "real_flash") & (df_joined["output"] >= vc_threshold), 
        (df_joined['pred_class'] == "pred_no_flash") & (df_joined['real_class'] == "real_flash"), 
        (df_joined['pred_class'] == "pred_flash") & (df_joined['real_class'] == "real_no_flash"), 
        (df_joined['pred_class'] == "pred_no_flash") & (df_joined['real_class'] == "real_no_flash"), 
    ], 
    [
        'TP_LC', 
        'TP_VC',
        'FN',
        'FP',
        'TN',
    ], 
    default='ERROR'
)

df_joined.loc[:, 'cluster'] = np.select(
    [df_joined['cat'] == "TP_LC", df_joined['cat'] == "TP_VC", df_joined['cat'] == "FN", df_joined['cat'] == "FP", df_joined['cat'] == "TN",],
    [0, 1, 2, 3, 4,],
    default=-1
)

df_joined = df_joined.rename(columns={"output": "pred_score"})

In [None]:
def sig(x):
    return 1 / (1 + np.exp(-x))


def invsig(y):
    return np.log(y / (1 - y))

In [None]:
df_plot_shap_cols = [df_joined[[f"{varname}_shapval_lvl{lvl}" for lvl in range(64, 138)]].sum(axis=1) for varname in ccc.LVL_TRAIN_COLS]
df_plot_shap = pd.concat(df_plot_shap_cols, axis=1)
df_plot_shap.columns = [f"{c}_shap" for c in ccc.LVL_TRAIN_COLS]

In [None]:
df_plot_meta_cols = [df_joined[[f"{varname}_meta_lvl{lvl}" for lvl in range(64, 138)]].sum(axis=1) for varname in ccc.LVL_TRAIN_COLS]
df_plot_meta = pd.concat(df_plot_meta_cols, axis=1)
df_plot_meta.columns = [f"{c}_meta" for c in ccc.LVL_TRAIN_COLS]

In [None]:
df_all = pd.concat([df_joined, df_plot_shap, df_plot_meta], axis=1)

In [None]:
invsig(used_threshold)

In [None]:
for c in ccc.LVL_TRAIN_COLS:
    df_all[f"{c}_relative_shap"] = df_all[f"{c}_shap"] / (invsig(used_threshold) - df_all["shap_base_value"])
    
df_all["hour_relative_shap"] = df_all["hour_shapval"] / (invsig(used_threshold) - df_all["shap_base_value"])
df_all["dayofyear_relative_shap"] = df_all["dayofyear_shapval"] / (invsig(used_threshold) - df_all["shap_base_value"])

In [None]:
cols = [f"{c}_shap" for c in ccc.LVL_TRAIN_COLS] + [f"{c}_relative_shap" for c in ccc.LVL_TRAIN_COLS + ["hour", "dayofyear"]] + ["shap_base_value", "cat", "cluster"] + ccc.INDEX_COLS

df_TP = df_all[df_all['cat'].isin(["TP_LC", "TP_VC"])][cols]
df_FP = df_all[df_all['cat'].isin(["FP"])][cols]
df_FN = df_all[df_all['cat'].isin(["FN"])][cols]
df_TN = df_all[df_all['cat'].isin(["TN"])][cols]

df_TP_LC = df_TP.query('cat == "TP_LC"')
df_TP_VC = df_TP.query('cat == "TP_VC"')

In [None]:
df_cloud = df_TP[df_TP['cswc_relative_shap'] + df_TP['ciwc_relative_shap'] + df_TP['crwc_relative_shap'] + df_TP['clwc_relative_shap'] > 0.5]
df_mass = df_TP[df_TP['q_relative_shap'] + df_TP['t_relative_shap'] > 0.5]
df_wind = df_TP[df_TP['u_relative_shap'] + df_TP['v_relative_shap'] + df_TP['w_relative_shap'] > 0.5]

df_cloud_plus_TN = pd.concat([df_cloud, df_TN])
df_mass_plus_TN = pd.concat([df_mass, df_TN])
df_wind_plus_TN = pd.concat([df_wind, df_TN])

In [None]:
print(f"Number of samples in cloud-dominant TPs:\t\t{len(df_cloud)}")
print(f"Number of samples in mass-dominant TPs: \t\t {len(df_mass)}")
print(f"Number of samples in wind-dominant TPs: \t\t{len(df_wind)}")

In [None]:
# Used for debugging to quickly reload utils_plots.py
import importlib
importlib.reload(utils_plots)

In [None]:
plot_clusters = {
                0 : 'TP_LC',
                1 : 'TP_VC',
                2 : 'FN',
                3 : 'FP',
                4 : 'TN',
}


def get_color_palette(categories_to_plot=['TP_LC', 'TP_VC', 'FN', 'FP', 'TN']):    
    return sns.color_palette([utils_plots.CLUSTER_COLORS[col] for col in categories_to_plot])

In [None]:
fig = sns.boxplot(df_TP_LC[PLOT_COLS])
plt.xticks(rotation=90)

fig

In [None]:
fig = sns.boxplot(df_all, y="shap_base_value", hue="cluster", palette=get_color_palette())
fig.legend(loc='center left', bbox_to_anchor=(1, 0.5))

for key in plot_clusters:
    fig.legend_.texts[key].set_text(plot_clusters[key])

In [None]:
def plot_bars_per_variable(df, title=""):
    f, axs = plt.subplots(len(PLOT_COLS), 1, constrained_layout=True, sharex=True)
    f.set_figheight(50)
    f.set_figwidth(15)

    if (title != ""):
        f.suptitle(title)

    df.sort_values(by="cluster", inplace=True)

    avail_clusters = df["cluster"].unique()
    avail_categories = df["cat"].unique()
    
    for idx, varname in enumerate(PLOT_COLS):
        fig = sns.boxplot(df, ax=axs[idx], y=varname, hue="cluster", palette=get_color_palette(avail_categories))
        fig.set_ylim(-1, 1)
        fig.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    
        for idx, key in enumerate(avail_clusters):
            fig.legend_.texts[idx].set_text(plot_clusters[key])

    postfix = f"_{title}" if title != "" else ""
    
    f.savefig(os.path.join("tmp", f"boxplots{postfix}.png"), bbox_inches='tight')

In [None]:
plot_bars_per_variable(df_mass_plus_TN, "mass_shapsum > 0.5")

In [None]:
f, axs = plt.subplots(len(PLOT_COLS), 1, sharex=True)
f.set_figheight(50)
f.set_figwidth(15)

for idx, varname in enumerate(PLOT_COLS):
    fig = sns.boxenplot(df_all, ax=axs[idx], y=varname, hue="cluster", palette=get_color_palette())
    fig.set_ylim(-1, 1)
    fig.legend(loc='center left', bbox_to_anchor=(1, 0.5))

    for key in plot_clusters:
        fig.legend_.texts[key].set_text(plot_clusters[key])
        
f.savefig(os.path.join("tmp", "boxenplots.png"), bbox_inches='tight')

In [None]:
f, axs = plt.subplots(len(PLOT_COLS), 1, sharex=True)
f.set_figheight(50)
f.set_figwidth(15)

df_filtered = df_all.query("cluster in [0.0, 1.0, 4.0]").copy()
df_filtered.loc[:, "is_TP"] = (df_filtered["cluster"] <= 1)

plot_clusters_filtered = {
                            4: 'TN',
                            1: 'TP_VC',
}

palette_filtered = sns.color_palette([utils_plots.CLUSTER_COLORS[plot_clusters_filtered[col]] for col in plot_clusters_filtered] if len(plot_clusters_filtered) > 0 else utils_plots.CLUSTER_COLORS.values())

for idx, varname in enumerate(PLOT_COLS):
    fig = sns.boxplot(df_filtered, ax=axs[idx], y=varname, hue="is_TP", palette=get_color_palette(plot_clusters_filtered.values()))
    fig.set_ylim(-1, 1)
    fig.legend(loc='center left', bbox_to_anchor=(1, 0.5))

    for idx, key in enumerate(plot_clusters_filtered):
        fig.legend_.texts[idx].set_text(plot_clusters_filtered[key])
        
f.savefig(os.path.join("tmp", "boxplots_TP_vs_TN.png"), bbox_inches='tight')

In [None]:
for col in PLOT_COLS:
    tpm = np.quantile(df_TP[col], 0.5)
    tnq = np.quantile(df_TN[col], 0.75)

    print(f"{col}: TP median = {tpm:0.2f} - TN 3rd quartile = {tnq:0.2f}")

In [None]:
for col in PLOT_COLS:
    print(f"{col}: {stats.ttest_ind(df_TP[col], df_TN[col], equal_var=False, alternative='greater')}")

In [None]:
def plot_histograms_per_variable(df, title=""):
    size_of_smallest_cl = df['cluster'].value_counts().min()
    
    f, axs = plt.subplots(len(PLOT_COLS), 1, constrained_layout=True, sharey=True)
    f.set_figheight(60)
    f.set_figwidth(15)
    
    if (title != ""):
        f.suptitle(title)
        
    postfix = f"_{title}" if title != "" else ""
    
    for idx, varname in enumerate(PLOT_COLS):
        cats = ["TP_LC", "TP_VC", "TN"]
        
        for cat in cats:
            sns.histplot(df.query(f"cat == '{cat}'").sample(size_of_smallest_cl), ax=axs[idx], color=utils_plots.CLUSTER_COLORS[cat], x=varname)
            
        axs[idx].set_xlim([-1, 1])
        
        axs[idx].legend(labels=cats)
    
    f.savefig(os.path.join("tmp", f"histgram_relative{postfix}.png"), bbox_inches='tight')

In [None]:
plot_histograms_per_variable(df_cloud_plus_TN, "cloud")

In [None]:
plot_histograms_per_variable(df_mass_plus_TN, "mass")

In [None]:
plot_histograms_per_variable(df_wind_plus_TN, "wind")

In [None]:
size_of_smallest_cl = min([len(df_TP_LC), len(df_TP_VC), len(df_TN)])

ax = sns.histplot(df_TP_LC.sample(size_of_smallest_cl), x="shap_base_value", color=utils_plots.CLUSTER_COLORS["TP_LC"])
sns.histplot(df_TP_VC.sample(size_of_smallest_cl), ax=ax, x="shap_base_value", color=utils_plots.CLUSTER_COLORS["TP_VC"])
sns.histplot(df_TN.sample(size_of_smallest_cl), ax=ax, x="shap_base_value", color=utils_plots.CLUSTER_COLORS["TN"])
    
ax.legend(labels=["TP less confident", "TP very confident", "TN"])

In [None]:
ax.get_figure().savefig("tmp/basevalues.png", bbox_inches='tight')

In [None]:
sns.histplot(df_TN.sample(size_of_smallest_cl), x="shap_base_value", color=utils_plots.CLUSTER_COLORS["TN"])

In [None]:
df_all[["longitude", "latitude", "shap_base_value"]].describe()

In [None]:
df_basevalues = df_all.groupby(["longitude", "latitude"])["shap_base_value"].mean().reset_index()

m = Basemap(projection='lcc', resolution='f', lon_0=12.5, lat_0=47.5, llcrnrlon=8, llcrnrlat=45, urcrnrlon=17, urcrnrlat=50)
m.drawmapboundary()
m.drawcountries(linewidth=2)

m.scatter(df_basevalues['longitude'], df_basevalues['latitude'], c=df_basevalues["shap_base_value"], cmap="jet", s=10, latlon=True)
plt.colorbar(label='BASE_VALUE', extend="max")

In [None]:
df_countlightningcells = df_all.query("target == 1.0").groupby(["longitude", "latitude"])["target"].count().reset_index()
df_countlightningcells.rename(columns={"target": "count"}, inplace=True)

m = Basemap(projection='lcc', resolution='f', lon_0=12.5, lat_0=47.5, llcrnrlon=8, llcrnrlat=45, urcrnrlon=17, urcrnrlat=50)
m.drawmapboundary()
m.drawcountries(linewidth=2)

m.scatter(df_countlightningcells['longitude'], df_countlightningcells['latitude'], c=df_countlightningcells["count"], cmap="jet", s=10, latlon=True)
plt.colorbar(label='NR_LIGHTNING_CELLS', extend="max")

In [None]:
df_countvsbase = df_countlightningcells.merge(df_basevalues, how="outer", on=["longitude", "latitude"])

sns.scatterplot(df_countvsbase, x="count", y="shap_base_value")

In [None]:
def plot_map(df, plotcols=PLOT_COLS, vmin=0, vmax=0.5):
    for varname in plotcols:
        df_var = df.groupby(["longitude", "latitude"])[varname].mean().reset_index()
        
        m = Basemap(projection='lcc', resolution='f', lon_0=12.5, lat_0=47.5, llcrnrlon=8, llcrnrlat=45, urcrnrlon=17, urcrnrlat=50)
        m.drawmapboundary()
        m.drawcountries(linewidth=2)
        
        m.scatter(df_var['longitude'], df_var['latitude'], c=df_var[varname], vmin=vmin, vmax=vmax, cmap="jet", s=20, latlon=True)
        plt.colorbar(label=varname, extend="max")
        plt.show()

In [None]:
def plot_map_count(df, suffix="", df_norm=None):
    def roundup(x, divisor):
        return math.ceil(x / divisor) * divisor
    
    varname = "shap_base_value"  # does not matter; we only count anyways
    
    df_var = df.groupby(["longitude", "latitude"])[varname].count().reset_index()

    if df_norm is not None:
        tmpcol = "hour"  # does not matter; we only count
        df_norm_count = df_norm.groupby(["longitude", "latitude"])[tmpcol].count().reset_index()
        df_var = df_var.merge(df_norm_count, on=["longitude", "latitude"])

        df_var[varname] /= df_var[tmpcol]
        df_var[varname] *= 100

        df_var.drop(tmpcol, axis=1, inplace=True) 
        
    max_value = df_var["shap_base_value"].max()

    df_var["shap_base_value"] -= 0.01 # hacky solution to ensure each sample is in the correct bin

    if max_value > 50:
        divisor = 20        
        max_rounded = roundup(max_value, divisor)
        
        ticks = list(np.arange(10, max_rounded, divisor))
        ticklabels = [f'{math.floor(n - 9)} - {math.floor(n + 10)}' for n in ticks]
    elif max_value > 30:
        divisor = 10        
        max_rounded = roundup(max_value, divisor)
        
        ticks = list(np.arange(5.5, max_rounded, divisor))
        ticklabels = [f'{math.floor(n - 4)} - {math.floor(n + 5)}' for n in ticks]
    elif max_value >= 20:
        divisor = 5        
        max_rounded = roundup(max_value, divisor)
        
        ticks = list(range(3, max_rounded, divisor))
        ticklabels = [f'{math.floor(n - 2)} - {math.floor(n + 2)}' for n in ticks]
    elif max_value >= 10:
        divisor = 3        
        max_rounded = roundup(max_value, divisor)
        
        ticks = list(np.arange(1.5, max_rounded, divisor))
        ticklabels = [f'{math.ceil(n - 1)} - {math.ceil(n + 1)}' for n in ticks]
    else:
        divisor = 2       
        max_rounded = roundup(max_value, divisor)
        
        ticks = list(np.arange(1, max_rounded, divisor))
        ticklabels = [f'{math.ceil(n - 0.5)} - {math.ceil(n + 0.5)}' for n in ticks]

    m = Basemap(projection='lcc', resolution='f', lon_0=12.5, lat_0=47.5, llcrnrlon=8, llcrnrlat=45, urcrnrlon=17, urcrnrlat=50)
    m.drawmapboundary()
    m.drawcountries(linewidth=2)
    
    nr_categories = len(ticks)
    cm = plt.get_cmap('cool', nr_categories)
    
    cm = ListedColormap(("#F8DCD9","#F4B8C0","#E198B5","#C87AAD","#AA5FA5","#87489D","#5B3794")).resampled(nr_categories)
    
    m.scatter(df_var['longitude'], df_var['latitude'], c=df_var[varname], cmap=cm, vmin=0, vmax=(nr_categories * divisor), latlon=True, s=20)
    
    cbar = m.colorbar(ticks=ticks)
    cbar.ax.set_yticklabels(ticklabels)

    if suffix != "":
        plt.savefig(os.path.join("tmp", f"mapplot_{suffix}.png"), bbox_inches='tight')
    
    plt.show()

In [None]:
plot_map(df_cloud)

In [None]:
plot_map(df_mass)

In [None]:
plot_map(df_wind)

In [None]:
plot_map_count(df_cloud, "cloud_percent", df_TP)

In [None]:
plot_map_count(df_mass, "mass_percent", df_TP)

In [None]:
plot_map_count(df_wind, "wind_percent", df_TP)

In [None]:
plot_map_count(df_TP, "true_positives")

In [None]:
plot_map_count(df_cloud, "cloud")

In [None]:
plot_map_count(df_mass, "mass")

In [None]:
plot_map_count(df_wind, "wind")

In [None]:
f, axs = plt.subplots(len(PLOT_COLS), 1, sharey=True)
f.set_figheight(40)
f.set_figwidth(15)

for idx, varname in enumerate(PLOT_COLS):
    sns.histplot(df_TP_LC.sample(size_of_smallest_cl), ax=axs[idx], color=utils_plots.CLUSTER_COLORS["TP_LC"], x=varname)
    sns.histplot(df_TP_VC.sample(size_of_smallest_cl), ax=axs[idx], color=utils_plots.CLUSTER_COLORS["TP_VC"], x=varname)
    sns.histplot(df_TN.sample(size_of_smallest_cl), ax=axs[idx], color=utils_plots.CLUSTER_COLORS["TN"], x=varname)
    axs[idx].set_xlim([-1, 1])
    
    axs[idx].legend(labels=["TP less confident", "TP very confident", "TN"])

In [None]:
f, axs = plt.subplots(len(ccc.LVL_TRAIN_COLS), 1, sharey=True)
f.set_figheight(40)
f.set_figwidth(15)

for idx, varname in enumerate(ccc.LVL_TRAIN_COLS):
    sns.histplot(df_filtered.query("cluster == 0").sample(size_of_smallest_cl), ax=axs[idx], color=utils_plots.CLUSTER_COLORS["TP_LC"], x=f"{varname}_meta")
    sns.histplot(df_filtered.query("cluster == 1").sample(size_of_smallest_cl), ax=axs[idx], color=utils_plots.CLUSTER_COLORS["TP_VC"], x=f"{varname}_meta")
    sns.histplot(df_filtered.query("cluster == 4").sample(size_of_smallest_cl), ax=axs[idx], color=utils_plots.CLUSTER_COLORS["TN"], x=f"{varname}_meta")
    
    axs[idx].legend(labels=["TP less confident", "TP very confident", "TN"])

In [None]:
def plot_stack(df, nr_samples=100):    
    cols = [
                'cswc',
                'q',
                'w',
                'ciwc',
                'hour',
                'clwc',
                'crwc',
                't',
                'u',
                'v',
    ]
    
    cols_r = [f"{c}_relative_shap" for c in cols]
    
    df_samp = df.sample(nr_samples)[cols_r]
    df_samp.sort_values(by=cols_r, inplace=True, ascending=False)
    
    plt.figure(figsize=(20, 6))
    plt.stackplot(range(nr_samples), df_samp[cols_r].T, labels=cols)
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.hlines(1, -1, nr_samples, color="red", linestyle="dashed")
    
    plt.show()

In [None]:
def plot_stack_other(df, nr_samples=100):    
    cols = [
                'cswc',
                'q',
                'w',
                'ciwc'
    ]

    other_cols =  [
                    'clwc',
                    'crwc',
                    't',
                    'u',
                    'v',
                    'hour',
                    'dayofyear'
    ]
    
    
    cols_r = [f"{c}_relative_shap" for c in cols]
    other_cols_r = [f"{c}_relative_shap" for c in other_cols]
    
    df_samp = df.sample(nr_samples)[cols_r + other_cols_r].copy()

    df_samp["other"] = df_samp[other_cols_r].sum(axis=1)
    df_samp.sort_values(by=cols_r, inplace=True, ascending=False)
    
    plt.figure(figsize=(20, 6))
    plt.stackplot(range(nr_samples), df_samp[cols_r + ["other"]].T, labels=cols + ["other"])
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.hlines(1, -1, nr_samples, color="red", linestyle="dashed")
    
    plt.show()

In [None]:
def plot_stack_grouped(df, nr_samples=100):
    df_samp = df.sample(nr_samples)

    df_samp["cloud"] = df_samp[[f"{c}_relative_shap" for c in ['cswc', 'ciwc', 'clwc', 'crwc']]].sum(axis=1)
    df_samp["mass"] = df_samp[[f"{c}_relative_shap" for c in ['q', 't']]].sum(axis=1)
    df_samp["wind"] = df_samp[[f"{c}_relative_shap" for c in ['u', 'v', 'w']]].sum(axis=1)
    df_samp["time"] = df_samp[[f"{c}_relative_shap" for c in ['hour', 'dayofyear']]].sum(axis=1)

    cols = [
                "time",
                "cloud",
                "wind",
                "mass",
    ]
    
    df_samp.sort_values(by=cols, inplace=True, ascending=False)
    
    plt.figure(figsize=(20, 6))
    plt.stackplot(range(nr_samples), df_samp[cols].T, labels=cols)
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.hlines(1, -1, nr_samples, color="red", linestyle="dashed")
    
    plt.show()

In [None]:
def plot_barplots_grouped(df):
    df_tmp = df.copy()

    df_tmp["cloud"] = df_tmp[[f"{c}_relative_shap" for c in ['cswc', 'ciwc', 'clwc', 'crwc']]].sum(axis=1)
    df_tmp["mass"] = df_tmp[[f"{c}_relative_shap" for c in ['q', 't']]].sum(axis=1)
    df_tmp["wind"] = df_tmp[[f"{c}_relative_shap" for c in ['u', 'v', 'w']]].sum(axis=1)
    df_tmp["time"] = df_tmp[[f"{c}_relative_shap" for c in ['hour', 'dayofyear']]].sum(axis=1)

    cols = [
                "time",
                "cloud",
                "mass",
                "wind",
    ]

    f, axs = plt.subplots(len(cols), 1, sharex=True)
    f.set_figheight(50)
    f.set_figwidth(15)
    
    for idx, varname in enumerate(cols):
        fig = sns.boxplot(df_tmp, ax=axs[idx], y=varname, hue="cluster", palette=get_color_palette())
        fig.set_ylim(-1, 1)
        fig.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    
        for key in plot_clusters:
            fig.legend_.texts[key].set_text(plot_clusters[key])
    
    f.savefig(os.path.join("tmp", "boxplots_grouped.png"), bbox_inches='tight')

    plt.show()

In [None]:
plot_barplots_grouped(df_all)

In [None]:
plot_stack_other(df_TP, 100)

In [None]:
plot_stack_grouped(df_TP, 300)

In [None]:
plot_stack_grouped(df_TP_LC)

In [None]:
plot_stack(df_TP_VC)

In [None]:
plot_stack(df_TN)

In [None]:
plt.figure(figsize=(20, 6))

sns.boxplot(df_all.query("cluster <= 1.0"), y="hour_relative_shap", x="hour")