# Transformer

Train Base ViT model for IEEE EEG dataset

In [1]:
# Mount Google Drive for Colab env
import sys
from google.colab import drive

drive.mount("/content/drive", force_remount=False)
sys.path.append("/content/drive/MyDrive")

Mounted at /content/drive


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from utils import (
    ignore_warnings,
    fix_random_seed,
    device,
    clear_cache,
    join_drive_path,
    log_json,
    train_with_kfold,
    WarmupScheduler,
    evaluate,
    Config,
    IEEEDataConfig,
    EEGDataset,
)
from models.transformer import TransformerConfig, ViTransformer

ignore_warnings()
fix_random_seed(42)
device = device(force_cuda=True)
print("Device:", device)

Device: cuda


In [3]:
config = Config(
    name="ieee transformer sm",
    batch=8,
    epochs=50,
    lr=1e-3,
    enable_fp16=True,
    grad_step=4,
    warmup_steps=30,
    lr_decay_factor=0.5,
    weight_decay=1e-3,
    patience=30,
)
config.add(k_folds=5)
data_config = IEEEDataConfig()
model_config = TransformerConfig(
    embed_dim=32,
    num_heads=4,
    num_blocks=4,
    block_hidden_dim=64,
    fc_hidden_dim=16,
    dropout=0.1,
)

print("ID:", config.id)
print("Name:", config.name)

ID: 250303012806180100
Name: ieee-transformer-sm


In [4]:
train_data_path = join_drive_path("data", data_config.train)
val_data_path = join_drive_path("data", data_config.val)

train_data = torch.load(train_data_path, weights_only=True)
val_data = torch.load(val_data_path, weights_only=True)

# Concat Train-set and Validation-set for Cross validation
signals = torch.cat([train_data["data"], val_data["data"]], dim=0)
labels = torch.cat([train_data["label"], val_data["label"]], dim=0)

train_dataset = EEGDataset({"data": signals, "label": labels})

In [5]:
model_param = {
    "input_channel": data_config.channels,
    "seq_length": data_config.length,
    "embed_dim": model_config.embed_dim,
    "num_heads": model_config.num_heads,
    "num_blocks": model_config.num_blocks,
    "block_hidden_dim": model_config.block_hidden_dim,
    "fc_hidden_dim": model_config.fc_hidden_dim,
    "num_classes": data_config.num_classes,
    "dropout_p": model_config.dropout,
}

criterion = nn.CrossEntropyLoss()
check_point, best_model_path = train_with_kfold(
    k_folds=config.k_folds,
    model_class=ViTransformer,
    device=device,
    model_path=join_drive_path("log", config.model_path),
    optimizer_class=optim.Adam,
    criterion=criterion,
    epochs=config.epochs,
    train_dataset=train_dataset,
    batch=config.batch,
    gradient_step=config.grad_step,
    patience=config.patience,
    model_params=model_param,
    optimizer_params={"lr": config.lr, "weight_decay": config.weight_decay},
    enable_fp16=config.enable_fp16,
    scheduler_class=WarmupScheduler,
    scheduler_params={
        "lr": config.lr,
        "warmup_steps": config.warmup_steps,
        "decay_factor": config.lr_decay_factor,
    },
)


===== Fold 1 =====


  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 1, Train-Loss: 0.70930,  Val-Loss: 0.61891
