# import

In [1]:
import mne
import numpy as np
from scipy.signal import firwin, filtfilt
import matplotlib.pyplot as plt
import logging
import warnings
import torch
import random
warnings.filterwarnings("ignore", category=RuntimeWarning, module="mne")
# Completely silence MNE-Python output
mne.set_log_level('WARNING')  # or 'ERROR' for even less output
logging.getLogger('mne').setLevel(logging.WARNING)
#mne.set_log_level('debug')

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# load data

In [2]:
def load_eeg(path, sfreq = None):
    raw = mne.io.read_raw_gdf(path, preload=True)

    # Step 2: Get the sampling frequency
    if not sfreq:
        sfreq = raw.info['sfreq']  # Hz

    # Step 3: Define FIR filter parameters
    low_cutoff = 1.0   # Hz
    high_cutoff = 30.0 # Hz
    filter_order = 177 # Must be odd for linear-phase FIR

    nyquist = 0.5 * sfreq


    fir_coeffs = firwin(
        numtaps=filter_order,
        cutoff=[low_cutoff / nyquist, high_cutoff / nyquist],
        pass_zero=False,
        window='blackman'
    )
    eeg_data = raw.get_data()
    filtered_data = filtfilt(fir_coeffs, 1.0, eeg_data, axis=1)

    new_raw = mne.io.RawArray(filtered_data, raw.info.copy())
    annotations = raw.annotations 
    new_raw.set_annotations(annotations)

    return new_raw

# preprocess

In [3]:
def preprocess(path, test=False, event_id= None):

    raw = load_eeg(path)
    eeg_channels = raw.ch_names[:22]
    raw.pick(eeg_channels)
    
    if event_id is not None:
        events, events_id = mne.events_from_annotations(raw, event_id=event_id)
    else:
        events, events_id = mne.events_from_annotations(raw)

        #print(events_id)
    #print(events_id)
    epochs = mne.Epochs(
        raw,
        events=events,  
        tmin=0,     
        tmax=6.0,
        event_id=events_id,
        baseline=None,
        preload=True
    )

    labels = epochs.events[:, 2]
    #print(labels)
    data = epochs.get_data()
    
    return {
        'epochs': data,   
        'labels': labels
    }

#a = preprocess("E:/LOKI/BCI-IV/A01T.gdf")
# print(a['epochs'].shape)


# Data Loader

In [4]:
import glob
from torch.utils.data import Dataset, DataLoader
import scipy.io as sio
import os
from sklearn.model_selection import train_test_split
import torch.nn.functional as F
from scipy.ndimage import gaussian_filter1d

