## STEP 1: LOAD DEPENDENCIES, SET GLOBAL PARAMETERS, DEFINE UTILITY FUNCTIONS

In [1]:
# Import Libraries and Library Settings
import numpy as np
import scipy.io as sio
import scipy.signal as signal
import scipy.stats as stats
from os.path import dirname, join as pjoin
import time
import matplotlib.pyplot as plt
%matplotlib tk
from matplotlib.animation import FuncAnimation
# import glob
import os
# import re
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.widgets import Button
from tkinter import simpledialog
import tkinter as tk
from sklearn.mixture import GaussianMixture

# Setting Params for plotting & troubleshooting
np.set_printoptions(threshold=np.inf)
plt.rcParams['animation.html']='jshtml'

# Define Functions in Use

# If you have a channel's row and col, return the channel's index
def map2idx(ch_row,ch_col):
    if ch_row > 31 or ch_row <0:
        print('Row out of range')
    elif ch_col >31 or ch_col<0:
        print('Col out of range')
    else:
        ch_idx = int(ch_row*32 + ch_col)
    return ch_idx

# If you have a channel index, return the channel's row and col
def idx2map(ch_idx):
    if ch_idx > 1023 or ch_idx < 0:
        print('Chan num out of range')
    else:
        ch_row = int(ch_idx/32)
        ch_col = int(ch_idx - ch_row*32)
    return (ch_row,ch_col)

def newSpikeThres(newThres):
    global spikeDev
    global spikeThres
    global spanIdx
    spikeDev = newThres
    answer = 3
    if hiLITEspikes != True:
        answer = simpledialog.askstring(title='Spike Highlighting Off',prompt='Gaussian Mixture (Step 8) Needs to be Run First. Proceed? (1=Yes, 0=No)')
    if answer == 1 or hiLITEspikes==True:
        for k in range(0,numChan):
            start = time.time()
            # Determine the negative threshold for a spike
            spikeThres[chMap[0,k],chMap[1,k]] = mean_hat[chMap[0,k],chMap[1,k],int(noiseIdx[chMap[0,k],chMap[1,k]])]-spikeDev*std_hat[chMap[0,k],chMap[1,k],int(noiseIdx[chMap[0,k],chMap[1,k]])]
            # Locate all spikes
            spanIdx[k] = np.where(dataFilt[chMap[0,k],chMap[1,k],:]<spikeThres[chMap[0,k],chMap[1,k]])
            spanIdx[k] = np.asarray(spanIdx[k][0])     

            # Remove consecutive samples for the same spike
            for x in reversed(range(0,len(spanIdx[k]))):
                if spanIdx[k][x] - spanIdx[k][x-1] == 1:
                    spanIdx[k] = np.delete(spanIdx[k],x)

            end = time.time()
            text1 = 'Estimated Time Remaining: ' + str.format('{0:.2f}', (end-start)*(numChan-k)/60 ) + ' min'
            text2 = str(k+1)+'/'+str(numChan)+' Channels Filtered'
            print( text1 + ' ' + text2, end="\r" )
    elif answer == 0:
        pass
    else:
        print('Not acceptable answer. Type 1 or 0')

TO DOS:
-OPTION TO PLOT CHANNELS WITH MOST SPIKES FIRST
-GENERATE LIST OF WHAT FIG NUM EACH CHANNEL IS ON (spikeCntIdx)
-PARALLELIZE CHANNEL FILTERING AND GAUSSIANMIXTURE CODE (OR FIND FASTER FILTER)
-MAKE X AXIS LOOK GOOD ON ANY DISPLAY (NO SQUISHING NUMBERS TOGETHER)
-DONT HIGHLIGHT PHYSICALLY IMPOSSIBLE EVENTS AS SPIKES (would prefer to make this as a distribution measure)

## STEP 2: INITIALIZE VARIABLES FOR LOADING AND CONVERTING DATA

In [7]:
# For current variable names
path = '/Users/vision/Desktop/Xilinx'
date_piece = '2022-02-18-1'
datarun = 'data000'
buffer_num = 0

string = 'gmem1' # Name of data variable in the mat files

## STEP 3: INITIALIZE BUTTON FUNCTIONS

In [3]:
def callback_next_btn1(event):
    fig = plt.gcf()
    figNum = plt.gcf().number
    print('Reset button pressed')
    for m in range(0,plotPerFig):
        axs = fig.get_axes()
        ax = axs[m]
        ax.set_xlim([times[int(startIdx[figNum*plotPerFig+m])],times[int(startIdx[figNum*plotPerFig+m])]+int(timeWin)])
        # Set xlim to the time window of the first recorded values
        if manSety == True:
            plt.ylim([ymi,yma])
        ax = plt.gca()
        ax.ticklabel_format(useOffset=False)
        ax.relim()
        ax.autoscale_view()
        
        plt.draw()
        
def callback_next_btn2(event):
    fig = plt.gcf()
    print('Next button pressed')
    for m in range(0,plotPerFig):
        axs = fig.get_axes()
        ax = axs[m]
        [xLeft, xRight] = ax.get_xlim()
        ax.set_xlim([xLeft+timeWin,xRight+timeWin])
        # Set xlim to the time window of the first recorded values
        if manSety == True:
            plt.ylim([ymi,yma])
        ax = plt.gca()
        ax.ticklabel_format(useOffset=False)
        ax.relim()
        ax.autoscale_view()
        
        plt.draw()
        
