**NeoEEGTransformer**

The **NeoEEGTransformer** is a custom neural network architecture specifically designed for seizure classification based on EEG signal. The transformer-based neural network was designed for the classification of multi-channel EEG data. Class imbalance in EEG data poses a significant challenge, often causing models to underperform on minority classes. To address this, SMOTE generates synthetic minority class samples by augmenting the dataset, which balances class distribution but may alter data characteristics. Alternatively, cost-sensitive learning uses class-weighted loss functions to penalize misclassifications of minority classes more heavily, improving recognition without changing the original data. Both approaches enhance model fairness and performance, with SMOTE focusing on data-level balancing and cost-sensitive learning emphasizing algorithm-level adjustments. Overall, the EEGTransformer model is designed to effectively process and classify EEG signals by combining normalization, positional encoding, attention mechanisms, and feed-forward processing. This makes it well-suited for applications in EEG analysis where capturing complex spatial-temporal relationships is critical.  







The model is engineered to capture complex spatial-temporal dynamics inherent in EEG signals through the following key components:

- **Input and Preprocessing:**  
  - **Input Shape:** The model expects EEG data in the shape `(batch_size, num_channels, num_timepoints)`.
  - **Standardization:** Each sample is normalized along the time axis by subtracting the mean and dividing by the standard deviation.

- **Positional Encoding:**  
  - A fixed positional encoding is computed using sine and cosine functions across channels and time points.  
  - This encoding is added to the normalized input to inject temporal information into the model, enabling it to distinguish the order of time points.

- **Multi-Head Self-Attention:**  
  - A multi-head attention mechanism captures relationships between different channels across time.
  - This component allows the model to focus on relevant parts of the signal by computing attention weights.

- **Feed-Forward Network (FFN) and Residual Connections:**  
  - The attention outputs are passed through a feed-forward network consisting of two linear layers with a ReLU activation in between.
  - Residual connections and layer normalization are applied to stabilize training and improve gradient flow.

- **Classifier:**  
  - The processed features are flattened and fed into a final linear layer that acts as the classifier, mapping the features to the desired output classes.

- **Regularization:**  
  - Dropout is employed within the attention layer to reduce overfitting and improve generalization.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
from sklearn.metrics import roc_auc_score, roc_curve, precision_recall_curve, auc, confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
import math

from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Subset
import h5py
import os
import numpy as np
from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import roc_auc_score, roc_curve, precision_recall_curve, confusion_matrix, accuracy_score
import seaborn as sns
from sklearn.metrics import ConfusionMatrixDisplay
from sklearn.metrics import auc
from sklearn.metrics import classification_report
from sklearn.model_selection import learning_curve
from collections import Counter

In [10]:
# EEG Transformer Model
class EEGTransformer(nn.Module):
    def __init__(self, num_channels=18, num_timepoints=64, output_dim=2,
                 num_heads=6, intermediate_dim=128, ffn_output_dim=18):
        super(EEGTransformer, self).__init__()
        
        # Build the positional encoding using math.sin and math.cos
        positional_encoding = torch.zeros(num_channels, num_timepoints)
        for j in range(num_channels):
            for k in range(num_timepoints):
                if j % 2 == 0:
                    positional_encoding[j, k] = math.sin(k / (10000 ** (j / num_channels)))
                else:
                    positional_encoding[j, k] = math.cos(k / (10000 ** ((j - 1) / num_channels)))
                    
        # Register positional_encoding as a buffer so that it's moved to the correct device
        self.register_buffer('positional_encoding', positional_encoding)
        
        self.multihead_attn = nn.MultiheadAttention(embed_dim=num_channels, num_heads=num_heads)
        self.ffn = nn.Sequential(
            nn.Linear(num_channels, intermediate_dim),
            nn.ReLU(),
            nn.Linear(intermediate_dim, ffn_output_dim)
        )
        self.norm1 = nn.LayerNorm(num_channels)
        self.norm2 = nn.LayerNorm(num_channels)
        self.classifier = nn.Linear(num_channels * num_timepoints, output_dim)
        self.dropout = nn.Dropout(0.1)

    def forward(self, X):
        # Input shape: (batch_size, num_channels, num_timepoints)
        mean = X.mean(dim=2, keepdim=True)
        std = X.std(dim=2, keepdim=True)
        X_hat = (X - mean) / (std + 1e-5)
        
        X_tilde = X_hat + self.positional_encoding  # Buffer automatically moves with the model
        X_tilde = X_tilde.permute(2, 0, 1)  # (num_timepoints, batch_size, num_channels)
        
        attn_output, _ = self.multihead_attn(X_tilde, X_tilde, X_tilde)
        attn_output = self.dropout(attn_output)
        X_ring = torch.stack([self.norm1(a) for a in attn_output], dim=1)
        
        ff_output = self.ffn(X_ring)
        O = self.norm2(ff_output + X_ring)
        O_flat = O.view(O.size(0), -1)
        output = self.classifier(O_flat)
        
        return output
        
