# Train models
This notebook allows you to train models for SOZ localisation from SPES responses. Prior to running this notebook, you should first produce the dataset by running create_dataset.ipynb.

In [1]:
import os

# Needed for MPS
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

# Paths to average and standard deviation of SPES responses
mean_filepath = 'data/mean'
std_filepath = 'data/std'

# Function for training and evaluating the model
from train import train_and_evaluate
from dataset import create_dataset
from evaluate import get_thresh_and_evaluate
from models import load_model_from_path
import pandas as pd

### Define the model
Uncomment the relevant section - the main model is Transformer (all)

In [2]:
# model_name = 'CNN (divergent)'
# learning_rate = 0.003962831229235175
# hyperparams = {'dropout_rate': 0.21763415739071962,
#                'input_channels': 49}

# model_name = 'CNN (convergent)'
# learning_rate = 0.001289042623854371
# hyperparams = {'dropout_rate': 0.44374819954858546,
#                'input_channels': 37}

# model_name = 'Transformer (base)'
# learning_rate = 0.00014782432202142655
# hyperparams = {'dropout_rate': 0.4610858126530446,
#                'embedding_dim': 2**4,
#                'num_layers': 2**1}

model_name = 'Transformer (all)'
learning_rate = 0.003368045116199473
hyperparams = {'dropout_rate': 0.4391902174353594,
               'embedding_dim': 2**4,
               'num_layers': 2**1}

### Set your device and the random seed
This is particularly useful for ensuring folds are consistent to compare models. Useful in particular when doing repeated k-fold cross-validation.

In [3]:
seed = 1

# E.g. cuda:0, cuda:1, cpu
device = 'mps'

### Train the model
This can be done for each fold. Note that for each model, the epoch that achieved the highest AUROC is saved, and the stats outputted at the end correspond to the test set with this checkpoint.

In [4]:
for fold in range(5):
    train_and_evaluate(model_name, fold, learning_rate, seed, 'final', device, mean_filepath, std_filepath, num_epochs=10, batch_size=8, **hyperparams)

100%|██████████| 145/145 [00:31<00:00,  4.63it/s]


Epoch 1/10, Train Loss: 1.4875, Train AUROC: 0.5507


  and not torch._nested_tensor_from_mask_left_aligned(src, src_key_padding_mask.logical_not())):
100%|██████████| 55/55 [00:03<00:00, 14.25it/s]


Epoch 1/10, Valid AUROC: 0.7634

Best validation metric: 0.7634491600369968

Saving best model for epoch: 1



100%|██████████| 145/145 [00:29<00:00,  4.92it/s]


Epoch 2/10, Train Loss: 1.3068, Train AUROC: 0.5701


100%|██████████| 55/55 [00:03<00:00, 17.12it/s]


Epoch 2/10, Valid AUROC: 0.7343


100%|██████████| 145/145 [00:29<00:00,  4.95it/s]


Epoch 3/10, Train Loss: 1.2264, Train AUROC: 0.5923


100%|██████████| 55/55 [00:03<00:00, 17.15it/s]


Epoch 3/10, Valid AUROC: 0.7678

Best validation metric: 0.7678421516561882

Saving best model for epoch: 3



100%|██████████| 145/145 [00:29<00:00,  4.90it/s]


Epoch 4/10, Train Loss: 1.1938, Train AUROC: 0.6227


100%|██████████| 55/55 [00:03<00:00, 16.99it/s]


Epoch 4/10, Valid AUROC: 0.7973

Best validation metric: 0.7972895727276645

Saving best model for epoch: 4



100%|██████████| 145/145 [00:29<00:00,  4.89it/s]


Epoch 5/10, Train Loss: 1.1344, Train AUROC: 0.6930


100%|██████████| 55/55 [00:03<00:00, 16.80it/s]


Epoch 5/10, Valid AUROC: 0.7804


100%|██████████| 145/145 [00:29<00:00,  4.87it/s]


Epoch 6/10, Train Loss: 1.1431, Train AUROC: 0.6889


100%|██████████| 55/55 [00:03<00:00, 16.96it/s]


