# Transformer

Train Base ViT model for IEEE EEG dataset

In [None]:
import os
import json
import secrets
import gc

import numpy as np
from tqdm.auto import tqdm, trange
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import accuracy_score, f1_score, roc_curve, auc
import warnings
from google.colab import drive

drive.mount("/content/drive")
warnings.filterwarnings("ignore")

def join_path(*args):
    return os.path.join("/content/drive/MyDrive", *args)

def clear():
    gc.collect()
    torch.cuda.empty_cache()

## Settings

In [None]:
# Create unique ID
while True:
    experiment_id = secrets.token_hex(8)
    if not os.path.exists(join_path(f"{experiment_id}.pth")) and \
        not os.path.exists(join_path(f"{experiment_id}.json")):
        break
print("ID:", experiment_id)

# Fix random seed
torch.manual_seed(42)
np.random.seed(42)

# Settings
ARGS = {
    "id": experiment_id,
    "name": "4 Head Self attention",
    "model_path": join_path(f"{experiment_id}.pth"),
    "batch": 256,
    "grad_step": 1,
    "epochs": 100,
    "lr": 1e-4,
    "weight_decay": 1e-2,
    "patience": 2,
}
DATA = {
    "train_path": join_path("data", "train.pt"),
    "test_path": join_path("data", "test.pt"),
    "val_path": join_path("data", "val.pt"),
    "channel": 19,
    "length": 2560,
    "labels": ["control", "ADHD"],
}
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", DEVICE)

## Utils

In [3]:
class EarlyStopping(object):
    """Stop training when loss does not decrease

    :param patience: number of epochs to wait before stopping
    :param save_path: path to save the best model
    """

    def __init__(self, patience, save_path):
        self._min_loss = np.inf
        self._patience = patience
        self._path = save_path
        self.__counter = 0

    def should_stop(self, model, loss):
        """Check if training should stop

        :param model: model to save
        :param loss: current loss
        """
        if loss < self._min_loss:
            self._min_loss = loss
            self.__counter = 0
            torch.save(model.state_dict(), self._path)
        elif loss > self._min_loss:
            self.__counter += 1
            if self.__counter >= self._patience:
                return True
        return False

    def load(self, model):
        """Load best model

        :param model: model structure
        """
        model.load_state_dict(torch.load(self._path))
        return model

    @property
    def counter(self):
        return self.__counter

In [None]:
class WarmupScheduler:
    """Warmup learning rate and dynamically adjusts learning rate based on training loss.

    :param optimizer: torch optimizer
    :param initial_lr: initial learning rate
    :param min_lr: minimum learning rate
    :param warmup_steps: number of warmup steps
    :param decay_factor: decay factor
    """

    def __init__(
        self, optimizer, initial_lr, min_lr=1e-6, warmup_steps=10, decay_factor=10
    ):
        self.optimizer = optimizer
        self.initial_lr = initial_lr
        self.min_lr = min_lr
        self.warmup_steps = warmup_steps
        self.decay_factor = decay_factor

        assert self.warmup_steps > 0, "Warmup steps must be greater than 0"
        assert self.decay_factor > 1, "Decay factor must be greater than 1"

        self.global_step = 0
        self.best_loss = float("inf")

        # Store initial learning rates
        for param_group in self.optimizer.param_groups:
            param_group["lr"] = 0  # Start with 0 LR

    def step(self, loss):
        """Update learning rate based on current loss."""
        self.global_step += 1

        if self.global_step <= self.warmup_steps:
            # Linear warmup
            warmup_lr = self.initial_lr * (self.global_step / self.warmup_steps)
            for param_group in self.optimizer.param_groups:
                param_group["lr"] = warmup_lr
        else:
            # Check if loss increased
            if loss > self.best_loss:
                for param_group in self.optimizer.param_groups:
                    new_lr = max(param_group["lr"] / self.decay_factor, self.min_lr)
                    param_group["lr"] = new_lr
            self.best_loss = min(self.best_loss, loss)

    def get_lr(self):
        """Return current learning rates."""
        return [param_group["lr"] for param_group in self.optimizer.param_groups]