# Updated train_model that evaluates training data in eval mode (to disable dropout)
# and also saves the softmax probabilities for ROC/PR calculations.
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device):
    train_losses, val_losses = [], []
    train_accuracies, val_accuracies = [], []
    
    # To collect probabilities and labels for further analysis
    train_labels_all, val_labels_all = [], []
    train_preds_all, val_preds_all = [], []
    train_probs_all, val_probs_all = [], []
    
    for epoch in range(num_epochs):
        # --- Training phase (with dropout for weight updates) ---
        model.train()
        running_loss = 0.0
        for batch_data, batch_labels in train_loader:
            batch_data = batch_data.to(device)
            batch_labels = batch_labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(batch_data)
            loss = criterion(outputs, batch_labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * batch_data.size(0)
        
        # --- Evaluate training metrics in eval mode (dropout off) ---
        model.eval()
        correct_train = 0
        running_loss_eval = 0.0
        epoch_train_probs, epoch_train_preds, epoch_train_labels = [], [], []
        with torch.no_grad():
            for batch_data, batch_labels in train_loader:
                batch_data = batch_data.to(device)
                batch_labels = batch_labels.to(device)
                
                outputs = model(batch_data)
                loss = criterion(outputs, batch_labels)
                running_loss_eval += loss.item() * batch_data.size(0)
                
                probs = torch.softmax(outputs, dim=1)[:, 1]
                preds = outputs.argmax(dim=1)
                correct_train += (preds == batch_labels).sum().item()
                
                epoch_train_probs.extend(probs.cpu().numpy())
                epoch_train_preds.extend(preds.cpu().numpy())
                epoch_train_labels.extend(batch_labels.cpu().numpy())
        
        train_loss_epoch = running_loss_eval / len(train_loader.dataset)
        train_accuracy = correct_train / len(train_loader.dataset)
        train_losses.append(train_loss_epoch)
        train_accuracies.append(train_accuracy)
        train_labels_all.extend(epoch_train_labels)
        train_preds_all.extend(epoch_train_preds)
        train_probs_all.extend(epoch_train_probs)
        
        # --- Validation phase ---
        running_loss_val = 0.0
        correct_val = 0
        epoch_val_probs, epoch_val_preds, epoch_val_labels = [], [], []
        with torch.no_grad():
            for batch_data, batch_labels in val_loader:
                batch_data = batch_data.to(device)
                batch_labels = batch_labels.to(device)
                
                outputs = model(batch_data)
                loss = criterion(outputs, batch_labels)
                running_loss_val += loss.item() * batch_data.size(0)
                
                probs = torch.softmax(outputs, dim=1)[:, 1]
                preds = outputs.argmax(dim=1)
                correct_val += (preds == batch_labels).sum().item()
                
                epoch_val_probs.extend(probs.cpu().numpy())
                epoch_val_preds.extend(preds.cpu().numpy())
                epoch_val_labels.extend(batch_labels.cpu().numpy())
        
        val_loss_epoch = running_loss_val / len(val_loader.dataset)
        val_accuracy = correct_val / len(val_loader.dataset)
        val_losses.append(val_loss_epoch)
        val_accuracies.append(val_accuracy)
        val_labels_all.extend(epoch_val_labels)
        val_preds_all.extend(epoch_val_preds)
        val_probs_all.extend(epoch_val_probs)
        
        print(f"Epoch [{epoch+1}/{num_epochs}] "
              f"Train Loss: {train_loss_epoch:.4f}, Train Acc: {train_accuracy:.4f} | "
              f"Val Loss: {val_loss_epoch:.4f}, Val Acc: {val_accuracy:.4f}")
    
    return (train_losses, val_losses, train_accuracies, val_accuracies, 
            train_labels_all, val_labels_all, train_preds_all, val_preds_all, 
            train_probs_all, val_probs_all)

# --- Plotting Functions ---

def plot_roc_curve(true_labels, pred_probs, title="ROC Curve"):
    fpr, tpr, _ = roc_curve(true_labels, pred_probs)
    roc_auc = auc(fpr, tpr)
    
    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, color='blue', lw=2, label=f'ROC Curve (AUC = {roc_auc:.4f})')
    plt.plot([0, 1], [0, 1], color='gray', linestyle='--', lw=1)
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(title)
    plt.legend(loc='lower right')
    plt.grid(alpha=0.5)
    plt.savefig("CustomTransormer_ROC.png", dpi=300)
    plt.show()