Epoch 6/10, Valid AUROC: 0.8036

Best validation metric: 0.8035999530747394

Saving best model for epoch: 6



100%|██████████| 145/145 [00:29<00:00,  4.94it/s]


Epoch 7/10, Train Loss: 1.1538, Train AUROC: 0.6730


100%|██████████| 55/55 [00:03<00:00, 17.03it/s]


Epoch 7/10, Valid AUROC: 0.7673


100%|██████████| 145/145 [00:29<00:00,  4.93it/s]


Epoch 8/10, Train Loss: 1.1128, Train AUROC: 0.7014


100%|██████████| 55/55 [00:03<00:00, 17.09it/s]


Epoch 8/10, Valid AUROC: 0.8142

Best validation metric: 0.8142425998565542

Saving best model for epoch: 8



100%|██████████| 145/145 [00:29<00:00,  4.94it/s]


Epoch 9/10, Train Loss: 1.0824, Train AUROC: 0.7263


100%|██████████| 55/55 [00:03<00:00, 17.01it/s]


Epoch 9/10, Valid AUROC: 0.7360


100%|██████████| 145/145 [00:29<00:00,  4.93it/s]


Epoch 10/10, Train Loss: 1.1061, Train AUROC: 0.7083


100%|██████████| 55/55 [00:03<00:00, 17.09it/s]


Epoch 10/10, Valid AUROC: 0.7305


100%|██████████| 55/55 [00:03<00:00, 17.09it/s]
100%|██████████| 57/57 [00:04<00:00, 12.77it/s]


{'AUROC (averaged)': 0.5957765780213963, 'AUPRC (averaged)': 0.36089584768728283, 'AUROC (all)': 0.6528737220226581, 'AUPRC (all)': 0.2222967709828123, 'Baseline (averaged)': 0.17156696, 'Baseline (all)': 0.16997792, 'Youden (averaged)': 0.20460835246315762, 'Specificity (averaged)': 0.7646372124920175, 'Sensitivity (averaged)': 0.43997113997113996, 'Youden threshold': 0.23198655}


100%|██████████| 153/153 [00:35<00:00,  4.27it/s]


Epoch 1/10, Train Loss: 1.3640, Train AUROC: 0.5602


100%|██████████| 49/49 [00:03<00:00, 16.02it/s]


Epoch 1/10, Valid AUROC: 0.7186

Best validation metric: 0.718621182636938

Saving best model for epoch: 1



100%|██████████| 153/153 [00:33<00:00,  4.53it/s]


Epoch 2/10, Train Loss: 1.1657, Train AUROC: 0.6297


100%|██████████| 49/49 [00:01<00:00, 24.53it/s]


Epoch 2/10, Valid AUROC: 0.7628

Best validation metric: 0.7627582087405049

Saving best model for epoch: 2



100%|██████████| 153/153 [00:33<00:00,  4.59it/s]


Epoch 3/10, Train Loss: 1.1842, Train AUROC: 0.5950


100%|██████████| 49/49 [00:01<00:00, 24.68it/s]


Epoch 3/10, Valid AUROC: 0.8269

Best validation metric: 0.826860964531427

Saving best model for epoch: 3



100%|██████████| 153/153 [00:33<00:00,  4.61it/s]


Epoch 4/10, Train Loss: 1.1308, Train AUROC: 0.6377


100%|██████████| 49/49 [00:01<00:00, 24.54it/s]


Epoch 4/10, Valid AUROC: 0.8298

Best validation metric: 0.8298178817565924

Saving best model for epoch: 4



100%|██████████| 153/153 [00:33<00:00,  4.62it/s]


Epoch 5/10, Train Loss: 1.1338, Train AUROC: 0.6346


100%|██████████| 49/49 [00:01<00:00, 24.65it/s]


Epoch 5/10, Valid AUROC: 0.8220


100%|██████████| 153/153 [00:33<00:00,  4.60it/s]


Epoch 6/10, Train Loss: 1.1040, Train AUROC: 0.6591


100%|██████████| 49/49 [00:01<00:00, 24.89it/s]


