# Transformer

Train Base ViT model for IEEE EEG dataset

In [1]:
# Mount Google Drive for Colab env
import sys
from google.colab import drive

drive.mount("/content/drive", force_remount=False)
sys.path.append("/content/drive/MyDrive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from utils import (
    ignore_warnings,
    fix_random_seed,
    device,
    clear_cache,
    join_drive_path,
    log_json,
    train,
    evaluate,
    Config,
    IEEEDataConfig,
    EEGDataset,
)
from models.transformer import TransformerConfig, ViTransformer

ignore_warnings()
fix_random_seed(42)
device = device(force_cuda=True)
print("Device:", device)

Device: cuda


In [3]:
config = Config(
    name="ieee transformer",
    batch=8,
    epochs=50,
    lr=1e-3,
    grad_step=2,
    warmup_steps=30,
    lr_decay_factor=2,
    weight_decay=1e-3,
    patience=30,
)
data_config = IEEEDataConfig()
model_config = TransformerConfig(
    embed_dim=64,
    num_heads=4,
    num_blocks=4,
    block_hidden_dim=64,
    fc_hidden_dim=32,
    dropout=0.1,
)

print("ID:", config.id)
print("Name:", config.name)

ID: 250227082819731394
Name: ieee-transformer


In [4]:
train_data_path = join_drive_path("data", data_config.train)
val_data_path = join_drive_path("data", data_config.val)

train_dataset = EEGDataset(train_data_path)
val_dataset = EEGDataset(val_data_path)

train_dataloader = DataLoader(train_dataset, batch_size=config.batch, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=config.batch)

In [5]:
clear_cache()

model = ViTransformer(
    input_channel=data_config.channels,
    seq_length=data_config.length,
    embed_dim=model_config.embed_dim,
    num_heads=model_config.num_heads,
    num_blocks=model_config.num_blocks,
    block_hidden_dim=model_config.block_hidden_dim,
    fc_hidden_dim=model_config.fc_hidden_dim,
    num_classes=data_config.num_classes,
    dropout_p=model_config.dropout,
).to(device)
loss_fn = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(
    model.parameters(),
    lr=config.lr,
    weight_decay=config.weight_decay,
)

In [6]:
check_point = train(
    model=model,
    model_path=join_drive_path("log", config.model_path),
    device=device,
    optimizer=optimizer,
    gradient_step=config.grad_step,
    criterion=loss_fn,
    epochs=config.epochs,
    train_loader=train_dataloader,
    val_loader=val_dataloader,
    learning_rate=config.lr,
    min_lr=1e-8,
    warmup_steps=config.warmup_steps,
    lr_decay_factor=config.lr_decay_factor,
    patience=config.patience,
)

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 1, Train-Loss: 0.72046,  Val-Loss: 0.68152
Epoch 2, Train-Loss: 0.72222,  Val-Loss: 0.69384
Epoch 3, Train-Loss: 0.69556,  Val-Loss: 0.70985
Epoch 4, Train-Loss: 0.69749,  Val-Loss: 0.72461
Epoch 5, Train-Loss: 0.67572,  Val-Loss: 0.73123
Epoch 6, Train-Loss: 0.66886,  Val-Loss: 0.74125
Epoch 7, Train-Loss: 0.63411,  Val-Loss: 0.72936
Epoch 8, Train-Loss: 0.65321,  Val-Loss: 0.74320
Epoch 9, Train-Loss: 0.63364,  Val-Loss: 0.75853
Epoch 10, Train-Loss: 0.60568,  Val-Loss: 0.72771
Epoch 11, Train-Loss: 0.62832,  Val-Loss: 0.70706
Epoch 12, Train-Loss: 0.61022,  Val-Loss: 0.68605
Epoch 13, Train-Loss: 0.63782,  Val-Loss: 0.72742
Epoch 14, Train-Loss: 0.57227,  Val-Loss: 0.71311
Epoch 15, Train-Loss: 0.57045,  Val-Loss: 0.71467
Epoch 16, Train-Loss: 0.53519,  Val-Loss: 0.68736
Epoch 17, Train-Loss: 0.54855,  Val-Loss: 0.71036
Epoch 18, Train-Loss: 0.54712,  Val-Loss: 0.64775
Epoch 19, Train-Loss: 0.45137,  Val-Loss: 0.80003
Epoch 20, Train-Loss: 0.50532,  Val-Loss: 0.67170
Epoch 21,

KeyboardInterrupt: 

In [7]:
config.epochs = 30  # Interrupt manually

In [8]:
clear_cache()

trained_weights = torch.load(
    join_drive_path("log", config.model_path), weights_only=True, map_location=device
)
model.load_state_dict(trained_weights)

<All keys matched successfully>

In [9]:
test_data_path = join_drive_path("data", data_config.test)
test_dataset = EEGDataset(test_data_path)
test_dataloader = DataLoader(test_dataset, batch_size=config.batch)

metrics = evaluate(model, device, test_dataloader)

print(f"Accuracy: {metrics['accuracy']:.3f}")
print(f"F1-Score: {metrics['f1-score']:.3f}")
print(f"Recall: {metrics['recall']:.3f}")
print(f"AUC: {metrics['auc']:.3f}")

Accuracy: 0.944
F1-Score: 0.952
Recall: 0.952
AUC: 0.943


In [10]:
json_path = join_drive_path("log", f"{config.name}_{config.id}.json")
metrics_to_log = {
    "accuracy": metrics["accuracy"],
    "f1-score": metrics["f1-score"],
    "recall": metrics["recall"],
    "auc": metrics["auc"],
}
log_json(json_path, config, data_config, model_config, **metrics_to_log)

{'id': '250227082819731394',
 'name': 'ieee-transformer',
 'model_path': 'ieee-transformer_250227082819731394.pt',
 'batch': 8,
 'epochs': 30,
 'lr': 0.001,
 'grad_step': 2,
 'warmup_steps': 30,
 'lr_decay_factor': 2,
 'weight_decay': 0.001,
 'patience': 30,
 'embed_dim': 64,
 'num_heads': 4,
 'num_blocks': 4,
 'block_hidden_dim': 64,
 'fc_hidden_dim': 32,
 'dropout': 0.1,
 'accuracy': 0.9444444444444444,
 'f1-score': 0.9523809523809523,
 'recall': 0.9523809523809523,
 'auc': 0.9428571428571428}