def plot_precision_recall_curve(true_labels, pred_probs, title="Precision-Recall Curve"):
    precision, recall, _ = precision_recall_curve(true_labels, pred_probs)
    plt.figure(figsize=(8, 6))
    plt.plot(recall, precision, color='b', lw=2)
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title(title)
    plt.savefig("CustomTransormer_Precision-RecallCurve.png", dpi=300)
    plt.show()

def plot_confusion_matrix_and_metrics(true_labels, preds, title="Confusion Matrix"):
    cm = confusion_matrix(true_labels, preds)
    
    plt.figure(figsize=(6, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False,
                xticklabels=['Negative', 'Positive'], yticklabels=['Negative', 'Positive'])
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title(title)
    plt.savefig("CustomConfusionMatrix.png", dpi=300)
    plt.show()
    
    metrics = {}
    tn, fp, fn, tp = cm.ravel()
    metrics["Accuracy"] = (tp + tn) / (tp + tn + fp + fn)
    metrics["Precision"] = tp / (tp + fp) if (tp + fp) != 0 else 0
    metrics["Recall"] = tp / (tp + fn) if (tp + fn) != 0 else 0
    metrics["F1-Score"] = 2 * (metrics["Precision"] * metrics["Recall"]) / (metrics["Precision"] + metrics["Recall"]) if (metrics["Precision"] + metrics["Recall"]) != 0 else 0
    
    print("\nModel Metrics:")
    for metric, value in metrics.items():
        print(f"{metric}: {value:.4f}")
    
    print("\nClassification Report:")
    print(classification_report(true_labels, preds, target_names=['Negative', 'Positive']))

def plot_loss_curve(train_losses, val_losses, title="Training and Validation Loss"):
    plt.figure(figsize=(8, 6))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(title)
    plt.legend()
    plt.savefig("CustomTrainingandValidationLoss.png", dpi=300)
    plt.show()

def plot_accuracy_curve(train_accuracies, val_accuracies, title="Training and Validation Accuracy"):
    plt.figure(figsize=(8, 6))
    plt.plot(train_accuracies, label='Training Accuracy')
    plt.plot(val_accuracies, label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title(title)
    plt.legend()
    plt.savefig("CustomTrainingandValidationAccuracy.png", dpi=300)
    plt.show()

# --- Dataset Definition ---
class EEGDatasetFromNumpy(Dataset):
    def __init__(self, X_path, y_path):
        # Load the data directly from .npy files
        self.X = np.load(X_path)
        self.y = np.load(y_path)
        
        # Ensure matching number of samples
        assert len(self.X) == len(self.y), "Mismatch between number of samples in X and y"
        
        # Convert to tensors
        self.X = torch.tensor(self.X, dtype=torch.float32)
        self.y = torch.tensor(self.y, dtype=torch.long)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]



## Imbalanced Training Pipeline 

Class imbalance is a common challenge in EEG classification tasks where certain brain state classes are underrepresented compared to others. This imbalance can bias a model toward the majority class, leading to poor predictive performance on minority classes and overall reduced clinical utility. Properly addressing class imbalance is critical to developing reliable models that generalize well across all categories.

This code provides an end-to-end pipeline for training and evaluating a Transformer-based EEG classification model with strong emphasis on performance visualization and detailed metrics reporting.


**1. Imports and Setup**

* Uses PyTorch for deep learning, NumPy for array handling, and scikit-learn for metrics and data splitting.
* Matplotlib and Seaborn are used for visualization.

**2. Custom EEGTransformer Model**

* Implements a Transformer-based neural network for EEG data classification.
* Includes positional encoding based on sine and cosine functions for temporal context.
* Uses multi-head self-attention, feedforward layers, layer normalization, dropout, and a final linear classifier.
* Input shape: `(batch_size, channels, timepoints)`.

**3. Dataset Handling**

* `EEGDatasetFromNumpy` loads EEG features and labels from `.npy` files, converting them into PyTorch tensors.
* Data is split into training and validation subsets using `train_test_split`.

**4. Training Loop (`train_model`)**

* Alternates between training and evaluation mode per epoch to properly handle dropout and batch normalization.
* Collects losses, accuracies, predicted labels, and prediction probabilities for both train and validation sets.
* Uses cross-entropy loss and Adam optimizer.

**5. Evaluation and Visualization**

* Plots training/validation loss and accuracy curves over epochs.
* Generates ROC curves, precision-recall curves, and confusion matrices to assess model performance.
* Prints detailed classification reports and key metrics (accuracy, precision, recall, F1-score).

**6. Model Saving and Inference**

* Saves the trained model’s state dictionary.
* Demonstrates how to reload the model and set it to evaluation mode for inference.

In [None]:
# --- Training Setup ---
# (Ensure that you have defined batch_size, num_epochs, and learning_rate)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 32
num_epochs = 100
learning_rate = 0.001

# Load dataset (update paths as needed)
X_path = "Neontal_eeg_dataset1/annotations/Normalized_Updated_annonated_X_features.npy"
y_path = "Neontal_eeg_dataset1/annotations/Normalized_Updated_annonated_y_features.npy"

dataset = EEGDatasetFromNumpy(X_path, y_path)
print("Dataset X shape:", dataset.X.shape, "y shape:", dataset.y.shape)

from sklearn.model_selection import train_test_split
indices = list(range(len(dataset)))
train_idx, val_idx = train_test_split(indices, test_size=0.2, random_state=42)
train_dataset = torch.utils.data.Subset(dataset, train_idx)
val_dataset = torch.utils.data.Subset(dataset, val_idx)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Initialize model and training components
model = EEGTransformer().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Train model
(train_losses, val_losses, train_accuracies, val_accuracies, 
 train_labels, val_labels, train_preds, val_preds, 
 train_probs, val_probs) = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device)

