In [1]:
import mne
import numpy as np
from scipy.signal import firwin, filtfilt
import matplotlib.pyplot as plt
import logging
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning, module="mne")
# Completely silence MNE-Python output
mne.set_log_level('WARNING')  # or 'ERROR' for even less output
logging.getLogger('mne').setLevel(logging.WARNING)

In [2]:
def load_eeg(path):
    raw = mne.io.read_raw_gdf(path, preload=True)

    # Step 2: Get the sampling frequency
    sfreq = raw.info['sfreq']  # Hz

    # Step 3: Define FIR filter parameters
    low_cutoff = 1.0   # Hz
    high_cutoff = 30.0 # Hz
    filter_order = 177 # Must be odd for linear-phase FIR

    nyquist = 0.5 * sfreq


    fir_coeffs = firwin(
        numtaps=filter_order,
        cutoff=[low_cutoff / nyquist, high_cutoff / nyquist],
        pass_zero=False,
        window='blackman'
    )
    eeg_data = raw.get_data()
    filtered_data = filtfilt(fir_coeffs, 1.0, eeg_data, axis=1)

    new_raw = mne.io.RawArray(filtered_data, raw.info.copy())
    annotations = raw.annotations 
    new_raw.set_annotations(annotations)

    return new_raw

In [3]:
def preprocess(path, test=False):
    raw = load_eeg(path)
    eeg_channels = raw.ch_names[:22]
    raw.pick(eeg_channels)
    
    if not test:
        events, events_id = mne.events_from_annotations(raw, event_id= {'769': 0,'770': 1,'771': 2,'772': 3})
    else:
        events, events_id = mne.events_from_annotations(raw, event_id= {'768':6})
        #print(events_id)
    #print(events_id)
    epochs = mne.Epochs(
        raw,
        events=events,  
        tmin=0,     
        tmax=4.0,
        event_id=events_id,
        baseline=None,
        preload=True
    )

    labels = epochs.events[:, 2]
    #print(labels)
    data = epochs.get_data()
    
    return {
        'epochs': data,   
        'labels': labels
    }

a = preprocess('/teamspace/studios/this_studio/EEG_REC/A01T.gdf')
# print(a['epochs'].shape)


#start of trials (768)

In [4]:
import torch
from torch.utils.data import Dataset, DataLoader
import scipy.io as sio
import os
from sklearn.model_selection import train_test_split

class BCI4_2a_Dataset(Dataset):

    def __init__(self, data_dir, subjects=[1], train=True, transform=None, target_transform=None, test=False):
        self.data_dir = data_dir
        self.subjects = subjects
        self.train = train
        self.transform = transform
        self.target_transform = target_transform
        self.test=test
        self.data, self.labels = self._load_data()
        
        self.labels = self.labels 
        
    def _load_data(self):
        all_data = []
        all_labels = []
        
        for subject in self.subjects:
            if self.train:
                filename = f'A{subject:02d}T.gdf'
            else:
                filename = f'A{subject:02d}E.gdf'
                
            filepath = os.path.join(self.data_dir, filename)
            data_ = preprocess(filepath, self.test)
            
            data = data_['epochs']
            labels = data_['labels']
           
            data = np.transpose(data, (0, 2, 1))
            
            all_data.append(data)
            all_labels.append(labels)
            
        all_data = np.concatenate(all_data, axis=0)
        all_labels = np.concatenate(all_labels, axis=0)
        #print(all_labels.shape)
        def normalize_eeg(trial_data):
            """trial_data shape: (channels, timepoints)"""
            means = trial_data.mean(axis=1, keepdims=True)
            stds = trial_data.std(axis=1, keepdims=True)
            return (trial_data - means) / (stds + 1e-8)

        all_data_n = normalize_eeg(all_data)

        return all_data_n, all_labels
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        x = self.data[idx]
        y = self.labels[idx]
        
        x = torch.from_numpy(x).float().unsqueeze(0)  # shape: (1, timepoints, channels)
        
        if self.transform:
            x = self.transform(x)
        if self.target_transform:
            y = self.target_transform(y)
            
        return x, y