def callback_next_btn3(event):
    fig = plt.gcf()
    print('Back button pressed')
    for m in range(0,plotPerFig):
        axs = fig.get_axes()
        ax = axs[m]
        [xLeft, xRight] = ax.get_xlim()
        ax.set_xlim([xLeft-timeWin,xRight-timeWin])
        # Set xlim to the time window of the first recorded values
        if manSety == True:
            plt.ylim([ymi,yma])
        ax = plt.gca()
        ax.ticklabel_format(useOffset=False)
        ax.relim()
        ax.autoscale_view()
        
        plt.draw()
        
def callback_next_btn4(event):
    fig = plt.gcf()
    print('Time Window button pressed')
    global timeWin
    timeStr = simpledialog.askstring(title='Time Change',prompt='Enter desired time window (ms):')
    timeWin = int(timeStr)

    for m in range(0,plotPerFig):
        axs = fig.get_axes()
        ax = axs[m]
        [xLeft, xRight] = ax.get_xlim()
        ax.set_xlim([xLeft,xLeft+timeWin])
        # Set xlim to the time window of the first recorded values
        plt.axes(ax)
        if manSety == True:
            plt.ylim([ymi,yma])
             
        ax = plt.gca()
        ax.ticklabel_format(useOffset=False)
        ax.relim()
        ax.autoscale_view()
        
        plt.draw()
        
def callback_next_btn5(event):
    fig_cur = plt.gcf()
    print('Next Fig Pressed')
    global gk
    figStr = simpledialog.askstring(title='Plot Next Figure',prompt='Enter Figure # To Display: 0 thru ' + str(numFigs-1))
    gk = int(figStr)
    
    closeStr = simpledialog.askstring(title='Close Current Figure?',prompt='Enter 1 to close current figure and 0 to remain open')
    closeYN = int(closeStr)
    if closeYN == 1:
        plt.close(fig_cur)
    elif closeYN != 0 and closeYN != 1:
        closeYN = 0
        print('Only enter 1 or 0 to close or keep open current figure')
    
    if gk < 0 or gk > numFigs:
        gk = 0
        print('Requested Figure Out of Range')
        
    # Generate figure and define style
    fig = plt.figure(gk,figsize=(18,12),facecolor='white',constrained_layout=False)
    plt.style.use('fivethirtyeight')

    # Figure Title
    title = 'Figure ' + str(gk) + ':'
    startElec = str(int(chId[gk*plotPerFig]))

    # Generate subplots
    for m in range(1,plotPerFig+1):
        plt.subplot(numRows,numCols,m)
        # Check if this range has a potential spike
        plt.plot(times,dataFilt[chMap[0,gk*plotPerFig+m-1],chMap[1,gk*plotPerFig+m-1],:],linewidth=1,color='k')

        # Set xlim to the time window of the first recorded values
        plt.xlim([times[int(startIdx[gk*plotPerFig+m-1])],times[int(startIdx[gk*plotPerFig+m-1])]+int(timeWin)])

        if manSety == True:
            plt.ylim([ymi,yma])
        ax = plt.gca()
        ax.ticklabel_format(useOffset=False)
        ax.relim()
        ax.autoscale_view()
        text_elec = 'Ch # ' + str(int(chId[gk*plotPerFig+m-1]))
        ax.annotate(text_elec,(0.6,0.9),xycoords='axes fraction',va='center',fontsize = 11,color='blue')

        # Check for spikes & Plot Spans for All:
        if hiLITEspikes:
            spanStarts = spanIdx[gk*plotPerFig+m-1]
            if np.shape(spanStarts) != (0,):
                for i in range(0,len(spanStarts)):
                    if spanStarts[i] <= 5:
                        plt.axvspan(times[0], times[spanStarts[i]+15], ec = 'white', color='yellow',alpha=0.75)
                    else:
                        plt.axvspan(times[spanStarts[i]-5], times[spanStarts[i]+15], ec = 'white', color='yellow',alpha=0.75)


        # Update so we plot all electrodes
        if gk*plotPerFig+m-1 >= numChan:
            break

    # Adjust the subplots to take the entire screen        
    plt.subplots_adjust(left=0.03,right=0.978,bottom=0.05,top=0.95,wspace=0.25,hspace=0.2)

    # Figure Title 
    stopElec = str(int(chId[gk*plotPerFig+m-1]))
    figTitle = title + ' Ch ' + startElec + ' to Ch ' + stopElec
    axs = fig.get_axes()
    ax = axs[-1]
    ax.set_xlabel('Time (ms)')
    ax.set_ylabel('ADC')
    ax.yaxis.set_label_position("right")
    ax = axs[0]
    ax.set_title(figTitle,fontsize=16,ha='center',va='bottom',fontweight='bold')

    plt.xlabel('Time (ms)')
    plt.ylabel('ADC',loc='bottom')

    # Add Buttons
    # Reset
    axButn1 = plt.axes(([0.7, 0.955, 0.05, 0.04]))
    btn1 = Button(axButn1, label="Reset",color='pink',hovercolor='red')
    butTrack[0,gk] = btn1
    butTrack[0,gk].on_clicked(callback_next_btn1)
    # Next
    axButn2 = plt.axes(([0.9, 0.955, 0.05, 0.04]))
    btn2 = Button(axButn2, label="Next",color='lightblue',hovercolor='blue')
    butTrack[1,gk] = btn2
    butTrack[1,gk].on_clicked(callback_next_btn2)
    # Back
    axButn3 = plt.axes(([0.85, 0.955, 0.05, 0.04]))
    btn3 = Button(axButn3, label="Back",color='lightblue',hovercolor='blue')
    butTrack[2,gk] = btn3
    butTrack[2,gk].on_clicked(callback_next_btn3)
    # Button for Time Window
    axButn4 = plt.axes(([0.5, 0.955, 0.1, 0.04]))
    btn4 = Button(axButn4, label="Alt Time Window", color = 'lightgreen',hovercolor='green')
    butTrack[3,gk] = btn4
    butTrack[3,gk].on_clicked(callback_next_btn4)
    # Button for Next Figure
    nextFig = 0
    axButn5 = plt.axes(([0.25,0.955,0.05,0.04]))
    btn5 = Button(axButn5, label="Next Fig",color='gold',hovercolor='goldenrod')
    butTrack[4,gk] = btn5
    butTrack[4,gk].on_clicked(callback_next_btn5)
    # Button for Scale Adjust
    axButn6 = plt.axes(([0.45, 0.955, 0.05, 0.04]))
    btn6 = Button(axButn6, label="Y Scale", color = 'lightgreen',hovercolor='green')
    butTrack[5,gk] = btn6 
    butTrack[5,gk].on_clicked(callback_next_btn6)

    # Flush the plot out
    plt.show()
    
