In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from model_trainer import LabelTransferTrainer
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import keras

In [None]:
datasets = pd.read_csv('datasets.csv')
datasets = datasets[~datasets['study_id'].isin(['Ho2020M', 'Ho2020F'])]

trainer = LabelTransferTrainer(hyperparams_file='hyperparams.csv',
                               data_dir='exported_matrices',
                               output_dir='label_transfer_model_output',
                               verbose=False,
                               random_seed=12345)

trainer.save_models_plots()

# Hyperparameter tuning

In [None]:
print(datasets['study_id'].unique())

In [None]:
study_id = 'Lopez2021M'
do_hyperparam_search = True
single_param_search = False

Sweep through the hyperparameters of the model to find the best combination of hyperparameters.

In [None]:
# valid_hyperparameters = ['learning_rate', 'num_layers', 'num_nodes', 'reg_factor', 
#                             'reg_type', 'dropout_rate', 'bn_momentum', 'batch_size']

if do_hyperparam_search:
    keras.backend.clear_session()
    
    # We recreate the trainer object here in case we changed the hyperparams
    trainer = LabelTransferTrainer(hyperparams_file='hyperparams.csv',
                                      data_dir='exported_matrices',
                                      output_dir='label_transfer_model_output',
                                      verbose=False,
                                      random_seed=12345)

    if single_param_search:
        # Single hyperparameter
        param = "bn_momentum"
        values = [0.5, 0.75, 0.9, 0.95, 0.99]
        
        res = trainer.tune_hyperparam(study_id, param, values, n_folds=5)
    else:        
        # Two hyperparameters
        param1 = "bn_momentum"
        param2 = "reg_factor"
        values1 = [0.5, 0.75, 0.9, 0.95, 0.99]
        values2 = [0.0, 0.0001, 0.001, 0.01, 0.1]
        
        res = trainer.tune_two_hyperparameters(study_id, param1, param2, values1, values2, n_folds=5)    

In [None]:
res
# Mean folds
if single_param_search:
    res_grouped = res.groupby([param, "Epoch"]).mean().groupby(param)
else:
    res_grouped = res.groupby([param1, param2, "Epoch"]).mean().groupby([param1, param2])

display(res)
display(res_grouped)

In [None]:
# Plot the results
# The results are saved in a pd dataframe with the following columns:
#   - <hyperparam>: the value of hyperparameter that was tuned
#   - Fold: the fold number
#   - Epoch: the epoch number
#   - Loss, Val loss: the training and validation loss
#   - F1 score, Val F1 score: the training and validation F1 score

