In [None]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
from hs2 import herdingspikes
from probe import NeuralProbe, NeuroPixel, BioCam
import numpy as np

%matplotlib inline

In [None]:
data_path = '/disk/scratch/mhennig/P29_16_07_14/raw/P29_16_05_14_retina02_left_stim3_fullarray_fullfieldHDF5.brw'

Probe = BioCam(data_path)
H = herdingspikes(Probe)

In [None]:
# Probe.show(figwidth=12, show_neighbors=[-20])

# Detect and localise spikes

In [None]:
# detectData(data, neighbours, spikefilename, shapefilename, channels, sfd, thres, maa = None, maxsl = None, minsl = None, ahpthr = None, tpre = 1, tpost = 2)
# MinAvgAmp minimal avg. amplitude of peak (in units of Qd)
# MaxSl dead time in frames after peak, used for further testing
# MinSl length considered for determining avg. spike amplitude
# AHPthr signal should go below that threshold within MaxSl-Slmin frames


to_localize = True
cutout_start = 6#10
cutout_end = 30
threshold = 16

In [None]:
H.DetectFromRaw?

In [None]:
H.DetectFromRaw(to_localize, cutout_start, cutout_end, threshold,
                maa=0, maxsl=12, minsl=3, ahpthr=0., out_file_name="ProcessedSpikes_biocam")

# # OR

#H.LoadDetected("ProcessedSpikes_biocam.bin")

print(str(H.spikes.shape[0])+' spikes detected')

In [None]:
plt.figure(figsize=(10, 5))
H.PlotTracesChannels(1022, cutout_start=cutout_start)

In [None]:
plt.figure(figsize=(10, 10))
H.PlotAll(invert=True, s=1, alpha=0.05)
plt.xlim([0, 64])
plt.ylim([0, 64])

# Cluster with DBScan

In [None]:
%%time
from sklearn.cluster import DBSCAN

eps = 0.1#15#0.12
alpha = 0.25
min_samples = 5

H.CombinedClustering(eps=eps, alpha=alpha, clustering_algorithm=DBSCAN,
                    min_samples=min_samples, pca_ncomponents=2, pca_whiten=True,
                    n_jobs=-1)

In [None]:
plt.figure(figsize=(10, 10))
H.PlotAll(invert=True, s=1, alpha=0.05)
plt.xlim([0, 64])
plt.ylim([0, 64])

In [None]:
units = (18,0,1,10,100)

H.PlotShapes(units)

# plot all units and unclustered spikes in a small region 

In [None]:
plt.figure(figsize=(10,10))
ax = plt.axes(facecolor='k')

largest = np.argsort(H.clusters['Size']).values[::-1]
cl = largest[2]
print(cl)

cx, cy = H.clusters['ctr_x'][cl],H.clusters['ctr_y'][cl]

# slice out the unclustered spikes
# is there a better way?
inds = np.where(H.spikes.cl==-1)[0]
x,y = H.spikes.x[inds].values, H.spikes.y[inds].values
dists = np.sqrt((cx-x)**2+(cy-y)**2)
spInds = np.where(dists<1)[0]
plt.scatter(x[spInds], y[spInds], c='r', s=3)


dists = np.sqrt((cx-H.clusters['ctr_x'])**2+(cy-H.clusters['ctr_y'])**2)
clInds = np.where(dists<2)[0]
for cl_t in clInds:
    cx, cy = H.clusters['ctr_x'][cl_t],H.clusters['ctr_y'][cl_t]
    inds = np.where(H.spikes.cl==cl_t)[0]
    x,y = H.spikes.x[inds],H.spikes.y[inds]
    plt.scatter(x,y,c=plt.cm.hsv(H.clusters['Color'][cl_t]), s=3,alpha=0.2)
    plt.text(cx-0.1,cy,str(cl_t), fontsize=16, color='w')
plt.axis('equal');


In [None]:
plt.figure(figsize=(14,14))
# ax = plt.axes(facecolor='k')
print(cl)

cx, cy = H.clusters['ctr_x'][cl],H.clusters['ctr_y'][cl]

plt.subplot(4,4,1)
inds = np.where(H.spikes.cl==-1)[0]
x,y = H.spikes.x[inds].values, H.spikes.y[inds].values
dists = np.sqrt((cx-x)**2+(cy-y)**2)
spInds = np.where(dists<2)[0][:20]
for i in inds[spInds]:
    plt.plot(H.spikes.Shape[i],'r')
plt.ylim((-300,100))    
plt.title('unclustered')

dists = np.sqrt((cx-H.clusters['ctr_x'])**2+(cy-H.clusters['ctr_y'])**2)
clInds = np.where(dists<2)[0]
for i,cl_t in enumerate(clInds[:15]):
    plt.subplot(4,4,i+2)
    spInds = np.where(H.spikes.cl==cl_t)[0]
    for i in spInds[:20]:
        plt.plot(H.spikes.Shape[i],'k')
    plt.ylim((-300,100))    
    plt.title('cluster '+str(cl_t))
    



In [None]:
# unclustered events
plt.subplot(121)
inds = np.where(H.spikes.cl==-1)[0]
for i in range(0,200,2):
    plt.plot(H.spikes['Shape'][inds[i]],'r')
plt.plot(np.mean(H.spikes['Shape'][inds]),'b',lw=2)
plt.title('unclustered')
plt.ylim((-300,100))    

plt.subplot(122)
inds = np.where(H.spikes.cl>-1)[0]
for i in range(20,200,2):
    plt.plot(H.spikes['Shape'][inds[i]],'k')
plt.plot(np.mean(H.spikes['Shape'][inds]),'b',lw=2)
plt.title('clustered')
plt.ylim((-300,100))

# Mean Shift Clustering

In [None]:
%%time
# H.CombinedClustering(alpha=40,
#                     bandwidth = 20, bin_seeding=True, min_bin_freq=10,
#                     pca_ncomponents=2, pca_whiten=True,
#                     n_jobs=-1)
H.CombinedClustering(alpha=0.4,
                    bandwidth = 0.3, bin_seeding=True, min_bin_freq=10,
                    pca_ncomponents=2, pca_whiten=True,
                    n_jobs=-1)

plt.figure(figsize=(10, 10))
H.PlotAll(invert=True, s=1)
# plt.xlim((1350,1600))
plt.title("MeanShift, bandwidth=.3, min_bin_freq=10")