In [115]:
import pandas as pd
import numpy as np

In [133]:
n400 = pd.read_csv("N400_by_trial.csv")
p600 = pd.read_csv("P600_by_trial.csv")
spr = pd.read_csv("SPR_by_trial.csv")

In [169]:
def extract_features(dataframe: pd.DataFrame, condition):
    """
    Extracts features from a dataframe of ERPs / SRP's. 
    
    Args:
        dataframe: A dataframe of ERPs / SRP's.
        condition: The condition to extract features from.
    Returns:
        A numpy array of features.
    """
    
    features = dataframe[dataframe["Condition"] == condition]
    features = features.drop(["Condition"], axis=1)
    features = features.set_index("ItemNum")
    features = np.expand_dims(features.to_numpy(), axis=0)
    return features

n400_control = extract_features(n400, "control")
n400_script_related = extract_features(n400, "script-related")
n400_script_unrelated = extract_features(n400, "script-unrelated")

p600_control = extract_features(p600, "control")
p600_script_related = extract_features(p600, "script-related")
p600_script_unrelated = extract_features(p600, "script-unrelated")

spr_control = extract_features(spr, "control")
spr_script_related = extract_features(spr, "script-related")
spr_script_unrelated = extract_features(spr, "script-unrelated")

### Visualise data

In [175]:
def print_item(item = 0):
    print(f"ItemNum: {item+1}")
    print(f"ERP's (n400): \n{n400_control[0][item]}")
    print(f"SPR: {spr_control[0][item]}")

print_item(0)

ItemNum: 1
ERP's (n400): 
[-1.79205279 -1.86712636 -1.70394696 -3.9231718  -4.74066061 -4.90966067
 -2.37030723 -1.98819744 -3.95095261 -3.93343971 -4.15195217 -2.20787836
 -3.72650287 -2.34174259  0.50864363 -1.46534323 -2.34354867 -2.66679243
 -0.31249529  0.80925981 -1.85934109 -1.26222229 -1.31437606 -1.49421967
 -0.46220971 -0.38511491]
SPR: [432.36363636]


In [200]:
# check data shape
print(spr_control.shape)
print(n400_control.shape)
print(p600_control.shape)

# check data type
print(type(spr_control))
print(n400_control.dtype)
print(p600_control.dtype)

(1, 90, 1)
(1, 90, 26)
(1, 90, 26)
<class 'numpy.ndarray'>
float64
float64


### Prepare data for training

In [153]:
from sklearn.model_selection import train_test_split

In [220]:
X = np.concatenate((n400_control, p600_control), axis=0)
y = np.reshape(spr_control, (90, 1, 1))

# reshape x to have 90 in first dimension
X = np.reshape(X, (90, X.shape[0], 26))

In [221]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [222]:
print(f"X_train shape: {X_train.shape}")
print(f"y_train shape: {y_train.shape}")

X_train shape: (72, 2, 26)
y_train shape: (72, 1, 1)


### Helpers

In [295]:
import tensorflow.keras.backend as K
from keras.utils import get_custom_objects

def r_squared(y_true, y_pred):
    SS_res = K.sum(K.square(y_true - y_pred))
    SS_tot = K.sum(K.square(y_true - K.mean(y_true)))
    return 1 - SS_res / (SS_tot + K.epsilon())

get_custom_objects().update({"r_squared": r_squared})

In [395]:
from keras.callbacks import EarlyStopping
early_stopping = EarlyStopping(monitor='val_loss', patience=20)

### Construct RNN Model

In [99]:
# from tensorflow import Sequential
from keras import layers, Sequential

In [389]:
# make an rnn model that takes in the n400 and p600 data and predicts the spr data

model = Sequential()
model.add(layers.Input(shape=(2, 26)))
model.add(layers.LSTM(32, input_shape=(2, 26)))
model.add(layers.Dense(64, activation="relu"))
model.add(layers.Dropout(0.2))
model.add(layers.Dense(128, activation="relu"))
model.add(layers.Dropout(0.2))
model.add(layers.Dense(256, activation="relu"))
model.add(layers.Dropout(0.2))
model.add(layers.Dense(128, activation="relu"))
model.add(layers.Dropout(0.2))
model.add(layers.Dense(64, activation="relu"))
model.add(layers.Dropout(0.2))
model.add(layers.Dense(32, activation="relu"))
model.add(layers.Dropout(0.2))
model.add(layers.Dense(16, activation="relu"))
model.add(layers.Dense(1, activation="relu"))

model.compile(optimizer="adam", loss="mae", metrics=['mae', 'mse', 'r_squared'])

model.summary()

Model: "sequential_53"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 lstm_87 (LSTM)              (None, 32)                7552      
                                                                 
 dense_194 (Dense)           (None, 128)               4224      
                                                                 
 dropout_31 (Dropout)        (None, 128)               0         
                                                                 
 dense_195 (Dense)           (None, 256)               33024     
                                                                 
 dropout_32 (Dropout)        (None, 256)               0         
                                                                 
 dense_196 (Dense)           (None, 128)               32896     
                                                                 
 dropout_33 (Dropout)        (None, 128)             

In [396]:
model.fit(X_train, y_train, epochs=200, batch_size=64, validation_split=0.2, callbacks=[early_stopping])

Epoch 1/200
Epoch 2/200
Epoch 3/200
Epoch 4/200
Epoch 5/200
Epoch 6/200
Epoch 7/200
Epoch 8/200
Epoch 9/200
Epoch 10/200
Epoch 11/200
Epoch 12/200
Epoch 13/200
Epoch 14/200
Epoch 15/200
Epoch 16/200
Epoch 17/200
Epoch 18/200
Epoch 19/200
Epoch 20/200
Epoch 21/200
Epoch 22/200
Epoch 23/200
Epoch 24/200
Epoch 25/200
Epoch 26/200
Epoch 27/200
Epoch 28/200
Epoch 29/200
Epoch 30/200
Epoch 31/200
Epoch 32/200
Epoch 33/200
Epoch 34/200
Epoch 35/200
Epoch 36/200
Epoch 37/200
Epoch 38/200
Epoch 39/200
Epoch 40/200
Epoch 41/200
Epoch 42/200
Epoch 43/200


<keras.callbacks.History at 0x27c70e658b0>

### Test model

In [337]:
from sklearn.metrics import mean_squared_error, r2_score

In [397]:
pred = model.predict(X_test)

# remove 1 dim from y_test to match pred
y_test_ = np.squeeze(y_test, axis=1)
pred = np.squeeze(pred, axis=1)

mse = mean_squared_error(y_test_, pred)
print(f"\nMSE: {mse}")

r_squared = r2_score(y_test_, pred)
print(f"R^2: {r_squared}")

print("\nFirst 5 predictions: ")
for i in range(min(len(pred), 5)):
    print(f"pred: {round(pred[i])}", end=" ")
    print(f"actual: {round(y_test_[i])}")


MSE: 8447.366241759839
R^2: -0.2929320909621169

First 5 predictions: 
pred: 412 actual: 595
pred: 357 actual: 428
pred: 356 actual: 425
pred: 368 actual: 472
pred: 400 actual: 432
