# Notebook for training density estimator

In [None]:
import os, random
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

import pickle
import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger#, WandbLogger
#import wandb
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
import optuna
import numpy as np
from sklearn.utils.class_weight import compute_class_weight

from modules.helpers import seed#, fit_model
from modules.predictors_modules import FairClfDataModule, Density_estimator, StandardRegressor

import matplotlib.pyplot as plt

os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

#wandb.login()

seed(1)

In [None]:
experiment_list = ["sim_cont_b", "sim_cont_c", "sim_cont_d", "sim_cont_e"]

save_location = "../models"

In [None]:
train_config = dict(
    # network architecture
    criterion = nn.BCEWithLogitsLoss,
    hidden_dim = 64,
    dropout = 0.1,
    # training
    batch_size = 128,
    learning_rate = 0.0001,
    max_epochs = 300,                                                                                                      
    num_workers = 0,    
    # optimization
    number_of_trials = 10                                      
)

In [None]:
import warnings
warnings.filterwarnings("ignore")

### Real-world data

In [None]:
experiment = "prison"

data_location = "../data/prison/prison_dataframe"

with open(data_location, "rb") as input:
    full_data = pickle.load(input)

full_data = full_data.drop(["Y", "UIE", "UDE", "USE"],axis=1)

n_classes = len(np.unique(full_data["M"]))
n_features = full_data.drop("M", axis = 1).shape[1]

class_weights = compute_class_weight(class_weight ='balanced', classes = np.unique(full_data["M"]), y= full_data["M"].astype(float))

In [None]:
dm = FairClfDataModule(   
    data_dir = data_location,
    label_col = "M",
    batch_size = train_config['batch_size'],
    num_workers = train_config['num_workers'],
    mode = "density"
    )
model = Density_estimator(
    n_features = n_features,
    n_classes = n_classes,
    hidden_dim = train_config['hidden_dim'],
    dropout = train_config['dropout'],
    criterion = train_config['criterion'],
    class_weights = class_weights,
    learning_rate = train_config['learning_rate']
)#.to("cpu", dtype=float)


def objective(trial: optuna.trial.Trial) -> float:
    learning_rate = trial.suggest_float("learning_rate", 0.5e-4, 5e-2, log=True)
    batch_size = trial.suggest_categorical("batch_size",[8, 16, 32, 64, 128])
    hidden_dim = trial.suggest_categorical("hidden_dim",[32, 64, 128, 256])
    dropout = trial.suggest_float("dropout",low=0, high=0.5)
    early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=5, verbose=False, mode="min", check_on_train_epoch_end=True)
    trainer = pl.Trainer(
        max_epochs=train_config['max_epochs'],
        gpus=0,
        logger = True,
        callbacks = [early_stop_callback],
        deterministic=True,
        accelerator="cpu"
    )
    hyperparameters = dict(learning_rate = learning_rate, batch_size = batch_size, hidden_dim = hidden_dim, dropout = dropout)
    trainer.logger.log_hyperparams(hyperparameters)
    trainer.fit(model, datamodule = dm)
    return trainer.callback_metrics["val_loss"].item()

pruner = optuna.pruners.MedianPruner()
sampler = optuna.samplers.TPESampler(seed=1)

study = optuna.create_study(direction="minimize", pruner=pruner, sampler = sampler)
study.optimize(objective, n_trials=train_config["number_of_trials"], timeout=100000)                 

model_name = "density_estimator_" + experiment
print("Best trial:")
trial = study.best_trial
print("  Validation loss: {}".format(trial.value))
print("  Parameters: ")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))

    
# Train model with optimized hyperparameters

final_model = Density_estimator(
    n_features = n_features,
    n_classes = n_classes,
    hidden_dim = study.best_params["hidden_dim"],
    dropout = study.best_params["dropout"],
    criterion = train_config['criterion'],
    class_weights = class_weights,
    learning_rate = study.best_params["learning_rate"]
)

dm_final = FairClfDataModule(  
    num_workers = train_config['num_workers'],
    data_dir = data_location,
    label_col = "M",
    batch_size = study.best_params["batch_size"],
    mode = "density"
)

logger = TensorBoardLogger(save_location, name = model_name)


checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath= save_location + "/" + model_name,
    filename= model_name + "_checkpoints",
    save_top_k=1,
    mode="min",
)
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=5, verbose=False, mode="min", check_on_train_epoch_end=True)

