# EEG Classification - Tensorflow
updated: Sep. 01, 2018

Data: https://www.physionet.org/pn4/eegmmidb/

## 1. Data Downloads

### Warning: Executing these blocks will automatically create directories and download datasets.

In [2]:
import tensorflow as tf
sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))

In [3]:
from keras.backend.tensorflow_backend import set_session
import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True  # dynamically grow the memory used on the GPU
config.log_device_placement = True  # to log device placement (on which device the operation ran)
sess = tf.Session(config=config)
set_session(sess)  # set this TensorFlow session as the default session for Keras

Using TensorFlow backend.


In [4]:
from keras import backend as K
K.tensorflow_backend._get_available_gpus()

['/job:localhost/replica:0/task:0/device:GPU:0',
 '/job:localhost/replica:0/task:0/device:GPU:1',
 '/job:localhost/replica:0/task:0/device:GPU:2',
 '/job:localhost/replica:0/task:0/device:GPU:3']

In [5]:
# Tensorflow Style Guide
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# System
import requests
import re
import os
import pathlib
import urllib

# Modeling & Preprocessing
from keras.layers import Conv2D, BatchNormalization, Activation, Flatten, Dense, Dropout, LSTM, Input, TimeDistributed
from keras import initializers, Model, optimizers, callbacks
from keras.utils.training_utils import multi_gpu_model
from keras import backend as K
from keras.models import load_model
from keras.callbacks import Callback, TensorBoard
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score

# Essential Data Handling
import numpy as np
import pandas as pd
from math import ceil, floor

# Get Paths
from glob import glob

# EEG package
from mne import pick_types, events_from_annotations
from mne.io import read_raw_edf

import pickle
import sys

## Data Description