Epoch 6/10, Valid AUROC: 0.8060


100%|██████████| 153/153 [00:33<00:00,  4.57it/s]


Epoch 7/10, Train Loss: 1.1214, Train AUROC: 0.6508


100%|██████████| 49/49 [00:02<00:00, 24.00it/s]


Epoch 7/10, Valid AUROC: 0.7638


100%|██████████| 153/153 [00:33<00:00,  4.62it/s]


Epoch 8/10, Train Loss: 1.1299, Train AUROC: 0.6339


100%|██████████| 49/49 [00:01<00:00, 24.95it/s]


Epoch 8/10, Valid AUROC: 0.7931


100%|██████████| 153/153 [00:33<00:00,  4.60it/s]


Epoch 9/10, Train Loss: 1.1353, Train AUROC: 0.6220


100%|██████████| 49/49 [00:01<00:00, 24.65it/s]


Epoch 9/10, Valid AUROC: 0.8041


100%|██████████| 153/153 [00:33<00:00,  4.55it/s]


Epoch 10/10, Train Loss: 1.1377, Train AUROC: 0.6137


100%|██████████| 49/49 [00:01<00:00, 24.50it/s]


Epoch 10/10, Valid AUROC: 0.7752


100%|██████████| 49/49 [00:01<00:00, 24.87it/s]
100%|██████████| 55/55 [00:03<00:00, 16.58it/s]


{'AUROC (averaged)': 0.7898241225348538, 'AUPRC (averaged)': 0.5492288010720215, 'AUROC (all)': 0.7523261595283756, 'AUPRC (all)': 0.3839905485255579, 'Baseline (averaged)': 0.15719962, 'Baseline (all)': 0.17767654, 'Youden (averaged)': 0.37748728524178904, 'Specificity (averaged)': 0.6670732998278037, 'Sensitivity (averaged)': 0.7104139854139854, 'Youden threshold': 0.39216185}


100%|██████████| 162/162 [00:36<00:00,  4.41it/s]


Epoch 1/10, Train Loss: 1.3164, Train AUROC: 0.5506


100%|██████████| 46/46 [00:02<00:00, 16.80it/s]


Epoch 1/10, Valid AUROC: 0.8276

Best validation metric: 0.8275628707003037

Saving best model for epoch: 1



100%|██████████| 162/162 [00:34<00:00,  4.64it/s]


Epoch 2/10, Train Loss: 1.1347, Train AUROC: 0.6083


100%|██████████| 46/46 [00:02<00:00, 18.02it/s]


Epoch 2/10, Valid AUROC: 0.8028


100%|██████████| 162/162 [00:35<00:00,  4.62it/s]


Epoch 3/10, Train Loss: 1.1009, Train AUROC: 0.6572


100%|██████████| 46/46 [00:02<00:00, 18.00it/s]


Epoch 3/10, Valid AUROC: 0.8036


100%|██████████| 162/162 [00:34<00:00,  4.64it/s]


Epoch 4/10, Train Loss: 1.0739, Train AUROC: 0.6726


100%|██████████| 46/46 [00:02<00:00, 17.91it/s]


Epoch 4/10, Valid AUROC: 0.7801


100%|██████████| 162/162 [00:34<00:00,  4.65it/s]


Epoch 5/10, Train Loss: 1.0989, Train AUROC: 0.6527


100%|██████████| 46/46 [00:02<00:00, 17.97it/s]


Epoch 5/10, Valid AUROC: 0.8321

Best validation metric: 0.8321300351111869

Saving best model for epoch: 5



100%|██████████| 162/162 [00:35<00:00,  4.63it/s]


Epoch 6/10, Train Loss: 1.0755, Train AUROC: 0.6651


100%|██████████| 46/46 [00:02<00:00, 18.06it/s]


Epoch 6/10, Valid AUROC: 0.8511

Best validation metric: 0.851105291111302

Saving best model for epoch: 6



100%|██████████| 162/162 [00:34<00:00,  4.63it/s]


Epoch 7/10, Train Loss: 1.0876, Train AUROC: 0.6611


100%|██████████| 46/46 [00:02<00:00, 18.00it/s]