# Plot curves
plot_loss_curve(train_losses, val_losses, title="Training and Validation Loss")
plot_accuracy_curve(train_accuracies, val_accuracies, title="Training and Validation Accuracy")
plot_roc_curve(val_labels, val_probs, title="Validation ROC Curve")
plot_precision_recall_curve(val_labels, val_probs, title="Validation Precision-Recall Curve")
plot_confusion_matrix_and_metrics(val_labels, val_preds, title="Validation Confusion Matrix")


In [None]:
plot_loss_curve(train_losses, val_losses, title="Training and Validation Loss")
plot_accuracy_curve(train_accuracies, val_accuracies, title="Training and Validation Accuracy")
plot_roc_curve(val_labels, val_probs, title="Validation ROC Curve")
plot_precision_recall_curve(val_labels, val_probs, title="Validation Precision-Recall Curve")
plot_confusion_matrix_and_metrics(val_labels, val_preds, title="Validation Confusion Matrix")

In [4]:
# Save the model state dictionary after training
torch.save(model.state_dict(), 'custom_neonatal_eeg_transformer.pth')

In [None]:
# Inferences
# Recreate the model architecture (ensure the parameters match your training setup)
model = EEGTransformer(num_channels=18, num_timepoints=64, output_dim=2,
                       num_heads=6, intermediate_dim=128, ffn_output_dim=18)

