In [1]:
%load_ext line_profiler
%load_ext memory_profiler

%matplotlib inline

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import util
import models

from os.path import expanduser
from os import path
import logging

logging.basicConfig(level=logging.INFO)

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn import metrics
from sklearn_evaluation import plot as sk_plot

import yass
from yass import read_config, preprocess, detect
from yass.augment import make
from yass.neuralnetwork import NeuralNetDetector, NeuralNetTriage, AutoEncoder 
from yass.explore import RecordingExplorer, SpikeTrainExplorer

import numpy as np
from scipy.io import loadmat

from dstools import plot

Using TensorFlow backend.


In [4]:
yass.set_config("../config/49-lab.yaml")
CONFIG = read_config()

In [5]:
path_to_data = path.expanduser('~/data')
path_to_ground_truth = path.join(path_to_data,
                                 'groundtruth_ej49_data1_set1.mat')
path_to_standarized = path.join(path_to_data, 'tmp', 'preprocess',
                               'standarized.bin')
path_to_geom = path.join(path_to_data, 'ej49_geometry1.txt')

path_to_here = path.expanduser('~/dev/private-yass/nnet')

In [6]:
# load ground truth
_ = loadmat(path_to_ground_truth)
gt = np.hstack([_['spt_gt'], _['L_gt']])

gt = gt[2:-1]
gt[:, 1] = gt[:, 1] - 1

# compensate alignment
gt[:, 0] = gt[:, 0] + 5

## Generating training data

In [7]:
n_isolated_spikes = 500
min_amplitude = 5
n_templates = np.max(gt[:,1]) + 1
chosen_templates = np.arange(n_templates)

In [8]:
(x_detect, y_detect,
 x_triage, y_triage,
 x_ae, y_ae) = make.training_data(CONFIG, gt, chosen_templates,
                                  min_amplitude, n_isolated_spikes,
                                  data_folder='/home/Edu/data/nnet/',
                                  multi_channel=True)

100%|██████████| 3/3 [00:08<00:00,  2.69s/it]


UnboundLocalError: local variable 'x_to_collide_all' referenced before assignment

In [None]:
x_triage.shape

In [None]:
print(f'Training set sizes:\n\tdetect:{x_detect.shape}\n\ttriage:{x_triage.shape}\n\tautoencoder:{x_ae.shape}')

In [None]:
# compute amplitude for positive and negative samples in the detect training set
x_detect_positive = x_detect[y_detect == 1]
x_detect_negative = x_detect[y_detect == 0]

x_triage_positive = x_triage[y_triage == 1]
x_triage_negative = x_triage[y_triage == 0]

In [None]:
plt.rcParams['figure.figsize'] = (10, 10)
plot.grid_from_array(x_triage_positive, axis=0, elements=10, sharey=True)

In [None]:
plt.rcParams['figure.figsize'] = (10, 10)
plot.grid_from_array(x_triage_negative, axis=0, elements=10, sharey=False)

## Manually generating collisions

In [8]:
import os
from yass.templates import preprocess
from yass.augment.util import (make_noisy, make_clean, make_collided,
                               make_misaligned, make_noise)

yass._enable_debug_mode()

collision_ratio = 1
multi_channel = True

data_folder = '/home/Edu/data/nnet/'
path_to_data = os.path.join(data_folder, 'preprocess', 'standarized.bin')

templates, templates_uncropped = preprocess(CONFIG, gt,
                                            path_to_data,
                                            chosen_templates)

_, _, n_neigh = templates.shape
K, _, n_channels = templates_uncropped.shape

# make training data set
R = CONFIG.spike_size
amps = np.max(np.abs(templates), axis=1)

# make clean augmented spikes
nk = int(np.ceil(n_isolated_spikes/K))
max_amp = np.max(amps) * 1.5
max_shift = 2*R

# make clean spikes
x_clean = make_clean(templates, min_amplitude, max_amp, nk)

INFO:yass.templates.preprocess:Getting templates...
INFO:yass.templates.util:Computing templates...
INFO:yass.batch.batch:Applying function yass.templates.util.compute_weighted_templates...
100%|██████████| 3/3 [00:08<00:00,  2.73s/it]
INFO:yass.batch.batch:yass.templates.util.compute_weighted_templates took 8.1987 seconds


In [9]:
x_clean.shape

(533, 61, 7)

In [13]:
# make collided spikes
x_collision, x_to_collide = make_collided(x_clean, collision_ratio, templates,
                                          R, multi_channel, n_neigh)

In [12]:
_.shape

(533, 61, 7)

In [None]:
plt.rcParams['figure.figsize'] = (10, 10)
plot.grid_from_array(x_clean, axis=0, elements=range(10), sharey=True)

In [None]:
plt.rcParams['figure.figsize'] = (10, 10)
plot.grid_from_array(x_collision, axis=0, elements=range(10), sharey=True)

In [None]:
plt.rcParams['figure.figsize'] = (10, 10)
plot.grid_from_array(x_to_collide, axis=0, elements=range(10), sharey=False)