def callback_next_btn6(event):
    fig = plt.gcf()
    print('Y Scale button pressed')
    global ymi, yma, manSety
    ymiStr = simpledialog.askstring(title='Y Scale Change',prompt='Enter desired ymin (ADC):')
    ymi = int(ymiStr)
    ymaStr = simpledialog.askstring(title='Y Scale Change',prompt='Enter desired ymax (ADC):')
    yma = int(ymaStr)

    for m in range(0,plotPerFig):
        axs = fig.get_axes()
        ax = axs[m]
        ax.set_ylim([ymi,yma])
        manSety = True
        plt.axes(ax)
             
        ax = plt.gca()
        ax.ticklabel_format(useOffset=False)
        ax.relim()
        ax.autoscale_view()
        
        plt.draw()

## STEP 4: LOAD AND CONVERT ALL THE DATA
Either from original file, compressed file with a few spikes, or compressed file with many spikes

In [8]:
# Original Data
fileDir = pjoin(date_piece,datarun)
bufDir = os.listdir(fileDir)
num_of_buf = len(bufDir)
bramdepth = 65536

# Initialize Variables
timeTrack = 0
cntTrack = 0
dataAll = np.zeros((32,32,int(bramdepth*num_of_buf/2))) #Largest possible value of dataAll, perfect recording, only double cnt
cntAll = np.zeros((32,32,int(bramdepth*num_of_buf/2)))
times = np.zeros((int(bramdepth*num_of_buf/2)))

#Idea: plot in order of which buffers have the most 'spikes'

# Process all the Data

for k in range(buffer_num,num_of_buf):
    start = time.time()
    
    # Load Data from this Loop's Buffer
    file_next = pjoin(date_piece,datarun,datarun+'_'+str(k)+'.mat')
    mat_contents = sio.loadmat(file_next)
    dataRaw = mat_contents[string][0][:]
    
    # Initialize Variables Needed for Each Buffer
    chan_index_pre = 1025 #Check for chan changes, double cnt
    cnt_pre = 0 #Check for cnt changes, double cnt
    N = 0 #Sample times (DOES NOT ALLOW NON-COLLISION FREE SAMPLES)
    data_real = np.zeros((32,32,len(dataRaw)-2))  #Initialize to max possible length. Note: Throw out first two values b/c garbo
    cnt_real = np.zeros((32,32,len(dataRaw)-2))
    
    # Convert data and remove double/triple counts
    for i in range(2,len(dataRaw)-1): 
            # Convert bit number into binary
            word = (np.binary_repr(dataRaw[i],32))

            # Break that binary into it's respective pieces and convert to bit number
            cnt = int(word[12:14],2)
            col = int(word[27:32],2)
            row = int(word[22:27],2)
            chan_index = row*32 + col

            # Only record the unique non-double count sample
            if(i==2 or (cnt_pre != cnt or chan_index != chan_index_pre)):

                # Sample time only changes when cnt changes
                if cnt != cnt_pre:
                    N += 1

                # On the occurance the first cnt is not 0, make sure sample time is 0
                if i == 2:
                    N = 0

                # Update variables
                cnt_pre = cnt
                chan_index_pre = chan_index

                # Record pertinent data
                data_real[row][col][N] = int(word[14:22],2)
                cnt_real[row][col][N] = cnt
    
    # Determine time estimate and sample counts for the total combined buffers
    if timeTrack == 0: #For the first buffer, we assume the first sample comes in at time 0
        dataAll[:,:,:N] = data_real[:,:,:N]
        cntAll[:,:,:N] = cnt_real[:,:,:N]
        endTime = N*0.05 # 20kHz sampling rate, means time_recording (ms) = num_sam*0.05ms
        new_times = np.linspace(0,endTime,N+1)
        times[0:len(new_times)] = new_times
    elif timeTrack != 0: # For buffers after the first, we place these values directly after the previous buffer (note this does not take into account communication delays - hence an estimate)
        dataAll[:,:,cntTrack:cntTrack+N] = data_real[:,:,:N]
        cntAll[:,:,cntTrack:cntTrack+N] = cnt_real[:,:,:N]
        endTime = N*0.05
        new_times = np.linspace(timeTrack,timeTrack+endTime,N)
        times[cntTrack:cntTrack+N] = new_times
    
    # Update for the next buffer file
    timeTrack += endTime
    cntTrack += N
    
    end = time.time()
    text1 = 'Estimated Time Remaining: ' + str.format('{0:.2f}', (end-start)*(num_of_buf-k)/60 ) + ' min'
    text2 = file_next
    print( text1 + ' ' + text2, end="\r" )
    
