In [1]:
import mne
import numpy as np
from scipy.signal import firwin, filtfilt
import matplotlib.pyplot as plt
import logging
import warnings
import torch
import random
warnings.filterwarnings("ignore", category=RuntimeWarning, module="mne")
# Completely silence MNE-Python output
mne.set_log_level('WARNING')  # or 'ERROR' for even less output
logging.getLogger('mne').setLevel(logging.WARNING)
#mne.set_log_level('debug')

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [2]:
def load_eeg(path):
    raw = mne.io.read_raw_gdf(path, preload=True)

    # Step 2: Get the sampling frequency
    sfreq = raw.info['sfreq']  # Hz

    # Step 3: Define FIR filter parameters
    low_cutoff = 1.0   # Hz
    high_cutoff = 30.0 # Hz
    filter_order = 177 # Must be odd for linear-phase FIR

    nyquist = 0.5 * sfreq


    fir_coeffs = firwin(
        numtaps=filter_order,
        cutoff=[low_cutoff / nyquist, high_cutoff / nyquist],
        pass_zero=False,
        window='blackman'
    )
    eeg_data = raw.get_data()
    filtered_data = filtfilt(fir_coeffs, 1.0, eeg_data, axis=1)

    new_raw = mne.io.RawArray(filtered_data, raw.info.copy())
    annotations = raw.annotations 
    new_raw.set_annotations(annotations)

    return new_raw

In [3]:
def preprocess(path, test=False):
    raw = load_eeg(path)
    eeg_channels = raw.ch_names[:22]
    raw.pick(eeg_channels)
    
    if not test:
        events, events_id = mne.events_from_annotations(raw, event_id= {'769': 0,'770': 1,'771': 2,'772': 3})
    else:
        events, events_id = mne.events_from_annotations(raw, event_id= {'768':6})
        #print(events_id)
    #print(events_id)
    epochs = mne.Epochs(
        raw,
        events=events,  
        tmin=0,     
        tmax=6.0,
        event_id=events_id,
        baseline=None,
        preload=True
    )

    labels = epochs.events[:, 2]
    #print(labels)
    data = epochs.get_data()
    
    return {
        'epochs': data,   
        'labels': labels
    }

#a = preprocess("E:/LOKI/BCI-IV/A01T.gdf")
# print(a['epochs'].shape)


In [4]:
# a = preprocess('/teamspace/studios/this_studio/EEG_REC/A01T.gdf')
# print(a['epochs'].shape)

In [5]:
import torch
import torch.nn as nn
import math

