# Load Packages

In [None]:
%load_ext autoreload
%autoreload 2

import sys
from os.path import join
from tqdm.auto import tqdm
import joblib
import torch
sys.path.append("../../")

from src.file_manager.load_data import load_split_dict
from src.models.der.tuning import tune_der_model
from src.models.der.training import train_der_w_param
from src.models.der.predicting import der_model_prediction
from src.misc import create_folder
from seed_file import seed

# seed = 2023
tuning_seed = 2023
data_label = "physionet"
batch_size = 64

# File paths
fp_notebooks_folder = "../"
fp_project_folder = join(fp_notebooks_folder, "../")
fp_processed_data_folder = join(fp_project_folder, "processed_data")
fp_output_data_folder = join(fp_processed_data_folder, "physionet")
fp_checkpoint_folder = join(fp_project_folder, "checkpoints")
fp_project_checkpoints = join(fp_checkpoint_folder, data_label)
fp_tuning = join(fp_project_checkpoints, "tuning")
fp_models = join(fp_project_checkpoints, "models")
fp_predictions = join(fp_project_checkpoints, "predictions")

# Seed filepaths
fp_cur_tune_folder = join(fp_tuning, str(tuning_seed))
create_folder(fp_cur_tune_folder)
fp_cur_model_folder = join(fp_models, str(seed))
create_folder(fp_cur_model_folder)
fp_cur_predictions_folder = join(fp_predictions, str(seed))
create_folder(fp_cur_predictions_folder)

# Load Data

In [None]:
split_dict = load_split_dict(fp_output_data_folder)

# Tune DER

In [None]:
if seed==tuning_seed:
    all_der_best_hp = {}
    for time_label, target_cols in tqdm(split_dict["target_cols"].items()):
        der_tuning_df, der_best_hp = tune_der_model(
            param_grid={
                "n_hidden_layers":[2, 3, 4],
                "hidden_width": [128, 256, 512]},  
            train_df=split_dict["train_df"], valid_df=split_dict["valid_df"], 
            feat_cols=split_dict["feat_cols"], target_cols=target_cols, 
            epochs=500, patience=5, seed=seed, batch_size=batch_size
        )
        der_tuning_df.to_csv(join(fp_cur_tune_folder, f"tuning_der_{time_label}.csv"))
        all_der_best_hp[time_label] = der_best_hp
        display(der_tuning_df)
    joblib.dump(all_der_best_hp, join(fp_cur_tune_folder, "all_der_best_hp.joblib"))
    display(all_der_best_hp)

# Training DER

In [None]:
all_der_best_hp = joblib.load(join(fp_cur_tune_folder, "all_der_best_hp.joblib"))
for time_label, target_cols in tqdm(split_dict["target_cols"].items()):
    fp_model = join(fp_cur_model_folder, f"der_{time_label}.pt")
    der_model, _ = train_der_w_param(
        **all_der_best_hp[time_label], 
        train_df=split_dict["train_df"], valid_df=split_dict["valid_df"], 
        inputs=split_dict["feat_cols"], outputs=target_cols,
        seed=seed, max_epochs=500, patience=5, batch_size=batch_size
    )
    torch.save(der_model, fp_model)

# Prediction

In [None]:
for time_label, target_cols in tqdm(split_dict["target_cols"].items()):
    fp_model = join(fp_cur_model_folder, f"der_{time_label}.pt")
    der_model = torch.load(fp_model)
    der_valid_df = der_model_prediction(
        der_model, test_df=split_dict["valid_df"], 
        feat_cols=split_dict["feat_cols"], target_cols=target_cols, 
        seed=seed, silent=False, regressor_label=time_label)
    der_test_df = der_model_prediction(
        der_model, test_df=split_dict["test_df"], 
        feat_cols=split_dict["feat_cols"], target_cols=target_cols, 
        seed=seed, silent=False, regressor_label=time_label)
    der_valid_df.to_csv(join(fp_cur_predictions_folder, f"der_valid_{time_label[-1]}.csv"))
    der_test_df.to_csv(join(fp_cur_predictions_folder, f"der_test_{time_label[-1]}.csv"))