# EEG Classification - Tensorflow
updated: Sep. 01, 2018

Data: https://www.physionet.org/pn4/eegmmidb/

## 1. Data Downloads

### Warning: Executing these blocks will automatically create directories and download datasets.

In [1]:
import tensorflow as tf
sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))

Using TensorFlow backend.
W0702 16:00:58.624246 4726728128 deprecation_wrapper.py:119] From /anaconda3/envs/eeg-env/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:174: The name tf.get_default_session is deprecated. Please use tf.compat.v1.get_default_session instead.

W0702 16:00:58.625966 4726728128 deprecation_wrapper.py:119] From /anaconda3/envs/eeg-env/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:181: The name tf.ConfigProto is deprecated. Please use tf.compat.v1.ConfigProto instead.

W0702 16:00:58.627233 4726728128 deprecation_wrapper.py:119] From /anaconda3/envs/eeg-env/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:186: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.

W0702 16:00:58.646740 4726728128 deprecation_wrapper.py:119] From /anaconda3/envs/eeg-env/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:190: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_

[]

In [2]:
from keras.backend.tensorflow_backend import set_session
import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True  # dynamically grow the memory used on the GPU
config.log_device_placement = True  # to log device placement (on which device the operation ran)
sess = tf.Session(config=config)
set_session(sess)  # set this TensorFlow session as the default session for Keras

Using TensorFlow backend.


In [3]:
from keras import backend as K
K.tensorflow_backend._get_available_gpus()

['/job:localhost/replica:0/task:0/device:GPU:0',
 '/job:localhost/replica:0/task:0/device:GPU:1']

In [2]:
# Tensorflow Style Guide
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# System
import requests
import re
import os
import pathlib
import urllib

# Modeling & Preprocessing
from keras.layers import Conv2D, BatchNormalization, Activation, Flatten, Dense, Dropout, LSTM, Input, TimeDistributed
from keras import initializers, Model, optimizers, callbacks
from keras.utils.training_utils import multi_gpu_model
from keras import backend as K
from keras.models import load_model
from keras.callbacks import Callback, TensorBoard
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score

# Essential Data Handling
import numpy as np
import pandas as pd
from math import ceil, floor

# Get Paths
from glob import glob

# EEG package
from mne import pick_types, events_from_annotations
from mne.io import read_raw_edf

import pickle
import sys

In [3]:
# just to take the dataset from physionet web site
CONTEXT = 'pn4/'
MATERIAL = 'eegmmidb/'
URL = 'https://www.physionet.org/' + CONTEXT + MATERIAL

# Change this directory according to your setting
USERDIR = './py/data/'

page = requests.get(URL).text
FOLDERS = sorted(list(set(re.findall(r'S[0-9]+', page))))

URLS = [URL+x+'/' for x in FOLDERS]

## Data Description

