## Optimizing neural network parameters

In [1]:
import itertools
import util
import platform
from pathlib import Path
from os import path

import numpy as np
from scipy.io import loadmat

import yass
from yass.augment import make
from yass.neuralnetwork import NeuralNetDetector, NeuralNetTriage
from yass.util import get_version

from dstools.params import make_grid
from dstools.reproducibility import make_path

In [2]:
LOCAL = platform.system() == 'Darwin'

if LOCAL:
    yass.set_config("../config/49-local.yaml")
else:
    yass.set_config("../config/49-lab.yaml")

# yass configuration
home = Path.home()
path_to_config = str(home / 'dev/private-yass/config/49-lab.yaml')
path_to_ground_truth = str(home / 'data/groundtruth_ej49_data1_set1.mat')
path_to_data = str(home / 'data/tmp')

yass.set_config(path_to_config)
CONFIG = yass.read_config()

if LOCAL:
    path_to_here = path.expanduser('~/dev/lab/private-yass/nnet')
else:
    path_to_here = path.expanduser('~/dev/private-yass/nnet')
    
path_to_models = path.join(path_to_here, 'models')

In [3]:
# load ground truth
_ = loadmat(path_to_ground_truth)
spike_train = np.hstack([_['spt_gt'], _['L_gt']])

spike_train = spike_train[2:-1]
spike_train[:, 1] = spike_train[:, 1] - 1

# compensate alignment
spike_train[:, 0] = spike_train[:, 0] + 10

In [4]:
n_templates = np.max(spike_train[:,1]) + 1

# chosen_templates = [0, 1, 2, 3, 4, 5, 7, 8, 9, 11, 12, 13,
#                     14, 15, 16, 17, 18, 19, 22, 23, 24, 25,
#                     26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
#                     36, 37, 38, 39, 41, 42, 43, 44, 45, 46,
#                     47, 48]


set_params = dict(n_spikes=[10000, 20000, 30000],
                  min_amplitude=[2, 5, 8],
                  max_amplitude=[20, 60, 100],
                  chosen_templates=[np.arange(n_templates)])

model_params = dict(n_iter=[5000, 7000],
                    l2_reg_scale=[0.00000005, 0.0000005, 0.000005],
                    filters_size = [[16, 8], [32, 16]])

In [5]:
set_grid = make_grid(set_params)
model_grid = make_grid(model_params)
grid = itertools.product(set_grid, model_grid)

In [None]:
for set_params, model_params in parameters:
    
    dir_name = util.directory()

    # make training data
    (x_detect, y_detect,
     x_triage, y_triage,
     x_ae, y_ae) = make.training_data(CONFIG,
                                      spike_train,
                                      chosen_templates,
                                      path_to_data,
                                     **set_params)
    
    # detector
    detect_name = make_path(path_to_models, 'detect', extension='ckpt')

    _, waveform_length, n_neighbors =  x_detect.shape

    detector = NeuralNetDetector(path_to_model=detect_name,
                                 waveform_length=waveform_length,
                                 n_neighbors=n_neighbors,
                                 threshold=0.5,
                                 channel_index=CONFIG.channel_index,
                                 **model_params)

    detector.fit(x_detect, y_detect)
    
    # triage
    triage_name = make_path(path_to_models, 'triage', extension='ckpt')

    _, waveform_length, n_neighbors = x_triage.shape

    triage = NeuralNetTriage(path_to_model=triage_name,
                             waveform_length=waveform_length,
                             threshold=0.5,
                             n_neighbors=n_neighbors,
                             **model_params)
    
    triage.fit(x_triage, y_triage)
    
    # store models metadata
    metadata = dict(yass_version=get_version())
    # TODO: save name, loss, computed metrics, set and model params,
    # yass version
    
    # save test sets for future evaluation, update code above so
    # they match models name
    _ = path.join(path_to_here, 'models', 'x-triage-'+dir_name+'.npy')
    np.save(_, x_triage)
    
    _ = path.join(path_to_here, 'models', 'y-triage-'+dir_name+'.npy')
    np.save(_, y_triage)
    
    _ = path.join(path_to_here, 'models', 'x-detect-'+dir_name+'.npy')
    np.save(_, x_detect)
    
    _ = path.join(path_to_here, 'models', 'y-detect-'+dir_name+'.npy')
    np.save(_, y_detect)