final_trainer = pl.Trainer(
max_epochs=train_config['max_epochs'],
gpus=0,
logger = logger,
callbacks = [checkpoint_callback, early_stop_callback],
deterministic=True,
accelerator="cpu"
)

final_trainer.fit(final_model, dm_final)

### Synthetic

In [None]:
for experiment in experiment_list:

    data_location = "../data/simulator/"+ experiment + "_full_dataframe"

    with open(data_location, "rb") as input:
        full_data = pickle.load(input)

    full_data = full_data.drop(["Y", "USE", "UIE", "UDE", "M"],axis=1)

    n_classes = len(np.unique(full_data["A"]))
    if n_classes == 2:
        n_classes = 1
    n_features = full_data.drop("A", axis = 1).shape[1]

    if n_classes == 1:
        class_weights = compute_class_weight(class_weight ='balanced', classes = np.float64(np.arange(n_classes+1)), y= full_data["A"].astype(float))
    else: 
        class_weights = compute_class_weight(class_weight ='balanced', classes = np.float64(np.arange(n_classes)), y= full_data["A"].astype(float))


    # Configurate model

    dm = FairClfDataModule(
        data_dir = data_location,
        label_col = "A",
        batch_size = train_config['batch_size'],
        num_workers = train_config['num_workers'],
        mode = "density_A"
        )
    model = Density_estimator(
        n_features = n_features,
        n_classes = n_classes,
        class_weights = class_weights,
        hidden_dim = train_config['hidden_dim'],
        dropout = train_config['dropout'],
        criterion = train_config['criterion'],
        learning_rate = train_config['learning_rate']
    )

    # Hyperparameter optimization

    def objective(trial: optuna.trial.Trial) -> float:
        learning_rate = trial.suggest_float("learning_rate", 5e-4, 5e-2, log=True)
        batch_size = trial.suggest_categorical("batch_size",[64, 128, 256])
        hidden_dim = trial.suggest_categorical("hidden_dim",[64, 128, 256])
        dropout = trial.suggest_float("dropout",low=0, high=0.5)
        early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=20, verbose=False, mode="min", check_on_train_epoch_end=True)
        trainer = pl.Trainer(
            max_epochs=train_config['max_epochs'],
            gpus=0,
            logger = True,
            callbacks = [early_stop_callback],
            deterministic=True,
            accelerator="cpu"
        )
        hyperparameters = dict(learning_rate = learning_rate, batch_size = batch_size, hidden_dim = hidden_dim, dropout = dropout)
        trainer.logger.log_hyperparams(hyperparameters)
        trainer.fit(model, datamodule = dm)
        return trainer.callback_metrics["val_loss"].item()

    pruner = optuna.pruners.MedianPruner()
    sampler = optuna.samplers.TPESampler(seed=1)

    study = optuna.create_study(direction="minimize", pruner=pruner, sampler = sampler)
    study.optimize(objective, n_trials=train_config["number_of_trials"], timeout=100000)                 

    model_name = "density_estimator_A_" + experiment
    print("Best trial:")
    trial = study.best_trial
    print("  Validation loss: {}".format(trial.value))
    print("  Parameters: ")
    for key, value in trial.params.items():
        print("    {}: {}".format(key, value))

    
# Train model with optimized hyperparameters

    final_model = Density_estimator(
        n_features = n_features,
        n_classes = n_classes,
        class_weights = class_weights,
        hidden_dim = study.best_params["hidden_dim"],
        dropout = study.best_params["dropout"],
        criterion = train_config['criterion'],
        learning_rate = study.best_params["learning_rate"]
    )#.to("cpu", dtype=float)

    dm_final = FairClfDataModule(  
        num_workers = train_config['num_workers'],
        data_dir = data_location,
        label_col = "A",
        batch_size = study.best_params["batch_size"],
        mode = "density_A"
    )


    logger = TensorBoardLogger(save_location, name = model_name)

    checkpoint_callback = ModelCheckpoint(
        monitor="val_loss",
        dirpath= save_location + "/" + model_name,
        filename= model_name + "_checkpoints",
        save_top_k=1,
        mode="min",
    )
    early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=10, verbose=False, mode="min", check_on_train_epoch_end=True)

    final_trainer = pl.Trainer(
    max_epochs=train_config['max_epochs'],
    gpus=0,
    logger = logger,
    callbacks = [checkpoint_callback, early_stop_callback],
    deterministic=True,
    accelerator="cpu"
    )

    final_trainer.fit(final_model, dm_final)

