In [1]:
from pylab import *
from scipy import *
from scipy import stats, io, linalg
import numpy as np
import struct
import tables as tb
from phy.io import KwikModel
from attrdict import AttrDict
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import os as os
import codecs as codecs
import matplotlib.gridspec as gridspec
import matplotlib.ticker as mtick
from scipy.stats import chi2

In [None]:
def getexpinfo(kwik) :
    if '150716' in kwik : 
        exp=22
    if '151027' in kwik : 
        exp=23
    if '151103' in kwik : 
        exp=24
    if '151105' in kwik : 
        exp=25
    if '151110' in kwik : 
        exp=26
    if '151112' in kwik : 
        exp=27
    if '151116' in kwik : 
        exp=28
    if '151118' in kwik : 
        exp=29
    if '151208' in kwik : 
        exp=30
    if '151210' in kwik : 
        exp=31
    if '151214' in kwik : 
        exp=32
     
    if 'ele01_ele08' in kwik : 
        shank=1
    if 'ele09_ele16' in kwik : 
        shank=2
    if 'ele17_ele24' in kwik : 
        shank=3
    if 'ele25_ele32' in kwik : 
        shank=4
    if 'ele33_ele40' in kwik : 
        shank=5
    if 'ele41_ele48' in kwik : 
        shank=6
    if 'ele49_ele56' in kwik : 
        shank=7
    if 'ele57_ele64' in kwik : 
        shank=8
        
    if '1_ele' in kwik : 
        meas=1
    if '2_ele' in kwik : 
        meas=2
    if '3_ele' in kwik : 
        meas=3
    
    if exp==30 :
        if '-2_ele' in kwik : 
            meas=1
        if '-3_ele' in kwik : 
            meas=2
        if '-4_ele' in kwik : 
            meas=3
        if '-5_ele' in kwik : 
            meas=4
    return exp, meas, shank

def getwaveforms(kwik, UseAll=True) :
    ## This function will extract the waveforms from a kwik file and its associated dat file.
    ## If the UseAll option is True, all waveforms will be extracted including the ones that 
    ## are in clusters that were not included in the analysis
    ## If the UseAll option is False, only the good clusters will be taken and the cluster which
    ## was not marked as good and has the largest number of events will be considered the noise cluster
    if UseAll==False :
        model = KwikModel(kwik) # load kwik model from file
        clusters = model.spike_clusters # extract the clusters
        cluster_groups = model.cluster_groups # extract cluster group dictionary
        noiseclust = {} #initialize noise cluster dictionary
        for key, value in cluster_groups.items() : 
            if value != 2 : # look at all clusters that are not good
                noiseclust[key] = len(np.where(clusters==key)[0]) # store the size of each in dictionary
        v = list(noiseclust.values()) # list all the sizes of nongood clusters
        k = list(noiseclust.keys()) # list all the identities of nongood clusters
        noisecluster = k[v.index(max(v))] # finds identity of largest nongood cluster
    
        inds = [] # initialize the good waveforms to get
        waveformdata = np.zeros(model.waveforms.shape) # initialize waveforms
        for i in np.arange(waveformdata.shape[0]) :
            waveformdata[i,:,:] = model.waveforms[i] # add all waveforms at first
            if cluster_groups[clusters[i]]==2 :
                inds.append(i) # store the indexes of good waveforms to use
            if clusters[i]==noisecluster :
                inds.append(i) # store the indexes of noise waveforms to use
        waves = waveformdata[inds,:,:] # restrict waveforms returned to just good and noise clusters
        clusters = clusters[inds] # restrict the cluster list returned to just good and noise clusters
    
    if UseAll==True :
        model = KwikModel(kwik) # load kwik model from file
        clusters = model.spike_clusters # extract the clusters
        cluster_groups = model.cluster_groups # extract cluster gruop dictionary
        waves = np.zeros(model.waveforms.shape) # initialize waveforms
        for i in np.arange(waves.shape[0]) :
            waves[i,:,:] = model.waveforms[i] # add all waveforms
    return waves, clusters, cluster_groups

