In [12]:
import os
import pickle
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Subset
import torch.optim as optim

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [14]:
class DEAPDataset(Dataset):
    def __init__(
        self,
        data_dir,
        label_type="valence",   # "valence" or "arousal"
        window_size=256,        # 2 seconds @ 128 Hz
        step_size=128,          # 50% overlap
        transform=None
    ):
        self.samples = []
        self.labels = []
        self.transform = transform
        self.subject_ids = []

        label_idx = {"valence": 0, "arousal": 1}[label_type]

        for file in sorted(os.listdir(data_dir)):
            if not file.endswith(".dat"):
                continue
            
            subject_id = int(file[1:3])  # s01 → 1
            path = os.path.join(data_dir, file)
            with open(path, "rb") as f:
                subject = pickle.load(f, encoding="latin1")

            eeg = subject["data"][:, :32, :]    # (40, 32, 8064)
            labels = subject["labels"][:, label_idx]

            for trial in range(eeg.shape[0]):
                signal = eeg[trial]
                label = 1 if labels[trial] >= 5 else 0

                # Sliding window segmentation
                for start in range(0, signal.shape[1] - window_size, step_size):
                    window = signal[:, start:start + window_size]
                    self.samples.append(window)
                    self.labels.append(label)
                    self.subject_ids.append(subject_id)

        self.samples = np.array(self.samples, dtype=np.float32)
        self.labels = np.array(self.labels, dtype=np.int64)

        print(f"Loaded {len(self.samples)} samples")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        x = self.samples[idx]
        y = self.labels[idx]

        if self.transform:
            x = self.transform(x)

        return torch.tensor(x), torch.tensor(y), self.subject_ids[idx]

In [15]:
def loso_split(dataset, test_subject):
    train_idx = []
    test_idx = []

    for i, s in enumerate(dataset.subject_ids):
        if s == test_subject:
            test_idx.append(i)
        else:
            train_idx.append(i)

    return Subset(dataset, train_idx), Subset(dataset, test_idx)

In [16]:
def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for x, y, _ in loader:
            x = x.to(device)
            y = y.to(device)

            preds = model(x).argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)

    if total == 0:
        return 0.0   # or np.nan

    return 100.0 * correct / total

