# <b> Detecting Seizure Activity in EEG Using Spiking Neural Networks with the CHB-MIT Dataset <b>

##### This tutorial provides a step-by-step demonstration of using Spiking Neural Networks (SNNs) for seizure detection in EEG data. The dataset, preprocessing steps, SNN model design, and evaluation metrics are discussed comprehensively.

### Author: Kai Juarez-Jimenez

--- 


## <b> I. Introduction <b>
### <b>Objective </b>
#### <b> Importance of Seizure Detection in EEG Data:</b> Seizure detection plays a vital role in advancing research and innovation aimed at developing solutions for epilepsy, such as wearable devices or neuromorphic chips. It also supports the accurate diagnosis of epilepsy and aids in identifying the precise location of seizure activity, which is critical for effective treatment planning [2].

#### <b> How Do SNNs Address This?: </b> Spiking Neural Networks (SNNs) leverage their ability to process temporal information and replicate firing patterns similar to actual neurons. This capability enables them to effectively identify changes in brain activity associated with seizures, providing a more biologically inspired approach to analysis.

#### <b> Tutorial's Aim: </b> This tutorial demonstrates the end-to-end process of:
#### 1. Preprocessing EEG data from the CHB-MIT dataset.
#### 2. Training an SNN model tailored to detect seizures.
#### 3. Evaluating model performance using metrics like AUC-ROC and confusion matrices.


---

## <b> II. Data and Preprocessing </b>
### <b> Dataset Overview </b>
#### <b> About the CHB-MIT Dataset:</b>The CHB-MIT data comprises EEG recordings from 23 patients who experienced seizures, with a total of 24 cases, including one patient with two recordings taken 1.5 years apart. It contains 969 hours of scalp EEG data capturing 173 seizure events. The patients range in age from 3 to 22 years old [1].

#### <b> Channels: </b> Each EEG recording contains data from multiple channels (23 channels by default in CHB-MIT), representing different electrodes placed on the scalp. Each channel records the electrical activity at a specific location on the brain [1].

#### <b> Data: </b> The raw EEG signals are stored as continuous time-series data, often in digital formats such as .edf, a standard for storing biomedical signals. 

#### <b> Frequency: </b> The CHB-MIT dataset has a fixed sampling rate of 256 Hz, meaning that 256 data points are recorded per second for each channel. This high sampling rate is sufficient to capture fast-changing brain activity, including seizure-related patterns [1]. 

<div class = "alert alert-block alert-success" > 
    Before starting, download the dataset, which contains recordings across 23 channels. Note that the full dataset is approximately 43 GB in size, so ensure you have sufficient storage available. Dataset download available at: https://physionet.org/content/chbmit/1.0.0/#files-panel
</div>

---

### <b> Preprocessing </b>
#### Before training the model, the raw EEG data requires preprocessing to extract meaningful features.
#### <b> EEG Segmentations: </b> To analyze EEG recordings effectively, we segment the data into smaller time windows. This step ensures localized analysis of the signal and facilitates applying frequency-domain transformations. Segmenting the data also simplifies computational processing, making it feasible to apply machine learning techniques.
#### <b> Fourier Transformation: </b> Fourier Transformation is applied to each segment to convert the signal from the time domain to the frequency domain.


<div class = "alert alert-block alert-success" > 
    <b> Why Fourier Transformation? </b>
    It converts the raw EEG data into the frequency domain, indentifying specific frequency bands (gamma 30-50 Hz) that significantly increase in power during a seizure, which can be used as a key indicator for seizure detection [4].
</div>


---

#### This code processes EEG signals stored in EDF files to extract meaningful features that can be used for seizure detection. It focuses on the gamma-band power (30–50 Hz), which is a known indicator of seizure activity. The extracted features are saved incrementally to a CSV file for easy access and analysis.

In [122]:
import numpy as np
import pandas as pd
import os
from scipy.signal import stft
import pyedflib
import glob

def segment_and_transform(signal, fs=256, window_size=256):
    """segmenting signal into smaller windows and apply fourier transformation."""
    segments = [signal[i:i+window_size] for i in range(0, len(signal), window_size) if i+window_size <= len(signal)]
    features = []
    for segment in segments:
        f, _, Zxx = stft(segment, fs=fs, nperseg=window_size)
        gamma_band = (f >= 30) & (f <= 50)
        gamma_power = np.abs(Zxx[gamma_band]).mean()
        features.append(gamma_power)
    return np.array(features).mean()  # aggregate features

