
I utilized the LSTM implementation specifically tailored for EEG data from the following GitHub repository: https://github.com/theyou21/BigProject. This resource provided invaluable support for my LSTM analysis.

In [1]:
from google.colab import drive

drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import numpy as np
import pandas as pd
import os

In [3]:
ec_data_dir = "/content/drive/MyDrive/TD-BRAIN/training_data/data/EC_26"
eo_data_dir = "/content/drive/MyDrive/TD-BRAIN/training_data/data/EO_26"
ec_eeg_data = np.load(os.path.join(ec_data_dir, "normalized_epoch_eeg_data.npy"))
eo_eeg_data = np.load(os.path.join(eo_data_dir, "normalized_epoch_eeg_data.npy"))

In [4]:
print(ec_eeg_data.shape)
print(eo_eeg_data.shape)

(4356, 1, 26, 4975)
(4344, 1, 26, 4975)


In [5]:
ec_labels_dir = "/content/drive/MyDrive/TD-BRAIN/training_data/data/EC_26"
eo_labels_dir = "/content/drive/MyDrive/TD-BRAIN/training_data/data/EO_26"
ec_eeg_labels = np.load(os.path.join(ec_labels_dir, "labels_data.npy"))
eo_eeg_labels = np.load(os.path.join(eo_labels_dir, "labels_data.npy"))

In [6]:
print(ec_eeg_labels.shape)
print(eo_eeg_labels.shape)

(4356, 2)
(4344, 2)


In [7]:
for label in ec_eeg_labels:
  sample_id = label[0]
  if sample_id not in eo_eeg_labels[:, 0]:
        index_to_remove = np.where(ec_eeg_labels[:, 0] == sample_id)[0]
        ec_eeg_labels = np.delete(ec_eeg_labels, index_to_remove, axis=0)
        ec_eeg_data = np.delete(ec_eeg_data, index_to_remove, axis=0)
print(ec_eeg_labels.shape)
print(ec_eeg_data.shape)

(4344, 2)
(4344, 1, 26, 4975)


In [8]:
eeg_data = np.concatenate((ec_eeg_data[:, 0], eo_eeg_data[:, 0]), axis=1)
eeg_data.shape

(4344, 52, 4975)

In [9]:
eeg_labels = ec_eeg_labels

In [10]:
healthy_count, mdd_count = 0, 0
for sample in eeg_labels:
  if sample[1] == "MDD":
      mdd_count += 1
  else:
      healthy_count += 1

print(f"Number of MDD participants: {mdd_count}")
print(f"Number of Healthy participants: {healthy_count}")

Number of MDD participants: 3780
Number of Healthy participants: 564


Extracting data for male participants

In [11]:
import pandas as pd
import numpy as np

# Load the participants data
df_participants = pd.read_pickle('/content/drive/MyDrive/TD-BRAIN/TDBRAIN_participants_V2_data/df_participants.pkl')

# Prepare lists to hold the filtered data and labels
eeg_data_male = []
eeg_label_male = []

# Loop over each label in your existing eeg_labels list
for i, labels in enumerate(eeg_labels):
    sample_id = labels[0]
    index = df_participants.loc[df_participants['participants_ID'] == sample_id].index

    if not index.empty:  # Check if the index is not empty
        participant_gender = df_participants.loc[index, 'gender'].values[0]
        participant_condition = labels[1]  # Assuming the condition (MDD/HEALTHY) is stored in labels[1]

        # Check if participant is male and has the condition "Healthy" or "MDD"
        if participant_gender == 0 and (participant_condition == "HEALTHY" or participant_condition == "MDD"):
            eeg_data_male.append(eeg_data[i])
            eeg_label_male.append(labels)

# Convert lists to NumPy arrays for further processing
eeg_data_male = np.array(eeg_data_male)
eeg_label_male = np.array(eeg_label_male)

# Output the shape of the arrays to verify the results
print(f"Shape of male EEG data: {eeg_data_male.shape}")
print(f"Shape of male EEG labels: {eeg_label_male.shape}")


Shape of male EEG data: (2412, 52, 4975)
Shape of male EEG labels: (2412, 2)


In [12]:
healthy_count_male, mdd_count_male = 0, 0
for sample in eeg_label_male:
  if sample[1] == "MDD":
      mdd_count_male += 1
  else:
      healthy_count_male += 1

print(f"Number of MDD male participants: {mdd_count_male}")
print(f"Number of Healthy male participants: {healthy_count_male}")

Number of MDD male participants: 2040
Number of Healthy male participants: 372


# **Model**

In [13]:
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout, BatchNormalization
import matplotlib.pyplot as plt
from keras.regularizers import l2
from keras.metrics import Precision, Recall
from sklearn.metrics import f1_score, confusion_matrix

In [14]:
import numpy as np