# Truncate dataAll, cntAll, and times to remove the 0's representing lost data potential due to triple counts
dataAll = dataAll[:,:,:cntTrack-1]
cntAll = cntAll[:,:,:cntTrack-1]
times = times[:cntTrack-1]
        

Estimated Time Remaining: 0.01 min 2022-02-18-1/data000/data000_255.mat

In [28]:
# Data with Litke Spikes
dataAllLoad = np.load("debugData/compressed_dataAll_2022-01-14-1_data001_litkeSpikes.npz")
dataAll = dataAllLoad['modified_data']
dataAllSmall = dataAllLoad['small_spike_idx']
dataAllLarge = dataAllLoad['large_spike_idx']
times = np.linspace(0,np.shape(dataAll)[2]*0.05,np.shape(dataAll)[2])

In [6]:
# Data with Many Litke Spikes
dataAllLoad = np.load("compressed_dataAll_2022-01-14-1_data001_litkeSpikes_many.npz")
dataAll = dataAllLoad['modified_data']
dataAllLarge = dataAllLoad['large_spike_idx']
times = np.linspace(0,np.shape(dataAll)[2]*0.05,np.shape(dataAll)[2])

## STEP 5: IDENTIFY RELEVANT CHANNELS

In [9]:
# Find which channels were recorded
numSam = np.zeros((32,32))
numSam = np.count_nonzero(dataAll,axis=2) # This bit takes the longest. ~ 30 sec for whole array 4 channel recording
numChan = np.count_nonzero(numSam)

# Map and Identify recorded channels
findCoors = np.nonzero(numSam)
chMap = np.array(findCoors)
chId = np.zeros((numChan))
startIdx = np.zeros((numChan))
for k in range(0,numChan):
    start = time.time()
    chId[k] = map2idx(chMap[0,k],chMap[1,k])
    startIdx[k] = (dataAll[chMap[0,k],chMap[1,k],:]!=0).argmax(axis=0)
    end = time.time()
    text1 = 'Estimated Time Remaining: ' + str.format('{0:.2f}', (end-start)*(numChan-k)/60 ) + ' min'
    text2 = str(k+1)+'/'+str(numChan)+' Channels Filtered'
    print( text1 + ' ' + text2, end="\r" )

Estimated Time Remaining: 0.00 min 1024/1024 Channels Filtered

## STEP 6: FILTERING THE DATA

In [10]:
# CHOOSE FILTER SETTING
# Filter
filtType = 'modHierlemann'

# fastBandpass = filtfilt, passband[250,4000], order = 1 (~8 min)
# fasterBandpass = sosfiltfilt, passband[250,4000], order = 1 (~7.5 min)
# modHierlemann = filtfilt, FIR, [250,4000], 75 taps (~4 min)
# auto = filtfilt, pass[250,4000],stop[5,6000], maxPassLoss = 3, minStopLoss = 30, determines order and cutoff needed to reach this (~50 min)
# hObandpass = filtfilt, passband [250,4000], order = 5 (~40 min)
# Hierlemann = filtfilt,FIR,[100],75 taps (~4 min)
# Litke = filtfilt, [250,2000],order = 2 (~18 min)
# highpass = filtfilt, [250], order = 5 (~26 min)
# none = no filtering of data (~0 min)



In [11]:
## APPLY FILTER
# Future update: only calculate for channels recorded not all
dataFilt = np.zeros((np.shape(dataAll)))
if filtType == 'Hierlemann':
    BP_LOW_CUTOFF = 100.0
    NUM_TAPS = 75
    TAPS = signal.firwin(NUM_TAPS,
                         [BP_LOW_CUTOFF, ],
                         pass_zero=False,
                         fs=20e3 * 1.0)
    a = 1
    for k in range(0,numChan):
        start = time.time()    
        dataFilt[chMap[0,k],chMap[1,k],:] = signal.filtfilt(TAPS,[a],dataAll[chMap[0,k],chMap[1,k],:])
        end = time.time()
        text1 = 'Estimated Time Remaining: ' + str.format('{0:.2f}', (end-start)*(numChan-k)/60 ) + ' min'
        text2 = str(k+1)+'/'+str(numChan)+' Channels Filtered'
        print( text1 + ' ' + text2, end="\r" )
        
if filtType == 'modHierlemann':
    BP_LOW_CUTOFF = 250.0
    BP_HIGH_CUTOFF = 4000.0
    NUM_TAPS = 100
    TAPS = signal.firwin(NUM_TAPS,
                         [BP_LOW_CUTOFF, BP_HIGH_CUTOFF],
                         pass_zero=False,
                         fs=20e3 * 1.0)
    a = 1
    for k in range(0,numChan):
        start = time.time()    
        dataFilt[chMap[0,k],chMap[1,k],:] = signal.filtfilt(TAPS,[a],dataAll[chMap[0,k],chMap[1,k],:])
        end = time.time()
        text1 = 'Estimated Time Remaining: ' + str.format('{0:.2f}', (end-start)*(numChan-k)/60 ) + ' min'
        text2 = str(k+1)+'/'+str(numChan)+' Channels Filtered'
        print( text1 + ' ' + text2, end="\r" )
            