class ChannelAttention(nn.Module):
    def __init__(self, num_channels, reduction_ratio=4):
        super(ChannelAttention, self).__init__()
        self.num_channels = num_channels
        self.reduction_ratio = reduction_ratio

        self.mlp = nn.Sequential(
            nn.Linear(num_channels, num_channels // reduction_ratio),
            nn.ReLU(),
            nn.Linear(num_channels // reduction_ratio, num_channels)
        )

    def forward(self, x):
        channel_avg = torch.mean(x, dim=2, keepdim=True)  
        channel_max, _ = torch.max(x, dim=2, keepdim=True) 
        combined = channel_avg + channel_max  
        combined = combined.squeeze(2)
        attention = self.mlp(combined)  # (batch_size, channels)
        attention = torch.sigmoid(attention).unsqueeze(2) # (batch_size, 1, 1, channels)
        attended_x = x * attention
        return attended_x, attention.squeeze()

class LearnableSTFT(nn.Module):

    def __init__(self, window_size, hop_size, dft_size=None):
        super(LearnableSTFT, self).__init__()


        self.window_size = window_size
        self.hop_size = hop_size
        self.dft_size = dft_size if dft_size is not None else window_size

        initial_window = 0.54 - 0.46 * torch.cos(
            2 * math.pi * torch.arange(window_size, dtype=torch.float32) / (window_size - 1)
        )

        self.window = nn.Parameter(initial_window)

        dft_matrix = self._create_dft_matrix(self.dft_size, self.window_size)
        
        self.register_buffer('dft_matrix', dft_matrix)

    def _create_dft_matrix(self, dft_size, window_size):

        k = torch.arange(dft_size).unsqueeze(1) # Shape: [dft_size, 1]
        n = torch.arange(window_size)            # Shape: [window_size]

        # Calculate the angle for the complex exponential
        angle = -2 * math.pi * k * n / dft_size

        # Use Euler's formula to create the complex matrix
        # e^(j*angle) = cos(angle) + j*sin(angle)
        dft_matrix = torch.complex(torch.cos(angle), torch.sin(angle))
        return dft_matrix

    def forward(self, signal):

        if signal.dim() == 1:
            signal = signal.unsqueeze(0).unsqueeze(0) # [1, 1, T]
        elif signal.dim() == 2:
            signal = signal.unsqueeze(1) # [B, 1, T]
        batch_size, num_channels, num_samples = signal.shape

        signal_reshaped = signal.reshape(batch_size * num_channels, num_samples)

        learnable_window = self.window
        frames = signal_reshaped.unfold(dimension=1, size=self.window_size, step=self.hop_size)

        num_frames_unfolded = frames.shape[1]
        expected_num_frames = int(math.ceil((num_samples - self.window_size) / self.hop_size)) + 1

        if num_frames_unfolded < expected_num_frames:
            padding_amount = (expected_num_frames - 1) * self.hop_size + self.window_size - num_samples
            padded_signal = torch.nn.functional.pad(signal_reshaped, (0, padding_amount))
            frames = padded_signal.unfold(1, self.window_size, self.hop_size)

        windowed_frames = frames * learnable_window

        BC, F, W = windowed_frames.shape # BC = batch_size * num_channels
        windowed_frames_reshaped = windowed_frames.reshape(BC * F, W)

        windowed_frames_complex = windowed_frames_reshaped.to(self.dft_matrix.dtype)
        stft_result_reshaped = self.dft_matrix @ windowed_frames_complex.T
        stft_result = stft_result_reshaped.T.reshape(batch_size, num_channels, F, self.dft_size)

        return stft_result

class Attention4D(nn.Module):

    def __init__(self, in_channels, time_frames, freq_bins, d_model=128, n_heads=4, d_ff=256, dropout=0.1):

        super().__init__()

        self.d_model = d_model
        self.time_frames = time_frames

        in_features = in_channels * freq_bins

        self.projection = nn.Linear(in_features, d_model)

        self.positional_encoding = nn.Parameter(torch.randn(1, time_frames, d_model))

        self.layernorm1 = nn.LayerNorm(d_model)
        self.attention = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_heads, dropout=dropout, batch_first=True)
        self.dropout1 = nn.Dropout(dropout)

        self.layernorm2 = nn.LayerNorm(d_model)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )
        self.dropout2 = nn.Dropout(dropout)


    def forward(self, x):


        batch_size = x.shape[0]

        x = x.permute(0, 2, 1, 3)

        x = x.reshape(batch_size, self.time_frames, -1)

        x = self.projection(x) # Output: (batch, time_frames, d_model)

        x = x + self.positional_encoding

        x_norm1 = self.layernorm1(x)
        attn_output, _ = self.attention(x_norm1, x_norm1, x_norm1)
        x = x + self.dropout1(attn_output)

        x_norm2 = self.layernorm2(x)
        ff_output = self.feed_forward(x_norm2)
        x = x + self.dropout2(ff_output)

        return x
class Classifier(nn.Module):

    def __init__(self, in_features, num_classes):

        super().__init__()

        self.pooling = nn.AdaptiveAvgPool1d(1) # A flexible way to do global average pooling
        self.classifier = nn.Linear(in_features, num_classes)

    def forward(self, x):
        
        x = x.permute(0, 2, 1)  # Shape -> (4, 128, 17)

        x = self.pooling(x).squeeze(2) # Shape -> (4, 128)

        output = self.classifier(x) # Shape -> (4, num_classes)

        return output
    
