## EEGConformer - Neonatal EGG Seizure detection  

This project tackles the challenge of **class imbalance** in neonatal EEG classification, where the dataset naturally contains uneven numbers of samples across different classes. Initially, the raw dataset is loaded and visualized, revealing significant imbalance which can negatively impact model performance by biasing it toward majority classes. To address this, two complementary strategies are employed: **SMOTE (Synthetic Minority Over-sampling Technique)** and **cost-sensitive learning**.

In the **SMOTE approach**, the dataset samples are first flattened to a 2D format suitable for synthetic sample generation. SMOTE creates new minority class samples to balance the dataset, which is then reshaped back to the original multi-channel time-series format. This augmented data is split into training and validation sets to train the EEGConformer deep learning model.

In parallel, **cost-sensitive learning** keeps the original imbalanced dataset intact but calculates class weights inversely proportional to class frequencies. These weights are used in the loss function to penalize misclassification of minority classes more heavily, guiding the model to pay increased attention to underrepresented classes without modifying the data distribution.

Both methods utilize the EEGConformer architecture tailored for EEG time-series data and are trained with the Adam optimizer. During training, various performance metrics such as loss, accuracy, ROC curves, and precision-recall curves are monitored to assess the effectiveness of imbalance handling. The trained models are saved and can be reloaded for evaluation.

This combined approach—first acknowledging the natural imbalance, then applying SMOTE for data-level balancing and cost-sensitive learning for algorithm-level adjustment—demonstrates a robust solution for improving classification accuracy on imbalanced EEG datasets.


It includes:

- Dataset handling

- Model definition

- Training and evaluation

- Visualization and metrics reporting

- Saving and loading model checkpoints

#### Key Components
1. Imports and Libraries
- Data Processing: numpy, pandas, mne, scipy, h5py, etc.

- Modeling: torch, torch.nn, torch.optim, braindecode.models.EEGConformer

- Visualization: matplotlib, seaborn

- Evaluation: sklearn.metrics

2. Dataset Preparation
- Data is loaded from .npy files: EEG features (X) and corresponding labels (y).

- Two Dataset classes are defined:

- EEGDataset: accepts arrays directly.

- EEGDatasetFromNumpy: loads data from disk using file paths.
- Class distribution is visualized post-SMOTE to monitor imbalance handling.


3. Train-Validation Split
- Dataset is split using train_test_split (80/20).

- DataLoader is used for batching.

4. Model: EEGConformer
- Convolution + Transformer-based model specifically designed for EEG signals.

Model hyperparameters:

- n_filters_time=40

- filter_time_length=16

- pool_time_length=8, pool_time_stride=4

- att_depth=6, att_heads=10

- Dropout: 0.5

Model is adapted for EEG signals with 18 channels and 64 time steps.

5. Training Loop
Implements:

- Loss tracking (CrossEntropyLoss)

- Accuracy calculation

- Softmax probability extraction

Keeps detailed logs of:

- Epoch-wise training/validation loss and accuracy

- Predictions and probabilities

6. Evaluation & Visualization
- ROC Curve

- Precision-Recall Curve

- Confusion Matrix

- Loss and Accuracy Curves

Classification Report & Metrics:

- Accuracy

- Precision

- Recall

- F1-score

7. Model Saving/Loading
- Trained model is saved using torch.save.

- Reloading is done by reconstructing the architecture and applying load_state_dict.


In [None]:
import os
import re
import random
import glob
from pathlib import Path
from collections import Counter
import warnings

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import scipy.io
from scipy import signal
from scipy.signal import butter, filtfilt, spectrogram
from scipy.stats import skew, kurtosis

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset, Subset

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    precision_score, recall_score, f1_score, roc_auc_score,
    roc_curve, auc, precision_recall_curve, confusion_matrix, classification_report
)

import h5py
import mne
from braindecode.models import EEGConformer

 ## EEGConformer - With Imbalanced EEG Dataset
 