Epoch 7/10, Valid AUROC: 0.8435


100%|██████████| 162/162 [00:34<00:00,  4.63it/s]


Epoch 8/10, Train Loss: 1.0632, Train AUROC: 0.6880


100%|██████████| 46/46 [00:02<00:00, 18.02it/s]


Epoch 8/10, Valid AUROC: 0.8444


100%|██████████| 162/162 [00:35<00:00,  4.63it/s]


Epoch 9/10, Train Loss: 1.0840, Train AUROC: 0.6621


100%|██████████| 46/46 [00:02<00:00, 17.98it/s]


Epoch 9/10, Valid AUROC: 0.8350


100%|██████████| 162/162 [00:34<00:00,  4.64it/s]


Epoch 10/10, Train Loss: 1.0783, Train AUROC: 0.6655


100%|██████████| 46/46 [00:02<00:00, 18.02it/s]


Epoch 10/10, Valid AUROC: 0.8600

Best validation metric: 0.8600030330892311

Saving best model for epoch: 10



100%|██████████| 46/46 [00:02<00:00, 18.23it/s]
100%|██████████| 49/49 [00:02<00:00, 23.64it/s]


{'AUROC (averaged)': 0.7680081494879715, 'AUPRC (averaged)': 0.30087190308762707, 'AUROC (all)': 0.6853551912568305, 'AUPRC (all)': 0.146189190019619, 'Baseline (averaged)': 0.06182741, 'Baseline (all)': 0.06393862, 'Youden (averaged)': 0.3413912935624789, 'Specificity (averaged)': 0.5874230395942249, 'Sensitivity (averaged)': 0.753968253968254, 'Youden threshold': 0.62363684}


100%|██████████| 161/161 [00:37<00:00,  4.34it/s]


Epoch 1/10, Train Loss: 1.4233, Train AUROC: 0.5184


100%|██████████| 51/51 [00:03<00:00, 12.92it/s]


Epoch 1/10, Valid AUROC: 0.7282

Best validation metric: 0.7282352166775606

Saving best model for epoch: 1



100%|██████████| 161/161 [00:34<00:00,  4.63it/s]


Epoch 2/10, Train Loss: 1.2295, Train AUROC: 0.5610


100%|██████████| 51/51 [00:02<00:00, 20.95it/s]


Epoch 2/10, Valid AUROC: 0.7735

Best validation metric: 0.7735435620849548

Saving best model for epoch: 2



100%|██████████| 161/161 [00:34<00:00,  4.66it/s]


Epoch 3/10, Train Loss: 1.1790, Train AUROC: 0.5771


100%|██████████| 51/51 [00:02<00:00, 21.08it/s]


Epoch 3/10, Valid AUROC: 0.7056


100%|██████████| 161/161 [00:34<00:00,  4.65it/s]


Epoch 4/10, Train Loss: 1.1474, Train AUROC: 0.6170


100%|██████████| 51/51 [00:02<00:00, 21.06it/s]


Epoch 4/10, Valid AUROC: 0.7497


100%|██████████| 161/161 [00:34<00:00,  4.64it/s]


Epoch 5/10, Train Loss: 1.1840, Train AUROC: 0.5984


100%|██████████| 51/51 [00:02<00:00, 21.16it/s]


Epoch 5/10, Valid AUROC: 0.7568


100%|██████████| 161/161 [00:34<00:00,  4.64it/s]


Epoch 6/10, Train Loss: 1.1550, Train AUROC: 0.6269


100%|██████████| 51/51 [00:02<00:00, 21.07it/s]


Epoch 6/10, Valid AUROC: 0.7605


100%|██████████| 161/161 [00:34<00:00,  4.65it/s]


Epoch 7/10, Train Loss: 1.1418, Train AUROC: 0.6398


100%|██████████| 51/51 [00:02<00:00, 21.09it/s]


Epoch 7/10, Valid AUROC: 0.7490


100%|██████████| 161/161 [00:34<00:00,  4.65it/s]


Epoch 8/10, Train Loss: 1.1301, Train AUROC: 0.6643


