# Load Packages

In [None]:
%load_ext autoreload
%autoreload 2

import sys
from os.path import join
from tqdm.auto import tqdm
import joblib
sys.path.append("../../")

from src.file_manager.load_data import load_split_dict
from src.models.rue.tuning import model_tuning_regressor, model_tuning_decoder
from src.models.rue.training import model_training_predictor, model_training_decoder
from src.models.rue.save_load_model import save_model, load_model
from src.models.rue.predicting import model_test_predictions
from src.misc import create_folder
from seed_file import seed

# seed = 2023
tuning_seed = 2023
data_label = "physionet"
batch_size = 64

# File paths
fp_notebooks_folder = "../"
fp_project_folder = join(fp_notebooks_folder, "../")
fp_processed_data_folder = join(fp_project_folder, "processed_data")
fp_output_data_folder = join(fp_processed_data_folder, "physionet")
fp_checkpoint_folder = join(fp_project_folder, "checkpoints")
fp_project_checkpoints = join(fp_checkpoint_folder, data_label)
fp_tuning = join(fp_project_checkpoints, "tuning")
fp_models = join(fp_project_checkpoints, "models")
fp_predictions = join(fp_project_checkpoints, "predictions")

# Seed filepaths
fp_cur_tune_folder = join(fp_tuning, str(tuning_seed))
create_folder(fp_cur_tune_folder)
fp_cur_model_folder = join(fp_models, str(seed))
create_folder(fp_cur_model_folder)
fp_cur_predictions_folder = join(fp_predictions, str(seed))
create_folder(fp_cur_predictions_folder)

# Load Data

In [None]:
split_dict = load_split_dict(fp_output_data_folder)

# Tune and Train Predictor

## Tune

In [None]:
if seed == tuning_seed:
    all_rue_predictor_best_hp = {}
    for time_label, target_cols in tqdm(split_dict["target_cols"].items()):
        rue_predictor_tuning_df, rue_predictor_best_hp = model_tuning_regressor(
            param_grid = {
                "encoder_width": [128, 256, 512],
                "encoder_depth": [2, 3, 4],
                "decoder_width":[64],
                "decoder_depth": [2],
            }, predictors=split_dict["feat_cols"], pred_cols=target_cols, 
            train_df=split_dict["train_df"], valid_df=split_dict["valid_df"], seed=seed,
            batch_size=batch_size, max_epochs=10000, verbose=1, patience=20
        )
        display(rue_predictor_tuning_df)
        rue_predictor_tuning_df.to_csv(join(fp_cur_tune_folder, f"tuning_rue_{time_label}.csv"))
        all_rue_predictor_best_hp[time_label] = rue_predictor_best_hp
    joblib.dump(all_rue_predictor_best_hp, join(fp_cur_tune_folder, "all_rue_predictor_best_hp.joblib"))
    all_rue_predictor_best_hp

## Train

In [None]:
all_rue_predictor_best_hp = joblib.load(join(fp_cur_tune_folder, "all_rue_predictor_best_hp.joblib"))
for time_label, target_cols in tqdm(split_dict["target_cols"].items()):
    best_predictor_hp = all_rue_predictor_best_hp[time_label]
    ae_regressor = model_training_predictor(
        best_predictor_hp, 
        predictors=split_dict["feat_cols"], pred_cols=target_cols, 
        train_df=split_dict["train_df"], valid_df = split_dict["valid_df"], seed=seed,
        batch_size=batch_size, max_epochs=10000, verbose=1, patience=20
    ) 
    save_model(
        model=ae_regressor, name=f"rue_predictor_{time_label}", 
        fp_checkpoints=fp_cur_model_folder, override=True)

# Tune and Train Decoder

## Tune

In [None]:
if seed == tuning_seed:
    all_rue_predictor_best_hp = joblib.load(join(fp_cur_tune_folder, "all_rue_predictor_best_hp.joblib"))
    all_rue_decoder_best_hp = {}
    for time_label, target_cols in tqdm(split_dict["target_cols"].items()):
        best_predictor_hp = all_rue_predictor_best_hp[time_label]
        prev_model = load_model(
            name=f"rue_predictor_{time_label}", fp_checkpoints=fp_cur_model_folder)
        rue_tuning_df, rue_decoder_best_hp = model_tuning_decoder(
            param_grid=dict(
                encoder_width = [best_predictor_hp["encoder_width"]], 
                encoder_depth = [best_predictor_hp["encoder_depth"]],
                decoder_width =[128, 256, 512],
                decoder_depth = [2, 3, 4],
            ), predictors=split_dict["feat_cols"], pred_cols=target_cols, 
            train_df=split_dict["train_df"], valid_df=split_dict["valid_df"], seed=seed,
            max_epochs=10000, verbose=1, patience=20, prev_model=prev_model
        )
        display(rue_tuning_df)
        rue_tuning_df.to_csv(join(fp_cur_tune_folder, f"tuning_rue_decoder_{time_label}.csv"))
        all_rue_decoder_best_hp[time_label] = rue_decoder_best_hp
    joblib.dump(all_rue_decoder_best_hp, join(fp_cur_tune_folder, "all_rue_decoder_best_hp.joblib"))
    all_rue_decoder_best_hp

## Train

In [None]:
all_rue_decoder_best_hp = joblib.load(join(fp_cur_tune_folder, "all_rue_decoder_best_hp.joblib"))
for time_label, target_cols in tqdm(split_dict["target_cols"].items()):
    prev_model = load_model(
        name=f"rue_predictor_{time_label}", fp_checkpoints=fp_cur_model_folder)
    hp_dict = all_rue_decoder_best_hp[time_label]
    ae_regressor = model_training_decoder(
        hp_dict, predictors=split_dict["feat_cols"], pred_cols=target_cols, 
        train_df=split_dict["train_df"], valid_df = split_dict["valid_df"], 
        seed=seed, prev_model=prev_model,
        batch_size=batch_size, max_epochs=10000, verbose=1, patience=20
    ) 
    save_model(model=ae_regressor, name=f"rue_{time_label}", 
               fp_checkpoints=fp_cur_model_folder, override=True)

# Prediction

In [None]:
for time_label, target_cols in tqdm(split_dict["target_cols"].items()):
    ae_regressor = load_model(name=f"rue_{time_label}", fp_checkpoints=fp_cur_model_folder)
    rue_valid_df = model_test_predictions(
        ae_regressor, df_train=split_dict["train_df"], df_test=split_dict["valid_df"], 
        pred_cols=target_cols, predictors=split_dict["feat_cols"], 
        regressor_label="_"+time_label, pred_min=int(time_label[-1]), T=10, seed=seed)
    rue_test_df = model_test_predictions(
        ae_regressor, 
        df_train=split_dict["train_df"], df_test=split_dict["test_df"], 
        pred_cols=target_cols, predictors=split_dict["feat_cols"], 
        regressor_label="_"+time_label, pred_min=int(time_label[-1]), T=10, seed=seed)
    display(rue_test_df)
    rue_valid_df.to_csv(join(fp_cur_predictions_folder, f"rue_valid_{time_label[-1]}.csv"))
    rue_test_df.to_csv(join(fp_cur_predictions_folder, f"rue_test_{time_label[-1]}.csv"))
