In [1]:
import tensorflow as tf

import numpy as np

import pandas as pd

import mne

import os

import keras

import seaborn as sns

from collections import Counter

from tqdm import tqdm_notebook
from tqdm import tqdm_notebook as tqdm

import matplotlib.pyplot as plt

Using TensorFlow backend.


In [2]:
import tensorflow as tf
from keras.layers import Dense,Activation,Dropout, Input, Conv1D, MaxPool1D, Flatten
from keras.layers import LSTM,Bidirectional,TimeDistributed #could try TimeDistributed(Dense(...))
from keras.models import Sequential, load_model
from keras import optimizers,regularizers
from keras.layers.normalization import BatchNormalization
import keras.backend.tensorflow_backend as KTF

In [3]:
from scipy.signal import stft

# Prepare constants

In [6]:
CHUNK_FREQUENCY = 250
CHUNK_TIME = CHUNK_FREQUENCY
FFT_STEPS_NUM = 125
BATCH_SIZE = 128

In [7]:
CHANNELS = ['EEG P3-REF', 'EEG FP2-REF', 'EEG T5-REF', 'EEG O1-REF', 'EEG T4-REF', 'EEG FP1-REF', 'EEG F7-REF', 'EEG F3-REF', 'EEG CZ-REF', 'EEG T6-REF', 'EEG F4-REF', 'EEG PZ-REF', 'EEG A1-REF', 'EEG A2-REF', 'EEG F8-REF', 'EEG P4-REF', 'EEG C4-REF', 'EEG FZ-REF', 'EEG C3-REF', 'EEG T3-REF', 'EEG O2-REF']

In [8]:
MAX_LENGTH = 66000

# Prepare dataset

In [9]:
train_df = pd.read_csv("../processed-data/train.csv")
val_df = pd.read_csv("../processed-data/val.csv")

# Generator functions + speed tests

In [10]:
def iterate_files(df, batch_size=BATCH_SIZE):
    files = df[["full_path", "length_chunks"]].drop_duplicates().values
    file_chunks = []
    for file, chunks in files:
        for chunk in range(chunks + 1):
            file_chunks.append((file, chunk))
    files_number_to_add = (batch_size - len(file_chunks) % batch_size)
    files_to_add = file_chunks[0:files_number_to_add]
    file_chunks = file_chunks + files_to_add
    
    file_chunks_index = list(range(len(file_chunks)))
    file_chunks_index = np.random.choice(file_chunks_index, len(file_chunks_index), replace=False)
    file_chunks = [file_chunks[i] for i in file_chunks_index]
    for files in zip(*[iter(file_chunks)]*batch_size):
        yield files

In [11]:
files = next(iterate_files(train_df))

