# Transformer

Train Base ViT model for IEEE EEG dataset

In [1]:
# Mount Google Drive for Colab env
import sys
from google.colab import drive

drive.mount("/content/drive", force_remount=False)
sys.path.append("/content/drive/MyDrive")

Mounted at /content/drive


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from utils import (
    ignore_warnings,
    fix_random_seed,
    device,
    clear_cache,
    join_drive_path,
    log_json,
    train_with_kfold,
    WarmupScheduler,
    evaluate,
    Config,
    IEEEDataConfig,
    EEGDataset,
)
from models.transformer import TransformerConfig, ViTransformer

ignore_warnings()
fix_random_seed(42)
device = device(force_cuda=True)
print("Device:", device)

Device: cuda


In [3]:
config = Config(
    name="ieee transformer",
    batch=8,
    epochs=50,
    lr=1e-3,
    enable_fp16=True,
    grad_step=4,
    warmup_steps=30,
    lr_decay_factor=0.5,
    weight_decay=1e-3,
    patience=30,
)
config.add(k_folds=5)
data_config = IEEEDataConfig()
model_config = TransformerConfig(
    embed_dim=64,
    num_heads=4,
    num_blocks=4,
    block_hidden_dim=128,
    fc_hidden_dim=32,
    dropout=0.1,
)

print("ID:", config.id)
print("Name:", config.name)

ID: 250303001232982598
Name: ieee-transformer


In [4]:
train_data_path = join_drive_path("data", data_config.train)
val_data_path = join_drive_path("data", data_config.val)

train_data = torch.load(train_data_path, weights_only=True)
val_data = torch.load(val_data_path, weights_only=True)

# Concat Train-set and Validation-set for Cross validation
signals = torch.cat([train_data["data"], val_data["data"]], dim=0)
labels = torch.cat([train_data["label"], val_data["label"]], dim=0)

train_dataset = EEGDataset({"data": signals, "label": labels})

In [5]:
model_param = {
    "input_channel": data_config.channels,
    "seq_length": data_config.length,
    "embed_dim": model_config.embed_dim,
    "num_heads": model_config.num_heads,
    "num_blocks": model_config.num_blocks,
    "block_hidden_dim": model_config.block_hidden_dim,
    "fc_hidden_dim": model_config.fc_hidden_dim,
    "num_classes": data_config.num_classes,
    "dropout_p": model_config.dropout,
}

criterion = nn.CrossEntropyLoss()
check_point, best_model_path = train_with_kfold(
    k_folds=config.k_folds,
    model_class=ViTransformer,
    device=device,
    model_path=config.model_path,
    optimizer_class=optim.Adam,
    criterion=criterion,
    epochs=config.epochs,
    train_dataset=train_dataset,
    batch=config.batch,
    gradient_step=config.grad_step,
    patience=config.patience,
    model_params=model_param,
    optimizer_params={"lr": config.lr, "weight_decay": config.weight_decay},
    enable_fp16=config.enable_fp16,
    scheduler_class=WarmupScheduler,
    scheduler_params={
        "lr": config.lr,
        "warmup_steps": config.warmup_steps,
        "decay_factor": config.lr_decay_factor,
    },
)


===== Fold 1 =====


  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 1, Train-Loss: 0.83362,  Val-Loss: 0.87037
