## Imports

In [None]:
import json
import os
import torch

import ccc
import utils
import utils_plots
import utils_shap

import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from matplotlib.patches import Patch

## Settings

### General

In [None]:
TARGET_MODE = 1
GRADIENT_EXPLAINER = False

## Preparations

In [None]:
tmsubdir = f'targetmode_{TARGET_MODE}'
model_root_tm_path = os.path.join(ccc.MODEL_ROOT_PATH, tmsubdir)

In [None]:
modeldirs = os.listdir(model_root_tm_path)
modeldirs.sort(reverse=True)

wmodel = widgets.Dropdown(
                    options=modeldirs,
                    value=modeldirs[0],
                    description='Choose a model:',
)

display(wmodel)

In [None]:
modelpath = os.path.join(model_root_tm_path, wmodel.value)
shaps = os.listdir(modelpath)
shaps.sort()

shaps = [x for x in shaps if x.endswith("test_scores.json")]

wshaps = widgets.Dropdown(
                    options=shaps,
                    value=shaps[0],
                    description='Choose a model:',
)

display(wshaps)

In [None]:
with open(os.path.join(modelpath, wshaps.value), 'r') as f:
    test_scores_json = json.load(f)

In [None]:
_, model_name = utils.load_model(os.path.join(f'targetmode_{TARGET_MODE}', wmodel.value), torch.device("cpu"), "18")

prefix = "_model_00018"

shap_path = os.path.join(modelpath, prefix + "_shap_parquet_bg_by_lon_lat_no_flash")
df_path = os.path.join(modelpath, prefix + "_test_df.pickle")

dd = pd.read_pickle(df_path)

In [None]:
dd.columns

In [None]:
vis_save_path = os.path.join(modelpath, prefix + f"_shap_plots_bg_no_flash")

if vis_save_path != "":
    if not os.path.isdir(vis_save_path):
        os.makedirs(vis_save_path)

In [None]:
vc_threshold = utils.getVeryConfidentThreshold(test_scores_json["used_threshold"])
dd_transf = dd

dd_transf.loc[:, "pred_class"] = np.where(dd_transf["output"] > test_scores_json["used_threshold"], "pred_flash", "pred_no_flash")
dd_transf.loc[:, "real_class"] = np.where(dd_transf["target"] > 0.5, "real_flash", "real_no_flash")  # target col only contains 0s and 1s.

dd_transf.loc[:, 'cat'] = np.select(
    [
        (dd_transf['pred_class'] == "pred_flash") & (dd_transf['real_class'] == "real_flash") & (dd_transf["output"] < vc_threshold), 
        (dd_transf['pred_class'] == "pred_flash") & (dd_transf['real_class'] == "real_flash") & (dd_transf["output"] >= vc_threshold), 
        (dd_transf['pred_class'] == "pred_no_flash") & (dd_transf['real_class'] == "real_flash"), 
        (dd_transf['pred_class'] == "pred_flash") & (dd_transf['real_class'] == "real_no_flash"), 
        (dd_transf['pred_class'] == "pred_no_flash") & (dd_transf['real_class'] == "real_no_flash"), 
    ], 
    [
        'TP_LC', 
        'TP_VC',
        'FN',
        'FP',
        'TN',
    ], 
    default='ERROR'
)

dd_transf.loc[:, 'cluster'] = np.select(
    [dd_transf['cat'] == "TP_LC", dd_transf['cat'] == "TP_VC", dd_transf['cat'] == "FN", dd_transf['cat'] == "FP", dd_transf['cat'] == "TN",],
    [0, 1, 2, 3, 4,],
    default=-1
)

dd_transf = dd_transf.rename(columns={"output": "pred_score"})

In [None]:
sns.histplot(dd_transf.query("cat in ['TP_LC', 'TP_VC']")['pred_score'])

In [None]:
with open(os.path.join(modelpath, 'data_cfg.json'), 'r') as f:
    config_data = json.load(f)

with open(os.path.join(modelpath, 'model_cfg.json'), 'r') as f:
    config_model = json.load(f)

traincols = ccc.TRAIN_COLS

cols = []

for col in ccc.LVL_TRAIN_COLS:
    for lvl_idx in range(74):
        lvl = 64 + lvl_idx
        cols.append(f"{col}_lvl{lvl}")

cols.extend(ccc.INDEX_COLS)
cols.extend(traincols)

cols = list(set(cols))
print("Load test data into spark df", flush=True)
test_package = utils.get_testdf_spark(config_data, cols + ["cbh", "cth"], None)
sparkdd_test = test_package.drop("features").drop("label")

print("Convert test data (excluding TNs) into spark df", flush=True)
spark = utils.getsparksession()
#sparkdd = spark.createDataFrame(dd_transf.query("cat != 'TN'"))
sparkdd = spark.createDataFrame(dd_transf)