def getfeatures(kwik, numpcs=3, UseAll=True) :
    ## This function is to be used if rather than calculating new features, you take the features that KK used
    ## when it did the clustering
    if UseAll==False :
        model = KwikModel(kwik) # load kwik model from file
        clusters = model.spike_clusters # extract the clusters
        cluster_groups = model.cluster_groups # extract cluster group dictionary
        noiseclust = {} #initialize noise cluster dictionary
        for key, value in cluster_groups.items() : 
            if value != 2 : # look at all clusters that are not good
                noiseclust[key] = len(np.where(clusters==key)[0]) # store the size of each in dictionary
        v = list(noiseclust.values()) # list all the sizes of nongood clusters
        k = list(noiseclust.keys()) # list all the identities of nongood clusters
        noisecluster = k[v.index(max(v))] # finds identity of largest nongood cluster
    
        inds = [] # initialize the good waveforms to get
        features = model.features # extract features
        numelectrodes = np.int((features.shape[1]/numpcs))
        f = np.zeros([features.shape[0], numpcs, numelectrodes]) # initialize pcs data structure
        for i in np.arange(features.shape[0]) :
            f[i,:,:] = np.reshape(features[i], [numpcs, numelectrodes], order='F') # add all pcs at first
            if cluster_groups[clusters[i]]==2 :
                inds.append(i) # store the indexes of good waveforms to use
            if clusters[i]==noisecluster :
                inds.append(i) # store the indexes of noise waveforms to use
        
        pcs = f[inds,:,:] # restrict waveforms returned to just good and noise clusters
        clusters = clusters[inds] # restrict the cluster list returned to just good and noise clusters
    
    if UseAll==True :
        model = KwikModel(kwik) # load kwik model from file
        clusters = model.spike_clusters # extract the clusters
        cluster_groups = model.cluster_groups # extract cluster group dictionary
        features = model.features # extract features 
        pcs = np.zeros([features.shape[0], numpcs, numelectrodes]) # initialize pcs data structure
        for i in np.arange(features.shape[0]) :
            pcs[i,:,:] = np.reshape(features[i], [numpcs, numelectrodes], order='F')
    return pcs, clusters, cluster_groups, numelectrodes

def EnergyNorm(waveformdata, enorm=True) :
    waveformsnorm = np.zeros(waveformdata.shape) # initialize normalized waveforms
    E = np.zeros([waveformdata.shape[0], 1, waveformdata.shape[2]]) # initialize Energy
    for i in np.arange(waveformsnorm.shape[0]) :
        for j in np.arange(waveformsnorm.shape[2]) : 
            E[i,:,j] = np.sqrt(np.sum(np.square(waveformdata[i,:,j])))/waveformdata.shape[1] #compute energy
            waveformsnorm[i,:,j] = waveformdata[i,:,j]/E[i,:,j] #normalize
    if enorm == False :
        waveformsnorm = waveformdata
    return waveformsnorm, E

def PCA(waveformdata, numpcs=3) :
    pcs = np.zeros([waveformdata.shape[0], numpcs, waveformdata.shape[2]]) #initialize pcs
    for ch in np.arange(waveformdata.shape[2]) :
        data = waveformdata[:,:,ch] - waveformdata[:,:,ch].mean(axis=0) #mean center
        R = np.cov(data, rowvar=False) #covariance matrix
        evals, evecs = linalg.eigh(R) #eigenvector decomposition
        idx = np.argsort(evals)[::-1] #find the highest eigenvalues and return there index
        evecs = evecs[:,idx] #sort eigenvectors
        evals = evals[idx] #sort eigenvalues
        evecs = evecs[:, :numpcs] #throw out eigenvectors outside the desired range
        pcs[:,:,ch] = np.dot(evecs.T, data.T).T #project to compute PCS
    return pcs

