# Load Packages

In [None]:

%load_ext autoreload
%autoreload 2

import sys
from os.path import join
from tqdm.auto import tqdm
import joblib
import torch
sys.path.append("../../")

from src.file_manager.load_data import load_split_dict
from src.models.bnn.tuning import tune_bnn_model
from src.models.bnn.training import train_model_w_best_param
from src.models.bnn.predicting import bnn_model_prediction
from src.misc import create_folder
from seed_file import seed

# seed = 2023
tuning_seed = 2023
data_label = "physionet"
batch_size = 64

# File paths
fp_notebooks_folder = "../"
fp_project_folder = join(fp_notebooks_folder, "../")
fp_processed_data_folder = join(fp_project_folder, "processed_data")
fp_output_data_folder = join(fp_processed_data_folder, "physionet")
fp_checkpoint_folder = join(fp_project_folder, "checkpoints")
fp_project_checkpoints = join(fp_checkpoint_folder, data_label)
fp_tuning = join(fp_project_checkpoints, "tuning")
fp_models = join(fp_project_checkpoints, "models")
fp_predictions = join(fp_project_checkpoints, "predictions")

# Seed filepaths
fp_cur_tune_folder = join(fp_tuning, str(tuning_seed))
create_folder(fp_cur_tune_folder)
fp_cur_model_folder = join(fp_models, str(seed))
create_folder(fp_cur_model_folder)
fp_cur_predictions_folder = join(fp_predictions, str(seed))
create_folder(fp_cur_predictions_folder)

# Load Data

In [None]:
split_dict = load_split_dict(fp_output_data_folder)

# Tune BNN

In [None]:
if seed==tuning_seed:
    all_bnn_best_hp = {}
    for time_label, target_cols in tqdm(split_dict["target_cols"].items()):
        fp_model = join(fp_cur_model_folder, f"bnn_{time_label}.pt")
        bnn_tuning_df, bnn_best_hp = tune_bnn_model(
            param_grid={"num_layers":[2, 3], "width":[32, 64, 128, 256]}, 
            train_df=split_dict["train_df"], valid_df=split_dict["valid_df"], 
            feat_cols=split_dict["feat_cols"], target_cols=target_cols, 
            epochs=500, patience=5, seed=seed, batch_size=batch_size, fp_model=fp_model)
        display(bnn_tuning_df)
        bnn_tuning_df.to_csv(join(fp_cur_tune_folder, f"tuning_bnn_{time_label}.csv"))
        all_bnn_best_hp[time_label] = bnn_best_hp
    joblib.dump(all_bnn_best_hp, join(fp_tuning, "all_bnn_best_hp.joblib"))
    display(all_bnn_best_hp)

# Training

In [None]:
all_bnn_best_hp = joblib.load(join(fp_tuning, "all_bnn_best_hp.joblib"))
for time_label, target_cols in tqdm(split_dict["target_cols"].items()):
    fp_model = join(fp_cur_model_folder, f"bnn_{time_label}.pt")
    bnn_best_hp = all_bnn_best_hp[time_label]
    bnn_model = train_model_w_best_param(
        best_param=bnn_best_hp, 
        train_df=split_dict["train_df"], valid_df=split_dict["valid_df"], 
        feat_cols=split_dict["feat_cols"], target_cols=target_cols,
        epochs=500, patience=5, seed=seed, fp_model=fp_model, batch_size=batch_size
    )

# Prediction

In [None]:
for time_label, target_cols in tqdm(split_dict["target_cols"].items()):
    fp_model = join(fp_cur_model_folder, f"bnn_{time_label}.pt")
    bnn_model = torch.load(fp_model)
    bnn_valid_df = bnn_model_prediction(
        bnn_model, split_dict["valid_df"], 
        feat_cols=split_dict["feat_cols"], target_cols=target_cols, 
        T=10, seed=seed, regressor_label=time_label, batch_size=batch_size)
    bnn_test_df = bnn_model_prediction(
        bnn_model, split_dict["test_df"], 
        feat_cols=split_dict["feat_cols"], target_cols=target_cols, 
        T=10, seed=seed, regressor_label=time_label, batch_size=batch_size)
    # display(bnn_test_df)
    bnn_valid_df.to_csv(join(fp_cur_predictions_folder, f"bnn_valid_{time_label[-1]}.csv"))
    bnn_test_df.to_csv(join(fp_cur_predictions_folder, f"bnn_test_{time_label[-1]}.csv"))