In [None]:
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt

from ospkg.constants import RESULTS_DIR

sns.set_theme()

res = []
for res_file in RESULTS_DIR.iterdir():
    res_json = pd.read_json(res_file)
    if len(res_json) == 5:
        res.append(res_json)
    else:
        if len(res_json) == 4:
            res.append(res_json)
        print(f"Detected incomplete run: '{res_file}' (len={len(res_json)})")

df = pd.concat(res, ignore_index=True)
df["model"] = df.model.str.lower()
df = df.loc[~df.model.isin(("sig", "dsig", "box_ord_n"))]
df.loc[df.model == "bin_n", "model"] += df.loc[df.model == "bin_n"].n_bins.astype(str)
df.loc[df.model == "box_ord_n", "model"] += df.loc[df.model == "box_ord_n"].order.astype(str)
df.loc[df.smote, "dataset"] += "_smote"
df.loc[df.val_mse, "model"] += "_mse"
df.dataset += "_" + df.seed.astype(str)
df["mse"] = df["mse"].transform(pd.to_numeric, errors="coerce")

agg_values = {f"mean_{v}": (v, "mean") for v in ["c_index", "mse", "best_trial_no"]}
stat_df = df.groupby(["model", "dataset"], dropna=False).agg(
    **agg_values, min_trial_num=("num_trials", "min"), count=("model", "size")
)
stat_df

In [None]:
def plot_heatmap(df, measure, sort_by=None, datasets=None, use_rank=False, **kwargs):
    data = df
    data = data.groupby(["model", "dataset"])[measure].mean().unstack("dataset")
    if datasets is not None:
        data = data[datasets]
    if use_rank:
        data = data.rank(ascending=measure == "mse")
    title = measure
    if sort_by is not None:
        data = data.sort_values(sort_by, ascending=measure == "mse" or use_rank)
        title += f" (sorted by {sort_by}) "
    ax = sns.heatmap(data=data, annot=True, cmap="viridis", **kwargs)
    ax.set_title(title)
    return ax


datasets = [
    "snmmi_22",
    "snmmi_42",
    "snmmi_smote_22",
    "snmmi_smote_42",
    "snmmi_gauss_22",
    "snmmi_gauss_42",
    "snmmi_gauss_smote_22",
    "snmmi_gauss_smote_42",
]
plt.figure(figsize=(22, 6))
for i, measure in enumerate(("c_index", "mse"), 1):
    ax = plt.subplot(1, 2, i)
    plot_heatmap(df, measure, sort_by=datasets[0], datasets=datasets, fmt=".3f", ax=ax)
    ax.set(ylabel=None, xlabel=None)

In [None]:
datasets = [
    "snmmi_22",
    "snmmi_42",
    "snmmi_smote_22",
    "snmmi_smote_42",
    "snmmi_gauss_22",
    "snmmi_gauss_42",
    "snmmi_gauss_smote_22",
    "snmmi_gauss_smote_42",
]
plt.figure(figsize=(22, 6))
for i, measure in enumerate(("c_index", "mse"), 1):
    ax = plt.subplot(1, 3, i)
    plot_heatmap(df, measure, sort_by=datasets[0], datasets=datasets, use_rank=True, ax=ax)
    ax.set(ylabel=None, xlabel=None)

In [None]:
plt.figure(figsize=(18, 5))
for i, measure in enumerate(("c_index", "mse", "c_index_td"), 1):
    ax = plt.subplot(1, 3, i)
    sns.boxplot(df, y=measure, x="fold_num", ax=ax)
    ax.set_title(f"average {measure} across folds")