# Fine Tuning of the convolutional network

In [1]:
%pylab
%matplotlib inline

import glob
import os
import mne
CORPORA_PATH = "~/corpora/sets"

file_path = os.path.expanduser(CORPORA_PATH)
files = glob.glob(os.path.join(file_path, "*.set"))

def normalize_subject(X):
    mean = X.mean(axis=(0, 2)).reshape(-1, 1)
    std = X.std(axis=(0, 2)).reshape(-1, 1)
    return (X - mean) / std

def load_data(filename, normalize=True):
    data_mne = mne.io.read_raw_eeglab(filename, preload=True, event_id={"0": 1, "1": 2})
    data_mne.filter(0, 20)
    events = mne.find_events(data_mne)
    epochs = mne.Epochs(
        data_mne, events,
        baseline=(None, 0), tmin=-0.1, tmax=0.7)

    epochs.load_data()
    
    ch_names = epochs.ch_names
    
    X = epochs.get_data()[:, :-1]
    y = (events[:, 2] == 2).astype('float')

    if len(events) != len(epochs):
        raise ValueError("Epochs events mismatch")
    if normalize: 
        X = normalize_subject(X)
    X = X[..., np.newaxis]
    
    return X, y 


Using matplotlib backend: TkAgg
Populating the interactive namespace from numpy and matplotlib


In [2]:
from keras.models import load_model
channels = ['AF3', 'F7', 'F3', 'FC5', 'T7', 'P7', 'O1', 'O2', 'P8', 'T8', 'FC6', 'F4', 'F8', 'AF4', 'STI 014']

model = load_model("models/model.h5")

Using TensorFlow backend.
  return f(*args, **kwds)


In [3]:
X, y = load_data(files[143])

Reading /home/ubuntu/corpora/sets/PruebasMuseo_7488001.fdt
Reading 0 ... 63231  =      0.000 ...   493.992 secs...
Setting up low-pass filter at 20 Hz
h_trans_bandwidth chosen to be 5.0 Hz
Filter length of 169 samples (1.320 sec) selected
1800 events found
Events id: [1 2]
1800 matching events found
0 projection items activated
Loading data for 1800 events and 104 original time points ...
0 bad epochs dropped


In [4]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test= train_test_split(X, y, test_size=0.1, stratify=y)

In [5]:


from sklearn.metrics import precision_score, recall_score, roc_auc_score, accuracy_score

def get_metrics(model, X_test, y_test):
    y_pred = model.predict_classes(X_test)
    y_prob = model.predict(X_test)

    precision = precision_score(y_test, y_pred)
    recall = recall_score(y_test, y_pred)
    auc = roc_auc_score(y_test, y_prob)
    accuracy = accuracy_score(y_test, y_pred)

    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall, 
        "roc_auc": auc
    }
    
    
get_metrics(model, X_test, y_test)

{'accuracy': 0.5444444444444444,
 'precision': 0.24509803921568626,
 'recall': 0.83333333333333337,
 'roc_auc': 0.6651111111111111}

In [6]:
model.layers

[<keras.layers.convolutional.Conv2D at 0x7fb5fe916c18>,
 <keras.layers.convolutional.Conv2D at 0x7fb5fe916f28>,
 <keras.layers.core.Flatten at 0x7fb5fe916ef0>,
 <keras.layers.core.Dropout at 0x7fb5fe8e94a8>,
 <keras.layers.core.Dense at 0x7fb5fe8d2dd8>,
 <keras.layers.core.Dense at 0x7fb5fe8fffd0>]

Let's fix the first two convolutional layers

In [7]:

for i in range(4):
    model.layers[i].trainable = False



model.compile(loss='binary_crossentropy', # using the cross-entropy loss function
              optimizer='rmsprop', 
              metrics=['accuracy']) # reporting the accuracy
[(l, "Trainable: {}".format(l.trainable)) for l in model.layers]

[(<keras.layers.convolutional.Conv2D at 0x7fb5fe916c18>, 'Trainable: False'),
 (<keras.layers.convolutional.Conv2D at 0x7fb5fe916f28>, 'Trainable: False'),
 (<keras.layers.core.Flatten at 0x7fb5fe916ef0>, 'Trainable: False'),
 (<keras.layers.core.Dropout at 0x7fb5fe8e94a8>, 'Trainable: False'),
 (<keras.layers.core.Dense at 0x7fb5fe8d2dd8>, 'Trainable: True'),
 (<keras.layers.core.Dense at 0x7fb5fe8fffd0>, 'Trainable: True')]