def preprocess_eeg(filepath, fs=256, window_size=256):
    """preprocess all channels in a file."""
    try:
        print(f"Processing: {filepath}")
        with pyedflib.EdfReader(filepath) as f:
            n = f.signals_in_file
            signals = np.zeros((n, f.getNSamples()[0]))
            for i in range(n):
                signals[i, :] = f.readSignal(i)
        features = [segment_and_transform(channel, fs, window_size) for channel in signals]
        # ensure exactly 23 features
        if len(features) > 23:
            features = features[:23]  # truncate to 23
        elif len(features) < 23:
            features += [0] * (23 - len(features))  # pad with zeros
        return np.array(features)
    except Exception as e:
        print(f"error processing file {filepath}: {e}")
        return None


# dir containing the dataset
data_dir = './CHBMIT'
file_paths = glob.glob(f'{data_dir}/**/*.edf', recursive=True)

# csv file for saving features
output_file = 'preprocessed_features.csv'

# initialize or load existing data
if os.path.exists(output_file):
    print(f"resuming from existing file: {output_file}")
    processed_files = [os.path.abspath(p) for p in pd.read_csv(output_file)['file_path'].tolist()]
    all_features = pd.read_csv(output_file)
else:
    processed_files = []
    all_features = pd.DataFrame(columns=['file_path'] + [f'feature_{i+1}' for i in range(23)])

# preprocess and save incrementally
for file_path in file_paths:
    if os.path.abspath(file_path) in processed_files:
        #print(f"skipping already processed file: {file_path}")
        continue
    print(f"processing file: {file_path}")
    features = preprocess_eeg(file_path)  # preprocess file
    if features is None or len(features) == 0:
        print(f"skipping file due to invalid or empty features: {file_path}")
        continue
    feature_row = [os.path.abspath(file_path)] + features.tolist()
    new_row = pd.DataFrame([feature_row], columns=all_features.columns)
    all_features = pd.concat([all_features, new_row], ignore_index=True)  # append new row
    all_features.to_csv(output_file, index=False)  # save incrementally
    print(f"saved features for file: {file_path}")

#  (verify) the saved features
print(f"shape of the dataset features: {all_features.shape}")

# preprocessed_data = pd.read_csv('preprocessed_features.csv')
# print(preprocessed_data.head())


resuming from existing file: preprocessed_features.csv
shape of the dataset features: (686, 24)


---

### <b> Load and Inspect the Preprocessed Features <b>
#### The preprocessed features stored in the preprocessed_features.csv file contain critical information extracted from EEG signals, such as gamma-band power for each channel. Loading this file allows us to verify the data structure and ensure it aligns with our modeling requirements.

<div class="alert alert-block alert-success"> 
    <b>Why Is This Important?</b><br>
    <b>Data Integrity:</b> Verifying the shape and structure ensures that the preprocessing pipeline worked as intended and the dataset contains the expected number of features for each EEG file.<br>
    <b>Preparation for Merging:</b> This inspection is a crucial step before merging the features with labels, ensuring compatibility in downstream tasks like training and evaluation.
</div>


In [123]:
import pandas as pd

# load preprocessed features
data_file = 'preprocessed_features.csv'
data = pd.read_csv(data_file)
print(f"dataset shape: {data.shape}")  # should be (686, 24) rows, col
print(data.head())  # inspect the first few rows


dataset shape: (686, 24)
                                   file_path  feature_1  feature_2  feature_3  \
0  /Users/kjuarezj/CHBMIT/chb11/chb11_53.edf   0.335455   0.425711   0.319864   
1  /Users/kjuarezj/CHBMIT/chb11/chb11_92.edf   1.194302   1.326720   1.355358   
2  /Users/kjuarezj/CHBMIT/chb11/chb11_82.edf   0.948020   1.374275   1.305503   
3  /Users/kjuarezj/CHBMIT/chb11/chb11_55.edf   0.822861   1.004402   0.914919   
4  /Users/kjuarezj/CHBMIT/chb11/chb11_54.edf   0.302837   0.320896   0.251632   

   feature_4  feature_5  feature_6  feature_7  feature_8  feature_9  ...  \
0   0.467975   0.001104   0.390983   0.483227   0.351886   0.382416  ...   
1   0.914141   0.001104   0.915612   0.531348   0.400397   0.460694  ...   
2   0.501360   0.001104   0.826069   0.363019   0.249405   0.309681  ...   
3   0.607224   0.001104   0.780986   0.471013   0.413295   0.492717  ...   
4   0.326476   0.001104   0.345955   0.339471   0.260942   0.309337  ...   

   feature_14  feature_15  feat

---