def CalcMahalDist(waveformdata, E, pcs, clusters, cluster_groups, pcsinc=3, numelectrodes=8, enorm=True) :
    
    goodclusts = [] #initialize good cluster list
    for key, value in cluster_groups.items() :
        if value==2 :
            goodclusts.append(key) #adds clusters marked good to list
    goodclusts = np.sort(goodclusts)
    if enorm == True :
        numfeatures = (1+pcsinc)*numelectrodes #features are energy plus number of pcs included (pcsinc)
    if enorm == False :
        numfeatures = pcsinc*numelectrodes
    featurevecs = np.zeros([waveformdata.shape[0], numfeatures])
    for i in np.arange(waveformdata.shape[0]) :
        for j in np.arange(waveformdata.shape[2]) :
            if enorm == True :
                featurevecs[i,j] = E[i,0,j]
                if pcsinc==1 :
                    featurevecs[i,j+numelectrodes] = pcs[i,0,j]
                if pcsinc==2 :
                    featurevecs[i,j+numelectrodes] = pcs[i,0,j]
                    featurevecs[i,j+numelectrodes*2] = pcs[i,1,j]
                if pcsinc==3 :
                    featurevecs[i,j+numelectrodes] = pcs[i,0,j]
                    featurevecs[i,j+numelectrodes*2] = pcs[i,1,j]
                    featurevecs[i,j+numelectrodes*3] = pcs[i,2,j]
            if enorm == False :
                if pcsinc==1 :
                    featurevecs[i,j] = pcs[i,0,j]
                if pcsinc==2 :
                    featurevecs[i,j] = pcs[i,0,j]
                    featurevecs[i,j+numelectrodes] = pcs[i,1,j]
                if pcsinc==3 :
                    featurevecs[i,j] = pcs[i,0,j]
                    featurevecs[i,j+numelectrodes] = pcs[i,1,j]
                    featurevecs[i,j+numelectrodes*2] = pcs[i,2,j]
                    
    muC = {} #initialize cluster centres
    covC = {} #initialize cluster covariances
    for clust in goodclusts :
        inds = np.where(clusters==clust)[0]
        if enorm == True :
            if pcsinc==1 :
                muC[clust] = np.array(np.concatenate(([np.mean(E[inds,0,:], axis=0), np.mean(pcs[inds,0,:], axis=0)])))
                # This is coded to use the Energy and the 1st PC as features across the 8 channels
            if pcsinc==2 :
                muC[clust] = np.array(np.concatenate(([np.mean(E[inds,0,:], axis=0), np.mean(pcs[inds,0,:], axis=0), np.mean(pcs[inds,1,:], axis=0)])))
                # This is coded to use the Energy, 1st, and 2nd PC as features across the 8 channels
            if pcsinc==3 :
                muC[clust] = np.array(np.concatenate(([np.mean(E[inds,0,:], axis=0), np.mean(pcs[inds,0,:], axis=0), np.mean(pcs[inds,1,:], axis=0), np.mean(pcs[inds,2,:], axis=0)])))
                # This is coded to use the Energy, 1st, 2nd, and 3rd PC as features across the 8 channels
            covC[clust] = np.cov(featurevecs[inds,:]-np.mean(featurevecs[inds,:], axis=0), rowvar=False)
        if enorm == False :
            if pcsinc==1 :
                muC[clust] = np.mean(pcs[inds,0,:], axis=0)
                # This is coded to use the 1st PC as features across the 8 channels
            if pcsinc==2 :
                muC[clust] = np.array(np.concatenate(([np.mean(pcs[inds,0,:], axis=0), np.mean(pcs[inds,1,:], axis=0)])))
                # This is coded to use the 1st and 2nd PC as features across the 8 channels
            if pcsinc==3 :
                muC[clust] = np.array(np.concatenate(([np.mean(pcs[inds,0,:], axis=0), np.mean(pcs[inds,1,:], axis=0), np.mean(pcs[inds,2,:], axis=0)])))
                # This is coded to use the 1st, 2nd, and 3rd PC as features across the 8 channels
            covC[clust] = np.cov(featurevecs[inds,:]-np.mean(featurevecs[inds,:], axis=0), rowvar=False)
            
    D = np.zeros([waveformdata.shape[0], len(goodclusts)]) #initialize Mahalanobis distance
    j=0
    for clust in goodclusts :
        mu = muC[clust]
        sigma = linalg.inv(covC[clust])
        for i in np.arange(waveformdata.shape[0]) :
            D[i,j] = np.dot(np.dot(((featurevecs[i,:]-mu).T), sigma), (featurevecs[i,:]-mu))
        j+=1
    return D, goodclusts, featurevecs, muC, covC

def computeLratio(D, clusters, goodclusts, df=24) :
    Lr = {}
    j = 0
    for c in goodclusts :
        Dtemp = D[np.where(clusters!=c)[0],j]
        Lratio = 0
        for i in np.arange(Dtemp.shape[0]) :
            Lratio += 1 - chi2.cdf(Dtemp[i], df, loc=0, scale=1)
        Lr[c] = Lratio/len(np.where(clusters==c)[0])
        j+=1
    return Lr

def computeIsolationDistance(D, clusters, goodclusts) :
    Id = {}
    j = 0
    for c in goodclusts :
        Dtemp = D[np.where(clusters!=c)[0],j] # look at one cluster, just the out of cluster spikes, column j
        idx = np.argsort(Dtemp) # find the order or indexes to sort this array from smallest to largest
        Dtemp = Dtemp[idx] # sort those spikes by Dvalue, smallest to largest
        indx = len(np.where(clusters==c)[0]) # find the index corresponding to the number of spikes in the cluster
        Id[c] = Dtemp[indx-1] # subtract one because the indexing starts at 0
        j+=1
    return Id