Epoch 2, Train-Loss: 0.74221,  Val-Loss: 0.62191
Epoch 3, Train-Loss: 0.71575,  Val-Loss: 0.62577
Epoch 4, Train-Loss: 0.68658,  Val-Loss: 0.63199
Epoch 5, Train-Loss: 0.70477,  Val-Loss: 0.63704
Epoch 6, Train-Loss: 0.68911,  Val-Loss: 0.64607
Epoch 7, Train-Loss: 0.69961,  Val-Loss: 0.65375
Epoch 8, Train-Loss: 0.66531,  Val-Loss: 0.66093
Epoch 9, Train-Loss: 0.67174,  Val-Loss: 0.66049
Epoch 10, Train-Loss: 0.65581,  Val-Loss: 0.65605
Epoch 11, Train-Loss: 0.64386,  Val-Loss: 0.65476
Epoch 12, Train-Loss: 0.66868,  Val-Loss: 0.64147
Epoch 13, Train-Loss: 0.63774,  Val-Loss: 0.63545
Epoch 14, Train-Loss: 0.64486,  Val-Loss: 0.63030
Epoch 15, Train-Loss: 0.63013,  Val-Loss: 0.61938
Epoch 16, Train-Loss: 0.63768,  Val-Loss: 0.60935
Epoch 17, Train-Loss: 0.63651,  Val-Loss: 0.60205
Epoch 18, Train-Loss: 0.60311,  Val-Loss: 0.60168
Epoch 19, Train-Loss: 0.59671,  Val-Loss: 0.61282
Epoch 20, Train-Loss: 0.58090,  Val-Loss: 0.59736
Epoch 21,

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 1, Train-Loss: 0.79971,  Val-Loss: 0.70547
Epoch 2, Train-Loss: 0.74759,  Val-Loss: 0.70012
Epoch 3, Train-Loss: 0.77210,  Val-Loss: 0.69473
Epoch 4, Train-Loss: 0.72796,  Val-Loss: 0.69169
Epoch 5, Train-Loss: 0.71608,  Val-Loss: 0.69278
Epoch 6, Train-Loss: 0.69677,  Val-Loss: 0.69772
Epoch 7, Train-Loss: 0.69807,  Val-Loss: 0.70102
Epoch 8, Train-Loss: 0.67199,  Val-Loss: 0.70462
Epoch 9, Train-Loss: 0.72549,  Val-Loss: 0.70513
Epoch 10, Train-Loss: 0.66281,  Val-Loss: 0.70260
Epoch 11, Train-Loss: 0.62716,  Val-Loss: 0.69767
Epoch 12, Train-Loss: 0.66020,  Val-Loss: 0.68966
Epoch 13, Train-Loss: 0.70724,  Val-Loss: 0.68184
Epoch 14, Train-Loss: 0.63162,  Val-Loss: 0.67239
Epoch 15, Train-Loss: 0.62857,  Val-Loss: 0.66634
Epoch 16, Train-Loss: 0.59206,  Val-Loss: 0.65785
Epoch 17, Train-Loss: 0.64023,  Val-Loss: 0.65037
Epoch 18, Train-Loss: 0.59327,  Val-Loss: 0.64640
Epoch 19, Train-Loss: 0.62196,  Val-Loss: 0.64174
Epoch 20, Train-Loss: 0.62442,  Val-Loss: 0.63888
Epoch 21,

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 1, Train-Loss: 0.80287,  Val-Loss: 0.77357
Epoch 2, Train-Loss: 0.78662,  Val-Loss: 0.75573
Epoch 3, Train-Loss: 0.79580,  Val-Loss: 0.73220
Epoch 4, Train-Loss: 0.73017,  Val-Loss: 0.70623
Epoch 5, Train-Loss: 0.71905,  Val-Loss: 0.68281
Epoch 6, Train-Loss: 0.67437,  Val-Loss: 0.66520
Epoch 7, Train-Loss: 0.72603,  Val-Loss: 0.65591
Epoch 8, Train-Loss: 0.69704,  Val-Loss: 0.65096
Epoch 9, Train-Loss: 0.69152,  Val-Loss: 0.64680
Epoch 10, Train-Loss: 0.68314,  Val-Loss: 0.64261
Epoch 11, Train-Loss: 0.64982,  Val-Loss: 0.63811
Epoch 12, Train-Loss: 0.67232,  Val-Loss: 0.63477
Epoch 13, Train-Loss: 0.64902,  Val-Loss: 0.63119
Epoch 14, Train-Loss: 0.68590,  Val-Loss: 0.63009
Epoch 15, Train-Loss: 0.62186,  Val-Loss: 0.62683
Epoch 16, Train-Loss: 0.61387,  Val-Loss: 0.62405
Epoch 17, Train-Loss: 0.62640,  Val-Loss: 0.61581
Epoch 18, Train-Loss: 0.60097,  Val-Loss: 0.60155
Epoch 19, Train-Loss: 0.60945,  Val-Loss: 0.59424
Epoch 20, Train-Loss: 0.59105,  Val-Loss: 0.58655
Epoch 21,

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 1, Train-Loss: 0.65219,  Val-Loss: 0.68580
Epoch 2, Train-Loss: 0.67183,  Val-Loss: 0.68595
Epoch 3, Train-Loss: 0.66912,  Val-Loss: 0.68602
Epoch 4, Train-Loss: 0.67036,  Val-Loss: 0.68608
Epoch 5, Train-Loss: 0.66525,  Val-Loss: 0.68642
Epoch 6, Train-Loss: 0.65525,  Val-Loss: 0.68666
Epoch 7, Train-Loss: 0.66714,  Val-Loss: 0.68708
Epoch 8, Train-Loss: 0.65917,  Val-Loss: 0.68652
Epoch 9, Train-Loss: 0.66055,  Val-Loss: 0.68591
Epoch 10, Train-Loss: 0.67485,  Val-Loss: 0.68514
Epoch 11, Train-Loss: 0.64726,  Val-Loss: 0.68439
Epoch 12, Train-Loss: 0.63649,  Val-Loss: 0.68310
Epoch 13, Train-Loss: 0.65423,  Val-Loss: 0.68132
Epoch 14, Train-Loss: 0.65197,  Val-Loss: 0.68069
Epoch 15, Train-Loss: 0.63596,  Val-Loss: 0.67921
Epoch 16, Train-Loss: 0.63409,  Val-Loss: 0.67790
Epoch 17, Train-Loss: 0.62535,  Val-Loss: 0.67677
Epoch 18, Train-Loss: 0.62106,  Val-Loss: 0.67483
Epoch 19, Train-Loss: 0.62768,  Val-Loss: 0.67410
Epoch 20, Train-Loss: 0.59466,  Val-Loss: 0.67326
Epoch 21,

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 1, Train-Loss: 0.69108,  Val-Loss: 0.76678
Epoch 2, Train-Loss: 0.68906,  Val-Loss: 0.76171
Epoch 3, Train-Loss: 0.67436,  Val-Loss: 0.75415
Epoch 4, Train-Loss: 0.67359,  Val-Loss: 0.74421
Epoch 5, Train-Loss: 0.66190,  Val-Loss: 0.73400
Epoch 6, Train-Loss: 0.68461,  Val-Loss: 0.72246
Epoch 7, Train-Loss: 0.66427,  Val-Loss: 0.71522
Epoch 8, Train-Loss: 0.64807,  Val-Loss: 0.71039
Epoch 9, Train-Loss: 0.67149,  Val-Loss: 0.70541
Epoch 10, Train-Loss: 0.67436,  Val-Loss: 0.70140
Epoch 11, Train-Loss: 0.65320,  Val-Loss: 0.69974
Epoch 12, Train-Loss: 0.64071,  Val-Loss: 0.70004
Epoch 13, Train-Loss: 0.64126,  Val-Loss: 0.69482
Epoch 14, Train-Loss: 0.63503,  Val-Loss: 0.69197
Epoch 15, Train-Loss: 0.64208,  Val-Loss: 0.68976
Epoch 16, Train-Loss: 0.62703,  Val-Loss: 0.69024
Epoch 17, Train-Loss: 0.62573,  Val-Loss: 0.69161
Epoch 18, Train-Loss: 0.63301,  Val-Loss: 0.69373
Epoch 19, Train-Loss: 0.61411,  Val-Loss: 0.69202
Epoch 20, Train-Loss: 0.61457,  Val-Loss: 0.68024
Epoch 21,

