## Overfitting Exploration

### Useful Preliminaries

In [None]:
import os
import sys
sys.path.append("..")  # add project root

import shutil
import re
from argparse import ArgumentParser
from pickle import dump, load

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from tqdm import tqdm

import zarr
import dask.array as da

from ray import tune

from sklearn.metrics import balanced_accuracy_score, roc_auc_score

from src.data_utils import *
from src.constants import *
from src.tuner import train_cv, RayAdaptiveRepeatedCVSearch

In [None]:
pd.options.display.float_format = '{:10,.3f}'.format

In [None]:
np.random.seed(420)

In [None]:
sns.set_theme(context="talk")

In [None]:
# path constants
train_dir = "/home/mr2238/project_pi_np442/mr2238/accelerate/data/smooth46/"

In [None]:
# check img directory exists, if not make it
img_dir = "/home/mr2238/project_pi_np442/mr2238/accelerate/imgs/overfit"
os.makedirs(img_dir, exist_ok=True)

### Loading Model and Results

In [None]:
dataset_name = "smooth_downsample_w_300s_hr_rso2r_rso2l_spo2_abp"
run_name = "current"
small = False
model_name = f"models{'_debug' if small else ''}_{run_name}"

In [None]:
model_store = os.path.join(train_dir, dataset_name, model_name)
print(model_store)

In [None]:
print(os.listdir(model_store))

In [None]:
model_states = {}
for f in os.listdir(model_store):
    if not f.endswith(".pkl"):
        state = tune.ExperimentAnalysis(experiment_checkpoint_path=os.path.join(model_store, f))
        model_states[f] = state

In [None]:
# TBD load test metrics? could also just move this to eval.py

### Plot Best Results

In [None]:
for k, v in model_states.items():
    print(k)
    print(v.results_df.columns[:9])

In [None]:
# gather results
def gather_results(model_states, metric, others_to_fetch):
    rows = []
    of_interest = ['model'] + [metric] + others_to_fetch
    for k, v in model_states.items():
        df = v.results_df
        try:
            result = df.loc[[df[metric].idxmax()]]
            result["model"] = k
            rows.append(result[of_interest])
        except:
            continue
    return pd.concat(rows, ignore_index=True)

In [None]:
others = ['mean_train_auc', 'std_val_auc', 'std_train_auc',
       'mean_val_auc', 'mean_val_balanced_accuracy',
       'std_val_balanced_accuracy', 'std_train_balanced_accuracy']

In [None]:
r = gather_results(model_states, 'mean_train_balanced_accuracy', others)
print(r)

### Plot results per model

##### Prelims

In [None]:
# list training dirs
dataset_names = os.listdir(train_dir)
run_name = "rapid"
dataset_names

In [None]:
dataset_names.remove("debug")

In [None]:
# loop through training dirs, pick out training results per model
def model_path_iter(dataset_names, run_name):
    for ds in dataset_names:
        ds_path = os.path.join(train_dir, ds)
        for model_dir in os.listdir(ds_path):
            if run_name in model_dir:
                # grab model paths
                md_path = os.path.join(ds_path, model_dir)
                for m in os.listdir(md_path):
                    model_path = os.path.join(md_path, m)
                    if not m.endswith(".pkl"):
                        yield model_path, "debug" in model_dir

In [None]:
of_interest = ['mean_val_auc', 'mean_train_auc', 'std_val_auc', 'std_train_auc',
       'mean_train_balanced_accuracy', 'mean_val_balanced_accuracy',
       'std_val_balanced_accuracy', 'std_train_balanced_accuracy']

##### DF

In [None]:
large_results = []
for m, d in model_path_iter(dataset_names, run_name):
    # grab results_df
    try:
        state = tune.ExperimentAnalysis(experiment_checkpoint_path=m)
    except ValueError:
        print(f"Could not find experiment at {m}, skipping.")
        continue
    df = state.results_df
    if df.shape[1] > 0:
        df = df[df['done'] == True]
        df = df[of_interest]
    
        # add debug flag to df
        df['debug'] = d
        # add model_name
        df['model'] = os.path.basename(m)

        # add dataset_name
        df['dataset'] = os.path.basename(os.path.dirname(os.path.dirname(m)))

        # combine into one dataset
        large_results.append(df)

large_result_df = pd.concat(large_results)
print(large_result_df.shape)

In [None]:
mapping = {"_separate_decomp": "separate_pca", "_pca": "pca", "_raw": "raw"}
modelnames = [
            "log_reg",
            "svm",
            "knn",
            "rand_forest",
            "decision_tree",
            "xgb",
            "rocket",
            "kn_multivar",
        ]

In [None]:
large_result_df["model"] = large_result_df["model"].str.replace("_separate_pca", "_separate_decomp")
large_result_df["model"] = large_result_df["model"].str.replace("knn_multivar", "kn_multivar")
large_result_df["datamode"] = large_result_df["model"].apply(
    lambda x: next(
        (v for k, v in mapping.items() if k in x),
        None  # default if no match
    )
)
large_result_df["datamode"]
large_result_df["model"] = large_result_df["model"].apply(
    lambda x: next(
        (m for m in modelnames if m in x),
        None  # default if no match
    )
)
large_result_df["model"]

In [None]:
long_df = large_result_df.melt(
    id_vars=["dataset", "debug", "model", "datamode"],
    value_vars=["mean_train_auc", "mean_val_auc"],
    var_name="metric",
    value_name="auc"
)