class EEGClassifier(nn.Module):

    def __init__(self, num_classes=4):
        super(EEGClassifier, self).__init__()
        self.channel_attn = ChannelAttention(num_channels = 22)
        self.learnable_stft = LearnableSTFT(window_size=250, hop_size=100)
        
        self.attention_module = Attention4D(
            in_channels=22,
            time_frames=14,
            freq_bins=250,
            d_model=128,
            n_heads=8
        )
        self.classifier = Classifier(in_features=128, num_classes=num_classes)

    def forward(self, x):

        x,_ = self.channel_attn(x)
        x = self.learnable_stft(x)
        x = torch.abs(x)
        x = self.attention_module(x)
        x = self.classifier(x)

        return x
    

#start of trials (768)

In [6]:
import torch
from torch.utils.data import Dataset, DataLoader
import scipy.io as sio
import os
from sklearn.model_selection import train_test_split

class BCI4_2a_Dataset(Dataset):

    def __init__(self, data_dir, subjects=[1], train=True, transform=None, target_transform=None, test=False):
        self.data_dir = data_dir
        self.subjects = subjects
        self.train = train
        self.transform = transform
        self.target_transform = target_transform
        self.test=test
        self.data, self.labels = self._load_data()
        
        self.labels = self.labels 
        
    def _load_data(self):
        all_data = []
        all_labels = []
        
        for subject in self.subjects:
            if self.train:
                filename = f'A{subject:02d}T.gdf'
            else:
                filename = f'A{subject:02d}E.gdf'
                
            filepath = os.path.join(self.data_dir, filename)
            data_ = preprocess(filepath, self.test)
            
            data = data_['epochs']
            labels = data_['labels']
            #data = np.transpose(data, (0, 2, 1))
            
            all_data.append(data)
            all_labels.append(labels)
            
        all_data = np.concatenate(all_data, axis=0)
        all_labels = np.concatenate(all_labels, axis=0)
        #print(all_labels.shape)
        def normalize_eeg(trial_data):
            """trial_data shape: (channels, timepoints)"""
            means = trial_data.mean(axis=1, keepdims=True)
            stds = trial_data.std(axis=1, keepdims=True)
            return (trial_data - means) / (stds + 1e-8)

        all_data_n = normalize_eeg(all_data)

        return all_data_n, all_labels
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        x = self.data[idx]
        y = self.labels[idx]
        
        x = torch.from_numpy(x).float()  # shape: (1, timepoints, channels)
        
        if self.transform:
            x = self.transform(x)
        if self.target_transform:
            y = self.target_transform(y)
            
        return x, y

def get_data_loaders(data_dir, subjects=[1], batch_size=32, val_split=0.2, test_split=0.2, random_state=42):

    full_train_dataset = BCI4_2a_Dataset(data_dir, subjects=subjects, train=True)
    
    test_dataset = BCI4_2a_Dataset(data_dir, subjects=subjects, train=False, test=True )
    
    train_idx, val_idx = train_test_split(
        range(len(full_train_dataset)),
        test_size=val_split,
        random_state=random_state,
        stratify=full_train_dataset.labels
    )
    
    train_dataset = torch.utils.data.Subset(full_train_dataset, train_idx)
    val_dataset = torch.utils.data.Subset(full_train_dataset, val_idx)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, val_loader, test_loader

# if __name__ == "__main__":
#     DATA_DIR = "/teamspace/studios/this_studio/EEG_REC"
    
#     train_loader, val_loader, test_loader = get_data_loaders(
#         data_dir=DATA_DIR,
#         subjects=[1,2,3,5,6,7,8,9],
#         batch_size=32
#     )
    
#     print("Dataset sizes:")
#     print(f"Training samples: {len(train_loader.dataset)}")
#     print(f"Validation samples: {len(val_loader.dataset)}")
#     print(f"Test samples: {len(test_loader.dataset)}")
    
#     x, y = next(iter(train_loader))
#     print("\nBatch shape:")
#     print(f"Input shape: {x.shape}")  # Should be (batch_size, 1, timepoints, channels)
#     print(f"Labels shape: {y.shape}")