# Load the saved weights (make sure to use the same device as your inference)
model.load_state_dict(torch.load('custom_neonatal_eeg_transformer.pth', map_location=device))
model.to(device)

# Set the model to evaluation mode
model.eval()


Device Management
 - If you trained your model on a GPU, ensure that during inference your model and data are moved to the correct device (using model.to(device) and input_data.to(device)).

Batch Inference
- The predict function is written to handle a batch of samples. For a single sample, your input shape should be (1, 18, 64).

Data Preprocessing
- Make sure any preprocessing (normalization, scaling, etc.) that you applied during training is also applied to new input data.

In [None]:
# Performance Evaluation
def predict(input_data):
    """
    Perform inference on input_data using the loaded model.
    
    Parameters:
        input_data (torch.Tensor): Input tensor of shape (batch_size, num_channels, num_timepoints)
        
    Returns:
        predictions (torch.Tensor): Predicted class indices.
        probabilities (torch.Tensor): Softmax probabilities for each class.
    """
    model.eval()  # Ensure the model is in eval mode
    with torch.no_grad():
        # Move the input to the same device as the model
        input_data = input_data.to(device)
        outputs = model(input_data)
        
        # Compute softmax probabilities
        probabilities = torch.softmax(outputs, dim=1)
        
        # Get the predicted classes (for example, using argmax)
        predictions = torch.argmax(probabilities, dim=1)
    
    return predictions, probabilities

# Example usage:
# Suppose you have a single sample with shape (1, 18, 64)
sample_input = torch.randn(1, 18, 64)  # Replace with your actual data sample

# Run prediction
predicted_class, pred_probs = predict(sample_input)
print("Predicted class:", predicted_class.item())
print("Prediction probabilities:", pred_probs)

## SMOTE

One popular approach to handle imbalance is SMOTE (Synthetic Minority Over-sampling Technique), which artificially increases the number of minority class samples by generating synthetic data points in feature space. In the EEG context, the high-dimensional time-series data is flattened to apply SMOTE, then reshaped back to the original dimensions. This augmentation balances class distributions and allows the model to learn more robust decision boundaries. However, SMOTE modifies the data distribution, which may introduce noise or reduce interpretability.

This sections ensures the EEG dataset is balanced with SMOTE before training the Transformer, leading to improved handling of class imbalance and potentially better classification performance. smote.fit_resample(X_flat, y) first "fits" the SMOTE algorithm to your data by learning the feature space of the minority class. Then it "resamples" the dataset by creating synthetic samples for the minority class so that the overall dataset becomes more balanced.


**1. Imports and Setup**

* Uses PyTorch, NumPy, Matplotlib, Seaborn for model building, data handling, and visualization.
* Incorporates `SMOTE` from imbalanced-learn to address class imbalance.

**2. Dataset Preparation**

* Original EEG dataset is loaded as a 3D array `(samples, channels, timepoints)` with associated labels.
* Flatten each sample into 2D `(samples, features)` to apply SMOTE, which synthesizes new minority class samples.
* Reshape the oversampled data back to original 3D EEG format for model input.

**3. Visualization of Class Distribution**

* Plots class distribution before and after SMOTE to demonstrate balancing effectiveness.
* Helps visually confirm that minority class samples are augmented properly.

**4. Custom Dataset Class**

* `EEGDataset` accepts data arrays directly (post-SMOTE) and converts them to PyTorch tensors.
* Provides `__len__` and `__getitem__` for DataLoader compatibility.

**5. Data Splitting and Loading**

* Splits the balanced dataset into training and validation sets using `train_test_split`.
* Creates DataLoaders for batching during training and validation.

**6. Model Architecture: EEGTransformer**

* Transformer-based model designed for EEG data with positional encoding, multi-head attention, feedforward layers, normalization, dropout, and a linear classification head.
* Input shape: `(batch_size, channels, timepoints)`.

**7. Training Loop**

* Implements training and evaluation phases per epoch, tracking loss, accuracy, predictions, and probabilities.
* Uses cross-entropy loss and Adam optimizer.

**8. Evaluation and Visualization**

* Generates performance plots:

  * Training & validation loss and accuracy curves.
  * ROC curve and Precision-Recall curve for validation.
  * Confusion matrix with detailed classification metrics.