elif filtType == 'highpass':
    nyq = 0.5 * (20e3*1.0)
    cutoff = 250/nyq
    b, a = signal.butter(5, [cutoff],btype="highpass",analog=False)
    for k in range(0,numChan):
        start = time.time()
        dataFilt[chMap[0,k],chMap[1,k],:] = signal.filtfilt(b, a,dataAll[chMap[0,k],chMap[1,k],:])
        end = time.time()
        text1 = 'Estimated Time Remaining: ' + str.format('{0:.2f}', (end-start)*(numChan-k)/60 ) + ' min'
        text2 = str(k+1)+'/'+str(numChan)+' Channels Filtered'
        print( text1 + ' ' + text2, end="\r" )
    
elif filtType == 'hObandpass':
    nyq = 0.5 * (20e3*1.0)
    cutoff1 = 250/nyq
    cutoff2 = 4000/nyq
    b, a = signal.butter(5, [cutoff1, cutoff2],btype="bandpass",analog=False)
    print('Order = ' + str(5))
    for k in range(0,numChan):
        start = time.time()
        dataFilt[chMap[0,k],chMap[1,k],:] = signal.filtfilt(b,a,dataAll[chMap[0,k],chMap[1,k],:])
        end = time.time()
        text1 = 'Estimated Time Remaining: ' + str.format('{0:.2f}', (end-start)*(numChan-k)/60 ) + ' min'
        text2 = str(k+1)+'/'+str(numChan)+' Channels Filtered'
        print( text1 + ' ' + text2, end="\r" )

            
elif filtType == 'auto':
    samFreq = 20e3*1.0
    passband = [250,4000]
    stopband = [5,6000]
    max_loss_passband = 3
    min_loss_stopband = 30
    order, normal_cutoff = signal.buttord(passband,stopband,max_loss_passband,min_loss_stopband,fs=samFreq)
    b, a = signal.butter(order, normal_cutoff, btype='bandpass',fs=samFreq )
    print('Order = ' + str(order))
    for k in range(0,numChan):
        start = time.time()
        dataFilt[chMap[0,k],chMap[1,k],:] = signal.filtfilt(b,a,dataAll[chMap[0,k],chMap[1,k],:])
        end = time.time()
        text1 = 'Estimated Time Remaining: ' + str.format('{0:.2f}', (end-start)*(numChan-k)/60 ) + ' min'
        text2 = str(k+1)+'/'+str(numChan)+' Channels Filtered'
        print( text1 + ' ' + text2, end="\r" )
        

            
elif filtType == 'fastBandpass':
    nyq = 0.5 * (20e3*1.0)
    cutoff1 = 250/nyq
    cutoff2 = 4000/nyq
    b, a = signal.butter(1, [cutoff1, cutoff2], btype='bandpass',analog=False )
    print('Order = ' + str(1))
    for k in range(0,numChan):
        start = time.time()
        dataFilt[chMap[0,k],chMap[1,k],:] = signal.filtfilt(b,a,dataAll[chMap[0,k],chMap[1,k],:])
        end = time.time()
        text1 = 'Estimated Time Remaining: ' + str.format('{0:.2f}', (end-start)*(numChan-k)/60 ) + ' min'
        text2 = str(k+1)+'/'+str(numChan)+' Channels Filtered'
        print( text1 + ' ' + text2, end="\r" )
        
elif filtType == 'fasterBandpass':
    nyq = 0.5 * (20e3*1.0)
    cutoff1 = 250/nyq
    cutoff2 = 4000/nyq
    sos1 = signal.butter(1, [cutoff1, cutoff2], btype='bandpass',output='sos')
    print('Order = ' + str(1))
    for k in range(0,numChan):
        start = time.time()
        dataFilt[chMap[0,k],chMap[1,k],:] = signal.sosfiltfilt(sos1,dataAll[chMap[0,k],chMap[1,k],:])
        end = time.time()
        text1 = 'Estimated Time Remaining: ' + str.format('{0:.2f}', (end-start)*(numChan-k)/60 ) + ' min'
        text2 = str(k+1)+'/'+str(numChan)+' Channels Filtered'
        print( text1 + ' ' + text2, end="\r" )
            
            
elif filtType == 'Litke':
    nyq = 0.5 * (20e3*1.0)
    cutoff1 = 250/nyq
    cutoff2 = 2000/nyq
    b, a = signal.butter(2, [cutoff1, cutoff2],btype="bandpass",analog=False)
    for k in range(0,numChan):
        start = time.time()
        dataFilt[chMap[0,k],chMap[1,k],:] = signal.filtfilt(b, a,dataAll[chMap[0,k],chMap[1,k],:])
        end = time.time()
        text1 = 'Estimated Time Remaining: ' + str.format('{0:.2f}', (end-start)*(numChan-k)/60 ) + ' min'
        text2 = str(k+1)+'/'+str(numChan)+' Channels Filtered'
        print( text1 + ' ' + text2, end="\r" )

elif filtType == 'none':
    dataFilt = np.copy(dataAll)

else:
    dataFilt = np.copy(dataAll)
    print('Filter not recognized. Options include Hierlemann, highpass, bandpass, Litke or none')