In [6]:
config.epochs = check_point
config.model_path = best_model_path

print("Best model path:", join_drive_path("log", config.model_path))
print("Model checkpoint:", config.epochs)

Best model path: /content/drive/MyDrive/log/ieee-transformer-sm_250303012806180100_2.pt
Model checkpoint: 49


In [7]:
clear_cache()

trained_weights = torch.load(
    join_drive_path("log", config.model_path), weights_only=True, map_location=device
)
model = ViTransformer(**model_param)
model.load_state_dict(trained_weights)

<All keys matched successfully>

In [8]:
test_data_path = join_drive_path("data", data_config.test)
test_dataset = EEGDataset(test_data_path)
test_dataloader = DataLoader(test_dataset, batch_size=config.batch)

metrics = evaluate(model, device, test_dataloader)

print(f"Accuracy: {metrics['accuracy']:.3f}")
print(f"F1-Score: {metrics['f1-score']:.3f}")
print(f"Recall: {metrics['recall']:.3f}")
print(f"AUC: {metrics['auc']:.3f}")

Accuracy: 0.944
F1-Score: 0.952
Recall: 0.952
AUC: 0.943


In [9]:
json_path = join_drive_path("log", f"{config.name}_{config.id}.json")
log_json(
    json_path, config=config, data=data_config, model=model_config, metrics=metrics
)

{'config': {'name': 'ieee-transformer-sm',
  'batch': 8,
  'epochs': 49,
  'lr': 0.001,
  'enable_fp16': True,
  'grad_step': 4,
  'warmup_steps': 30,
  'lr_decay_factor': 0.5,
  'weight_decay': 0.001,
  'patience': 30,
  'id': '250303012806180100',
  'model_path': '/content/drive/MyDrive/log/ieee-transformer-sm_250303012806180100_2.pt',
  'k_folds': 5},
 'data': {'tag': 'IEEE_23',
  'train': 'ieee_train.pt',
  'test': 'ieee_test.pt',
  'val': 'ieee_val.pt',
  'channels': 19,
  'length': 9250,
  'num_classes': 2},
 'model': {'embed_dim': 32,
  'num_heads': 4,
  'num_blocks': 4,
  'block_hidden_dim': 64,
  'fc_hidden_dim': 16,
  'dropout': 0.1},
 'metrics': {'accuracy': 0.9444444444444444,
  'f1-score': 0.9523809523809523,
  'recall': 0.9523809523809523,
  'auc': 0.9428571428571428}}