Subjects performed different motor/imagery tasks while 64-channel EEG were recorded using the BCI2000 system (http://www.bci2000.org). Each subject performed 14 experimental runs: 

- two one-minute baseline runs (one with eyes open, one with eyes closed)
- three two-minute runs of each of the four following tasks:
    - 1:
        - A target appears on either the left or the right side of the screen. 
        - The subject opens and closes the corresponding fist until the target disappears. 
        - Then the subject relaxes.
    - 2:
        - A target appears on either the left or the right side of the screen. 
        - The subject imagines opening and closing the corresponding fist until the target disappears. 
        - Then the subject relaxes.
    - 3:
        - A target appears on either the top or the bottom of the screen. 
        - The subject opens and closes either both fists (if the target is on top) or both feet (if the target is on the bottom) until the target disappears. 
        - Then the subject relaxes.
    - 4:
        - A target appears on either the top or the bottom of the screen. 
        - The subject imagines opening and closing either both fists (if the target is on top) or both feet (if the target is on the bottom) until the target disappears. 
        - Then the subject relaxes.

The data are provided here in EDF+ format (containing 64 EEG signals, each sampled at 160 samples per second, and an annotation channel). 
For use with PhysioToolkit software, rdedfann generated a separate PhysioBank-compatible annotation file (with the suffix .event) for each recording. 
The .event files and the annotation channels in the corresponding .edf files contain identical data.

# Summary tasks

Remembering that:

    - Task 1 (open and close left or right fist)
    - Task 2 (imagine opening and closing left or right fist)
    - Task 3 (open and close both fists or both feet)
    - Task 4 (imagine opening and closing both fists or both feet)

we will referred to 'Task *' with the meneaning above. 

In summary, the experimental runs were:

1.  Baseline, eyes open
2.  Baseline, eyes closed
3.  Task 1 
4.  Task -2 
5.  Task --3 
6.  Task ---4 
7.  Task 1
8.  Task -2
9.  Task --3
10. Task ---4
11. Task 1
12. Task -2
13. Task --3
14. Task ---4

# Annotation

Each annotation includes one of three codes (T0, T1, or T2):

- T0 corresponds to rest
- T1 corresponds to onset of motion (real or imagined) of
    - the left fist (in runs 3, 4, 7, 8, 11, and 12)
    - both fists (in runs 5, 6, 9, 10, 13, and 14)
- T2 corresponds to onset of motion (real or imagined) of
    - the right fist (in runs 3, 4, 7, 8, 11, and 12)
    - both feet (in runs 5, 6, 9, 10, 13, and 14)
    
In the BCI2000-format versions of these files, which may be available from the contributors of this data set, these annotations are encoded as values of 0, 1, or 2 in the TargetCode state variable.

{'T0':0, 'T1':1, 'T2':2}

In our experiments we will see only :

- run_type_0:
    - append_X
- run_type_1
    - append_X_y
- run_type_2
    - append_X_y
    
and the coding is: 

- T0 corresponds to rest 
    - (2)
- T1 (real or imagined)
    - (4,  8, 12) the left fist 
    - (6, 10, 14) both fists 
- T2 (real or imagined)
    - (4,  8, 12) the right fist 
    - (6, 10, 14) both feet 

## 2. Raw Data Import

I will use a EEG data handling package named MNE (https://martinos.org/mne/stable/index.html) to import raw data and annotation for events from edf files. This package also provides essential signal analysis features, e.g. band-pass filtering. The raw data were filtered using 1Hz of high-pass filter.

In this research, there are 5 classes for the data, imagined motion of:
    - right fist, 
    - left fist, 
    - both fists, 
    - both feet,
    - rest with eyes closed.

A data (S089) from one of the 109 subjects was excluded as the record was severely corrupted.

In [4]:
# Get file paths
PATH = './py/data/'
SUBS = glob(PATH + 'S[0-9]*')
FNAMES = sorted([x[-4:] for x in SUBS])

REMOVE = ['S088', 'S089', 'S092', 'S100']

# Remove subject 'S089' with damaged data and 'S088', 'S092', 'S100' with 128Hz sampling rate (we want 160Hz)
FNAMES = [ x for x in FNAMES if x not in REMOVE] 

emb = {'T0': 1, 'T1': 2, 'T2': 3}

In [16]:
def get_data(subj_num=FNAMES, epoch_sec=0.0625):
    """ Import from edf files data and targets in the shape of 3D tensor
    
        Output shape: (Trial*Channel*TimeFrames)
        
        Some edf+ files recorded at low sampling rate, 128Hz, are excluded. 
        Majority was sampled at 160Hz.
        
        epoch_sec: time interval for one segment of mashes (0.0625 is 1/16 as a fraction)
    """
    
    # Event codes mean different actions for two groups of runs
    # run_type_0 = '01'.split(',')
    # run_type_1 = '03,07,11'.split(',')
    # run_type_2 = '05,09,13'.split(',')
    
    run_type_0 = '02'.split(',')
    run_type_1 = '04,08,12'.split(',')
    run_type_2 = '06,10,14'.split(',')
    
    # Initiate X, y
    X = []
    y = []
    
    # To compute the completion rate
    count = len(subj_num)
    
    # fixed numbers
    nChan = 64 
    sfreq = 160
    sliding = epoch_sec/2 
    timeFromQue = 0.5
    timeExercise = 4.1 #secomds

    # Sub-function to assign X and X, y
    def append_X(n_segments, data, event=[]):
        # Data should be changed
        '''This function generate a tensor for X and append it to the existing X'''
    
        def window(n):
            # (80) + (160 * 1/16 * n) 
            windowStart = int(timeFromQue*sfreq) + int(sfreq*sliding*n) 
            # (80) + (160 * 1/16 * (n+2))
            windowEnd = int(timeFromQue*sfreq) + int(sfreq*sliding*(n+2)) 
            
            while (windowEnd - windowStart) != sfreq*epoch_sec:
                windowEnd += int(sfreq*epoch_sec) - (windowEnd - windowStart)
                
            return [windowStart, windowEnd]
        
        new_x = []
        for n in range(n_segments):
            # print('data[:, ',window(n)[0],':',window(n)[1],'].shape = ', data[:, window(n)[0]:window(n)[1]].shape, '(',nChan,',',int(sfreq*epoch_sec),')')
            
            if data[:, window(n)[0]:window(n)[1]].shape==(nChan, int(sfreq*epoch_sec)):
                new_x.append(data[:, window(n)[0]: window(n)[1]])
                 
        return new_x
    
    def append_X_Y(run_type, event, old_x, old_y, data):
        '''This function seperate the type of events 
        (refer to the data descriptitons for the list of the types)
        Then assign X and Y according to the event types'''
        # Number of sliding windows

        # print('data', data.shape[1])
        n_segments = floor(data.shape[1]/(epoch_sec*sfreq*timeFromQue) - 1/epoch_sec - 1)
        # print('run_'+str(run_type),' n_segments', n_segments)
        
        
        # Rest excluded
        if event[2] == emb['T0']:
            return old_x, old_y
        
        # y assignment
        if run_type == 1:
            temp_y = [1] if event[2] == 'T1' else [2]
        
        elif run_type == 2:
            temp_y = [3] if event[2] == 'T1' else [4]
            
        # print('timeExercise * sfreq', timeExercise*sfreq, ' ?= 656')
        new_x = append_X(n_segments, data, event)
        new_y = old_y + temp_y*len(new_x)
        
        return old_x + new_x, new_y
    
    # Iterate over subj_num: S001, S002, S003, ...
    for i, subj in enumerate(subj_num):
        
        
        
        # Return completion rate
        if i%((len(subj_num)//10)+1) == 0:
            print('\n')
            print('working on {}, {:.0%} completed'.format(subj, i/count))
            print('\n')
        
        old_size = np.array(y).shape[0]
        # print('subj:', subj, '| y.shape', np.array(y).shape ,'| X.shape', np.array(X).shape)

        # Get file names
        fnames = glob(os.path.join(PATH, subj, subj+'R*.edf'))
        # Hold only the files that have an even number
        fnames = sorted([name for name in fnames if name[-6:-4] in run_type_0+run_type_1+run_type_2])

        # for each of ['02', '04', '06', '08', '12', '14']
        for i, fname in enumerate(fnames):
            
            # Import data into MNE raw object
            raw = read_raw_edf(fname, preload=True, verbose=False)
            
            picks = pick_types(raw.info, eeg=True)
            # print('n_times', raw.n_times)
            
            if raw.info['sfreq'] != 160:
                print('{} is sampled at 128Hz so will be excluded.'.format(subj))
                break
            
            
            # High-pass filtering
            raw.filter(l_freq=1, h_freq=None, picks=picks)

            # Get annotation
            try:
                events = events_from_annotations(raw, verbose=False)
            except:
                continue

            # Get data
            data = raw.get_data(picks=picks)

            # print('event.shape', np.array(events[0]).shape, '| data.shape', data.shape)

            # Number of this run
            which_run = fname[-6:-4]

            """ Assignment Starts """ 
            # run 1 - baseline (eye closed)
            if which_run in run_type_0:

                # Number of sliding windows
                n_segments = floor(data.shape[1]/(epoch_sec*sfreq*timeFromQue) - 1/epoch_sec - 1)
                # print('run_0 n_segments', n_segments)

                # Append 0`s based on number of windows
                new_X = append_X(n_segments, data)
                X += new_X
                y.extend([0] * len(new_X))
                # print(events[0])   

            # run 4,8,12 - imagine opening and closing left or right fist    
            elif which_run in run_type_1:

                for i, event in enumerate(events[0]):

                    X, y = append_X_Y(run_type=1, event=event, old_x=X, old_y=y, data=data[:, int(event[0]) : int(event[0] + timeExercise*sfreq)])
                    # print(event)   

            # run 6,10,14 - imagine opening and closing both fists or both feet
            elif which_run in run_type_2:

                for i, event in enumerate(events[0]):      

                    X, y = append_X_Y(run_type=2, event=event, old_x=X, old_y=y, data=data[:, int(event[0]) : int(event[0] + timeExercise*sfreq)])
                    # print(event)    

        print('subj:', subj, '|', np.array(y).shape[0] - old_size, '| y.shape', np.array(y).shape ,'| X.shape', np.array(X).shape)
        
    print(np.array(X).shape)

    X = np.stack(X)
    y = np.array(y).reshape((-1,1))
    return X, y

In [17]:
X,y = get_data(FNAMES, epoch_sec=0.0625) # to get 20 set 0.125



working on S001, 0% completed


subj: S001 | 12195 | y.shape (12195,) | X.shape (12195, 64, 10)
subj: S002 | 12195 | y.shape (24390,) | X.shape (24390, 64, 10)
subj: S003 | 12195 | y.shape (36585,) | X.shape (36585, 64, 10)
subj: S004 | 12195 | y.shape (48780,) | X.shape (48780, 64, 10)
subj: S005 | 12195 | y.shape (60975,) | X.shape (60975, 64, 10)
subj: S006 | 12195 | y.shape (73170,) | X.shape (73170, 64, 10)
subj: S007 | 12195 | y.shape (85365,) | X.shape (85365, 64, 10)
subj: S008 | 12195 | y.shape (97560,) | X.shape (97560, 64, 10)
subj: S009 | 12195 | y.shape (109755,) | X.shape (109755, 64, 10)
subj: S010 | 12195 | y.shape (121950,) | X.shape (121950, 64, 10)
subj: S011 | 12195 | y.shape (134145,) | X.shape (134145, 64, 10)


working on S012, 10% completed


subj: S012 | 12195 | y.shape (146340,) | X.shape (146340, 64, 10)
subj: S013 | 12195 | y.shape (158535,) | X.shape (158535, 64, 10)
subj: S014 | 12195 | y.shape (170730,) | X.shape (170730, 64, 10)
subj: S015 | 12195 | y.

In [18]:
print(X.shape)
print(y.shape)

(1280070, 64, 10)
(1280070, 1)


In [None]:
print(sys.getsizeof(X))

In [None]:
pickle.dump( X , open( "./py/stack/X.p", "wb" ) , protocol=4)

In [None]:
pickle.dump( y , open( "./py/stack/y.p", "wb" ) , protocol=4)

In [None]:
X = pickle.load( open( "./py/stack/X.p", "rb" ) )
y = pickle.load( open( "./py/stack/y.p", "rb" ) )

## 3. Data Preprocessing

The original goal of applying neural networks is to exclude hand-crafted algorithms & preprocessing as much as possible. I did not use any proprecessing techniques further than standardization to build an end-to-end classifer from the dataset

In [19]:
import numpy as np
from sklearn.preprocessing import OneHotEncoder, scale

#%%
def convert_mesh(X):
    
    mesh = np.zeros((X.shape[0], X.shape[2], 10, 11, 1))
    X = np.swapaxes(X, 1, 2)
    
    # 1st line
    mesh[:, :, 0, 4:7, 0] = X[:,:,21:24]; print('1st finished')
    
    # 2nd line
    mesh[:, :, 1, 3:8, 0] = X[:,:,24:29]; print('2nd finished')
    
    # 3rd line
    mesh[:, :, 2, 1:10, 0] = X[:,:,29:38]; print('3rd finished')
    
    # 4th line
    mesh[:, :, 3, 1:10, 0] = np.concatenate((X[:,:,38].reshape(-1, X.shape[1], 1),\
                                          X[:,:,0:7], X[:,:,39].reshape(-1, X.shape[1], 1)), axis=2)
    print('4th finished')
    
    # 5th line
    mesh[:, :, 4, 0:11, 0] = np.concatenate((X[:,:,(42, 40)],\
                                        X[:,:,7:14], X[:,:,(41, 43)]), axis=2)
    print('5th finished')
    
    # 6th line
    mesh[:, :, 5, 1:10, 0] = np.concatenate((X[:,:,44].reshape(-1, X.shape[1], 1),\
                                        X[:,:,14:21], X[:,:,45].reshape(-1, X.shape[1], 1)), axis=2)
    print('6th finished')
               
    # 7th line
    mesh[:, :, 6, 1:10, 0] = X[:,:,46:55]; print('7th finished')
    
    # 8th line
    mesh[:, :, 7, 3:8, 0] = X[:,:,55:60]; print('8th finished')
    
    # 9th line
    mesh[:, :, 8, 4:7, 0] = X[:,:,60:63]; print('9th finished')
    
    # 10th line
    mesh[:, :, 9, 5, 0] = X[:,:,63]; print('10th finished')
    
    return mesh

#%%
def prepare_data(X, y, test_ratio=0.2, return_mesh=True, set_seed=42):
    
    # y encoding
    oh = OneHotEncoder(categories='auto')
    y = oh.fit_transform(y).toarray()
    
    # Shuffle trials
    np.random.seed(set_seed)
    trials = X.shape[0]
    shuffle_indices = np.random.permutation(trials)
    X = X[shuffle_indices]
    y = y[shuffle_indices]
    
    # Test set seperation
    train_size = int(trials*(1-test_ratio)) 
    X_train, X_test, y_train, y_test = X[:train_size,:,:], X[train_size:,:,:],\
                                    y[:train_size,:], y[train_size:,:]
                                    
    # Z-score Normalization
    def scale_data(X):
        shape = X.shape
        for i in range(shape[0]):
            # Standardize a dataset along any axis
            # Center to the mean and component wise scale to unit variance.
            X[i,:, :] = scale(X[i,:, :])
            if i%int(shape[0]//10) == 0:
                print('{:.0%} done'.format((i+1)/shape[0]))   
        return X
            
    X_train, X_test  = scale_data(X_train), scale_data(X_test)
    if return_mesh:
        X_train, X_test = convert_mesh(X_train), convert_mesh(X_test)
    
    return X_train, y_train, X_test, y_test
    
    

In [20]:
X_train, y_train, X_test, y_test = prepare_data(X, y)

0% done
10% done
20% done
30% done
40% done
50% done
60% done
70% done
80% done
90% done
100% done
0% done
10% done
20% done
30% done
40% done
50% done
60% done
70% done
80% done
90% done
100% done
1st finished
2nd finished
3rd finished
4th finished
5th finished
6th finished
7th finished
8th finished
9th finished
10th finished
1st finished
2nd finished
3rd finished
4th finished
5th finished
6th finished
7th finished
8th finished
9th finished
10th finished


In [None]:
pickle.dump( [X_train, y_train]  , open( "./py/stack/train.p", "wb" ) , protocol=4)

In [None]:
pickle.dump( [X_test, y_test]  , open( "./py/stack/test.p", "wb" ) , protocol=4)

In [5]:
[X_train, y_train] = pickle.load( open( "./py/stack/train.p", "rb" ) )

In [6]:
[X_test, y_test] = pickle.load( open( "./py/stack/test.p", "rb" ) )

As the EEG recording instrument has 3D locations over the subjects\` scalp, it is essential for the model to learn from the spatial pattern as well as the temporal pattern. I transformed the data into 2D meshes that represents the locations of the electrodes so that stacked convolutional neural networks can grasp the spatial information.

## 4. Modeling - Time-Distributed CNN + RNN

Training Plan:

+ 4 GPU units (Nvidia Tesla P100) were used to train this neural network.
+ Instead of training the whole model at once, I trained the first block (CNN) first. Then using the trained parameters as initial values, I trained the next blocks step-by-step. This approach can greatly reduce the time required for training and help avoiding falling into local minimums.
+ The first blocks (CNN) can be applied for other EEG classification models as a pre-trained base.

+ The initial learning rate is set to be $10^{3}$ with Adam optimization. I used several callbacks such as ReduceLROnPlateau which adjusts the learning rate at local minima. Also, I record the log for tensorboard to monitor the training process.

In [21]:
print(X_train.shape)
print(X_test.shape)

(1024056, 10, 10, 11, 1)
(256014, 10, 10, 11, 1)


In [31]:
print(X_train.shape)
print(X_test.shape)

In [7]:
X_train = X_train.squeeze().reshape(*X_train.squeeze().shape, 1)
X_test = X_test.squeeze().reshape(*X_test.squeeze().shape, 1)

In [32]:
print(X_train.shape)
print(X_test.shape)

(1024056, 10, 10, 11, 1)
(256014, 10, 10, 11, 1)


In [24]:
# Make another dimension, 1, to apply CNN for each time frame.
X_train = X_train.reshape(*X_train.shape, 1)
X_test = X_test.reshape(*X_test.shape, 1)

In [33]:
print(X_train.shape)
print(X_test.shape)

(1024056, 10, 10, 11, 1)
(256014, 10, 10, 11, 1)


### 4.1 Keras Implementation

The Keras functional API is the way to go for defining complex models, such as multi-output models, directed acyclic graphs, or models with shared layers.

In [40]:
# Complicated Model - the same as Zhang`s
input_shape = (10, 10, 11, 1)
lecun = initializers.lecun_normal(seed=42)

# TimeDistributed Wrapper
def timeDist(layer, prev_layer, name):
    return TimeDistributed(layer, name=name)(prev_layer)
    
# Input layer
inputs = Input(shape=input_shape)

# Convolutional layers block
x = timeDist(Conv2D(32, (3,3), padding='same', data_format='channels_last', kernel_initializer=lecun), inputs, name='CNN1')
x = BatchNormalization(name='batch1')(x)
x = Activation('elu', name='act1')(x)
x = timeDist(Conv2D(64, (3,3), padding='same', data_format='channels_last', kernel_initializer=lecun), x, name='CNN2')
x = BatchNormalization(name='batch2')(x)
x = Activation('elu', name='act2')(x)
x = timeDist(Conv2D(128, (3,3), padding='same', data_format='channels_last', kernel_initializer=lecun), x, name='CNN3')
x = BatchNormalization(name='batch3')(x)
x = Activation('elu', name='act3')(x)
x = timeDist(Flatten(), x, name='flatten')

# Fully connected layer block
y = Dense(1024, kernel_initializer=lecun, name='FC')(x)
y = Dropout(0.5, name='dropout1')(y)
y = BatchNormalization(name='batch4')(y)
y = Activation(activation='elu')(y)

# Recurrent layers block
z = LSTM(64, kernel_initializer=lecun, return_sequences=True, name='LSTM1')(y)
z = LSTM(64, kernel_initializer=lecun, name='LSTM2')(z)

# Fully connected layer block
h = Dense(1024, kernel_initializer=lecun, activation='elu', name='FC2')(z)
h = Dropout(0.5, name='dropout2')(h)

# Output layer
outputs = Dense(3, activation='softmax')(h)

# Model compile
model = Model(inputs=inputs, outputs=outputs)
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 10, 10, 11, 1)     0         
_________________________________________________________________
CNN1 (TimeDistributed)       (None, 10, 10, 11, 32)    320       
_________________________________________________________________
batch1 (BatchNormalization)  (None, 10, 10, 11, 32)    128       
_________________________________________________________________
act1 (Activation)            (None, 10, 10, 11, 32)    0         
_________________________________________________________________
CNN2 (TimeDistributed)       (None, 10, 10, 11, 64)    18496     
_________________________________________________________________
batch2 (BatchNormalization)  (None, 10, 10, 11, 64)    256       
_________________________________________________________________
act2 (Activation)            (None, 10, 10, 11, 64)    0         
__________

In [10]:
model.save('./py/model/model_base.h5')

In [None]:
model = load_model('./py/model/model_14_0.7098.h5')

In [None]:
''' 
# Load a model to transfer pre-trained parameters
trans_model = model.load('CNN_3blocks.h5')

# Transfer learning - parameter copy & paste
which_layer = 'CNN1,CNN2,CNN3,batch1,batch2,batch3'.split(',')
layer_names = [layer.name for layer in model.layers]
trans_layer_names = [layer.name for layer in trans_model.layers]

for layer in which_layer:
    ind = layer_names.index(layer)
    trans_ind = trans_layer_names.index(layer)
    model.layers[ind].set_weights(trans_model.layers[trans_ind].get_weights())
    
for layer in model.layers[:9]: # Freeze the first 9 layers(CNN block)
    layer.trainable = False
    


# Turn on multi-GPU mode
model = multi_gpu_model(model, gpus=4)

This metrics calculate sensitivity and specificity batch-wise.
Keras development team removed this feature because
these metrics should be understood as global metrics.


I am not using it this time.

# Metrics - sensitivity, specificity, accuracy
def sens(y_true, y_pred): # Sensitivity
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    return true_positives / (possible_positives + K.epsilon())

def prec(y_true, y_pred): # Precision
    true_negatives = K.sum(K.round(K.clip((1-y_true) * (1-y_pred), 0, 1)))
    possible_negatives = K.sum(K.round(K.clip(1-y_true, 0, 1)))
    return true_negatives / (possible_negatives + K.epsilon())
''' 

In [42]:
from keras.callbacks import Callback

class CustomModelCheckPoint(Callback):
    def __init__(self,**kargs):
        super(CustomModelCheckPoint,self).__init__(**kargs)
        self.epoch_accuracy = {} # loss at given epoch
        self.epoch_loss = {} # accuracy at given epoch
        def on_epoch_begin(self,epoch, logs={}):
            # Things done on beginning of epoch. 
            return

        def on_epoch_end(self, epoch, logs={}):
            # things done on end of the epoch
            self.epoch_accuracy[epoch] = logs.get("acc")
            self.epoch_loss[epoch] = logs.get("loss")
            self.model.save_weights("./py/model/model_%d.h5" %epoch) # save the model

In [15]:
class TrainValTensorBoard(TensorBoard):
    '''
    Plot training and validation losses on the same Tensorboard graph
    Supersede Tensorboard callback
    '''
    def __init__(self, log_dir="./py/logs/", **kwargs):
        # Make the original `TensorBoard` log to a subdirectory 'training'
        training_log_dir = os.path.join(log_dir, 'training')
        super(TrainValTensorBoard, self).__init__(training_log_dir, **kwargs)

        # Log the validation metrics to a separate subdirectory
        self.val_log_dir = os.path.join(log_dir, 'validation')

    def set_model(self, model):
        # Setup writer for validation metrics
        self.val_writer = tf.summary.FileWriter(self.val_log_dir)
        super(TrainValTensorBoard, self).set_model(model)

    def on_epoch_end(self, epoch, logs=None):
        # Pop the validation logs and handle them separately with
        # `self.val_writer`. Also rename the keys so that they can
        # be plotted on the same figure with the training metrics
        logs = logs or {}
        val_logs = {k.replace('val_', ''): v for k, v in logs.items() if k.startswith('val_')}
        for name, value in val_logs.items():
            summary = tf.Summary()
            summary_value = summary.value.add()
            summary_value.simple_value = value.item()
            summary_value.tag = name
            self.val_writer.add_summary(summary, epoch)
        self.val_writer.flush()

        # Pass the remaining logs to `TensorBoard.on_epoch_end`
        logs = {k: v for k, v in logs.items() if not k.startswith('val_')}
        super(TrainValTensorBoard, self).on_epoch_end(epoch, logs)

    def on_train_end(self, logs=None):
        super(TrainValTensorBoard, self).on_train_end(logs)
        self.val_writer.close()

In [16]:
callbacks_list = [callbacks.ModelCheckpoint("./py/weights/weights_{epoch:02d}_{val_acc:.4f}.h5", save_best_only=False, monitor='val_loss'),
                 callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5),
                 callbacks.CSVLogger("./py/logs/log.csv", separator=',', append=True),
                 TrainValTensorBoard()]

# Start training
model.compile(loss='categorical_crossentropy', optimizer=optimizers.adam(lr=1e-4), metrics=['acc'])

In [None]:
history = model.fit(X_train, y_train, batch_size=64, epochs=500, shuffle=True, 
                    validation_split=0.2, callbacks=callbacks_list)

Train on 819244 samples, validate on 204812 samples
Epoch 1/500
Epoch 2/500
Epoch 3/500
Epoch 4/500
Epoch 5/500
Epoch 6/500
Epoch 7/500
Epoch 8/500
Epoch 9/500
Epoch 10/500
Epoch 11/500
Epoch 12/500
Epoch 13/500
Epoch 14/500
Epoch 15/500

### 4.2 Tensorflow Eager Execution API

TensorFlow's eager execution is an imperative programming environment that evaluates operations immediately, without building graphs: operations return concrete values instead of constructing a computational graph to run later. This makes it easy to get started with TensorFlow and debug models, and it reduces boilerplate as well. To follow along with this guide, run the code samples below in an interactive python interpreter.

In [None]:
# Parameters
learning_rate = 0.001
num_steps = 1000
batch_size = 128
display_step = 100

# Network Parameters
num_input = X.shape[0] # PhysioNet data input (mesh shape: 10*11)
num_classes = 5 # PhysioNet total classes

In [None]:
# Using TF Dataset to split data into batches
dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train)).batch(batch_size)
dataset_iter = tfe.Iterator(dataset)

In [None]:
class ZhangModel(tf.keras.Model):
    def __init__(self, n_nodes=[3,2,2], 
                 initializer=tf.contrib.layers.variance_scaling_initializer(mode="FAN_AVG")):
        """
        This is a tensorflow implementation of Zhang`s model (2018) with 
        3 CNN, 2 LSTM, 2 dense layers by default.
        
        n_nodes [list] : specifies the number of layers for each block. 
                        [CNN, LSTM, DENSE] respectively.
        initializer [tf.layers] : defualt initializer set to be He initializer
        """
        super().__init__()
        self.n_CNN, self.n_LSTM, self.n_dense = n_nodes
        self.CNN, self.LSTM, self.dense = [[] for i in range(len(n_nodes))]
        self.initializer = initializer
        
        count = 0
        for n in range(self.n_CNN):
            count += 1
            n_filter = 32*2**count
            
            self.CNN.append(tf.keras.layers.Conv2D(n_filter, (3, 3), padding='same', 
                                                data_format='channels_last',
                                                kernel_initializer=self.initializer))
        
        for n in range(self.n_LSTM):    
            n_hidden = 64
            return_sequences = False if n == self.n_LSTM-1 else True
            
            name = 'LSTM' + str(len(self.LSTM)+1)            
            self.LSTM.append(tf.keras.layers.LSTM(n_hidden, kernel_initializer=self.initializer, 
                                             return_sequences=return_sequences, name=name))
            
        for n in range(self.n_dense):
            n_node=1024
            
            name = 'Dense' + str(len(self.dense)+1)
            self.dense.append(tf.keras.layers.Dense(n_node, 
                                                    kernel_initializer=self.initializer, name=name))

    def call(self, input_tensor):
        "Run the model."
        
        assert self.CNN, 'No CNN blocks defined!'
        assert self.dense, 'No Dense Blocks defined!'
        assert self.LSTM, 'No LSTM blocks defined!'
        
        def timeDist(layer, prev_layer, name):
            return tf.keras.layers.TimeDistributed(layer, name=name)(prev_layer)
        
        for i, layer in enumerate(self.CNN):
            name = 'CNN' + str(i+1)
            nameBatch = 'batch' + str(i+1)
            nameAct = 'act' + str(i+1)
            
            prev_layer = input_tensor if i==0 else x
            
            x = timeDist(layer, prev_layer, name=name)
            x = tf.keras.layers.BatchNormalization(name=nameBatch)(x)
            x = tf.nn.elu(x, name=nameAct)
            
        x = timeDist(tf.keras.layers.Flatten(), x, name='flatten')
        
        for i, layer in enumerate(self.dense):            
            
            nameDrop = 'drop' + str(i+1)
            nameBatch = 'batch' + str(i+1)
            nameAct = 'act' + str(i+1)
            
            if i == len(self.dense)-1:
                break
            
            x = layer(x)
            x = tf.keras.layers.Dropout(0.5, name=nameDrop)(x)
            x = tf.keras.layers.BatchNormalization(name=nameBatch)(x)
            x = tf.nn.elu(x, name=nameAct)
            
        for i, layer in enumerate(self.LSTM):
            x = layer(x)
        
        x = self.dense[-1](x)
        x = tf.keras.layers.Dropout(0.5, name=nameDrop)(x)
        x = tf.keras.layers.BatchNormalization(name=nameBatch)(x)
        x = tf.nn.elu(x, name=nameAct)
        
        output = tf.keras.layers.Dense(5, activation='softmax')(x)
        
        return output    

In [None]:
model = ZhangModel([3, 2, 2])
print(model(tf.random_normal([1, 10, 10, 11, 1])))

In [None]:
# Cross-Entropy loss function
def loss_fn(inference_fn, inputs, labels):
    # Using sparse_softmax cross entropy
    return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
        logits=inference_fn(inputs), labels=labels))