def get_data_loaders(data_dir, subjects=[1], batch_size=32, val_split=0.2, test_split=0.2, random_state=42):

    full_train_dataset = BCI4_2a_Dataset(data_dir, subjects=subjects, train=True)
    
    test_dataset = BCI4_2a_Dataset(data_dir, subjects=subjects, train=False, test=True )
    
    train_idx, val_idx = train_test_split(
        range(len(full_train_dataset)),
        test_size=val_split,
        random_state=random_state,
        stratify=full_train_dataset.labels
    )
    
    train_dataset = torch.utils.data.Subset(full_train_dataset, train_idx)
    val_dataset = torch.utils.data.Subset(full_train_dataset, val_idx)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, val_loader, test_loader

# if __name__ == "__main__":
#     DATA_DIR = "/teamspace/studios/this_studio/EEG_REC"
    
#     train_loader, val_loader, test_loader = get_data_loaders(
#         data_dir=DATA_DIR,
#         subjects=[1,2,3,5,6,7,8,9],
#         batch_size=32
#     )
    
#     print("Dataset sizes:")
#     print(f"Training samples: {len(train_loader.dataset)}")
#     print(f"Validation samples: {len(val_loader.dataset)}")
#     print(f"Test samples: {len(test_loader.dataset)}")
    
#     x, y = next(iter(train_loader))
#     print("\nBatch shape:")
#     print(f"Input shape: {x.shape}")  # Should be (batch_size, 1, timepoints, channels)
#     print(f"Labels shape: {y.shape}")

In [5]:
import torch.nn as nn
import torch.nn.functional as F

class ChannelAttention(nn.Module):
    """Channel-wise attention mechanism for EEG signals"""
    def __init__(self, num_channels, reduction_ratio=4):
        super(ChannelAttention, self).__init__()
        self.num_channels = num_channels
        self.reduction_ratio = reduction_ratio
        
        self.mlp = nn.Sequential(
            nn.Linear(num_channels, num_channels // reduction_ratio),
            nn.ReLU(),
            nn.Linear(num_channels // reduction_ratio, num_channels)
        )
        
    def forward(self, x):
        
        channel_avg = torch.mean(x, dim=2, keepdim=True)  # (batch_size, 1, 1, channels)
        channel_max, _ = torch.max(x, dim=2, keepdim=True) # (batch_size, 1, 1, channels)
        
        combined = channel_avg + channel_max  # (batch_size, 1, 1, channels)
        combined = combined.squeeze(1).squeeze(1)  # (batch_size, channels)
        
        attention = self.mlp(combined)  # (batch_size, channels)
        attention = torch.sigmoid(attention).unsqueeze(1).unsqueeze(2)  # (batch_size, 1, 1, channels)
        
        attended_x = x * attention
        
        return attended_x, attention.squeeze()

class TemporalFeatureExtractor(nn.Module):
    """Temporal feature extractor with depthwise separable convolutions"""
    def __init__(self, num_channels, timepoints):
        super(TemporalFeatureExtractor, self).__init__()
        
        # Depthwise separable convolutions with BatchNorm
        self.depthwise = nn.Sequential(
            nn.Conv2d(1, num_channels, kernel_size=(1, 64), 
            groups=1, padding=(0, 32)),
            nn.BatchNorm2d(num_channels),
            nn.ELU()  # Using ELU for EEG data often works better
        )
        
        self.pointwise = nn.Sequential(
            nn.Conv2d(num_channels, 32, kernel_size=1),
            nn.BatchNorm2d(32),
            nn.ELU()
        )
        
        # Temporal convolution with BatchNorm
        self.temporal_conv = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=(1, 16), 
                     stride=(1, 4), padding=(0, 8)),
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.Dropout(0.3)  # Added dropout for better regularization
        )
        
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 64))
        
    def forward(self, x):
        
        x = F.relu(self.depthwise(x))
        x = F.relu(self.pointwise(x))
        
        x = F.relu(self.temporal_conv(x))
        
        x = self.adaptive_pool(x)
        
        return x

class EEGAttentionNet(nn.Module):
    def __init__(self, num_channels=22, num_classes=4, timepoints=1125):
        super(EEGAttentionNet, self).__init__()
        
        self.channel_attention = ChannelAttention(num_channels)
        self.temporal_extractor = TemporalFeatureExtractor(num_channels, timepoints)
        
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 64, 128),
            nn.BatchNorm1d(128),  # Added BatchNorm
            nn.ELU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )
        
    def forward(self, x):
        
        attended_x, attention_weights = self.channel_attention(x)
        
        temporal_features = self.temporal_extractor(attended_x)
        
        logits = self.classifier(temporal_features)
        
        return logits, attention_weights