Estimated Time Remaining: 0.00 min 1024/1024 Channels Filtered

## STEP 7: CUSTOMIZE PLOT SETTINGS

In [12]:
# Spike ID
timeWin = 20 #timeInterval you wish to view and find spikes in
spikeDev = 3.5 # 3.291 is 1/1000 False Positives (1 FP in ~50ms)
hiLITEspikes = True # Do you want to highlight potential spikes?

# Plotting
plotPerFig = 36 # How many plots per fig?
numRows = 6 # How many rows in your plot?
manSety = True # Do you want to manually set the ylim of the plots?
ymi = -20 #If yes, ylim bottom
yma = 20 #If yes, ylim top

## STEP 8: IDENTIFY POTENTIAL SPIKES

In [13]:
if hiLITEspikes: 
    
    # Initialize Variables
    mean_hat = np.zeros((32,32,2))
    std_hat = np.zeros((32,32,2))
    w_hat = np.zeros((32,32,2))
    noiseIdx = np.zeros((32,32))
    spikeThres = np.zeros((32,32))
    spans = np.zeros(np.shape(dataFilt))
    spanIdx = [None]*len(chId)

    # Loop through all channels
    for k in range(0,numChan):
        start = time.time()
        # Find at what times a channel was recorded
        samInd = np.asarray(np.where(abs(dataFilt[chMap[0,k],chMap[1,k],:])>0.00001))[0,:]
        # Loop back to top if current channel was not recorded
        if np.shape(samInd) == (0,):
            continue
        # Look at all the recorded data
        y = dataFilt[chMap[0,k],chMap[1,k],samInd[0]:samInd[-1]]
        gmSam = np.reshape(y,(len(y),1))
        # Separate data into 2 gaussian distributions (do this in the case many spikes make a distribution, still works if very few spikes exist because of outlier noise forms a distribution)
        gm = GaussianMixture(n_components=2).fit(gmSam)
        mean_hat[chMap[0,k],chMap[1,k],:] = gm.means_.flatten()
        w_hat[chMap[0,k],chMap[1,k],:] = gm.weights_.flatten()
        std_hat[chMap[0,k],chMap[1,k],:] = np.sqrt(gm.covariances_).flatten()
        # Identify the noise distribution of the two (it's the one centered on zero with a bandpass filter)
        noiseIdx[chMap[0,k],chMap[1,k]] = np.argmin(np.abs(mean_hat[chMap[0,k],chMap[1,k],:]))
        # Determine the negative threshold for a spike
        spikeThres[chMap[0,k],chMap[1,k]] = mean_hat[chMap[0,k],chMap[1,k],int(noiseIdx[chMap[0,k],chMap[1,k]])]-spikeDev*std_hat[chMap[0,k],chMap[1,k],int(noiseIdx[chMap[0,k],chMap[1,k]])]
        # Locate all spikes
        spanIdx[k] = np.where(dataFilt[chMap[0,k],chMap[1,k],:]<spikeThres[chMap[0,k],chMap[1,k]])
        spanIdx[k] = np.asarray(spanIdx[k][0])     

        # Remove consecutive samples for the same spike
        for x in reversed(range(0,len(spanIdx[k]))):
            if spanIdx[k][x] - spanIdx[k][x-1] == 1:
                spanIdx[k] = np.delete(spanIdx[k],x)

        end = time.time()
        text1 = 'Estimated Time Remaining: ' + str.format('{0:.2f}', (end-start)*(numChan-k)/60 ) + ' min'
        text2 = str(k+1)+'/'+str(numChan)+' Channels Filtered'
        print( text1 + ' ' + text2, end="\r" )

    # Report Channels with Most Spikes & Their Figure Numbers
    # spikeCnt = np.zeros((400))
    # for i in range(0,len(spanIdx)):
    #     spikeCnt[i] = len(spanIdx[i])
    # spikeCntIdx = np.argsort(spikeCnt)
    # spikeCntIdx = spikeCntIdx[::-1]
    # print(chId[spikeCntIdx])



Estimated Time Remaining: 0.00 min 1024/1024 Channels Filtered

## Want to change the spike threshold?

In [16]:
newThres = 3.5
newSpikeThres(newThres)

Estimated Time Remaining: 0.00 min 1024/1024 Channels Filtered

## STEP 9: PLOTTING ALL THOSE TRACES

In [17]:
# Make the correct number of figures
numFigs = int(numChan/plotPerFig)
numCols = int(plotPerFig/numRows)
if numCols*numRows < plotPerFig:
    numCols += 1
butTrack = {}

# Prep for Plotting
plt.close('all')
plt.ioff()
gk = 0

# Generate figure and define style
fig = plt.figure(gk,figsize=(18,12),facecolor='white',constrained_layout=False)
plt.style.use('fivethirtyeight')

# Figure Title
title = 'Figure ' + str(gk) + ':'
startElec = str(int(chId[gk*plotPerFig]))