The provided code implements a  pipeline for classifying neonatal EEG signals using the **EEGConformer** model from the Braindecode library. It begins by importing a wide range of libraries for data handling, signal processing, deep learning, and evaluation. EEG data is loaded from `.npy` files and encapsulated in custom PyTorch `Dataset` classes. The dataset is split into training and validation sets using `train_test_split`, and batched via `DataLoader`. The core model, EEGConformer, combines convolutional layers and transformer-based attention mechanisms tailored to EEG signal characteristics (with 18 channels and 64 time steps). The training loop tracks loss and accuracy over multiple epochs and evaluates performance using softmax-based predictions. After training, the model is evaluated using various metrics including accuracy, precision, recall, F1-score, and ROC-AUC. Visualization functions are included for plotting loss curves, accuracy trends, ROC and precision-recall curves, and a confusion matrix. The trained model is saved for future use, and a loading routine is provided to restore it for inference. Overall, the code forms a solid and scalable deep learning framework for EEG-based classification tasks.

In [3]:
warnings.filterwarnings("ignore", category=RuntimeWarning, module="mne")

In [4]:

# =====================================
# Define a Dataset Class that accepts arrays directly
# =====================================
class EEGDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# Training Function
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device):
    train_losses, val_losses = [], []
    train_accuracies, val_accuracies = [], []
    
    # To collect probabilities and labels for further analysis
    train_labels_all, val_labels_all = [], []
    train_preds_all, val_preds_all = [], []
    train_probs_all, val_probs_all = [], []
    
    for epoch in range(num_epochs):
        # --- Training phase (with dropout for weight updates) ---
        model.train()
        running_loss = 0.0
        for batch_data, batch_labels in train_loader:
            batch_data = batch_data.to(device)
            batch_labels = batch_labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(batch_data)
            loss = criterion(outputs, batch_labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * batch_data.size(0)
        
        # --- Evaluate training metrics in eval mode (dropout off) ---
        model.eval()
        correct_train = 0
        running_loss_eval = 0.0
        epoch_train_probs, epoch_train_preds, epoch_train_labels = [], [], []
        with torch.no_grad():
            for batch_data, batch_labels in train_loader:
                batch_data = batch_data.to(device)
                batch_labels = batch_labels.to(device)
                
                outputs = model(batch_data)
                loss = criterion(outputs, batch_labels)
                running_loss_eval += loss.item() * batch_data.size(0)
                
                probs = torch.softmax(outputs, dim=1)[:, 1]
                preds = outputs.argmax(dim=1)
                correct_train += (preds == batch_labels).sum().item()
                
                epoch_train_probs.extend(probs.cpu().numpy())
                epoch_train_preds.extend(preds.cpu().numpy())
                epoch_train_labels.extend(batch_labels.cpu().numpy())
        
        train_loss_epoch = running_loss_eval / len(train_loader.dataset)
        train_accuracy = correct_train / len(train_loader.dataset)
        train_losses.append(train_loss_epoch)
        train_accuracies.append(train_accuracy)
        train_labels_all.extend(epoch_train_labels)
        train_preds_all.extend(epoch_train_preds)
        train_probs_all.extend(epoch_train_probs)
        
        # --- Validation phase ---
        running_loss_val = 0.0
        correct_val = 0
        epoch_val_probs, epoch_val_preds, epoch_val_labels = [], [], []
        with torch.no_grad():
            for batch_data, batch_labels in val_loader:
                batch_data = batch_data.to(device)
                batch_labels = batch_labels.to(device)
                
                outputs = model(batch_data)
                loss = criterion(outputs, batch_labels)
                running_loss_val += loss.item() * batch_data.size(0)
                
                probs = torch.softmax(outputs, dim=1)[:, 1]
                preds = outputs.argmax(dim=1)
                correct_val += (preds == batch_labels).sum().item()
                
                epoch_val_probs.extend(probs.cpu().numpy())
                epoch_val_preds.extend(preds.cpu().numpy())
                epoch_val_labels.extend(batch_labels.cpu().numpy())
        
        val_loss_epoch = running_loss_val / len(val_loader.dataset)
        val_accuracy = correct_val / len(val_loader.dataset)
        val_losses.append(val_loss_epoch)
        val_accuracies.append(val_accuracy)
        val_labels_all.extend(epoch_val_labels)
        val_preds_all.extend(epoch_val_preds)
        val_probs_all.extend(epoch_val_probs)
        
        print(f"Epoch [{epoch+1}/{num_epochs}] "
              f"Train Loss: {train_loss_epoch:.4f}, Train Acc: {train_accuracy:.4f} | "
              f"Val Loss: {val_loss_epoch:.4f}, Val Acc: {val_accuracy:.4f}")
    
    return (train_losses, val_losses, train_accuracies, val_accuracies, 
            train_labels_all, val_labels_all, train_preds_all, val_preds_all, 
            train_probs_all, val_probs_all)
    


# --- Plotting Functions ---

def plot_roc_curve(true_labels, pred_probs, title="ROC Curve"):
    fpr, tpr, _ = roc_curve(true_labels, pred_probs)
    roc_auc = auc(fpr, tpr)
    
    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, color='blue', lw=2, label=f'ROC Curve (AUC = {roc_auc:.4f})')
    plt.plot([0, 1], [0, 1], color='gray', linestyle='--', lw=1)
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(title)
    plt.legend(loc='lower right')
    plt.grid(alpha=0.5)
    plt.savefig("EEGConformer_ROC.png", dpi=300)
    plt.show()

