In [None]:
import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, random_split
import os
from torch.utils.data import Dataset
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
import torch.nn.utils.parametrizations as param
from collections import Counter
from tqdm import tqdm




In [None]:
# Load the file
data = np.load("eeg_cleaned_all.npz", allow_pickle=True)

# Access EEG and labels
eegs = data["X"]   
labels = data["Y"]


In [None]:
X_top = eegs
Y_top = labels


Y_top = np.array([y for y in Y_top])

In [None]:
class EEGTopDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.labels = y  

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        x = self.X[idx]
        # Normalize per trial 
        x = (x - x.mean(axis=1, keepdims=True)) / (x.std(axis=1, keepdims=True) + 1e-6)
        return torch.tensor(x, dtype=torch.float32), self.labels[idx]


In [None]:
# Dataset instance
dataset = EEGTopDataset(X_top, Y_top)

# Split
train_size = int(0.75 * len(dataset))
test_size = len(dataset) - train_size
train_ds, test_ds = random_split(dataset, [train_size, test_size])

# DataLoaders
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader  = DataLoader(test_ds, batch_size=64, shuffle=False)


In [None]:
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1. Fixed FocalBalancedLoss
class FocalBalancedLoss(nn.Module):
    def __init__(self, class_counts, alpha=0.25, gamma=2):
        super().__init__()
        weights = 1.0 / torch.tensor(class_counts, dtype=torch.float)
        self.weights = (weights / weights.sum()).to(device)
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, pred, target):
        log_prob = F.log_softmax(pred, dim=-1)
        prob = torch.exp(log_prob)

        # Gather the log probabilities for the target classes
        nll_loss = -log_prob.gather(1, target.unsqueeze(1))

        # Compute focal loss
        focal_loss = self.alpha * (1 - prob.gather(1, target.unsqueeze(1)))**self.gamma * nll_loss

        # Apply class weights
        weighted_loss = focal_loss.squeeze() * self.weights[target]
        return weighted_loss.mean()

# 2. Model Architecture (unchanged)
class AdvancedEEGNet(nn.Module):
    def __init__(self, num_classes=80, input_shape=(62, 501)):
        super().__init__()
        self.num_classes = num_classes
        self.channels, self.timepoints = input_shape

        # Block 1
        self.block1 = nn.Sequential(
            nn.Conv2d(1, 64, (1, 64), padding=(0, 32), bias=False),
            nn.BatchNorm2d(64),
            nn.GELU(),
            nn.Conv2d(64, 128, (self.channels, 1), groups=64, bias=False),
            nn.BatchNorm2d(128),
            nn.GELU(),
            nn.AvgPool2d((1, 4)),
            nn.Dropout(0.25)
        )

        # Block 2
        self.block2 = nn.Sequential(
            param.weight_norm(nn.Conv2d(128, 256, (1, 32), padding=(0, 16), bias=False)),
            nn.BatchNorm2d(256),
            nn.GELU(),
            AttentionBlock(256),
            nn.AvgPool2d((1, 4)),
            nn.Dropout(0.4)
        )

        # Block 3
        self.block3 = nn.Sequential(
            param.weight_norm(nn.Conv2d(256, 512, (1, 16), padding=(0, 8), bias=False)),
            nn.BatchNorm2d(512),
            nn.GELU(),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Dropout(0.5)
        )

        # Classifier
        self.classifier = nn.Sequential(
        nn.Linear(512, 2048),
        nn.GELU(),
        nn.LayerNorm(2048),
        nn.Dropout(0.4),
        nn.Linear(2048, 1024),
        nn.GELU(),
        nn.Dropout(0.3),
        nn.Linear(1024, num_classes)
    )

    def forward(self, x):
        x = x.unsqueeze(1)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = x.flatten(start_dim=1)
        return self.classifier(x)

