In [None]:
#!pip install numpy==1.19

import os
import fnmatch
from scipy.signal import butter, filtfilt, fftconvolve,find_peaks
#import numpy as np
!pip install stumpy
import stumpy
import umap
import numpy as np
from sklearn.cluster import DBSCAN
from scipy.interpolate import interp1d

import pickle
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline


def butter_bandpass(lowcut, highcut, fs, order=3):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return b, a
def butter_bandpass_filter(data, lowcut, highcut, fs, order=3):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = filtfilt(b, a, data)
    return y
def sliding_std(filtered_dat,sliding_win_size=20000):
    #sliding window size (in datapoints) is the measurement unit for local std
    #spikes are detected above 5std (neg, and pos)
    #we have spurious discharges (a few datapoint-wide spikes, width being at half maximum.)
    #we set minimum width to throw out spurious spikes, such as visible at the last quarter of channel [7]
    #in baseline1.

    kernel = np.ones( (sliding_win_size,1 ))/sliding_win_size #average counter
    sig=filtered_dat.T
    sliding_means = fftconvolve(sig,kernel,mode='same')
    sig_ = sig-sliding_means
    sliding_stds = np.sqrt( fftconvolve(sig_*sig_, kernel, mode='same'))
    return sliding_stds
def get_peaks(filtered_dat,sliding_stds):
    peaks=[]
    for i in range(len(filtered_dat)):
        p,_=find_peaks( np.abs(filtered_dat[i]),height=5*sliding_std[:,i],width=peakwidth,rel_height=0.3)
        peaks.append(p)
    return peaks





  @numba.jit()
  @numba.jit()
  @numba.jit()


## The Analyser class below helps create a more compact notebook, especially when a lot of data is being handled.

### The analyser expects the data in the form of a numpy array. You can load your recording file (.rhs, etc), and save it with numpy into an .npy array, in the dimensions of (channel,samples)The current analyser was custom fit to our needs, please note that:

* The data is sampled at 20kHz. It's bandpass between 210 and 9500 Hz. (3rd order acausal butterworth)
* The standard deviation is computed in 1s (20000 samples) window, see definition in cell above
* First, events larger than 5 local (within 1s) standard deviations are collected and clustered
* * Then, with template matching, smaller events are fished out
* Non-electrophysiological clusters are discarded optically, althoug heuristics could be built to automate this (peak half-width ratios, etc)
* Extracted peaks for template matching and clustering consist of 30 samples before and 30 samples after the peak maximum.
* Clustering is done via 2-component UMAP (from Umap-learn package) and density based clustering.


In [None]:

