In [1]:
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
from hs2 import herdingspikes
from probe import NeuroPixel
from sklearn.cluster import DBSCAN
import numpy as np
import heapq

%matplotlib inline

In [2]:
# detectData(data, neighbours, spikefilename, shapefilename, channels, sfd, thres, maa = None, maxsl = None, minsl = None, ahpthr = None, tpre = 1, tpost = 2)
# MinAvgAmp minimal avg. amplitude of peak (in units of Qd)
# MaxSl dead time in frames after peak, used for further testing
# MinSl length considered for determining avg. spike amplitude
# AHPthr signal should go below that threshold within MaxSl-Slmin frames

user = "Cole"

if user == "Hennig":
    data_path = '/disk/scratch/mhennig/neuropixel/data/rawDataSample.bin'
elif user == "Martino":
    data_path = "data/rawDataSample.bin"
else:
    data_path = '/home/cole/neuropixeldata/rawDataSample.bin'

to_localize = True
cutout_start = 10
cutout_end = 30
threshold = 12
masking = None
file_name = 'ProcessedSpikes'


In [3]:
Probe = NeuroPixel(data_file_path=data_path, fps=30000, masked_channels=masking)

H = herdingspikes(Probe)

In [4]:
H.DetectFromRaw(to_localize, cutout_start, cutout_end, threshold,
                maa=0, maxsl=12, minsl=3, ahpthr=0)

# # OR

# H.LoadDetected()

# Sampling rate: 30000
# Localization On
# Not Masking any Channels
# Number of recorded channels: 385
# Analysing frames: 1800000, Seconds:60.0
# Frames before spike in cutout: 30
# Frames after spike in cutout: 66
# tcuts: 42 55
# tInc: 200000
# Analysing 200000 frames; -42 200055
# Analysing 200000 frames; 199958 400055
# Analysing 200000 frames; 399958 600055
# Analysing 200000 frames; 599958 800055
# Analysing 200000 frames; 799958 1000055
# Analysing 200000 frames; 999958 1200055
# Analysing 200000 frames; 1199958 1400055
# Analysing 200000 frames; 1399958 1600055
# Analysing 199945 frames; 1599958 1800000
# Time taken for detection: 0:00:14.995852
# Time per frame: 0:00:00.008331
# Time per sample: 0:00:00.000022
Detected and read 385019 spikes.


In [5]:
H.CombinedClustering(alpha=40,
                    bandwidth=20, bin_seeding=True, min_bin_freq=10,
                    pca_ncomponents=2, pca_whiten=True,
                    n_jobs=-1)

Fitting PCA using 385019 spikes


  self.spikes.cl = clusterer.labels_


Number of estimated clusters: 191


In [6]:
from sklearn.decomposition import PCA
from random import shuffle
import random
from sklearn import svm
from sklearn import preprocessing


pca_whiten = True
pca = PCA(n_components=10, whiten=pca_whiten)
pca.fit(np.array(list(H.spikes.Shape)))

PCA(copy=True, iterated_power='auto', n_components=10, random_state=None,
  svd_solver='auto', tol=0.0, whiten=True)

In [7]:
def getClosestClusters(cluster_id, num_neighbors):
    cluster_distances = [(cl_id, np.sqrt(((center - H.centerz[cluster_id])**2).sum())) for cl_id, center in enumerate(H.centerz)]
    closest_clusters = heapq.nsmallest(num_neighbors,cluster_distances, key=lambda X: X[1])
    return closest_clusters

In [8]:
def getRepresentativeWaveforms(cluster_id, num_waveforms):
    cluster = H.in_cl[cluster_id]
    spikes_dist_from_cluster = [(spike, np.sqrt(((H.fourvec[spike] - H.centerz[cluster_id])**2).sum())) for spike in cluster]
    representative_spikes = heapq.nsmallest(num_waveforms, spikes_dist_from_cluster, key=lambda X: X[1])
    representative_cutouts = np.array(list(H.spikes.Shape[[spike[0] for spike in representative_spikes]]))
    return (cluster_id, representative_cutouts, representative_spikes) 

In [9]:
def createTrainingSet(cluster_representatives):
    X = []
    Y = []
    for cluster in cluster_representatives:
        cl_id = cluster[0]
        cutouts = cluster[1]
        for i, example in enumerate(pca.transform(np.array(list(cutouts)))):
            spike_id = cluster[2][i][0]
            np.append(example, [H.spikes.x[spike_id]])
            np.append(example, [H.spikes.y[spike_id]])
            X.append(example)
            Y.append(cl_id)
            
    c = list(zip(X, Y))
    random.shuffle(c)
    X, Y = zip(*c)
    return (X, Y)