# 3. Attention Block
class AttentionBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.att = nn.Sequential(
            nn.Conv2d(channels, channels//8, 1),
            nn.GELU(),
            nn.Conv2d(channels//8, channels, 1),
            nn.Sigmoid()
        )
        self.conv = nn.Conv2d(channels, channels, 3, padding=1)
        self.conv = param.weight_norm(self.conv)


    def forward(self, x):
        att = self.att(x)
        return x * att + self.conv(x)

# 4. Evaluation Function
def evaluate(model, loader, criterion):
    model.eval()
    running_loss, correct = 0.0, 0
    total = 0

    with torch.no_grad():
        for inputs, labels in tqdm(loader, desc="Evaluating"):
            inputs= inputs.to(device)
            labels = labels.to(device).long()

            with torch.amp.autocast(device_type='cuda', enabled=torch.cuda.is_available()):
                outputs = model(inputs)
                loss = criterion(outputs, labels)

            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            running_loss += loss.item() * inputs.size(0)
            total += labels.size(0)

    return running_loss / total, correct / total

# 5. Training Function
def train(model, loader, optimizer, criterion, scaler, scheduler=None):
    model.train()
    running_loss, correct = 0.0, 0
    total = 0

    for inputs, labels in tqdm(loader, desc="Training"):
        inputs= inputs.to(device)
        labels = labels.to(device).long()


        optimizer.zero_grad()

        with torch.amp.autocast(device_type='cuda', enabled=torch.cuda.is_available()):
            outputs = model(inputs)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()

        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        running_loss += loss.item() * inputs.size(0)
        total += labels.size(0)

    if scheduler:
        scheduler.step()
    return running_loss / total, correct / total, grad_norm.item()

# 6. Main Training Loop
def main(train_loader, test_loader):
    all_labels = []
    for _, labels in train_loader:
        all_labels.append(labels)
    class_counts = list(Counter(torch.cat(all_labels).cpu().numpy()).values())

    # Initialize model and components
    model = AdvancedEEGNet(num_classes=80).to(device)
    criterion = FocalBalancedLoss(class_counts)
    optimizer = AdamW(model.parameters(), lr=0.001, weight_decay=1e-5)
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=30, T_mult=2, eta_min=1e-6)
    scaler = torch.amp.GradScaler()
    swa_start = 80
    swa_model = AveragedModel(model)
    swa_scheduler = SWALR(optimizer, swa_lr=1e-4)

    best_acc = 0.0
    patience = 15
    no_improve = 0

    print(f"Training model with {sum(p.numel() for p in model.parameters()):,} parameters")

    for epoch in range(1, 201):
        train_loss, train_acc, grad_norm = train(model, train_loader, optimizer, criterion, scaler, scheduler)

        if epoch >= swa_start:
            swa_model.update_parameters(model)
            swa_scheduler.step()
        else:
            scheduler.step()



        test_loss, test_acc = evaluate(model, test_loader, criterion)

        print(f"\nEpoch {epoch}:")
        print(f"Train Acc: {train_acc:.4f} | Loss: {train_loss:.4f} | Grad Norm: {grad_norm:.2f}")
        print(f"Test Acc:  {test_acc:.4f} | Loss: {test_loss:.4f}")
        print(f"LR: {optimizer.param_groups[0]['lr']:.2e}")

        if test_acc > best_acc:
            best_acc = test_acc
            no_improve = 0
            torch.save(model.state_dict(), 'best_modelre.pth')
            print(" Saved best model!")
        else:
            no_improve += 1

        if no_improve >= patience:
            print(f"  Early stopping at epoch {epoch}")
            break

    print("\nUpdating batch norm stats with SWA model")
    swa_model.to(device)
    update_bn(train_loader, swa_model) 
    torch.save(swa_model.state_dict(), 'best_model_swa.pth')

    final_test_loss, final_test_acc = evaluate(swa_model, test_loader, criterion)
    print(f"\n Final SWA Test Accuracy: {final_test_acc*100:.2f}%")

    return swa_model  

In [None]:
model = main(train_loader, test_loader)