### <b> Data Splitting </b> 
#### This code prepares the dataset for machine learning by organizing EEG recordings into two labeled groups: "seizure" (1) and "non-seizure" (0). It ensures a balanced split of the data into training (80%) and testing (20%) sets while maintaining the proportion of seizure and non-seizure recordings. Finally, it saves the training and testing sets to separate CSV files for use in machine learning or analysis.

<div class = "alert alert-block alert-success" > 
    <b> Debugging: </b>
    Print statements are invaluable for ensuring your dataset is clean and correctly processed! <3 Verify the total file counts, inspect the splits, and ensure there are no missing or mislabeled files. Additionally, printing the first few rows of the labeled DataFrame can help confirm that the labeling process was executed as expected
</div>


In [104]:
import os
import pandas as pd
from sklearn.model_selection import train_test_split

# load seizure files with clean handling
seizure_files_path = "./CHBMIT/RECORDS-WITH-SEIZURES"  # replace with actual path
with open(seizure_files_path, "r") as f:
    seizure_files = [line.strip() for line in f if line.strip() and line.strip().endswith(".edf")]

# add base path to seizure files
seizure_files = {os.path.join("./CHBMIT", file) for file in seizure_files}

# load all files with clean handling
all_files_path = "./CHBMIT/RECORDS"  # replace with actual path
with open(all_files_path, "r") as f:
    all_files = [line.strip() for line in f if line.strip() and line.strip().endswith(".edf")]

# add base path to all files
all_files = {os.path.join("./CHBMIT", file) for file in all_files}

# debug step: ensuring  all files and seizure files are clean
print(f"total files: {len(all_files)}") #686
print(f"seizure files: {len(seizure_files)}") #141

# find non-seizure files
non_seizure_files = all_files - seizure_files
print(f"non-seizure files: {len(non_seizure_files)}") #545

# create dataframes
seizure_df = pd.DataFrame({"file_path": list(seizure_files), "label": 1})
non_seizure_df = pd.DataFrame({"file_path": list(non_seizure_files), "label": 0})

# combine into a single DataFrame
all_labels_df = pd.concat([seizure_df, non_seizure_df], ignore_index=True)


print(f"total labeled files: {len(all_labels_df)}")  # should match total files in RECORDS (686)

# sort the DataFrame for consistency
all_labels_df = all_labels_df.sort_values("file_path").reset_index(drop=True)

# display the first __ rows of the labeled data
print("first __ rows of the labeled data:")
print(all_labels_df.head(5))

# save the dataframe to a csv file
output_file = "labeled_files.csv"
all_labels_df.to_csv(output_file, index=False)
print(f"labeled files saved to: {os.path.abspath(output_file)}")

# split the dataset into training and testing sets
print("\nsplitting the dataset into training and testing sets...")

# extract file paths and labels
file_paths = all_labels_df["file_path"].values
labels = all_labels_df["label"].values

# split into training and testing sets
train_files, test_files, train_labels, test_labels = train_test_split(
    file_paths, labels, test_size=0.2, random_state=42, stratify=labels
)

# debuging core: verifying the splits
print(f"training set size: {len(train_files)}, Testing set size: {len(test_files)}")
print(f"training labels: {sum(train_labels)} seizures, {len(train_labels) - sum(train_labels)} non-seizures")
print(f"testing labels: {sum(test_labels)} seizures, {len(test_labels) - sum(test_labels)} non-seizures")

# save the splits to csv files for future reference (optional tbh) 
train_df = pd.DataFrame({"file_path": train_files, "label": train_labels})
test_df = pd.DataFrame({"file_path": test_files, "label": test_labels})

train_output_file = "train_files.csv"
test_output_file = "test_files.csv"

train_df.to_csv(train_output_file, index=False)
test_df.to_csv(test_output_file, index=False)

print(f"training set saved to: {os.path.abspath(train_output_file)}")
print(f"testing set saved to: {os.path.abspath(test_output_file)}")


total files: 686
seizure files: 141
non-seizure files: 545
total labeled files: 686
first __ rows of the labeled data:
                     file_path  label
0  ./CHBMIT/chb01/chb01_01.edf      0
1  ./CHBMIT/chb01/chb01_02.edf      0
2  ./CHBMIT/chb01/chb01_03.edf      1
3  ./CHBMIT/chb01/chb01_04.edf      1
4  ./CHBMIT/chb01/chb01_05.edf      0
labeled files saved to: /Users/kjuarezj/labeled_files.csv

splitting the dataset into training and testing sets...
training set size: 548, Testing set size: 138
training labels: 113 seizures, 435 non-seizures
testing labels: 28 seizures, 110 non-seizures
training set saved to: /Users/kjuarezj/train_files.csv
testing set saved to: /Users/kjuarezj/test_files.csv