In [7]:
import os
import time
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import ReduceLROnPlateau
import json

class EEGTrainer:
    def __init__(self, model, train_loader, val_loader, test_loader=None, config=None):
        """
        Initialize the EEG trainer
        
        Args:
            model: EEGAttentionNet model
            train_loader: Training data loader
            val_loader: Validation data loader
            test_loader: Optional test data loader
            config: Dictionary of training configuration
        """
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        
        # Set device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self.model.to(self.device)
        
        # Default configuration
        self.config = {
            'lr': 1e-3,
            'weight_decay': 1e-4,
            'patience': 10,
            'min_lr': 1e-6,
            'epochs': 100,
            'save_dir': 'experiments',
            'experiment_name': f'exp_{time.strftime("%Y%m%d-%H%M%S")}',
            'save_attention_maps': True,
            'attention_map_freq': 5
        }
        
        # Update with user config if provided
        if config:
            self.config.update(config)
            
        # Create experiment directory
        self.exp_dir = os.path.join(self.config['save_dir'], self.config['experiment_name'])
        os.makedirs(self.exp_dir, exist_ok=True)
        
        # Initialize components
        self._init_components()
        
        # Save config
        self._save_config()
        
    def _init_components(self):
        """Initialize training components"""
        # Loss function (CrossEntropy + KLDiv for attention regularization)
        self.criterion = nn.CrossEntropyLoss()
        
        # Optimizer
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=self.config['lr'],
            weight_decay=self.config['weight_decay']
        )
        
        # Learning rate scheduler
        self.scheduler = ReduceLROnPlateau(
            self.optimizer,
            mode='max',
            factor=0.5,
            patience=self.config['patience']//2,
            min_lr=self.config['min_lr'],
            verbose=True
        )
        
        # Early stopping counter
        self.best_val_acc = 0.0
        self.early_stop_counter = 0
        
        # Tensorboard writer
        self.writer = SummaryWriter(log_dir=self.exp_dir)
        
        # Attention maps directory
        if self.config['save_attention_maps']:
            self.attention_dir = os.path.join(self.exp_dir, 'attention_maps')
            os.makedirs(self.attention_dir, exist_ok=True)
    
    def _save_config(self):
        """Save training configuration to JSON file"""
        config_path = os.path.join(self.exp_dir, 'config.json')
        with open(config_path, 'w') as f:
            json.dump(self.config, f, indent=4)
    
    def _compute_metrics(self, outputs, labels):
        """Compute accuracy and other metrics"""
        _, predicted = torch.max(outputs.data, 1)
        correct = (predicted == labels).sum().item()
        accuracy = correct / labels.size(0)
        return accuracy
    
    def _log_metrics(self, phase, metrics, epoch):
        """Log metrics to Tensorboard and print to console"""
        loss = metrics['loss']
        acc = metrics['accuracy']
        
        # Console logging
        print(f"{phase.capitalize()} - Epoch: {epoch+1} | "
              f"Loss: {loss:.4f} | Acc: {acc:.2%}")
        
        # Tensorboard logging
        self.writer.add_scalar(f'Loss/{phase}', loss, epoch)
        self.writer.add_scalar(f'Accuracy/{phase}', acc, epoch)
        
        # Log learning rate
        if phase == 'train':
            lr = self.optimizer.param_groups[0]['lr']
            self.writer.add_scalar('LR', lr, epoch)
    
    def _save_checkpoint(self, epoch, is_best=False):
        """Save model checkpoint"""
        state = {
            'epoch': epoch,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict(),
            'best_val_acc': self.best_val_acc
        }
        
        # Save regular checkpoint
        checkpoint_path = os.path.join(self.exp_dir, f'checkpoint_epoch_{epoch}.pth')
        torch.save(state, checkpoint_path)
        
        # Save best model
        if is_best:
            best_path = os.path.join(self.exp_dir, 'best_model.pth')
            torch.save(state, best_path)
    

    
    def train_epoch(self, epoch):
        """Train for one epoch"""
        self.model.train()
        running_loss = 0.0
        running_acc = 0.0
        total_samples = 0
        
        for inputs, labels in self.train_loader:
            inputs = inputs.to(self.device)
            labels = labels.to(self.device)
            #print(labels)
            # Zero gradients
            self.optimizer.zero_grad()
            
            # Forward pass
            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels)
            
            # Backward pass and optimize
            loss.backward()
            self.optimizer.step()
            
            # Compute metrics
            batch_size = inputs.size(0)
            running_loss += loss.item() * batch_size
            running_acc += self._compute_metrics(outputs, labels) * batch_size
            total_samples += batch_size
        
        epoch_loss = running_loss / total_samples
        epoch_acc = running_acc / total_samples
        
        return {
            'loss': epoch_loss,
            'accuracy': epoch_acc,
        }
    
    def validate_epoch(self, epoch):
        """Validate for one epoch"""
        self.model.eval()
        running_loss = 0.0
        running_acc = 0.0
        total_samples = 0
        
        with torch.no_grad():
            for inputs, labels in self.val_loader:
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)
     
                # Forward pass
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)
                
                # Compute metrics
                batch_size = inputs.size(0)
                running_loss += loss.item() * batch_size
                running_acc += self._compute_metrics(outputs, labels) * batch_size
                total_samples += batch_size
                
        
        epoch_loss = running_loss / total_samples
        epoch_acc = running_acc / total_samples
        
        
        return {
            'loss': epoch_loss,
            'accuracy': epoch_acc,
        }
    
    def train(self):
        """Main training loop"""
        start_time = time.time()
        
        for epoch in range(self.config['epochs']):
            # Train and validate
            train_metrics = self.train_epoch(epoch)
            val_metrics = self.validate_epoch(epoch)
            
            # Log metrics
            self._log_metrics('train', train_metrics, epoch)
            self._log_metrics('val', val_metrics, epoch)
            
            # Save attention maps periodically
            
            # Step scheduler
            self.scheduler.step(val_metrics['accuracy'])
            
            # Check for best model
            if val_metrics['accuracy'] > self.best_val_acc:
                self.best_val_acc = val_metrics['accuracy']
                self._save_checkpoint(epoch, is_best=True)
                self.early_stop_counter = 0
            else:
                self.early_stop_counter += 1
            
            # Save checkpoint periodically
            if epoch % 10 == 0:
                self._save_checkpoint(epoch)
            
            # Early stopping
            # if self.early_stop_counter >= self.config['patience']:
            #     print(f"Early stopping at epoch {epoch+1}")
            #     break
        
        # Training complete
        training_time = time.time() - start_time
        print(f"Training completed in {training_time//60:.0f}m {training_time%60:.0f}s")
        print(f"Best validation accuracy: {self.best_val_acc:.2%}")
        
        # Test if test loader provided
        if self.test_loader:
            test_acc = self.test()
            print(f"Test accuracy: {test_acc:.2%}")
        
        # Close tensorboard writer
        self.writer.close()
        
        return self.best_val_acc
    
    def test(self):
        """Evaluate on test set"""
        self.model.eval()
        running_acc = 0.0
        total_samples = 0
        
        # Load best model
        best_path = os.path.join(self.exp_dir, 'best_model.pth')
        if os.path.exists(best_path):
            checkpoint = torch.load(best_path)
            self.model.load_state_dict(checkpoint['state_dict'])
            print(f"Loaded best model from epoch {checkpoint['epoch']}")
        
        with torch.no_grad():
            for inputs, labels in self.test_loader:
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)
                
                outputs, _ = self.model(inputs)
                running_acc += self._compute_metrics(outputs, labels) * labels.size(0)
                total_samples += labels.size(0)
        
        test_acc = running_acc / total_samples
        self.writer.add_scalar('Accuracy/test', test_acc)
        
        return test_acc

