# Fine Tuning of the convolutional network

In [1]:
%pylab
%matplotlib inline

import glob
import os
import mne
CORPORA_PATH = "~/corpora/sets"

file_path = os.path.expanduser(CORPORA_PATH)
files = glob.glob(os.path.join(file_path, "*.set"))

def normalize_subject(X):
    mean = X.mean(axis=(0, 2)).reshape(-1, 1)
    std = X.std(axis=(0, 2)).reshape(-1, 1)
    return (X - mean) / std

def load_data(filename, normalize=True):
    data_mne = mne.io.read_raw_eeglab(filename, preload=True, event_id={"0": 1, "1": 2})
    data_mne.filter(0, 20)
    events = mne.find_events(data_mne)
    epochs = mne.Epochs(
        data_mne, events,
        baseline=(None, 0), tmin=-0.1, tmax=0.7)

    epochs.load_data()
    
    ch_names = epochs.ch_names
    
    X = epochs.get_data()[:, :-1]
    y = (events[:, 2] == 2).astype('float')

    if len(events) != len(epochs):
        raise ValueError("Epochs events mismatch")
    if normalize: 
        X = normalize_subject(X)
    X = X[..., np.newaxis]
    
    return X, y 


Using matplotlib backend: TkAgg
Populating the interactive namespace from numpy and matplotlib


In [2]:
from keras.models import load_model
channels = ['AF3', 'F7', 'F3', 'FC5', 'T7', 'P7', 'O1', 'O2', 'P8', 'T8', 'FC6', 'F4', 'F8', 'AF4', 'STI 014']

model = load_model("models/model.h5")

Using TensorFlow backend.
  return f(*args, **kwds)


In [3]:
X, y = load_data(files[143])

Reading /home/ubuntu/corpora/sets/PruebasMuseo_7488001.fdt
Reading 0 ... 63231  =      0.000 ...   493.992 secs...
Setting up low-pass filter at 20 Hz
h_trans_bandwidth chosen to be 5.0 Hz
Filter length of 169 samples (1.320 sec) selected
1800 events found
Events id: [1 2]
1800 matching events found
0 projection items activated
Loading data for 1800 events and 104 original time points ...
0 bad epochs dropped


In [4]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test= train_test_split(X, y, test_size=0.1, stratify=y)

In [5]:


from sklearn.metrics import precision_score, recall_score, roc_auc_score, accuracy_score

def get_metrics(model, X_test, y_test):
    y_pred = model.predict_classes(X_test)
    y_prob = model.predict(X_test)

    precision = precision_score(y_test, y_pred)
    recall = recall_score(y_test, y_pred)
    auc = roc_auc_score(y_test, y_prob)
    accuracy = accuracy_score(y_test, y_pred)

    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall, 
        "roc_auc": auc
    }
    
    
get_metrics(model, X_test, y_test)

{'accuracy': 0.53888888888888886,
 'precision': 0.22680412371134021,
 'recall': 0.73333333333333328,
 'roc_auc': 0.69155555555555548}

In [6]:
model.layers

[<keras.layers.convolutional.Conv2D at 0x7f3963894c18>,
 <keras.layers.convolutional.Conv2D at 0x7f3963894f28>,
 <keras.layers.core.Flatten at 0x7f3963894ef0>,
 <keras.layers.core.Dropout at 0x7f396386a4a8>,
 <keras.layers.core.Dense at 0x7f39638d1dd8>,
 <keras.layers.core.Dense at 0x7f3963880fd0>]

Let's fix the first two convolutional layers

In [7]:

for i in range(4):
    model.layers[i].trainable = False



model.compile(loss='binary_crossentropy', # using the cross-entropy loss function
              optimizer='rmsprop', 
              metrics=['accuracy']) # reporting the accuracy
[(l, "Trainable: {}".format(l.trainable)) for l in model.layers]

[(<keras.layers.convolutional.Conv2D at 0x7f3963894c18>, 'Trainable: False'),
 (<keras.layers.convolutional.Conv2D at 0x7f3963894f28>, 'Trainable: False'),
 (<keras.layers.core.Flatten at 0x7f3963894ef0>, 'Trainable: False'),
 (<keras.layers.core.Dropout at 0x7f396386a4a8>, 'Trainable: False'),
 (<keras.layers.core.Dense at 0x7f39638d1dd8>, 'Trainable: True'),
 (<keras.layers.core.Dense at 0x7f3963880fd0>, 'Trainable: True')]