**9. Model Saving**

* Saves the trained model weights after training completes, allowing for future inference or fine-tuning.

In [None]:
%%capture
!pip install imblearn

In [None]:
# Load dataset (update paths as needed)
X_path = "Neontal_eeg_dataset1/annotations/Normalized_Updated_annonated_X_features.npy"
y_path = "Neontal_eeg_dataset1/annotations/Normalized_Updated_annonated_y_features.npy"

dataset = EEGDatasetFromNumpy(X_path, y_path)
print("Dataset X shape:", dataset.X.shape, "y shape:", dataset.y.shape)

In [23]:
# =====================================
# Define a Dataset Class that accepts arrays directly
# =====================================
class EEGDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [None]:
unique_res, counts_res = np.unique(dataset.y, return_counts=True)
plt.figure()
plt.bar(unique_res, counts_res, color='salmon')
plt.title("Class Distribution After SMOTE")
plt.xlabel("Class")
plt.ylabel("Count")
plt.show()

In [None]:
# --- Apply SMOTE ---
# SMOTE expects a 2D array, so flatten each sample:
n_samples, n_channels, n_timepoints = dataset.X.shape
X_flat = dataset.X.reshape((n_samples, -1))

smote = SMOTE(random_state=42)
X_res, y_res = smote.fit_resample(X_flat, dataset.y)
X_res = X_res.reshape((-1, n_channels, n_timepoints))
print("After SMOTE -> X:", X_res.shape, "y:", y_res.shape)

# =====================================
# Visualize Class Distribution After SMOTE
# =====================================
unique_res, counts_res = np.unique(y_res, return_counts=True)
plt.figure()
plt.bar(unique_res, counts_res, color='salmon')
plt.title("Class Distribution After SMOTE")
plt.xlabel("Class")
plt.ylabel("Count")
plt.show()

In [27]:

# Create dataset from the resampled arrays
dataset = EEGDataset(X_res, y_res)

In [None]:
# --- Training Setup ---
# (Ensure that you have defined batch_size, num_epochs, and learning_rate)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 32
num_epochs = 100
learning_rate = 0.001


from sklearn.model_selection import train_test_split
indices = list(range(len(dataset)))
train_idx, val_idx = train_test_split(indices, test_size=0.2, random_state=42)
train_dataset = torch.utils.data.Subset(dataset, train_idx)
val_dataset = torch.utils.data.Subset(dataset, val_idx)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Initialize model and training components
model = EEGTransformer().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Train model
(train_losses, val_losses, train_accuracies, val_accuracies, 
 train_labels, val_labels, train_preds, val_preds, 
 train_probs, val_probs) = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device)

# Plot curves
plot_loss_curve(train_losses, val_losses, title="Training and Validation Loss")
plot_accuracy_curve(train_accuracies, val_accuracies, title="Training and Validation Accuracy")
plot_roc_curve(val_labels, val_probs, title="Validation ROC Curve")
plot_precision_recall_curve(val_labels, val_probs, title="Validation Precision-Recall Curve")
plot_confusion_matrix_and_metrics(val_labels, val_preds, title="Validation Confusion Matrix")


In [29]:
# Save the model state dictionary after training
torch.save(model.state_dict(), 'smote_custom_neonatal_eeg_transformer.pth')

##  Cost Sentitive Learning 
An alternative strategy is cost-sensitive learning, where the model is trained using weighted loss functions that assign higher penalties to misclassifications of minority classes without altering the original data. Class weights are computed inversely proportional to class frequencies, making the model “pay more attention” to underrepresented classes during training. This method preserves the dataset’s natural distribution and can improve minority class recognition while avoiding the risks associated with synthetic sample generation.


This section customize EEG Transformer pipeline updated to incorporate Cost-Sensitive Learning. 
Here’s the refined summary for your EEG Transformer pipeline updated to incorporate **Cost-Sensitive Learning** instead of SMOTE:

**1. Imports and Setup**

* Utilizes PyTorch, NumPy, Matplotlib, Seaborn for deep learning and visualization.
* No external resampling method; relies on class weighting during training to handle imbalance.

**2. Dataset Preparation**