In [8]:
from keras.callbacks import ModelCheckpoint, EarlyStopping

early_stopping = EarlyStopping(monitor='val_loss', patience=3)

model.fit(
    X_train, y_train, epochs=30, 
    batch_size=64, class_weight={0:1, 1:6}, validation_split=0.01,
    callbacks=[early_stopping]
)

Train on 1603 samples, validate on 17 samples
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30


<keras.callbacks.History at 0x7fb5fafd9ac8>

In [9]:
get_metrics(model, X_test, y_test)

{'accuracy': 0.56666666666666665,
 'precision': 0.24468085106382978,
 'recall': 0.76666666666666672,
 'roc_auc': 0.68711111111111112}

In [10]:
from os.path import basename
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras import backend as K

def fix_layers(model, fixed_layers):
    for i in range(fixed_layers):
        model.layers[i].trainable = False
    
    model.compile(loss='binary_crossentropy',
              optimizer='rmsprop', 
              metrics=['accuracy'])
    
def fine_tune(fixed_layers):
    model = load_model("models/model.h5")
    
    fix_layers(model, fixed_layers)
    
    early_stopping = EarlyStopping(monitor='val_loss', patience=3)

    model.fit(
        X_train, y_train, epochs=10, 
        batch_size=64, class_weight={0:1, 1:6}, validation_split=0.01,
        callbacks=[early_stopping]
    )
    
    return model

def get_analysis(filename, fixed_layers=4):
    K.clear_session()
    model = load_model("models/model.h5")
    X, y = load_data(files[143])
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, stratify=y)
    
    ret = {"file" : basename(filename)}
    metrics = {"ft_0_{}".format(k):v for k,v in get_metrics(model, X_test, y_test).items()}
    ret.update(metrics)
    
    model = fine_tune(4)
    metrics = {"ft_4_{}".format(k):v for k,v in get_metrics(model, X_test, y_test).items()}
    ret.update(metrics)
    
    model = fine_tune(5)
    metrics = {"ft_5_{}".format(k):v for k,v in get_metrics(model, X_test, y_test).items()}
    ret.update(metrics)
    
    K.clear_session()
    return ret

get_analysis(files[100], 4)

Reading /home/ubuntu/corpora/sets/PruebasMuseo_7488001.fdt
Reading 0 ... 63231  =      0.000 ...   493.992 secs...
Setting up low-pass filter at 20 Hz
h_trans_bandwidth chosen to be 5.0 Hz
Filter length of 169 samples (1.320 sec) selected
1800 events found
Events id: [1 2]
1800 matching events found
0 projection items activated
Loading data for 1800 events and 104 original time points ...
0 bad epochs dropped
Train on 1603 samples, validate on 17 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Train on 1603 samples, validate on 17 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10


{'file': 'PruebasMuseo_28970001.set',
 'ft_0_accuracy': 0.58333333333333337,
 'ft_0_precision': 0.24719101123595505,
 'ft_0_recall': 0.73333333333333328,
 'ft_0_roc_auc': 0.74199999999999999,
 'ft_4_accuracy': 0.72222222222222221,
 'ft_4_precision': 0.35294117647058826,
 'ft_4_recall': 0.80000000000000004,
 'ft_4_roc_auc': 0.81355555555555559,
 'ft_5_accuracy': 0.63888888888888884,
 'ft_5_precision': 0.27272727272727271,
 'ft_5_recall': 0.69999999999999996,
 'ft_5_roc_auc': 0.7513333333333333}

In [11]:
analysis = [get_analysis(file) for file in files[-10:]]
    