---

### <b> Merging the Datasets  </b> 
#### The datasets must be successfully merged, ensuring file paths align correctly and that features are combined with their corresponding labels. This process creates a feature matrix X (features) and a label array y (labels), which are critical for model training.

In [105]:
import pandas as pd
import os

# oad both datasets
labeled_data = pd.read_csv("labeled_files.csv")
features_data = pd.read_csv("preprocessed_features.csv")


# step 2: convert relative paths to absolute paths in labeled data
labeled_data["file_path"] = labeled_data["file_path"].apply(os.path.abspath)



# step 4: merge the datasets
merged_data = pd.merge(labeled_data, features_data, on="file_path", how="inner")

# debugging merged data
print("\nFirst few rows of merged data:")
print(merged_data.head())


# step 5: save the merged data
merged_data.to_csv("merged_data.csv", index=False)
print("\nMerged data saved to 'merged_data.csv'")

# step 6: extract features and labels
X = merged_data.iloc[:, 2:25].values  # features (23 columns)
y = merged_data["label"].values       # labels

# verify extracted data
print(f"\nX shape: {X.shape}")  # should be (number of merged rows, 23)
print(f"y shape: {y.shape}")  # should be (number of merged rows,)




First few rows of merged data:
                                   file_path  label  feature_1  feature_2  \
0  /Users/kjuarezj/CHBMIT/chb01/chb01_01.edf      0   1.904987   1.282136   
1  /Users/kjuarezj/CHBMIT/chb01/chb01_02.edf      0   1.023876   0.692827   
2  /Users/kjuarezj/CHBMIT/chb01/chb01_03.edf      1   0.461945   0.314066   
3  /Users/kjuarezj/CHBMIT/chb01/chb01_04.edf      1   0.547646   0.392939   
4  /Users/kjuarezj/CHBMIT/chb01/chb01_05.edf      0   1.508249   1.026032   

   feature_3  feature_4  feature_5  feature_6  feature_7  feature_8  ...  \
0   1.160174   0.660278   2.245718   1.012637   0.605064   0.827183  ...   
1   0.688489   0.656426   1.179537   0.777356   0.439819   0.825834  ...   
2   0.315250   0.430209   0.490375   0.572592   0.380799   0.585074  ...   
3   0.360169   0.444842   0.581529   0.576234   0.382254   0.619699  ...   
4   0.933145   0.472396   1.744471   0.703685   0.443001   0.561593  ...   

   feature_14  feature_15  feature_16  feature_1

---

### <b> Splitting the Data </b> 
#### To evaluate the performance of our model effectively, we must split the dataset into training and testing sets. The training set is used to train the model, while the testing set serves as unseen data to validate the model's ability to generalize.

<div class = "alert alert-block alert-success" > 
    <b> Why Split the Data? </b>
   To evaluate the performance of our model effectively, we must split the dataset into training and testing sets. The training set is used to train the model, while the testing set serves as unseen data to validate the model's ability to generalize.
</div>

In [106]:
import pandas as pd
from sklearn.model_selection import train_test_split

# load merged data
merged_data = pd.read_csv("merged_data.csv")  # replace with actual merged file path

# extract features and labels
X = merged_data.iloc[:, 2:].values  # features (columns 2–24)
y = merged_data["label"].values     # labels

# verify shapes
print(f"X shape: {X.shape}")  # should be (686, 23)
print(f"y shape: {y.shape}")  # should be (686,)




X shape: (686, 23)
y shape: (686,)


<div class="alert alert-block alert-success"> <b> Debugging Tip: </b> Always verify the shapes of your feature matrix (X) and label vector (y) after splitting to ensure the data is correctly prepared for training and evaluation. </div>

---

### <b> Train-Test Split </b>
#### <b> Purpose of the Train-Test Split: </b> The train-test split is crucial for evaluating a model's performance. By dividing the data:
#### <b> Training Set: </b> Used to train the model, enabling it to learn patterns from the EEG data.
#### <b> Testing Set: </b> Held out to validate the model's ability to generalize to unseen data, ensuring it isn't overfitting.

In [107]:
# split into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

# print splits for verification
print(f"training set size: {len(X_train)}")
print(f"testing set size: {len(X_test)}")
print(f"class distribution in training set: {pd.Series(y_train).value_counts().to_dict()}")
print(f"class distribution in testing set: {pd.Series(y_test).value_counts().to_dict()}")


training set size: 548
testing set size: 138
class distribution in training set: {0: 435, 1: 113}
class distribution in testing set: {0: 110, 1: 28}


