***Generating simulated testing data ***

In [1]:
import os
import numpy as np
import logging


from yass.augment.choose import choose_templates
from yass.augment.crop import crop_templates
from yass.augment.noise import noise_cov
from yass.templates.util import get_templates
from yass.util import load_yaml

# TODO: documentation
# TODO: comment code, it's not clear what it does
def make_testing_data(CONFIG, data_length, ptp_att_std, spike_train, chosen_templates,n_per_cluster,data_folder):
 
    logger = logging.getLogger(__name__)

    path_to_data = os.path.join(data_folder, 'standarized.bin')
    path_to_config = os.path.join(data_folder, 'standarized.yaml')

    # make sure standarized data already exists
    if not os.path.exists(path_to_data):
        raise ValueError('Standarized data does not exist in: {}, this is '
                         'needed to generate training data, run the '
                         'preprocesor first to generate it'
                         .format(path_to_data))

    PARAMS = load_yaml(path_to_config)

    logger.info('Getting templates...')

    # get templates
    templates, _ = get_templates(spike_train, path_to_data, data_length*CONFIG.spike_size)

    templates = np.transpose(templates, (2, 1, 0))

    logger.info('Got templates ndarray of shape: {}'.format(templates.shape))

    # choose good templates (good looking and big enough)
    #templates = choose_templates(templates, chosen_templates)
    templates=templates[chosen_templates]
    if templates.shape[0] == 0:
        raise ValueError("Coulndt find any good templates...")   

    logger.info('Good looking templates of shape: {}'.format(templates.shape))

    # align and crop templates
    templates = crop_templates(templates, data_length*CONFIG.spike_size,
                               CONFIG.neigh_channels, CONFIG.geom)
    
    # determine noise covariance structure
    spatial_SIG, temporal_SIG = noise_cov(path_to_data,
                                          PARAMS['dtype'],
                                          CONFIG.recordings.n_channels,
                                          PARAMS['data_order'],
                                          CONFIG.neigh_channels,
                                          CONFIG.geom,
                                          templates.shape[1])
    
    # make training data set
    K = templates.shape[0]
    #R = CONFIG.spike_size
    #amps = np.max(np.abs(templates), axis=1)

    # make clean augmented spikes
    #nk = int(np.ceil(nspikes/K))
    #if max_amp == 0:
    #    max_amp = np.max(amps)*1.5
     
    #nneigh = templates.shape[2]

    ################
    # clean spikes #
    ################
    x_clean = np.zeros((n_per_cluster*K, templates.shape[1], templates.shape[2]))
    ids=np.zeros(x_clean.shape[0],dtype=int)
    ptp=np.zeros(K)

    for i in range(K):
        ptp[i]=np.ptp(templates[i,:,0])
        
    for k in range(K):
        
        
        tt = templates[k]
        ptp_now=ptp[k]
        ptp_range = (np.random.normal(1,ptp_att_std,n_per_cluster))[:, np.newaxis, np.newaxis]
        
        
        
        x_clean[k*n_per_cluster:(k+1)*n_per_cluster] = tt[np.newaxis, :, :]*ptp_range
        ids[k*n_per_cluster:(k+1)*n_per_cluster]=k

            
    
    #########
    # noise #
    #########

    # get noise
    noise = np.random.normal(size=[x_clean.shape[0], templates.shape[1], templates.shape[2]])
    for c in range(noise.shape[2]):
        noise[:, :, c] = np.matmul(noise[:, :, c], temporal_SIG)

        reshaped_noise = np.reshape(noise, (-1, noise.shape[2]))
    noise = np.reshape(np.matmul(reshaped_noise, spatial_SIG),
                       [noise.shape[0], x_clean.shape[1], x_clean.shape[2]])
       


    x_clean=x_clean+noise
    
   

    return x_clean,noise,ids,ptp

  return f(*args, **kwds)


***Loading files (please have your config file ready) ***

In [3]:
import os
import numpy as np
import tensorflow as tf
#import h5py
import progressbar
%matplotlib inline
import matplotlib.pyplot as plt
#import panda as pd
import pickle
import logging
import scipy.io as sio
import yass
from yass import read_config
from yass.augment import make_training_data, save_detect_network_params, save_triage_network_params, save_ae_network_params
from yass.augment import train_detector, train_ae, train_triage
from yass import preprocess
import matplotlib.pyplot as plt


yass.set_config("/ssd/data/shenghao/retinal/configuration_retinal.yaml")
CONFIG = read_config()

***Load Spike Train***

To train the Neural Network, you need to have a recording with sorted result. The result does not need to be perfect.
If you don't have any sorting result yet, you can run yass with threshold detection option. In your configuration file, set spikes.detection = threshold.

spike_train is a matrix of size (number of spikes x 2). Each row represents an individual spike. The first column is the spike time (not in milliseconds or seconds but in actual temporal location in recording). The second column is the spike ID.

In [5]:
import numpy as np
# load ground truth
# make spikeTrain
import scipy.io
kk = scipy.io.loadmat('/ssd/data/shenghao/retinal/groundtruth_ej49_data1_set1.mat')

#L_gt has length total number of spikes; it is the cluster index for each spike
#spt_gt has length total number of spikes; it is the time for each spike
L_gt = kk['L_gt']-1
spt_gt = kk['spt_gt'] +10
spike_train = np.concatenate((spt_gt, L_gt),axis=1)
print (spike_train[0:10])

[[ 17  39]
 [ 23  26]
 [ 61  28]
 [ 80  15]
 [126  10]
 [251  40]
 [278   6]
 [288  42]
 [468  36]
 [545  38]]


***generating testing data!!***


In [102]:
#input parameters for x_test

data_length=12
n_per_cluster=200
ptp_att_std=0.01
data_folder='/ssd/data/shenghao/retinal/tmp'
chosen_templates = [0, 1, 2, 3, 4, 5, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18, 19, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 41, 42, 43, 44, 45, 46, 47, 48] # should be your own number
x_clean,noise,ids,ptp=make_testing_data(CONFIG, data_length, ptp_att_std, spike_train, chosen_templates,n_per_cluster,data_folder)
sio.savemat('retinal_testing_may19.mat',mdict={'x_clean':x_clean,'noise':noise,'ids':ids,'ptp':ptp})