def plot_precision_recall_curve(true_labels, pred_probs, title="Precision-Recall Curve"):
    precision, recall, _ = precision_recall_curve(true_labels, pred_probs)
    plt.figure(figsize=(8, 6))
    plt.plot(recall, precision, color='b', lw=2)
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title(title)
    plt.savefig("EEGConformer_Precision-RecallCurve.png", dpi=300)
    plt.show()

def plot_confusion_matrix_and_metrics(true_labels, preds, title="Confusion Matrix"):
    cm = confusion_matrix(true_labels, preds)
    
    plt.figure(figsize=(6, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False,
                xticklabels=['Negative', 'Positive'], yticklabels=['Negative', 'Positive'])
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title(title)
    plt.savefig("CustomConfusionMatrix.png", dpi=300)
    plt.show()
    
    metrics = {}
    tn, fp, fn, tp = cm.ravel()
    metrics["Accuracy"] = (tp + tn) / (tp + tn + fp + fn)
    metrics["Precision"] = tp / (tp + fp) if (tp + fp) != 0 else 0
    metrics["Recall"] = tp / (tp + fn) if (tp + fn) != 0 else 0
    metrics["F1-Score"] = 2 * (metrics["Precision"] * metrics["Recall"]) / (metrics["Precision"] + metrics["Recall"]) if (metrics["Precision"] + metrics["Recall"]) != 0 else 0
    
    print("\nModel Metrics:")
    for metric, value in metrics.items():
        print(f"{metric}: {value:.4f}")
    
    print("\nClassification Report:")
    print(classification_report(true_labels, preds, target_names=['Negative', 'Positive']))

def plot_loss_curve(train_losses, val_losses, title="Training and Validation Loss"):
    plt.figure(figsize=(8, 6))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(title)
    plt.legend()
    plt.savefig("EEGConformerTrainingandValidationLoss.png", dpi=300)
    plt.show()

def plot_accuracy_curve(train_accuracies, val_accuracies, title="Training and Validation Accuracy"):
    plt.figure(figsize=(8, 6))
    plt.plot(train_accuracies, label='Training Accuracy')
    plt.plot(val_accuracies, label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title(title)
    plt.legend()
    plt.savefig("EEGConformerTrainingandValidationAccuracy.png", dpi=300)
    plt.show()
# Set seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [5]:
# --- Training Setup ---
# (Ensure that you have defined batch_size, num_epochs, and learning_rate)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 64
num_epochs = 100
learning_rate = 0.001


# Load dataset (update paths as needed)
X_path = "Neontal_eeg_dataset1/annotations/EMS_Normalized_annonated_X_features.npy"
y_path = "Neontal_eeg_dataset1/annotations/EMS_Normalized_annonated_y_features.npy"

In [6]:
# =====================================
# Define a Dataset Class that accepts arrays directly
# =====================================
class EEGDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [None]:
# --- Dataset Definition ---
class EEGDatasetFromNumpy(Dataset):
    def __init__(self, X_path, y_path):
        # Load the data directly from .npy files
        self.X = np.load(X_path)
        self.y = np.load(y_path)
        
        # Ensure matching number of samples
        assert len(self.X) == len(self.y), "Mismatch between number of samples in X and y"
        
        # Convert to tensors
        self.X = torch.tensor(self.X, dtype=torch.float32)
        self.y = torch.tensor(self.y, dtype=torch.long)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]
dataset = EEGDatasetFromNumpy(X_path, y_path)
print("Dataset X shape:", dataset.X.shape, "y shape:", dataset.y.shape)

In [None]:
unique_res, counts_res = np.unique(dataset.y, return_counts=True)
plt.figure()
plt.bar(unique_res, counts_res, color='salmon')
plt.title("Class Distribution After SMOTE")
plt.xlabel("Class")
plt.ylabel("Count")
plt.show()

In [9]:
# --- Training Setup ---
# (Ensure that you have defined batch_size, num_epochs, and learning_rate)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 64
num_epochs = 100
learning_rate = 0.001

## IMBALANCE 

In [None]:


from sklearn.model_selection import train_test_split
indices = list(range(len(dataset)))
train_idx, val_idx = train_test_split(indices, test_size=0.2, random_state=42)
train_dataset = torch.utils.data.Subset(dataset, train_idx)
val_dataset = torch.utils.data.Subset(dataset, val_idx)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)