<div class="alert alert-block alert-success"> <b> Stratification: </b>
Stratifying the split ensures the same proportion of "seizure" (1) and "non-seizure" (0) labels in both training and testing sets. This step is essential because imbalanced datasets can lead to biased models. </div>

---

### <b> Convert Data into PyTorch Tensors </b> 
#### <b> Purpose:</b> To train and evaluate the Spiking Neural Network (SNN) in PyTorch, the data must be converted into tensors. PyTorch tensors are the primary data structure used in PyTorch for computations on GPUs. The DataLoader is used to efficiently manage the data during training and testing [6].

In [124]:
import torch
from torch.utils.data import DataLoader, TensorDataset

# convert to tensors
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.long)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.long)

# create dataLoaders
train_loader = DataLoader(TensorDataset(X_train_tensor, y_train_tensor), batch_size=32, shuffle=True)
test_loader = DataLoader(TensorDataset(X_test_tensor, y_test_tensor), batch_size=32, shuffle=False)

# check tensor shapes
print(f"X_train_tensor shape: {X_train_tensor.shape}, y_train_tensor shape: {y_train_tensor.shape}")
print(f"X_test_tensor shape: {X_test_tensor.shape}, y_test_tensor shape: {y_test_tensor.shape}")

# check dataLoader length
print(f"Number of training batches: {len(train_loader)}")
print(f"Number of testing batches: {len(test_loader)}")

print("DataLoaders created successfully.")


X_train_tensor shape: torch.Size([548, 23]), y_train_tensor shape: torch.Size([548])
X_test_tensor shape: torch.Size([138, 23]), y_test_tensor shape: torch.Size([138])
Number of training batches: 18
Number of testing batches: 5
DataLoaders created successfully.


<div class="alert alert-block alert-success"> <b> Why Use DataLoaders? </b> DataLoaders simplify data handling during training by automating the batching and shuffling process. They ensure optimal GPU/CPU utilization by loading data efficiently in the background, allowing seamless training even with large datasets. </div>


---

## <b> III. Model Architecture and Training</b> 

### <b> Nework Design </b>
#### <b> Input Layer:</b> 23 Neuron corresponding to the 23 preprocessed EEG channels. Responsible for feeding the preprocessed features into the network for analysis. 
#### <b> Hidden Layer:</b> 1000 Leaky Integrate-and-Fire (LIF) Neurons designed to capture temporal dependencies and patterns in EEG data over time. Utilizes spiking dynamics for event-based processing, emulating biological neurons.
#### <b> Output Layer:</b> 2 Neurons designed for binary classification—detecting whether a seizure event is present (1) or absent (0).



### <b> Leaky Integrate-and-Fire</b>  
#### The LIF model governs the behavior of the hidden layer, dynamically updating each neuron's membrane potential based on incoming input [5]. A spike is generated whenever the membrane potential exceeds a threshold, after which the potential resets.

$$
\tau_m \frac{dV}{dt} = - (V - V_{rest}) + R_m I(t)
$$
where:

- $\tau_m$ is the membrane time constant, controlling the rate of decay of the membrane potential.
- $V$ is the membrane potential, representing the current state of the neuron.
- $V_{rest}$ is the resting membrane potential, the baseline state of the neuron.
- $R_m$ is the membrane resistance, determining how much the membrane potential changes in response to an input current.
- $I(t)$ is the input current, derived from the preceding layer or external input.


#### The model is a SNN designed for binary classification of EEG data. It leverages the bio inspiration of spiking neurons to process temporal dependencies in the data. The architecture consists of input, hidden, and output layers, with (LIF) neurons modeling the hidden and output dynamics.

In [168]:
import torch
import torch.nn as nn
import snntorch as snn

class SNNModel(nn.Module):
    def __init__(self, input_size=23, hidden_size=1000, output_size=2, beta=0.9, num_steps=25):
        super(SNNModel, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.lif1 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.lif2 = snn.Leaky(beta=beta)
        self.num_steps = num_steps

    def forward(self, x):
        # initialize states for LIF neurons
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        # record the output of the final layer across time steps
        spk2_rec = []
        mem2_rec = []

        for step in range(self.num_steps):
            cur1 = self.fc1(x)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        # return spikes and membrane potentials across time steps
        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)

# initialize the model
model = SNNModel()
print(model)


SNNModel(
  (fc1): Linear(in_features=23, out_features=1000, bias=True)
  (lif1): Leaky()
  (fc2): Linear(in_features=1000, out_features=2, bias=True)
  (lif2): Leaky()
)


---

