# Detector network workflow

Tested with Python 3.6.7 (from miniconda)

## Part 0. Dependencies

In [None]:
%%bash

pip install numpy seaborn sklearn-evaluation tensorflow
pip install git+git://github.com/edublancas/dstools
pip install git+git://github.com/paninski-lab/yass@8ddce299fe52901af0b35f3a49dda86f61ca2c6e

In [None]:
%%bash

mkdir -p /tmp/spike-sorting
curl https://dl.dropboxusercontent.com/s/sylnygjmcvkmi4z/templates.npy?dl=0 -o /tmp/spike-sorting/templates.npy
curl https://dl.dropboxusercontent.com/s/smk83ob73y9z7p0/config.yaml?dl=0 -o /tmp/spike-sorting/config.yaml
curl https://dl.dropboxusercontent.com/s/mfp5vcu9b53ws91/noise_cov.npz?dl=0 -o /tmp/spike-sorting/noise_cov.npz
curl https://dl.dropboxusercontent.com/s/k9qa7vttuzrsmr4/geometry.txt?dl=0 -o /tmp/spike-sorting/geometry.txt

## Part 1. Train/Test set creation

In [None]:
import datetime
import logging
from pathlib import Path
from os.path import expanduser
from os import path

import yass
from yass import read_config
from yass.augment import make
from yass.neuralnetwork import NeuralNetDetector
from yass.batch import RecordingsReader
from yass.augment.noise import noise_cov
from yass.templates import TemplatesProcessor
from yass.geometry import make_channel_index


import seaborn as sns
import numpy as np
from dstools import plot
import sklearn_evaluation.plot as skplot

# logging.basicConfig(level=logging.DEBUG)

In [None]:
path_to_data = expanduser('~/data')
path_to_experiment = path.join(path_to_data, 'retinal/sample_output')
path_to_standarized = path.join(path_to_experiment,
                                'preprocess', 'standarized.bin')

In [None]:
yass.set_config('/tmp/spike-sorting/config.yaml',
                '/tmp/spike-sorting/output')
CONFIG = read_config()

### 1.1 Loading templates

In [None]:
raw_templates = np.load('/tmp/spike-sorting/templates.npy')
n_templates, waveform_length, _ = raw_templates.shape
print(raw_templates.shape)

In [None]:
# crop templates spatially
processor = TemplatesProcessor(raw_templates)
templates = (processor
            .crop_spatially(CONFIG.neigh_channels, CONFIG.geom)
            .values)
templates.shape

In [None]:
plot.grid_from_array(templates, axis=0, auto_figsize=4,
                     max_cols=3, elements=9)

### 1.2 Estimating noise covariance structure

In [None]:
ch_idx = make_channel_index(CONFIG.neigh_channels, CONFIG.geom)
selected_channels = ch_idx[0]
selected_channels

In [None]:
LOAD_NOISE_COV = True


if LOAD_NOISE_COV:
    cov = np.load('/tmp/spike-sorting/noise_cov.npz')
    spatial_sig, temporal_sig = cov['spatial_sig'], cov['temporal_sig']
else:
    rec = RecordingsReader(path_to_standarized, loader='array').data[:, selected_channels]
    (spatial_sig,
     temporal_sig) = noise_cov(rec, templates.shape[1], templates.shape[1])

In [None]:
min_amplitude = 4
max_amplitude = 60
n_clean_per_template = 200
n_positive_total = n_templates *  n_clean_per_template
n_collided_per_spike = 0
probs = [0.6, 0.04, 0.04, 0.04, 0.04, 0.04, 0.04, 0.04, 0.04, 0.04, 0.04]

### 1.3 Create dataset

In [None]:
X, y = make.training_data_detect(templates=templates,
                                 minimum_amplitude=min_amplitude,
                                 maximum_amplitude=max_amplitude,
                                 n_clean_per_template=n_clean_per_template,
                                 n_collided_per_spike=n_collided_per_spike,
                                 n_temporally_misaligned_per_spike=0.25,
                                 n_noise=int(n_positive_total * 0.5),
                                 n_spatially_misaliged_per_spike=0,
                                 spatial_SIG=spatial_sig,
                                 temporal_SIG=temporal_sig,
                                 add_noise_kwargs={'reject_cancelling_noise': False},
                                 from_templates_kwargs={'probabilities': probs})

In [None]:
X[y == 1].shape, X[y == 0].shape

### 1.4 Plot some examples

In [None]:
# positive examples (spikes)
plot.grid_from_array(X[y == 1], axis=0,
                     elements=9, auto_figsize=3,
                     sharey=False)

In [None]:
# negative examples: noise and non-centered spikes
plot.grid_from_array(X[y == 0], axis=0,
                     elements=9, auto_figsize=3,
                     sharey=False)

## Part 2. Network training

In [None]:
n_iter = 5000
n_batch = 512
l2_reg_scale = 0.00000005
train_step_size =  0.0001
filters_detect = [32, 32]

In [None]:
_, waveform_length, n_neighbors =  X.shape

detector = NeuralNetDetector('/tmp/spike-sorting/my-detector-network.ckpt', filters_detect,
                             waveform_length, n_neighbors,
                             threshold=0.5,
                             channel_index=CONFIG.channel_index,
                             n_iter=n_iter)

detector.fit(X, y)

### Part 3. Network evaluation

In [None]:
preds = detector.predict(detector.x_test)

In [None]:
skplot.confusion_matrix(detector.y_test, preds, normalize=True, target_names=['Noise', 'Spike'])

### Part 4. cleanup

In [None]:
! rm -rf /tmp/spike-sorting