# Transformer

Train Base ViT model for IEEE EEG dataset

In [1]:
# Mount Google Drive for Colab env
import sys
from google.colab import drive

drive.mount("/content/drive", force_remount=False)
sys.path.append("/content/drive/MyDrive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from utils import (
    ignore_warnings,
    fix_random_seed,
    device,
    clear_cache,
    join_drive_path,
    log_json,
    train_with_kfold,
    WarmupScheduler,
    evaluate,
    Config,
    IEEEDataConfig,
    EEGDataset,
)
from models.transformer import TransformerConfig, ViTransformer

ignore_warnings()
fix_random_seed(42)
device = device(force_cuda=True)
print("Device:", device)

Device: cuda


In [3]:
config = Config(
    name="ieee transformer",
    batch=8,
    epochs=50,
    lr=1e-3,
    enable_fp16=True,
    grad_step=4,
    warmup_steps=50,
    lr_decay_factor=0.5,
    weight_decay=1e-3,
    patience=30,
)
config.add(k_folds=5)
data_config = IEEEDataConfig()
model_config = TransformerConfig(
    embed_dim=64,
    num_heads=4,
    num_blocks=4,
    block_hidden_dim=64,
    fc_hidden_dim=32,
    dropout=0.1,
)

print("ID:", config.id)
print("Name:", config.name)

ID: 250302035900177174
Name: ieee-transformer


In [4]:
train_data_path = join_drive_path("data", data_config.train)
val_data_path = join_drive_path("data", data_config.val)

train_data = torch.load(train_data_path, weights_only=True)
val_data = torch.load(val_data_path, weights_only=True)

# Concat Train-set and Validation-set for Cross validation
signals = torch.cat([train_data["data"], val_data["data"]], dim=0)
labels = torch.cat([train_data["label"], val_data["label"]], dim=0)

train_dataset = EEGDataset({"data": signals, "label": labels})

In [5]:
model_param = {
    "input_channel": data_config.channels,
    "seq_length": data_config.length,
    "embed_dim": model_config.embed_dim,
    "num_heads": model_config.num_heads,
    "num_blocks": model_config.num_blocks,
    "block_hidden_dim": model_config.block_hidden_dim,
    "fc_hidden_dim": model_config.fc_hidden_dim,
    "num_classes": data_config.num_classes,
    "dropout_p": model_config.dropout,
}

criterion = nn.CrossEntropyLoss()
check_point, best_model_path = train_with_kfold(
    k_folds=config.k_folds,
    model_class=ViTransformer,
    device=device,
    model_path=config.model_path,
    optimizer_class=optim.Adam,
    criterion=criterion,
    epochs=config.epochs,
    train_dataset=train_dataset,
    batch=config.batch,
    gradient_step=config.grad_step,
    patience=config.patience,
    model_params=model_param,
    optimizer_params={"lr": config.lr, "weight_decay": config.weight_decay},
    enable_fp16=config.enable_fp16,
    scheduler_class=WarmupScheduler,
    scheduler_params={
        "lr": config.lr,
        "warmup_steps": config.warmup_steps,
        "decay_factor": config.lr_decay_factor,
    },
)


===== Fold 1 =====


  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 1, Train-Loss: 0.73011,  Val-Loss: 0.76622
