## Imports

In [1]:
import json
import os
import torch

import ccc
import utils
import utils_plots
import utils_shap

import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from matplotlib.patches import Patch

  from pyarrow import LocalFileSystem


## Settings

### General

In [2]:
TARGET_MODE = 1
GRADIENT_EXPLAINER = False

## Preparations

In [3]:
tmsubdir = f'targetmode_{TARGET_MODE}'
model_root_tm_path = os.path.join(ccc.MODEL_ROOT_PATH, tmsubdir)

In [4]:
modeldirs = os.listdir(model_root_tm_path)
modeldirs.sort(reverse=True)

wmodel = widgets.Dropdown(
                    options=modeldirs,
                    value=modeldirs[0],
                    description='Choose a model:',
)

display(wmodel)

Dropdown(description='Choose a model:', options=('2023_10_10__09-55_1D_CNN', '2023_08_24__15-54__fc_testyear_2…

In [5]:
modelpath = os.path.join(model_root_tm_path, wmodel.value)
shaps = os.listdir(modelpath)
shaps.sort()

shaps = [x for x in shaps if x.endswith(".test.json")]

wshaps = widgets.Dropdown(
                    options=shaps,
                    value=shaps[0],
                    description='Choose a model:',
)

display(wshaps)

IndexError: list index out of range

In [None]:
with open(os.path.join(modelpath, wshaps.value), 'r') as f:
    test_scores_json = json.load(f)

In [None]:
gradexplain_postfix = "_gradientexplainer" if GRADIENT_EXPLAINER else ""

_, model_name = utils.load_model(os.path.join(f'targetmode_{TARGET_MODE}', wmodel.value), torch.device("cpu"), test_scores_json["epoch"])

shap_path = os.path.join(modelpath, test_scores_json["prefix_dirs"] + f"_shap_parquet{gradexplain_postfix}")
df_path = os.path.join(modelpath, test_scores_json["test_pickle"])

dd = pd.read_pickle(df_path)

In [None]:
dd.columns

In [None]:
vis_save_path = os.path.join(modelpath, test_scores_json["prefix_dirs"] + f"_shap_plots{gradexplain_postfix}")

if vis_save_path != "":
    if not os.path.isdir(vis_save_path):
        os.makedirs(vis_save_path)

In [None]:
vc_threshold = utils.getVeryConfidentThreshold(test_scores_json["used_threshold"])
dd_transf = dd

dd_transf.loc[:, "pred_class"] = np.where(dd_transf["output"] > test_scores_json["used_threshold"], "pred_flash", "pred_no_flash")
dd_transf.loc[:, "real_class"] = np.where(dd_transf["target"] > 0.5, "real_flash", "real_no_flash")  # target col only contains 0s and 1s.

dd_transf.loc[:, 'cat'] = np.select(
    [
        (dd_transf['pred_class'] == "pred_flash") & (dd_transf['real_class'] == "real_flash") & (dd_transf["output"] < vc_threshold), 
        (dd_transf['pred_class'] == "pred_flash") & (dd_transf['real_class'] == "real_flash") & (dd_transf["output"] >= vc_threshold), 
        (dd_transf['pred_class'] == "pred_no_flash") & (dd_transf['real_class'] == "real_flash"), 
        (dd_transf['pred_class'] == "pred_flash") & (dd_transf['real_class'] == "real_no_flash"), 
        (dd_transf['pred_class'] == "pred_no_flash") & (dd_transf['real_class'] == "real_no_flash"), 
    ], 
    [
        'TP_LC', 
        'TP_VC',
        'FN',
        'FP',
        'TN',
    ], 
    default='ERROR'
)

dd_transf.loc[:, 'cluster'] = np.select(
    [dd_transf['cat'] == "TP_LC", dd_transf['cat'] == "TP_VC", dd_transf['cat'] == "FN", dd_transf['cat'] == "FP", dd_transf['cat'] == "TN",],
    [0, 1, 2, 3, 4,],
    default=-1
)

dd_transf = dd_transf.rename(columns={"output": "pred_score"})

In [None]:
sns.histplot(dd_transf.query("cat in ['TP_LC', 'TP_VC']")['pred_score'])

In [None]:
np.median(dd_transf.query("cat in ['TP_LC', 'TP_VC']")['pred_score'])

In [None]:
vc_threshold

In [None]:
with open(os.path.join(modelpath, 'data_cfg.json'), 'r') as f:
    config_data = json.load(f)

with open(os.path.join(modelpath, 'model_cfg.json'), 'r') as f:
    config_model = json.load(f)

traincols = config_model["traincols"] if ("traincols" in config_model) else utils.get_train_cols(config_data["datamode"])

cols = []

for col in ccc.LVL_COLS_ETL:
    for lvl_idx in range(74):
        lvl = 64 + lvl_idx
        cols.append(f"{col}_lvl{lvl}")

cols.extend(ccc.INDEX_COLS)
cols.extend(traincols)

cols = list(set(cols))
print("Load test data into spark df", flush=True)
test_package = utils.get_testdf_spark(config_data, cols, None, use_months=test_scores_json["used_months"], test_scores_config=test_scores_json)
sparkdd_test = test_package["data"].drop("features").drop("label")

print("Convert test data (excluding TNs) into spark df", flush=True)
spark = utils.getsparksession()
sparkdd = spark.createDataFrame(dd_transf.query("cat != 'TN'"))

print("Join the two dfs")
sparkdd = utils.joinDataframes(sparkdd, sparkdd_test)

print("Convert to pandas df")
dd_enriched = sparkdd.toPandas()

print("Free memory")
del sparkdd
del sparkdd_test

## Clustering

In [None]:
mask = (dd_enriched["cat"] == 'TP_VC') | (dd_enriched["cat"] == 'TP_LC')  # we only cluster for true positives
dd_tp = dd_enriched.loc[mask, :]
dd_fx = dd_enriched.loc[~mask, :]

In [None]:
dd_tp.head(10)

In [None]:
[len(dd_tp), len(dd_fx)]

In [None]:
size_of_smallest_cl = dd_enriched['cluster'].value_counts().min()
dd_enriched['cluster'].value_counts()

## Plotting Profiles

In [None]:
geoh_cols = [col for col in dd_enriched.columns if col.startswith("geoh_")]

df_many_cases = dd_enriched[ccc.INDEX_COLS + geoh_cols + ['cluster']]

In [None]:
len(df_many_cases)

In [None]:
df_many_cases_sampled = df_many_cases.groupby('cluster').sample(size_of_smallest_cl)

In [None]:
len(df_many_cases_sampled)

In [None]:
dshap = pd.read_parquet(shap_path) ## serves as test file
dshap.rename(columns={utils_shap.colname_meta_infix(col) : col for col in ccc.INDEX_COLS}, inplace=True)

In [None]:
df_many_cases_shap = utils.joinDataframes(df_many_cases_sampled, dshap)
df_many_cases_shap.drop(["flash_meta"], axis=1, inplace=True)

In [None]:
# Casestudies
case1 = {  
        "lon" : 15,
        "lat" : 45.25,
        "day" : 11,
        "month" : 6,
        "hour" : 19,
        "year" : 2019,
}

case2 = {  
        "lon" : 9.25,
        "lat" : 48.25,
        "day" : 12,
        "month" : 7,
        "hour" : 19,
        "year" : 2019,
}

case3 = {  
        "lon" : 10.5,
        "lat" : 47.5,
        "day" : 12,
        "month" : 7,
        "hour" : 22,
        "year" : 2019,
}

case4 = {  
        "lon" : 8.75,
        "lat" : 46.25,
        "day" : 14,
        "month" : 7,
        "hour" : 21,
        "year" : 2019,
}

case5 = {  
        "lon" : 16.75,
        "lat" : 48.0,
        "day" : 1,
        "month" : 7,
        "hour" : 21,
        "year" : 2019,
}

case6 = {  
        "lon" : 11.25,
        "lat" : 49.75,
        "day" : 3,
        "month" : 6,
        "hour" : 18,
        "year" : 2019,
}

case7 = {  
        "lon" : 16.75,
        "lat" : 46.75,
        "day" : 24,
        "month" : 8,
        "hour" : 17,
        "year" : 2019,
}

case8 = {  
        "lon" : 13.5,
        "lat" : 47.0,
        "day" : 1,
        "month" : 8,
        "hour" : 19,
        "year" : 2019,
}

case9 = {  
        "lon" : 13.5,
        "lat" : 48.5,
        "day" : 10,
        "month" : 6,
        "hour" : 20,
        "year" : 2019,
}

case10 = {  
        "lon" : 16.0,
        "lat" : 47.0,
        "day" : 27,
        "month" : 6,
        "hour" : 19,
        "year" : 2019,
}

case11 = {  
        "lon" : 16.0,
        "lat" : 45.5,
        "day" : 27,
        "month" : 7,
        "hour" : 19,
        "year" : 2019,
}

case12 = {  
        "lon" : 15,
        "lat" : 47.25,
        "day" : 20,
        "month" : 6,
        "hour" : 20,
        "year" : 2019,
}

case13 = {  
        "lon" : 11.25,
        "lat" : 49.75,
        "day" : 3,
        "month" : 6,
        "hour" : 18,
        "year" : 2019,
}

case14 = {  
        "lon" : 10.25,
        "lat" : 46.25,
        "day" : 24,
        "month" : 7,
        "hour" : 21,
        "year" : 2019,
}

case15 = {  
        "lon" : 11.0,
        "lat" : 48.0,
        "day" : 10,
        "month" : 6,
        "hour" : 16,
        "year" : 2019,
}

only_show_case = None  # choose a case you are interested in; otherwise set to None

df_shap_to_plot = df_many_cases_shap if only_show_case is None else df_many_cases_shap.query(f"longitude == {only_show_case['lon']} and latitude == {only_show_case['lat']} and day == {only_show_case['day']} and month == {only_show_case['month']} and hour == {only_show_case['hour']}")

In [None]:
df_shap_to_plot

In [None]:
# Used for debugging to quickly reload utils_plots.py
import importlib
importlib.reload(utils_plots)

In [None]:
ptype = "q50"  # can be mult, q50, q95
use_cache = True
separate_clusters = False
plot_clusters = {
                    0: 'TP less confident',
#                    1: 'TP very confident',
#                    2: 'FN',
                    3: 'FP',
#                    4: 'TN',
                }

only_show_cols = []
#only_show_cols = ["ciwc", "cswc"]

y_axis = "geopotential_altitude"  # level, geopotential_altitude

In [None]:
utils_plots.plot_many_profiles(df_shap_to_plot, "feature", config_data["datamode"], ptype=ptype, y_axis=y_axis, separate_clusters=separate_clusters, save_path=vis_save_path, use_cache=use_cache, plot_clusters=plot_clusters, only_show_cols=only_show_cols)

In [None]:
utils_plots.plot_many_profiles(df_many_cases_shap, "shap", config_data["datamode"], ptype=ptype, y_axis=y_axis, separate_clusters=separate_clusters, save_path=vis_save_path, use_cache=use_cache, plot_clusters=plot_clusters, only_show_cols=only_show_cols)

In [None]:
ucols = ["longitude", "latitude", "year", "month", "day", "hour", "cbh", "cth", "cluster"]
df_cbh_cth_grouped = df_many_cases.reset_index()[ucols].groupby(ucols)

In [None]:
cbh_ls = []
cth_ls = []
cluster_ls = []

for group, _ in df_cbh_cth_grouped:
    cbh_ls.append(group[ucols.index("cbh")])
    cth_ls.append(group[ucols.index("cth")])
    cluster_ls.append(group[ucols.index("cluster")])

In [None]:
df_cbh_cth = pd.DataFrame({"cbh" : cbh_ls, "cth" : cth_ls, "cluster" : cluster_ls})
df_cbh_cth["ch"] = df_cbh_cth["cth"] - df_cbh_cth["cbh"]
df_cbh_cth = df_cbh_cth[df_cbh_cth["cluster"].isin(plot_clusters.keys())]
df_cbh_cth["cluster_labels"] = df_cbh_cth["cluster"].replace(plot_clusters)

In [None]:
palette = sns.color_palette([utils_plots.CLUSTER_COLORS[col] for col in plot_clusters.keys()] if len(plot_clusters) > 0 else utils_plots.CLUSTER_COLORS)

In [None]:
legend_elements = [Patch(color=utils_plots.CLUSTER_COLORS[key], label=plot_clusters[key]) for key in plot_clusters]

In [None]:
cloud_medians = df_cbh_cth.groupby("cluster").median()

In [None]:
cloud_medians

In [None]:
g = plt.figure(figsize=(8, 15))
graph = sns.violinplot(data=df_cbh_cth, x="cluster", y="cth", fliersize=3, palette=palette, cut=0)
g.legend(handles=legend_elements, loc='center left', bbox_to_anchor=(0.9, 0.5))

for idx in plot_clusters.keys():
    graph.axhline(cloud_medians.query(f"cluster == {idx}")["cth"].values[0], color=utils_plots.CLUSTER_COLORS[idx])

ofile = os.path.join(vis_save_path, f"cth_violinplot")
g.savefig(f"{ofile}.pdf")
g.savefig(f"{ofile}.png")

In [None]:
g = plt.figure(figsize=(8, 15))
graph = sns.violinplot(data=df_cbh_cth, x="cluster", y="cbh", fliersize=3, palette=palette, cut=0)
g.legend(handles=legend_elements, loc='center left', bbox_to_anchor=(0.9, 0.5))

for idx in plot_clusters.keys():
    graph.axhline(cloud_medians.query(f"cluster == {idx}")["cbh"].values[0], color=utils_plots.CLUSTER_COLORS[idx])

ofile = os.path.join(vis_save_path, f"violinplot_cbh")
g.savefig(f"{ofile}.pdf")
g.savefig(f"{ofile}.png")

In [None]:
g = plt.figure(figsize=(8, 15))
graph = sns.violinplot(data=df_cbh_cth, x="cluster", y="ch", fliersize=3, palette=palette, cut=0)
g.legend(handles=legend_elements, loc='center left', bbox_to_anchor=(0.9, 0.5))

for idx in plot_clusters.keys():
    graph.axhline(cloud_medians.query(f"cluster == {idx}")["ch"].values[0], color=utils_plots.CLUSTER_COLORS[idx])

ofile = os.path.join(vis_save_path, f"violinplot_ch")
g.savefig(f"{ofile}.pdf")
g.savefig(f"{ofile}.png")