# if __name__ == "__main__":
#     NUM_CHANNELS = 22  # BCI-IV 2a has 22 EEG channels
#     NUM_CLASSES = 4    # 4 motor imagery tasks
#     TIMEPOINTS = 1001  # Typical sample length in the dataset
    
#     # Create model
#     model = EEGAttentionNet(num_channels=NUM_CHANNELS, 
#                           num_classes=NUM_CLASSES,
#                           timepoints=TIMEPOINTS)
    
#     print(model)
    
#     batch_size = 32
#     x = torch.randn(batch_size, 1, TIMEPOINTS, NUM_CHANNELS)
#     logits, attention = model(x)
    
#     print(f"\nInput shape: {x.shape}")
#     print(f"Output logits shape: {logits.shape}")
#     print(f"Attention weights shape: {attention.shape}")

In [6]:
import os
import time
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import ReduceLROnPlateau
import json

class EEGTrainer:
    def __init__(self, model, train_loader, val_loader, test_loader=None, config=None):
        """
        Initialize the EEG trainer
        
        Args:
            model: EEGAttentionNet model
            train_loader: Training data loader
            val_loader: Validation data loader
            test_loader: Optional test data loader
            config: Dictionary of training configuration
        """
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        
        # Set device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self.model.to(self.device)
        
        # Default configuration
        self.config = {
            'lr': 1e-3,
            'weight_decay': 1e-4,
            'patience': 10,
            'min_lr': 1e-6,
            'epochs': 100,
            'save_dir': 'experiments',
            'experiment_name': f'exp_{time.strftime("%Y%m%d-%H%M%S")}',
            'save_attention_maps': True,
            'attention_map_freq': 5
        }
        
        # Update with user config if provided
        if config:
            self.config.update(config)
            
        # Create experiment directory
        self.exp_dir = os.path.join(self.config['save_dir'], self.config['experiment_name'])
        os.makedirs(self.exp_dir, exist_ok=True)
        
        # Initialize components
        self._init_components()
        
        # Save config
        self._save_config()
        
    def _init_components(self):
        """Initialize training components"""
        # Loss function (CrossEntropy + KLDiv for attention regularization)
        self.criterion = nn.CrossEntropyLoss()
        
        # Optimizer
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=self.config['lr'],
            weight_decay=self.config['weight_decay']
        )
        
        # Learning rate scheduler
        self.scheduler = ReduceLROnPlateau(
            self.optimizer,
            mode='max',
            factor=0.5,
            patience=self.config['patience']//2,
            min_lr=self.config['min_lr'],
            verbose=True
        )
        
        # Early stopping counter
        self.best_val_acc = 0.0
        self.early_stop_counter = 0
        
        # Tensorboard writer
        self.writer = SummaryWriter(log_dir=self.exp_dir)
        
        # Attention maps directory
        if self.config['save_attention_maps']:
            self.attention_dir = os.path.join(self.exp_dir, 'attention_maps')
            os.makedirs(self.attention_dir, exist_ok=True)
    
    def _save_config(self):
        """Save training configuration to JSON file"""
        config_path = os.path.join(self.exp_dir, 'config.json')
        with open(config_path, 'w') as f:
            json.dump(self.config, f, indent=4)
    
    def _compute_metrics(self, outputs, labels):
        """Compute accuracy and other metrics"""
        _, predicted = torch.max(outputs.data, 1)
        correct = (predicted == labels).sum().item()
        accuracy = correct / labels.size(0)
        return accuracy
    
    def _log_metrics(self, phase, metrics, epoch):
        """Log metrics to Tensorboard and print to console"""
        loss = metrics['loss']
        acc = metrics['accuracy']
        
        # Console logging
        print(f"{phase.capitalize()} - Epoch: {epoch+1} | "
              f"Loss: {loss:.4f} | Acc: {acc:.2%}")
        
        # Tensorboard logging
        self.writer.add_scalar(f'Loss/{phase}', loss, epoch)
        self.writer.add_scalar(f'Accuracy/{phase}', acc, epoch)
        
        # Log learning rate
        if phase == 'train':
            lr = self.optimizer.param_groups[0]['lr']
            self.writer.add_scalar('LR', lr, epoch)
    
    def _save_checkpoint(self, epoch, is_best=False):
        """Save model checkpoint"""
        state = {
            'epoch': epoch,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict(),
            'best_val_acc': self.best_val_acc
        }
        
        # Save regular checkpoint
        checkpoint_path = os.path.join(self.exp_dir, f'checkpoint_epoch_{epoch}.pth')
        torch.save(state, checkpoint_path)
        
        # Save best model
        if is_best:
            best_path = os.path.join(self.exp_dir, 'best_model.pth')
            torch.save(state, best_path)
    
    def _save_attention_maps(self, attention_weights, epoch):
        """Save attention weights for visualization"""
        attention_path = os.path.join(
            self.attention_dir,
            f'attention_epoch_{epoch}.npy'
        )
        np.save(attention_path, attention_weights.cpu().numpy())
    
    def train_epoch(self, epoch):
        """Train for one epoch"""
        self.model.train()
        running_loss = 0.0
        running_acc = 0.0
        total_samples = 0
        
        for inputs, labels in self.train_loader:
            inputs = inputs.to(self.device)
            labels = labels.to(self.device)
            #print(labels)
            # Zero gradients
            self.optimizer.zero_grad()
            
            # Forward pass
            outputs, attention_weights = self.model(inputs)
            loss = self.criterion(outputs, labels)
            
            # Backward pass and optimize
            loss.backward()
            self.optimizer.step()
            
            # Compute metrics
            batch_size = inputs.size(0)
            running_loss += loss.item() * batch_size
            running_acc += self._compute_metrics(outputs, labels) * batch_size
            total_samples += batch_size
        
        epoch_loss = running_loss / total_samples
        epoch_acc = running_acc / total_samples
        
        return {
            'loss': epoch_loss,
            'accuracy': epoch_acc,
            'attention_weights': attention_weights if self.config['save_attention_maps'] else None
        }
    
    def validate_epoch(self, epoch):
        """Validate for one epoch"""
        self.model.eval()
        running_loss = 0.0
        running_acc = 0.0
        total_samples = 0
        all_attention_weights = []
        
        with torch.no_grad():
            for inputs, labels in self.val_loader:
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)
     
                # Forward pass
                outputs, attention_weights = self.model(inputs)
                loss = self.criterion(outputs, labels)
                
                # Compute metrics
                batch_size = inputs.size(0)
                running_loss += loss.item() * batch_size
                running_acc += self._compute_metrics(outputs, labels) * batch_size
                total_samples += batch_size
                
                if self.config['save_attention_maps']:
                    all_attention_weights.append(attention_weights.cpu())
        
        epoch_loss = running_loss / total_samples
        epoch_acc = running_acc / total_samples
        
        attention_weights = torch.cat(all_attention_weights) if all_attention_weights else None
        
        return {
            'loss': epoch_loss,
            'accuracy': epoch_acc,
            'attention_weights': attention_weights
        }
    
    def train(self):
        """Main training loop"""
        start_time = time.time()
        
        for epoch in range(self.config['epochs']):
            # Train and validate
            train_metrics = self.train_epoch(epoch)
            val_metrics = self.validate_epoch(epoch)
            
            # Log metrics
            self._log_metrics('train', train_metrics, epoch)
            self._log_metrics('val', val_metrics, epoch)
            
            # Save attention maps periodically
            if (self.config['save_attention_maps'] and 
                (epoch % self.config['attention_map_freq'] == 0 or epoch == self.config['epochs']-1)):
                self._save_attention_maps(val_metrics['attention_weights'], epoch)
            
            # Step scheduler
            self.scheduler.step(val_metrics['accuracy'])
            
            # Check for best model
            if val_metrics['accuracy'] > self.best_val_acc:
                self.best_val_acc = val_metrics['accuracy']
                self._save_checkpoint(epoch, is_best=True)
                self.early_stop_counter = 0
            else:
                self.early_stop_counter += 1
            
            # Save checkpoint periodically
            if epoch % 10 == 0:
                self._save_checkpoint(epoch)
            
            # Early stopping
            # if self.early_stop_counter >= self.config['patience']:
            #     print(f"Early stopping at epoch {epoch+1}")
            #     break
        
        # Training complete
        training_time = time.time() - start_time
        print(f"Training completed in {training_time//60:.0f}m {training_time%60:.0f}s")
        print(f"Best validation accuracy: {self.best_val_acc:.2%}")
        
        # Test if test loader provided
        if self.test_loader:
            test_acc = self.test()
            print(f"Test accuracy: {test_acc:.2%}")
        
        # Close tensorboard writer
        self.writer.close()
        
        return self.best_val_acc
    
    def test(self):
        """Evaluate on test set"""
        self.model.eval()
        running_acc = 0.0
        total_samples = 0
        
        # Load best model
        best_path = os.path.join(self.exp_dir, 'best_model.pth')
        if os.path.exists(best_path):
            checkpoint = torch.load(best_path)
            self.model.load_state_dict(checkpoint['state_dict'])
            print(f"Loaded best model from epoch {checkpoint['epoch']}")
        
        with torch.no_grad():
            for inputs, labels in self.test_loader:
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)
                
                outputs, _ = self.model(inputs)
                running_acc += self._compute_metrics(outputs, labels) * labels.size(0)
                total_samples += labels.size(0)
        
        test_acc = running_acc / total_samples
        self.writer.add_scalar('Accuracy/test', test_acc)
        
        return test_acc