<div class = "alert alert-block alert-success" > 
    <b> Why LIF? </b>
   The use of LIF neurons allows the model to leverage temporal dynamics inherent in EEG signals[5].
</div>

<div class = "alert alert-block alert-success" > 
    <b> Why Spiking Neural Networks? </b>
    It efficientaly processes temporal and event based data like EEG signals. Spike based computation mimics brain activity [7]. 
</div>

<div class = "alert alert-block alert-success" > 
    <b> Why This Architexture? </b>
   This architecture is optimized for capturing the temporal dependencies inherent in EEG signals. By leveraging the event-based processing capabilities of Spiking Neural Networks (SNNs), the design provides a biologically inspired approach to seizure detection [8]. The LIF neurons in the hidden layer excel at detecting temporal patterns, making the network particularly effective for binary classification tasks like seizure detection.
</div>


---

### <b> Define Training Parameters </b>
#### <b> Optimizer and Loss Function: </b> To train the Spiking Neural Network (SNN), we use the following:
#### <b> Optimizer:</b>  Adam, which adapts the learning rate for each parameter to optimize convergence [3].
#### <b> Learning Rate:</b> Set to a small value (0.001) for gradual learning, ensuring the model captures nuanced patterns in EEG data.
#### <b> Loss Function:</b> CrossEntropyLoss, a standard choice for binary classification, computes the difference between predicted probabilities and true labels [3].


In [180]:
from torch.optim import Adam

# optimizer and loss function
optimizer = Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()

# debug: check optimizer parameters
print("Optimizer initialized with the following parameters:")
for param_group in optimizer.param_groups:
    print(f"Learning Rate: {param_group['lr']}")


print("Optimizer and loss function initialized and validated successfully.")


Optimizer initialized with the following parameters:
Learning Rate: 0.001
Optimizer and loss function initialized and validated successfully.


---

### <b> Training the Model </b>
#### This section focuses on training the SNN using the training dataset. The training process leverages the cross-entropy loss function, a standard loss function for classification tasks, and gradient-based backpropagation with surrogate gradients. Surrogate gradients enable the training of spiking neurons by approximating gradients for non-differentiable spiking activation functions [3].

In [185]:
import torch
import torch.nn as nn
from torch.optim import Adam
from tqdm import tqdm
import snntorch as snn
from sklearn.metrics import roc_auc_score, confusion_matrix, classification_report
from torch.nn.functional import softmax

# parameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_steps = 25  # Number of simulation steps

# assuming model, train_loader, test_loader are already defined
loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.001)

# training Loop
num_epochs = 15
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch_X, batch_y in tqdm(train_loader, desc=f"Epoch {epoch + 1}"):
        batch_X, batch_y = batch_X.to(device), batch_y.to(device)

        optimizer.zero_grad()

        # forward pass
        spk_rec, mem_rec = model(batch_X)

        # loss calculation
        loss_val = torch.zeros((1), device=device)
        for step in range(num_steps):
            loss_val += loss_fn(mem_rec[step], batch_y)

        # backward pass
        loss_val.backward()
        optimizer.step()

        total_loss += loss_val.item()

    print(f"Epoch {epoch + 1}, Loss: {total_loss / len(train_loader)}")

# testing Loop with AUC-ROC
model.eval()
test_loss = 0
correct = 0
total = 0
true_labels = []
predicted_probs = []

with torch.no_grad():
    for batch_X, batch_y in tqdm(test_loader, desc="Testing"):
        batch_X, batch_y = batch_X.to(device), batch_y.to(device)

        spk_rec, mem_rec = model(batch_X)

        # Summed loss over time
        loss_val = torch.zeros((1), device=device)
        for step in range(num_steps):
            loss_val += loss_fn(mem_rec[step], batch_y)

        test_loss += loss_val.item()

        # Collect probabilities for AUC-ROC
        probabilities = softmax(spk_rec.sum(dim=0), dim=1)  # Sum over time and apply softmax
        predicted_probs.extend(probabilities[:, 1].cpu().numpy())  # Probability for positive class
        true_labels.extend(batch_y.cpu().numpy())

        # Accuracy Calculation
        _, predicted = spk_rec.sum(dim=0).max(1)
        correct += (predicted == batch_y).sum().item()
        total += batch_y.size(0)

print(f"Test Loss: {test_loss / len(test_loader):.4f}, Test Accuracy: {test_accuracy:.2f}")


Epoch 1: 100%|████████████████████████████████| 18/18 [00:00<00:00, 73.06it/s]


Epoch 1, Loss: 10.321447372436523


Epoch 2: 100%|████████████████████████████████| 18/18 [00:00<00:00, 79.01it/s]