In [8]:
from keras.callbacks import ModelCheckpoint, EarlyStopping

early_stopping = EarlyStopping(monitor='val_loss', patience=3)

model.fit(
    X_train, y_train, epochs=30, 
    batch_size=64, class_weight={0:1, 1:6}, validation_split=0.01,
    callbacks=[early_stopping]
)

Train on 1603 samples, validate on 17 samples
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30


<keras.callbacks.History at 0x7f395ff58ac8>

In [9]:
get_metrics(model, X_test, y_test)

{'accuracy': 0.69999999999999996,
 'precision': 0.29310344827586204,
 'recall': 0.56666666666666665,
 'roc_auc': 0.72088888888888891}

In [10]:
from os.path import basename
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras import backend as K

def fix_layers(model, fixed_layers):
    for i in range(fixed_layers):
        model.layers[i].trainable = False
    
    model.compile(loss='binary_crossentropy',
              optimizer='rmsprop', 
              metrics=['accuracy'])
    
def fine_tune(fixed_layers):
    model = load_model("models/model.h5")
    
    fix_layers(model, fixed_layers)
    
    early_stopping = EarlyStopping(monitor='val_loss', patience=3)

    model.fit(
        X_train, y_train, epochs=10, 
        batch_size=64, class_weight={0:1, 1:6}, validation_split=0.01,
        callbacks=[early_stopping]
    )
    
    return model

def get_analysis(filename, fixed_layers=4):
    K.clear_session()
    model = load_model("models/model.h5")
    X, y = load_data(files[143])
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, stratify=y)
    
    ret = {"file" : basename(filename)}
    metrics = {"ft_0_{}".format(k):v for k,v in get_metrics(model, X_test, y_test).items()}
    ret.update(metrics)
    
    model = fine_tune(4)
    metrics = {"ft_4_{}".format(k):v for k,v in get_metrics(model, X_test, y_test).items()}
    ret.update(metrics)
    
    model = fine_tune(5)
    metrics = {"ft_5_{}".format(k):v for k,v in get_metrics(model, X_test, y_test).items()}
    ret.update(metrics)
    
    K.clear_session()
    return ret

get_analysis(files[100], 4)

Reading /home/ubuntu/corpora/sets/PruebasMuseo_7488001.fdt
Reading 0 ... 63231  =      0.000 ...   493.992 secs...
Setting up low-pass filter at 20 Hz
h_trans_bandwidth chosen to be 5.0 Hz
Filter length of 169 samples (1.320 sec) selected
1800 events found
Events id: [1 2]
1800 matching events found
0 projection items activated
Loading data for 1800 events and 104 original time points ...
0 bad epochs dropped
Train on 1603 samples, validate on 17 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Train on 1603 samples, validate on 17 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10


{'file': 'PruebasMuseo_28970001.set',
 'ft_0_accuracy': 0.5444444444444444,
 'ft_0_precision': 0.23469387755102042,
 'ft_0_recall': 0.76666666666666672,
 'ft_0_roc_auc': 0.67955555555555558,
 'ft_4_accuracy': 0.71111111111111114,
 'ft_4_precision': 0.35135135135135137,
 'ft_4_recall': 0.8666666666666667,
 'ft_4_roc_auc': 0.8626666666666668,
 'ft_5_accuracy': 0.57222222222222219,
 'ft_5_precision': 0.21686746987951808,
 'ft_5_recall': 0.59999999999999998,
 'ft_5_roc_auc': 0.68600000000000005}

In [11]:
analysis = [get_analysis(file) for file in files[-10:]]
    

Reading /home/ubuntu/corpora/sets/PruebasMuseo_7488001.fdt
Reading 0 ... 63231  =      0.000 ...   493.992 secs...
Setting up low-pass filter at 20 Hz
h_trans_bandwidth chosen to be 5.0 Hz
Filter length of 169 samples (1.320 sec) selected
1800 events found
Events id: [1 2]
1800 matching events found
0 projection items activated
Loading data for 1800 events and 104 original time points ...
0 bad epochs dropped
Train on 1603 samples, validate on 17 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Train on 1603 samples, validate on 17 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Reading /home/ubuntu/corpora/sets/PruebasMuseo_7488001.fdt
Reading 0 ... 63231  =      0.000 ...   493.992 secs...
Setting up low-pass filter at 20 Hz
h_trans_bandwidth chosen to be 5.0 Hz
Filter length of 169 samples (1.320 sec) selected
1800 events found
Events id: [1 2]
1800 matching events found
0 projection items activated