In [12]:
def get_data(df, file, chunk, channels=CHANNELS, chunk_size=CHUNK_TIME, step_size=CHUNK_TIME // FFT_STEPS_NUM):
    annotations = df[
        (df["full_path"] == file) & \
        (df["label"] == "seiz")
    ][["start", "end"]]
    edf = mne.io.read_raw_edf(file, preload=True, verbose='ERROR')
    edf.filter(2, 60)
    edf_picks = edf.pick_channels(channels)
    data, time = edf_picks[:, chunk * MAX_LENGTH:(chunk + 1) * MAX_LENGTH]
    
    events = time * 0
    for _, (start, end) in annotations.iterrows():
        events += (time >= start) & (time <= end)
    events = (events > 0).astype(int)
    
    del edf
    
    return data, events

In [13]:
def get_data_multiple(df, files, channels=CHANNELS, chunk_size=CHUNK_TIME):
    total_data = []
    total_events = []
    for file, chunk in tqdm_notebook(files):
        data, events = get_data(df, file, chunk)
        total_data.append(data)
        total_events.append(events)
    
    return total_data, total_events

In [14]:
data, events = get_data(train_df, files[0][0], files[0][1])

In [15]:
data, events = get_data_multiple(train_df, files)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  after removing the cwd from sys.path.


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




In [16]:
def get_fourier_transform(data, window_size=CHUNK_TIME, step_size=CHUNK_TIME // FFT_STEPS_NUM):
    data = data.T
    return np.log1p(np.abs(stft(
        data, 
        fs=window_size, 
        window="hann", 
        nperseg=window_size, 
        noverlap=window_size - step_size, 
        return_onesided=True, 
        boundary=None,
    )[-1]).T)

In [17]:
class ChunksIterator():
    def __init__(self, data, events, chunk_size=CHUNK_TIME, step_size=CHUNK_TIME // FFT_STEPS_NUM, max_length=MAX_LENGTH, tqdm_enabled=False):
        self.data = data
        self.events = events
        self.chunk_size = chunk_size
        self.step_size = step_size
        self.valid_chunks = 0
        
        max_time = min(max([e.shape[0] for e in events]), max_length)
        self.iterations = range(0, max_time - 2*chunk_size, chunk_size)
        self.iterations = np.random.choice(self.iterations, len(self.iterations), replace=False)
        self.iterations_number = len(self.iterations)
        if tqdm_enabled:
            self.iterations = tqdm(self.iterations)
        self.iterations = iter(self.iterations)
    
    def __iter__(self):
        return self
    
    def __next__(self):
        data = self.data
        events = self.events
        chunk_size = self.chunk_size
        step_size = self.step_size
        chunk_start = next(self.iterations)
        
        data_chunk = [d[:, chunk_start:chunk_start + 2*chunk_size].T for d in data]
        labels_chunk = [e[chunk_start:chunk_start + chunk_size] for e in events]

        zero_data_chunk = np.zeros((chunk_size * 2, len(CHANNELS)))
        zero_labels_chunk = np.zeros((chunk_size, ))

        masks_chunk = [e.shape[0] == chunk_size for e in labels_chunk]
        data_chunk = [d if d.shape[0] == 2*chunk_size else zero_data_chunk for d in data_chunk]
        labels_chunk = [e if e.shape[0] == chunk_size else zero_labels_chunk for e in labels_chunk]

#         data_chunk = [get_fourier_transform(d) for d in data_chunk]
        labels_chunk = [e.max() for e in labels_chunk]
        
        self.valid_chunks += sum(masks_chunk)

        return np.stack(data_chunk).swapaxes(1, 2), np.stack(labels_chunk)[:, np.newaxis], np.array(masks_chunk)

In [18]:
iterator = ChunksIterator(data, events, tqdm_enabled=True)

print("Total iterations:", iterator.iterations_number)

for i, _ in zip(range(20), iterator):
    print(i)
    
print("Valid chunks:", iterator.valid_chunks)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))

Total iterations: 262
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
Valid chunks: 2229


# TF stft investigation

In [45]:
# stft_input = tf.placeholder(tf.float32)

# stft_tensor = tf.signal.stft(
#     stft_input, 
#     frame_length=CHUNK_TIME,
#     frame_step=CHUNK_TIME // FFT_STEPS_NUM,
#     fft_length=CHUNK_TIME
# )

# s = tf.Session()

# iterator = ChunksIterator(data, events, tqdm_enabled=True)

# data_chunk, labels_chunk, _ = next(iterator)

# data_chunk.shape

# s.run(stft_tensor, {stft_input: data_chunk}).shape

# network_input = keras.layers.Input(batch_shape=data_chunk.shape)

# fft_layer = keras.layers.Lambda(lambda x: tf.log1p(tf.abs(tf.signal.stft(
#     x, 
#     frame_length=CHUNK_TIME,
#     frame_step=CHUNK_TIME // FFT_STEPS_NUM,
#     fft_length=CHUNK_TIME
# ))))(network_input)

# flatten = keras.layers.Flatten()(fft_layer)
# dense_layer = keras.layers.Dense(1, activation='sigmoid')(flatten)

# model = keras.models.Model(inputs=[network_input], outputs=[dense_layer])

# data_chunk.shape

# model.predict(data_chunk).shape

# model.predict(data_chunk).shape

# model.compile(
#     loss="binary_crossentropy",
#     optimizer='adam'
# )

# model.fit([data_chunk], [labels_chunk], epochs=10)

# data_chunk

# Main model architecture

In [46]:
from keras import backend as K

def recall_m(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    recall = true_positives / (possible_positives + K.epsilon())
    return recall

def precision_m(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + K.epsilon())
    return precision

In [47]:
eeg_input = Input(shape=(len(CHANNELS), CHUNK_TIME * 2, ))

fft_layer = keras.layers.Lambda(lambda x: tf.log1p(tf.abs(tf.signal.stft(
    x, 
    frame_length=CHUNK_TIME,
    frame_step=CHUNK_TIME // FFT_STEPS_NUM,
    fft_length=CHUNK_TIME
))))(eeg_input)

x = Flatten()(fft_layer)
x = Dropout(0.2)(x)

for _ in range(3):
    x = Dense(128, kernel_regularizer=regularizers.l2(0.01))(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
events_output = Dense(1, activation='sigmoid')(x)

Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.


In [48]:
model = keras.models.Model(inputs=[eeg_input], outputs=[events_output])

In [49]:
model.compile(
    loss="binary_crossentropy",
    optimizer='adam', 
    metrics=['accuracy', recall_m, precision_m]
)

Some tests on small chunks

In [50]:
generator = ChunksIterator(data, events)
for i in range(5):
    next(generator)
x_in, y_end, masks_in = next(generator)

In [51]:
model.predict(x_in).mean()

0.49998707

In [52]:
model.evaluate(x_in, y_end)



[5.8239967823028564, 0.7265625, 0.17500000447034836, 0.2083333358168602]

In [53]:
# model.fit(x_in, y_end, batch_size=x_in.shape[0], epochs=1000, sample_weight=masks_in)

In [54]:
y_end.sum()

18.0

# Training procedure

In [55]:
GENERATOR_PAD_BUG = 1

In [56]:
total_train_history = {}
total_val_history = {}

In [None]:
for epoch in tqdm(range(10)):
    total_train_history[epoch] = []
    total_val_history[epoch] = []
    
    total_train_loss = 0
    total_train_chunks = 0
    
    total_val_loss = 0
    total_val_chunks = 0
    
    
    for train_files in tqdm_notebook(list(iterate_files(train_df))):
        train_data, train_events = get_data_multiple(train_df, train_files)
        train_generator = ChunksIterator(train_data, train_events, tqdm_enabled=True)
        
        train_chunks = train_generator.iterations_number
        
        train_history = model.fit_generator(
            train_generator, 
            epochs=1, 
            steps_per_epoch=train_chunks - GENERATOR_PAD_BUG,
#             class_weight=[1, 10]
        )
        
        total_train_loss += train_history.history['loss'][0] * train_generator.valid_chunks
        total_train_chunks += train_generator.valid_chunks
        
        model.reset_states()
        
        del train_generator
        del train_data
        del train_events
        
        total_train_history[epoch].append(total_train_loss / total_train_chunks)
        print("Train loss:", total_train_history)
        
        model.save_weights("./models/big-fc-fft-model-{}.h5".format(epoch))
        
    for val_files in tqdm_notebook(list(iterate_files(val_df))):
        val_data, val_events = get_data_multiple(val_df, val_files)
        val_generator = ChunksIterator(val_data, val_events, tqdm_enabled=True)
        
        val_chunks = val_generator.iterations_number
        
        val_metrics = model.evaluate_generator(
            val_generator,
            steps=val_chunks - GENERATOR_PAD_BUG
        )
        
        print(val_metrics)
        
        total_val_loss += val_metrics[0] * val_generator.valid_chunks
        total_val_chunks += val_generator.valid_chunks
        
        model.reset_states()
        del val_generator
        del val_data
        del val_events
        
        total_val_history[epoch].append(total_val_loss / total_val_chunks)
        print("Val loss:", total_val_history)
    

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  """Entry point for launching an IPython kernel.


HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  if sys.path[0] == '':


HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  after removing the cwd from sys.path.


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))

Instructions for updating:
Use tf.cast instead.
Epoch 1/1
Train loss: {0: [1.6803342033163342]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))

Epoch 1/1
Train loss: {0: [1.6803342033163342, 1.192454763004138]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))

Epoch 1/1
Train loss: {0: [1.6803342033163342, 1.192454763004138, 1.048453326514852]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))

Epoch 1/1
Train loss: {0: [1.6803342033163342, 1.192454763004138, 1.048453326514852, 0.9384002263460985]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))

Epoch 1/1
Train loss: {0: [1.6803342033163342, 1.192454763004138, 1.048453326514852, 0.9384002263460985, 0.8776981047002004]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))

Epoch 1/1
Train loss: {0: [1.6803342033163342, 1.192454763004138, 1.048453326514852, 0.9384002263460985, 0.8776981047002004, 0.8198577896726577]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))

Epoch 1/1
Train loss: {0: [1.6803342033163342, 1.192454763004138, 1.048453326514852, 0.9384002263460985, 0.8776981047002004, 0.8198577896726577, 0.7862382492191038, 0.7575672983565611]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))

Epoch 1/1
Train loss: {0: [1.6803342033163342, 1.192454763004138, 1.048453326514852, 0.9384002263460985, 0.8776981047002004, 0.8198577896726577, 0.7862382492191038, 0.7575672983565611, 0.7432610229920361]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))

Epoch 1/1
Train loss: {0: [1.6803342033163342, 1.192454763004138, 1.048453326514852, 0.9384002263460985, 0.8776981047002004, 0.8198577896726577, 0.7862382492191038, 0.7575672983565611, 0.7432610229920361, 0.7430573872992285]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)






HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))

Epoch 1/1
Train loss: {0: [1.6803342033163342, 1.192454763004138, 1.048453326514852, 0.9384002263460985, 0.8776981047002004, 0.8198577896726577, 0.7862382492191038, 0.7575672983565611, 0.7432610229920361, 0.7430573872992285, 0.7290655598790535, 0.7210331537519713, 0.7033488912363082]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))

Epoch 1/1
Train loss: {0: [1.6803342033163342, 1.192454763004138, 1.048453326514852, 0.9384002263460985, 0.8776981047002004, 0.8198577896726577, 0.7862382492191038, 0.7575672983565611, 0.7432610229920361, 0.7430573872992285, 0.7290655598790535, 0.7210331537519713, 0.7033488912363082, 0.6954681508923458]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))

Epoch 1/1
Train loss: {0: [1.6803342033163342, 1.192454763004138, 1.048453326514852, 0.9384002263460985, 0.8776981047002004, 0.8198577896726577, 0.7862382492191038, 0.7575672983565611, 0.7432610229920361, 0.7430573872992285, 0.7290655598790535, 0.7210331537519713, 0.7033488912363082, 0.6954681508923458, 0.684481616740026]}



Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))


[0.8281069160421233, 0.7990301724137931, 0.1055259703519358, 0.5412440012812157]
Val loss: {0: [0.8281069160421233]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))


[0.9925952501223918, 0.7334770114942529, 0.15430688369890738, 0.4918995884700296]
Val loss: {0: [0.8281069160421233, 0.9105895141127434]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))


[0.8581205030967449, 0.7743354885057471, 0.07335673726495655, 0.25708096151836074]
Val loss: {0: [0.8281069160421233, 0.9105895141127434, 0.8924771286702872]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))


[0.8102228426841941, 0.7866977969348659, 0.17757471779297138, 0.38904802184337856]
Val loss: {0: [0.8281069160421233, 0.9105895141127434, 0.8924771286702872, 0.8722748382238076]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))


[0.8567848013735365, 0.7948694923371648, 0.14927875862895756, 0.5273755083709841]
Val loss: {0: [0.8281069160421233, 0.9105895141127434, 0.8924771286702872, 0.8722748382238076, 0.8693136283145201]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))


[0.8531065567243145, 0.7894516283524904, 0.11549870695533423, 0.4027369458842095]
Val loss: {0: [0.8281069160421233, 0.9105895141127434, 0.8924771286702872, 0.8722748382238076, 0.8693136283145201, 0.8666636285311622]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))


[0.7775972889757704, 0.8061242816091954, 0.19811860030923767, 0.4036065364306457]
Val loss: {0: [0.8281069160421233, 0.9105895141127434, 0.8924771286702872, 0.8722748382238076, 0.8693136283145201, 0.8666636285311622, 0.8542038544133839]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))


[0.8703401547738876, 0.7778675766283525, 0.1389885532930208, 0.39979826616144726]
Val loss: {0: [0.8281069160421233, 0.9105895141127434, 0.8924771286702872, 0.8722748382238076, 0.8693136283145201, 0.8666636285311622, 0.8542038544133839, 0.8562569421145861]}



HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))

Epoch 1/1
Train loss: {0: [1.6803342033163342, 1.192454763004138, 1.048453326514852, 0.9384002263460985, 0.8776981047002004, 0.8198577896726577, 0.7862382492191038, 0.7575672983565611, 0.7432610229920361, 0.7430573872992285, 0.7290655598790535, 0.7210331537519713, 0.7033488912363082, 0.6954681508923458, 0.684481616740026], 1: [0.6179229032490902]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))

Epoch 1/1
Train loss: {0: [1.6803342033163342, 1.192454763004138, 1.048453326514852, 0.9384002263460985, 0.8776981047002004, 0.8198577896726577, 0.7862382492191038, 0.7575672983565611, 0.7432610229920361, 0.7430573872992285, 0.7290655598790535, 0.7210331537519713, 0.7033488912363082, 0.6954681508923458, 0.684481616740026], 1: [0.6179229032490902, 0.5948432306867573]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))

Epoch 1/1
Train loss: {0: [1.6803342033163342, 1.192454763004138, 1.048453326514852, 0.9384002263460985, 0.8776981047002004, 0.8198577896726577, 0.7862382492191038, 0.7575672983565611, 0.7432610229920361, 0.7430573872992285, 0.7290655598790535, 0.7210331537519713, 0.7033488912363082, 0.6954681508923458, 0.684481616740026], 1: [0.6179229032490902, 0.5948432306867573, 0.6167410019638201]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))

Epoch 1/1
Train loss: {0: [1.6803342033163342, 1.192454763004138, 1.048453326514852, 0.9384002263460985, 0.8776981047002004, 0.8198577896726577, 0.7862382492191038, 0.7575672983565611, 0.7432610229920361, 0.7430573872992285, 0.7290655598790535, 0.7210331537519713, 0.7033488912363082, 0.6954681508923458, 0.684481616740026], 1: [0.6179229032490902, 0.5948432306867573, 0.6167410019638201, 0.6165876358039056]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))

Epoch 1/1
Train loss: {0: [1.6803342033163342, 1.192454763004138, 1.048453326514852, 0.9384002263460985, 0.8776981047002004, 0.8198577896726577, 0.7862382492191038, 0.7575672983565611, 0.7432610229920361, 0.7430573872992285, 0.7290655598790535, 0.7210331537519713, 0.7033488912363082, 0.6954681508923458, 0.684481616740026], 1: [0.6179229032490902, 0.5948432306867573, 0.6167410019638201, 0.6165876358039056, 0.6178627264985305]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))

Epoch 1/1
Train loss: {0: [1.6803342033163342, 1.192454763004138, 1.048453326514852, 0.9384002263460985, 0.8776981047002004, 0.8198577896726577, 0.7862382492191038, 0.7575672983565611, 0.7432610229920361, 0.7430573872992285, 0.7290655598790535, 0.7210331537519713, 0.7033488912363082, 0.6954681508923458, 0.684481616740026], 1: [0.6179229032490902, 0.5948432306867573, 0.6167410019638201, 0.6165876358039056, 0.6178627264985305, 0.6116189039656329]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))

Epoch 1/1
Train loss: {0: [1.6803342033163342, 1.192454763004138, 1.048453326514852, 0.9384002263460985, 0.8776981047002004, 0.8198577896726577, 0.7862382492191038, 0.7575672983565611, 0.7432610229920361, 0.7430573872992285, 0.7290655598790535, 0.7210331537519713, 0.7033488912363082, 0.6954681508923458, 0.684481616740026], 1: [0.6179229032490902, 0.5948432306867573, 0.6167410019638201, 0.6165876358039056, 0.6178627264985305, 0.6116189039656329, 0.604680796652444]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))

Epoch 1/1
Train loss: {0: [1.6803342033163342, 1.192454763004138, 1.048453326514852, 0.9384002263460985, 0.8776981047002004, 0.8198577896726577, 0.7862382492191038, 0.7575672983565611, 0.7432610229920361, 0.7430573872992285, 0.7290655598790535, 0.7210331537519713, 0.7033488912363082, 0.6954681508923458, 0.684481616740026], 1: [0.6179229032490902, 0.5948432306867573, 0.6167410019638201, 0.6165876358039056, 0.6178627264985305, 0.6116189039656329, 0.604680796652444, 0.5928983457905049]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))

Epoch 1/1
Train loss: {0: [1.6803342033163342, 1.192454763004138, 1.048453326514852, 0.9384002263460985, 0.8776981047002004, 0.8198577896726577, 0.7862382492191038, 0.7575672983565611, 0.7432610229920361, 0.7430573872992285, 0.7290655598790535, 0.7210331537519713, 0.7033488912363082, 0.6954681508923458, 0.684481616740026], 1: [0.6179229032490902, 0.5948432306867573, 0.6167410019638201, 0.6165876358039056, 0.6178627264985305, 0.6116189039656329, 0.604680796652444, 0.5928983457905049, 0.5897887729925787]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))

Epoch 1/1
Train loss: {0: [1.6803342033163342, 1.192454763004138, 1.048453326514852, 0.9384002263460985, 0.8776981047002004, 0.8198577896726577, 0.7862382492191038, 0.7575672983565611, 0.7432610229920361, 0.7430573872992285, 0.7290655598790535, 0.7210331537519713, 0.7033488912363082, 0.6954681508923458, 0.684481616740026], 1: [0.6179229032490902, 0.5948432306867573, 0.6167410019638201, 0.6165876358039056, 0.6178627264985305, 0.6116189039656329, 0.604680796652444, 0.5928983457905049, 0.5897887729925787, 0.5945049200120445]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))

Epoch 1/1
Train loss: {0: [1.6803342033163342, 1.192454763004138, 1.048453326514852, 0.9384002263460985, 0.8776981047002004, 0.8198577896726577, 0.7862382492191038, 0.7575672983565611, 0.7432610229920361, 0.7430573872992285, 0.7290655598790535, 0.7210331537519713, 0.7033488912363082, 0.6954681508923458, 0.684481616740026], 1: [0.6179229032490902, 0.5948432306867573, 0.6167410019638201, 0.6165876358039056, 0.6178627264985305, 0.6116189039656329, 0.604680796652444, 0.5928983457905049, 0.5897887729925787, 0.5945049200120445, 0.5871985733130765]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))

Epoch 1/1
Train loss: {0: [1.6803342033163342, 1.192454763004138, 1.048453326514852, 0.9384002263460985, 0.8776981047002004, 0.8198577896726577, 0.7862382492191038, 0.7575672983565611, 0.7432610229920361, 0.7430573872992285, 0.7290655598790535, 0.7210331537519713, 0.7033488912363082, 0.6954681508923458, 0.684481616740026], 1: [0.6179229032490902, 0.5948432306867573, 0.6167410019638201, 0.6165876358039056, 0.6178627264985305, 0.6116189039656329, 0.604680796652444, 0.5928983457905049, 0.5897887729925787, 0.5945049200120445, 0.5871985733130765, 0.5790858557272157]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))

Epoch 1/1
Train loss: {0: [1.6803342033163342, 1.192454763004138, 1.048453326514852, 0.9384002263460985, 0.8776981047002004, 0.8198577896726577, 0.7862382492191038, 0.7575672983565611, 0.7432610229920361, 0.7430573872992285, 0.7290655598790535, 0.7210331537519713, 0.7033488912363082, 0.6954681508923458, 0.684481616740026], 1: [0.6179229032490902, 0.5948432306867573, 0.6167410019638201, 0.6165876358039056, 0.6178627264985305, 0.6116189039656329, 0.604680796652444, 0.5928983457905049, 0.5897887729925787, 0.5945049200120445, 0.5871985733130765, 0.5790858557272157, 0.5721027988833733]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))

Epoch 1/1
Train loss: {0: [1.6803342033163342, 1.192454763004138, 1.048453326514852, 0.9384002263460985, 0.8776981047002004, 0.8198577896726577, 0.7862382492191038, 0.7575672983565611, 0.7432610229920361, 0.7430573872992285, 0.7290655598790535, 0.7210331537519713, 0.7033488912363082, 0.6954681508923458, 0.684481616740026], 1: [0.6179229032490902, 0.5948432306867573, 0.6167410019638201, 0.6165876358039056, 0.6178627264985305, 0.6116189039656329, 0.604680796652444, 0.5928983457905049, 0.5897887729925787, 0.5945049200120445, 0.5871985733130765, 0.5790858557272157, 0.5721027988833733, 0.5724520715470163]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))

Epoch 1/1
Train loss: {0: [1.6803342033163342, 1.192454763004138, 1.048453326514852, 0.9384002263460985, 0.8776981047002004, 0.8198577896726577, 0.7862382492191038, 0.7575672983565611, 0.7432610229920361, 0.7430573872992285, 0.7290655598790535, 0.7210331537519713, 0.7033488912363082, 0.6954681508923458, 0.684481616740026], 1: [0.6179229032490902, 0.5948432306867573, 0.6167410019638201, 0.6165876358039056, 0.6178627264985305, 0.6116189039656329, 0.604680796652444, 0.5928983457905049, 0.5897887729925787, 0.5945049200120445, 0.5871985733130765, 0.5790858557272157, 0.5721027988833733, 0.5724520715470163, 0.5691993314747644]}



HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))


[0.7518572983157132, 0.741977969348659, 0.38000217399834674, 0.4029614268249022]
Val loss: {0: [0.8281069160421233, 0.9105895141127434, 0.8924771286702872, 0.8722748382238076, 0.8693136283145201, 0.8666636285311622, 0.8542038544133839, 0.8562569421145861], 1: [0.7518572983157132]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))


[0.7639885078901532, 0.7285081417624522, 0.502780687420761, 0.4329612979030244]
Val loss: {0: [0.8281069160421233, 0.9105895141127434, 0.8924771286702872, 0.8722748382238076, 0.8693136283145201, 0.8666636285311622, 0.8542038544133839, 0.8562569421145861], 1: [0.7518572983157132, 0.7579148355656645]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))


[0.8483042938499158, 0.7131226053639846, 0.3819472561394118, 0.3427690656077816]
Val loss: {0: [0.8281069160421233, 0.9105895141127434, 0.8924771286702872, 0.8722748382238076, 0.8693136283145201, 0.8666636285311622, 0.8542038544133839, 0.8562569421145861], 1: [0.7518572983157132, 0.7579148355656645, 0.7885375091828664]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))


[0.6916697749232881, 0.7746647509578544, 0.5412419626767608, 0.5018718718797311]
Val loss: {0: [0.8281069160421233, 0.9105895141127434, 0.8924771286702872, 0.8722748382238076, 0.8693136283145201, 0.8666636285311622, 0.8542038544133839, 0.8562569421145861], 1: [0.7518572983157132, 0.7579148355656645, 0.7885375091828664, 0.7649469645112733]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=262.0), HTML(value='')))


[0.7073383849699378, 0.7788254310344828, 0.5107353492943263, 0.4344153666861669]
Val loss: {0: [0.8281069160421233, 0.9105895141127434, 0.8924771286702872, 0.8722748382238076, 0.8693136283145201, 0.8666636285311622, 0.8542038544133839, 0.8562569421145861], 1: [0.7518572983157132, 0.7579148355656645, 0.7885375091828664, 0.7649469645112733, 0.7534193625854283]}


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))

# Search of hyperparams