## Optimizing neural network parameters

In [1]:
import util
import platform
from pathlib import Path
from os import path

import numpy as np
from scipy.io import loadmat

import yass
from yass.augment import make
from yass.neuralnetwork import NeuralNetDetector, NeuralNetTriage

In [2]:
LOCAL = platform.system() == 'Darwin'

if LOCAL:
    yass.set_config("../config/49-local.yaml")
else:
    yass.set_config("../config/49-lab.yaml")

# yass configuration
home = Path.home()
path_to_config = str(home / 'dev/private-yass/config/49-lab.yaml')
path_to_ground_truth = str(home / 'data/groundtruth_ej49_data1_set1.mat')
path_to_data = str(home / 'data/tmp')

yass.set_config(path_to_config)
CONFIG = yass.read_config()

if LOCAL:
    path_to_here = path.expanduser('~/dev/lab/private-yass/nnet')
else:
    path_to_here = path.expanduser('~/dev/private-yass/nnet')

In [3]:
# load ground truth
_ = loadmat(path_to_ground_truth)
spike_train = np.hstack([_['spt_gt'], _['L_gt']])

spike_train = spike_train[2:-1]
spike_train[:, 1] = spike_train[:, 1] - 1

# compensate alignment
spike_train[:, 0] = spike_train[:, 0] + 10

In [12]:
import itertools

# training set parameters

n_spikes = [20000]
min_amplitude = [5]
max_amplitude = [90]

n_templates = np.max(spike_train[:,1]) + 1
chosen_templates = [np.arange(n_templates)]

set_parameters = itertools.product(n_spikes, min_amplitude,
                                   max_amplitude, chosen_templates)

# model parameters
n_iter = 5000
l2_reg_scale = [0.00000005, 0.0000005, 0.000005]
filters_size = [[32, 16]]

model_parameters = filters_size


parameters = itertools.product(set_parameters, model_parameters)

In [13]:
for (n_spikes, min_amplitude,
     max_amplitude, chosen_templates), filters in parameters:
    
    dir_name = util.directory()

    # make training data
    (x_detect, y_detect,
     x_triage, y_triage,
     x_ae, y_ae) = make.training_data(CONFIG,
                                      spike_train,
                                      chosen_templates,
                                      min_amplitude,
                                      max_amplitude,
                                      n_spikes,
                                      path_to_data)
    
    # detector
    detect_name = path.join(path_to_here, 'models',
                            'detect-'+dir_name+'.ckpt')

    _, waveform_length, n_neighbors =  x_detect.shape

    detector = NeuralNetDetector(detect_name, filters,
                                 waveform_length, n_neighbors,
                                 threshold=0.5,
                                 channel_index=CONFIG.channel_index,
                                 n_iter=n_iter)

    detector.fit(x_detect, y_detect)
    
    # triage
    triage_name = path.join(path_to_here, 'models',
                            'triage-'+dir_name+'.ckpt')

    _, waveform_length, n_neighbors = x_triage.shape

    triage = NeuralNetTriage(triage_name, filters,
                             waveform_length=waveform_length,
                             threshold=0.5,
                             n_neighbors=n_neighbors,
                             n_iter=n_iter)
    
    triage.fit(x_triage, y_triage)
    
    # save test sets
    _ = path.join(path_to_here, 'models', 'x-triage-'+dir_name+'.npy')
    np.save(_, x_triage)
    
    _ = path.join(path_to_here, 'models', 'y-triage-'+dir_name+'.npy')
    np.save(_, y_triage)
    
    _ = path.join(path_to_here, 'models', 'x-detect-'+dir_name+'.npy')
    np.save(_, x_detect)
    
    _ = path.join(path_to_here, 'models', 'y-detect-'+dir_name+'.npy')
    np.save(_, y_detect)

    
    print(detect_name, triage_name, n_spikes,
          min_amplitude, max_amplitude, filters)

100%|██████████| 3/3 [00:08<00:00,  2.76s/it]
Tr loss: 0.010302701, Val loss: 0.001811245: 100%|██████████| 5000/5000 [01:21<00:00, 61.35it/s]    
Tr loss: 0.05112303, Val loss: 0.06348427: 100%|██████████| 5000/5000 [01:04<00:00, 77.41it/s]  


/home/Edu/dev/private-yass/nnet/models/detect-12-Jul-2018@00-30-40.ckpt /home/Edu/dev/private-yass/nnet/models/triage-12-Jul-2018@00-30-40.ckpt 20000 5 90 [32, 16]