In [None]:
for experiment in experiment_list:

    data_location = "../data/simulator/"+ experiment + "_full_dataframe"

    with open(data_location, "rb") as input:
        full_data = pickle.load(input)

    full_data = full_data.drop(["Y", "USE", "UIE", "UDE"],axis=1)

    n_classes = len(np.unique(full_data["M"]))
    if n_classes == 2:
        n_classes = 1
    n_features = full_data.drop("M", axis = 1).shape[1]

    if n_classes == 1:
        class_weights = compute_class_weight(class_weight ='balanced', classes = np.float64(np.arange(n_classes+1)), y= full_data["M"].astype(float))
    else: 
        class_weights = compute_class_weight(class_weight ='balanced', classes = np.float64(np.arange(n_classes)), y= full_data["M"].astype(float))


    # Configurate model

    dm = FairClfDataModule(
        data_dir = data_location,
        label_col = "M",
        batch_size = train_config['batch_size'],
        num_workers = train_config['num_workers'],
        mode = "density_M"
        )
    model = Density_estimator(
        n_features = n_features,
        n_classes = n_classes,
        hidden_dim = train_config['hidden_dim'],
        dropout = train_config['dropout'],
        criterion = train_config['criterion'],
        class_weights = class_weights,
        learning_rate = train_config['learning_rate']
    )

    # Hyperparameter optimization

    def objective(trial: optuna.trial.Trial) -> float:
        learning_rate = trial.suggest_float("learning_rate", 5e-4, 5e-2, log=True)
        batch_size = trial.suggest_categorical("batch_size",[64, 128, 256])
        hidden_dim = trial.suggest_categorical("hidden_dim",[64, 128, 256])
        dropout = trial.suggest_float("dropout",low=0, high=0.5)
        early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=20, verbose=False, mode="min", check_on_train_epoch_end=True)
        trainer = pl.Trainer(
            max_epochs=train_config['max_epochs'],
            gpus=0,
            logger = True,
            callbacks = [early_stop_callback],
            deterministic=True,
            accelerator="cpu"
        )
        hyperparameters = dict(learning_rate = learning_rate, batch_size = batch_size, hidden_dim = hidden_dim, dropout = dropout)
        trainer.logger.log_hyperparams(hyperparameters)
        trainer.fit(model, datamodule = dm)
        return trainer.callback_metrics["val_loss"].item()

    pruner = optuna.pruners.MedianPruner()
    sampler = optuna.samplers.TPESampler(seed=1)

    study = optuna.create_study(direction="minimize", pruner=pruner, sampler = sampler)
    study.optimize(objective, n_trials=train_config["number_of_trials"], timeout=100000)                 

    model_name = "density_estimator_" + experiment
    print("Best trial:")
    trial = study.best_trial
    print("  Validation loss: {}".format(trial.value))
    print("  Parameters: ")
    for key, value in trial.params.items():
        print("    {}: {}".format(key, value))

    
# Train model with optimized hyperparameters

    final_model = Density_estimator(
        n_features = n_features,
        n_classes = n_classes,
        hidden_dim = study.best_params["hidden_dim"],
        dropout = study.best_params["dropout"],
        criterion = train_config['criterion'],
        class_weights = class_weights,
        learning_rate = study.best_params["learning_rate"]
    )#.to("cpu", dtype=float)

    dm_final = FairClfDataModule(  
        num_workers = train_config['num_workers'],
        data_dir = data_location,
        label_col = "M",
        batch_size = study.best_params["batch_size"],
        mode = "density_M"
    )


    logger = TensorBoardLogger(save_location, name = model_name)
    
    checkpoint_callback = ModelCheckpoint(
        monitor="val_loss",
        dirpath= save_location + "/" + model_name,
        filename= model_name + "_checkpoints",
        save_top_k=1,
        mode="min",
    )
    early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=10, verbose=False, mode="min", check_on_train_epoch_end=True)

    final_trainer = pl.Trainer(
    max_epochs=train_config['max_epochs'],
    gpus=0,
    logger = logger,
    callbacks = [checkpoint_callback, early_stop_callback],
    deterministic=True,
    accelerator="cpu"
    )

    final_trainer.fit(final_model, dm_final)