Epoch 2, Loss: 9.645870526631674


Epoch 3: 100%|████████████████████████████████| 18/18 [00:00<00:00, 81.47it/s]


Epoch 3, Loss: 9.829082594977486


Epoch 4: 100%|████████████████████████████████| 18/18 [00:00<00:00, 83.35it/s]


Epoch 4, Loss: 10.852201249864367


Epoch 5: 100%|████████████████████████████████| 18/18 [00:00<00:00, 83.50it/s]


Epoch 5, Loss: 11.142020596398247


Epoch 6: 100%|████████████████████████████████| 18/18 [00:00<00:00, 69.05it/s]


Epoch 6, Loss: 10.889189614189995


Epoch 7: 100%|████████████████████████████████| 18/18 [00:00<00:00, 81.75it/s]


Epoch 7, Loss: 11.645541508992514


Epoch 8: 100%|████████████████████████████████| 18/18 [00:00<00:00, 84.28it/s]


Epoch 8, Loss: 10.80745177798801


Epoch 9: 100%|████████████████████████████████| 18/18 [00:00<00:00, 83.86it/s]


Epoch 9, Loss: 9.705048031277126


Epoch 10: 100%|███████████████████████████████| 18/18 [00:00<00:00, 81.38it/s]


Epoch 10, Loss: 9.561542722913954


Epoch 11: 100%|███████████████████████████████| 18/18 [00:00<00:00, 81.78it/s]


Epoch 11, Loss: 8.944989654752943


Epoch 12: 100%|███████████████████████████████| 18/18 [00:00<00:00, 84.16it/s]


Epoch 12, Loss: 9.114972670873007


Epoch 13: 100%|███████████████████████████████| 18/18 [00:00<00:00, 83.45it/s]


Epoch 13, Loss: 8.85685912768046


Epoch 14: 100%|███████████████████████████████| 18/18 [00:00<00:00, 82.87it/s]


Epoch 14, Loss: 9.71735077434116


Epoch 15: 100%|███████████████████████████████| 18/18 [00:00<00:00, 84.23it/s]


Epoch 15, Loss: 9.51416187816196


Testing: 100%|█████████████████████████████████| 5/5 [00:00<00:00, 141.23it/s]

Test Loss: 14.3072, Test Accuracy: 80.43





<div class = "alert alert-block alert-success" > 
    <b> Why This Approach? </b>
    The gradient-based backpropagation with surrogate gradients allows the model to learn temporal patterns in EEG data. The spiking neuron architecture, with time-stepped simulation, ensures temporal dependencies are captured effectively.

<b> Temporal Learning:</b> The spiking dynamics allow the network to analyze EEG signals in a biologically plausible way.<br>
<b> Loss Summation:</b> Accumulating the loss over num_steps ensures that the model learns from patterns across the entire time window.
</div>


## <b> VI. Results and Evaluation </b>
### <b>AUC-ROC and Confusion Matrix </b>
#### <b> AUC-ROC:</b> evaluates the model's ability to distinguish between classes across all decision thresholds. It is particularly valuable for seizure detection as EEG signals are inherently noisy and seizures may activate multiple regions simultaneously. AUC-ROC provides a balanced view of sensitivity (true positive rate) and specificity (true negative rate) over varying thresholds [9].

#### <b>Confusion Matrix:</b> The confusion matrix visually represents:
#### <b> True Positives (TP):</b> Correctly detected seizures.
#### <b> True Negatives (TN):</b> Correctly identified non-seizure instances.
#### <b> False Positives (FP):</b> Non-seizures incorrectly classified as seizures.
#### <b> False Negatives (FN):</b> Missed seizure detections.


<div class="alert alert-block alert-success"> <b> Why AUC-ROC? </b> AUC-ROC (Area Under the Receiver Operating Characteristic Curve) is a robust metric for assessing the model’s performance across all thresholds, mitigating biases introduced by noisy EEG signals. It offers insights into trade-offs between sensitivity and specificity, particularly crucial in medical applications like seizure detection [9]. </div>


In [189]:
# compute AUC-ROC
auc_score = roc_auc_score(true_labels, predicted_probs)
test_accuracy = 100.0 * correct / total

# confusion Matrix
predicted_labels = [1 if prob > 0.5 else 0 for prob in predicted_probs]
conf_matrix = confusion_matrix(true_labels, predicted_labels)
report = classification_report(true_labels, predicted_labels, target_names=["non-Seizure", "seizure"])

print(f" AUC-ROC: {auc_score:.2f}")
print("\nconfusion matrix:")
print(conf_matrix)
print("\nclassification report:")
print(report)

 AUC-ROC: 0.67

