In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import logging
import os

import numpy as np

from yass.neuralnet import NeuralNetDetector
from yass.config import Config

from neural_clustering.explore import (SpikeTrainExplorer,
                                       RecordingExplorer)
from neural_clustering import config


import matplotlib.pyplot as plt

ModuleNotFoundError: No module named 'neural_clustering.explorer'

In [None]:
plt.style.use('ggplot')
plt.rcParams['figure.figsize'] = (10, 10)

In [None]:
logging.basicConfig(level=logging.ERROR)

# Loading configuration files and YASS output

YASS is a Python package for spike sorting, which is being developed by Peter Lee (PhD in the Stats department and me): https://github.com/paninski-lab/yass

Someone in the lab implemented a truncated DPMM using numpy. Since the code is hard to debug and the only person who understands it is the person who wrote it, I want to see if we can start using Edward instead, so we can iterate quickly and prototype new models easily – without having to write custom inference algorithms every time.


In [None]:
# load configuration files
cfg_yass = Config.from_yaml('../yass_config/local_7ch.yaml')
cfg = config.load('../config.yaml')

In [None]:
# load data generated from yass
files = ['score', 'clear_index', 'spike_times', 'spike_train', 'spike_left', 'templates']

(score, clear_index,
 spike_times, spike_train,
 spike_left, templates) = [np.load(os.path.join(cfg['root'], 'yass/{}.npy'.format(f))) for f in  files]

# Loading raw recordings, geometry file and projection matrix

In [None]:
# load raw recordings
path_to_raw_recordings = os.path.join(cfg_yass.root, '7ch.bin')
# load standarized recordings (these are raw recordings + filter + standarization)
path_to_recordings = os.path.join(cfg_yass.root, 'tmp/standarized.bin')
# load gemetry file (position for every electro)
path_to_geometry = os.path.join(cfg_yass.root, cfg_yass.geomFile)

In [None]:
# load projection matrix (to reduce dimensionality)
proj = NeuralNetDetector(cfg_yass).load_w_ae()

# Initialize explorers

In [None]:
# initialize explorers, these objects implement functions for plotting
# the output from YASS
explorer_rec = RecordingExplorer(path_to_recordings,
                                 path_to_geometry,
                                 dtype='float64',
                                 window_size=cfg_yass.spikeSize,
                                 n_channels=cfg_yass.nChan,
                                 neighbor_radius=cfg_yass.spatialRadius)

explorer_raw = RecordingExplorer(path_to_raw_recordings,
                                 path_to_geometry,
                                 dtype='int16',
                                 window_size=cfg_yass.spikeSize,
                                 n_channels=cfg_yass.nChan,
                                 neighbor_radius=cfg_yass.spatialRadius)


explorer_train = SpikeTrainExplorer(templates,
                                    spike_train,
                                    explorer_rec,
                                    proj)

In [None]:
print('Observations: {}. Channels: {}'.format(*explorer_raw.data.shape))

# Raw recordings

In [None]:
plt.rcParams['figure.figsize'] = (60, 60)
explorer_raw.plot_series(from_time=4500, to_time=5000)

## Filtered + Standarized Recordings

In [None]:
explorer_rec.plot_series(from_time=4500, to_time=5000)

# Geometry plot

In [None]:
plt.rcParams['figure.figsize'] = (10, 10)
explorer_rec.plot_geometry()

# Training data

## Load clear spike times in channel 0

In [None]:
clear_indexes = clear_index[0]
clear_spikes = spike_times[0][clear_indexes, 0]

In [None]:
all_spike_times = np.vstack(spike_times)[:, 0]
print('Detected {} clear spikes'.format(clear_spikes.shape[0]))

In [None]:
# there are some errors here.... check
# all_spike_times = np.sort(all_spike_times)[:-6]
# there is a bug in the latest version of yass that shifts spike times...
clear_spikes = clear_spikes - cfg_yass.BUFF

In [None]:
clear_spikes

# Visualizing a detected spike

In [None]:
plt.rcParams['figure.figsize'] = (10, 15)
t = clear_spikes[0]
explorer_rec.plot_waveform(time=t, channels=range(7))

## Load waveforms around spike times

In [None]:
waveforms = explorer_rec.read_waveforms(times=clear_spikes)
print('Training set dimensions: {}'.format(waveforms.shape))

## Reduce waveforms temporal dimensionality from 31 to 3 and flatten data

In [None]:
waveforms_reduced = explorer_train._reduce_dimension(waveforms, flatten=True)
print('Training set dimensions: {}'.format(waveforms_reduced.shape))

# Save training data

In [None]:
output_path = os.path.join(cfg['root'], 'training.npy')
np.save(output_path, waveforms_reduced)
print(f'Saved training data in {output_path}')