# Load Packages

In [None]:
%load_ext autoreload
%autoreload 2

import sys
from os.path import join
from tqdm.auto import tqdm
import joblib
import torch
sys.path.append("../../")

from src.evaluation.consolidate import consolidate_pred_perf, consolidate_ue_perf, consolidate_pi_perf
from src.evaluation.perf_eval import display_pred_perf
from src.evaluation.ue_eval import display_ue_perf
from src.evaluation.pi_eval import display_pi_perf, reorganise_pi_table, display_pi_perf_reorganised
from src.df_display.latex import df_to_latex, df_to_latex_grouped
from ue_pi_dicts import pi_order
from src.df_display.heatmap import generate_pi_heatmap

seed_list=list(range(2023, 2023+5))
data_label = "physionet"
batch_size = 64

# File paths
fp_notebooks_folder = "../"
fp_project_folder = join(fp_notebooks_folder, "../")
fp_processed_data_folder = join(fp_project_folder, "processed_data")
fp_output_data_folder = join(fp_processed_data_folder, "physionet")
fp_checkpoint_folder = join(fp_project_folder, "checkpoints")
fp_project_checkpoints = join(fp_checkpoint_folder, data_label)
fp_tuning = join(fp_project_checkpoints, "tuning")
fp_models = join(fp_project_checkpoints, "models")
fp_predictions = join(fp_project_checkpoints, "predictions")
fp_evaluation = join(fp_project_checkpoints, "model_evaluation")
fp_consolidated = join(fp_project_checkpoints, "consolidated_results")

# Prediction Performance

In [None]:
pred_perf_df = consolidate_pred_perf(seed_list, fp_evaluation)
display_pred_perf(pred_perf_df, consolidated=True)
pred_perf_df.to_csv(join(fp_consolidated, "pred_perf.csv"))

In [None]:
print(df_to_latex(pred_perf_df, column_format_dict={"t+1": "min", "t+2": "min", "t+3": "min"}))

In [None]:
pred_perf_df = consolidate_pred_perf(seed_list, fp_evaluation, one_col=True)
display_pred_perf(pred_perf_df, consolidated=True)
# pred_perf_df.to_csv(join(fp_consolidated, "pred_perf.csv"))

In [None]:
print(df_to_latex(pred_perf_df, column_format_dict={"Aggregated": "min"}))

# UE Performance

In [None]:
ue_perf_df = consolidate_ue_perf(seed_list, fp_evaluation, exclude_columns="Pval")
display_ue_perf(ue_perf_df, consolidated=True)

In [None]:
print(
    df_to_latex_grouped(
        ue_perf_df.drop(columns=["Sigma=0.3", "Sigma=0.4"]), 
        {"Corr": "max", "AURC":"min", "Sigma=0.1": "min", "Sigma=0.2": "min", }) # "Sigma=0.3": "min", "Sigma=0.4": "min"
)

# PI Performance

In [None]:
pi_perf_df = consolidate_pi_perf(
    seed_list, fp_evaluation, selected_columns=["CovP", "PINAW", "PINAFD", "CWFDC"],
    pi_order=pi_order
)
display_pi_perf(pi_perf_df, consolidated=True)
pi_perf_df.to_csv(join(fp_consolidated, "pi_perf.csv"))

## Heatmap

In [None]:
generate_pi_heatmap(
    seed_list, fp_evaluation=fp_evaluation, fp_consolidated=fp_consolidated, 
    bolded_methods=["RUE Gauss Copula", "RUE KNN"], pi_order=pi_order, gamma=0.6, save_fig=True,
    methods_to_drop=["RUE Cond Gauss", "RUE Weighted"],
    metrics=["CovP", "PINAW","CWFDC"], width=15, height=5.5
)

## Reorganised Table

In [None]:
pi_perf_df_reorganised = reorganise_pi_table(
    seed_list, fp_evaluation, pi_order, selected_columns=["CovP", "PINAW", "PINAFD", "CWFDC"],
)
pi_perf_df_reorganised = pi_perf_df_reorganised.drop(columns=[f"CWFDC\nt+{i+1}" for i in range(3)])
display_pi_perf_reorganised(pi_perf_df_reorganised)

## Time-Method Table

In [None]:
print(
    df_to_latex_grouped(
        pi_perf_df, 
        {"PINAW":"min", "PINAFD": "min", "CovP":"min", "CWFDC":"min"})
)