# Transformer

Train Base ViT model for IEEE EEG dataset

In [None]:
!pip install torchsummary

In [None]:
import os
import json
import secrets
import gc
from collections import OrderedDict

import numpy as np
from tqdm.auto import tqdm, trange
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchsummary import summary
from sklearn.metrics import accuracy_score, f1_score, roc_curve, auc
import matplotlib.pyplot as plt
import warnings
from google.colab import drive

drive.mount("/content/drive")
warnings.filterwarnings("ignore")

def join_path(*args):
    return os.path.join("/content/drive/MyDrive", *args)

def clear():
    gc.collect()
    torch.cuda.empty_cache()

## Settings

In [None]:
# Create unique ID
while True:
    experiment_id = secrets.token_hex(8)
    if not os.path.exists(join_path(f"{experiment_id}.pth")) and not os.path.exists(join_path(f"{experiment_id}.json")):
        break
print("ID:", experiment_id)

# Fix random seed
torch.manual_seed(42)
np.random.seed(42)

# Settings
ARGS = {
    "id": experiment_id,
    "name": "4 Head Self attention",
    "model_path": join_path(f"{experiment_id}.pth"),
    "batch": 256,
    "grad_step": 1,
    "epochs": 100,
    "lr": 1e-4,
    "warmup_steps": 10,
    "weight_decay": 1e-2,
    "patience": 2,
}
DATA = {
    "train_path": join_path("data", "train.pt"),
    "test_path": join_path("data", "test.pt"),
    "val_path": join_path("data", "val.pt"),
    "channel": 19,
    "length": 2560,
    "labels": ["control", "ADHD"],
}
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", DEVICE)

## Utils

In [None]:
class EarlyStopping(object):
    """Stop training when loss does not decrease

    :param patience: number of epochs to wait before stopping
    :param save_path: path to save the best model
    """

    def __init__(self, patience, save_path):
        self._min_loss = np.inf
        self._patience = patience
        self._path = save_path
        self.__counter = 0

    def should_stop(self, model, loss):
        """Check if training should stop

        :param model: model to save
        :param loss: current loss
        """
        if loss < self._min_loss:
            self._min_loss = loss
            self.__counter = 0
            torch.save(model.state_dict(), self._path)
        elif loss > self._min_loss:
            self.__counter += 1
            if self.__counter >= self._patience:
                return True
        return False

    def load(self, model):
        """Load best model

        :param model: model structure
        """
        model.load_state_dict(torch.load(self._path))
        return model

    @property
    def patience(self):
        """Return patience

        To calculate the check point:
        >>> stopper = EarlyStopping(...)
        >>> train()
        >>> check_point = epoch - stopper.patience
        """
        return self._patience

In [None]:
class WarmupScheduler:
    """Warmup learning rate and dynamically adjusts learning rate based on training loss.

    :param optimizer: torch optimizer
    :param initial_lr: initial learning rate
    :param min_lr: minimum learning rate
    :param warmup_steps: number of warmup steps
    :param decay_factor: decay factor
    """

    def __init__(
        self, optimizer, initial_lr, min_lr=1e-6, warmup_steps=10, decay_factor=10
    ):
        self.optimizer = optimizer
        self.initial_lr = initial_lr
        self.min_lr = min_lr
        self.warmup_steps = warmup_steps
        self.decay_factor = decay_factor

        assert self.warmup_steps > 0, "Warmup steps must be greater than 0"
        assert self.decay_factor > 1, "Decay factor must be greater than 1"

        self.global_step = 0
        self.best_loss = float("inf")

        # Store initial learning rates
        for param_group in self.optimizer.param_groups:
            param_group["lr"] = 0

    def step(self, loss):
        """Update learning rate based on current loss."""
        self.global_step += 1

        if self.global_step <= self.warmup_steps:
            # Linear warmup
            warmup_lr = self.initial_lr * (self.global_step / self.warmup_steps)
            for param_group in self.optimizer.param_groups:
                param_group["lr"] = warmup_lr
        else:
            # Check if loss increased
            if loss > self.best_loss:
                for param_group in self.optimizer.param_groups:
                    new_lr = max(param_group["lr"] / self.decay_factor, self.min_lr)
                    param_group["lr"] = new_lr
            self.best_loss = min(self.best_loss, loss)

    def get_lr(self):
        """Return current learning rates."""
        return [param_group["lr"] for param_group in self.optimizer.param_groups]

## Dataset

In [None]:
class EEGDataset(Dataset):
    def __init__(self, file_path):
        self.data = torch.load(file_path, mmap=True) # lazy load
        self.eeg = self.data["data"]
        self.labels = self.data["label"]

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.eeg[idx], self.labels[idx]

In [None]:
train_dataset = EEGDataset(DATA["train_path"])
val_dataset = EEGDataset(DATA["val_path"])
test_dataset = EEGDataset(DATA["test_path"])

train_dataloader = DataLoader(
    train_dataset, batch_size=ARGS["batch"], shuffle=True
)
val_dataloader = DataLoader(val_dataset, batch_size=ARGS["batch"])
test_dataloader = DataLoader(test_dataset, batch_size=ARGS["batch"])

## Model

In [None]:
class AttentionBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_dim):
        super(AttentionBlock, self).__init__()
        self.self_attention = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)

        self.norm1 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_dim),
            nn.ReLU(),
            nn.Linear(mlp_dim, embed_dim),
        )
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, input: torch.Tensor):
        # Multi-head Attention
        x = self.norm1(input)
        x, _ = self.self_attention(x, x, x, need_weights=False)

        # Add & Norm
        x_ = self.norm1(x + input)

        # Feed Forward
        x_ = self.mlp(x_)

        # Add & Norm
        x = self.norm2(x + x_)
        return x