print("Join the two dfs")
sparkdd = utils.joinDataframes(sparkdd, sparkdd_test)

print("Convert to pandas df")
dd_enriched = sparkdd.toPandas()

print("Free memory")
del sparkdd
del sparkdd_test

## Clustering

In [None]:
mask = (dd_enriched["cat"] == 'TP_VC') | (dd_enriched["cat"] == 'TP_LC')  # we only cluster for true positives
dd_tp = dd_enriched.loc[mask, :]
dd_fx = dd_enriched.loc[~mask, :]

In [None]:
dd_tp.head(10)

In [None]:
[len(dd_tp), len(dd_fx)]

In [None]:
size_of_smallest_cl = dd_enriched['cluster'].value_counts().min()
dd_enriched['cluster'].value_counts()

## Plotting Profiles

In [None]:
geoh_cols = [col for col in dd_enriched.columns if col.startswith("geoh_")]

df_many_cases = dd_enriched[ccc.INDEX_COLS + geoh_cols + ['cluster', 'cat']]

In [None]:
len(df_many_cases)

In [None]:
df_many_cases_sampled = df_many_cases.groupby('cluster').sample(size_of_smallest_cl)

In [None]:
len(df_many_cases_sampled)

In [None]:
dshap = pd.read_parquet(shap_path) ## serves as test file
dshap.rename(columns={utils_shap.colname_meta_infix(col) : col for col in ccc.INDEX_COLS}, inplace=True)

In [None]:
df_many_cases_shap = utils.joinDataframes(df_many_cases_sampled, dshap)
df_many_cases_shap.drop(["flash_meta"], axis=1, inplace=True)

In [None]:
def invsig(y):
    return np.log(y / (1 - y))

In [None]:
# Used for debugging to quickly reload utils_plots.py
import importlib
importlib.reload(utils_plots)

In [None]:
ptype = "q50"  # can be mult, q50, q95
use_cache = False
write_cache = False

separate_clusters = False

plot_grouped = false

only_show_cols = []

y_axis = "geopotential_altitude"  # level, geopotential_altitude

In [None]:
used_threshold = 0.8708981871604919

df_shap_to_plot = df_many_cases_shap.copy()

df_plot_shap_cols = [df_shap_to_plot[[f"{varname}_shapval_lvl{lvl}" for lvl in range(64, 138)]].sum(axis=1) for varname in ccc.LVL_TRAIN_COLS]
df_plot_shap_tmp = pd.concat(df_plot_shap_cols, axis=1)
df_plot_shap_tmp.columns = [f"{c}_shapsum" for c in ccc.LVL_TRAIN_COLS]

df_shap_to_plot = pd.concat([df_shap_to_plot, df_plot_shap_tmp], axis=1)

traincols = []

for traincol in ccc.LVL_TRAIN_COLS:
    for idx in range(74):
        lvl = 64 + idx
        traincols.append(f"{traincol}_shapval_lvl{lvl}")
            
for c in traincols + [f"{c}_shapsum" for c in ccc.LVL_TRAIN_COLS]:
    df_shap_to_plot[c] = df_shap_to_plot[c] / (invsig(used_threshold) - df_shap_to_plot["shap_base_value"])

if plot_grouped:
    plot_clusters = {
                    1: 'TP_CLOUD_HIGH',
                    2: 'TP_MASS_HIGH',
                    3: 'TP_WIND_HIGH',
                    4: 'TN',
    }
    
    df_shap_to_plot.loc[:, 'cluster'] = np.select(
        [
            df_shap_to_plot['cat'].isin(['TP_LC', 'TP_VC']) & (df_shap_to_plot['cswc_shapsum'] + df_shap_to_plot['ciwc_shapsum'] + df_shap_to_plot['crwc_shapsum'] + df_shap_to_plot['clwc_shapsum'] > 0.5),
            df_shap_to_plot['cat'].isin(['TP_LC', 'TP_VC']) & (df_shap_to_plot['q_shapsum'] + df_shap_to_plot['t_shapsum'] > 0.5),
            df_shap_to_plot['cat'].isin(['TP_LC', 'TP_VC']) & (df_shap_to_plot['u_shapsum'] + df_shap_to_plot['v_shapsum'] + df_shap_to_plot['w_shapsum'] > 0.5),
            df_shap_to_plot['cat'].isin(['TN']),
        ], 
        [
            1, 
            2,
            3,
            4,
        ], 
        default=-1
    )

    palette = "Pastel1"
