In [1]:
from keras.models import Model
from keras.layers import Dense, Input, LSTM, RepeatVector, GRU
import numpy as np
import mne
import pickle
import os
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.model_selection import train_test_split
import h5py
from random import shuffle

Using TensorFlow backend.


### Load data

In [2]:
def read_h5_file(file_name, scaler = None, preprocess = False):
    h5_file = h5py.File(train_eeg_dir + file_name, 'r')
    a_group_key = list(h5_file.keys())[0]
    eeg_data = np.array(h5_file[a_group_key]).T
    if preprocess:
        eeg_data = scaler.transform(eeg_data)
    return eeg_data

In [3]:
def train_scaler(scaler, train_eeg_names, log = False):
    i = 0
    for eeg_name in train_eeg_names:
        if log:
            print("{} from {}".format(i, len(train_eeg_names)))
            print("reading:{}".format(eeg_name))
        data = read_h5_file(eeg_name)
        i = i+1
        scaler.fit(data)
        if log:
            print("trained on {}".format(eeg_name))

In [4]:
def save_scaler(path,scaler):
    pickle.dump(scaler, open(path, 'wb'))

In [5]:
def load_scaler(path):
    scaler = pickle.load(open(path, 'rb'))
    return scaler

In [6]:
train_eeg_dir = "./data/train/"
trained_scaler_path = "StandardScaler.p"

Files different from the other, were deleted in some experiments

In [7]:
bad_files = ["zavrib_post_eeg_eyesopen15021500_processed.h5","shuhova_08022017_rest_eeg_processed.h5","zavrin_15021500_eyesclosed_post_eeg_processed.h5"]

In [8]:
train_eeg_dir = "./data/train/"
all_train_eeg_names = [x for x in os.listdir(train_eeg_dir) 
                 if x[-3:] == ".h5" and x not in bad_files]
eeg_num = len(all_train_eeg_names)
print("Number of EEG overall:", eeg_num)

Number of EEG overall: 29


In [9]:
if trained_scaler_path:
    scaler = load_scaler(trained_scaler_path)
else:
    scaler = StandardScaler()
    print("Params before training ", scaler.get_params())
    train_scaler(scaler, all_train_eeg_names, log = True)
    print("Params after training ", scaler.get_params())
    save_scaler("StandardScaler.p", scaler)

### Train-test files split

In [10]:
overall_epoch_num = 10
file_epoch_num = 2
batch_size = 20
hist_path = "train_hist.txt"

test_eeg_name = all_train_eeg_names[5]
train_eeg_names = all_train_eeg_names[:5] + all_train_eeg_names[6:]
print("test_eeg_name is ", test_eeg_name)
test_data = read_h5_file(test_eeg_name, scaler, True)

test_eeg_name is  2205_miloslavov_post_eeg_processed.h5


In [11]:
channels_num = 58

### GRU

In [12]:
def create_gru_ae(encoding_dim = 58):

    encoder_inputs = Input(shape=(None, channels_num))
    encoder = GRU(encoding_dim, return_state=True)
    encoder_outputs, encoder_state = encoder(encoder_inputs)

    encoder = Model(encoder_inputs, encoder_state)
    
    print("Encoder summary: ")
    encoder.summary()

    decoder_inputs = Input(shape=(None, channels_num))
    
    decoder_gru = GRU(encoding_dim, return_sequences=True)
    decoder_outputs = decoder_gru(decoder_inputs, initial_state=encoder_state)
    
    autoencoder = Model([encoder_inputs, decoder_inputs], decoder_outputs)
    autoencoder.compile(optimizer='adam', loss="mse")
    
    print("Autoencoder summary: ")
    autoencoder.summary()
    
    return encoder, autoencoder

In [13]:
dim = 58
timestep = 1

In [14]:
encoder, autoencoder = create_gru_ae()

Encoder summary: 
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, None, 58)          0         
_________________________________________________________________
gru_1 (GRU)                  [(None, 58), (None, 58)]  20358     
Total params: 20,358
Trainable params: 20,358
Non-trainable params: 0
_________________________________________________________________
Autoencoder summary: 
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, None, 58)     0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            (None, None, 58)     0                                            

In [16]:
hist_path = "train_hist_gru_{}_{}.txt".format(dim, timestep)
hist = [[],[]]
epoch_num = 10

