# Transformer

Train Base ViT model for IEEE EEG dataset

In [None]:
# Mount Google Drive for Colab env
import sys
from google.colab import drive

drive.mount("/content/drive", force_remount=False)
sys.path.append("/content/drive/MyDrive")

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from utils import *
from models.transformer import TransformerConfig, Transformer

ignore_warnings()
fix_random_seed(42)
device = device(force_cuda=True)
print("Device:", device)

In [None]:
config = Config(
    name="ieee transformer",
    batch=128,
    grad_step=4,
    epochs=300,
    lr=1e-4,
    warmup_steps=30,
    lr_decay_factor=2,
    weight_decay=1e-3,
    patience=20,
)
data_config = IEEEData()
model_config = TransformerConfig(
    embed_dim=56,
    num_heads=4,
    num_blocks=6,
    block_hidden_dim=56,
    fc_hidden_dim=64,
    dropout=0.1,
)

print("ID:", config.id)
print("Name:", config.name)

In [None]:
train_data_path = join_drive_path("data", data_config.train)
val_data_path = join_drive_path("data", data_config.val)

train_dataset = EEGDataset(train_data_path)
val_dataset = EEGDataset(val_data_path)

# The training set is ordered; shuffle it to ensure randomness
train_dataloader = DataLoader(train_dataset, batch_size=config.batch, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=config.batch)

In [None]:
clear_cache()

model = Transformer(
    seq_length=data_config.length,
    embed_dim=model_config.embed_dim,
    num_heads=model_config.num_heads,
    num_blocks=model_config.num_blocks,
    block_hidden_dim=model_config.block_hidden_dim,
    fc_hidden_dim=model_config.fc_hidden_dim,
    num_classes=data_config.num_classes,
    input_channel=data_config.channels,
    dropout_p=model_config.dropout,
).to(device)
loss_fn = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(
    model.parameters(),
    lr=config.lr,
    weight_decay=config.weight_decay,
)

In [None]:
check_point = train(
    model=model,
    model_path=config.model_path,
    device=device,
    optimizer=optimizer,
    grad_step=config.grad_step,
    criterion=loss_fn,
    epochs=config.epochs,
    train_loader=train_dataloader,
    val_loader=val_dataloader,
    initial_lr=config.lr,
    min_lr=1e-7,
    warmup_steps=config.warmup_steps,
    lr_decay_factor=config.lr_decay_factor,
    patience=config.patience,
)
config.epochs = check_point  # Update Epochs to save the result

In [None]:
clear_cache()

trained_weights = torch.load(config.model_path, weights_only=True, map_location=device)
model.load_state_dict(trained_weights)

In [None]:
test_data_path = join_drive_path("data", data_config.test)
test_dataset = EEGDataset(test_data_path)
test_dataloader = DataLoader(test_dataset, batch_size=config.batch)

metrics = evaluate(model, device, test_dataloader)

print(f"Accuracy: {metrics['accuracy']:.3f}")
print(f"F1-Score: {metrics['f1-score']:.3f}")
print(f"Recall: {metrics['recall']:.3f}")
print(f"AUC: {metrics['auc']:.3f}")

In [None]:
plot_roc(*metrics["roc-curve"])

In [None]:
json_path = join_drive_path("log", f"{config.name}_{config.id}.json")
metrics_to_log = {
    "accuracy": metrics["accuracy"],
    "f1-score": metrics["f1-score"],
    "recall": metrics["recall"],
    "auc": metrics["auc"],
}
log_json(json_path, config, data_config, model_config, **metrics_to_log)