def ClusterSummary(D, Lr, Id, goodclusts, pdf_files_directory, exp, meas, shank) :   
    MIN, MAX = 0.1, 10.0
    i=0
    numclusts = len(goodclusts)
    numbins = 500
    nGOOD = np.zeros([numclusts,numbins-1])
    nNOISE = np.zeros([numclusts,numbins-1])
    for c in goodclusts :
        GOOD = D[np.where(clusters==c)[0],i]
        NOISE = D[np.where(clusters!=c)[0],i]
        n1, bins, patches = hist(np.log10(GOOD), bins = 10 ** np.linspace(np.log10(MIN), np.log10(MAX), numbins))
        nGOOD[i,:] = n1
        close()
        n2, bins, patches = hist(np.log10(NOISE), bins = 10 ** np.linspace(np.log10(MIN), np.log10(MAX), numbins))
        nNOISE[i,:] = n2
        close()
        i+=1
    
    normnum = (1/np.sum(nGOOD+nNOISE))
    height = np.max(nGOOD, axis=1)/(1/normnum)
    clf()
    
    xmin = np.log10(0.99*np.min(D))
    xmax = np.log10(1.1*np.max(D))
    
    
    i=0
    for c in goodclusts :
        GOOD = D[np.where(clusters==c)[0],i]
        NOISE = D[np.where(clusters!=c)[0],i]
        ax1 = subplot(np.int(np.floor((len(goodclusts))/3))+1,3,i+1, frame_on=True)
        hist(np.log10(GOOD), bins = 10 ** np.linspace(np.log10(MIN), np.log10(MAX), numbins), color='b', alpha=0.5, edgecolor='none', histtype='stepfilled', label='n = '+str(len(GOOD)), weights=np.repeat(normnum, len(GOOD)))
        hist(np.log10(NOISE), bins = 10 ** np.linspace(np.log10(MIN), np.log10(MAX), numbins), color='g', alpha=0.5, edgecolor='none', histtype='stepfilled', label='n = '+str(len(NOISE)), weights=np.repeat(normnum, len(NOISE)))
        ymax = 1.1*height[i]
        ylim(0, ymax)
        xlim(xmin, xmax)
        ax1.tick_params(axis='y', which='both', left='on', right='off', labelsize=4, width=0.8)
        ax1.set_yticks(np.around(np.linspace(0, ymax, 4), decimals=6), minor=False)
        ax1.yaxis.set_major_formatter(FormatStrFormatter('%.2e'))
        gca().set_xscale("log")
        ax1.tick_params(axis='x', which='both', bottom='on', top='off', labelbottom='off', width=0.8)
        if c==goodclusts[-1] :
            ax1.tick_params(axis='x', which='both', bottom='on', top='off', labelbottom='on', labelsize=4, width=0.8)
            ax1.xaxis.set_minor_formatter(FormatStrFormatter('%.1f'))
            ax1.set_xlabel(r'$\log_{10}(D^2)$', fontsize = 6)
            ax1.xaxis.set_label_coords(0.5, -0.05)
        ax1.set_title('Nrn' + str(c),fontsize=8)
        handles, labels = ax1.get_legend_handles_labels()
        legend = ax1.legend(handles, labels, loc=[0.1,0.5], prop={'size':4})
        legend.set_title(title='Lr = '+str(np.around(Lr[c],decimals=2))+'\nId = '+str(np.int(np.around(Id[c])))+'\n', prop={'size':6})
        legend.get_frame().set_linewidth(0.8)
        legend.get_frame().set_edgecolor('red')
        ax1.spines['bottom'].set_linewidth(0.8)
        ax1.spines['bottom'].set_color('black')
        ax1.spines['left'].set_linewidth(0.8)
        ax1.spines['left'].set_color('black')
        ax1.spines['right'].set_linewidth(0)
        ax1.spines['top'].set_linewidth(0)        
        i+=1
    
    suptitle('Cluster Statistics for ' + 'Exp' + str(exp) + '_Meas' + str(meas) + '_Shank' + str(shank),fontsize=12)
    savefig(pdf_files_directory + 'Exp' + str(exp) + '_Meas' + str(meas) + '_Shank' + str(shank) + '_ClusterMetrics.pdf', format='pdf')
    clf()
    
    fig = plt.figure()
    left, width = 0.1, 0.65
    bottom, height = 0.1, 0.65
    bottom_h = left_h = left + width + 0.02
    
    rect_scatter = [left, bottom, width, height]
    rect_boxx = [left, bottom_h, width, 0.2]
    rect_boxy = [left_h, bottom, 0.2, height]
    
    axScatter = plt.axes(rect_scatter)
    axBoxx = plt.axes(rect_boxx)
    axBoxy = plt.axes(rect_boxy)
    axScatter.spines['bottom'].set_linewidth(0.8)
    axScatter.spines['bottom'].set_color('black')
    axScatter.spines['left'].set_linewidth(0.8)
    axScatter.spines['left'].set_color('black')
    axScatter.spines['right'].set_linewidth(0)
    axScatter.spines['top'].set_linewidth(0)
    axScatter.tick_params(axis='x', which='both', bottom='on', top='off', labelbottom='on', labelsize=8, width=0.8)
    axScatter.tick_params(axis='y', which='both', left='on', right='off', labelleft='on', labelsize=8, width=0.8)
    axScatter.set_xlabel(r'$L_{ratio}$', fontsize = 12)
    axScatter.set_ylabel(r'$I_{dist}$', fontsize = 12)
    axBoxx.spines['bottom'].set_linewidth(0)
    axBoxx.spines['left'].set_linewidth(0)
    axBoxx.spines['right'].set_linewidth(0)
    axBoxx.spines['top'].set_linewidth(0)
    axBoxx.tick_params(axis='x', which='both', bottom='off', top='off', labelbottom='off')
    axBoxx.tick_params(axis='y', which='both', left='off', right='off', labelleft='off')
    axBoxy.spines['bottom'].set_linewidth(0)
    axBoxy.spines['left'].set_linewidth(0)
    axBoxy.spines['right'].set_linewidth(0)
    axBoxy.spines['top'].set_linewidth(0)
    axBoxy.tick_params(axis='x', which='both', bottom='off', top='off', labelbottom='off')
    axBoxy.tick_params(axis='y', which='both', left='off', right='off', labelleft='off')
    
    Lratio = np.zeros(len(Lr))
    Idist = np.zeros(len(Id))
    names = []
    j=0
    for c in goodclusts :
        Lratio[j]=Lr[c]
        Idist[j]=Id[c]
        names.append(c)
        j+=1
            
    xlimit = np.max(np.fabs(Lratio))*1.1
    ylimit = np.max(np.fabs(Idist))*1.1
    axScatter.set_xlim((0, xlimit))
    axScatter.set_ylim((0, ylimit))
    axBoxx.set_xlim(axScatter.get_xlim())
    axBoxy.set_ylim(axScatter.get_ylim())
    
    axScatter.scatter(Lratio, Idist)
    
    for i, txt in enumerate(names):
        axScatter.annotate(txt, (Lratio[i],Idist[i]))
    
    axBoxx.boxplot(Lratio,0,'rs',0)
    axBoxy.boxplot(Idist)
    
    suptitle('Summary Statistics for ' + 'Exp' + str(exp) + '_Meas' + str(meas) + '_Shank' + str(shank),fontsize=12)
    savefig(pdf_files_directory + 'Exp' + str(exp) + '_Meas' + str(meas) + '_Shank' + str(shank) + '_MetricSummary.pdf', format='pdf')
    clf()

