This is necessary for having figures directly in the notebook.

In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

Import necessary modules

In [2]:
import numpy as np
import pandas as pd
from matplotlib import pyplot
import seaborn as sns

import tridesclous as tdc

Download Locust dataset from zenedo.

Get the first trial as a standard numpy array. shape = (nb_sampleXnb_channel)

In [14]:
from urllib.request import urlretrieve
import os
import h5py

name = 'locust20010201.hdf5'
distantfile = 'https://zenodo.org/record/21589/files/'+name
localfile = name
if not os.path.exists(localfile):
    urlretrieve(distantfile, localfile)
hdf = h5py.File(localfile,'r')

#create numpy array from the first trials
ch_names = ['ch09','ch11','ch13','ch16']
numpy_sigs = np.array([hdf['Continuous_1']['trial_01'][name][...] for name in ch_names]).transpose()
#numpy_sigs = (numpy_sigs.astype('float32') - 2**15.) / 2**15


# Signals = pandas.DataFrame
Each segment of data is a pandas.DataFrame:
  * index is the time coded in second.
  * columns are channels names.



In [15]:
sampling_rate =  15000. #in Hz
t_start = 0. #this is the time of the first point in second
times = np.arange(numpy_sigs.shape[0], dtype = 'float64')/sampling_rate + t_start

signals = pd.DataFrame(numpy_sigs, index = times, columns = ch_names)





So we can acces with times or sample position with DataFrame.loc and dataFrame.iloc.
See http://pandas.pydata.org/pandas-docs/stable/indexing.html

In [16]:
chunk = signals.iloc[45225:45450]  #slicing by sample
chunk = signals.loc[3.015:3.030]  #slicing by time

chunk

Signals is pure pandas.DataFrame so We can user all pandas facilities:

In [20]:
signals.describe()


In [21]:
med, mad = tdc.median_mad(signals)
mad

Plotting is also easy so.

In [22]:
chunk.plot()

tridesclous have some function that directly work on this kind of (signals) DataFrame:
  * normalize_signals
  * derivative_signals
  * rectify_signals


In [23]:
normed_sigs = tdc.normalize_signals(signals)
deriv_sigs = tdc.derivative_signals(signals)
retified_sigs = tdc.rectify_signals(normed_sigs, threshold = -4)

fig, axs = pyplot.subplots(ncols = 3, figsize = (15, 8))
normed_sigs[3.14:3.22].plot(ax = axs[0])
deriv_sigs[3.14:3.22].plot(ax = axs[1])
retified_sigs[3.14:3.22].plot(ax = axs[2])
axs[2].set_ylim(axs[0].get_ylim())


# filter

The class filter help:
  * high pass filter
  * smooth with boxcar

In [24]:
filter =  tdc.SignalFilter(signals, highpass_freq = 0., box_smooth = 5)
filtered_sigs = filter.get_filtered_data()

# Peak detection
The class  PeakDetector offers facilities:
  * to detect peaks.

This return peak_pos in index.

Having pek with times is easy : 
peak_time = signals.index[peaks_pos]



In [32]:
peakdetector = tdc.PeakDetector(filtered_sigs)
peaks_pos_bad = peakdetector.detect_peaks(threshold=-5, peak_sign = '-', n_span = 15)
peaks_index_bad = signals.index[peaks_pos_bad]

peaks_pos_ok = peakdetector.detect_peaks(threshold=-5, peak_sign = '-', n_span = 15)
peaks_index_ok = signals.index[peaks_pos_ok]

fig, axs = pyplot.subplots(nrows = 2, ncols = 2, figsize = (15, 8))


t1, t2 = 3.163, 3.166
chunk = normed_sigs[t1:t2]
chunk_rectified = peakdetector.rectified_sigs.sum(axis=1)[t1:t2]

#bad
chunk.plot(ax = axs[0,0])
chunk_rectified.plot(ax = axs[1,0])
peaks_value_bad = normed_sigs.loc[peaks_index_bad]
peaks_value_bad[t1:t2].plot(marker = 'o', linestyle = 'None', ax = axs[0,0], color = 'k')
axs[0,0].set_title('n_span=2')
peaks_value_bad = chunk_rectified.loc[peaks_index_bad]
peaks_value_bad[t1:t2].plot(marker = 'o', linestyle = 'None', ax = axs[1,0], color = 'k')