from braindecode.models import EEGConformer

n_outputs = len(torch.unique(dataset.y))  # Number of classes
n_chans = 18  # From your input shape
n_times = 64  # From your input shape

model = EEGConformer(
    n_outputs=n_outputs,
    n_chans=n_chans,
    n_times=n_times,
    n_filters_time=40,           # Default
    filter_time_length=16,       # Adjusted for 64 time steps
    pool_time_length=8,          # Reduced from default 75
    pool_time_stride=4,          # Reduced from default 15
    att_depth=6,                 # Default
    att_heads=10,                # Default
    drop_prob=0.5,               # Default
    final_fc_length="auto",      # Automatically inferred
    return_features=False,       # Return logits, not features
)

# Initialize model and training components
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Train model
(train_losses, val_losses, train_accuracies, val_accuracies, 
 train_labels, val_labels, train_preds, val_preds, 
 train_probs, val_probs) = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device)

# Plot curves
plot_loss_curve(train_losses, val_losses, title="Training and Validation Loss")
plot_accuracy_curve(train_accuracies, val_accuracies, title="Training and Validation Accuracy")
plot_roc_curve(val_labels, val_probs, title="Validation ROC Curve")
plot_precision_recall_curve(val_labels, val_probs, title="Validation Precision-Recall Curve")
plot_confusion_matrix_and_metrics(val_labels, val_preds, title="Validation Confusion Matrix")

In [11]:
# Save the model state dict
torch.save(model.state_dict(), "orig_imb_model_state_conformer_v4.pth")

In [None]:
 # load the model by re-creating the model architecture and then loading the state:
# Recreate the model architecture
model = EEGConformer(
    n_outputs=n_outputs,
    n_chans=n_chans,
    n_times=n_times,
    n_filters_time=40,
    filter_time_length=16,
    pool_time_length=8,
    pool_time_stride=4,
    att_depth=6,
    att_heads=10,
    drop_prob=0.5,
    final_fc_length="auto",
    return_features=False,
)
model.load_state_dict(torch.load("orig_imb_model_state_conformer_v4.pth"))
model.eval()  # Set to evaluation mode

In [13]:
# Alternatively, if you prefer to save the entire model (which includes the architecture) you can use
torch.save(model, "final_orig_imb_model_state_conformer_v4.pth")

In [None]:
# load
model = torch.load("final_orig_imb_model_state_conformer_v4.pth")
model.eval()

## SMOTE

This sections incorporates **SMOTE (Synthetic Minority Over-sampling Technique)** to address class imbalance in a neonatal EEG classification task using the **EEGConformer** model. Initially, EEG data is loaded and reshaped into a 2D format as required by SMOTE. The algorithm synthetically generates new samples for the minority class, effectively balancing the dataset. The resampled data is reshaped back into its original 3D EEG format and wrapped in a custom PyTorch `Dataset` class. The new balanced dataset is then split into training and validation sets. The **EEGConformer**, a hybrid deep learning model combining convolutional and attention mechanisms, is instantiated with parameters suited for 18-channel, 64-timepoint EEG inputs. The model is trained using cross-entropy loss and the Adam optimizer, while monitoring training and validation performance over multiple epochs. A suite of visualization functions provides insight into training dynamics (loss and accuracy curves), classification performance (ROC and precision-recall curves), and prediction quality (confusion matrix and metrics). Finally, the model is saved and later reloaded in evaluation mode for further use. This workflow ensures that the classifier is not biased toward the majority class and is better equipped for robust EEG signal classification in imbalanced scenarios.