In [10]:
def getAllSpikeShapesinCluster(cluster_id):
    return np.array(list([H.spikes.Shape[spike] for spike in H.in_cl[cluster_id]]))

In [11]:
def getAllSpikePositionsinCluster(cluster_id):
    return [(H.spikes.x[spike], H.spikes.y[spike]) for spike in H.in_cl[cluster_id]]

In [16]:
def getProbabilitiesinClusterSVM(testing_set, testing_positions, labels, clf):
    correct = 0
    wrong = 0
    probabilities = []
    testing_set_transforms = pca.transform(testing_set)
    for i, transform in enumerate(testing_set_transforms):
        np.append(transform, [testing_positions[i][0]])
        np.append(transform, [testing_positions[i][1]])
    for label in labels:
        for prediction in clf.predict(testing_set_transforms):
            if(prediction == label):
                correct += 1
            if(prediction != label):
                wrong += 1
        probabilities.append((label, correct/(correct + wrong)))
        correct = 0
        wrong = 0
    return probabilities

In [17]:
cluster_representatives = []

#Get closest clusters to 0 (Minowski Distance)
for i in range(10):
    
    closest_clusters = getClosestClusters(i, 5)

    #Get all representative waveforms from all neighbors
    for cluster in closest_clusters:
        representative_waveforms = getRepresentativeWaveforms(cluster[0], 100)
        cluster_representatives.append(representative_waveforms)


    #Getting Training Data from all neighbors
    X, Y = createTrainingSet(cluster_representatives)
    
    #Create and Train the classifier
    clf = svm.SVC(kernel = 'rbf')
    clf.fit(X , Y)

    #Run cluster 0 through the classifier and return probabilities
    testing_set = getAllSpikeShapesinCluster(i)
    positions_set = getAllSpikePositionsinCluster(i)
    labels = [cluster_id for cluster_id, _ in closest_clusters]
    probabilities = getProbabilitiesinClusterSVM(testing_set, positions_set, labels, clf)
    probabilities.sort(key=lambda X: -X[1])
    cluster_representatives = []
    
    print("Cluster " + str(i), probabilities)