Epoch 2, Train-Loss: 0.71604,  Val-Loss: 0.75153
Epoch 3, Train-Loss: 0.72164,  Val-Loss: 0.73429
Epoch 4, Train-Loss: 0.70390,  Val-Loss: 0.71230
Epoch 5, Train-Loss: 0.73161,  Val-Loss: 0.69196
Epoch 6, Train-Loss: 0.68867,  Val-Loss: 0.67226
Epoch 7, Train-Loss: 0.68459,  Val-Loss: 0.65197
Epoch 8, Train-Loss: 0.70673,  Val-Loss: 0.63855
Epoch 9, Train-Loss: 0.68117,  Val-Loss: 0.62884
Epoch 10, Train-Loss: 0.66076,  Val-Loss: 0.61990
Epoch 11, Train-Loss: 0.65410,  Val-Loss: 0.61634
Epoch 12, Train-Loss: 0.65531,  Val-Loss: 0.61345
Epoch 13, Train-Loss: 0.66445,  Val-Loss: 0.61860
Epoch 14, Train-Loss: 0.65732,  Val-Loss: 0.62434
Epoch 15, Train-Loss: 0.64022,  Val-Loss: 0.62886
Epoch 16, Train-Loss: 0.62848,  Val-Loss: 0.62792
Epoch 17, Train-Loss: 0.63894,  Val-Loss: 0.63122
Epoch 18, Train-Loss: 0.62234,  Val-Loss: 0.62925
Epoch 19, Train-Loss: 0.59838,  Val-Loss: 0.62325
Epoch 20, Train-Loss: 0.61902,  Val-Loss: 0.61838
Epoch 21,

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 1, Train-Loss: 0.89777,  Val-Loss: 0.81088
Epoch 2, Train-Loss: 0.87446,  Val-Loss: 0.79575
Epoch 3, Train-Loss: 0.83191,  Val-Loss: 0.77538
Epoch 4, Train-Loss: 0.83071,  Val-Loss: 0.75215
Epoch 5, Train-Loss: 0.79939,  Val-Loss: 0.72701
Epoch 6, Train-Loss: 0.77655,  Val-Loss: 0.70493
Epoch 7, Train-Loss: 0.73282,  Val-Loss: 0.69123
Epoch 8, Train-Loss: 0.70611,  Val-Loss: 0.68572
Epoch 9, Train-Loss: 0.68295,  Val-Loss: 0.68351
Epoch 10, Train-Loss: 0.68416,  Val-Loss: 0.68601
Epoch 11, Train-Loss: 0.66392,  Val-Loss: 0.69197
Epoch 12, Train-Loss: 0.65196,  Val-Loss: 0.69898
Epoch 13, Train-Loss: 0.67978,  Val-Loss: 0.70664
Epoch 14, Train-Loss: 0.65855,  Val-Loss: 0.71028
Epoch 15, Train-Loss: 0.65202,  Val-Loss: 0.71159
Epoch 16, Train-Loss: 0.66598,  Val-Loss: 0.71032
Epoch 17, Train-Loss: 0.66502,  Val-Loss: 0.70478
Epoch 18, Train-Loss: 0.66883,  Val-Loss: 0.69814
Epoch 19, Train-Loss: 0.64987,  Val-Loss: 0.68785
Epoch 20, Train-Loss: 0.62815,  Val-Loss: 0.67911
Epoch 21,

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 1, Train-Loss: 0.75067,  Val-Loss: 0.75355
Epoch 2, Train-Loss: 0.77797,  Val-Loss: 0.73730
Epoch 3, Train-Loss: 0.71009,  Val-Loss: 0.71488
Epoch 4, Train-Loss: 0.69068,  Val-Loss: 0.69160
Epoch 5, Train-Loss: 0.69044,  Val-Loss: 0.67075
Epoch 6, Train-Loss: 0.68914,  Val-Loss: 0.65678
Epoch 7, Train-Loss: 0.66510,  Val-Loss: 0.64843
Epoch 8, Train-Loss: 0.71499,  Val-Loss: 0.64449
Epoch 9, Train-Loss: 0.67933,  Val-Loss: 0.64353
Epoch 10, Train-Loss: 0.63600,  Val-Loss: 0.64146
Epoch 11, Train-Loss: 0.66698,  Val-Loss: 0.63734
Epoch 12, Train-Loss: 0.64862,  Val-Loss: 0.63168
Epoch 13, Train-Loss: 0.66648,  Val-Loss: 0.62858
Epoch 14, Train-Loss: 0.62387,  Val-Loss: 0.62947
Epoch 15, Train-Loss: 0.65095,  Val-Loss: 0.62753
Epoch 16, Train-Loss: 0.58086,  Val-Loss: 0.62461
Epoch 17, Train-Loss: 0.60535,  Val-Loss: 0.61822
Epoch 18, Train-Loss: 0.64294,  Val-Loss: 0.61100
Epoch 19, Train-Loss: 0.60192,  Val-Loss: 0.60831
Epoch 20, Train-Loss: 0.56411,  Val-Loss: 0.60026
Epoch 21,

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 1, Train-Loss: 0.82260,  Val-Loss: 0.74705
Epoch 2, Train-Loss: 0.80955,  Val-Loss: 0.74007
Epoch 3, Train-Loss: 0.78061,  Val-Loss: 0.72967
Epoch 4, Train-Loss: 0.76019,  Val-Loss: 0.71611
Epoch 5, Train-Loss: 0.73630,  Val-Loss: 0.70157
Epoch 6, Train-Loss: 0.74912,  Val-Loss: 0.68779
Epoch 7, Train-Loss: 0.69200,  Val-Loss: 0.67847
Epoch 8, Train-Loss: 0.67633,  Val-Loss: 0.67484
Epoch 9, Train-Loss: 0.68412,  Val-Loss: 0.67472
Epoch 10, Train-Loss: 0.70130,  Val-Loss: 0.67833
Epoch 11, Train-Loss: 0.64139,  Val-Loss: 0.68283
Epoch 12, Train-Loss: 0.64348,  Val-Loss: 0.68741
Epoch 13, Train-Loss: 0.64057,  Val-Loss: 0.68744
Epoch 14, Train-Loss: 0.64092,  Val-Loss: 0.68282
Epoch 15, Train-Loss: 0.64274,  Val-Loss: 0.67616
Epoch 16, Train-Loss: 0.62897,  Val-Loss: 0.66850
Epoch 17, Train-Loss: 0.61223,  Val-Loss: 0.66136
Epoch 18, Train-Loss: 0.60613,  Val-Loss: 0.65612
Epoch 19, Train-Loss: 0.62201,  Val-Loss: 0.65173
Epoch 20, Train-Loss: 0.61434,  Val-Loss: 0.64859
Epoch 21,

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 1, Train-Loss: 0.76217,  Val-Loss: 0.69103
Epoch 2, Train-Loss: 0.73444,  Val-Loss: 0.68778
Epoch 3, Train-Loss: 0.72799,  Val-Loss: 0.68534
Epoch 4, Train-Loss: 0.73019,  Val-Loss: 0.68584
Epoch 5, Train-Loss: 0.69699,  Val-Loss: 0.69084
Epoch 6, Train-Loss: 0.68929,  Val-Loss: 0.69876
Epoch 7, Train-Loss: 0.69224,  Val-Loss: 0.71112
Epoch 8, Train-Loss: 0.64153,  Val-Loss: 0.72383
Epoch 9, Train-Loss: 0.66979,  Val-Loss: 0.72995
Epoch 10, Train-Loss: 0.65469,  Val-Loss: 0.73112
Epoch 11, Train-Loss: 0.68099,  Val-Loss: 0.72671
Epoch 12, Train-Loss: 0.67363,  Val-Loss: 0.72264
Epoch 13, Train-Loss: 0.63756,  Val-Loss: 0.71089
Epoch 14, Train-Loss: 0.63931,  Val-Loss: 0.69935
Epoch 15, Train-Loss: 0.64633,  Val-Loss: 0.69133
Epoch 16, Train-Loss: 0.61992,  Val-Loss: 0.68225
Epoch 17, Train-Loss: 0.61384,  Val-Loss: 0.67184
Epoch 18, Train-Loss: 0.61573,  Val-Loss: 0.65892
Epoch 19, Train-Loss: 0.59340,  Val-Loss: 0.65417
Epoch 20, Train-Loss: 0.56893,  Val-Loss: 0.65294
Epoch 21,