# Calculate accuracy
def accuracy_fn(inference_fn, inputs, labels):
    prediction = inference_fn(inputs)
    correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(labels, 1))
    return tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# SGD Optimizer
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)

# Compute gradients
grad = tfe.implicit_gradients(loss_fn)

In [None]:
# Training
average_loss = 0.
average_acc = 0.
for step in range(num_steps):

    # Iterate through the dataset
    try:
        d = dataset_iter.next()
    except StopIteration:
        # Refill queue
        dataset_iter = tfe.Iterator(dataset)
        d = dataset_iter.next()

    # EEGs
    x_batch = d[0]
    # Labels
    y_batch = tf.cast(d[1], dtype=tf.int64)

    # Compute the batch loss
    batch_loss = loss_fn(model, x_batch, y_batch)
    average_loss += batch_loss
    
    # Compute the batch accuracy
    batch_accuracy = accuracy_fn(model, x_batch, y_batch)
    average_acc += batch_accuracy

    if step == 0:
        # Display the initial cost, before optimizing
        print("Initial loss= {:.9f}".format(average_loss))

    # Update the variables following gradients info
    optimizer.apply_gradients(grad(model, x_batch, y_batch))

    # Display info
    if (step + 1) % display_step == 0 or step == 0:
        if step > 0:
            average_loss /= display_step
            average_acc /= display_step
        print("Step:", '%04d' % (step + 1), " loss=",
              "{:.9f}".format(average_loss), " accuracy=",
              "{:.4f}".format(average_acc))
        average_loss = 0.
        average_acc = 0.