100%|██████████| 51/51 [00:02<00:00, 21.11it/s]


Epoch 8/10, Valid AUROC: 0.7482


100%|██████████| 161/161 [00:34<00:00,  4.64it/s]


Epoch 9/10, Train Loss: 1.1473, Train AUROC: 0.6340


100%|██████████| 51/51 [00:02<00:00, 21.13it/s]


Epoch 9/10, Valid AUROC: 0.7146


100%|██████████| 161/161 [00:34<00:00,  4.64it/s]


Epoch 10/10, Train Loss: 1.1533, Train AUROC: 0.6390


100%|██████████| 51/51 [00:02<00:00, 20.46it/s]


Epoch 10/10, Valid AUROC: 0.7520


100%|██████████| 51/51 [00:02<00:00, 20.78it/s]
100%|██████████| 46/46 [00:02<00:00, 17.54it/s]


{'AUROC (averaged)': 0.7599954333644977, 'AUPRC (averaged)': 0.38365848233220584, 'AUROC (all)': 0.664983164983165, 'AUPRC (all)': 0.18206888456622278, 'Baseline (averaged)': 0.12356382, 'Baseline (all)': 0.11956522, 'Youden (averaged)': 0.3556224067512243, 'Specificity (averaged)': 0.566506760492721, 'Sensitivity (averaged)': 0.7891156462585034, 'Youden threshold': 0.5026825}


100%|██████████| 150/150 [00:35<00:00,  4.21it/s]


Epoch 1/10, Train Loss: 1.3719, Train AUROC: 0.5716


100%|██████████| 57/57 [00:03<00:00, 16.23it/s]


Epoch 1/10, Valid AUROC: 0.6725

Best validation metric: 0.6725066076317755

Saving best model for epoch: 1



100%|██████████| 150/150 [00:32<00:00,  4.68it/s]


Epoch 2/10, Train Loss: 1.2425, Train AUROC: 0.5974


100%|██████████| 57/57 [00:03<00:00, 16.77it/s]


Epoch 2/10, Valid AUROC: 0.6263


100%|██████████| 150/150 [00:32<00:00,  4.67it/s]


Epoch 3/10, Train Loss: 1.1688, Train AUROC: 0.6459


100%|██████████| 57/57 [00:03<00:00, 17.04it/s]


Epoch 3/10, Valid AUROC: 0.6066


100%|██████████| 150/150 [00:31<00:00,  4.70it/s]


Epoch 4/10, Train Loss: 1.1821, Train AUROC: 0.6422


100%|██████████| 57/57 [00:03<00:00, 16.89it/s]


Epoch 4/10, Valid AUROC: 0.5848


100%|██████████| 150/150 [00:31<00:00,  4.72it/s]


Epoch 5/10, Train Loss: 1.1238, Train AUROC: 0.6768


100%|██████████| 57/57 [00:03<00:00, 17.21it/s]


Epoch 5/10, Valid AUROC: 0.5648


100%|██████████| 150/150 [00:31<00:00,  4.76it/s]


Epoch 6/10, Train Loss: 1.1445, Train AUROC: 0.6625


100%|██████████| 57/57 [00:03<00:00, 17.01it/s]


Epoch 6/10, Valid AUROC: 0.5498


100%|██████████| 150/150 [00:31<00:00,  4.75it/s]


Epoch 7/10, Train Loss: 1.1743, Train AUROC: 0.6510


100%|██████████| 57/57 [00:03<00:00, 17.20it/s]


Epoch 7/10, Valid AUROC: 0.5183


100%|██████████| 150/150 [00:31<00:00,  4.72it/s]


Epoch 8/10, Train Loss: 1.1337, Train AUROC: 0.6739


100%|██████████| 57/57 [00:03<00:00, 17.01it/s]


Epoch 8/10, Valid AUROC: 0.5485


100%|██████████| 150/150 [00:32<00:00,  4.65it/s]


Epoch 9/10, Train Loss: 1.1227, Train AUROC: 0.6898


100%|██████████| 57/57 [00:03<00:00, 16.37it/s]