# Example usage
if __name__ == "__main__":
    # Example configuration
    config = {
        'lr': 3e-4,
        'weight_decay': 1e-5,
        'patience': 15,
        'epochs': 500,
        'experiment_name': 'eeg_attention_experiment',
        'save_attention_maps': True
    }
    
    # Initialize components (using the previous data loader and model code)
    model = EEGClassifier(num_classes=4)
    train_loader, val_loader, test_loader = get_data_loaders(data_dir="/teamspace/studios/this_studio/EEG_REC", subjects = [1,2,3,5,6,7,8,9])
    
    # Create trainer
    trainer = EEGTrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        config=config
    )
    
    # Start training
    best_val_acc = trainer.train()

2025-08-02 17:55:46.087202: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2025-08-02 17:55:46.501841: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-08-02 17:55:46.501947: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-08-02 17:55:46.504334: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-08-02 17:55:46.766687: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: A

Train - Epoch: 1 | Loss: 3.1257 | Acc: 26.58%
Val - Epoch: 1 | Loss: 1.4869 | Acc: 31.74%
Train - Epoch: 2 | Loss: 1.4734 | Acc: 31.54%
Val - Epoch: 2 | Loss: 1.5332 | Acc: 29.78%
Train - Epoch: 3 | Loss: 1.4785 | Acc: 32.35%
Val - Epoch: 3 | Loss: 1.5823 | Acc: 26.30%
Train - Epoch: 4 | Loss: 1.4291 | Acc: 34.69%
Val - Epoch: 4 | Loss: 1.3433 | Acc: 35.00%
Train - Epoch: 5 | Loss: 1.2677 | Acc: 41.72%
Val - Epoch: 5 | Loss: 1.2090 | Acc: 45.87%
Train - Epoch: 6 | Loss: 1.2233 | Acc: 45.70%
Val - Epoch: 6 | Loss: 1.1974 | Acc: 47.17%
Train - Epoch: 7 | Loss: 1.0587 | Acc: 56.37%
Val - Epoch: 7 | Loss: 1.2052 | Acc: 48.91%
Train - Epoch: 8 | Loss: 0.9946 | Acc: 59.10%
Val - Epoch: 8 | Loss: 1.4230 | Acc: 39.78%
Train - Epoch: 9 | Loss: 0.8129 | Acc: 69.99%
Val - Epoch: 9 | Loss: 1.1989 | Acc: 49.78%
Train - Epoch: 10 | Loss: 0.6990 | Acc: 74.46%
Val - Epoch: 10 | Loss: 1.2951 | Acc: 46.74%
Train - Epoch: 11 | Loss: 0.5685 | Acc: 79.30%
Val - Epoch: 11 | Loss: 1.2700 | Acc: 48.04%
Train 

