# SEEG Seizure Detection Script

This script processes SEEG data stored in EDF files to detect seizures using a pre-trained BrainBERT model and a logistic regression classifier. The script imports necessary libraries and custom modules, processes the EDF data, generates embeddings using BrainBERT, and predicts seizures using a logistic regression model.

## Code Overview

### Importing Libraries and Modules

The script begins by importing essential libraries such as `os`, `sys`, `numpy`, `torch`, and `joblib`, along with `OmegaConf` for configuration handling. It also imports custom functions from various modules related to BrainBERT and SEEG preprocessing.

### Setting the Working Directory

The working directory is set to the BrainBERT directory to ensure that the script has access to the necessary files and modules.

### Function to Process an EDF File and Make Predictions

A function is defined to:
- Preprocess the EDF file to create epochs and save them.
- Load the preprocessed data from the output directory.
- Generate BrainBERT embeddings for the preprocessed data.
- Use the logistic regression model to predict seizures.
- Identify and return epochs and channels where seizures are detected.

### Loading the BrainBERT Model

The BrainBERT model is loaded from a specified checkpoint path.

### Loading the Logistic Regression Model

The trained logistic regression model is loaded from a joblib file.

### Defining Paths

Paths for the EDF directory, epoch output directory, and concatenated output directory are defined to organize the processing workflow.

### Processing the EDF File

The main script processes the EDF file using the defined function, identifies seizure epochs and channels, and prints the results.

### Example Usage

The script demonstrates how to use the defined functions to detect seizures in an EDF file and output the epochs and channels where seizures are detected.


In [6]:
# Import Libraries and Modules
import os
import sys
import numpy as np
import torch
from omegaconf import OmegaConf
import joblib
from sklearn.linear_model import LogisticRegression
import time

print("Starting script...")

# Set the working directory to the BrainBERT directory
os.chdir('/home/vineetreddy/Dropbox/CZW_MIT/BrainBERT')
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)  # Add the parent directory to the system path

print("Changed directory and updated sys.path...")

#from train_brainbert_logreg import *
from preprocess_edf_pipeline import *  # Import custom preprocessing functions
from demo_brainbert_annotated import *  # Import custom functions
import models  # Custom models

print("Imported custom functions and models...")




Starting script...
Changed directory and updated sys.path...
Imported custom functions and models...


In [7]:
# Function to load pre-trained model weights and configuration
def load_brainbert_model(ckpt_path):
    """
    Loads the BrainBERT model with pre-trained weights.

    Args:
        ckpt_path (str): Path to the checkpoint file.

    Returns:
        torch.nn.Module: BrainBERT model with loaded weights.
    """
    cfg = OmegaConf.create({"upstream_ckpt": ckpt_path})
    brainbert_model = build_model(cfg)  # Build the model with the given configuration
    brainbert_model.to('cuda')  # Move the model to GPU
    init_state = torch.load(ckpt_path)  # Load the initial state of the model
    load_model_weights(brainbert_model, init_state['model'], False)  # Load the model weights
    return brainbert_model

# Function to generate BrainBERT embeddings from example waveforms
def generate_brainbert_embeddings(model, example_wavs):
    """
    Generates BrainBERT embeddings for each example.

    Args:
        brainbert_model (torch.nn.Module): BrainBERT model.
        example_wavs (np.array): Array of example waveforms.

    Returns:
        np.array: Array of BrainBERT embeddings.
    """
    brainbert_outs = []
    for example_wav in example_wavs:
        # Get the Short-Time Fourier Transform (STFT) of the signal
        f, t, linear = get_stft(example_wav, 2048, clip_fs=25, nperseg=400, noverlap=350, normalizing="zscore", return_onesided=True)
        inputs = torch.FloatTensor(linear).unsqueeze(0).transpose(1, 2).to('cuda')  # Prepare inputs for the model
        mask = torch.zeros((inputs.shape[:2])).bool().to('cuda')  # Create a mask for the inputs
        with torch.no_grad():
            out = brainbert_model.forward(inputs, mask, intermediate_rep=True)  # Get the model output
        brainbert_outs.append(out.cpu().numpy())  # Append the output to the list

    # Concatenate and average the outputs
    brainbert_outs_arrr = np.concatenate(brainbert_outs, axis=0)
    brainbert_outs_arr = brainbert_outs_arrr.mean(axis=1)
    
    return brainbert_outs_arr