if do_hyperparam_search:
    if single_param_search:
        # Single hyperparameter
        res_grouped = res.groupby(param)
        fig, ax = plt.subplots(1, 2, figsize=(10, 5))

        for param_value, group in res_grouped:
            mean_group = group.groupby("Epoch").mean(numeric_only=True)
            min_group = group.groupby("Epoch").min()
            max_group = group.groupby("Epoch").max()
            
            sns.lineplot(data=mean_group, x="Epoch", y="Macro F1 score",
                         ax=ax[0], label=f"{param}={param_value}")
            sns.lineplot(data=mean_group, x="Epoch", y="Val Macro F1 score",
                         ax=ax[1], label=f"{param}={param_value}")

            ax[0].fill_between(
                mean_group.index, min_group["Macro F1 score"], max_group["Macro F1 score"], alpha=0.1)
            ax[1].fill_between(
                mean_group.index, min_group["Val Macro F1 score"], max_group["Val Macro F1 score"], alpha=0.1)
            plt.ylim(0, 1)

        # Plot average F1 score
        ax[0].plot(res_grouped.mean()["Epoch"], res_grouped.mean()[
                   "Macro F1 score"], label="Average - train", color='C1')
        ax[1].plot(res_grouped.mean()["Epoch"], res_grouped.mean()[
                   "Val Macro F1 score"], label="Average - validation", color='C1')

        ax[0].set_title(f"F1 score vs {param}")
        ax[1].set_title(f"Validation F1 score vs {param}")
        fig.suptitle(study_id)
    else:
        # Two hyperparameters
        # Mean the folds
        res_grouped = res.groupby([param1, param2, "Epoch"]).mean().groupby([param1, param2])
        fig, ax = plt.subplots(1, 3, figsize=(15, 5))
        max_f1 = res_grouped.max()['Val Macro F1 score'].unstack()
        sns.heatmap(max_f1, annot=True, ax=ax[0], cmap='viridis')
        ax[0].set_title(f"{study_id} - Max Macro F1 score")
        ax[0].set_xlabel(param2)
        ax[0].set_ylabel(param1)

        # Plot of F1 over epochs for max of param 1
        # Get the values of param1 and param2 that give max F1
        max_p1 = max_f1.idxmax(axis=0)[values2[0]]


        for param2_value in values2:
            group = res_grouped.get_group((max_p1, param2_value))
            mean_group = group.groupby("Epoch").mean(numeric_only=True)
            min_group = group.groupby("Epoch").min(numeric_only=True)
            max_group = group.groupby("Epoch").max(numeric_only=True)
            sns.lineplot(data=mean_group, x="Epoch", y="Val Macro F1 score",
                         ax=ax[1], label=f"{param2}={param2_value}")
        ax[1].set_title(f"F1 score vs {param2} for {param1}={max_p1}")

        # Plot of F1 over epochs for max of param 2
        max_p2 = max_f1.idxmax(axis=1)[values1[0]]

        for param1_value in values1:
            group = res_grouped.get_group((param1_value, max_p2))
            mean_group = group.groupby("Epoch").mean(numeric_only=True)
            min_group = group.groupby("Epoch").min(numeric_only=True)
            max_group = group.groupby("Epoch").max(numeric_only=True)
            sns.lineplot(data=mean_group, x="Epoch", y="Val Macro F1 score",
                         ax=ax[2], label=f"{param1}={param1_value}")
            ax[2].fill_between(
                mean_group.index, min_group["Val Macro F1 score"], max_group["Val Macro F1 score"], alpha=0.1)
        ax[2].set_title(f"F1 score vs {param1} for {param2}={max_p2}")

    plt.show()

## Train single model

In [None]:
keras.backend.clear_session()
# We recreate the trainer object here in case we changed the hyperparams
trainer = LabelTransferTrainer(hyperparams_file='hyperparams.csv',
                            data_dir='exported_matrices',
                            output_dir='label_transfer_model_output',
                            verbose=False,
                            random_seed=12345) 
    
trainer.models[study_id], history = trainer.train_single_model(study_id)

fig, ax = plt.subplots(2, 2, figsize=(10, 10))

ax.flat[0].plot(history['loss'], label='train')
ax.flat[0].plot(history['val_loss'], label='validation')
ax.flat[0].set_xlabel('Epoch')
ax.flat[0].set_ylabel('Loss')

ax.flat[1].plot(history['macro_f1_score'], label='train')
ax.flat[1].plot(history['val_macro_f1_score'], label='validation')
ax.flat[1].set_xlabel('Epoch')
ax.flat[1].set_ylabel('F1 score (macro)')

ax.flat[2].plot(history['micro_f1_score'], label='train')
ax.flat[2].plot(history['val_micro_f1_score'], label='validation')
ax.flat[2].set_xlabel('Epoch')
ax.flat[2].set_ylabel('F1 score (micro)')

ax.flat[3].plot(history['ROCAUC'], label='train')
ax.flat[3].plot(history['val_ROCAUC'], label='validation')
ax.flat[3].set_xlabel('Epoch')
ax.flat[3].set_ylabel('ROC AUC')

plt.suptitle(f"Training history for {study_id}")

for a in ax.flat[1:]:
    a.set_ylim(0, 1.1)
    a.legend()

plt.tight_layout()

trainer.evaluate_single_model(study_id)