In [None]:
class EEGTransformer(nn.Module):
    def __init__(self, embed_dim, num_heads, num_block, num_class, seq_length, mlp_dim, fc_dim):
        super(EEGTransformer, self).__init__()

        # Embedding
        self.embedding = nn.Conv1d(19, embed_dim, kernel_size=3, padding=1)
        self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, embed_dim).normal_(std=0.02))

        # Attention Blocks
        attention_blocks: OrderedDict[str, nn.Module] = OrderedDict()
        for i in range(num_block):
            attention_blocks[f"attention_block_{i}"] = AttentionBlock(embed_dim, num_heads, mlp_dim)
        self.encoder = nn.Sequential(attention_blocks)

        # Decoding layers
        self.global_max_pool = nn.AdaptiveMaxPool1d(1)
        self.fc = nn.Sequential(
            nn.Flatten(1, -1),
            nn.Linear(embed_dim, fc_dim),
            nn.Tanh(),
            nn.Linear(fc_dim, num_class),
        )

    def forward(self, input):
        x = self.embedding(input)
        x = x + self.pos_embedding
        x = self.encoder(x)
        x = self.global_max_pool(x)
        x = self.fc(x)
        return x

In [None]:
model = EEGTransformer().to(DEVICE)
summary(model, (1, DATA["length"], DATA["channel"]), device=DEVICE)

## Train

In [None]:
loss_fn = nn.CrossEntropyLoss().to(DEVICE)
optimizer = optim.Adam(
    model.parameters(),
    lr=ARGS["lr"],
    weight_decay=ARGS["weight_decay"],
)
scheduler = WarmupScheduler(optimizer, ARGS["lr"], warmup_steps=ARGS["warmup_steps"])

In [None]:
def evaluate(model, criterion, val_loader, device):
    model.eval()
    val_loss = 0

    with torch.no_grad():
        for signal, label in val_loader:
            signal = signal.to(device)
            label = label.to(device)
            output = model(signal)

            batch_loss = criterion(output.logits, label.long())
            val_loss += batch_loss.item()

        return val_loss / len(val_loader)


def train(model, optimizer, scheduler, criterion, train_loader, val_loader, device):
    clear()

    model_path = ARGS["model_path"]
    grad_step = ARGS["grad_step"]
    epoch_trange = trange(1, ARGS["epochs"] + 1)
    early_stopper = EarlyStopping(ARGS["patience"], model_path)

    model.zero_grad()

    for epoch in epoch_trange:
        model.train()
        train_loss = 0
        for batch_id, (signal, label) in enumerate(train_loader, start=1):

            output = model(signal)

            batch_loss = criterion(output.logits, label.long())
            train_loss += batch_loss.item()

            batch_loss /= grad_step
            batch_loss.backward()

            # Batch Accumulation
            if batch_id % grad_step == 0:
                optimizer.step()
                model.zero_grad()

        # Validate Training Epoch
        train_loss /= len(train_loader)
        val_loss = evaluate(model, criterion, val_loader, device)
        tqdm.write(
            f"Epoch {epoch}, Train-Loss: {train_loss:.5f},  Val-Loss: {val_loss:.5f}"
        )

        # Early stopping
        if early_stopper.should_stop(model, val_loss):
            break

        # Learning Rate Scheduling
        scheduler.step(train_loss)

    tqdm.write(f"\n--Check point: [Epoch: {epoch - early_stopper.patience}]")
    model = early_stopper.load(model)

    return model

In [None]:
model = train(model, optimizer, scheduler, loss_fn, train_dataloader, val_dataloader, DEVICE)

## Test

In [None]:
def test(model, test_loader, device):
    model.eval()
    y_pred = list()
    y_true = list()

    with torch.no_grad():
        for signal, label in test_loader:
            signal = signal.to(device)
            label = label.to(device)
            output = model(signal)

            y_pred.extend(output.logits.argmax(1).detach().cpu().numpy())
            y_true.extend(label.detach().cpu().numpy())

        accuracy = accuracy_score(y_true, y_pred)
        f1 = f1_score(y_true, y_pred)
        fpr, tpr, _ = roc_curve(y_true, y_pred)
        auc_value = auc(fpr, tpr)

        return {
            "accuracy": accuracy,
            "f1-score": f1,
            "auc": auc_value,
            "roc-curve": (fpr, tpr),
        }

In [None]:
metrics = test(model, test_dataloader, DEVICE)
print("Accuracy:", metrics["accuracy"])
print("F1-Score:", metrics["f1-score"])
print("AUC:", metrics["auc"])

In [None]:
plt.figure(figsize=(6, 6))
plt.plot(*metrics["roc-curve"], color='blue')
plt.plot([0, 1], [0, 1], color='grey', linestyle='--')  # Baseline
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.grid()
plt.show()

## Log

In [None]:
experiment_result = {
    "id": ARGS["id"],
    "weights": ARGS["model_path"],
    "batch": (ARGS["batch"], ARGS["grad_step"]), # Batch size with Gradient Accumulation
    "lr": ARGS["lr"],
    "weight_decay": ARGS["weight_decay"],
    "accuracy": metrics["accuracy"],
    "f1-score": metrics["f1-score"],
    "auc": metrics["auc"],
}

for key, value in experiment_result.items():
    print(f"{key}: {value}")

# Save result in json format
with open(join_path(f"{ARGS['id']}.json"), "w") as f:
    json.dump(experiment_result, f, indent=2)