# CNN for many subjects

$$ x = y^2 + 2$$

In [1]:
%pylab
%matplotlib inline

import glob
import os
import mne
CORPORA_PATH = "~/corpora/sets"

file_path = os.path.expanduser(CORPORA_PATH)
files = glob.glob(os.path.join(file_path, "*.set"))

def normalize_subject(X):
    mean = X.mean(axis=(0, 2)).reshape(-1, 1)
    std = X.std(axis=(0, 2)).reshape(-1, 1)
    return (X - mean) / std

def load_data(filename, normalize=True):
    data_mne = mne.io.read_raw_eeglab(filename, preload=True, event_id={"0": 1, "1": 2})
    data_mne.filter(0, 20)
    events = mne.find_events(data_mne)
    epochs = mne.Epochs(
        data_mne, events,
        baseline=(None, 0), tmin=-0.1, tmax=0.7)

    epochs.load_data()
    
    ch_names = epochs.ch_names
    
    X = epochs.get_data()[:, :-1]
    y = (events[:, 2] == 2).astype('float')

    if len(events) != len(epochs):
        raise ValueError("Epochs events mismatch")
    if normalize: 
        X = normalize_subject(X)
    
    
    return X, y 


Using matplotlib backend: TkAgg
Populating the interactive namespace from numpy and matplotlib


Targets appear as 2 in the third column


We remove last channel as well

# CNN with more data

In [2]:
from IPython.core.debugger import set_trace as bp

filenames = files


X = None
y = None
print(filenames)
for filename in filenames:
    try:
        X_subject, y_subject = load_data(filename)

        if X is None:
            X, y = X_subject, y_subject
        else:
            print(X.shape, X_subject.shape)
            X = np.vstack((X, X_subject))
            print(y.shape, y_subject.shape)
            y = np.vstack((y.reshape(-1,1), y_subject.reshape(-1,1)))
    except ValueError as e:
        print(e)