Epoch 2, Train-Loss: 0.81363,  Val-Loss: 0.82943
Epoch 3, Train-Loss: 0.76236,  Val-Loss: 0.77422
Epoch 4, Train-Loss: 0.72067,  Val-Loss: 0.71213
Epoch 5, Train-Loss: 0.70140,  Val-Loss: 0.65857
Epoch 6, Train-Loss: 0.71600,  Val-Loss: 0.62298
Epoch 7, Train-Loss: 0.69855,  Val-Loss: 0.60456
Epoch 8, Train-Loss: 0.69132,  Val-Loss: 0.59623
Epoch 9, Train-Loss: 0.67160,  Val-Loss: 0.59117
Epoch 10, Train-Loss: 0.67245,  Val-Loss: 0.59479
Epoch 11, Train-Loss: 0.63504,  Val-Loss: 0.60853
Epoch 12, Train-Loss: 0.62889,  Val-Loss: 0.63093
Epoch 13, Train-Loss: 0.61484,  Val-Loss: 0.63689
Epoch 14, Train-Loss: 0.59062,  Val-Loss: 0.61239
Epoch 15, Train-Loss: 0.57963,  Val-Loss: 0.59356
Epoch 16, Train-Loss: 0.58103,  Val-Loss: 0.57167
Epoch 17, Train-Loss: 0.58916,  Val-Loss: 0.55743
Epoch 18, Train-Loss: 0.52298,  Val-Loss: 0.56049
Epoch 19, Train-Loss: 0.51367,  Val-Loss: 0.55241
Epoch 20, Train-Loss: 0.51065,  Val-Loss: 0.57405
Epoch 21,

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 1, Train-Loss: 1.42598,  Val-Loss: 1.23487
Epoch 2, Train-Loss: 1.42278,  Val-Loss: 1.16691
Epoch 3, Train-Loss: 1.27158,  Val-Loss: 1.07139
Epoch 4, Train-Loss: 1.13190,  Val-Loss: 0.95914
Epoch 5, Train-Loss: 1.03817,  Val-Loss: 0.84918
Epoch 6, Train-Loss: 0.89114,  Val-Loss: 0.76055
Epoch 7, Train-Loss: 0.78621,  Val-Loss: 0.70179
Epoch 8, Train-Loss: 0.71217,  Val-Loss: 0.68459
Epoch 9, Train-Loss: 0.65768,  Val-Loss: 0.71130
Epoch 10, Train-Loss: 0.66891,  Val-Loss: 0.74660
Epoch 11, Train-Loss: 0.69187,  Val-Loss: 0.76084
Epoch 12, Train-Loss: 0.71086,  Val-Loss: 0.74400
Epoch 13, Train-Loss: 0.71402,  Val-Loss: 0.71523
Epoch 14, Train-Loss: 0.68297,  Val-Loss: 0.69520
Epoch 15, Train-Loss: 0.63206,  Val-Loss: 0.67902
Epoch 16, Train-Loss: 0.64472,  Val-Loss: 0.66958
Epoch 17, Train-Loss: 0.64827,  Val-Loss: 0.66165
Epoch 18, Train-Loss: 0.62881,  Val-Loss: 0.65522
Epoch 19, Train-Loss: 0.59074,  Val-Loss: 0.64909
Epoch 20, Train-Loss: 0.62995,  Val-Loss: 0.64437
Epoch 21,

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 1, Train-Loss: 0.75827,  Val-Loss: 0.69509
Epoch 2, Train-Loss: 0.69780,  Val-Loss: 0.68518
Epoch 3, Train-Loss: 0.69323,  Val-Loss: 0.67533
Epoch 4, Train-Loss: 0.70772,  Val-Loss: 0.66479
Epoch 5, Train-Loss: 0.70435,  Val-Loss: 0.65783
Epoch 6, Train-Loss: 0.70065,  Val-Loss: 0.65160
Epoch 7, Train-Loss: 0.64972,  Val-Loss: 0.64447
Epoch 8, Train-Loss: 0.68423,  Val-Loss: 0.63612
Epoch 9, Train-Loss: 0.62811,  Val-Loss: 0.62854
Epoch 10, Train-Loss: 0.61472,  Val-Loss: 0.62196
Epoch 11, Train-Loss: 0.61602,  Val-Loss: 0.61580
Epoch 12, Train-Loss: 0.59977,  Val-Loss: 0.60498
Epoch 13, Train-Loss: 0.59105,  Val-Loss: 0.59890
Epoch 14, Train-Loss: 0.55905,  Val-Loss: 0.59677
Epoch 15, Train-Loss: 0.51647,  Val-Loss: 0.57628
Epoch 16, Train-Loss: 0.56760,  Val-Loss: 0.56595
Epoch 17, Train-Loss: 0.53443,  Val-Loss: 0.55991
Epoch 18, Train-Loss: 0.49017,  Val-Loss: 0.54511
Epoch 19, Train-Loss: 0.48041,  Val-Loss: 0.53248
Epoch 20, Train-Loss: 0.43992,  Val-Loss: 0.52212
Epoch 21,

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 1, Train-Loss: 1.25596,  Val-Loss: 1.11313
Epoch 2, Train-Loss: 1.17531,  Val-Loss: 1.05027
Epoch 3, Train-Loss: 1.09979,  Val-Loss: 0.96372
Epoch 4, Train-Loss: 1.01238,  Val-Loss: 0.86444
Epoch 5, Train-Loss: 0.88032,  Val-Loss: 0.77309
Epoch 6, Train-Loss: 0.81281,  Val-Loss: 0.70388
Epoch 7, Train-Loss: 0.69886,  Val-Loss: 0.67336
Epoch 8, Train-Loss: 0.65465,  Val-Loss: 0.67899
Epoch 9, Train-Loss: 0.64395,  Val-Loss: 0.70543
Epoch 10, Train-Loss: 0.71167,  Val-Loss: 0.72850
Epoch 11, Train-Loss: 0.71020,  Val-Loss: 0.73492
Epoch 12, Train-Loss: 0.70954,  Val-Loss: 0.72124
Epoch 13, Train-Loss: 0.66057,  Val-Loss: 0.69878
Epoch 14, Train-Loss: 0.63484,  Val-Loss: 0.66908
Epoch 15, Train-Loss: 0.66688,  Val-Loss: 0.64941
Epoch 16, Train-Loss: 0.59632,  Val-Loss: 0.64080
Epoch 17, Train-Loss: 0.58627,  Val-Loss: 0.63666
Epoch 18, Train-Loss: 0.62124,  Val-Loss: 0.63211
Epoch 19, Train-Loss: 0.57871,  Val-Loss: 0.62476
Epoch 20, Train-Loss: 0.58519,  Val-Loss: 0.61825
Epoch 21,

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 1, Train-Loss: 0.85703,  Val-Loss: 0.73628
Epoch 2, Train-Loss: 0.84693,  Val-Loss: 0.71466
Epoch 3, Train-Loss: 0.82297,  Val-Loss: 0.69954
Epoch 4, Train-Loss: 0.76304,  Val-Loss: 0.70273
Epoch 5, Train-Loss: 0.70288,  Val-Loss: 0.72590
Epoch 6, Train-Loss: 0.72836,  Val-Loss: 0.76239
Epoch 7, Train-Loss: 0.72023,  Val-Loss: 0.79107
Epoch 8, Train-Loss: 0.70082,  Val-Loss: 0.79616
Epoch 9, Train-Loss: 0.66297,  Val-Loss: 0.77336
Epoch 10, Train-Loss: 0.64663,  Val-Loss: 0.73679
Epoch 11, Train-Loss: 0.65394,  Val-Loss: 0.70028
Epoch 12, Train-Loss: 0.67720,  Val-Loss: 0.68226
Epoch 13, Train-Loss: 0.62701,  Val-Loss: 0.66449
Epoch 14, Train-Loss: 0.62007,  Val-Loss: 0.66030
Epoch 15, Train-Loss: 0.60256,  Val-Loss: 0.66815
Epoch 16, Train-Loss: 0.61664,  Val-Loss: 0.67745
Epoch 17, Train-Loss: 0.55244,  Val-Loss: 0.65428
Epoch 18, Train-Loss: 0.52648,  Val-Loss: 0.62615
Epoch 19, Train-Loss: 0.54095,  Val-Loss: 0.61666
Epoch 20, Train-Loss: 0.48300,  Val-Loss: 0.63503
Epoch 21,

