# Transformer

Train Base ViT model for IEEE EEG dataset

In [1]:
# Mount Google Drive for Colab env
import sys
from google.colab import drive

drive.mount("/content/drive", force_remount=False)
sys.path.append("/content/drive/MyDrive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from utils import (
    ignore_warnings,
    fix_random_seed,
    device,
    clear_cache,
    join_drive_path,
    log_json,
    train,
    WarmupScheduler,
    evaluate,
    Config,
    IEEEDataConfig,
    EEGDataset,
)
from models.transformer import TransformerConfig, ViTransformer

ignore_warnings()
fix_random_seed(42)
device = device(force_cuda=True)
print("Device:", device)

Device: cuda


In [3]:
config = Config(
    name="ieee transformer",
    batch=8, # Cuda out of memory when batch=16.
    epochs=100,
    lr=1e-3,
    enable_fp16=True,
    grad_step=4,
    warmup_steps=50,
    lr_decay_factor=0.5,
    weight_decay=1e-3,
    patience=30,
)
data_config = IEEEDataConfig()
model_config = TransformerConfig(
    embed_dim=64,
    num_heads=4,
    num_blocks=4,
    block_hidden_dim=64,
    fc_hidden_dim=32,
    dropout=0.1,
)

print("ID:", config.id)
print("Name:", config.name)

ID: 250228075729819009
Name: ieee-transformer


In [4]:
train_data_path = join_drive_path("data", data_config.train)
val_data_path = join_drive_path("data", data_config.val)

train_dataset = EEGDataset(train_data_path)
val_dataset = EEGDataset(val_data_path)

train_dataloader = DataLoader(train_dataset, batch_size=config.batch, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=config.batch)

In [5]:
clear_cache()

model = ViTransformer(
    input_channel=data_config.channels,
    seq_length=data_config.length,
    embed_dim=model_config.embed_dim,
    num_heads=model_config.num_heads,
    num_blocks=model_config.num_blocks,
    block_hidden_dim=model_config.block_hidden_dim,
    fc_hidden_dim=model_config.fc_hidden_dim,
    num_classes=data_config.num_classes,
    dropout_p=model_config.dropout,
).to(device)
loss_fn = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(
    model.parameters(),
    lr=config.lr,
    weight_decay=config.weight_decay,
)
scheduler = WarmupScheduler(optimizer, lr=config.lr, warmup_steps=config.warmup_steps, decay_factor=config.lr_decay_factor)

In [6]:
check_point = train(
    model=model,
    model_path=join_drive_path("log", config.model_path),
    device=device,
    optimizer=optimizer,
    criterion=loss_fn,
    epochs=config.epochs,
    train_loader=train_dataloader,
    val_loader=val_dataloader,
    gradient_step=config.grad_step,
    patience=config.patience,
    enable_fp16=config.enable_fp16,
    scheduler=scheduler,
)

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1, Train-Loss: 0.72421,  Val-Loss: 0.67883
Epoch 2, Train-Loss: 0.73435,  Val-Loss: 0.68087
Epoch 3, Train-Loss: 0.71614,  Val-Loss: 0.68610
Epoch 4, Train-Loss: 0.72005,  Val-Loss: 0.69597
Epoch 5, Train-Loss: 0.69731,  Val-Loss: 0.70465
Epoch 6, Train-Loss: 0.68367,  Val-Loss: 0.71672
Epoch 7, Train-Loss: 0.65140,  Val-Loss: 0.72831
Epoch 8, Train-Loss: 0.67178,  Val-Loss: 0.74406
Epoch 9, Train-Loss: 0.65741,  Val-Loss: 0.75937
Epoch 10, Train-Loss: 0.63346,  Val-Loss: 0.76816
Epoch 11, Train-Loss: 0.67670,  Val-Loss: 0.76804
Epoch 12, Train-Loss: 0.64273,  Val-Loss: 0.75644
Epoch 13, Train-Loss: 0.67917,  Val-Loss: 0.74818
Epoch 14, Train-Loss: 0.63282,  Val-Loss: 0.74361
Epoch 15, Train-Loss: 0.63663,  Val-Loss: 0.74388
Epoch 16, Train-Loss: 0.61675,  Val-Loss: 0.74098
Epoch 17, Train-Loss: 0.64906,  Val-Loss: 0.74360
Epoch 18, Train-Loss: 0.62233,  Val-Loss: 0.72756
Epoch 19, Train-Loss: 0.57583,  Val-Loss: 0.72918
Epoch 20, Train-Loss: 0.61097,  Val-Loss: 0.73921
Epoch 21,

In [10]:
config.epochs = check_point

In [7]:
clear_cache()

trained_weights = torch.load(
    join_drive_path("log", config.model_path), weights_only=True, map_location=device
)
model.load_state_dict(trained_weights)

<All keys matched successfully>

In [8]:
test_data_path = join_drive_path("data", data_config.test)
test_dataset = EEGDataset(test_data_path)
test_dataloader = DataLoader(test_dataset, batch_size=config.batch)

metrics = evaluate(model, device, test_dataloader, enable_fp16=config.enable_fp16)

print(f"Accuracy: {metrics['accuracy']:.3f}")
print(f"F1-Score: {metrics['f1-score']:.3f}")
print(f"Recall: {metrics['recall']:.3f}")
print(f"AUC: {metrics['auc']:.3f}")

Accuracy: 0.972
F1-Score: 0.976
Recall: 0.952
AUC: 0.976


In [11]:
json_path = join_drive_path("log", f"{config.name}_{config.id}.json")
metrics_to_log = {
    "accuracy": metrics["accuracy"],
    "f1-score": metrics["f1-score"],
    "recall": metrics["recall"],
    "auc": metrics["auc"],
}
log_json(json_path, config, data_config, model_config, **metrics_to_log)

{'id': '250228075729819009',
 'name': 'ieee-transformer',
 'model_path': 'ieee-transformer_250228075729819009.pt',
 'batch': 8,
 'epochs': 47,
 'lr': 0.001,
 'enable_fp16': True,
 'grad_step': 4,
 'warmup_steps': 50,
 'lr_decay_factor': 0.5,
 'weight_decay': 0.001,
 'patience': 30,
 'tag': 'IEEE_23',
 'train': 'ieee_train.pt',
 'test': 'ieee_test.pt',
 'val': 'ieee_val.pt',
 'channels': 19,
 'length': 9250,
 'num_classes': 2,
 'embed_dim': 64,
 'num_heads': 4,
 'num_blocks': 4,
 'block_hidden_dim': 64,
 'fc_hidden_dim': 32,
 'dropout': 0.1,
 'accuracy': 0.9722222222222222,
 'f1-score': 0.975609756097561,
 'recall': 0.9523809523809523,
 'auc': 0.9761904761904762}