Cluster 0 [(0, 0.6330893840967026), (166, 0.20549296932818026), (164, 0.08535482279417811), (119, 0.07458268234520188), (6, 0.0014801414357371926)]
Cluster 1 [(1, 0.3280223063134834), (134, 0.27384983071101376), (56, 0.24039036048595896), (80, 0.1525592511451902), (30, 0.005178251344353714)]
Cluster 2 [(37, 0.2815433936340737), (2, 0.2807419280970918), (127, 0.14987405541561713), (143, 0.1478131440348065), (158, 0.14002747881841082)]
Cluster 3 [(3, 0.6601039103089964), (8, 0.22559474979491387), (86, 0.11156685808039377), (36, 0.002461033634126333), (126, 0.00027344818156959256)]
Cluster 4 [(4, 0.534874684282106), (43, 0.26306586360986983), (84, 0.1206528074606567), (63, 0.08024091703905188), (73, 0.0011657276083155237)]
Cluster 5 [(5, 0.8002343292325718), (76, 0.08904510837727006), (155, 0.058875219683655534), (75, 0.04657293497363796), (48, 0.005272407732864675)]
Cluster 6 [(6, 0.9037494284407864), (0, 0.09625057155921353), (166, 0.0), (164, 0.0), (119, 0.0)]
Cluster 7 [(7, 0.66428775

In [14]:
cluster_representatives = []

#Get closest clusters to 0 (Minowski Distance)
for i in range(2):
    
    closest_clusters = getClosestClusters(i, 5)

    #Get all representative waveforms from all neighbors
    for cluster in closest_clusters:
        representative_waveforms = getRepresentativeWaveforms(cluster[0], 100)
        cluster_representatives.append(representative_waveforms)
    print(cluster_representatives)

    #Getting Training Data from all neighbors
    X, Y = createTestSet(cluster_representatives)

    #Create and Train the classifier
    clf = svm.SVC(kernel = 'rbf', probability=True)
    clf.fit(X, Y)

    #Run cluster 0 through the classifier and return probabilities
    testing_set = getAllSpikeShapesinCluster(i)
    labels = [cluster_id for cluster_id, _ in closest_clusters]
    probabilities = getProbabilitiesVerboseinClusterSVM(testing_set, labels, clf)
    probabilities.sort(key=lambda X: -X[1])
    cluster_representatives = []
    
    print("Cluster " + str(i), probabilities)

[(0, array([[  8,   5,  -7, ...,  -9, -16, -13],
       [ -8,   2,  -4, ..., -17, -16,  -8],
       [  1,   4,   8, ..., -11, -16, -19],
       ..., 
       [  2,  -2,  -1, ...,   4,   8,   1],
       [-11, -10,  -3, ..., -19, -23, -20],
       [ -5,  -3,  -3, ...,   4,  -2,  -1]], dtype=int32), [(154608, 3.3305478103307768), (48922, 3.3448402065448937), (164277, 3.3999625074950721), (21705, 3.4014166840608437), (151834, 3.4652758905793952), (148521, 4.7083302994225225), (281929, 4.8158376555711486), (360502, 4.9475865034057254), (185189, 4.9627803839780587), (265581, 4.9873680940991436), (28155, 5.0576461532803965), (240964, 5.1096575189930045), (355239, 5.3196074793877965), (298137, 5.4355532836237392), (361078, 5.4675674340940263), (134474, 5.4823322651838353), (260812, 5.4976973041598995), (345968, 5.5040198502164923), (37033, 5.6128845801213112), (44613, 5.6935154809156865), (181097, 5.7157942664796995), (28076, 5.7572967075667218), (326712, 6.0104750200944101), (370059, 6.1159514

NameError: name 'createTestSet' is not defined

In [None]:
cutouts = np.array(list(H.spikes.Shape[[spike[0] for spike in farthest_spikes]])) 
cutouts2 = np.array(list(H.spikes.Shape[[spike[0] for spike in closest_spikes]]))
cutouts3 = np.array(list(H.spikes.Shape[[spike[0] for spike in closest_spikes165]]))
cutouts4 = np.array(list(H.spikes.Shape[[spike[0] for spike in closest_spikes167]]))
cutouts5 = np.array(list(H.spikes.Shape[[spike[0] for spike in closest_spikes6]]))
cutouts6 = np.array(list(H.spikes.Shape[[spike[0] for spike in closest_spikes134]]))

cutouts_whole_0 = np.array(list(H.spikes.Shape[[spike for spike in cluster_0]])) 
cutouts_whole_170 = np.array(list(H.spikes.Shape[[spike for spike in cluster_134]])) 

#f, (ax1, ax2, ax3, ax4, ax5, ax6) = plt.subplots(6, sharex=True, sharey=True)
f, (ax1, ax2) = plt.subplots(1, 2, sharex=True, sharey=True)
f.set_figheight(10)
f.set_figwidth(10)

#ax1.plot(cutouts[0] - np.median(cutouts[0], axis=0), 'red', lw=2)
ax1.plot(np.mean(cutouts2, axis=0), 'red', lw=2)

ax1.set_title("Farthest Spikes Cluster 0")

cluster = 6

if cluster == 0:
    #ax2.plot(np.mean(cutouts2[:100], axis=0), 'green', lw=2)
    ax2.plot(cutouts2[20], 'green', lw=2)
    ax2.set_title("Closest Spikes Cluster 0")
if cluster == 16:
    ax2.plot(np.mean(cutouts3[:100], axis=0), 'green', lw=2)
    ax2.set_title("Closest Spikes Cluster 165")
    
if cluster == 167:
    ax2.plot(np.mean(cutouts4[:100], axis=0))
    #ax2.plot(cutouts4[0], 'green', lw=2)
    ax2.set_title("Closest Spikes Cluster 167")
    
if cluster == 6:
    ax2.plot(np.mean(cutouts5[:50], axis=0) - np.median(np.mean(cutouts5[:1], axis=0), axis=0), 'green', lw=2)
    #ax2.plot(np.mean(cutouts5[:10], axis=0), 'green', lw=2)
    #ax2.plot(cutouts5[0], 'green', lw=2)
    ax2.set_title("Closest Spikes Cluster 6")
    
if cluster == 134:
    ax2.plot(np.mean(cutouts6, axis=0), 'green', lw=2)
    #ax2.plot(cutouts6[0], 'green', lw=2)
    ax2.set_title("Closest Spikes Cluster 134")

spike = farthest_spikes[0][0]
spike_dist_from_cluster = [(cl_id, np.sqrt(((H.fourvec[spike] - center)**2).sum())) for cl_id, center in enumerate(H.centerz)]
closest_clusters = heapq.nsmallest(50,spike_dist_from_cluster, key=lambda X: X[1])
print(closest_clusters)
#tuples0 = [(spike, np.sqrt(((H.fourvec[spike] - H.centerz[0])**2).sum())) for cl_id, center in enumerate(H.centerz)]
    
#ax3.plot(cutouts3[0], 'blue', lw=2)
#ax3.set_title("Closest Spikes Cluster 165")
#ax4.plot(cutouts4[0], 'blue', lw=2)
#ax4.set_title("Closest Spikes Cluster 167")
#ax5.plot(cutouts5[0], 'blue', lw=2)
#ax5.set_title("Closest Spikes Cluster 6")
#ax6.plot(cutouts6[0], 'blue', lw=2)
#ax6.set_title("Closest Spikes Cluster 170")                   
#ax1.plot(np.mean(cutouts, axis=0), 'red', lw=4)
#ax2.plot(np.mean(cutouts2, axis=0), 'green', lw=4)

In [None]:
import pywt
from statsmodels.robust import mad

sym4= pywt.Wavelet('sym4')
cutout_far0 = cutouts[0][1:] - np.mean(cutouts[0])
cutout_close0 = cutouts2[0][1:] - np.mean(cutouts2[0])

#cutout_center0 = cutout_center0 - np.mean(cutout_center0)

#for element in cutout_center0:
 #   if(element < 0):
 #       if (element > -10):
  #          element = 0
 #   else:
 #       if(element > 5):
  #          element = 0

f, (ax1, ax2, ax3, ax4) = plt.subplots(4, sharex=True, sharey=True)
f.set_figheight(10)
f.set_figwidth(5)

#ax1.plot(cutouts[0] - np.median(cutouts[0], axis=0), 'red', lw=2)
ax1.plot(cutout_far0, 'red', lw=2)
ax1.set_title("Farthest Spike Cluster 0")
ax3.plot(cutout_close0, 'red', lw=2)
ax3.set_title("Closest Spike Cluster 0")


#SWT transform of cutout and thresholding
coefficients = pywt.swt(cutout_far0, sym4, level=2)
coefficients2 = pywt.swt(cutout_close0, sym4, level=2)

universal = False
if(universal):
    sigma = np.median(coefficients[-1][1])/0.6745
    #sigma = mad(coefficients[-1]) #omit smoothing coefficients
    uthresh = sigma*np.sqrt(2*np.log(len(coefficients[0])))
    new_coefficients = coefficients[:]
    new_coefficients = pywt.threshold(new_coefficients, uthresh, 'hard')
    ax2.plot(pywt.iswt(new_coefficients,'sym4') , 'blue', lw=2)
    
    sigma = np.median(coefficients[-1][1])/0.6745
    #sigma = mad(coefficients2[-1]) #omit smoothing coefficients
    uthresh = sigma*np.sqrt(2*np.log(len(coefficients2[0])))
    new_coefficients = coefficients2[:]
    new_coefficients = pywt.threshold(new_coefficients, uthresh, 'hard')
    ax4.plot(pywt.iswt(new_coefficients,'sym4') , 'blue', lw=2)
else:
    new_coefficients = coefficients[:]
    new_coefficients = pywt.threshold(new_coefficients, 10, 'soft')
    ax2.plot(pywt.iswt(new_coefficients,'sym4') , 'green', lw=2)
    
    new_coefficients = coefficients2[:]
    new_coefficients = pywt.threshold(new_coefficients, 10, 'soft')
    ax4.plot(pywt.iswt(new_coefficients,'sym4') , 'blue', lw=2)
    
#print(cutouts_mean)
#ax2.plot(cA1, 'blue', lw=2)
#ax2.plot(cD1, 'green', lw=2)
#ax2.plot(cA2, 'red', lw=2)
##ax2.plot(cD2, 'purple', lw=2)

#ax1.plot(pywt.iswt([(cA1, cD1)],'sym4') , 'blue', lw=2)

ax2.set_title("SWT Reconstruction Farthest Cluster 0")
ax4.set_title("SWT Reconstruction Closest Cluster 0")

#if cluster == 0:
 #   ax2.plot(np.mean(cutouts2[:100], axis=0), 'green', lw=2)
    #ax2.plot(cutouts2[0], 'green', lw=2)
 #  ax2.set_title("Closest Spikes Cluster 0")

In [None]:
sym4= pywt.Wavelet('sym4')
cutout_far0 = cutouts[0][:40] - np.mean(cutouts[0])

#cutout_center0 = cutout_center0 - np.mean(cutout_center0)

#for element in cutout_center0:
 #   if(element < 0):
 #       if (element > -10):
  #          element = 0
 #   else:
 #       if(element > 5):
  #          element = 0

f, (ax1, ax2) = plt.subplots(1, 2, sharex=True, sharey=True)
f.set_figheight(10)
f.set_figwidth(10)

#ax1.plot(cutouts[0] - np.median(cutouts[0], axis=0), 'red', lw=2)
ax1.plot(cutout_far0, 'red', lw=2)
ax1.set_title("Farthest Spike Cluster 0")


#SWT transform of cutout and thresholding
coefficients = pywt.swt(cutout_far0, sym4, level=2)

universal = False
if(universal):
    sigma = mad(coefficients[-1]) #omit smoothing coefficients
    uthresh = sigma*np.sqrt(2*np.log(len(coefficients[0])))
    new_coefficients = coefficients[:]
    new_coefficients = pywt.threshold(new_coefficients, uthresh, 'hard')
    ax2.plot(pywt.iswt(new_coefficients,'sym4') , 'blue', lw=2)
    
else:
    new_coefficients = coefficients[:]
    new_coefficients = pywt.threshold(new_coefficients, 15, 'hard')
    ax2.plot(pywt.iswt(new_coefficients,'sym4') , 'green', lw=2)
#print(cutouts_mean)
#ax2.plot(cA1, 'blue', lw=2)
#ax2.plot(cD1, 'green', lw=2)
#ax2.plot(cA2, 'red', lw=2)
##ax2.plot(cD2, 'purple', lw=2)

#ax1.plot(pywt.iswt([(cA1, cD1)],'sym4') , 'blue', lw=2)

ax2.set_title("SWT Reconstruction Farthest Cluster 0")
print(new_coefficients)
#if cluster == 0:
 #   ax2.plot(np.mean(cutouts2[:100], axis=0), 'green', lw=2)
    #ax2.plot(cutouts2[0], 'green', lw=2)
 #  ax2.set_title("Closest Spikes Cluster 0")

In [None]:
X = []
from random import shuffle
import random
clusters = [cutouts2[:80], cutouts3[:80], cutouts4[:80], cutouts5[:80], cutouts6[:80]]
for cutouts_clusters in clusters:
    for cutout in cutouts_clusters:
        coefficients = pywt.swt(cutout[1:], sym4, level=2)
        new_coefficients = coefficients[:]
        new_coefficients = pywt.threshold(new_coefficients, 10, 'soft')
        data_point = []
        for level in new_coefficients:
            for coefficient_list in level:
                for element in coefficient_list:
                    data_point = [element] + data_point
        X.append(data_point)
Y = []        
for i in range(len(X)):
    if i < 50:
        Y.append(0)
    elif i < 100:
        Y.append(165)
    elif i < 150:
        Y.append(167)
    elif i < 200:
        Y.append(6)
    elif i < 250:
        Y.append(170)

c = list(zip(X, Y))
random.shuffle(c)
X, Y = zip(*c)


In [None]:
from sklearn import svm
clf = svm.SVC()
clf.fit(X, Y)

#cutouts2 = cluster 0
#cutouts3 = cluster 165
#cutouts4 = 167
#cutouts5 = 6
#cutouts6 = 170

cutout_full0 = np.array(list([H.spikes.Shape[spike] for spike in cluster_0])) 

to_predict = []
for cutout in cutout_full0:
    coefficients = pywt.swt(cutout[1:], sym4, level=2)
    new_coefficients = coefficients[:]
    new_coefficients = pywt.threshold(new_coefficients, 10, 'soft')
    data_point = []
    for level in new_coefficients:
            for coefficient_list in level:
                for element in coefficient_list:
                    data_point = [element] + data_point
    to_predict.append(data_point)
correct = 0
wrong = 0
for elem in clf.predict(to_predict):
    if(elem == 170):
        correct += 1
    if(elem !=170):
        wrong += 1
print(correct/(correct + wrong))