# Load Packages

In [14]:
%load_ext autoreload
%autoreload 2

import sys
from os.path import join
from tqdm.auto import tqdm
import joblib
import torch
sys.path.append("../../")

from src.file_manager.load_data import load_split_dict
from src.file_processing.processing_predictions import load_prediction_df_dict
from src.file_manager.save_load_scaler import load_scaler
from src.pi_methods.knn import knn_prediction_interval
from src.pi_methods.weighted import weighted_prediction_interval
from src.misc import create_folder

seed = 2023
data_label = "physionet"
batch_size = 64

# File paths
fp_notebooks_folder = "../"
fp_project_folder = join(fp_notebooks_folder, "../")
fp_data_folder = join(fp_project_folder, "../", "data")
fp_output_data_folder = join(fp_data_folder, data_label)
fp_checkpoint_folder = join(fp_project_folder, "checkpoints")
fp_project_checkpoints = join(fp_checkpoint_folder, data_label)
fp_tuning = join(fp_project_checkpoints, "tuning")
fp_models = join(fp_project_checkpoints, "models")
fp_predictions = join(fp_project_checkpoints, "predictions")
fp_pi_predictions = join(fp_project_checkpoints, "pi_predictions")

# Seed filepaths
fp_cur_model_folder = join(fp_models, str(seed))
create_folder(fp_cur_model_folder)
fp_cur_predictions_folder = join(fp_predictions, str(seed))
create_folder(fp_cur_predictions_folder)
fp_cur_pi_predictions_folder = join(fp_pi_predictions, str(seed))
create_folder(fp_cur_pi_predictions_folder)

ue_dict = {
    "RUE": {"pred_label": "_direct", "ue_col": "rue"},
    "MC Dropout": {"pred_label": "_mean", "ue_col": "mcd", },
    "GPR": {"pred_label": "_gpr", "ue_col": "gpr_std_mean", },
    "Infer-Noise": {"pred_label": "_infernoise", "ue_col": "infernoise_uncertainty", },
    "BNN": {"pred_label": "_bnn", "ue_col": "bnn_uncertainty", },
    "DER": {"pred_label": "_der", "ue_col": "der_uncertainty", }
}   

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Load Data

In [2]:
split_dict = load_split_dict(fp_output_data_folder)

# Load All Predictions

In [6]:
scaler = load_scaler(join(fp_output_data_folder, "scaler.pickle"))
pred_df_dict = load_prediction_df_dict(
    split_dict, fp_cur_predictions_folder, 
    pred_file_names=["rue", "gpr", "infernoise"]) # "bnn", "der"

  0%|          | 0/3 [00:00<?, ?it/s]

# Add KNN PI

In [None]:
# Apply knn_prediction_interval to all RUE
ue_label = "RUE"
for time_label, time_info in tqdm(pred_df_dict.items()):
    val_df, test_df, pred_cols = time_info["valid_df"], time_info["test_df"], time_info["pred_cols"]
    pred_label, ue_col = ue_dict[ue_label]["pred_label"], ue_dict[ue_label]["ue_col"]
    pred_df_dict[time_label]["test_df"] = knn_prediction_interval(
        df_val=val_df, df_test=test_df, 
        predictors=split_dict["feat_cols"], pred_cols=pred_cols, 
        pred_label=pred_label, regressor_label=time_label, ue_col=ue_col, 
        scaler=scaler, seed=seed
    )


# Add Weighted Percentile PI

In [16]:
# Apply weighted_prediction_interval to all RUE
ue_label = "RUE"
for time_label, time_info in tqdm(pred_df_dict.items()):
    val_df, test_df, pred_cols = time_info["valid_df"], time_info["test_df"], time_info["pred_cols"]
    pred_label, ue_col = ue_dict[ue_label]["pred_label"], ue_dict[ue_label]["ue_col"]
    pred_df_dict[time_label]["test_df"] = weighted_prediction_interval(
        df_val=val_df, df_test=test_df, 
        predictors=split_dict["feat_cols"], pred_cols=pred_cols, 
        pred_label=pred_label, regressor_label=time_label, ue_col=ue_col, 
        scaler=scaler, seed=seed
    )

  0%|          | 0/3 [00:00<?, ?it/s]