def gauss_smooth(inputs, device, smooth_kernel_std=2, smooth_kernel_size=100,  padding='same'):

    #print(inputs.shape)
    inputs = inputs.transpose(0, 2, 1)
    # Get Gaussian kernel
    inp = np.zeros(smooth_kernel_size, dtype=np.float32)
    inp[smooth_kernel_size // 2] = 1
    gaussKernel = gaussian_filter1d(inp, smooth_kernel_std)
    validIdx = np.argwhere(gaussKernel > 0.01)
    gaussKernel = gaussKernel[validIdx]
    gaussKernel = np.squeeze(gaussKernel / np.sum(gaussKernel))

    # Convert to tensor
    gaussKernel = torch.tensor(gaussKernel, dtype=torch.float32, device=device)
    gaussKernel = gaussKernel.view(1, 1, -1)  # [1, 1, kernel_size]

    # Prepare convolution
    B, T, C = inputs.shape
    inputs = inputs.transpose(0, 2, 1)  # [B, C, T]
    inputs = torch.tensor(inputs, dtype=torch.float32, device=device)
    
    gaussKernel = gaussKernel.repeat(C, 1, 1)  # [C, 1, kernel_size]

    # Perform convolution
    smoothed = F.conv1d(inputs, gaussKernel, padding=padding, groups=C)
    return smoothed  # [B, T, C]

# 
import pandas as pd

class BCI4_2a_Dataset(Dataset):

    def __init__(self, data_dir, subjects=[1], transform=None, target_transform=None):
        self.data_dir = data_dir
        self.subjects = subjects
        self.transform = transform
        self.target_transform = target_transform
        

        self.annotations = pd.read_csv('/teamspace/studios/shared-amethyst-w577/PME4_dataset_configs.csv')
        self.data, self.labels = self._load_data()

        self.labels = self.labels 
        
    def _load_data(self):
        all_data = []
        all_labels = []

        for subj in self.subjects:
            subj_str = f"s{subj:02d}"
            subj_annots = self.annotations[self.annotations['subject'] == subj]

            for _, row in subj_annots.iterrows():
                trial = row['trial']
                label = row['emotion_num']  # numeric label

                eeg_path = os.path.join(self.data_dir, subj_str, f"t{trial:03d}", f"{subj_str}_t{trial:03d}_processed_eeg_1kHz.npy")
                if not os.path.exists(eeg_path):
                    print(f"Warning: {eeg_path} not found!")
                    continue

                eeg_data = np.load(eeg_path)  # shape: (channels, timepoints)
                # Normalize the trial
                means = eeg_data.mean(axis=1, keepdims=True)
                stds = eeg_data.std(axis=1, keepdims=True)
                eeg_data = (eeg_data - means) / (stds + 1e-8)

                # Append original data
                all_data.append(eeg_data)
                all_labels.append(label)

                eeg_data = np.expand_dims(eeg_data, axis=0)
                eeg_smooth = gauss_smooth(eeg_data, device='cpu')  # output shape: [1, channels, timepoints]
                eeg_smooth = eeg_smooth.squeeze(0).numpy()  # back to (channels, timepoints)

                all_data.append(eeg_smooth)  # append augmented data
                all_labels.append(label)      # same label for augmented data

        # Convert to arrays
        all_data = np.stack(all_data, axis=0)  # shape: (num_samples, channels, timepoints)

        # Map labels
        label_map = {2:0, 4:1, 6:2, 8:3, 10:4, 12:5, 14:6}
        all_labels = np.array([label_map[l] for l in all_labels])

        return all_data, all_labels
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        x = self.data[idx]
        y = self.labels[idx]
        
        x = torch.from_numpy(x).float()  # shape: (1, timepoints, channels)
        
        if self.transform:
            x = self.transform(x)
        if self.target_transform:
            y = self.target_transform(y)
            
        return x, y

def get_data_loaders(data_dir,all_subjects=[1], test_subject=[1,2], batch_size=32, val_split=0.2, random_state=42):
    # Get all subjects except the test subject for training/validation
    
    #train_val_subjects = [s for s in all_subjects if s not in test_subject]
    
    # Create datasets
    train_val_dataset = BCI4_2a_Dataset(data_dir, subjects=all_subjects)
    test_dataset = BCI4_2a_Dataset(data_dir, subjects=test_subject)
    single_trial_shape = test_dataset.data[0].shape
    print(f"Train+Val dataset size: {len(train_val_dataset)}")
    print(f"Test dataset size: {len(test_dataset)}")
    
    # Split train_val into train and validation
    train_idx, val_idx = train_test_split(
        range(len(train_val_dataset)),
        test_size=val_split,
        random_state=random_state,
        stratify=train_val_dataset.labels
    )
    
    train_dataset = torch.utils.data.Subset(train_val_dataset, train_idx)
    val_dataset = torch.utils.data.Subset(train_val_dataset, val_idx)
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    print(f"Train batches: {len(train_loader)}")
    print(f"Val batches: {len(val_loader)}")
    print(f"Test batches: {len(test_loader)}")
    print(f"Single Trial Shape: {single_trial_shape}")
    return train_loader, val_loader, test_loader, single_trial_shape



# Model

In [5]:
import torch
import torch.nn as nn
import math

class ChannelAttention(nn.Module):
    def __init__(self, num_channels, reduction_ratio=4, dropout=0.1):
        super(ChannelAttention, self).__init__()
        self.num_channels = num_channels
        self.reduction_ratio = reduction_ratio

        self.mlp = nn.Sequential(
            nn.Linear(num_channels, num_channels // reduction_ratio),
            nn.ReLU(),
            nn.Dropout(dropout),  # Added dropout
            nn.Linear(num_channels // reduction_ratio, num_channels)
        )

    def forward(self, x):
        channel_avg = torch.mean(x, dim=2, keepdim=True)  
        channel_max, _ = torch.max(x, dim=2, keepdim=True) 
        combined = channel_avg + channel_max  
        combined = combined.squeeze(2)
        attention = self.mlp(combined)
        attention = torch.sigmoid(attention).unsqueeze(2)
        attended_x = x * attention
        return attended_x, attention.squeeze(2)  # Return proper attention scores


class LearnableSTFT(nn.Module):
    def __init__(self, window_size, hop_size, learnable_window=True, dropout=0.05):
        super(LearnableSTFT, self).__init__()
        
        self.window_size = window_size
        self.hop_size = hop_size
        self.dft_size = window_size
        self.learnable_window = learnable_window
        
        # Initialize with Hamming window
        initial_window = 0.54 - 0.46 * torch.cos(
            2 * math.pi * torch.arange(window_size, dtype=torch.float32) / (window_size - 1)
        )
        
        if learnable_window:
            self.window = nn.Parameter(initial_window)
        else:
            # Fixed window reduces overfitting
            self.register_buffer('window', initial_window)
        
        # Batch normalization after STFT
        self.batch_norm = nn.BatchNorm2d(1)
        self.dropout = nn.Dropout2d(dropout)
        
        dft_matrix = self._create_dft_matrix(self.dft_size, self.window_size)
        self.register_buffer('dft_matrix', dft_matrix)

    def _create_dft_matrix(self, dft_size, window_size):
        k = torch.arange(dft_size).unsqueeze(1)
        n = torch.arange(window_size)
        angle = -2 * math.pi * k * n / dft_size
        dft_matrix = torch.complex(torch.cos(angle), torch.sin(angle))
        return dft_matrix

    def forward(self, signal):
        if signal.dim() == 1:
            signal = signal.unsqueeze(0).unsqueeze(0)
        elif signal.dim() == 2:
            signal = signal.unsqueeze(1)
            
        batch_size, num_channels, num_samples = signal.shape
        signal_reshaped = signal.reshape(batch_size * num_channels, num_samples)
        
        frames = signal_reshaped.unfold(dimension=1, size=self.window_size, step=self.hop_size)
        num_frames_unfolded = frames.shape[1]
        expected_num_frames = int(math.ceil((num_samples - self.window_size) / self.hop_size)) + 1
        
        if num_frames_unfolded < expected_num_frames:
            padding_amount = (expected_num_frames - 1) * self.hop_size + self.window_size - num_samples
            padded_signal = torch.nn.functional.pad(signal_reshaped, (0, padding_amount))
            frames = padded_signal.unfold(1, self.window_size, self.hop_size)
        
        windowed_frames = frames * self.window
        BC, F, W = windowed_frames.shape
        windowed_frames_reshaped = windowed_frames.reshape(BC * F, W)
        
        windowed_frames_complex = windowed_frames_reshaped.to(self.dft_matrix.dtype)
        stft_result_reshaped = self.dft_matrix @ windowed_frames_complex.T
        stft_result = stft_result_reshaped.T.reshape(batch_size, num_channels, F, self.dft_size)
        
        # Apply magnitude and normalization
        stft_magnitude = torch.abs(stft_result)
        
        # Reshape for batch norm: (B*C, 1, F, freq_bins)
        stft_normalized = stft_magnitude.reshape(batch_size * num_channels, 1, F, self.dft_size)
        stft_normalized = self.batch_norm(stft_normalized)
        stft_normalized = self.dropout(stft_normalized)
        stft_normalized = stft_normalized.reshape(batch_size, num_channels, F, self.dft_size)
        
        return stft_normalized


class Attention4D(nn.Module):
    def __init__(self, in_channels, time_frames, freq_bins, d_model=128, n_heads=4, 
                 d_ff=256, dropout=0.2, num_layers=1):
        super().__init__()
        
        self.d_model = d_model
        self.time_frames = time_frames
        self.num_layers = num_layers
        
        in_features = in_channels * freq_bins
        
        # Project input to d_model with layer norm
        self.projection = nn.Sequential(
            nn.Linear(in_features, d_model),
            nn.LayerNorm(d_model),
            nn.Dropout(dropout)
        )
        
        # Learnable positional encoding with smaller init
        self.positional_encoding = nn.Parameter(
            torch.randn(1, time_frames, d_model) * 0.02
        )
        
        # Stack multiple transformer layers for better representation
        self.transformer_layers = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        self.final_norm = nn.LayerNorm(d_model)
        
    def forward(self, x):
        batch_size = x.shape[0]
        
        # Reshape: (B, C, T, F) -> (B, T, C*F)
        x = x.permute(0, 2, 1, 3)
        x = x.reshape(batch_size, self.time_frames, -1)
        
        # Project and add positional encoding
        x = self.projection(x)
        x = x + self.positional_encoding
        
        # Store attention weights from all layers
        all_attention_weights = []
        
        # Pass through transformer layers
        for layer in self.transformer_layers:
            x, attn_weights = layer(x)
            all_attention_weights.append(attn_weights)
        
        x = self.final_norm(x)
        
        return x, all_attention_weights


class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout):
        super().__init__()
        
        self.layernorm1 = nn.LayerNorm(d_model)
        self.attention = nn.MultiheadAttention(
            embed_dim=d_model, 
            num_heads=n_heads, 
            dropout=dropout, 
            batch_first=True
        )
        self.dropout1 = nn.Dropout(dropout)
        
        self.layernorm2 = nn.LayerNorm(d_model)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),  # GELU often works better than ReLU
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )
        self.dropout2 = nn.Dropout(dropout)
    
    def forward(self, x):
        # Pre-norm architecture (more stable)
        x_norm = self.layernorm1(x)
        attn_output, attn_weights = self.attention(x_norm, x_norm, x_norm, 
                                                    need_weights=True, 
                                                    average_attn_weights=True)
        x = x + self.dropout1(attn_output)
        
        x_norm = self.layernorm2(x)
        ff_output = self.feed_forward(x_norm)
        x = x + self.dropout2(ff_output)
        
        return x, attn_weights