## Dataset

In [None]:
class EEGDataset(Dataset):
    def __init__(self, file_path):
        self.data = torch.load(file_path, mmap=True) # lazy load
        self.eeg = self.data["data"]
        self.labels = self.data["label"]

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.eeg[idx], self.labels[idx]

In [None]:
train_dataset = EEGDataset(DATA["train_path"])
val_dataset = EEGDataset(DATA["val_path"])
test_dataset = EEGDataset(DATA["test_path"])

train_dataloader = DataLoader(
    train_dataset, batch_size=ARGS["batch"], shuffle=True
)
val_dataloader = DataLoader(val_dataset, batch_size=ARGS["batch"])
test_dataloader = DataLoader(test_dataset, batch_size=ARGS["batch"])

## Model

## Train

In [None]:
loss_fn = nn.BCELoss()
model = EEGTransformer().to(DEVICE)
optimizer = optim.Adam(
    model.parameters(),
    lr=ARGS["lr"],
    weight_decay=ARGS["weight_decay"],
)

In [None]:
def evaluate(model, criterion, val_loader, device, mode=None):
    model.eval()
    val_loss = list()

    with torch.no_grad():
        for input_ids, attention_mask, label in val_loader:
            label = label.to(device)
            input_id = input_ids.to(device)
            mask = attention_mask.to(device)

            output = model(input_id, mask)

            batch_loss = criterion(output.logits, label.long())
            val_loss.append(batch_loss.item())

        return val_loss


def train(model, optimizer, scheduler, criterion, train_loader, val_loader, device):
    clear()

    model_path = ARGS["model_path"]
    grad_step = ARGS["grad_step"]
    epoch_progress = trange(1, ARGS["epochs"] + 1)
    early_stopper = EarlyStopping(ARGS["patience"], model_path)

    model.to(device)
    criterion.to(device)
    model.zero_grad()

    for epoch in epoch_progress:

        model.train()
        train_loss = list()
        for batch_id, data in enumerate(train_loader, start=1):

            input_ids, attention_mask, train_label = data
            input_id = input_ids.to(device)
            mask = attention_mask.to(device)
            train_label = train_label.to(device)

            output = model(input_id, mask)

            batch_loss = criterion(output.logits, train_label.long())
            train_loss.append(batch_loss.item())

            batch_loss /= grad_step
            batch_loss.backward()

            if batch_id % grad_step == 0:
                optimizer.step()
                model.zero_grad()

        val_loss = evaluate(model, criterion, val_loader, device, mode="train")
        train_loss = np.mean(train_loss)
        val_loss = np.mean(val_loss)
        tqdm.write(
            f"Epoch {epoch}, Train-Loss: {train_loss:.5f},  Val-Loss: {val_loss:.5f}"
        )

        if early_stopper.should_stop(model, val_loss):
            break

        scheduler.step()

    tqdm.write(f"\n\n -- EarlyStopping: [Epoch: {epoch - early_stopper.counter}]")
    tqdm.write(f"Model saved at '{model_path}'.")
    model = early_stopper.load(model)

    return model

## Test

In [None]:
accuracy = accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred)
fpr, tpr, _ = roc_curve(y_true, y_prob)
auc_value = auc(fpr, tpr)

## Log

In [None]:
experiment_result = {
    "id": ARGS["id"],
    "model_path": ARGS["model_path"],
    "batch": (ARGS["batch"], ARGS["grad_step"]),
    "lr": ARGS["lr"],
    "weight_decay": ARGS["weight_decay"],
    "accuracy": accuracy,
    "f1": f1,
    "auc": auc_value,
}

for key, value in experiment_result.items():
    print(f"{key}: {value}")

# Save result in json format
with open(join_path(f"{ARGS['id']}.json"), "w") as f:
    json.dump(experiment_result, f, indent=2)