# Filtering eeg_label_male and eeg_data_male
encountered_sample_ids = {}
sample_ids_with_more_than_12_entries = []

for index, sample_id in enumerate(eeg_label_male):
    sample_id_tuple = tuple(sample_id)
    count = encountered_sample_ids.get(sample_id_tuple, 0)
    count += 1
    encountered_sample_ids[sample_id_tuple] = count
    if count > 12:
        sample_ids_with_more_than_12_entries.append((sample_id_tuple, index))

indices_to_remove = [index for _, index in sample_ids_with_more_than_12_entries]
eeg_label_male = [sample for i, sample in enumerate(eeg_label_male) if i not in indices_to_remove]
eeg_data_male = [data for i, data in enumerate(eeg_data_male) if i not in indices_to_remove]
print("Length of filtered eeg_label_male:", len(eeg_label_male))
print("Length of filtered eeg_data_male:", len(eeg_data_male))

# Undersampling and preparing training data
ll = eeg_label_male
unique_sample_id = []
encountered_sample_ids = set()

for sample_id in ll:
    sample_id_tuple = tuple(sample_id)
    if sample_id_tuple not in encountered_sample_ids:
        unique_sample_id.append(sample_id)
        encountered_sample_ids.add(sample_id_tuple)

num_samples_per_class = 16
indices_mdd = [index for index, sample in enumerate(unique_sample_id) if sample[1] == "MDD"]
indices_healthy = [index for index, sample in enumerate(unique_sample_id) if sample[1] == "HEALTHY"]

# Ensure we have at least 16 samples of each class
if len(indices_mdd) < num_samples_per_class or len(indices_healthy) < num_samples_per_class:
    raise ValueError("Not enough samples to balance the classes with 16 samples each.")

undersampled_mdd = np.random.choice(indices_mdd, num_samples_per_class, replace=False)
undersampled_healthy = np.random.choice(indices_healthy, num_samples_per_class, replace=False)

balanced_data_indices = np.concatenate([undersampled_mdd, undersampled_healthy])
balanced_unique_sample_id = [unique_sample_id[i] for i in balanced_data_indices]

# Extract all unique sample IDs from balanced_unique_sample_id
unique_sample_ids = [sample_id[0] for sample_id in balanced_unique_sample_id]

# Extract all indices from eeg_labels for sample IDs in balanced_unique_sample_id
indices = []
for i, sample_id in enumerate(eeg_label_male):
    if sample_id[0] in unique_sample_ids:
        indices.append(i)

# Convert indices to a NumPy array
indices = np.array(indices)
X_train = []
y_train = []

for i in indices:
    X_train.append(eeg_data_male[i])
    y_train.append(eeg_label_male[i])

X_train = np.array(X_train)
y_train = np.array(y_train)

# Shuffle together with their indices
permutation = np.random.permutation(len(X_train))
X_train = X_train[permutation]
y_train = y_train[permutation]

print(X_train.shape)

sample_ids = []
for sample in y_train:
    sample_ids.append(sample[0])
sample_ids = np.array(sample_ids)
l = np.array([1 if label[1] == "MDD" else 0 for label in y_train])


Length of filtered eeg_label_male: 2328
Length of filtered eeg_data_male: 2328
(384, 52, 4975)


In [24]:
import numpy as np
from keras.models import Sequential
from keras.layers import Conv1D, MaxPooling1D, LSTM, Dense, Dropout
from keras.optimizers import Adam
from keras.metrics import Precision, Recall
from keras.callbacks import EarlyStopping
from keras.regularizers import l2
from sklearn.metrics import f1_score, confusion_matrix
import matplotlib.pyplot as plt
from sklearn.model_selection import StratifiedKFold