class Classifier(nn.Module):
    def __init__(self, in_features, num_classes, dropout=0.3):
        super().__init__()
        
        self.pooling = nn.AdaptiveAvgPool1d(1)
        
        # Add intermediate layer for better capacity
        self.classifier = nn.Sequential(
            nn.Linear(in_features, in_features // 2),
            nn.LayerNorm(in_features // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(in_features // 2, num_classes)
        )
    
    def forward(self, x):
        x = x.permute(0, 2, 1)  # (B, d_model, T)
        x = self.pooling(x).squeeze(2)  # (B, d_model)
        output = self.classifier(x)  # (B, num_classes)
        return output


class EEGClassifier(nn.Module):
    def __init__(self, num_samples, num_classes=4, window_percent=0.25, 
                 overlap_percent=0.10, d_model=128, n_heads=8, 
                 transformer_layers=1, dropout=0.2, learnable_window=True):

        super(EEGClassifier, self).__init__()
        
        window_size = int(num_samples * window_percent)
        hop_size = int(num_samples * overlap_percent)
        
        self.window_size = max(window_size, 30)
        self.hop_size = max(hop_size, 10)
        
        self.channel_attn = ChannelAttention(num_channels=8, dropout=dropout)
        
        self.learnable_stft = LearnableSTFT(
            window_size=self.window_size,
            hop_size=self.hop_size,
            learnable_window=learnable_window,
            dropout=dropout * 0.5
        )
        
        time_frames = int(math.ceil((num_samples - self.window_size) / self.hop_size)) + 1
        
        self.attention_module = Attention4D(
            in_channels=8,
            time_frames=time_frames,
            freq_bins=self.window_size,
            d_model=d_model,
            n_heads=n_heads,
            num_layers=transformer_layers,
            dropout=dropout
        )
        
        self.classifier = Classifier(
            in_features=d_model, 
            num_classes=num_classes,
            dropout=dropout
        )

    def forward(self, x, return_attention=False):

        # Channel attention
        x, channel_attn_scores = self.channel_attn(x)
        
        # STFT
        x = self.learnable_stft(x)
        
        # Transformer attention
        x, transformer_attn_weights = self.attention_module(x)
        
        # Classification
        logits = self.classifier(x)
        
        if return_attention:
            attention_dict = {
                'channel_attention': channel_attn_scores,  # (B, 22)
                'transformer_attention': transformer_attn_weights,  # List of (B, T, T) for each layer
            }
            return logits, attention_dict
        
        return logits


class AttentionRegularizedLoss(nn.Module):

    def __init__(self, num_classes, alpha=0.01, beta=0.01):
        super().__init__()
        self.ce_loss = nn.CrossEntropyLoss()
        self.alpha = alpha  # Channel attention regularization weight
        self.beta = beta    # Transformer attention regularization weight
        
    def forward(self, logits, targets, attention_dict):
        # Classification loss
        ce = self.ce_loss(logits, targets)
        
        # Channel attention regularization (encourage diversity)
        
        # Transformer attention regularization (discourage extreme peakiness)
        transformer_attn = attention_dict['transformer_attention'][-1]  # Use last layer (B, T, T)
        # Variance is always non-negative (variance >= 0)
        attn_variance = torch.var(transformer_attn, dim=-1).mean()
        
        # Calculate total loss: CE + (alpha * negative_entropy) + (beta * variance)
        # Since alpha * negative_entropy is negative, it reduces the loss.
        total_loss = ce + self.beta * attn_variance
        
        # *** CRITICAL FIX: Ensure Total Loss is Non-Negative ***
        # If CE is near zero, the negative channel_reg term can cause total_loss < 0.
        # Clipping maintains optimization goal while ensuring mathematical stability.
        total_loss = torch.max(total_loss, torch.tensor(0.0, device=total_loss.device))
        
        return total_loss, {
            'ce_loss': ce.item(),
            'attn_variance': attn_variance.item()
        }




# test model

In [6]:
# if __name__ == "__main__":
#     # Model initialization
#     model = EEGClassifier(
#         num_samples=1000,
#         num_classes=4,
#         d_model=64,  # Reduced from 128
#         n_heads=4,   # Reduced from 8
#         transformer_layers=2,
#         dropout=0.3,
#         learnable_window=False  # Start with fixed window
#     )
    
#     # Loss function
#     criterion = AttentionRegularizedLoss(num_classes=4, alpha=0.01, beta=0.01)
    
#     # Dummy data
#     x = torch.randn(4, 22, 1000)  # (batch, channels, time)
#     y = torch.randint(0, 4, (4,))
    
#     # Forward pass with attention
#     logits, attention_dict = model(x, return_attention=True)
    
#     # Calculate loss
#     loss, loss_dict = criterion(logits, y, attention_dict)
    
#     print(f"Logits shape: {logits.shape}")
#     print(f"Channel attention shape: {attention_dict['channel_attention'].shape}")
#     print(f"Num transformer layers: {len(attention_dict['transformer_attention'])}")
#     print(f"Transformer attention shape: {attention_dict['transformer_attention'][0].shape}")
#     print(f"Total loss: {loss.item():.4f}")
#     print(f"Loss components: {loss_dict}")

# Train model

In [7]:
import time
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import ReduceLROnPlateau
import json
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

class EEGTrainer:
    def __init__(self, model, train_loader, val_loader, test_loader=None, config=None):

        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        
        # Set device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self.model.to(self.device)
        
        # Default configuration
        self.config = {
            'lr': 1e-3,
            'weight_decay': 1e-4,
            'patience': 10,
            'min_lr': 1e-6,
            'epochs': 100,
            'save_dir': 'experiments',
            'experiment_name': f'exp_{time.strftime("%Y%m%d-%H%M%S")}',
            'save_attention_maps': True,
            'attention_map_freq': 5
        }
        
        # Update with user config if provided
        if config:
            self.config.update(config)
            
        # Create experiment directory
        self.exp_dir = os.path.join(self.config['save_dir'], self.config['experiment_name'])
        os.makedirs(self.exp_dir, exist_ok=True)
        
        # Initialize components
        self._init_components()
        
        # Save config
        self._save_config()
        
    def _init_components(self):
        """Initialize training components"""
        # Loss function (CrossEntropy + KLDiv for attention regularization)
        self.criterion = AttentionRegularizedLoss(num_classes=4, alpha=0.01, beta=0.01)#nn.CrossEntropyLoss()
        
        # Optimizer
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=self.config['lr'],
            weight_decay=self.config['weight_decay']
        )
        
        # Learning rate scheduler
        self.scheduler = ReduceLROnPlateau(
            self.optimizer,
            mode='max',
            factor=0.5,
            patience=self.config['patience']//2,
            min_lr=self.config['min_lr'],
            #verbose=True
        )
        
        self.best_val_acc = 0.0
        self.early_stop_counter = 0
        
        # Tensorboard writer
        self.writer = SummaryWriter(log_dir=self.exp_dir)
        
        # Attention maps directory
        if self.config['save_attention_maps']:
            self.attention_dir = os.path.join(self.exp_dir, 'attention_maps')
            os.makedirs(self.attention_dir, exist_ok=True)
    
    def _save_config(self):
        config_path = os.path.join(self.exp_dir, 'config.json')
        with open(config_path, 'w') as f:
            json.dump(self.config, f, indent=4)
    
    def _compute_metrics(self, outputs, labels):
        _, predicted = torch.max(outputs.data, 1)
        correct = (predicted == labels).sum().item()
        accuracy = correct / labels.size(0)
        return accuracy
    
    def _log_metrics(self, phase, metrics, epoch):
        loss = metrics['loss']
        acc = metrics['accuracy']
        pre = metrics['precision']
        rec = metrics['recall']
        f1 = metrics['f1']
        
        # Console logging
        print(f"{phase.capitalize()} - Epoch: {epoch+1} | "
              f"Loss: {loss:.4f} | Acc: {acc:.2%} | Precision: {pre:.2%} | recall: {rec:.2%} | f1: {f1:.2%}")
        
        # Tensorboard logging
        self.writer.add_scalar(f'Loss/{phase}', loss, epoch)
        self.writer.add_scalar(f'Accuracy/{phase}', acc, epoch)
        
        
        # Log learning rate
        if phase == 'train':
            lr = self.optimizer.param_groups[0]['lr']
            self.writer.add_scalar('LR', lr, epoch)
    
    def _save_checkpoint(self, epoch, is_best=False):
        """Save model checkpoint"""
        state = {
            'epoch': epoch,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict(),
            'best_val_acc': self.best_val_acc
        }
        
        # Save regular checkpoint
        checkpoint_path = os.path.join(self.exp_dir, f'checkpoint_epoch_{epoch}.pth')
        torch.save(state, checkpoint_path)
        
        # Save best model
        if is_best:
            best_path = os.path.join(self.exp_dir, 'best_model.pth')
            torch.save(state, best_path)
    

    
    def train_epoch(self, epoch):
        """Train for one epoch"""
        self.model.train()
        running_loss = 0.0
        total_samples = 0

        # Store all predictions and labels for epoch-wise metrics
        all_preds = []
        all_labels = []

        for inputs, labels in self.train_loader:
            inputs = inputs.to(self.device)
            labels = labels.to(self.device)

            # Zero gradients
            self.optimizer.zero_grad()
            
            # Forward pass
            outputs, attention_dict = self.model(inputs, return_attention=True)
            loss ,_ = self.criterion(outputs, labels,attention_dict)
            
            # Backward pass and optimize
            loss.backward()
            self.optimizer.step()
            
            # ----- accumulate loss -----
            batch_size = inputs.size(0)
            running_loss += loss.item() * batch_size
            total_samples += batch_size

            # ----- accumulate preds + labels for metrics -----
            preds = torch.argmax(outputs, dim=1)          # [B]
            all_preds.append(preds.detach().cpu())
            all_labels.append(labels.detach().cpu())

        # ---- end of epoch: stack everything and compute metrics ----
        all_preds = torch.cat(all_preds).numpy()
        all_labels = torch.cat(all_labels).numpy()

        epoch_loss = running_loss / total_samples

        # accuracy
        epoch_acc = accuracy_score(all_labels, all_preds)

        # precision, recall, f1
        precision, recall, f1, _ = precision_recall_fscore_support(
            all_labels,
            all_preds,
            average='weighted',      # change to 'macro' / 'binary' if needed
            zero_division=0          # avoid NaN if a class is missing in preds
        )

        return {
            'loss': epoch_loss,
            'accuracy': epoch_acc,
            'precision': precision,
            'recall': recall,
            'f1': f1,
        }


    
    def validate_epoch(self, epoch):
        """Validate for one epoch"""
        self.model.eval()
        running_loss = 0.0
        total_samples = 0

        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for inputs, labels in self.val_loader:
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)
     
                # Forward pass
                #outputs = self.model(inputs)
                outputs, attention_dict = self.model(inputs, return_attention=True)
                loss, _ = self.criterion(outputs, labels, attention_dict)
                #loss = self.criterion(outputs, labels)
                
                batch_size = inputs.size(0)
                running_loss += loss.item() * batch_size
                total_samples += batch_size

                # ----- accumulate preds + labels for metrics -----
                preds = torch.argmax(outputs, dim=1)          # [B]
                all_preds.append(preds.detach().cpu())
                all_labels.append(labels.detach().cpu())

            # ---- end of epoch: stack everything and compute metrics ----
            all_preds = torch.cat(all_preds).numpy()
            all_labels = torch.cat(all_labels).numpy()

            epoch_loss = running_loss / total_samples

            # accuracy
            epoch_acc = accuracy_score(all_labels, all_preds)

            # precision, recall, f1
            precision, recall, f1, _ = precision_recall_fscore_support(
                all_labels,
                all_preds,
                average='weighted',      # change to 'macro' / 'binary' if needed
                zero_division=0          # avoid NaN if a class is missing in preds
            )

            return {
                'loss': epoch_loss,
                'accuracy': epoch_acc,
                'precision': precision,
                'recall': recall,
                'f1': f1,
            }
    
    def train(self):
        """Main training loop"""
        start_time = time.time()
        
        for epoch in range(self.config['epochs']):
            # Train and validate
            train_metrics = self.train_epoch(epoch)
            val_metrics = self.validate_epoch(epoch)
            
            # Log metrics
            self._log_metrics('train', train_metrics, epoch)
            self._log_metrics('val', val_metrics, epoch)
            
            # Save attention maps periodically
            
            # Step scheduler
            self.scheduler.step(val_metrics['accuracy'])
            
            # Check for best model
            if val_metrics['accuracy'] > self.best_val_acc:
                self.best_val_acc = val_metrics['accuracy']
                self._save_checkpoint(epoch, is_best=True)
                self.early_stop_counter = 0
            else:
                self.early_stop_counter += 1
            
            # Save checkpoint periodically
            if epoch % 10 == 0:
                self._save_checkpoint(epoch)
            
            # Early stopping
            # if self.early_stop_counter >= self.config['patience']:
            #     print(f"Early stopping at epoch {epoch+1}")
            #     break
        
        # Training complete
        training_time = time.time() - start_time
        print(f"Training completed in {training_time//60:.0f}m {training_time%60:.0f}s")
        print(f"Best validation accuracy: {self.best_val_acc:.2%}")
        
        # Test if test loader provided
        if self.test_loader:
            test_acc = self.test()
            print(f"Test accuracy: {test_acc:.2%}")
        
        # Close tensorboard writer
        self.writer.close()
        
        return self.best_val_acc
    
    def test(self):
        """Evaluate on test set"""
        self.model.eval()
        running_acc = 0.0
        total_samples = 0
        
        # Load best model
        best_path = os.path.join(self.exp_dir, 'best_model.pth')
        if os.path.exists(best_path):
            checkpoint = torch.load(best_path)
            self.model.load_state_dict(checkpoint['state_dict'])
            print(f"Loaded best model from epoch {checkpoint['epoch']}")
        
        all_preds=[]
        all_labels=[]
        with torch.no_grad():
            for inputs, labels in self.test_loader:
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)
                
                outputs = self.model(inputs)
                batch_size = inputs.size(0)
                total_samples += batch_size

                # ----- accumulate preds + labels for metrics -----
                preds = torch.argmax(outputs, dim=1)          # [B]
                all_preds.append(preds.detach().cpu())
                all_labels.append(labels.detach().cpu())

            # ---- end of epoch: stack everything and compute metrics ----
            all_preds = torch.cat(all_preds).numpy()
            all_labels = torch.cat(all_labels).numpy()

            epoch_acc = accuracy_score(all_labels, all_preds)

            # precision, recall, f1
            precision, recall, f1, _ = precision_recall_fscore_support(
                all_labels,
                all_preds,
                average='weighted',      # change to 'macro' / 'binary' if needed
                zero_division=0          # avoid NaN if a class is missing in preds
            )
        
        test_acc = epoch_acc
        self.writer.add_scalar('Accuracy/test', test_acc)
        print("acc: ",test_acc)
        print("precision:", precision)
        print("recall:", recall)
        print("F1 :", f1)
        return epoch_acc