* Creates the `EEGDataset` directly from the original EEG arrays `(samples, channels, timepoints)` and labels, **without any resampling or augmentation**.
* This preserves the original data distribution.

**3. Data Splitting and Loading**

* Splits the dataset into training and validation sets using `train_test_split`.
* Creates DataLoaders for efficient batching during training and validation.

**4. Cost-Sensitive Learning Setup**

* Calculates class weights to penalize the loss function for underrepresented classes more heavily.
* Formula used:
In all, Cost-Sensitive Learning via Class Weights:

We compute class weights using the formula:

weight𝑖 = total samples(
number of classes
×
count
𝑖
)
weight 
i
​
 = 
(number of classes×count 
i
​
 )
total samples
​
 
These weights are passed to nn.CrossEntropyLoss. This causes misclassifications of the minority class to be penalized more during training

* Class weights are converted to tensors and moved to the training device (CPU/GPU).

**5. Model Architecture: EEGTransformer**

* Same Transformer-based architecture tailored for EEG input with positional encoding, multi-head attention, normalization, dropout, and a classification head.
* Input dimensions match the dataset shape.

**6. Training Loop with Weighted Loss**

* Uses `CrossEntropyLoss` with the computed class weights for **cost-sensitive training** to mitigate class imbalance impact.
* Tracks training and validation losses, accuracies, and prediction probabilities.

**7. Evaluation and Visualization**

* Produces comprehensive plots to assess performance:

  * Training & validation loss and accuracy curves.
  * ROC curve and Precision-Recall curve on validation data.
  * Confusion matrix with precision, recall, F1-score, and accuracy metrics.

In [None]:
# Load dataset (update paths as needed)
X_path = "Neontal_eeg_dataset1/annotations/Normalized_Updated_annonated_X_features.npy"
y_path = "Neontal_eeg_dataset1/annotations/Normalized_Updated_annonated_y_features.npy"

dataset = EEGDatasetFromNumpy(X_path, y_path)
print("Dataset X shape:", dataset.X.shape, "y shape:", dataset.y.shape)

# =====================================
# Visualize Class Distribution Before Cost-Sensitive Adjustment
# =====================================
unique, counts = np.unique(dataset.y, return_counts=True)
plt.figure()
plt.bar(unique, counts, color='skyblue')
plt.title("Class Distribution Before Cost-Sensitive Adjustment")
plt.xlabel("Class")
plt.ylabel("Count")
plt.show()

In [None]:
# Create dataset from the original arrays (no resampling)
dataset = EEGDataset(dataset.X, dataset.y)

# Split dataset into training and validation sets
indices = list(range(len(dataset)))
train_idx, val_idx = train_test_split(indices, test_size=0.2, random_state=42)
train_dataset = torch.utils.data.Subset(dataset, train_idx)
val_dataset   = torch.utils.data.Subset(dataset, val_idx)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# ------------------------------------
# Cost-Sensitive Learning: Compute Class Weights
# ------------------------------------
# Compute weights using the formula:
#    weight_i = total_samples / (num_classes * count_i)
total_samples = len(dataset.y)
num_classes = len(unique)
class_weights = [total_samples / (num_classes * count) for count in counts]
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(device)
print("Class Weights:", class_weights)

# Initialize model and training components
model = EEGTransformer(num_channels=dataset.X.shape[1], num_timepoints=dataset.X.shape[2]).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Train the model using cost-sensitive learning
(train_losses, val_losses, train_accuracies, val_accuracies, 
 train_labels, val_labels, train_preds, val_preds, 
 train_probs, val_probs) = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device)

# Plot performance curves
plot_loss_curve(train_losses, val_losses, title="Training and Validation Loss")
plot_accuracy_curve(train_accuracies, val_accuracies, title="Training and Validation Accuracy")
plot_roc_curve(val_labels, val_probs, title="Validation ROC Curve")
plot_precision_recall_curve(val_labels, val_probs, title="Validation Precision-Recall Curve")
plot_confusion_matrix_and_metrics(val_labels, val_preds, title="Validation Confusion Matrix")

In [35]:
# Save the model state dictionary after training
torch.save(model.state_dict(), 'costsenstive_custom_neonatal_eeg_transformer.pth')