In [None]:
# plot all model performances on scatter plot
plt.figure(figsize=(10, 5))
for embedding in large_result_df.datamode.unique():
    plt.figure(figsize=(16, 12))
    d = large_result_df[large_result_df.datamode == embedding]
    g = sns.scatterplot(
        data=d,
        x = "mean_train_auc",
        y = "mean_val_auc",
        hue="model",
        style="dataset",
        s=150,
        alpha=0.7
    )
    # add y=x line
    plt.plot([0, 1], [0, 1], ls='--', c='gray')
    g.set_title(f"{embedding.upper()} model performances")
    g.set_xlabel("Mean Train AUC")
    g.set_ylabel("Mean Val AUC")
    g.legend(loc="upper left", bbox_to_anchor=(1.02, 1))
    g.set_ylim(0.3, 0.8)
    g.set_xlim(0.5, 1.0)
    
    img_name = f"{embedding}_all_models_performance.png"
    # plt.savefig(os.path.join(img_dir, img_name), bbox_inches='tight')


In [None]:
# plot
for ds in long_df.dataset.unique():
    plot_df = long_df[(long_df.dataset == ds)].copy()
    plot_df["model"] = np.where(plot_df["debug"], plot_df["model"] + "*",plot_df["model"])
    # plot grouped barchart of train_auc and val_auc with model on x axis
    if plot_df.empty:
        continue

    plt.figure(figsize=(10, 5))
    g = sns.catplot(
        data=plot_df,
        x="datamode",
        y="auc",
        hue="metric",
        col="model",
        kind="bar",
        dodge=True,
        height=4,
        aspect=1.2,
        col_wrap = 3,
        sharex=False,
    )
    g.set_titles("{col_name}")
    # g.set_xticklabels(rotation=30)
    g.set_axis_labels("", "AUC")
    g.set(ylim=(0, 1))
    # g.legend.set_loc("upper right")

    
    plt.suptitle(f"{ds}", y=1.04)
    # plt.tight_layout()
    # plt.legend(loc=(1,1))
    img_name = f"{ds}.png"
    # plt.savefig(os.path.join(img_dir, img_name), bbox_inches='tight')
    plt.show()


In [None]:
plt.close()

In [None]:
# display top 5 per group
for s in long_df.metric.unique():
    print(f"Top 5 for {s}:")
    print(long_df[long_df.metric == s].sort_values(by="auc", ascending=False).head(5))

In [None]:
# determine which models and datasets have highest performance
groups = large_result_df.groupby(['dataset'])
print(groups['mean_val_auc'].max().sort_values(ascending=False)[:50])
print(groups['mean_val_auc'].mean().sort_values(ascending=False)[:50])

### Models

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

In [None]:
# best model in model class per dataset
groups = large_result_df.groupby(['dataset', 'model', 'datamode'])
best_models = groups['mean_val_auc'].max().sort_values(ascending=False)
print(best_models.reset_index()[best_models.reset_index().model == 'rocket'].mean_val_auc)

In [None]:
# loop through training dirs, pick out best model
def model_pkl_iter(dataset_names, run_name):
    for ds in dataset_names:
        ds_path = os.path.join(train_dir, ds)
        for model_dir in os.listdir(ds_path):
            if run_name in model_dir:
                # grab model paths
                md_path = os.path.join(ds_path, model_dir)
                for m in os.listdir(md_path):
                    model_path = os.path.join(md_path, m)
                    if m.endswith(".pkl"):
                        yield model_path, "debug" in model_dir

In [None]:
# big hyper parameter plots for train and test auc
# large_results = []
# for m, d in model_path_iter(dataset_names, run_name):
#     # grab results_df
#     try:
#         state = tune.ExperimentAnalysis(experiment_checkpoint_path=m)
#     except ValueError:
#         print(f"Could not find experiment at {m}, skipping.")
#         continue
#     df = state.dataframe()
#     df = df[df["done"]]
#     hps = [c for c in df.columns if c.startswith("config/") and df[c].dtype != 'object']
#     df = pd.melt(df, id_vars=["mean_val_auc", "mean_train_auc"], value_vars=hps, var_name="hp", value_name="value")

#     print('\n')
#     print(os.path.basename(m), os.path.basename(os.path.dirname(os.path.dirname(m))))
    
#     plt.figure(figsize=(12,6))
#     g = sns.FacetGrid(
#         df,
#         col="hp",
#         col_wrap=3,
#         sharey=True,
#         sharex=False
#     )
#     g.map_dataframe(
#         sns.scatterplot,
#         x="value",
#         y="mean_val_auc",
#         s=10,
#     )
#     g.map_dataframe(
#         sns.scatterplot,
#         x="value",
#         y="mean_train_auc",
#         s=10,
#         color='red',
#     )

#     # g.set_axis_labels(x_var="hp", y_var="mean_val_auc")
#     g.set_titles(col_template="{col_name}", size=8)
#     plt.show()

In [None]:
# confusion matrices for best models TBD
for m, d in model_pkl_iter(dataset_names, run_name):
    model_inst = load(open(m, "rb"))
    print(type(model_inst))
    print(model_inst.get_params())
    # load  dataset

    # split train, val


    # fit


    # plot confusion matrix with fitted model

