**Step 1: Load yass**

In [None]:
%matplotlib inline

import os
import numpy as np
import tensorflow as tf
import progressbar
import matplotlib.pyplot as plt
import logging
import scipy.io as sio
import yass
from yass import read_config
from yass import preprocess
from yass.augment import make_training_data, save_detect_network_params, save_triage_network_params, save_ae_network_params
from yass.augment import train_detector, train_ae, train_triage
import matplotlib.pyplot as plt


**Step 2: Read Configuration File**

In [None]:
yass.set_config("location/to/config.yaml")
CONFIG = read_config()

**Step 2.1: Preprocessing (optional)**

Make sure the standarized data exists. Otherwise, run this step to automatically generate a temp folder including the standarized and filtered data.

In [None]:
preprocess.run()

**Step 3: Load Spike Train**

To train the Neural Network, you need to have a recording with sorted result. The result does not need to be perfect.
If you don't have any sorting result yet, you can run yass with threshold detection option. In your configuration file, set spikes.detection = threshold.

spike_train is a matrix of size (number of spikes x 2). Each row represents an individual spike. The first column is the spike time (not in milliseconds or seconds but in actual temporal location in recording). The second column is the spike ID.

Other python functions can be used to load the data depending on the data type. For example, use the following code for loading the binary data: 

f = open('location/to/spiketrain', 'r')

spike_train= np.fromfile(f, dtype)

In [None]:
spike_train=np.loadtxt("location/to/spiketrain",dtype=int) 


**Step 3.2: Visually Inspect Templates (optional)**

Not necessarily all the templates generated from the groundtruth look good. It is advised that user visually inspect these templates one by one and only keep good templates, i.e. isolated spikes with enough absolute amplitude (usually >= 3 standard units).

In [None]:
from yass.templates.util import get_templates
from yass.augment.crop import crop_templates

#if the standardized data is elsewhere, replace the following path accordingly
path_to_data = os.path.join(CONFIG.data.root_folder, 'tmp/standarized.bin') 

templates, _ = get_templates(spike_train, path_to_data, 4*CONFIG.spike_size)
templates = np.transpose(templates, (2, 1, 0))
templates = crop_templates(templates, CONFIG.spike_size,
                           CONFIG.neigh_channels, CONFIG.geom)
for i in range(templates.shape[0]):
    plt.plot(templates[i])
    plt.title('template id: %d'%(i))
    plt.show()


**Step 4: Make Training Dataset**

1. CONFIG and spike_train are from step 2 and 3.

2. chosen_templates: It is a vector containing which templates to use. Given spike sorting result, not all templates look good. Therefore, the training dataset should be obtained from good looking templates only. Make sure that you do not include bad templates. However, it is still important to keep variability in template shapes. To visually check templates, check optional step at the bottom.

3. min_amp: the minimum of absolute maximal amplitude of augmented spikes. It should determine how small spikes in the training set can be. Default is 3.

4. nspikes: approximately how many training data it should produce? This is the number of isolated spikes in the training set, NOT the total number of training datapoints 

5. noise_ratio,collision_ratio,misalign_ratio,misalign_ratio2: the ratio of the number of corresponding type of training data as opposed to the number of isolated spikes. For example, if noise_ratio=10, collision_ratio=1, misalign_ratio=3, misalign_ratio2=3; then for every portion of isolated data generated, 10 portions of noise data, 1 portion of collision data, 3 portions of spatially-and-temporally misaligned data (for x_detect), and 3 portions of spatially misaligned data(x_triage) will be generated. The proportion is to be customized according to the assumption of the distribution of the portions in real data. It is generally recommended that noise_ratio>=5 to reflect the sparse firing pattern of the neurons.

6. multi: boolean variable, if True (default), multi-channel data will be generated.



In [None]:
min_amp = 3
nspikes = 10000
data_folder="location/to/standarized.bin"
chosen_templates=np.arange(1,164) # should be your own numbers
np.delete(chosen_templates,[12,34,57,65,163])#getting rid of the bad-looking templates from Step 3.2
noise_ratio=10
collision_ratio=1
misalign_ratio=3
misalign_ratio2=3
multi=True

In [None]:
x_detect, y_detect, x_triage, y_triage, x_ae, y_ae = make_training_data(CONFIG, spike_train, chosen_templates,
                                                                         min_amp, nspikes,data_folder,
                                                                         noise_ratio,collision_ratio,misalign_ratio,
                                                                         misalign_ratio2,multi)

**Step 5: Train All Three Neural Networks**

Training parameters:
1. n_iter: the number of iterations to run
2. n_batch: the size of mini-batch to be used for training
3. l2_reg_scale: L2 regularization penalty term
4. train_step_size: training step size

In [None]:
n_iter = 50000
n_batch = 512
l2_reg_scale = 0.00000005
train_step_size =  0.001

Training neural net detector
1. detect_name: name of saved model with the location to save 
2. n_filters: number of filters to use in each layer. It should be a list of length 2

In [None]:
detectnet_name = 'detectnet'
n_filters_detect = [16, 8]

In [None]:
train_detector(x_detect, y_detect, n_filters_detect, n_iter, n_batch, l2_reg_scale, train_step_size, detectnet_name)

In [None]:
# save model parameters
save_detect_network_params(filters = n_filters_detect,
                           size = x_detect.shape[1],
                           n_neighbors = x_detect.shape[2],
                           output_path = detectnet_name.replace('ckpt', 'yaml'))

Training neural net triage
1. triage_name: name of saved model with the location to save 
2. n_filters: number of filters to use in each layer. It should be a list of length 2

In [None]:
triagenet_name = 'triagenet'
n_filters_triage = [16, 8]

In [None]:
# run training
train_triage(x_triage, y_triage, n_filters_triage, n_iter, n_batch, l2_reg_scale, train_step_size, triagenet_name)

In [None]:
# save model parameters
save_triage_network_params(filters = n_filters_triage,
                           size = x_detect.shape[1],
                           n_neighbors = x_detect.shape[2],
                           output_path = triagenet_name.replace('ckpt', 'yaml'))

Training autoencoder
1. ae_name: name of saved model with the location to save 
2. n_feature: number of latent variables

In [None]:
ae_name = 'ae'
n_features = 3
n_batch = x_ae.shape[0]

In [None]:
# run training
train_ae(x_ae, y_ae, n_features, n_iter, n_batch, train_step_size, ae_name)

In [None]:
# save model parameters
save_ae_network_params(n_input = x_ae.shape[1],
                       n_features = n_features,
                       output_path = ae_name.replace('ckpt', 'yaml'))

**You are done!**

**Step 6: When Using yass**

Make sure that you have all your files! You must have **3 '.ckpt'** files and **1 '.yaml'** file for **each neural network model**, which make **total 12 files**.

Also, make sure that the parameters in your configuration file match with the parameters used during the training

| Name in config.yaml | How it should change |
|---|---|
|spikes.temporal_features|n_feature used for training autoencoder|
|recordings.spike_size_ms|make sure that this value stays the same as configuration loaded here|
|neural_network_detector.filename|file name used above to save neural net detector|
|neural_network_triage.filename|file name used above to save neural net triage|
|neural_network_autoencoder.filename|file name used above to save neural net autoencoder|