Subjects performed different motor/imagery tasks while 64-channel EEG were recorded using the BCI2000 system (http://www.bci2000.org). Each subject performed 14 experimental runs: 

- two one-minute baseline runs (one with eyes open, one with eyes closed)
- three two-minute runs of each of the four following tasks:
    - 1:
        - A target appears on either the left or the right side of the screen. 
        - The subject opens and closes the corresponding fist until the target disappears. 
        - Then the subject relaxes.
    - 2:
        - A target appears on either the left or the right side of the screen. 
        - The subject imagines opening and closing the corresponding fist until the target disappears. 
        - Then the subject relaxes.
    - 3:
        - A target appears on either the top or the bottom of the screen. 
        - The subject opens and closes either both fists (if the target is on top) or both feet (if the target is on the bottom) until the target disappears. 
        - Then the subject relaxes.
    - 4:
        - A target appears on either the top or the bottom of the screen. 
        - The subject imagines opening and closing either both fists (if the target is on top) or both feet (if the target is on the bottom) until the target disappears. 
        - Then the subject relaxes.

The data are provided here in EDF+ format (containing 64 EEG signals, each sampled at 160 samples per second, and an annotation channel). 
For use with PhysioToolkit software, rdedfann generated a separate PhysioBank-compatible annotation file (with the suffix .event) for each recording. 
The .event files and the annotation channels in the corresponding .edf files contain identical data.

# Summary tasks

Remembering that:

    - Task 1 (open and close left or right fist)
    - Task 2 (imagine opening and closing left or right fist)
    - Task 3 (open and close both fists or both feet)
    - Task 4 (imagine opening and closing both fists or both feet)

we will referred to 'Task *' with the meneaning above. 

In summary, the experimental runs were:

1.  Baseline, eyes open
2.  Baseline, eyes closed
3.  Task 1 
4.  Task -2 
5.  Task --3 
6.  Task ---4 
7.  Task 1
8.  Task -2
9.  Task --3
10. Task ---4
11. Task 1
12. Task -2
13. Task --3
14. Task ---4

# Annotation

Each annotation includes one of three codes (T0, T1, or T2):

- T0 corresponds to rest
- T1 corresponds to onset of motion (real or imagined) of
    - the left fist (in runs 3, 4, 7, 8, 11, and 12)
    - both fists (in runs 5, 6, 9, 10, 13, and 14)
- T2 corresponds to onset of motion (real or imagined) of
    - the right fist (in runs 3, 4, 7, 8, 11, and 12)
    - both feet (in runs 5, 6, 9, 10, 13, and 14)
    
In the BCI2000-format versions of these files, which may be available from the contributors of this data set, these annotations are encoded as values of 0, 1, or 2 in the TargetCode state variable.

{'T0':0, 'T1':1, 'T2':2}

In our experiments we will see only :

- run_type_0:
    - append_X
- run_type_1
    - append_X_y
- run_type_2
    - append_X_y
    
and the coding is: 

- T0 corresponds to rest 
    - (2)
- T1 (real or imagined)
    - (4,  8, 12) the left fist 
    - (6, 10, 14) both fists 
- T2 (real or imagined)
    - (4,  8, 12) the right fist 
    - (6, 10, 14) both feet 

## 2. Raw Data Import

I will use a EEG data handling package named MNE (https://martinos.org/mne/stable/index.html) to import raw data and annotation for events from edf files. This package also provides essential signal analysis features, e.g. band-pass filtering. The raw data were filtered using 1Hz of high-pass filter.

In this research, there are 5 classes for the data, imagined motion of:
    - right fist, 
    - left fist, 
    - both fists, 
    - both feet,
    - rest with eyes closed.

A data (S089) from one of the 109 subjects was excluded as the record was severely corrupted.

## 3. Data Preprocessing

The original goal of applying neural networks is to exclude hand-crafted algorithms & preprocessing as much as possible. I did not use any proprecessing techniques further than standardization to build an end-to-end classifer from the dataset

In [6]:
[X_train, y_train] = pickle.load( open( "./py/stack/train.p", "rb" ) )

In [7]:
[X_test, y_test] = pickle.load( open( "./py/stack/test.p", "rb" ) )

As the EEG recording instrument has 3D locations over the subjects\` scalp, it is essential for the model to learn from the spatial pattern as well as the temporal pattern. I transformed the data into 2D meshes that represents the locations of the electrodes so that stacked convolutional neural networks can grasp the spatial information.

## 4. Modeling - Time-Distributed CNN + RNN

Training Plan:

+ 4 GPU units (Nvidia Tesla P100) were used to train this neural network.
+ Instead of training the whole model at once, I trained the first block (CNN) first. Then using the trained parameters as initial values, I trained the next blocks step-by-step. This approach can greatly reduce the time required for training and help avoiding falling into local minimums.
+ The first blocks (CNN) can be applied for other EEG classification models as a pre-trained base.

+ The initial learning rate is set to be $10^{3}$ with Adam optimization. I used several callbacks such as ReduceLROnPlateau which adjusts the learning rate at local minima. Also, I record the log for tensorboard to monitor the training process.

In [None]:
print(X_train.shape)
print(X_test.shape)

In [8]:
X_train = X_train.squeeze().reshape(*X_train.squeeze().shape, 1)
X_test = X_test.squeeze().reshape(*X_test.squeeze().shape, 1)

In [9]:
print(X_train.shape)
print(X_test.shape)

(1024056, 10, 10, 11, 1)
(256014, 10, 10, 11, 1)


In [None]:
# Make another dimension, 1, to apply CNN for each time frame.
X_train = X_train.reshape(*X_train.shape, 1)
X_test = X_test.reshape(*X_test.shape, 1)

In [None]:
print(X_train.shape)
print(X_test.shape)

### 4.1 Keras Implementation

The Keras functional API is the way to go for defining complex models, such as multi-output models, directed acyclic graphs, or models with shared layers.

In [16]:
# Complicated Model - the same as Zhang`s
input_shape = (10, 10, 11, 1)
lecun = initializers.lecun_normal(seed=42)

# TimeDistributed Wrapper
def timeDist(layer, prev_layer, name):
    return TimeDistributed(layer, name=name)(prev_layer)
    
# Input layer
inputs = Input(shape=input_shape)

# Convolutional layers block
x = timeDist(Conv2D(16, (3,3), padding='same', data_format='channels_last', kernel_initializer=lecun), inputs, name='CNN1')
x = BatchNormalization(name='batch1')(x)
x = Activation('elu', name='act1')(x)
x = timeDist(Conv2D(32, (3,3), padding='same', data_format='channels_last', kernel_initializer=lecun), x, name='CNN2')
x = BatchNormalization(name='batch2')(x)
x = Activation('elu', name='act2')(x)
x = timeDist(Conv2D(64, (3,3), padding='same', data_format='channels_last', kernel_initializer=lecun), x, name='CNN3')
x = BatchNormalization(name='batch3')(x)
x = Activation('elu', name='act3')(x)
x = timeDist(Flatten(), x, name='flatten')

# Fully connected layer block
y = Dense(1024, kernel_initializer=lecun, name='FC')(x)
y = Dropout(0.5, name='dropout1')(y)
y = BatchNormalization(name='batch4')(y)
y = Activation(activation='elu')(y)

# Recurrent layers block
z = LSTM(64, kernel_initializer=lecun, return_sequences=True, name='LSTM1')(y)
z = LSTM(64, kernel_initializer=lecun, name='LSTM2')(z)

# Fully connected layer block
h = Dense(1024, kernel_initializer=lecun, activation='elu', name='FC2')(z)
h = Dropout(0.5, name='dropout2')(h)

# Output layer
outputs = Dense(5, activation='softmax')(h)

# Model compile
model = Model(inputs=inputs, outputs=outputs)
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_3 (InputLayer)         (None, 10, 10, 11, 1)     0         
_________________________________________________________________
CNN1 (TimeDistributed)       (None, 10, 10, 11, 16)    160       
_________________________________________________________________
batch1 (BatchNormalization)  (None, 10, 10, 11, 16)    64        
_________________________________________________________________
act1 (Activation)            (None, 10, 10, 11, 16)    0         
_________________________________________________________________
CNN2 (TimeDistributed)       (None, 10, 10, 11, 32)    4640      
_________________________________________________________________
batch2 (BatchNormalization)  (None, 10, 10, 11, 32)    128       
_________________________________________________________________
act2 (Activation)            (None, 10, 10, 11, 32)    0         
__________

In [12]:
model.save('./py/type1/model/model_base.h5')

In [13]:
class TrainValTensorBoard(TensorBoard):
    '''
    Plot training and validation losses on the same Tensorboard graph
    Supersede Tensorboard callback
    '''
    def __init__(self, log_dir="./py/logs/", **kwargs):
        # Make the original `TensorBoard` log to a subdirectory 'training'
        training_log_dir = os.path.join(log_dir, 'training')
        super(TrainValTensorBoard, self).__init__(training_log_dir, **kwargs)

        # Log the validation metrics to a separate subdirectory
        self.val_log_dir = os.path.join(log_dir, 'validation')

    def set_model(self, model):
        # Setup writer for validation metrics
        self.val_writer = tf.summary.FileWriter(self.val_log_dir)
        super(TrainValTensorBoard, self).set_model(model)

    def on_epoch_end(self, epoch, logs=None):
        # Pop the validation logs and handle them separately with
        # `self.val_writer`. Also rename the keys so that they can
        # be plotted on the same figure with the training metrics
        logs = logs or {}
        val_logs = {k.replace('val_', ''): v for k, v in logs.items() if k.startswith('val_')}
        for name, value in val_logs.items():
            summary = tf.Summary()
            summary_value = summary.value.add()
            summary_value.simple_value = value.item()
            summary_value.tag = name
            self.val_writer.add_summary(summary, epoch)
        self.val_writer.flush()

        # Pass the remaining logs to `TensorBoard.on_epoch_end`
        logs = {k: v for k, v in logs.items() if not k.startswith('val_')}
        super(TrainValTensorBoard, self).on_epoch_end(epoch, logs)

    def on_train_end(self, logs=None):
        super(TrainValTensorBoard, self).on_train_end(logs)
        self.val_writer.close()

In [14]:
callbacks_list = [callbacks.ModelCheckpoint("./py/weights/weights_{epoch:02d}_{val_acc:.4f}.h5", save_best_only=False, monitor='val_loss'),
                 callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5),
                 callbacks.CSVLogger("./py/logs/log.csv", separator=',', append=True),
                 TrainValTensorBoard()]

# Start training
model.compile(loss='categorical_crossentropy', optimizer=optimizers.adam(lr=1e-4), metrics=['acc'])

In [None]:
history = model.fit(X_train, y_train, batch_size=64, epochs=500, shuffle=True, 
                    validation_split=0.2, callbacks=callbacks_list)

Train on 819244 samples, validate on 204812 samples
Epoch 1/500
Epoch 2/500
Epoch 3/500
Epoch 4/500
Epoch 5/500
Epoch 6/500
Epoch 7/500
Epoch 8/500
Epoch 9/500
Epoch 10/500
139456/819244 [====>.........................] - ETA: 23:33 - loss: 0.5216 - acc: 0.7402

### 5. Evaluation

In [None]:
# load in libraries
import pickle
import itertools
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score

In [None]:
# make directories
if not os.path.exists('./py/metrics/'):
    os.makedirs('./py/metrics/')

In [None]:
def plot_history(history):
    loss_list = [s for s in history.keys() if 'loss' in s and 'val' not in s]
    val_loss_list = [s for s in history.keys() if 'loss' in s and 'val' in s]
    acc_list = [s for s in history.keys() if 'acc' in s and 'val' not in s]
    val_acc_list = [s for s in history.keys() if 'acc' in s and 'val' in s]
    
    if len(loss_list) == 0:
        print('Loss is missing in history')
        return 
    
    ## As loss always exists
    epochs = range(1,len(history[loss_list[0]]) + 1)
    
   ## Loss
    plt.figure(1)
    for l in loss_list:
        plt.plot(epochs, history[l], 'b', label='Training loss (' + str(str(format(history[l][-1],'.5f'))+')'))
    for l in val_loss_list:
        plt.plot(epochs, history[l], 'g', label='Validation loss (' + str(str(format(history[l][-1],'.5f'))+')'))
    
    plt.title('Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig("./py/metrics/loss.png")
    
    ## Accuracy
    plt.figure(2)
    for l in acc_list:
        plt.plot(epochs, history[l], 'b', label='Training accuracy (' + str(format(history[l][-1],'.5f'))+')')
    for l in val_acc_list:    
        plt.plot(epochs, history[l], 'g', label='Validation accuracy (' + str(format(history[l][-1],'.5f'))+')')

    plt.title('Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.show()
    plt.savefig("./py/metrics/acc.png")
    
def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        title='Normalized confusion matrix'
    else:
        title='Confusion matrix'

    plt.figure(3)
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.savefig("./py/metrics/confuMat.png")
    plt.show()
    
def full_multiclass_report(model,
                           x,
                           y_true,
                           classes):
    
    # 2. Predict classes and stores in y_pred
    y_pred = model.predict(x).argmax(axis=1)
    
    # 3. Print accuracy score
    print("Accuracy : "+ str(accuracy_score(y_true,y_pred)))
    
    print("")
    
    # 4. Print classification report
    print("Classification Report")
    print(classification_report(y_true,y_pred,digits=4))    
    
    # 5. Plot confusion matrix
    cnf_matrix = confusion_matrix(y_true,y_pred)
    print(cnf_matrix)
    plot_confusion_matrix(cnf_matrix,classes=classes)    

In [None]:
# Load in the data
howManyTest = 0.2

thisInd = np.random.randint(0, len(X_test), size=(len(X_test)//howManyTest))
X_conf, y_conf = X_test[[i for i in thisInd], :], y_test[[i for i in thisInd],:] 

'''
## Only if you have a previous model + history
# Get the model
model = models.load_model('./py/model/model_1230.h5')

# Get the history
with open('./history/history_1230.pkl', 'rb') as hist:
    history = pickle.load(hist)
'''

# Get the graphics
plot_history(history)
X_test = X_test.reshape(X_test.shape[0], X_train.shape[1], X_train.shape[2], X_train.shape[3], 1)
full_multiclass_report(model,
                       X_test,
                       y_test.argmax(axis=1),
                       [1,2,3,4,5])