0
2.424518261263724e-05


  0%|          | 0/5 [00:00<?, ?it/s]

0
2.045193056612151e-05


  0%|          | 0/5 [00:00<?, ?it/s]

0
2.147562592395685e-05


  0%|          | 0/5 [00:00<?, ?it/s]

## Add Cond Gauss PI

In [None]:
ue_dict = {
    "RUE": {
        "pred_label": "_direct", "ue_col": "rue", 
    },
}    

# Apply cond_gauss_prediction_interval to all RUE
for time_label, time_info in df_dict.items():
    val_df, test_df, pred_cols = time_info["valid_df"], time_info["test_df"], time_info["pred_cols"]
    ue_label = "rue"
    for ue_label, ue_info in ue_dict.items():
        pred_label, ue_col = ue_info["pred_label"], ue_info["ue_col"]
        df_dict[time_label]["test_df"] = cond_gauss_prediction_interval(
            df_val=df_dict[time_label]["valid_df"], df_test=df_dict[time_label]["test_df"], predictors=predictors, pred_cols=pred_cols, 
            pred_label=pred_label, regressor_label=time_label, ue_col=ue_col, scaler=scaler)

save_df_dict(df_dict=df_dict, seed=seed)

In [17]:
pred_df_dict["t+1"]["test_df"]

Unnamed: 0,HR_t-4,SysABP_t-4,DiasABP_t-4,MAP_t-4,Urine_t-4,time_t-4,HR_t-3,SysABP_t-3,DiasABP_t-3,MAP_t-3,...,HR_t+1_rue_knn,SysABP_t+1_rue_knn,DiasABP_t+1_rue_knn,MAP_t+1_rue_knn,Urine_t+1_rue_knn,HR_t+1_rue_weighted,SysABP_t+1_rue_weighted,DiasABP_t+1_rue_weighted,MAP_t+1_rue_weighted,Urine_t+1_rue_weighted
0,0.295699,0.405303,0.192157,0.267148,0.240,33.783333,0.333333,0.443182,0.207843,0.296029,...,69.588272,55.151863,44.522202,121.469887,543.430359,19.750633,36.727158,12.890018,16.426033,348.703857
1,0.333333,0.443182,0.207843,0.296029,0.260,34.783333,0.290323,0.405303,0.192157,0.270758,...,69.588272,55.151863,44.522202,121.469887,543.430359,19.750633,29.815804,15.415569,21.550995,348.008423
2,0.290323,0.405303,0.192157,0.270758,0.200,35.783333,0.279570,0.375000,0.172549,0.245487,...,69.588272,55.151863,44.522202,121.469887,543.430359,11.960800,32.405914,9.969261,15.759071,87.836746
3,0.279570,0.375000,0.172549,0.245487,0.100,36.783333,0.290323,0.386364,0.180392,0.259928,...,69.588272,55.151863,44.522202,121.469887,543.430359,69.588272,32.405914,17.812626,17.824089,342.530518
4,0.290323,0.386364,0.180392,0.259928,0.100,37.783333,0.322581,0.371212,0.184314,0.259928,...,69.588272,55.151863,44.522202,121.469887,543.430359,11.960800,32.405914,11.230831,15.759071,87.836746
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4,0.430108,0.465909,0.152941,0.234657,0.010,36.183333,0.376344,0.390152,0.141176,0.234657,...,69.588272,55.151863,44.522202,121.469887,543.430359,9.142712,26.901810,14.054073,16.887100,201.715027
5,0.376344,0.390152,0.141176,0.234657,0.025,37.183333,0.365591,0.401515,0.141176,0.238267,...,69.588272,55.151863,44.522202,121.469887,543.430359,22.434761,29.815804,15.415569,21.550995,348.008423
6,0.365591,0.401515,0.141176,0.238267,0.033,38.183333,0.376344,0.462121,0.156863,0.267148,...,69.588272,55.151863,44.522202,121.469887,543.430359,15.907608,29.531738,9.791492,15.720116,82.571487
7,0.376344,0.462121,0.156863,0.267148,0.030,39.183333,0.365591,0.439394,0.152941,0.259928,...,69.588272,55.151863,44.522202,121.469887,543.430359,24.371895,27.836327,11.848957,18.882324,82.571487