test_decoder_states = (-1)*np.zeros_like(test_data).reshape(-1,timestep,channels_num)

for epoch in range(epoch_num):
    shuffle(train_eeg_names)
    for name in train_eeg_names:
        train_data = read_h5_file(name, scaler, True)
        train_data = train_data[-len(train_data) % (timestep*channels_num):]
        train_data = train_data.reshape(-1,timestep,channels_num)
        np.random.shuffle(train_data)
        print("epoch: {}, file: {}".format(epoch, name))
        decoder_inputs = (-1)*np.zeros_like(train_data).reshape(-1,timestep,channels_num)
        history = autoencoder.fit([train_data, decoder_inputs], 
                                  train_data, 
                                  verbose=1, 
                                  epochs=1,
                                  batch_size = 20,
                                  validation_data=([test_data.reshape(-1,timestep,channels_num), test_decoder_states], test_data.reshape(-1,timestep,channels_num)))
        encoder.save('GRU_encoder_name.p')
        autoencoder.save('GRU_autoencoder_name.p')
        hist[0].append(history.history["loss"])
        hist[1].append(history.history["val_loss"])
        with open(hist_path, 'wb') as fp:
            pickle.dump(hist, fp)

epoch: 0, file: 2403_kutuzova_posteeg_processed.h5
Train on 604156 samples, validate on 623250 samples
Epoch 1/1
  8220/604156 [..............................] - ETA: 5:15 - loss: 0.2556   

KeyboardInterrupt: 

### LSTM

In [17]:
channels_num = 58

In [18]:
def create_ae(encoding_dim = 58, timesteps = 1):  
    input_data = Input(shape=(timesteps,channels_num))

    encoder_lstm = LSTM(encoding_dim, return_state=True)
    
    encoder_outputs, state_h, state_c = encoder_lstm(input_data)
    encoder_states = [state_h, state_c]
    
    encoder = Model(input_data, encoder_states)
    print("Encoder summary: ")
    encoder.summary()
    
    decoded = RepeatVector(timesteps)(encoder_outputs)
    
    decoder_lstm = LSTM(channels_num)
    
    decoder_outputs = decoder_lstm(decoded, initial_state=encoder_states)
    
    autoencoder = Model(input_data, decoder_outputs)
    autoencoder.compile(optimizer='adam', loss="mse")
    
    print("Autoencoder summary: ")
    autoencoder.summary()
    
    return encoder, autoencoder

In [19]:
enc_dim = 58
timestep = 1
hist_path = "train_hist.txt"
hist = [[],[]]

In [20]:
encoder, autoencoder = create_ae(enc_dim)

Encoder summary: 
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_3 (InputLayer)         (None, 1, 58)             0         
_________________________________________________________________
lstm_1 (LSTM)                [(None, 58), (None, 58),  27144     
Total params: 27,144
Trainable params: 27,144
Non-trainable params: 0
_________________________________________________________________
Autoencoder summary: 
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            (None, 1, 58)        0                                            
__________________________________________________________________________________________________
lstm_1 (LSTM)                   [(None, 58), (None,  27144       input_3[0][0]                    

In [21]:
for epoch in range(overall_epoch_num):
    for name in train_eeg_names:
        train_data = read_h5_file(name, scaler, True)
        #if len(train_data) > epoch*learn_file_length:
        #    train_data = train_data[:learn_file_length]
        print("epoch: {}, file: {}".format(epoch, name))
        history = autoencoder.fit(train_data.reshape(-1,1,58), train_data, 
                        verbose=1, 
                        epochs=file_epoch_num,
                        #batch_size = 2**(overall_epoch_num - epoch),
                        batch_size = 20,
                        validation_data=(test_data.reshape(-1,1,58), test_data))
        encoder.save('RNN_encoder.p')
        autoencoder.save('RNN_autoencoder.p')
        hist[0].append(history.history["loss"])
        hist[1].append(history.history["val_loss"])
        with open(hist_path, 'wb') as fp:
            pickle.dump(hist, fp)

epoch: 0, file: 2403_kutuzova_posteeg_processed.h5
Train on 604200 samples, validate on 623250 samples
Epoch 1/2
  3040/604200 [..............................] - ETA: 10:28 - loss: 0.3611   

KeyboardInterrupt: 