else:
    plot_clusters = {
                        0: 'TP less confident',
                        1: 'TP very confident',
    #                    2: 'FN',
    #                    3: 'FP',
    #                    4: 'TN',
    }
    
    df_shap_to_plot.loc[:, 'cluster'] = np.select(
        [df_shap_to_plot['cat'] == "TP_LC", df_shap_to_plot['cat'] == "TP_VC", df_shap_to_plot['cat'] == "FN", df_shap_to_plot['cat'] == "FP", df_shap_to_plot['cat'] == "TN",],
        [0, 1, 2, 3, 4,],
        default=-1
    )

    palette = None

df_shap_to_plot.query("cluster != -1", inplace=True)
df_shap_to_plot.drop(['cat'], axis=1, inplace=True)

In [None]:
df_shap_to_plot["cluster"]

In [None]:
utils_plots.plot_many_profiles(df_shap_to_plot, "feature", ptype=ptype, y_axis=y_axis, separate_clusters=separate_clusters, save_path=vis_save_path, use_cache=use_cache, plot_clusters=plot_clusters, only_show_cols=only_show_cols, write_cache=write_cache)

In [None]:
utils_plots.plot_many_profiles(df_shap_to_plot, "shap", ptype=ptype, y_axis=y_axis, separate_clusters=separate_clusters, save_path=vis_save_path, use_cache=use_cache, plot_clusters=plot_clusters, only_show_cols=only_show_cols, write_cache=write_cache)

In [None]:
ucols = ["longitude", "latitude", "year", "month", "day", "hour", "cbh", "cth", "cluster"]
df_cbh_cth_grouped = df_many_cases.reset_index()[ucols].groupby(ucols)

In [None]:
cbh_ls = []
cth_ls = []
cluster_ls = []

for group, _ in df_cbh_cth_grouped:
    cbh_ls.append(group[ucols.index("cbh")])
    cth_ls.append(group[ucols.index("cth")])
    cluster_ls.append(group[ucols.index("cluster")])

In [None]:
df_cbh_cth = pd.DataFrame({"cbh" : cbh_ls, "cth" : cth_ls, "cluster" : cluster_ls})
df_cbh_cth["ch"] = df_cbh_cth["cth"] - df_cbh_cth["cbh"]
df_cbh_cth = df_cbh_cth[df_cbh_cth["cluster"].isin(plot_clusters.keys())]
df_cbh_cth["cluster_labels"] = df_cbh_cth["cluster"].replace(plot_clusters)

In [None]:
palette = sns.color_palette([utils_plots.CLUSTER_COLORS[col] for col in plot_clusters.keys()] if len(plot_clusters) > 0 else utils_plots.CLUSTER_COLORS)

In [None]:
legend_elements = [Patch(color=utils_plots.CLUSTER_COLORS[key], label=plot_clusters[key]) for key in plot_clusters]

In [None]:
cloud_medians = df_cbh_cth.groupby("cluster").median()

In [None]:
cloud_medians

In [None]:
g = plt.figure(figsize=(8, 15))
graph = sns.violinplot(data=df_cbh_cth, x="cluster", y="cth", fliersize=3, palette=palette, cut=0)
g.legend(handles=legend_elements, loc='center left', bbox_to_anchor=(0.9, 0.5))

for idx in plot_clusters.keys():
    graph.axhline(cloud_medians.query(f"cluster == {idx}")["cth"].values[0], color=utils_plots.CLUSTER_COLORS[idx])

ofile = os.path.join(vis_save_path, f"cth_violinplot")
g.savefig(f"{ofile}.pdf")
g.savefig(f"{ofile}.png")

In [None]:
g = plt.figure(figsize=(8, 15))
graph = sns.violinplot(data=df_cbh_cth, x="cluster", y="cbh", fliersize=3, palette=palette, cut=0)
g.legend(handles=legend_elements, loc='center left', bbox_to_anchor=(0.9, 0.5))

for idx in plot_clusters.keys():
    graph.axhline(cloud_medians.query(f"cluster == {idx}")["cbh"].values[0], color=utils_plots.CLUSTER_COLORS[idx])

ofile = os.path.join(vis_save_path, f"violinplot_cbh")
g.savefig(f"{ofile}.pdf")
g.savefig(f"{ofile}.png")

In [None]:
g = plt.figure(figsize=(8, 15))
graph = sns.violinplot(data=df_cbh_cth, x="cluster", y="ch", fliersize=3, palette=palette, cut=0)
g.legend(handles=legend_elements, loc='center left', bbox_to_anchor=(0.9, 0.5))

for idx in plot_clusters.keys():
    graph.axhline(cloud_medians.query(f"cluster == {idx}")["ch"].values[0], color=utils_plots.CLUSTER_COLORS[idx])

ofile = os.path.join(vis_save_path, f"violinplot_ch")
g.savefig(f"{ofile}.pdf")
g.savefig(f"{ofile}.png")