Reading /home/ubuntu/corpora/sets/PruebasMuseo_7488001.fdt
Reading 0 ... 63231  =      0.000 ...   493.992 secs...
Setting up low-pass filter at 20 Hz
h_trans_bandwidth chosen to be 5.0 Hz
Filter length of 169 samples (1.320 sec) selected
1800 events found
Events id: [1 2]
1800 matching events found
0 projection items activated
Loading data for 1800 events and 104 original time points ...
0 bad epochs dropped
Train on 1603 samples, validate on 17 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Train on 1603 samples, validate on 17 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Reading /home/ubuntu/corpora/sets/PruebasMuseo_7488001.fdt
Reading 0 ... 63231  =      0.000 ...   493.992 secs...
Setting up low-pass filter at 20 Hz
h_trans_bandwidth chosen to be 5.0 Hz
Filter length of 169 samples (1.320 sec) selected
1800 events found
Events id: [1 2]
1800 matching events found
0 projection items activated
Loading data for 1800 events and 104 original time points ...
0 bad epochs dr

In [14]:
import pandas as pd

df = pd.DataFrame(analysis)

df

Unnamed: 0,file,ft_0_accuracy,ft_0_precision,ft_0_recall,ft_0_roc_auc,ft_4_accuracy,ft_4_precision,ft_4_recall,ft_4_roc_auc,ft_5_accuracy,ft_5_precision,ft_5_recall,ft_5_roc_auc
0,PruebasMuseo_21668001.set,0.605556,0.27957,0.866667,0.744222,0.688889,0.324324,0.8,0.825333,0.611111,0.272727,0.8,0.744444
1,PruebasMuseo_27157001.set,0.6,0.243902,0.666667,0.676,0.705556,0.290909,0.533333,0.765111,0.638889,0.246377,0.566667,0.689333
2,PruebasMuseo_13235001.set,0.55,0.225806,0.7,0.647778,0.788889,0.416667,0.666667,0.814,0.577778,0.22619,0.633333,0.658222
3,PruebasMuseo_31056001.set,0.561111,0.230769,0.7,0.662222,0.705556,0.31746,0.666667,0.739556,0.611111,0.25,0.666667,0.649111
4,PruebasMuseo_9809001.set,0.577778,0.232558,0.666667,0.707778,0.794444,0.431373,0.733333,0.810667,0.622222,0.25,0.633333,0.726667
5,PruebasMuseo_25302001.set,0.55,0.225806,0.7,0.663778,0.677778,0.28125,0.6,0.780667,0.6,0.2375,0.633333,0.688222
6,PruebasMuseo_18046001.set,0.588889,0.244186,0.7,0.624222,0.733333,0.354839,0.733333,0.734,0.622222,0.25641,0.666667,0.621333
7,PruebasMuseo_16943001.set,0.583333,0.252747,0.766667,0.713333,0.677778,0.310811,0.766667,0.737111,0.616667,0.259259,0.7,0.718444
8,PruebasMuseo_7488001.set,0.616667,0.26506,0.733333,0.688333,0.711111,0.328125,0.7,0.762667,0.638889,0.266667,0.666667,0.684556
9,PruebasMuseo_12168001.set,0.566667,0.23913,0.733333,0.687,0.638889,0.289157,0.8,0.776222,0.616667,0.26506,0.733333,0.693889


In [15]:
df.to_csv("fine_tuning_comparison.csv")

In [16]:
% cat fine_tuning_comparison.csv

,file,ft_0_accuracy,ft_0_precision,ft_0_recall,ft_0_roc_auc,ft_4_accuracy,ft_4_precision,ft_4_recall,ft_4_roc_auc,ft_5_accuracy,ft_5_precision,ft_5_recall,ft_5_roc_auc
0,PruebasMuseo_21668001.set,0.6055555555555555,0.27956989247311825,0.8666666666666667,0.7442222222222222,0.6888888888888889,0.32432432432432434,0.8,0.8253333333333334,0.6111111111111112,0.2727272727272727,0.8,0.7444444444444444
1,PruebasMuseo_27157001.set,0.6,0.24390243902439024,0.6666666666666666,0.676,0.7055555555555556,0.2909090909090909,0.5333333333333333,0.7651111111111111,0.6388888888888888,0.2463768115942029,0.5666666666666667,0.6893333333333332
2,PruebasMuseo_13235001.set,0.55,0.22580645161290322,0.7,0.6477777777777778,0.7888888888888889,0.4166666666666667,0.6666666666666666,0.8140000000000001,0.5777777777777777,0.2261904761904762,0.6333333333333333,0.6582222222222223
3,PruebasMuseo_31056001.set,0.5611111111111111,0.23076923076923078,0.7,0.6622222222222222,0.7055555555555556,0.31746031746031744,0.6666666666666666