['/home/ubuntu/corpora/sets/PruebasMuseo_36001.set', '/home/ubuntu/corpora/sets/PruebasMuseo_24540001.set', '/home/ubuntu/corpora/sets/PruebasMuseo_2089001.set', '/home/ubuntu/corpora/sets/PruebasMuseo_4305001.set', '/home/ubuntu/corpora/sets/PruebasMuseo_22109001.set', '/home/ubuntu/corpora/sets/PruebasMuseo_12521001.set', '/home/ubuntu/corpora/sets/PruebasMuseo_21011001.set', '/home/ubuntu/corpora/sets/PruebasMuseo_27496001.set', '/home/ubuntu/corpora/sets/PruebasMuseo_1414001.set', '/home/ubuntu/corpora/sets/PruebasMuseo_24888001.set', '/home/ubuntu/corpora/sets/PruebasMuseo_18112001.set', '/home/ubuntu/corpora/sets/PruebasMuseo_32459001.set', '/home/ubuntu/corpora/sets/PruebasMuseo_29273001.set', '/home/ubuntu/corpora/sets/PruebasMuseo_23794001.set', '/home/ubuntu/corpora/sets/PruebasMuseo_25922001.set', '/home/ubuntu/corpora/sets/PruebasMuseo_20947001.set', '/home/ubuntu/corpora/sets/PruebasMuseo_17576001.set', '/home/ubuntu/corpora/sets/PruebasMuseo_16893001.set', '/home/ubuntu/c

1800 matching events found
0 projection items activated
Loading data for 1800 events and 104 original time points ...
0 bad epochs dropped
Reading /home/ubuntu/corpora/sets/PruebasMuseo_24540001.fdt
Reading 0 ... 63231  =      0.000 ...   493.992 secs...
Setting up low-pass filter at 20 Hz
h_trans_bandwidth chosen to be 5.0 Hz
Filter length of 169 samples (1.320 sec) selected
1800 events found
Events id: [1 2]
1800 matching events found
0 projection items activated
Loading data for 1800 events and 104 original time points ...
0 bad epochs dropped
(1800, 14, 104) (1800, 14, 104)
(1800,) (1800,)
Reading /home/ubuntu/corpora/sets/PruebasMuseo_2089001.fdt
Reading 0 ... 63231  =      0.000 ...   493.992 secs...
Setting up low-pass filter at 20 Hz
h_trans_bandwidth chosen to be 5.0 Hz
Filter length of 169 samples (1.320 sec) selected
1800 events found
Events id: [1 2]
1800 matching events found
0 projection items activated
Loading data for 1800 events and 104 original time points ...
0 bad e

Filter length of 169 samples (1.320 sec) selected
2700 events found
Events id: [1 2]
2700 matching events found
0 projection items activated
Loading data for 2700 events and 104 original time points ...
0 bad epochs dropped
(36180, 14, 104) (2700, 14, 104)
(36180, 1) (2700,)
Reading /home/ubuntu/corpora/sets/PruebasMuseo_5568001.fdt
Reading 0 ... 69503  =      0.000 ...   542.992 secs...
Setting up low-pass filter at 20 Hz
h_trans_bandwidth chosen to be 5.0 Hz
Filter length of 169 samples (1.320 sec) selected
1980 events found
Events id: [1 2]
1980 matching events found
0 projection items activated
Loading data for 1980 events and 104 original time points ...
0 bad epochs dropped
(38880, 14, 104) (1980, 14, 104)
(38880, 1) (1980,)
Reading /home/ubuntu/corpora/sets/PruebasMuseo_24053001.fdt
Reading 0 ... 69503  =      0.000 ...   542.992 secs...
Setting up low-pass filter at 20 Hz
h_trans_bandwidth chosen to be 5.0 Hz
Filter length of 169 samples (1.320 sec) selected
1980 events found
E

Reading 0 ... 69503  =      0.000 ...   542.992 secs...
Setting up low-pass filter at 20 Hz
h_trans_bandwidth chosen to be 5.0 Hz
Filter length of 169 samples (1.320 sec) selected
1980 events found
Events id: [1 2]
1980 matching events found
0 projection items activated
Loading data for 1980 events and 104 original time points ...
0 bad epochs dropped
(73080, 14, 104) (1980, 14, 104)
(73080, 1) (1980,)
Reading /home/ubuntu/corpora/sets/PruebasMuseo_21120001.fdt
Reading 0 ... 63231  =      0.000 ...   493.992 secs...
Setting up low-pass filter at 20 Hz
h_trans_bandwidth chosen to be 5.0 Hz
Filter length of 169 samples (1.320 sec) selected
1800 events found
Events id: [1 2]
1800 matching events found
0 projection items activated
Loading data for 1800 events and 104 original time points ...
0 bad epochs dropped
(75060, 14, 104) (1800, 14, 104)
(75060, 1) (1800,)
Reading /home/ubuntu/corpora/sets/PruebasMuseo_25217001.fdt
Reading 0 ... 31871  =      0.000 ...   248.992 secs...
Setting up l

0 bad epochs dropped
(105480, 14, 104) (1980, 14, 104)
(105480, 1) (1980,)
Reading /home/ubuntu/corpora/sets/PruebasMuseo_8762001.fdt
Reading 0 ... 88319  =      0.000 ...   689.992 secs...
Setting up low-pass filter at 20 Hz
h_trans_bandwidth chosen to be 5.0 Hz
Filter length of 169 samples (1.320 sec) selected
2520 events found
Events id: [1 2]
2520 matching events found
0 projection items activated
Loading data for 2520 events and 104 original time points ...
0 bad epochs dropped
(107460, 14, 104) (2520, 14, 104)
(107460, 1) (2520,)
Reading /home/ubuntu/corpora/sets/PruebasMuseo_3195001.fdt
Reading 0 ... 94591  =      0.000 ...   738.992 secs...
Setting up low-pass filter at 20 Hz
h_trans_bandwidth chosen to be 5.0 Hz
Filter length of 169 samples (1.320 sec) selected
2700 events found
Events id: [1 2]
2700 matching events found
0 projection items activated
Loading data for 2700 events and 104 original time points ...
0 bad epochs dropped
(109980, 14, 104) (2700, 14, 104)
(109980, 1)

1980 events found
Events id: [1 2]
1980 matching events found
0 projection items activated
Loading data for 1980 events and 104 original time points ...
0 bad epochs dropped
(142560, 14, 104) (1980, 14, 104)
(142560, 1) (1980,)
Reading /home/ubuntu/corpora/sets/PruebasMuseo_29789001.fdt
Reading 0 ... 31911  =      0.000 ...   249.305 secs...
Setting up low-pass filter at 20 Hz
h_trans_bandwidth chosen to be 5.0 Hz
Filter length of 169 samples (1.320 sec) selected
900 events found
Events id: [1 2]
900 matching events found
0 projection items activated
Loading data for 900 events and 104 original time points ...
0 bad epochs dropped
(144540, 14, 104) (900, 14, 104)
(144540, 1) (900,)
Reading /home/ubuntu/corpora/sets/PruebasMuseo_10924001.fdt
Reading 0 ... 88319  =      0.000 ...   689.992 secs...
Setting up low-pass filter at 20 Hz
h_trans_bandwidth chosen to be 5.0 Hz
Filter length of 169 samples (1.320 sec) selected
2520 events found
Events id: [1 2]
2520 matching events found
0 proje

Setting up low-pass filter at 20 Hz
h_trans_bandwidth chosen to be 5.0 Hz
Filter length of 169 samples (1.320 sec) selected
1800 events found
Events id: [1 2]
1800 matching events found
0 projection items activated
Loading data for 1800 events and 104 original time points ...
0 bad epochs dropped
(175500, 14, 104) (1800, 14, 104)
(175500, 1) (1800,)
Reading /home/ubuntu/corpora/sets/PruebasMuseo_31777001.fdt
Reading 0 ... 88319  =      0.000 ...   689.992 secs...
Setting up low-pass filter at 20 Hz
h_trans_bandwidth chosen to be 5.0 Hz
Filter length of 169 samples (1.320 sec) selected
2520 events found
Events id: [1 2]
2520 matching events found
0 projection items activated
Loading data for 2520 events and 104 original time points ...
0 bad epochs dropped
(177300, 14, 104) (2520, 14, 104)
(177300, 1) (2520,)
Reading /home/ubuntu/corpora/sets/PruebasMuseo_10729001.fdt
Reading 0 ... 94591  =      0.000 ...   738.992 secs...
Setting up low-pass filter at 20 Hz
h_trans_bandwidth chosen to 

0 bad epochs dropped
(208620, 14, 104) (1980, 14, 104)
(208620, 1) (1980,)
Reading /home/ubuntu/corpora/sets/PruebasMuseo_27058001.fdt
Reading 0 ... 69503  =      0.000 ...   542.992 secs...
Setting up low-pass filter at 20 Hz
h_trans_bandwidth chosen to be 5.0 Hz
Filter length of 169 samples (1.320 sec) selected
1980 events found
Events id: [1 2]
1980 matching events found
0 projection items activated
Loading data for 1980 events and 104 original time points ...
0 bad epochs dropped
(210600, 14, 104) (1980, 14, 104)
(210600, 1) (1980,)
Reading /home/ubuntu/corpora/sets/PruebasMuseo_23344001.fdt
Reading 0 ... 69503  =      0.000 ...   542.992 secs...
Setting up low-pass filter at 20 Hz
h_trans_bandwidth chosen to be 5.0 Hz
Filter length of 169 samples (1.320 sec) selected
1980 events found
Events id: [1 2]
1980 matching events found
0 projection items activated
Loading data for 1980 events and 104 original time points ...
0 bad epochs dropped
(212580, 14, 104) (1980, 14, 104)
(212580, 

Filter length of 169 samples (1.320 sec) selected
1800 events found
Events id: [1 2]
1800 matching events found
0 projection items activated
Loading data for 1800 events and 104 original time points ...
0 bad epochs dropped
(244980, 14, 104) (1800, 14, 104)
(244980, 1) (1800,)
Reading /home/ubuntu/corpora/sets/PruebasMuseo_23732001.fdt
Reading 0 ... 69503  =      0.000 ...   542.992 secs...
Setting up low-pass filter at 20 Hz
h_trans_bandwidth chosen to be 5.0 Hz
Filter length of 169 samples (1.320 sec) selected
1980 events found
Events id: [1 2]
1980 matching events found
0 projection items activated
Loading data for 1980 events and 104 original time points ...
0 bad epochs dropped
(246780, 14, 104) (1980, 14, 104)
(246780, 1) (1980,)
Reading /home/ubuntu/corpora/sets/PruebasMuseo_5251001.fdt
Reading 0 ... 63231  =      0.000 ...   493.992 secs...
Setting up low-pass filter at 20 Hz
h_trans_bandwidth chosen to be 5.0 Hz
Filter length of 169 samples (1.320 sec) selected
1800 events fou

(280440, 1) (1800,)
Reading /home/ubuntu/corpora/sets/PruebasMuseo_16943001.fdt
Reading 0 ... 63231  =      0.000 ...   493.992 secs...
Setting up low-pass filter at 20 Hz
h_trans_bandwidth chosen to be 5.0 Hz
Filter length of 169 samples (1.320 sec) selected
1800 events found
Events id: [1 2]
1800 matching events found
0 projection items activated
Loading data for 1800 events and 104 original time points ...
0 bad epochs dropped
(282240, 14, 104) (1800, 14, 104)
(282240, 1) (1800,)
Reading /home/ubuntu/corpora/sets/PruebasMuseo_7488001.fdt
Reading 0 ... 63231  =      0.000 ...   493.992 secs...
Setting up low-pass filter at 20 Hz
h_trans_bandwidth chosen to be 5.0 Hz
Filter length of 169 samples (1.320 sec) selected
1800 events found
Events id: [1 2]
1800 matching events found
0 projection items activated
Loading data for 1800 events and 104 original time points ...
0 bad epochs dropped
(284040, 14, 104) (1800, 14, 104)
(284040, 1) (1800,)
Reading /home/ubuntu/corpora/sets/PruebasMuse

In [3]:
X.shape

(287640, 14, 104)

In [4]:
X_t = X[:, :, :, np.newaxis]

X_t.shape

(287640, 14, 104, 1)

In [5]:
 sum(y) / len(y)

0.16666666666666666

In [6]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test= train_test_split(X_t, y, test_size=0.1, stratify=y)

X_train.shape, y_train.shape

((258876, 14, 104, 1), (258876, 1))

In [10]:
from keras.models import Sequential
from keras.layers import Conv1D, Conv2D, Flatten, Dense, Dropout

model = Sequential()

n_kernels = 12
model.add(Conv2D(n_kernels, (14, 1), padding='same', 
                activation='relu', input_shape=(14, 104, 1)))
model.add(Conv2D(5*n_kernels, (1, 13), padding='same',
                activation='relu'))
model.add(Flatten())
model.add(Dropout(0.35))
model.add(Dense(256, activation='relu'))
model.add(Dense(1, activation='sigmoid'))

model.compile(loss='binary_crossentropy', # using the cross-entropy loss function
              optimizer='rmsprop', 
              metrics=['accuracy']) # reporting the accuracy


In [13]:
from keras.callbacks import ModelCheckpoint, EarlyStopping
checkpointer = ModelCheckpoint(filepath='model.h5', verbose=1, save_best_only=True)
early_stopping = EarlyStopping(monitor='val_loss', patience=3)

model.fit(
    X_train, y_train, epochs=30, 
    batch_size=256, class_weight={0:1, 1:6}, validation_split=0.01,
    callbacks=[checkpointer, early_stopping]
)

Train on 256287 samples, validate on 2589 samples
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30


<keras.callbacks.History at 0x7fd1569be4e0>

In [14]:
y_pred = model.predict_classes(X_test)
y_prob = model.predict(X_test)


from sklearn.metrics import precision_score, recall_score, roc_auc_score, accuracy_score

precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
auc = roc_auc_score(y_test, y_prob)
accuracy = accuracy_score(y_test, y_pred)

print("""
Accuracy   = {}
Precision  = {}
Recall     = {}
ROC AUC    = {}
""".format(accuracy, precision, recall, auc))


Accuracy   = 0.5545821165345571
Precision  = 0.21949342289392668
Recall     = 0.6543596161869003
ROC AUC    = 0.6437523028455294



In [15]:
model.save("model.2.h5")

In [17]:
from keras.models import load_model

model_2 = load_model("model.h5")

y_pred = model_2.predict_classes(X_test)
y_prob = model_2.predict(X_test)


from sklearn.metrics import precision_score, recall_score, roc_auc_score, accuracy_score

precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
auc = roc_auc_score(y_test, y_prob)
accuracy = accuracy_score(y_test, y_pred)

print("""
Accuracy   = {}
Precision  = {}
Recall     = {}
ROC AUC    = {}
""".format(accuracy, precision, recall, auc))


Accuracy   = 0.5848630232234738
Precision  = 0.22698449079379632
Recall     = 0.6197329995828118
ROC AUC    = 0.6464899717331966



In [23]:
from keras.models import Sequential
from keras.layers import Conv1D, Conv2D, Flatten, Dense, Dropout, MaxPool2D

model = Sequential()

n_kernels = 10
model.add(Conv2D(n_kernels, (14, 1), padding='same', 
                activation='relu', input_shape=(14, 104, 1)))
model.add(Conv2D(5*n_kernels, (1, 13), padding='same',
                activation='relu'))
model.add(MaxPool2D((1, 4)))
model.add(Flatten())
model.add(Dropout(0.35))
model.add(Dense(128, activation='relu'))
model.add(Dense(1, activation='sigmoid'))

model.compile(loss='binary_crossentropy', # using the cross-entropy loss function
              optimizer='rmsprop', 
              metrics=['accuracy']) # reporting the accuracy

In [24]:
from keras.callbacks import ModelCheckpoint, EarlyStopping
checkpointer = ModelCheckpoint(filepath='model.with_maxpool.h5', verbose=1, save_best_only=True)
early_stopping = EarlyStopping(monitor='val_loss', patience=3)

model.fit(
    X_train, y_train, epochs=30, 
    batch_size=256, class_weight={0:1, 1:6}, validation_split=0.01,
    callbacks=[checkpointer, early_stopping]
)

Train on 256287 samples, validate on 2589 samples
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30


<keras.callbacks.History at 0x7fd11de63b70>

In [27]:
from keras.models import load_model


y_pred = model.predict_classes(X_test)
y_prob = model.predict(X_test)


from sklearn.metrics import precision_score, recall_score, roc_auc_score, accuracy_score

precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
auc = roc_auc_score(y_test, y_prob)
accuracy = accuracy_score(y_test, y_pred)

print("""
Accuracy   = {}
Precision  = {}
Recall     = {}
ROC AUC    = {}
""".format(accuracy, precision, recall, auc))


Accuracy   = 0.5865665415102211
Precision  = 0.22678983833718244
Recall     = 0.6145181476846058
ROC AUC    = 0.642395853946901



## Model with two layers

In [37]:
from keras.models import Sequential
from keras.layers import Conv1D, Conv2D, Flatten, Dense, Dropout, MaxPool2D

model = Sequential()

n_kernels = 10
model.add(Conv2D(2*n_kernels, (14, 1), padding='same', 
                activation='relu', input_shape=(14, 104, 1)))
model.add(Conv2D(5*n_kernels, (1, 13), padding='same',
                activation='relu'))
model.add(MaxPool2D((1, 4)))
model.add(Conv2D(n_kernels, (14, 1), padding='same', 
                activation='relu', input_shape=(14, 104, 1)))
model.add(Conv2D(2*n_kernels, (1, 13), padding='same',
                activation='relu'))
model.add(MaxPool2D((1, 4)))

model.add(Flatten())
model.add(Dropout(0.35))
model.add(Dense(256, activation='relu'))
model.add(Dense(1, activation='sigmoid'))

model.compile(loss='binary_crossentropy', # using the cross-entropy loss function
              optimizer='rmsprop', 
              metrics=['accuracy']) # reporting the accuracy

In [38]:
from keras.callbacks import ModelCheckpoint, EarlyStopping
checkpointer = ModelCheckpoint(filepath='model.2conv_with_maxpool.h5', verbose=1, save_best_only=True)
early_stopping = EarlyStopping(monitor='val_loss', patience=3)

model.fit(
    X_train, y_train, epochs=30, 
    batch_size=256, class_weight={0:1, 1:6}, validation_split=0.01,
    callbacks=[checkpointer, early_stopping]
)

Train on 256287 samples, validate on 2589 samples
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30


<keras.callbacks.History at 0x7fd0f7f2ef98>

In [36]:
from keras.models import load_model


y_pred = model.predict_classes(X_test)
y_prob = model.predict(X_test)


from sklearn.metrics import precision_score, recall_score, roc_auc_score, accuracy_score

precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
auc = roc_auc_score(y_test, y_prob)
accuracy = accuracy_score(y_test, y_pred)

print("""
Accuracy   = {}
Precision  = {}
Recall     = {}
ROC AUC    = {}
""".format(accuracy, precision, recall, auc))


Accuracy   = 0.3304130162703379
Precision  = 0.18425877422734416
Recall     = 0.8804755944931164
ROC AUC    = 0.6389369560302485



## Conv with two layers but one simple

In [39]:
from keras.models import Sequential
from keras.layers import Conv1D, Conv2D, Flatten, Dense, Dropout, MaxPool2D

model = Sequential()

n_kernels = 10
model.add(Conv2D(2*n_kernels, (14, 1), padding='same', 
                activation='relu', input_shape=(14, 104, 1)))
model.add(Conv2D(5*n_kernels, (1, 13), padding='same',
                activation='relu'))
model.add(MaxPool2D((1, 4)))
model.add(Conv2D(n_kernels, (14, 5), padding='same', 
                activation='relu', input_shape=(14, 104, 1)))
model.add(Flatten())
model.add(Dropout(0.35))
model.add(Dense(256, activation='relu'))
model.add(Dense(1, activation='sigmoid'))

model.compile(loss='binary_crossentropy', # using the cross-entropy loss function
              optimizer='rmsprop', 
              metrics=['accuracy']) # reporting the accuracy

In [40]:
from keras.callbacks import ModelCheckpoint, EarlyStopping
checkpointer = ModelCheckpoint(filepath='model.2conv_with_maxpool.h5', verbose=1, save_best_only=True)
early_stopping = EarlyStopping(monitor='val_loss', patience=3)

model.fit(
    X_train, y_train, epochs=30, 
    batch_size=256, class_weight={0:1, 1:6}, validation_split=0.01,
    callbacks=[checkpointer, early_stopping]
)

Train on 256287 samples, validate on 2589 samples
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30


<keras.callbacks.History at 0x7fd11c4c54e0>

In [41]:
from keras.models import load_model


y_pred = model.predict_classes(X_test)
y_prob = model.predict(X_test)


from sklearn.metrics import precision_score, recall_score, roc_auc_score, accuracy_score

precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
auc = roc_auc_score(y_test, y_prob)
accuracy = accuracy_score(y_test, y_pred)

print("""
Accuracy   = {}
Precision  = {}
Recall     = {}
ROC AUC    = {}
""".format(accuracy, precision, recall, auc))


Accuracy   = 0.48338200528438324
Precision  = 0.2059820072438369
Recall     = 0.7355027117229871
ROC AUC    = 0.6363816524932344