In [6]:
config.epochs = check_point
config.model_path = best_model_path

print("Best model path:", join_drive_path("log", config.model_path))
print("Model checkpoint:", config.epochs)

Best model path: /content/drive/MyDrive/log/ieee-transformer_250303001232982598_3.pt
Model checkpoint: 35


In [8]:
clear_cache()

trained_weights = torch.load(
    join_drive_path("log", config.model_path), weights_only=True, map_location=device
)
model = ViTransformer(**model_param)
model.load_state_dict(trained_weights)

<All keys matched successfully>

In [9]:
test_data_path = join_drive_path("data", data_config.test)
test_dataset = EEGDataset(test_data_path)
test_dataloader = DataLoader(test_dataset, batch_size=config.batch)

metrics = evaluate(model, device, test_dataloader)

print(f"Accuracy: {metrics['accuracy']:.3f}")
print(f"F1-Score: {metrics['f1-score']:.3f}")
print(f"Recall: {metrics['recall']:.3f}")
print(f"AUC: {metrics['auc']:.3f}")

Accuracy: 0.972
F1-Score: 0.976
Recall: 0.952
AUC: 0.976


In [10]:
json_path = join_drive_path("log", f"{config.name}_{config.id}.json")
log_json(
    json_path, config=config, data=data_config, model=model_config, metrics=metrics
)

{'config': {'name': 'ieee-transformer',
  'batch': 8,
  'epochs': 35,
  'lr': 0.001,
  'enable_fp16': True,
  'grad_step': 4,
  'warmup_steps': 30,
  'lr_decay_factor': 0.5,
  'weight_decay': 0.001,
  'patience': 30,
  'id': '250303001232982598',
  'model_path': 'ieee-transformer_250303001232982598_3.pt',
  'k_folds': 5},
 'data': {'tag': 'IEEE_23',
  'train': 'ieee_train.pt',
  'test': 'ieee_test.pt',
  'val': 'ieee_val.pt',
  'channels': 19,
  'length': 9250,
  'num_classes': 2},
 'model': {'embed_dim': 64,
  'num_heads': 4,
  'num_blocks': 4,
  'block_hidden_dim': 128,
  'fc_hidden_dim': 32,
  'dropout': 0.1},
 'metrics': {'accuracy': 0.9722222222222222,
  'f1-score': 0.975609756097561,
  'recall': 0.9523809523809523,
  'auc': 0.9761904761904762}}