In [8]:
# Function to process an EDF file and make predictions
def process_edf_file(edf_file_path, model, logistic_model, epoch_output_dir, concat_output_dir):
    start_time = time.time()
    
    # Start of the function
    print("Entered process_edf_file function...")

    # Preprocess the EDF file to create epochs and save them
    print("Starting preprocessing...")
    preprocess_start_time = time.time()
    process_directory(edf_file_path, epoch_output_dir, concat_output_dir)
    print(f"Preprocessing took {time.time() - preprocess_start_time:.2f} seconds.")
    
    # Load the preprocessed data
    print("Loading preprocessed data...")
    load_start_time = time.time()
    example_wavs = []
    for filename in os.listdir(concat_output_dir):
        if filename.endswith('.npy'):
            filepath = os.path.join(concat_output_dir, filename)
            data = np.load(filepath)
            example_wavs.append(data)
    example_wavs = np.array(example_wavs)
    print(f"Loading preprocessed data took {time.time() - load_start_time:.2f} seconds.")
    
    # Generate BrainBERT embeddings
    print("Generating BrainBERT embeddings...")
    embedding_start_time = time.time()
    brainbert_outs_arr = generate_brainbert_embeddings(model, example_wavs)
    print(f"Generating embeddings took {time.time() - embedding_start_time:.2f} seconds.")
    
    # Predict using the logistic regression model
    print("Making predictions...")
    prediction_start_time = time.time()
    predictions = logistic_model.predict(brainbert_outs_arr)
    print(f"Making predictions took {time.time() - prediction_start_time:.2f} seconds.")
    
    # Identify epochs and channels with seizures
    print("Identifying seizures...")
    seizure_identification_start_time = time.time()
    seizure_epochs_channels = []
    for i, pred in enumerate(predictions):
        if pred == 1:
            channel_name = os.path.basename(concat_output_dir).split('_')[0]
            seizure_epochs_channels.append((i, channel_name))
    print(f"Identifying seizures took {time.time() - seizure_identification_start_time:.2f} seconds.")
    
    print(f"Total processing time: {time.time() - start_time:.2f} seconds.")
    
    return seizure_epochs_channels

print("Defined process_edf_file function...")

# Load the BrainBERT model
print("Loading BrainBERT model...")
ckpt_path = "/home/vineetreddy/Dropbox/CZW_MIT/stft_large_pretrained_256hz.pth"
brainbert_model = load_brainbert_model(ckpt_path)
print("BrainBERT model loaded...")

# Load the trained logistic regression model from a joblib file
print("Loading logistic regression model...")
logistic_model_path = "/home/vineetreddy/save model/logistic_model.joblib" #path to logistic model joblib file
logistic_model = joblib.load(logistic_model_path)
print("Logistic regression model loaded...")

# Define the paths
seizure_prediction_edf_dir = '/home/vineetreddy/edf predict'
epoch_output_dir = '/home/vineetreddy/epoch output dir'
concat_output_dir = '/home/vineetreddy/concat output dir'

print("Defined paths...")

# Process the EDF file and get the epochs and channels with seizures
print("Starting to process the EDF file...")
seizure_epochs_channels = process_edf_file(seizure_prediction_edf_dir, brainbert_model, logistic_model, epoch_output_dir, concat_output_dir)
print("Finished processing the EDF file...")

# Print the results
for epoch, channel in seizure_epochs_channels:
    print(f"Seizure detected in epoch {epoch} on channel {channel}")

Defined process_edf_file function...
Loading BrainBERT model...




BrainBERT model loaded...
Loading logistic regression model...
Logistic regression model loaded...
Defined paths...
Starting to process the EDF file...
Entered process_edf_file function...
Starting preprocessing...
Extracting EDF parameters from /home/vineetreddy/edf predict/sub-HUP060_ses-presurgery_task-ictal_acq-seeg_run-01_ieeg.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 188999  =      0.000 ...   377.998 secs...
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 1 contiguous segment
Setting up high-pass filter at 0.1 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal highpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.10
- Lower transition bandwidth: 0.10 Hz (-6 dB cutoff frequency: 0.05 Hz)
- Filter length: 16501 sample

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done  38 out of  38 | elapsed:    0.3s finished


Filtering raw data in 1 contiguous segment
Setting up band-stop filter

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower transition bandwidth: 0.50 Hz
- Upper transition bandwidth: 0.50 Hz
- Filter length: 1691 samples (6.605 s)

Not setting metadata
75 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 75 events and 1280 original time points ...
0 bad epochs dropped


[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done  38 out of  38 | elapsed:    0.1s finished


KeyboardInterrupt: 