KeyboardInterrupt: 

Saved 2304 individual epochs to /teamspace/studios/this_studio/epochs


In [7]:
import numpy as np

# Load the saved epoch
epoch_file = '/teamspace/studios/this_studio/epochs/epoch_0.npy'
loaded = np.load(epoch_file, allow_pickle=True).item()

# Access data and label
epoch_data = loaded['data']
epoch_label = loaded['label']

print("Epoch data shape:", epoch_data.shape)
print("Epoch label:", epoch_label)


Epoch data shape: (22, 31, 95)
Epoch label: 3


In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ChannelAttention(nn.Module):
    def __init__(self, num_channels, reduction_ratio=4):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc = nn.Sequential(
            nn.Linear(num_channels, num_channels // reduction_ratio, bias=False),
            nn.ReLU(),
            nn.Linear(num_channels // reduction_ratio, num_channels, bias=False)
        )
        self.sigmoid = nn.Sigmoid() # Sigmoid activation to scale weights between 0 and 1

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x).squeeze(-1).squeeze(-1))
        max_out = self.fc(self.max_pool(x).squeeze(-1).squeeze(-1))

        out = avg_out + max_out
        return self.sigmoid(out).unsqueeze(-1).unsqueeze(-1)

class EEGDecoder(nn.Module):
    def __init__(self, num_channels=22, freq_bins=64, time_bins=76, num_classes=4,
                 mha_num_heads=5, mha_dropout=0.1, channel_attention_reduction=4):
        super(EEGDecoder, self).__init__()

        self.num_channels = num_channels
        self.freq_bins = freq_bins
        self.time_bins = time_bins
        self.num_classes = num_classes

        self.channel_attention = ChannelAttention(num_channels, reduction_ratio=channel_attention_reduction)

        self.multihead_attention = nn.MultiheadAttention(
            embed_dim=time_bins,      # The embedding dimension for MHA, must match the last dimension of input
            num_heads=mha_num_heads,  # Number of parallel attention heads
            dropout=mha_dropout,      # Dropout rate for attention weights
            batch_first=False         # Specifies input format: (seq_len, batch_size, embed_dim)
        )

        flattened_size = num_channels * freq_bins * time_bins
        self.classifier = nn.Sequential(
            nn.Flatten(),                   # Flattens the input tensor (N, C, F, T) into (N, C*F*T)
            nn.Linear(flattened_size, 512), # First fully connected layer with 512 hidden units
            nn.ReLU(),                      # ReLU activation function
            nn.Dropout(0.5),                # Dropout for regularization to prevent overfitting
            nn.Linear(512, num_classes)     # Output layer, mapping to the number of classification classes
        )

    def forward(self, x):
        N, C, F, T = x.shape

        channel_attn_weights = self.channel_attention(x) # Shape: (N, C, 1, 1)
        x = x * channel_attn_weights                     # Resulting shape: (N, C, F, T)

        x_reshaped_for_mha = x.view(N * C, F, T).permute(1, 0, 2)

        mha_output, _ = self.multihead_attention(
            query=x_reshaped_for_mha,
            key=x_reshaped_for_mha,
            value=x_reshaped_for_mha
        ) # Output shape: (F, N*C, T)

        x = mha_output.permute(1, 0, 2).view(N, C, F, T)

        logits = self.classifier(x) # Output shape: (N, num_classes)

        return logits