# Generate subplots
for m in range(1,plotPerFig+1):
    plt.subplot(numRows,numCols,m)
    
    # Plot time & amplitude information
    plt.plot(times,dataFilt[chMap[0,gk*plotPerFig+m-1],chMap[1,gk*plotPerFig+m-1],:],linewidth=1,color='k')
    plt.xlim([times[int(startIdx[gk*plotPerFig+m-1])],times[int(startIdx[gk*plotPerFig+m-1])]+int(timeWin)])
    # Aesthetics
    if manSety == True:
        plt.ylim([ymi,yma])
    ax = plt.gca()
    ax.ticklabel_format(useOffset=False)
    ax.relim()
    ax.autoscale_view()
    text_elec = 'Ch # ' + str(int(chId[gk*plotPerFig+m-1]))
    ax.annotate(text_elec,(0.6,0.9),xycoords='axes fraction',va='center',fontsize = 11,color='blue')
    
    if hiLITEspikes:
    # Check for spikes & Plot Spans for All:
        spanStarts = spanIdx[gk*plotPerFig+m-1]
        if np.shape(spanStarts) != (0,):
            for i in range(0,len(spanStarts)):
                if spanStarts[i] <= 5:
                    plt.axvspan(times[0], times[spanStarts[i]+15], ec = 'white', color='yellow',alpha=0.75)
                else:
                    plt.axvspan(times[spanStarts[i]-5], times[spanStarts[i]+15], ec = 'white', color='yellow',alpha=0.75)

    # Break if we reach the end of our electrodes
    if gk*plotPerFig+m-1 >= numChan:
        break

# Adjust the subplots to take the entire screen        
plt.subplots_adjust(left=0.03,right=0.978,bottom=0.05,top=0.95,wspace=0.25,hspace=0.2)

# Figure Title 
stopElec = str(int(chId[gk*plotPerFig+m-1]))
figTitle = title + ' Ch ' + startElec + ' to Ch ' + stopElec
axs = fig.get_axes()
ax = axs[-1]
ax.set_xlabel('Time (ms)')
ax.set_ylabel('ADC')
ax.yaxis.set_label_position("right")
ax = axs[0]
ax.set_title(figTitle,fontsize=16,ha='center',va='bottom',fontweight='bold')

# Add Buttons
# Reset
axButn1 = plt.axes(([0.7, 0.955, 0.05, 0.04]))
btn1 = Button(axButn1, label="Reset",color='pink',hovercolor='red')
butTrack[0,gk] = btn1
butTrack[0,gk].on_clicked(callback_next_btn1)
# Next
axButn2 = plt.axes(([0.9, 0.955, 0.05, 0.04]))
btn2 = Button(axButn2, label="Next",color='lightblue',hovercolor='blue')
butTrack[1,gk] = btn2
butTrack[1,gk].on_clicked(callback_next_btn2)
# Back
axButn3 = plt.axes(([0.85, 0.955, 0.05, 0.04]))
btn3 = Button(axButn3, label="Back",color='lightblue',hovercolor='blue')
butTrack[2,gk] = btn3
butTrack[2,gk].on_clicked(callback_next_btn3)
# Button for Time Window
axButn4 = plt.axes(([0.5, 0.955, 0.1, 0.04]))
btn4 = Button(axButn4, label="Time Window", color = 'lightgreen',hovercolor='green')
butTrack[3,gk] = btn4
butTrack[3,gk].on_clicked(callback_next_btn4)
# Button for Next Figure
nextFig = 0
axButn5 = plt.axes(([0.25,0.955,0.05,0.04]))
btn5 = Button(axButn5, label="Next Fig",color='gold',hovercolor='goldenrod')
butTrack[4,gk] = btn5
butTrack[4,gk].on_clicked(callback_next_btn5)
# Button for Scale Adjust
axButn6 = plt.axes(([0.45, 0.955, 0.05, 0.04]))
btn6 = Button(axButn6, label="Y Scale", color = 'lightgreen',hovercolor='green')
butTrack[5,gk] = btn6 
butTrack[5,gk].on_clicked(callback_next_btn6)

# Flush the plot out
plt.show()

## OPTIONAL: Frequency Investigation FFT of Raw or Filtered Data

In [None]:
plt.close('all')    
from scipy.fft import fft, fftfreq

fig = plt.figure(0,figsize=(18,12),facecolor='white',constrained_layout=False)
plt.style.use('fivethirtyeight')
start = 0

for k in range(0,10):
    while start == 0:
        kstart = k
        start +=1
    plt.subplot(5,2,k+(kstart*-1+1))
    samInd = np.asarray(np.where(abs(dataFilt[chMap[0,k],chMap[1,k],:])>0.00001))[0,:]
    y = dataFilt[chMap[0,k],chMap[1,k],samInd[0]:samInd[-1]]