class EEGClassifier:
    def __init__(self, input_shape=(52, 4975), lstm_units=32):
        self.input_shape = input_shape
        self.lstm_units = lstm_units
        self.model = self.build_model()

    def build_model(self):
        model = Sequential()
        model.add(Conv1D(filters=32, kernel_size=4, activation='relu', input_shape=self.input_shape, kernel_regularizer=l2(0.001)))
        model.add(Conv1D(filters=64, kernel_size=4, activation='relu', kernel_regularizer=l2(0.001)))
        model.add(MaxPooling1D(pool_size=2))
        model.add(Dropout(0.2))
        model.add(LSTM(units=32, kernel_regularizer=l2(0.001)))
        model.add(Dropout(0.2))
        model.add(Dense(units=64, activation='relu', kernel_regularizer=l2(0.001)))
        model.add(Dense(units=1, activation='sigmoid'))
        model.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.001), metrics=['accuracy', Precision(), Recall()])
        return model

    def train(self, X_train, y_train, X_val, y_val, epochs=30, batch_size=16):
        # Calculate class weights
        class_weights = {0: 1, 1: 1}  # Initialize with equal weights
        num_minority = np.sum(y_train == 0)
        num_majority = np.sum(y_train == 1)
        total_samples = len(y_train)
        class_weights[0] = (1 / num_minority) * (total_samples / 2.0)
        class_weights[1] = (1 / num_majority) * (total_samples / 2.0)

        early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

        history = self.model.fit(X_train, y_train, epochs=epochs, batch_size=batch_size, validation_data=(X_val, y_val), class_weight=class_weights, callbacks=[early_stopping], verbose=0)
        return history

    def evaluate(self, X_test, y_test):
        loss, accuracy, precision, recall = self.model.evaluate(X_test, y_test, verbose=0)
        y_pred = self.model.predict(X_test, verbose=0)
        y_pred_classes = np.round(y_pred)
        f1 = f1_score(y_test, y_pred_classes)
        cm = confusion_matrix(y_test, y_pred_classes)

        evaluation_metrics = {
            'loss': loss,
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1_score': f1,
            'confusion_matrix': cm
        }

        return evaluation_metrics

    def predict(self, X):
        return self.model.predict(X, verbose=0)

    def plot_loss(self, history):
        plt.plot(history.history['loss'], label='Training Loss')
        plt.plot(history.history['val_loss'], label='Validation Loss')
        plt.title('Training and Validation Loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.legend()
        plt.show()


def main():
    classifier = EEGClassifier()

    num_splits = 5
    cv = StratifiedKFold(n_splits=num_splits, shuffle=True, random_state=42)

    overall_train_metrics = []
    overall_val_metrics = []

    for fold_idx, (train_index, val_index) in enumerate(cv.split(X_train, l), 1):
        print(f"Fold {fold_idx}:")

        X_train_fold, X_val_fold = X_train[train_index], X_train[val_index]
        y_train_fold, y_val_fold = l[train_index], l[val_index]

        history = classifier.train(X_train_fold, y_train_fold, X_val_fold, y_val_fold)

        # Evaluate on training set after training
        train_metrics = classifier.evaluate(X_train_fold, y_train_fold)
        print(f'Training Results - Loss: {train_metrics["loss"]}, Accuracy: {train_metrics["accuracy"]}, '
              f'Precision: {train_metrics["precision"]}, Recall: {train_metrics["recall"]}, '
              f'F1 Score: {train_metrics["f1_score"]}')
        overall_train_metrics.append(train_metrics)

        # Evaluate on the validation set after training
        val_metrics = classifier.evaluate(X_val_fold, y_val_fold)
        print(f'Validation Results - Loss: {val_metrics["loss"]}, Accuracy: {val_metrics["accuracy"]}, '
              f'Precision: {val_metrics["precision"]}, Recall: {val_metrics["recall"]}, '
              f'F1 Score: {val_metrics["f1_score"]}')
        overall_val_metrics.append(val_metrics)
        print()

    # Calculate and print overall metrics
    def calculate_overall_metrics(metrics_list):
        avg_metrics = {}
        for key in metrics_list[0].keys():
            avg_metrics[key] = np.mean([metrics[key] for metrics in metrics_list], axis=0)
        return avg_metrics

    overall_train_metrics = calculate_overall_metrics(overall_train_metrics)
    overall_val_metrics = calculate_overall_metrics(overall_val_metrics)

    print("Overall Training Metrics:")
    print(overall_train_metrics)
    print("\nOverall Validation Metrics:")
    print(overall_val_metrics)

if __name__ == "__main__":
    main()


Fold 1:
Training Results - Loss: 0.8347872495651245, Accuracy: 0.7752442955970764, Precision: 0.7470588088035583, Recall: 0.8300653696060181, F1 Score: 0.7863777089783281
Validation Results - Loss: 0.9261130094528198, Accuracy: 0.5714285969734192, Precision: 0.5600000023841858, Recall: 0.7179487347602844, F1 Score: 0.6292134831460674

Fold 2:
Training Results - Loss: 0.8235868215560913, Accuracy: 0.7394136786460876, Precision: 0.7267080545425415, Recall: 0.7647058963775635, F1 Score: 0.7452229299363057
Validation Results - Loss: 0.8876842260360718, Accuracy: 0.6753246784210205, Precision: 0.6842105388641357, Recall: 0.6666666865348816, F1 Score: 0.6753246753246753

Fold 3:
Training Results - Loss: 0.8354975581169128, Accuracy: 0.7166123986244202, Precision: 0.7133758068084717, Recall: 0.7272727489471436, F1 Score: 0.7202572347266881
Validation Results - Loss: 0.7591734528541565, Accuracy: 0.8181818127632141, Precision: 0.8529411554336548, Recall: 0.7631579041481018, F1 Score: 0.8055555