confusion matrix:
[[109   1]
 [ 24   4]]

classification report:
              precision    recall  f1-score   support

 non-Seizure       0.82      0.99      0.90       110
     seizure       0.80      0.14      0.24        28

    accuracy                           0.82       138
   macro avg       0.81      0.57      0.57       138
weighted avg       0.82      0.82      0.76       138



### <b> Analysis</b> 
#### <b> AUC-ROC:</b> The value of 0.67 suggests moderate model performance, with room for improvement in distinguishing between seizure and non-seizure classes.
#### <b> Confusion Matrix:</b> The model effectively classifies non-seizures but struggles with sensitivity for seizures, as reflected in the high false negative rate (24 FN).

#### <b> Classification Report:</b> Precision for seizures is relatively high (.80), but recall (0.14) indicates the model is missing a significant number of seizures. Balancing sensitivity and specificity will be critical in subsequent improvements.


---

## <b>V. Conclusion and Future Work  </b>
### <b> Performance of the SNN model: </b> 
#### The Spiking Neural Network (SNN) model achieved a test accuracy of 80.43%, with an AUC-ROC score of .66, demonstrating its potential for automated seizure detection. While the model effectively identified non-seizure instances, challenges in sensitivity highlight areas for improvement in seizure detection.
### <b> Challenges Encountered </b> 

#### <b> Noise in EEG Data:</b> Background noise and artifacts in the EEG recordings introduced additional complexity, necessitating robust preprocessing techniques.

#### <b>Class Imbalance:</b> The dataset's significant imbalance between seizure and non-seizure instances posed challenges for achieving high sensitivity.

#### <b> Biological Complexity:</b>  Seizure patterns are non-linear and often asynchronous, which can be challenging for binary classification.


## <b> VI. References  </b>
#### [1] Guttag, John. "CHB-MIT Scalp EEG Database" (version 1.0.0). PhysioNet (2010), https://doi.org/10.13026/C2K01R. 

#### [2] Shoeibi A, Khodatars M, Ghassemi N, Jafari M, Moridian P, Alizadehsani R, Panahiazar M, Khozeimeh F, Zare A, Hosseini-Nejad H, Khosravi A, Atiya AF, Aminshahidi D, Hussain S, Rouhani M, Nahavandi S, Acharya UR. Epileptic Seizures Detection Using Deep Learning Techniques: A Review. Int J Environ Res Public Health. 2021 May 27;18(11):5780. doi: 10.3390/ijerph18115780. PMID: 34072232; PMCID: PMC8199071.
#### [3] Jason K. Eshraghian, Max Ward, Emre Neftci, Xinxin Wang, Gregor Lenz, Girish Dwivedi, Mohammed Bennamoun, Doo Seok Jeong, and Wei D. Lu. “Training Spiking Neural Networks Using Lessons From Deep Learning”. arXiv preprint arXiv:2109.12894, September 2021.
#### [4] Wang, Z., Mengoni, P. Seizure classification with selected frequency bands and EEG montages: a Natural Language Processing approach. Brain Inf. 9, 11 (2022). https://doi.org/10.1186/s40708-022-00159-3
#### [5] Lu, Sijia, and Feng Xu. "Linear Leaky-Integrate-and-Fire Neuron Model-Based Spiking Neural Networks and Its Mapping Relationship to Deep Neural Networks." Frontiers in Neuroscience, vol. 16, 2022, https://doi.org/10.3389/fnins.2022.857513. 
#### [6] PyTorch Documentation. “Tensors.” PyTorch. https://pytorch.org/docs/stable/tensors.html
#### [7] Yamazaki K, Vo-Ho VK, Bulsara D, Le N. Spiking Neural Networks and Their Applications: A Review. Brain Sci. 2022 Jun 30;12(7):863. doi: 10.3390/brainsci12070863. PMID: 35884670; PMCID: PMC9313413.
#### [8] Hussein R, Palangi H, Ward RK, Wang ZJ. Optimized deep neural network architecture for robust detection of epileptic seizures using EEG signals. Clin Neurophysiol. 2019 Jan;130(1):25-37. doi: 10.1016/j.clinph.2018.10.010. Epub 2018 Nov 15. PMID: 30472579.
#### [9] Dastgoshadeh M, Rabiei Z. Detection of epileptic seizures through EEG signals using entropy features and ensemble learning. Front Hum Neurosci. 2023 Feb 1;16:1084061. doi: 10.3389/fnhum.2022.1084061. PMID: 36875740; PMCID: PMC9976189.