## Training neural networks

This notebook describes the workflow for training neural networks

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from os import getcwd
from pathlib import Path
from functools import reduce

import yaml
import numpy as np
import keras
from sklearn.model_selection import train_test_split
from keras.models import Sequential
from keras.layers import Dense, Conv2D, Flatten, MaxPooling2D, Dropout
from dstools.reproducibility import make_filename
from dstools.reproducibility.util import git_hash_in_path, get_version
from dstools.util import save

from yass import util

Using TensorFlow backend.


In [3]:
here = getcwd()
here_version = git_hash_in_path(here)

In [4]:
# for reference
print('YASS version is: {}'.format(get_version('yass')))
print('nbs version is: {}'.format(here_version))

YASS version is: d5925c9 options for specifying distribution of clean spikes
nbs version is: 1a9c9bb noteboks for new triage network


In [5]:
path_to_data = Path('~', 'data', 'triage').expanduser()
path_to_output = path_to_data  / 'models'
path_to_sets = path_to_data / 'sets'

## Loading training data

In [8]:
x_name = '2018-08-21T15-31-09:x-train-31wf7ch.npy'

path_to_x_train = path_to_sets / x_name
path_to_y_train = path_to_sets / x_name.replace('x', 'y')

path_to_x_test = path_to_sets / x_name.replace('train', 'test')
path_to_y_test = path_to_sets / x_name.replace('x-train', 'y-test')

x_train = np.load(path_to_x_train)
y_train = np.load(path_to_y_train)
x_test = np.load(path_to_x_test)
y_test = np.load(path_to_y_test)

In [9]:
def make_model(x_train, input_shape):
    n_data, window_size, n_channels, _ = x_train.shape

    model = Sequential()
        
#     model.add(MaxPooling2D(pool_size=(2, 1), data_format="channels_last", padding='same'))
    
#     model.add(Dropout(0.75))

    model.add(Conv2D(10, kernel_size=(5, 5),
                     padding='same', activation='relu', use_bias=True,
                     data_format="channels_last", input_shape=input_shape))


#     model.add(Conv2D(70, kernel_size=(window_size, 1),
#                      padding='valid', activation='relu', use_bias=True,
#                      data_format="channels_last"))
    
    model.add(Conv2D(10, kernel_size=(5, 5),
                     padding='same', activation='relu', use_bias=True,
                     data_format="channels_last"))


#     model.add(Conv2D(70, kernel_size=(1, n_channels),
#                      padding='valid', activation='relu', use_bias=True,
#                      data_format="channels_last"))
    
    model.add(Conv2D(10, kernel_size=(5, 5),
                     padding='same', activation='linear', use_bias=True,
                     data_format="channels_last"))
        
#     model.add(MaxPooling2D(pool_size=(1, 2), data_format="channels_last", padding='same'))
#     model.add(Dropout(0.75))
    
    model.add(Flatten())
    model.add(Dense(1, activation='sigmoid'))
    
    model.summary()
    
    # initiate RMSprop optimizer
    opt = keras.optimizers.rmsprop(lr=0.0001, decay=1e-6)
    opt = keras.optimizers.adam(lr=0.001)

    model.compile(loss='binary_crossentropy',
                  optimizer=opt,
                  metrics=['accuracy'])
    
    return model


# Triage Ttaining

In [10]:
# import models
x_train = x_train[:, : , :, np.newaxis]
x_test = x_test[:, : , :, np.newaxis]

_, wf, ch, _ = x_train.shape

m = make_model(x_train, (wf, ch, 1))

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_1 (Conv2D)            (None, 31, 7, 10)         260       
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 31, 7, 10)         2510      
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 31, 7, 10)         2510      
_________________________________________________________________
flatten_1 (Flatten)          (None, 2170)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 2171      
Total params: 7,451
Trainable params: 7,451
Non-trainable params: 0
_________________________________________________________________


In [11]:
m.fit(x_train, y_train,
      batch_size=10000, epochs=100, shuffle=True,
      validation_data=(x_test, y_test))

Train on 48240 samples, validate on 23760 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100


Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78/100
Epoch 79/100
Epoch 80/100
Epoch 81/100
Epoch 82/100
Epoch 83/100
Epoch 84/100
Epoch 85/100
Epoch 86/100
Epoch 87/100
Epoch 88/100
Epoch 89/100
Epoch 90/100
Epoch 91/100
Epoch 92/100
Epoch 93/100
Epoch 94/100
Epoch 95/100
Epoch 96/100
Epoch 97/100
Epoch 98/100
Epoch 99/100
Epoch 100/100


<keras.callbacks.History at 0x7f4133eaecc0>

In [12]:
metadata = dict(path_to_x_train=str(path_to_x_train),
                path_to_y_train=str(path_to_y_train),
                path_to_x_test=str(path_to_x_test),
                path_to_y_test=str(path_to_y_test),
                yass_version=get_version('yass'),
                nb_version=git_hash_in_path(getcwd()))
metadata

{'path_to_x_train': '/home/Edu/data/triage/sets/2018-08-21T15-31-09:x-train-31wf7ch.npy',
 'path_to_y_train': '/home/Edu/data/triage/sets/2018-08-21T15-31-09:y-train-31wf7ch.npy',
 'path_to_x_test': '/home/Edu/data/triage/sets/2018-08-21T15-31-09:x-test-31wf7ch.npy',
 'path_to_y_test': '/home/Edu/data/triage/sets/2018-08-21T15-31-09:y-test-31wf7ch.npy',
 'yass_version': 'd5925c9 options for specifying distribution of clean spikes',
 'nb_version': '1a9c9bb noteboks for new triage network'}

In [13]:
_, wf, ch, _ = m.input_shape

sufix = f'triage-{wf}wf{ch}ch'
names = make_filename(sufix=sufix, extension=('h5', 'yaml'))

path_to_model, path_to_metadata = [str(path_to_output / name) for name in names]

m.save(path_to_model)
save(metadata, path_to_metadata)

In [14]:
print(path_to_model)

/home/Edu/data/triage/models/2018-08-21T15-35-13:triage-31wf7ch.h5