In [6]:
config.epochs = check_point
config.model_path = best_model_path

print("Best model path:", join_drive_path("log", config.model_path))
print("Model checkpoint:", config.epochs)

Best model path: ieee-transformer_250302035900177174_5.pt
Model checkpoint: 50


In [8]:
clear_cache()

trained_weights = torch.load(
    join_drive_path("log", config.model_path), weights_only=True, map_location=device
)
model = ViTransformer(**model_param)
model.load_state_dict(trained_weights)

<All keys matched successfully>

In [11]:
test_data_path = join_drive_path("data", data_config.test)
test_dataset = EEGDataset(test_data_path)
test_dataloader = DataLoader(test_dataset, batch_size=config.batch)

metrics = evaluate(model, device, test_dataloader)

print(f"Accuracy: {metrics['accuracy']:.3f}")
print(f"F1-Score: {metrics['f1-score']:.3f}")
print(f"Recall: {metrics['recall']:.3f}")
print(f"AUC: {metrics['auc']:.3f}")

Accuracy: 0.972
F1-Score: 0.976
Recall: 0.952
AUC: 0.976


In [12]:
json_path = join_drive_path("log", f"{config.name}_{config.id}.json")
log_json(
    json_path, config=config, data=data_config, model=model_config, metrics=metrics
)

{'config': {'id': '250302035900177174',
  'name': 'ieee-transformer',
  'model_path': 'ieee-transformer_250302035900177174_5.pt',
  'batch': 8,
  'epochs': 50,
  'lr': 0.001,
  'enable_fp16': True,
  'grad_step': 4,
  'warmup_steps': 50,
  'lr_decay_factor': 0.5,
  'weight_decay': 0.001,
  'patience': 30,
  'extra': None},
 'data': {'tag': 'IEEE_23',
  'train': 'ieee_train.pt',
  'test': 'ieee_test.pt',
  'val': 'ieee_val.pt',
  'channels': 19,
  'length': 9250,
  'num_classes': 2},
 'model': {'embed_dim': 64,
  'num_heads': 4,
  'num_blocks': 4,
  'block_hidden_dim': 64,
  'fc_hidden_dim': 32,
  'dropout': 0.1},
 'metrics': {'accuracy': 0.9722222222222222,
  'f1-score': 0.975609756097561,
  'recall': 0.9523809523809523,
  'auc': 0.9761904761904762}}