In [1]:
import random
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import os
import re
import shutil
import csv
import glob
import numpy as np
import math
import numpy as np
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense
from tensorflow.keras.preprocessing.sequence import pad_sequences
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

In [None]:
# CSV files have a column 'Unnamed: 0' with MNI brain areas and 280 to 300 more column, each corresponding to a time bin
# Each csv corresponds to one word and has the high frequency oscillatory events (time-binned) of MNI brain areas
# the /ENCODE directory contains csv files of words that were presented to subjects, whereas the /RECALL contains csv files of later recalled words by the subject

# initialize a list with MNI brain areas by keeping the 'Unnamed: 0' column list of the first word
df=pd.read_csv('~/ENCODE/word1.csv')
mni_list=list(df['Unnamed: 0'])

path='~/ENCODE'
for file in os.listdir(path):
    data=pd.read_csv(path+'/'+file)
    data=data.set_index('Unnamed: 0')
    if 'nan' in data.index:
        # drop the row with NaN in the index
        data = data.drop(index='nan')
    if 'N/A' in data.index:
        # drop the row with N/A in the index
        data = data.drop(index='N/A')
    if len(data.index)<10:
        continue
    else:
        data = data.fillna(0)
        # keep intersection
        mni_list=list(set(mni_list) & set(list(data.index)))

print(mni_list)

In [None]:
# Create csv files that are the intra-mni HFOs summed (index values are not unique)

# loop over files
for file in os.listdir(path):
    data=pd.read_csv(path+'/'+file)
    data=data.set_index('Unnamed: 0')
    if 'nan' in data.index:
        # Drop the row with NaN in the index
        data = data.drop(index='nan')
    if 'N/A' in data.index:
        # Drop the row with N/A in the index
        data = data.drop(index='N/A')
    if len(data.index)<10:
        continue
    else:
        data = data.fillna(0)
        sum_df = data.groupby('Unnamed: 0').sum()
        sum_df.to_csv('~/lstm/'+file)

In [None]:
# create labels by assigning 0 if the word was not recalled and 1 if it was recalled
label_dict={}
recalled=[]
recall_path='~/RECALL'
for file in os.listdir(recall_path):
    recalled.append(file.split('_')[4])

for file in os.listdir('~/lstm'):
    if file.split('_')[4] in recalled:
        label_dict[file]=1
    else:
        label_dict[file]=0
label_dict

In [None]:
# Load the CSV files and preprocess the data
path = '~/lstm'

X = []  # List to store input features (time series data)
y = []  # List to store labels

max_sequence_length = 150  # Set the maximum sequence length to pad the data

for filename in os.listdir(path):
    if filename.endswith('.csv'):
        #print(filename)
        data = pd.read_csv(os.path.join(path, filename))
        # Drop the 'Unnamed: 0' column from the DataFrame
        data = data.drop(columns=['Unnamed: 0'])
        # Extract time series data as input features
        time_series = data.iloc[:, 0:max_sequence_length].values  # Adjust column range as needed
        
        # Check if the time_series has 27 rows
        if time_series.shape[0] == 27:
            X.append(time_series)
            # Retrieve the label from the label_dict using the filename as the key
            label = label_dict[filename]
            y.append(label)

# Convert lists to numpy arrays
X = np.array(X)
y = np.array(y)

# Pad the sequences to ensure they all have the same length
X_padded = pad_sequences(X, maxlen=max_sequence_length, dtype='float32', padding='post')

# Split the data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X_padded, y, test_size=0.2, random_state=42)

# Build the 1D CNN model
model = Sequential()
model.add(Conv1D(64, kernel_size=3, activation='relu', input_shape=(max_sequence_length, X_train.shape[2])))
model.add(GlobalMaxPooling1D())
model.add(Dense(64, activation='relu'))
model.add(Dense(1, activation='sigmoid'))

# Compile the model
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Train the model
model.fit(X_train, y_train, epochs=50, batch_size=32, validation_data=(X_test, y_test))

# Test the model
y_pred_probs = model.predict(X_test)
y_pred = (y_pred_probs > 0.5).astype(int)  # Convert probabilities to binary predictions

# Confusion Matrix
cm = confusion_matrix(y_test, y_pred)
print("Confusion Matrix:")
print(cm)

# Classification Report
print("\nClassification Report:")
print(classification_report(y_test, y_pred))

# ROC-AUC Score
roc_auc = roc_auc_score(y_test, y_pred_probs)
print("\nROC-AUC Score:", roc_auc)

# PR-AUC Score
pr_auc = average_precision_score(y_test, y_pred_probs)
print("PR-AUC Score:", pr_auc)

# Balanced Accuracy
balanced_acc = balanced_accuracy_score(y_test, y_pred)
print("Balanced Accuracy:", balanced_acc)

In [None]:
# Build the LSTM model
model = Sequential()
model.add(LSTM(64, input_shape=(max_sequence_length, X_train.shape[2])))
model.add(Dense(1, activation='sigmoid'))

# Compile the model
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Train the model
model.fit(X_train, y_train, epochs=10, batch_size=32, validation_data=(X_test, y_test))
#%%
# Test the LSTM model
loss, accuracy = model.evaluate(X_test, y_test)
print('Test Loss:', loss)
print('Test Accuracy:', accuracy)

# Make predictions using the trained model
predictions = model.predict(X_test)

# Convert predictions to binary values based on a threshold of 0.5
threshold = 0.5
binary_predictions = (predictions > threshold).astype(int)

# Confusion Matrix
cm = confusion_matrix(y_test, y_pred)
print("Confusion Matrix:")
print(cm)

# Classification Report
print("\nClassification Report:")
print(classification_report(y_test, y_pred))

# ROC-AUC Score
roc_auc = roc_auc_score(y_test, y_pred_probs)
print("\nROC-AUC Score:", roc_auc)

# PR-AUC Score
pr_auc = average_precision_score(y_test, y_pred_probs)
print("PR-AUC Score:", pr_auc)

# Balanced Accuracy
balanced_acc = balanced_accuracy_score(y_test, y_pred)
print("Balanced Accuracy:", balanced_acc)