In [1]:
import pyedflib
import numpy as np
import os
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from tensorflow.keras.utils import to_categorical

trainPath = r"D:\v2.0.3\edf\train"
evalPath = r"D:\v2.0.3\edf\eval"
labelPath = r"C:\Users\dalto\Box Sync\abnLabels.csv"

labelFrame = pd.read_csv(labelPath)

labels = []
names = []
tests = []
noLabelCnt = 0

for split in enumerate([os.listdir(trainPath), os.listdir(evalPath)]):
    test = split[0]

    subName = split[1]

    for sub in subName:
        # if sub is in labelFrame
        if len(labelFrame.loc[labelFrame['name'] == sub]) == 0:
            print("No label found for", sub)
            noLabelCnt += 1
        else:
            labels.append(labelFrame.loc[labelFrame['name'] == sub, 'label'].values[0])
            names.append(sub)
            tests.append(test)

df = pd.DataFrame({'name': names, 'label': labels, 'test': tests})

display(df)

No label found for aaaaaagt
No label found for dataset-tuh_task-binary_datatype-train_v6
No label found for aaaaaqfm
No label found for aaaaaqmo
No label found for aaaaaqxr


Unnamed: 0,name,label,test
0,aaaaaaac,3,0
1,aaaaaaag,3,0
2,aaaaaaar,0,0
3,aaaaaaav,0,0
4,aaaaaabg,0,0
...,...,...,...
613,aaaaaswc,3,1
614,aaaaatao,0,1
615,aaaaatba,0,1
616,aaaaatdq,3,1


In [2]:
import glob

# list all of the .edf files in each train and eval folder
healthyEdfs = []  # List to store the paths of all .edf files
seizEdfs = []  # List to store the paths of all .edf files
edfLabels = []  # List to store the labels corresponding to each .edf file

# Sample of the 'df' DataFrame (in your case, the DataFrame might already be available)
# df = pd.read_csv('path/to/your_dataframe.csv')

# Loop through each subject in the 'names' list
for subject in names:
    subject_path = os.path.join(trainPath, subject)
    
    # Find the label corresponding to the subject from the DataFrame
    subject_label = df[df['name'] == subject]['label'].values[0]
    
    # Loop through all the date of recording folders (e.g., s001_2002)
    if os.path.exists(subject_path):
        for recording_date in os.listdir(subject_path):
            recording_date_path = os.path.join(subject_path, recording_date)
            
            # Loop through all cap setup folders (e.g., 02_tcp_le)
            if os.path.isdir(recording_date_path):
                for cap_setup in os.listdir(recording_date_path):
                    cap_setup_path = os.path.join(recording_date_path, cap_setup)
                    
                    # Find all .edf files in the final directory
                    if os.path.isdir(cap_setup_path):
                        edf_files_in_dir = glob.glob(os.path.join(cap_setup_path, "*.edf"))
                        
                        # Add each .edf file to the list, and also add its corresponding label
                        for edf_file in edf_files_in_dir:
                            if subject_label == 0:
                                healthyEdfs.append(edf_file)
                            else:
                                seizEdfs.append(edf_file)

# check for any duplicate .edf files
healthyEdfs = list(set(healthyEdfs))
seizEdfs = list(set(seizEdfs))
edfFiles = healthyEdfs + seizEdfs

# add the labels to the edf files
edfLabels = [0] * len(healthyEdfs) + [3] * len(seizEdfs)

# print the number of edf files
print("Edf count:", len(edfFiles))

# print the number of subjects in each edfLabel class
print("Healthy count:", len([x for x in edfLabels if x == 0]))
print("Seizure count:", len([x for x in edfLabels if x == 3]))


#### ONLY FOR EXPORTING TO CSV ####
# # take everything after 'train\' and add it to a new list
# edfNames = [x.split('train\\')[1] for x in edfFiles]

# # save the edfNames and edfLabels to a csv
# edfDf = pd.DataFrame({'name': edfNames, 'label': edfLabels})
# edfDf.to_csv('edfFiles.csv', index=False)

Edf count: 4662
Healthy count: 3489
Seizure count: 1173


In [3]:
# Set the desired number of channels and fixed length of time points
TARGET_CHANNELS = 40  # Number of channels you want to have in the final data
TARGET_POINTS = 75000  # Fixed number of time points for each sample