#     samInd = np.asarray(np.where(abs(dataAll[chMap[0,k],chMap[1,k],:])>0.00001))[0,:]
#     y = dataAll[chMap[0,k],chMap[1,k],samInd[0]:samInd[-1]]
    
    Nsam = len(y)
    dt = 1/(20e3)
    x = np.linspace(0.0,Nsam*dt,Nsam,endpoint=False)
    
    yf = fft(y)
    xf = fftfreq(Nsam,dt)[:Nsam//2]
    
    plt.plot(xf,2.0/Nsam*np.abs(yf[0:Nsam//2]),linewidth=1,color='r')
    plt.grid()
    
    ax = plt.gca()
    text_elec = 'Ch # ' + str(int(map2idx(chMap[0,k],chMap[1,k])))
    ax.annotate(text_elec,(0.6,0.9),xycoords='axes fraction',va='center',fontsize = 11,color='blue')


    plt.xlabel('Frequency (Hz)')
    plt.xlim([0,6000])
    plt.ylim([0,1])
    
plt.suptitle('After Filter with Litke Spikes')  
plt.show()
        

## OPTIONAL: Amplitude Investigation

In [13]:
# Plot histogram of amplitudes from the timepoints where recording took place    
# Adjust plot number if needed
plt.close(1)    
fig = plt.figure(1,figsize=(18,12),facecolor='white',constrained_layout=False)
plt.style.use('fivethirtyeight')
start = 0

# Adjust range to see different channels (plot currently does 10 at a time)
for k in range(10,20):
    # Call the right subplot
    while start == 0:
        kstart = k
        start +=1
    plt.subplot(5,2,k+(kstart*-1+1))
    
    # Filtered Data: Find the timepoints
    samInd = np.asarray(np.where(abs(dataFilt[chMap[0,k],chMap[1,k],:])>0.00001))[0,:]  
    y = dataFilt[chMap[0,k],chMap[1,k],samInd[0]:samInd[-1]]
    n, bins, patches = plt.hist(x=y, bins='auto', density=True, alpha=0.7,rwidth=0.85)

    gmSam = np.reshape(y,(len(y),1))
    gm = GaussianMixture(n_components=2).fit(gmSam)
    means_hat = gm.means_.flatten()
    weights_hat = gm.weights_.flatten()
    std_hat = np.sqrt(gm.covariances_).flatten()
    
    ystd = np.std(y)
    plt.grid(axis='y',alpha=0.75)
    plt.xlabel('Amplitude')
    plt.ylabel('Frequency')
    plt.xlim([-20,20])
    plt.ylim([0,0.2])
    
    xmin, xmax = plt.xlim()
    x = np.linspace(xmin,xmax,1000)
    mu1_h, sd1_h = means_hat[0], std_hat[0]
    p1 = stats.norm.pdf(x,mu1_h,sd1_h)
    plt.plot(x, p1, 'k', linewidth=2)
    mu2_h, sd2_h = means_hat[1], std_hat[1]
    p2 = stats.norm.pdf(x,mu2_h,sd2_h)
    plt.plot(x, p2, 'r', linewidth=2)
    
    # Find which components has mean closer to 0
    noiseIdx = np.argmin(np.abs(means_hat))
    # How many std? 1/1000 FP means 99.9% of values fall in noise distribution. Z-score is 3.291
    print(3.291*std_hat[noiseIdx])
    print(means_hat[noiseIdx]-3.291*std_hat[noiseIdx])
    
    
    ax = plt.gca()
    text_elec = 'Ch # ' + str(int(map2idx(chMap[0,k],chMap[1,k])))
    ax.annotate(text_elec,(0.6,0.9),xycoords='axes fraction',va='center',fontsize = 11,color='blue')
    
plt.suptitle('After Filter with Litke Spikes')  
plt.show()

8.76700477196227
-8.041724002097506
8.238227371776336
-7.534153395506474
8.504167505548022
-7.808684790032875
7.4221301251524725
-6.696823618922469
7.512566602890332
-6.794273902665952
7.148297711550928
-6.43894621358195
7.493231416296627
-6.82060286924475
7.4192398954901275
-6.676485756775496
7.539597175894101
-6.805238700857252
7.276774180522459
-6.567904536709561


## DEPRECATED: FURTHER FREQUENCY INVESTIGATION

In [None]:
# Further Frequency Investigation comparing raw data and litke data 
# yf1 = raw data yf2 = litke spikes
# Note: I took the difference of FFTs but since litke and raw have different number of points, its approximated. Could make exact by making 0's in the raw data where the litke spikes are
# yf2 = []
# xf2 = []

# for k in range(20,30):

#     samInd = np.asarray(np.where(abs(dataAll2[chMap[0,k],chMap[1,k],:])>0.00001))[0,:]
#     y = dataAll2[chMap[0,k],chMap[1,k],samInd[0]:samInd[-1]]
    
#     Nsam = len(y)
#     dt = 1/(20e3)
#     x = np.linspace(0.0,Nsam*dt,Nsam,endpoint=False)
    
#     yf = fft(y)
#     yf2.append(yf)
#     xf = fftfreq(Nsam,dt)[:Nsam//2]
#     xf2.append(xf)
    
# fig = plt.figure(3,figsize=(18,12),facecolor='white',constrained_layout=False)
# plt.style.use('fivethirtyeight')
# start = 0

# for k in range(0,10):
#     while start == 0:
#         kstart = k
#         start +=1
#     plt.subplot(5,2,k+1)
    
#     lenDif = np.asarray(np.shape(yf2[k]))-np.asarray(np.shape(yf1[k]))
#     print(lenDif)
#     yf1_len = np.append(yf1[k],np.zeros((lenDif)))
    
#     yf = abs(yf2[k][:]-yf1_len[:])
#     Nsam = len(yf)

#     plt.plot(xf2[k][:],2.0/Nsam*np.abs(yf[0:Nsam//2]),linewidth=1,color='r')
             
             
#     plt.grid()
    
#     ax = plt.gca()
#     text_elec = 'Ch # ' + str(int(map2idx(chMap[0,k+20],chMap[1,k+20])))
#     ax.annotate(text_elec,(0.6,0.9),xycoords='axes fraction',va='center',fontsize = 11,color='blue')


#     plt.xlabel('Frequency (Hz)')
#     plt.xlim([0,10000])
#     plt.ylim([0,1])
    
# plt.suptitle('Approximate Difference of FFTs')  
# plt.show()