In [25]:
%%capture
!pip install imblearn

In [None]:
from imblearn.over_sampling import SMOTE
# --- Apply SMOTE ---
# SMOTE expects a 2D array, so flatten each sample:
n_samples, n_channels, n_timepoints = dataset.X.shape
X_flat = dataset.X.reshape((n_samples, -1))

smote = SMOTE(random_state=42)
X_res, y_res = smote.fit_resample(X_flat, dataset.y)
X_res = X_res.reshape((-1, n_channels, n_timepoints))
print("After SMOTE -> X:", X_res.shape, "y:", y_res.shape)

# =====================================
# Visualize Class Distribution After SMOTE
# =====================================
unique_res, counts_res = np.unique(y_res, return_counts=True)
plt.figure()
plt.bar(unique_res, counts_res, color='salmon')
plt.title("Class Distribution After SMOTE")
plt.xlabel("Class")
plt.ylabel("Count")
plt.show()

In [27]:
# Create dataset from the resampled arrays
dataset = EEGDataset(X_res, y_res)

In [None]:
from sklearn.model_selection import train_test_split
indices = list(range(len(dataset)))
train_idx, val_idx = train_test_split(indices, test_size=0.2, random_state=42)
train_dataset = torch.utils.data.Subset(dataset, train_idx)
val_dataset = torch.utils.data.Subset(dataset, val_idx)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)


from braindecode.models import EEGConformer

n_outputs = len(torch.unique(dataset.y))  # Number of classes
n_chans = 18  # From your input shape
n_times = 64  # From your input shape

model = EEGConformer(
    n_outputs=n_outputs,
    n_chans=n_chans,
    n_times=n_times,
    n_filters_time=40,           # Default
    filter_time_length=16,       # Adjusted for 64 time steps
    pool_time_length=8,          # Reduced from default 75
    pool_time_stride=4,          # Reduced from default 15
    att_depth=6,                 # Default
    att_heads=10,                # Default
    drop_prob=0.5,               # Default
    final_fc_length="auto",      # Automatically inferred
    return_features=False,       # Return logits, not features
)

# Initialize model and training components
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Train model
(train_losses, val_losses, train_accuracies, val_accuracies, 
 train_labels, val_labels, train_preds, val_preds, 
 train_probs, val_probs) = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device)

# Plot curves
plot_loss_curve(train_losses, val_losses, title="Training and Validation Loss")
plot_accuracy_curve(train_accuracies, val_accuracies, title="Training and Validation Accuracy")
plot_roc_curve(val_labels, val_probs, title="Validation ROC Curve")
plot_precision_recall_curve(val_labels, val_probs, title="Validation Precision-Recall Curve")
plot_confusion_matrix_and_metrics(val_labels, val_preds, title="Validation Confusion Matrix")

Post visualization

In [43]:
# Save the model state dict
torch.save(model.state_dict(), "smote_best_model_state_conformer_v4.pth")

In [None]:
 # load the model by re-creating the model architecture and then loading the state:
# Recreate the model architecture
model = EEGConformer(
    n_outputs=n_outputs,
    n_chans=n_chans,
    n_times=n_times,
    n_filters_time=40,
    filter_time_length=16,
    pool_time_length=8,
    pool_time_stride=4,
    att_depth=6,
    att_heads=10,
    drop_prob=0.5,
    final_fc_length="auto",
    return_features=False,
)
model.load_state_dict(torch.load("smote_best_model_state_conformer_v4.pth"))
model.eval()  # Set to evaluation mode

In [32]:
# Alternatively, if you prefer to save the entire model (which includes the architecture) you can use
torch.save(model, "final_smote_best_model_state_conformer_v4.pth")

In [None]:
# load
model = torch.load("final_smote_best_model_state_conformer_v4.pth")
model.eval()

## Cost Sensitive Learning

This section  implements **cost-sensitive learning (CSL)** to handle class imbalance in a neonatal EEG classification problem using the **EEGConformer** model. The EEG dataset is loaded without resampling, and the original class distribution is visualized. The dataset is then split into training and validation subsets, and data loaders are created for both. To account for class imbalance, class weights are computed inversely proportional to class frequencies, so that minority classes receive higher penalty during training. These weights are passed to the cross-entropy loss function, making the model more sensitive to errors on underrepresented classes. The EEGConformer model, configured for multi-channel time-series EEG input, is trained using the weighted loss and Adam optimizer. Training and validation performance are monitored via loss, accuracy, ROC, precision-recall curves, and confusion matrix visualizations. Finally, the trained model’s state is saved and later reloaded for evaluation, ensuring robust performance on imbalanced EEG classification tasks by explicitly incorporating class-specific costs during training.