# Function to load and preprocess each .edf file
def load_and_preprocess_edf(filePath, target_channels=TARGET_CHANNELS, target_points=TARGET_POINTS):
    # Load the raw EEG data
    RawEEGDataFile = mne.io.read_raw_edf(filePath, preload=True, verbose=False)
    RawEEGDataFile.interpolate_bads()

    # Get the raw data (channels × time)
    data = RawEEGDataFile.get_data()

    # Determine current number of channels
    current_channels, current_points = data.shape

    # Pad or truncate channels to make them equal to target_channels (e.g., 40)
    if current_channels < target_channels:
        # Pad with zeros if there are fewer channels than target_channels
        padding = target_channels - current_channels
        data = np.pad(data, ((0, padding), (0, 0)), mode='constant')
    else:
        # Truncate channels if there are more than target_channels
        data = data[:target_channels, :]

    # Interpolate or resample data to ensure target_points are present
    if current_points != target_points:
        data = np.array([np.interp(np.linspace(0, current_points - 1, target_points), np.arange(current_points), data[ch, :]) for ch in range(target_channels)])

    return data

In [4]:
# Cell 1: Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
import mne  # assuming mne is used for EEG data loading
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import gc

# Cell 2: Define a dataset class to load EEG data in segments
class EEGDataset(Dataset):
    def __init__(self, eeg_file_paths, target_channels=40, target_points=75000, segment_length=25000):
        self.eeg_file_paths = eeg_file_paths
        self.target_channels = target_channels
        self.target_points = target_points
        self.segment_length = segment_length

    def __len__(self):
        return len(self.eeg_file_paths)

    def __getitem__(self, idx):
        file_path = self.eeg_file_paths[idx]
        raw = mne.io.read_raw_edf(file_path, preload=True)
        
        # Apply preprocessing steps such as ICA or filtering here
        data = raw.get_data()  # Shape: (channels, time_points)
        
        # Pad or trim channels to match target_channels
        if data.shape[0] < self.target_channels:
            padding = np.zeros((self.target_channels - data.shape[0], data.shape[1]))
            data = np.vstack((data, padding))
        elif data.shape[0] > self.target_channels:
            data = data[:self.target_channels, :]
        
        # Split the data into segments to handle large files
        segments = []
        for start in range(0, data.shape[1], self.segment_length):
            end = min(start + self.segment_length, data.shape[1])
            segment = data[:, start:end]
            
            # Interpolate or compress each segment to match target_points
            if segment.shape[1] != self.target_points:
                segment = np.array([np.interp(np.linspace(0, 1, self.target_points), np.linspace(0, 1, segment.shape[1]), channel) for channel in segment])
            
            # Reshape segment to match the input dimensions required by the ResNet model
            segment = np.expand_dims(segment, axis=0)  # Add a batch dimension if needed
            segments.append(torch.tensor(segment, dtype=torch.float32))
        
        return segments

# Cell 3: List of all your EEG file paths
eeg_files = edfFiles

# Cell 4: Create a dataset and a dataloader for batch processing
batch_size = 2  # Set batch size based on your system's memory
segment_length = 5000  # Set segment length based on your system's memory

dataset = EEGDataset(eeg_files, segment_length=segment_length)

def collate_fn(batch):
    # Flatten the list of segments and create a new batch
    segments = [segment for segments in batch for segment in segments]
    return torch.stack(segments)

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

# Cell 5: Define the ResNet model
class ResNet(nn.Module):
    def __init__(self):
        super(ResNet, self).__init__()
        # Define a simple ResNet-like architecture
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.fc1 = nn.Linear(32 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, 2)  # Assuming binary classification

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = ResNet()

# Cell 6: Define optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Cell 7: Training loop with batch processing
num_epochs = 2  # Set the number of epochs
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for i, batch in enumerate(dataloader):
        inputs = batch
        # Get labels from edfLabels
        labels = edfLabels[i * batch_size : (i + 1) * batch_size]

        # Zero the gradient
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        
        # Calculate loss
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # Clear unused variables to free memory
        del inputs, labels, outputs, loss
        gc.collect()
        torch.cuda.empty_cache()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader)}")

print("Finished Training")

# Explanation of Computational Efficiency
# Increasing the batch size can make the training more computationally efficient, as it allows more data to be processed in parallel.
# However, it also requires more memory. If your system has limited memory, a smaller batch size may be more practical to avoid crashes.
# Increasing the segment length means more data is processed per segment, which can also improve efficiency.
# However, larger segments require more memory, which may not be feasible on systems with limited resources.
# Therefore, both batch size and segment length should be chosen carefully based on the available memory to balance efficiency and stability.


Extracting EDF parameters from D:\v2.0.3\edf\train\aaaaadns\s004_2014\01_tcp_ar\aaaaadns_s004_t001.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 3249  =      0.000 ...    12.996 secs...
Extracting EDF parameters from D:\v2.0.3\edf\train\aaaaanrb\s005_2012\01_tcp_ar\aaaaanrb_s005_t005.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 150249  =      0.000 ...   600.996 secs...


MemoryError: Unable to allocate 22.9 MiB for an array with shape (40, 75000) and data type float64