def getkwiks(startdir='.') :
    kwikfiles = []
    for dirpath, dirnames, filenames in os.walk(startdir) :
        for filename in [f for f in filenames if f.endswith('.kwik')]:
            kwikfiles.append(os.path.join(dirpath, filename))
    return kwikfiles

In [None]:
kwiks = getkwiks(startdir='.')
numpcs=3

for i in np.arange(len(kwiks)) :
    kwik = kwiks[i]
    method = 'kkfeatures'

    if method=='kkfeatures' :
        pcs, clusters, cluster_groups, numelectrodes = getfeatures(kwik, numpcs, False)
        waveformdata = pcs
        E = pcs[:,0,:]

    if method=='waveforms' :
        waveformdata, clusters, cluster_groups = getwaveforms(kwik, False)
        waveformdata, E = EnergyNorm(waveformdata, enorm=False)
        pcs = PCA(waveformdata, numpcs=3)

    exp, meas, shank = getexpinfo(kwik)
    D, goodclusts, featurevecs, muC, covC = CalcMahalDist(waveformdata, E, pcs, clusters, cluster_groups, pcsinc=3, enorm=False, numelectrodes=numelectrodes)
    Lr = computeLratio(D, clusters, goodclusts, df=numelectrodes*numpcs)
    Id = computeIsolationDistance(D, clusters, goodclusts)
    ClusterSummary(D, Lr, Id, goodclusts, './', exp, meas, shank)