In [17]:
class MultiScaleTemporalCNN(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv1d(in_ch, out_ch, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(in_ch, out_ch, kernel_size=5, padding=2)
        self.conv3 = nn.Conv1d(in_ch, out_ch, kernel_size=7, padding=3)
        self.bn = nn.BatchNorm1d(out_ch * 3)

    def forward(self, x):
        # x: (B, C, T)
        f1 = self.conv1(x)
        f2 = self.conv2(x)
        f3 = self.conv3(x)
        out = torch.cat([f1, f2, f3], dim=1)
        return self.bn(out)

In [18]:
class DynamicGraphLearner(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.theta = nn.Linear(in_features, in_features, bias=False)

    def forward(self, x):
        # x: (B, C, F)
        theta_x = self.theta(x)
        A = torch.matmul(theta_x, x.transpose(1, 2))
        A = F.softmax(A, dim=-1)
        return A

In [None]:
class GraphTransformerEncoder(nn.Module):
    def __init__(self, dim, heads):
        super().__init__()
        self.heads = heads
        self.scale = dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)
        self.norm = nn.LayerNorm(dim)

    def forward(self, x, adj):
        # x: (B, C, F)
        # adj: (B, C, C)

        B, C, F = x.shape
        qkv = self.qkv(x).chunk(3, dim=-1)

        q, k, v = [
            t.reshape(B, C, self.heads, F // self.heads).transpose(1, 2)
            for t in qkv
        ]  # (B, H, C, F/H)

        attn = (q @ k.transpose(-2, -1)) * self.scale  # (B, H, C, C)

        attn = attn * adj.unsqueeze(1)

        attn = attn.softmax(dim=-1)

        out = (attn @ v).transpose(1, 2).reshape(B, C, F)
        out = self.proj(out)

        return self.norm(x + out)

In [20]:
class DySTGATFormer(nn.Module):
    def __init__(self, channels=32, time_steps=256, num_classes=2):
        super().__init__()

        self.temporal = MultiScaleTemporalCNN(channels, 32)  # → (B, 96, T)

        self.graph_learner = DynamicGraphLearner(96)

        self.transformer = GraphTransformerEncoder(dim=96, heads=4)

        self.pool = nn.AdaptiveAvgPool1d(1)

        self.classifier = nn.Sequential(
            nn.Linear(96, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, num_classes)
        )

    def forward(self, x):
        # x: (B, C, T)

        x = self.temporal(x)        # (B, 96, T)
        x = x.permute(0, 2, 1)     # (B, T, 96)

        x = x.mean(dim=1)          # temporal aggregation → (B, 96)

        x = x.unsqueeze(1).repeat(1, 32, 1)  # (B, C, 96)

        adj = self.graph_learner(x)           # (B, C, C)

        x = self.transformer(x, adj)          # (B, C, 96)

        x = x.mean(dim=1)                     # global pooling

        return self.classifier(x)

In [21]:
data_dir = "DEAP_dataset"

full_dataset = DEAPDataset(
    data_dir=data_dir,
    label_type="valence",  # or "arousal"
    window_size=256,
    step_size=128
)

Loaded 29280 samples


In [22]:
subject_accuracies = []

for subject in range(1, 33):

    train_set, test_set = loso_split(full_dataset, subject)
    if len(test_set) == 0 or len(train_set) == 0:
        print(f"Skipping subject {subject:02d} (empty split)")
        continue

    print(f"Subject {subject:02d}: "
        f"train={len(train_set)}, test={len(test_set)}")

    train_loader = DataLoader(
        train_set, batch_size=32, shuffle=True,
        num_workers=4, pin_memory=True
    )

    test_loader = DataLoader(
        test_set, batch_size=32, shuffle=False,
        num_workers=4, pin_memory=True
    )

    model = DySTGATFormer().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    # ---- TRAIN ----
    for epoch in range(10):
        model.train()
        for x, y, _ in train_loader:
            x, y = x.cuda(), y.cuda()
            optimizer.zero_grad()
            loss = criterion(model(x), y)
            loss.backward()
            optimizer.step()

    # ---- EVALUATE ----
    acc = evaluate(model, test_loader, device)
    subject_accuracies.append(acc)

    print(f"Subject {subject:02d} accuracy: {acc:.2f}%")

Subject 01: train=26840, test=2440
Subject 01 accuracy: 48.81%
Subject 02: train=26840, test=2440
Subject 02 accuracy: 62.50%
Subject 03: train=26840, test=2440
Subject 03 accuracy: 54.84%
Skipping subject 04 (empty split)
Subject 05: train=26840, test=2440
Subject 05 accuracy: 59.30%
Skipping subject 06 (empty split)
Subject 07: train=26840, test=2440
Subject 07 accuracy: 53.89%
Subject 08: train=26840, test=2440
Subject 08 accuracy: 54.30%
Subject 09: train=26840, test=2440
Subject 09 accuracy: 55.70%
Skipping subject 10 (empty split)
Subject 11: train=26840, test=2440
Subject 11 accuracy: 59.39%
Skipping subject 12 (empty split)
Skipping subject 13 (empty split)
Skipping subject 14 (empty split)
Subject 15: train=26840, test=2440
Subject 15 accuracy: 49.84%
Subject 16: train=26840, test=2440
Subject 16 accuracy: 33.32%
Skipping subject 17 (empty split)
Skipping subject 18 (empty split)
Skipping subject 19 (empty split)
Skipping subject 20 (empty split)
Skipping subject 21 (empty spl

In [23]:
import numpy as np

mean_acc = np.mean(subject_accuracies)
std_acc = np.std(subject_accuracies)

print(f"\nLOSO Accuracy: {mean_acc:.2f} ± {std_acc:.2f}")


LOSO Accuracy: 52.58 ± 7.24