# Training

Train all models

In [None]:
keras.backend.clear_session()
# We recreate the trainer object here in case we changed the hyperparams
trainer = LabelTransferTrainer(hyperparams_file='hyperparams.csv',
                            data_dir='exported_matrices',
                            output_dir='label_transfer_model_output',
                            verbose=False,
                            random_seed=12345) 

# Uncomment to delete all saved model files
trainer.clear_saved_models()
trainer.train_all_models(reset_models=True)

In [None]:
fig, ax = plt.subplots(4, 3, figsize=(15, 15))

dataset_names = [
    f"{d['author']} {d['year']} {d['sex']}" for _, d in datasets.iterrows()]
dataset_names = np.unique(dataset_names)
hist_keys = list(trainer.training_histories.keys())

for i, d in enumerate(dataset_names):
    a = ax.ravel()[i]
    a.plot(trainer.training_histories[hist_keys[i]]['macro_f1_score'])
    a.plot(trainer.training_histories[hist_keys[i]]['val_macro_f1_score'])
    a.set_title(d)
    a.set_ylim([0, 1.1])
    a.set_xlabel('Epoch')
    a.set_ylabel('F1 score')

ax.flat[-1].axis('off')
ax.flat[-2].axis('off')
plt.tight_layout()

# Evaluation

In [None]:
trainer.load_best_models()
eval_res = trainer.evaluate_models()

macro_f1 = [r["macro_f1_score"] for r in eval_res.values()]
micro_f1 = [r["micro_f1_score"] for r in eval_res.values()]
rocauc = [r["ROCAUC"] for r in eval_res.values()]

# barplot of the F1 scores, colored in red if the F1 score is below 0.5, yellow if below 0.8,
# and green otherwise
fig, ax = plt.subplots(1, 3, figsize=(15, 5))

sns.barplot(x=list(eval_res.keys()), y=macro_f1,
            palette=['red' if f < 0.5 else 'yellow' if f < 0.75 else 'green' for f in macro_f1],
            ax=ax[0])
ax[0].set_title('Macro F1 score')
ax[0].set_ylabel('F1 score')
for i, v in enumerate(macro_f1):
    ax[0].text(i - 0.25, v + 0.01, f"{v:.2f}")

sns.barplot(x=list(eval_res.keys()), y=micro_f1,
            palette=['red' if f < 0.5 else 'yellow' if f < 0.75 else 'green' for f in micro_f1],
            ax=ax[1])
ax[1].set_title('Micro F1 score')
ax[1].set_ylabel('F1 score')
for i, v in enumerate(micro_f1):
    ax[1].text(i - 0.25, v + 0.01, f"{v:.2f}")

sns.barplot(x=list(eval_res.keys()), y=rocauc,
            palette=['red' if f < 0.5 else 'yellow' if f < 0.75 else 'green' for f in rocauc],
            ax=ax[2])
ax[2].set_title('ROC AUC')
ax[2].set_ylabel('ROC AUC')
for i, v in enumerate(rocauc):
    ax[2].text(i - 0.25, v + 0.01, f"{v:.2f}")

for a in ax:
    a.set_ylim([0, 1])
    a.set_xticklabels(a.get_xticklabels(), rotation=45, ha='right')
    
plt.tight_layout()

# for f in macro_f1:
#     print(f"{f:.2f},", end="")
    

In [None]:
print(f"Mean macro F1 score: {np.mean(macro_f1):.2f}")
print(f"Median macro F1 score: {np.median(macro_f1):.2f}")
print(f"Std macro F1 score: {np.std(macro_f1):.2f}")
print("*" * 20)
print(f"Mean ROC AUC: {np.mean(rocauc):.2f}")
print(f"Median ROC AUC: {np.median(rocauc):.2f}")
print(f"Std ROC AUC: {np.std(rocauc):.2f}")

In [None]:
bm = pd.DataFrame(trainer.load_best_models(clear_others=True))
bm