2025-11-29 10:46:15.696764: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2025-11-29 10:46:15.744081: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-11-29 10:46:15.744125: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-11-29 10:46:15.744155: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-11-29 10:46:15.753844: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: A

# Main Code

In [8]:
if __name__ == "__main__":
    # Example configuration
    config = {
        'lr': 3e-4,
        'weight_decay': 1e-5,
        'patience': 15,
        'epochs': 100,
        'experiment_name': 'eeg_attention_experiment',
        'save_attention_maps': True
    }
    
    
    
    train_loader, val_loader, test_loader, trail_shape = get_data_loaders(data_dir="/teamspace/studios/shared-amethyst-w577/data2", all_subjects = [1,2,3,4,5,6,7,8,9,10,11])
    model = EEGClassifier(num_samples = trail_shape[-1], num_classes=7)
    # Create trainer
    trainer = EEGTrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        config=config
    )
    
    # Start training
    best_val_acc = trainer.train()

Train+Val dataset size: 7658
Test dataset size: 1398
Train batches: 192
Val batches: 48
Test batches: 44
Single Trial Shape: (8, 5000)
Train - Epoch: 1 | Loss: 1.9614 | Acc: 14.32% | Precision: 14.21% | recall: 14.32% | f1: 14.18%
Val - Epoch: 1 | Loss: 1.9487 | Acc: 15.86% | Precision: 16.68% | recall: 15.86% | f1: 10.03%
Train - Epoch: 2 | Loss: 1.9556 | Acc: 13.86% | Precision: 13.78% | recall: 13.86% | f1: 13.64%
Val - Epoch: 2 | Loss: 1.9465 | Acc: 14.30% | Precision: 2.04% | recall: 14.30% | f1: 3.58%
Train - Epoch: 3 | Loss: 1.9500 | Acc: 14.48% | Precision: 14.15% | recall: 14.48% | f1: 13.67%
Val - Epoch: 3 | Loss: 1.9455 | Acc: 15.21% | Precision: 7.39% | recall: 15.21% | f1: 7.75%
Train - Epoch: 4 | Loss: 1.9485 | Acc: 15.03% | Precision: 14.92% | recall: 15.03% | f1: 14.71%
Val - Epoch: 4 | Loss: 1.9438 | Acc: 15.01% | Precision: 16.96% | recall: 15.01% | f1: 6.08%
Train - Epoch: 5 | Loss: 1.9448 | Acc: 16.18% | Precision: 15.83% | recall: 16.18% | f1: 15.58%
Val - Epoch: 5

  checkpoint = torch.load(best_path)


acc:  0.969241773962804
precision: 0.9693417285448676
recall: 0.969241773962804
F1 : 0.9692434215433547
Test accuracy: 96.92%