# if __name__ == "__main__":
#     num_channels = 22
#     freq_bins = 31
#     time_bins = 95
#     num_classes = 4
#     batch_size = 16 # Example batch size

#     dummy_input = torch.randn(batch_size, num_channels, freq_bins, time_bins)
#     print(f"Dummy input shape: {dummy_input.shape}")

#     model = EEGDecoder(
#         num_channels=num_channels,
#         freq_bins=freq_bins,
#         time_bins=time_bins,
#         num_classes=num_classes
#     )
#     print("\nModel Architecture:")
#     print(model)

#     output_logits = model(dummy_input)
#     print(f"\nOutput logits shape: {output_logits.shape}")

#     expected_output_shape = (batch_size, num_classes)
#     assert output_logits.shape == expected_output_shape, \
#         f"Output shape mismatch! Expected {expected_output_shape}, got {output_logits.shape}"
#     print("Output shape is correct!")

#     probabilities = F.softmax(output_logits, dim=1)
#     print(f"Example probabilities for the first sample: {probabilities[0]}")



Dummy input shape: torch.Size([16, 22, 31, 95])

Model Architecture:
EEGDecoder(
  (channel_attention): ChannelAttention(
    (avg_pool): AdaptiveAvgPool2d(output_size=1)
    (max_pool): AdaptiveMaxPool2d(output_size=1)
    (fc): Sequential(
      (0): Linear(in_features=22, out_features=5, bias=False)
      (1): ReLU()
      (2): Linear(in_features=5, out_features=22, bias=False)
    )
    (sigmoid): Sigmoid()
  )
  (multihead_attention): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=95, out_features=95, bias=True)
  )
  (classifier): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=64790, out_features=512, bias=True)
    (2): ReLU()
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=512, out_features=4, bias=True)
  )
)

Output logits shape: torch.Size([16, 4])
Output shape is correct!
Example probabilities for the first sample: tensor([0.2553, 0.2518, 0.2418, 0.2511], grad_fn=<SelectBackward0>)