# Example usage
if __name__ == "__main__":
    # Example configuration
    config = {
        'lr': 3e-4,
        'weight_decay': 1e-5,
        'patience': 15,
        'epochs': 500,
        'experiment_name': 'eeg_attention_experiment',
        'save_attention_maps': True
    }
    
    # Initialize components (using the previous data loader and model code)
    model = EEGAttentionNet(num_channels=22, num_classes=4, timepoints=1125)
    train_loader, val_loader, test_loader = get_data_loaders(data_dir="/teamspace/studios/this_studio/EEG_REC", subjects = [1,2,3,5,6,7,8,9])
    
    # Create trainer
    trainer = EEGTrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        config=config
    )
    
    # Start training
    best_val_acc = trainer.train()

2025-07-21 08:49:45.235074: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2025-07-21 08:49:45.281961: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-07-21 08:49:45.282009: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-07-21 08:49:45.282043: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-07-21 08:49:45.291023: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: A

Train - Epoch: 1 | Loss: 1.4420 | Acc: 31.31%
Val - Epoch: 1 | Loss: 1.3227 | Acc: 34.06%
Train - Epoch: 2 | Loss: 1.3500 | Acc: 36.68%
Val - Epoch: 2 | Loss: 1.4187 | Acc: 29.50%
Train - Epoch: 3 | Loss: 1.3153 | Acc: 39.23%
Val - Epoch: 3 | Loss: 1.4910 | Acc: 38.61%
Train - Epoch: 4 | Loss: 1.2824 | Acc: 39.45%
Val - Epoch: 4 | Loss: 1.4825 | Acc: 27.98%
Train - Epoch: 5 | Loss: 1.2484 | Acc: 42.54%
Val - Epoch: 5 | Loss: 1.3176 | Acc: 40.35%
Train - Epoch: 6 | Loss: 1.2119 | Acc: 47.26%
Val - Epoch: 6 | Loss: 1.7936 | Acc: 30.80%
Train - Epoch: 7 | Loss: 1.1891 | Acc: 46.55%
Val - Epoch: 7 | Loss: 1.3239 | Acc: 40.56%
Train - Epoch: 8 | Loss: 1.1532 | Acc: 50.30%
Val - Epoch: 8 | Loss: 1.5068 | Acc: 34.49%
Train - Epoch: 9 | Loss: 1.1147 | Acc: 50.57%
Val - Epoch: 9 | Loss: 1.5174 | Acc: 40.35%
Train - Epoch: 10 | Loss: 1.0927 | Acc: 52.58%
Val - Epoch: 10 | Loss: 1.2459 | Acc: 44.03%
Train - Epoch: 11 | Loss: 1.0714 | Acc: 53.34%
Val - Epoch: 11 | Loss: 1.2674 | Acc: 40.35%
Train 

KeyboardInterrupt: 