In [None]:
# Load dataset (update paths as needed)
X_path = "Neontal_eeg_dataset1/annotations/EMS_Normalized_annonated_X_features.npy"
y_path = "Neontal_eeg_dataset1/annotations/EMS_Normalized_annonated_y_features.npy"

dataset = EEGDatasetFromNumpy(X_path, y_path)
print("Dataset X shape:", dataset.X.shape, "y shape:", dataset.y.shape)

# =====================================
# Visualize Class Distribution Before Cost-Sensitive Adjustment
# =====================================
unique, counts = np.unique(dataset.y, return_counts=True)
plt.figure()
plt.bar(unique, counts, color='skyblue')
plt.title("Class Distribution Before Cost-Sensitive Adjustment")
plt.xlabel("Class")
plt.ylabel("Count")
plt.show()

In [None]:
# Create dataset from the original arrays (no resampling)
dataset = EEGDataset(dataset.X, dataset.y)

# Split dataset into training and validation sets
indices = list(range(len(dataset)))
train_idx, val_idx = train_test_split(indices, test_size=0.2, random_state=42)
train_dataset = torch.utils.data.Subset(dataset, train_idx)
val_dataset   = torch.utils.data.Subset(dataset, val_idx)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# ------------------------------------
# Cost-Sensitive Learning: Compute Class Weights
# ------------------------------------
# Compute weights using the formula:
#    weight_i = total_samples / (num_classes * count_i)
total_samples = len(dataset.y)
num_classes = len(unique)
class_weights = [total_samples / (num_classes * count) for count in counts]
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(device)
print("Class Weights:", class_weights)

model = EEGConformer(
    n_outputs=n_outputs,
    n_chans=n_chans,
    n_times=n_times,
    n_filters_time=40,           # Default
    filter_time_length=16,       # Adjusted for 64 time steps
    pool_time_length=8,          # Reduced from default 75
    pool_time_stride=4,          # Reduced from default 15
    att_depth=6,                 # Default
    att_heads=10,                # Default
    drop_prob=0.5,               # Default
    final_fc_length="auto",      # Automatically inferred
    return_features=False,       # Return logits, not features
)

# Initialize model and training components
model = model.to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights) # wieghts passed for penality of misclassfication
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Train model
(train_losses, val_losses, train_accuracies, val_accuracies, 
 train_labels, val_labels, train_preds, val_preds, 
 train_probs, val_probs) = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device)

# Plot curves
plot_loss_curve(train_losses, val_losses, title="CSL EEG Conformer Training and Validation Loss")
plot_accuracy_curve(train_accuracies, val_accuracies, title="SL EEG Conformer Training and Validation Accuracy")
plot_roc_curve(val_labels, val_probs, title="SL EEG Conformer Validation ROC Curve")
plot_precision_recall_curve(val_labels, val_probs, title="SL EEG Conformer Validation Precision-Recall Curve")
plot_confusion_matrix_and_metrics(val_labels, val_preds, title="SL EEG Conformer Validation Confusion Matrix")

In [17]:
# Save the model state dict
torch.save(model.state_dict(), "csl_imb_model_state_conformer_v4.pth")

In [None]:
 # load the model by re-creating the model architecture and then loading the state:
# Recreate the model architecture
model = EEGConformer(
    n_outputs=n_outputs,
    n_chans=n_chans,
    n_times=n_times,
    n_filters_time=40,
    filter_time_length=16,
    pool_time_length=8,
    pool_time_stride=4,
    att_depth=6,
    att_heads=10,
    drop_prob=0.5,
    final_fc_length="auto",
    return_features=False,
)
model.load_state_dict(torch.load("csl_imb_model_state_conformer_v4.pth"))
model.eval()  # Set to evaluation mode

In [19]:
# Alternatively, if you prefer to save the entire model (which includes the architecture) you can use
torch.save(model, "final_csl_imb_model_state_conformer_v4.pth")

In [None]:
# load
model = torch.load("final_csl_imb_model_state_conformer_v4.pth")
model.eval()