Epoch 8/10
Reading /home/ubuntu/corpora/sets/PruebasMuseo_7488001.fdt
Reading 0 ... 63231  =      0.000 ...   493.992 secs...
Setting up low-pass filter at 20 Hz
h_trans_bandwidth chosen to be 5.0 Hz
Filter length of 169 samples (1.320 sec) selected
1800 events found
Events id: [1 2]
1800 matching events found
0 projection items activated
Loading data for 1800 events and 104 original time points ...
0 bad epochs dropped
Train on 1603 samples, validate on 17 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Train on 1603 samples, validate on 17 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Reading /home/ubuntu/corpora/sets/PruebasMuseo_7488001.fdt
Reading 0 ... 63231  =      0.000 ...   493.992 secs...
Setting up low-pass filter at 20 Hz
h_trans_bandwidth chosen to be 5.0 Hz
Filter length of 169 samples (1.320 sec) selected
1800 events found
Event

Reading /home/ubuntu/corpora/sets/PruebasMuseo_7488001.fdt
Reading 0 ... 63231  =      0.000 ...   493.992 secs...
Setting up low-pass filter at 20 Hz
h_trans_bandwidth chosen to be 5.0 Hz
Filter length of 169 samples (1.320 sec) selected
1800 events found
Events id: [1 2]
1800 matching events found
0 projection items activated
Loading data for 1800 events and 104 original time points ...
0 bad epochs dropped
Train on 1603 samples, validate on 17 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Train on 1603 samples, validate on 17 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Reading /home/ubuntu/corpora/sets/PruebasMuseo_7488001.fdt
Reading 0 ... 63231  =      0.000 ...   493.992 secs...
Setting up low-pass filter at 20 Hz
h_trans_bandwidth chosen to be 5.0 Hz
Filter length of 169 samples (1.320 sec) selected
1800 events found
Events id: [1 2]
1800 matching events found
0 projection items activated
Loading data for 1800 

Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Train on 1603 samples, validate on 17 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [12]:
import pandas as pd

df = pd.DataFrame(df)

Unnamed: 0,file,ft_0_accuracy,ft_0_precision,ft_0_recall,ft_0_roc_auc,ft_4_accuracy,ft_4_precision,ft_4_recall,ft_4_roc_auc,ft_5_accuracy,ft_5_precision,ft_5_recall,ft_5_roc_auc
0,PruebasMuseo_21668001.set,0.622222,0.284091,0.833333,0.726889,0.75,0.377049,0.766667,0.863333,0.627778,0.277108,0.766667,0.732
1,PruebasMuseo_27157001.set,0.588889,0.244186,0.7,0.716667,0.8,0.446429,0.833333,0.889111,0.611111,0.261905,0.733333,0.728
2,PruebasMuseo_13235001.set,0.55,0.213483,0.633333,0.649333,0.733333,0.344828,0.666667,0.824444,0.561111,0.211765,0.6,0.650222
3,PruebasMuseo_31056001.set,0.594444,0.258427,0.766667,0.703111,0.777778,0.395833,0.633333,0.827111,0.616667,0.259259,0.7,0.688444
4,PruebasMuseo_9809001.set,0.605556,0.253012,0.7,0.681333,0.85,0.533333,0.8,0.906,0.611111,0.256098,0.7,0.692222
5,PruebasMuseo_25302001.set,0.566667,0.233333,0.7,0.682667,0.766667,0.388889,0.7,0.863556,0.572222,0.235955,0.7,0.68
6,PruebasMuseo_18046001.set,0.511111,0.21,0.7,0.644444,0.577778,0.265306,0.866667,0.802,0.561111,0.230769,0.7,0.650667
7,PruebasMuseo_16943001.set,0.6,0.255814,0.733333,0.693333,0.755556,0.383333,0.766667,0.848,0.611111,0.261905,0.733333,0.680222
8,PruebasMuseo_7488001.set,0.538889,0.226804,0.733333,0.679111,0.75,0.38806,0.866667,0.893778,0.555556,0.228261,0.7,0.681333
9,PruebasMuseo_12168001.set,0.611111,0.267442,0.766667,0.744222,0.822222,0.477273,0.7,0.853556,0.611111,0.267442,0.766667,0.754


In [14]:
df.to_csv()

AttributeError: 'list' object has no attribute 'to_csv'