#OK
chunk.plot(ax = axs[0,1])
chunk_rectified.plot(ax = axs[1,1])
peaks_value_ok = normed_sigs.loc[peaks_index_ok]
peaks_value_ok[t1:t2].plot(marker = 'o', linestyle = 'None', ax = axs[0,1], color = 'k')
axs[0,1].set_title('n_span=5')
peaks_value_ok = chunk_rectified.loc[peaks_index_ok]
peaks_value_ok[t1:t2].plot(marker = 'o', linestyle = 'None', ax = axs[1,1], color = 'k')



for ax in axs.flatten():
    ax.set_ylim(-30, 10)

# Extract waveform

The class WaveformExtractor offers facilities to:
    * extreact waveforms
    * extract noise (=fake waveform in between peaks)
    * keep or exclude good events
    * find good limits for the cut.
   
The wavefroms object os also a pandas.DataFrame with:
   * index is peak_pos
   * columns is MultiIndex (channels, samples) where samples is from n_left to n_rigth [-10, -9, ..., 0, 1, ...,  30]. 0 is peak.



In [33]:
#start with larger sweep
waveformextractor = tdc.WaveformExtractor(peakdetector, n_left=-30, n_right=50)
med, mad = tdc.median_mad(waveformextractor.long_waveforms)
fig, axs = pyplot.subplots(nrows =2)
med.plot(ax = axs[0], ylim = (-6, 2))
mad.plot(ax = axs[1], ylim = (0, 5))

# make some noise to compare
noise = waveformextractor.extract_noise(-30, 50, size=1000, safety_factor=2)
med_noise, mad_noise = tdc.median_mad(noise)
med_noise.plot(ax = axs[0], ylim = (-6, 2), color = 'r')
mad_noise.plot(ax = axs[1], ylim = (0, 5), color = 'r')

In [34]:
#find the good limits
limit_left, limit_right = waveformextractor.find_good_limits(mad_threshold = 1.1)
print(limit_left, limit_right)
waveformextractor.plot_good_limit()
short_wf = waveformextractor.get_ajusted_waveforms()

# Projection and Clustering

The class Clustering offers facilities to:
  * project waveform with : PCA, ...
  * clusters them with kmeans, EM+GMM


In [37]:
# work on shorted waveforms (see good limits)
clustering = tdc.Clustering(short_wf)

# do a PCA
features = clustering.project(method = 'pca', n_components = 4)
features

clustering.plot_explained_variance_ratio()
clustering.plot_waveform_variance()

clustering.plot_projection(plot_density = False)

In [38]:
# try to clusters
labels = clustering.find_clusters(7)
df = pd.concat([features, labels], axis=1)

clustering.plot_projection(plot_density = False)

In [39]:
catalogue = clustering.construct_catalogue()
clustering.plot_catalogue(sameax = True)
clustering.plot_catalogue(sameax = False)


# interactive windows

This work only on a localhost when PyQt4+pyqtgraph is installed.

Do not forget the %gui qt4

In [40]:
%gui qt4
import pyqtgraph as pg
app = pg.mkQApp()
win = tdc.CatalogueWindow.from_classes(peakdetector, waveformextractor, clustering)
win.show()

# Peeler
The class peeler help:
   * estimate jitter
   * predict spiketrain
   * subtract and get reisuals


In [41]:
signals = peakdetector.normed_sigs
peeler = tdc.Peeler(signals, catalogue,  limit_left, limit_right,
                        threshold=-4, peak_sign = '-', n_span = 5)

#Peel at level=0
prediction0, residuals0 = peeler.peel()
fig, axs = pyplot.subplots(nrows = 2)
axs[0].plot(prediction0)
axs[1].plot(residuals0)



In [43]:
#Peel at level=1
prediction1, residuals1 = peeler.peel()
fig, axs = pyplot.subplots(nrows = 2)
axs[0].plot(prediction1)
axs[1].plot(residuals1)


In [44]:
peeler.plot_spiketrains()