### 5. Evaluation

In [55]:
# load in libraries
import pickle
import itertools
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score

In [56]:
# make directories
if not os.path.exists('./py/metrics/'):
    os.makedirs('./py/metrics/')

In [57]:
def plot_history(history):
    loss_list = [s for s in history.keys() if 'loss' in s and 'val' not in s]
    val_loss_list = [s for s in history.keys() if 'loss' in s and 'val' in s]
    acc_list = [s for s in history.keys() if 'acc' in s and 'val' not in s]
    val_acc_list = [s for s in history.keys() if 'acc' in s and 'val' in s]
    
    if len(loss_list) == 0:
        print('Loss is missing in history')
        return 
    
    ## As loss always exists
    epochs = range(1,len(history[loss_list[0]]) + 1)
    
   ## Loss
    plt.figure(1)
    for l in loss_list:
        plt.plot(epochs, history[l], 'b', label='Training loss (' + str(str(format(history[l][-1],'.5f'))+')'))
    for l in val_loss_list:
        plt.plot(epochs, history[l], 'g', label='Validation loss (' + str(str(format(history[l][-1],'.5f'))+')'))
    
    plt.title('Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig("./py/metrics/loss.png")
    
    ## Accuracy
    plt.figure(2)
    for l in acc_list:
        plt.plot(epochs, history[l], 'b', label='Training accuracy (' + str(format(history[l][-1],'.5f'))+')')
    for l in val_acc_list:    
        plt.plot(epochs, history[l], 'g', label='Validation accuracy (' + str(format(history[l][-1],'.5f'))+')')

    plt.title('Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.show()
    plt.savefig("./py/metrics/acc.png")
    
def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        title='Normalized confusion matrix'
    else:
        title='Confusion matrix'

    plt.figure(3)
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.savefig("./py/metrics/confuMat.png")
    plt.show()
    
def full_multiclass_report(model,
                           x,
                           y_true,
                           classes):
    
    # 2. Predict classes and stores in y_pred
    y_pred = model.predict(x).argmax(axis=1)
    
    # 3. Print accuracy score
    print("Accuracy : "+ str(accuracy_score(y_true,y_pred)))
    
    print("")
    
    # 4. Print classification report
    print("Classification Report")
    print(classification_report(y_true,y_pred,digits=4))    
    
    # 5. Plot confusion matrix
    cnf_matrix = confusion_matrix(y_true,y_pred)
    print(cnf_matrix)
    plot_confusion_matrix(cnf_matrix,classes=classes)    

In [None]:
# Load in the data
howManyTest = 0.2

thisInd = np.random.randint(0, len(X_test), size=(len(X_test)//howManyTest))
X_conf, y_conf = X_test[[i for i in thisInd], :], y_test[[i for i in thisInd],:] 

'''
## Only if you have a previous model + history
# Get the model
model = models.load_model('./py/model/model_1230.h5')

# Get the history
with open('./history/history_1230.pkl', 'rb') as hist:
    history = pickle.load(hist)
'''

# Get the graphics
plot_history(history)
X_test = X_test.reshape(X_test.shape[0], X_train.shape[1], X_train.shape[2], X_train.shape[3], 1)
full_multiclass_report(model,
                       X_test,
                       y_test.argmax(axis=1),
                       [1,2,3,4,5])