class Analyser:
    def __init__(self,path):
        self.path = path
        print('Loading data')
        self.data = np.load(path+'/raw.npy')
        print('Data loaded')
        print('Filtering and getting windowed std')
        self.filtdata=butter_bandpass_filter(self.data,lowcut=210,highcut=9500,fs=20000)
        self.std=sliding_std(self.filtdata)
        print('Done')
        #self.get_peaks_and_get_waveforms()
        self.is_clust=False
    def get_peaks_and_get_waveforms(self):
        all_chan_peaks=[]
        for idx in range(len(self.filtdata)):
            p,_=find_peaks( np.abs(self.filtdata[idx]),height=5*self.std[:,idx], distance=60)
            all_chan_peaks.append(np.vstack((np.ones(len(p))*idx,p)).T)
        self.all_chan_peaks=np.concatenate(all_chan_peaks).astype(int)
        snips=[]
        for chan,pix in self.all_chan_peaks:
            snips.append(self.filtdata[chan][pix-30:pix+30])
        snips=[el for el in snips if len(el)==60]
        self.snips=np.asarray(snips)
    def extract_shape_clusters(self,show_clusters=True,close_plots=False):
        if not self.is_clust:
            reducer=umap.UMAP(n_components=2)
            print('Dimensionality reduction on detected waveforms')
            projections=reducer.fit_transform(self.snips)
            self.clust=DBSCAN(eps=0.5)
            self.clust.fit(projections)
            self.is_clust=True
        if show_clusters:
            n_clusters=np.max(self.clust.labels_)+1
            labs=self.clust.labels_

            fig,ax=plt.subplots(1,n_clusters, figsize=(n_clusters*3,3),sharex=True)
            for i in range(n_clusters):
                avg=np.median(self.snips[labs==i],axis=0)
                std=np.std(self.snips[labs==i],axis=0)
                l=len(self.snips[0])
                x=np.linspace(-l//2,l//2,l)/20 #ms
                ax[i].fill_between(x,y1=avg-std,y2=avg+std,alpha=0.3,color='b')
                ax[i].plot(x,avg,color='r',linewidth=0.5)
                ax[i].set_title('Cluster '+str(i)+', '+str(np.round(100*np.sum(labs==i)/len(labs),1))+'%')
            ax[0].set_ylabel('Voltage [uV]')
            ax[0].set_xlabel('Time [ms]')
            plt.savefig(self.path+'/clusters.svg',bbox_inches='tight',pad_inches=0)
        if close_plots:
            plt.close()
    def match_cluster(self,cluster_indices):
        #has to be a list
        channel_spindx=[]
        for c in cluster_indices:
            waveform=np.median(self.snips[self.clust.labels_==c],axis=0)[10:50]
            wave_length=len(waveform)
            seq=pd.DataFrame({'x': waveform})
            match=np.asarray([stumpy.mass(seq['x'],pd.DataFrame({'x': filt})['x']) for filt in self.filtdata])
            threshold=0.3*(2*np.sqrt(wave_length)) #scale-free is 0.3. Compensate by sequence length
            #do spike detection to not double count !
            #channel_spindx.append(np.asarray(np.where(match<threshold)).T)
            for chan in range(len(match)):
                p,_=find_peaks(-match[chan]+threshold,height=0,distance=60)
                channel_spindx.append(np.vstack((chan*np.ones(len(p)),p)).T)
        self.channel_spindx=np.concatenate(channel_spindx).astype(int)
        np.save(self.path+'/channel_spindx.npy',self.channel_spindx)
    def save_spike_cutouts(self):
        self.spikes_dict={}
        for chan in np.unique(self.channel_spindx[:,0]):
            ch_sp=self.channel_spindx
            a_chan=ch_sp[ch_sp[:,0]==chan].astype(int)
            sigs=[self.filtdata[el[0],el[1]-5:el[1]+55] for el in a_chan]
            self.spikes_dict[str( int(chan))]=sigs
        with open(self.path+'/spike_shapes.pickle', 'wb') as handle:
            pickle.dump(self.spikes_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

    def produce_sample_detections(self,n_samples,window_size):
        if not os.path.exists(self.path+'/sample_traces/'):
            os.mkdir(self.path+'/sample_traces/')
        for n in range(n_samples):
            channel,idx=self.channel_spindx[np.random.randint(len(self.channel_spindx))]
            trace=self.filtdata[channel,idx-window_size:idx+window_size]
            fig,ax=plt.subplots(1,1,figsize=(10,3))
            x=1000*np.linspace(-window_size,window_size,2*window_size)/20000.
            ax.plot(x,trace)
            ax.plot(x[window_size:window_size+40],trace[window_size:window_size+40], color='r')
            ax.set_xlabel('Time [ms]')
            ax.set_ylabel('Voltage [uV]')
            ax.set_title('Channel:'+str(channel)+', '+'Time stamp : '+str(idx))
            plt.savefig(self.path+'/sample_traces/'+str(n)+'.svg',bbox_inches='tight',pad_inches=0.3)
            plt.close()

## Example workflow

In [None]:
file_path_to_npy_data = './data/'

In [None]:
experiment = Analyser(file_path_to_npy_data)
experiment.get_peaks_and_get_waveforms()
experiment.extract_shape_clusters(close_plots=False)

### At this point, we extracted large events, first. We see that cluster 0 and 1 are noisy spikes. Cluster 2 is electrophysiological. So we take the good clusters, and we use them to template match

In [None]:
experiment.match_cluster([2]) #take a list of good clusters like [0,1,2...] but here, only [2]
experiment.save_spike_cutouts()
experiment.produce_sample_detections(20,400) #how many examples to produce to verify the matching, and the length of each sample


## At this point, we have all the time stamps, we produced
* some verification traces in data/sample_traces
* We saved the shapes of the spikes in spike_shapes.pickle
* and channel_spindx.npy saves the channel,timestamp pairs of peak detection

### Getting the firing rates

In [None]:
chan_spindx=np.load(file_path_to_npy_data+'/channel_spindx.npy')
#or, if it's still in memory: chan_spindx = experiment.channel_spindx
spike_total=[np.sum(chan_spindx[:,0]==chan) for chan in range(np.max(chan_spindx[:,0]))]

In [None]:
file_duration = experiment.filtdata.shape[1]
file_duration_in_s = file_duration/20000 #20 khz sampling
fire_rate_hz = [f/file_duration_in_s for f in spike_total]


In [None]:
width=5*20000 #seconds x sampling rate
t=np.linspace(-width/2,width/2,width)
gauss_kern=np.exp(-(t)**2/(0.1*width**2))
gauss_kern/=np.sum(gauss_kern)

density=[]
resolution=int(file_duration_in_s)# how many pixels to show the density plot on
for chan in range(len(fire_rate_hz)):
    tmp=np.zeros(int(file_duration))
    ch_sp=np.load(file_path_to_npy_data+'/channel_spindx.npy')
    tmp[ch_sp[ch_sp[:,0]==chan,1]]=1 #trace with '1' where there's a spike
    convolved_tmp=fftconvolve(tmp,gauss_kern,mode='same')
    x=np.arange(len(convolved_tmp))
    fy=interp1d(x,convolved_tmp)
    density.append(fy(np.linspace(0,x[-1],resolution)))


In [None]:
from scipy.stats import iqr
mi=np.min(np.concatenate(density))
low=mi+iqr(np.concatenate(density),rng=(0,0.5))
high=mi+iqr(np.concatenate(density),rng=(0,99.5))
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
import matplotlib.cm as cm
bluz = cm.get_cmap('Blues', 1024)
newcolors = bluz(np.linspace(0, 1, 1024))
pink = np.array(newcolors[5])
newcolors[10:50, :] = newcolors[50]
newcolors[:10, :] = newcolors[0]
newcmp = ListedColormap(newcolors)

fig,ax=plt.subplots(1,1,figsize=(8,4))
im=ax.imshow(np.asarray(density)*(len(gauss_kern)/5),aspect='auto',interpolation=None,cmap=newcmp,vmin=0.1,vmax=high*(len(gauss_kern)/5))
plt.colorbar(im)
plt.ylabel('Channel index')
plt.xlabel('Time [s]')
plt.title('Local firing rate [Hz]')