Epoch 9/10, Valid AUROC: 0.5760


100%|██████████| 150/150 [00:32<00:00,  4.63it/s]


Epoch 10/10, Train Loss: 1.1115, Train AUROC: 0.6939


100%|██████████| 57/57 [00:03<00:00, 17.07it/s]


Epoch 10/10, Valid AUROC: 0.5807


100%|██████████| 57/57 [00:03<00:00, 17.17it/s]
100%|██████████| 51/51 [00:02<00:00, 19.76it/s]


{'AUROC (averaged)': 0.7495985892405087, 'AUPRC (averaged)': 0.4742048869894257, 'AUROC (all)': 0.7615243583027763, 'AUPRC (all)': 0.3923485925837928, 'Baseline (averaged)': 0.17097831, 'Baseline (all)': 0.17206983, 'Youden (averaged)': 0.4546008328903066, 'Specificity (averaged)': 0.6367280258727627, 'Sensitivity (averaged)': 0.8178728070175438, 'Youden threshold': 0.62970114}


# Evaluate model performance
Define the seeds and models that you have trained:

In [5]:
# Having trained one model with one seed
seeds = [seed]
model_names = [model_name]

# E.g., if you've trained all models with 5 random seeds
# seeds = [1, 2, 3, 4, 5]
# model_names = ['CNN (divergent)', 'CNN (convergent)', 'Transformer (base)', 'Transformer (all)']

Produce a table with the scores for each of these models:

In [6]:
table = []
for seed in seeds:

    print("Seed", seed)

    # Iterate over each fold
    for fold in range(5):

        print("Fold", fold + 1)

        # Load data for the current fold
        train_loader, val_loader, test_loader, pos_weight = create_dataset(mean_filepath, std_filepath, fold, seed=seed, batch_size=8)

        # Iterate over each model
        for model_name in model_names:

            # Load the model
            model_save_path = f'../models/best_model_{model_name}_seed_{seed}_fold_{fold}.pth'
            model = load_model_from_path(model_name, model_save_path, device, **hyperparams)

            # Evaluate the model on the validation and test sets
            test_metrics = get_thresh_and_evaluate(model, device, val_loader, test_loader)

            # Append the results to the table
            table.append([model_name, seed, fold, test_metrics['AUROC (averaged)'], test_metrics['AUPRC (averaged)'], test_metrics['Youden (averaged)'], test_metrics['Specificity (averaged)'], test_metrics['Sensitivity (averaged)']])

# Convert to pandas table
table_df = pd.DataFrame(table, columns=['model', 'seed', 'fold', 'AUROC', 'AUPRC', 'Youden', 'Specificity', 'Sensitivity'])

Seed 1
Fold 1


100%|██████████| 55/55 [00:03<00:00, 16.41it/s]
100%|██████████| 57/57 [00:03<00:00, 16.24it/s]


Fold 2


100%|██████████| 49/49 [00:02<00:00, 23.74it/s]
100%|██████████| 55/55 [00:03<00:00, 16.41it/s]


Fold 3


100%|██████████| 46/46 [00:02<00:00, 17.85it/s]
100%|██████████| 49/49 [00:02<00:00, 23.76it/s]


Fold 4


100%|██████████| 51/51 [00:02<00:00, 20.44it/s]
100%|██████████| 46/46 [00:02<00:00, 17.49it/s]


Fold 5


100%|██████████| 57/57 [00:03<00:00, 16.11it/s]
100%|██████████| 51/51 [00:02<00:00, 19.50it/s]


In [7]:
# for each unique model, calculate the mean and standard deviation of the metrics
table_df.groupby('model').agg(['mean', 'std']).style.background_gradient(axis=0).format(precision=3)

Unnamed: 0_level_0,seed,seed,fold,fold,AUROC,AUROC,AUPRC,AUPRC,Youden,Youden,Specificity,Specificity,Sensitivity,Sensitivity
Unnamed: 0_level_1,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std
model,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2
Transformer (all),1.0,0.0,2.0,1.581,0.733,0.078,0.414,0.098,0.347,0.091,0.644,0.078,0.702,0.152
