## STEP 1: LOAD DEPENDENCIES, SET GLOBAL PARAMETERS, DEFINE UTILITY FUNCTIONS

In [1]:
# Import Libraries and Library Settings
import numpy as np
import scipy.io as sio
import scipy.signal as signal
from os.path import dirname, join as pjoin
import time
from datetime import datetime
import matplotlib.pyplot as plt
%matplotlib tk
from matplotlib.animation import FuncAnimation
import glob
import os
import re
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.widgets import Button
from tkinter import simpledialog
import tkinter as tk
from sklearn.mixture import GaussianMixture

# Setting Params for plotting & troubleshooting
np.set_printoptions(threshold=np.inf)
plt.rcParams['animation.html']='jshtml'

# Define Functions in Use

# If you have a channel's row and col, return the channel's index
def map2idx(ch_row,ch_col):
    if ch_row > 31 or ch_row <0:
        print('Row out of range')
    elif ch_col >31 or ch_col<0:
        print('Col out of range')
    else:
        ch_idx = int(ch_row*32 + ch_col)
    return ch_idx

# If you have a channel index, return the channel's row and col
def idx2map(ch_idx):
    if ch_idx > 1023 or ch_idx < 0:
        print('Chan num out of range')
    else:
        ch_row = int(ch_idx/32)
        ch_col = int(ch_idx - ch_row*32)
    return (ch_row,ch_col)

 TO DOS:
#OPTION TO PLOT CHANNELS WITH MOST SPIKES FIRST
#GENERATE LIST OF WHAT FIG NUM EACH CHANNEL IS ON (spikeCntIdx)
#PARALLELIZE CHANNEL FILTERING AND GAUSSIANMIXTURE CODE (OR FIND FASTER FILTER)
#MAKE X AXIS LOOK GOOD ON ANY DISPLAY (NO SQUISHING NUMBERS TOGETHER)
#DONT HIGHLIGHT PHYSICALLY IMPOSSIBLE EVENTS AS SPIKES (would prefer to make this as a distribution measure)

## STEP 2: INITIALIZE VARIABLES FOR LOADING AND CONVERTING DATA

In [2]:
# For current variable names
path = '/Users/vision/Desktop/Xilinx'
date_piece = '2022-01-14-1'
datarun = 'data001'
# date_piece = '2022-01-14-0'
# datarun = 'data009'
buffer_num = 0

string = 'gmem1' #variable of the data within the mat files

## STEP 3: INITIALIZE BUTTON FUNCTIONS

In [3]:
def callback_next_btn1(event):
    fig = plt.gcf()
    figNum = plt.gcf().number
    print('Reset button pressed')
    for m in range(0,plotPerFig):
        axs = fig.get_axes()
        ax = axs[m]
        ax.set_xlim([times[int(startIdx[figNum*plotPerFig+m])],times[int(startIdx[figNum*plotPerFig+m])]+int(timeWin)])
        # Set xlim to the time window of the first recorded values
        if manSety == True:
            plt.ylim([ymi,yma])
        ax = plt.gca()
        ax.ticklabel_format(useOffset=False)
        ax.relim()
        ax.autoscale_view()
        
        plt.draw()
        
def callback_next_btn2(event):
    fig = plt.gcf()
    print('Next button pressed')
    for m in range(0,plotPerFig):
        axs = fig.get_axes()
        ax = axs[m]
        [xLeft, xRight] = ax.get_xlim()
        ax.set_xlim([xLeft+timeWin,xRight+timeWin])
        # Set xlim to the time window of the first recorded values
        if manSety == True:
            plt.ylim([ymi,yma])
        ax = plt.gca()
        ax.ticklabel_format(useOffset=False)
        ax.relim()
        ax.autoscale_view()
        
        plt.draw()
        
def callback_next_btn3(event):
    fig = plt.gcf()
    print('Back button pressed')
    for m in range(0,plotPerFig):
        axs = fig.get_axes()
        ax = axs[m]
        [xLeft, xRight] = ax.get_xlim()
        ax.set_xlim([xLeft-timeWin,xRight-timeWin])
        # Set xlim to the time window of the first recorded values
        if manSety == True:
            plt.ylim([ymi,yma])
        ax = plt.gca()
        ax.ticklabel_format(useOffset=False)
        ax.relim()
        ax.autoscale_view()
        
        plt.draw()
        
def callback_next_btn4(event):
    fig = plt.gcf()
    print('Time Window button pressed')
    global timeWin
    timeStr = simpledialog.askstring(title='Time Change',prompt='Enter desired time window (ms):')
    timeWin = int(timeStr)

    for m in range(0,plotPerFig):
        axs = fig.get_axes()
        ax = axs[m]
        [xLeft, xRight] = ax.get_xlim()
        ax.set_xlim([xLeft,xLeft+timeWin])
        # Set xlim to the time window of the first recorded values
        plt.axes(ax)
        if manSety == True:
            plt.ylim([ymi,yma])
             
        ax = plt.gca()
        ax.ticklabel_format(useOffset=False)
        ax.relim()
        ax.autoscale_view()
        
        plt.draw()
        
def callback_next_btn5(event):
    fig_cur = plt.gcf()
    print('Next Fig Pressed')
    global gk
    figStr = simpledialog.askstring(title='Plot Next Figure',prompt='Enter Figure # To Display: 0 thru ' + str(numFigs-1))
    gk = int(figStr)
    
    closeStr = simpledialog.askstring(title='Close Current Figure?',prompt='Enter 1 to close current figure and 0 to remain open')
    closeYN = int(closeStr)
    if closeYN == 1:
        plt.close(fig_cur)
    elif closeYN != 0 and closeYN != 1:
        closeYN = 0
        print('Only enter 1 or 0 to close or keep open current figure')
    
    if gk < 0 or gk > numFigs:
        gk = 0
        print('Requested Figure Out of Range')
        
    # Generate figure and define style
    fig = plt.figure(gk,figsize=(18,12),facecolor='white',constrained_layout=False)
    plt.style.use('fivethirtyeight')

    # Figure Title
    title = 'Figure ' + str(gk) + ':'
    startElec = str(int(chId[gk*plotPerFig]))

    # Generate subplots
    for m in range(1,plotPerFig+1):
        plt.subplot(numRows,numCols,m)
        # Check if this range has a potential spike
        plt.plot(times,dataFilt[chMap[0,gk*plotPerFig+m-1],chMap[1,gk*plotPerFig+m-1],:],linewidth=1,color='k')

        # Set xlim to the time window of the first recorded values
        plt.xlim([times[int(startIdx[gk*plotPerFig+m-1])],times[int(startIdx[gk*plotPerFig+m-1])]+int(timeWin)])

        if manSety == True:
            plt.ylim([ymi,yma])
        ax = plt.gca()
        ax.ticklabel_format(useOffset=False)
        ax.relim()
        ax.autoscale_view()
        text_elec = 'Ch # ' + str(int(chId[gk*plotPerFig+m-1]))
        ax.annotate(text_elec,(0.6,0.9),xycoords='axes fraction',va='center',fontsize = 11,color='blue')

        # Check for spikes & Plot Spans for All:
        spanStarts = spanIdx[gk*plotPerFig+m-1]
        if np.shape(spanStarts) != (0,):
            for i in range(0,len(spanStarts)):
                if spanStarts[i] <= 5:
                    plt.axvspan(times[0], times[spanStarts[i]+15], ec = 'white', color='yellow',alpha=0.75)
                else:
                    plt.axvspan(times[spanStarts[i]-5], times[spanStarts[i]+15], ec = 'white', color='yellow',alpha=0.75)


        # Update so we plot all electrodes
        if gk*plotPerFig+m-1 >= numChan:
            break

    # Adjust the subplots to take the entire screen        
    plt.subplots_adjust(left=0.03,right=0.978,bottom=0.05,top=0.95,wspace=0.25,hspace=0.2)

    # Figure Title 
    stopElec = str(int(chId[gk*plotPerFig+m-1]))
    figTitle = title + ' Ch ' + startElec + ' to Ch ' + stopElec
    axs = fig.get_axes()
    ax = axs[-1]
    ax.set_xlabel('Time (ms)')
    ax.set_ylabel('ADC')
    ax.yaxis.set_label_position("right")
    ax = axs[0]
    ax.set_title(figTitle,fontsize=16,ha='center',va='bottom',fontweight='bold')

    plt.xlabel('Time (ms)')
    plt.ylabel('ADC',loc='bottom')

    # Add Buttons
    # Reset
    axButn1 = plt.axes(([0.7, 0.955, 0.05, 0.04]))
    btn1 = Button(axButn1, label="Reset",color='pink',hovercolor='red')
    butTrack[0,gk] = btn1
    butTrack[0,gk].on_clicked(callback_next_btn1)
    # Next
    axButn2 = plt.axes(([0.9, 0.955, 0.05, 0.04]))
    btn2 = Button(axButn2, label="Next",color='lightblue',hovercolor='blue')
    butTrack[1,gk] = btn2
    butTrack[1,gk].on_clicked(callback_next_btn2)
    # Back
    axButn3 = plt.axes(([0.85, 0.955, 0.05, 0.04]))
    btn3 = Button(axButn3, label="Back",color='lightblue',hovercolor='blue')
    butTrack[2,gk] = btn3
    butTrack[2,gk].on_clicked(callback_next_btn3)
    # Button for Time Window
    axButn4 = plt.axes(([0.5, 0.955, 0.1, 0.04]))
    btn4 = Button(axButn4, label="Alt Time Window", color = 'lightgreen',hovercolor='green')
    butTrack[3,gk] = btn4
    butTrack[3,gk].on_clicked(callback_next_btn4)
    # Button for Next Figure
    nextFig = 0
    axButn5 = plt.axes(([0.25,0.955,0.05,0.04]))
    btn5 = Button(axButn5, label="Next Fig",color='gold',hovercolor='goldenrod')
    butTrack[4,gk] = btn5
    butTrack[4,gk].on_clicked(callback_next_btn5)
    # Button for Scale Adjust
    axButn6 = plt.axes(([0.45, 0.955, 0.05, 0.04]))
    btn6 = Button(axButn6, label="Y Scale", color = 'lightgreen',hovercolor='green')
    butTrack[5,gk] = btn6 
    butTrack[5,gk].on_clicked(callback_next_btn6)

    # Flush the plot out
    plt.show()
    
def callback_next_btn6(event):
    fig = plt.gcf()
    print('Y Scale button pressed')
    global ymi, yma, manSety
    ymiStr = simpledialog.askstring(title='Y Scale Change',prompt='Enter desired ymin (ADC):')
    ymi = int(ymiStr)
    ymaStr = simpledialog.askstring(title='Y Scale Change',prompt='Enter desired ymax (ADC):')
    yma = int(ymaStr)

    for m in range(0,plotPerFig):
        axs = fig.get_axes()
        ax = axs[m]
        ax.set_ylim([ymi,yma])
        manSety = True
        plt.axes(ax)
             
        ax = plt.gca()
        ax.ticklabel_format(useOffset=False)
        ax.relim()
        ax.autoscale_view()
        
        plt.draw()

## STEP 4: LOAD AND CONVERT ALL THE DATA
Load from original files, compressed with a few spikes, or compressed with many spikes

In [4]:
# Load from original files
# Prep file for loading
fileDir = pjoin(date_piece,datarun)
bufDir = os.listdir(fileDir)
num_of_buf = len(bufDir)
bramdepth = 65536

# Initialize Variables
timeTrack = 0
cntTrack = 0
dataAll = np.zeros((32,32,int(bramdepth*num_of_buf/2))) #Largest possible value of dataAll, perfect recording, only double cnt
cntAll = np.zeros((32,32,int(bramdepth*num_of_buf/2)))
times = np.zeros((int(bramdepth*num_of_buf/2)))

#Idea: plot in order of which buffers have the most 'spikes'

# Process all the Data

for k in range(0,num_of_buf):
    
    print(k)
    
    # Load Data from this Loop's Buffer
    file_next = pjoin(date_piece,datarun,datarun+'_'+str(k)+'.mat')
    print(file_next)
    mat_contents = sio.loadmat(file_next)
    dataRaw = mat_contents[string][0][:]
    
    # Initialize Variables Needed for Each Buffer
    chan_index_pre = 1025 #Check for chan changes, double cnt
    cnt_pre = 0 #Check for cnt changes, double cnt
    N = 0 #Sample times (DOES NOT ALLOW NON-COLLISION FREE SAMPLES)
    data_real = np.zeros((32,32,len(dataRaw)-2))  #Initialize to max possible length. Note: Throw out first two values b/c garbo
    cnt_real = np.zeros((32,32,len(dataRaw)-2))
    
    # Convert data and remove double/triple counts
    for i in range(2,len(dataRaw)-1): 
            # Convert bit number into binary
            word = (np.binary_repr(dataRaw[i],32))

            # Break that binary into it's respective pieces and convert to bit number
            cnt = int(word[12:14],2)
            col = int(word[27:32],2)
            row = int(word[22:27],2)
            chan_index = row*32 + col

            # Only record the unique non-double count sample
            if(i==2 or (cnt_pre != cnt or chan_index != chan_index_pre)):

                # Sample time only changes when cnt changes
                if cnt != cnt_pre:
                    N += 1

                # On the occurance the first cnt is not 0, make sure sample time is 0
                if i == 2:
                    N = 0

                # Update variables
                cnt_pre = cnt
                chan_index_pre = chan_index

                # Record pertinent data
                data_real[row][col][N] = int(word[14:22],2)
                cnt_real[row][col][N] = cnt
    
    # Determine time estimate and sample counts for the total combined buffers
    if timeTrack == 0: #For the first buffer, we assume the first sample comes in at time 0
        dataAll[:,:,:N] = data_real[:,:,:N]
        cntAll[:,:,:N] = cnt_real[:,:,:N]
        endTime = N*0.05 # 20kHz sampling rate, means time_recording (ms) = num_sam*0.05ms
        new_times = np.linspace(0,endTime,N+1)
        times[0:len(new_times)] = new_times
    elif timeTrack != 0: # For buffers after the first, we place these values directly after the previous buffer (note this does not take into account communication delays - hence an estimate)
        dataAll[:,:,cntTrack:cntTrack+N] = data_real[:,:,:N]
        cntAll[:,:,cntTrack:cntTrack+N] = cnt_real[:,:,:N]
        endTime = N*0.05
        new_times = np.linspace(timeTrack,timeTrack+endTime,N)
        times[cntTrack:cntTrack+N] = new_times
    
    # Update for the next buffer file
    timeTrack += endTime
    cntTrack += N
    
# Truncate dataAll, cntAll, and times to remove the 0's representing lost data potential due to triple counts
dataAll = dataAll[:,:,:cntTrack-1]
cntAll = cntAll[:,:,:cntTrack-1]
times = times[:cntTrack-1]
        

0
2022-01-14-1/data001/data001_0.mat
1
2022-01-14-1/data001/data001_1.mat
2
2022-01-14-1/data001/data001_2.mat
3
2022-01-14-1/data001/data001_3.mat
4
2022-01-14-1/data001/data001_4.mat
5
2022-01-14-1/data001/data001_5.mat
6
2022-01-14-1/data001/data001_6.mat
7
2022-01-14-1/data001/data001_7.mat
8
2022-01-14-1/data001/data001_8.mat
9
2022-01-14-1/data001/data001_9.mat
10
2022-01-14-1/data001/data001_10.mat
11
2022-01-14-1/data001/data001_11.mat
12
2022-01-14-1/data001/data001_12.mat
13
2022-01-14-1/data001/data001_13.mat
14
2022-01-14-1/data001/data001_14.mat
15
2022-01-14-1/data001/data001_15.mat
16
2022-01-14-1/data001/data001_16.mat
17
2022-01-14-1/data001/data001_17.mat
18
2022-01-14-1/data001/data001_18.mat
19
2022-01-14-1/data001/data001_19.mat
20
2022-01-14-1/data001/data001_20.mat
21
2022-01-14-1/data001/data001_21.mat
22
2022-01-14-1/data001/data001_22.mat
23
2022-01-14-1/data001/data001_23.mat
24
2022-01-14-1/data001/data001_24.mat
25
2022-01-14-1/data001/data001_25.mat
26
202

In [51]:
# Data with Litke Spikes
dataAllLoad = np.load("compressed_dataAll_2022-01-14-1_data001_litkeSpikes.npz")
dataAll = dataAllLoad['modified_data']
dataAllSmall = dataAllLoad['small_spike_idx']
dataAllLarge = dataAllLoad['large_spike_idx']
times = np.linspace(0,np.shape(dataAll)[2]*0.05,np.shape(dataAll)[2])

In [4]:
# Data with Many Litke Spikes
dataAllLoad = np.load("compressed_dataAll_2022-01-14-1_data001_litkeSpikes_many.npz")
dataAll = dataAllLoad['modified_data']
dataAllLarge = dataAllLoad['large_spike_idx']
times = np.linspace(0,np.shape(dataAll)[2]*0.05,np.shape(dataAll)[2])

## STEP 5: FILTER THE DATA

In [5]:
# Choose Filter Type
filtType = 'bandpass' # filtType can currently be 'Hierlemann','highpass','bandpass', 'Litke', or none'
# Note: Hierlemann is relatively fast, highpass is a little slow, none is the fastest


In [6]:
# Filter the data
# Future update: only calculate for channels recorded not all
dataFilt = np.zeros((np.shape(dataAll)))
if filtType == 'Hierlemann':
    BP_LOW_CUTOFF = 100.0
    NUM_TAPS = 75
    TAPS = signal.firwin(NUM_TAPS,
                         [BP_LOW_CUTOFF, ],
                         pass_zero=False,
                         fs=20e3 * 1.0)
    a = 1
    for k in range(0,32):
        for m in range(0,32):
            dataFilt[k,m,:] = signal.filtfilt(TAPS,[a],dataAll[k,m,:])
            print(k,m)
            
elif filtType == 'highpass':
    nyq = 0.5 * (20e3*1.0)
    cutoff = 250/nyq
    b, a = signal.butter(5, [cutoff],btype="highpass",analog=False)
    for k in range(0,32):
        for m in range(0,32):
            dataFilt[k,m,:] = signal.filtfilt(b, a,dataAll[k,m,:])
            print(k,m)
    
elif filtType == 'bandpass':
    nyq = 0.5 * (20e3*1.0)
    cutoff1 = 250/nyq
    cutoff2 = 5000/nyq
    b, a = signal.butter(5, [cutoff1, cutoff2],btype="bandpass",analog=False)
    for k in range(0,32):
        for m in range(0,32):
            dataFilt[k,m,:] = signal.filtfilt(b, a,dataAll[k,m,:])
            print(k,m)
            
elif filtType == 'bandpass2':
    nyq = 0.5 * (20e3*1.0)
    cutoff1 = 250/nyq
    cutoff2 = 4000/nyq
    b, a = signal.butter(5, [cutoff1, cutoff2],btype="bandpass",analog=False)
    for k in range(0,32):
        for m in range(0,32):
            dataFilt[k,m,:] = signal.filtfilt(b, a,dataAll[k,m,:])
            print(k,m)
            
elif filtType == 'Litke':
    nyq = 0.5 * (20e3*1.0)
    cutoff1 = 250/nyq
    cutoff2 = 2000/nyq
    b, a = signal.butter(5, [cutoff1, cutoff2],btype="bandpass",analog=False)
    for k in range(0,32):
        for m in range(0,32):
            dataFilt[k,m,:] = signal.filtfilt(b, a,dataAll[k,m,:])
            print(k,m)

elif filtType == 'none':
    dataFilt = np.copy(dataAll)

else:
    dataFilt = np.copy(dataAll)
    print('Filter not recognized. Options include Hierlemann, highpass, bandpass, Litke or none')

0 0
0 1
0 2
0 3
0 4
0 5
0 6
0 7
0 8
0 9
0 10
0 11
0 12
0 13
0 14
0 15
0 16
0 17
0 18
0 19
0 20
0 21
0 22
0 23
0 24
0 25
0 26
0 27
0 28
0 29
0 30
0 31
1 0
1 1
1 2
1 3
1 4
1 5
1 6
1 7
1 8
1 9
1 10
1 11
1 12
1 13
1 14
1 15
1 16
1 17
1 18
1 19
1 20
1 21
1 22
1 23
1 24
1 25
1 26
1 27
1 28
1 29
1 30
1 31
2 0
2 1
2 2
2 3
2 4
2 5
2 6
2 7
2 8
2 9
2 10
2 11
2 12
2 13
2 14
2 15
2 16
2 17
2 18
2 19
2 20
2 21
2 22
2 23
2 24
2 25
2 26
2 27
2 28
2 29
2 30
2 31
3 0
3 1
3 2
3 3
3 4
3 5
3 6
3 7
3 8
3 9
3 10
3 11
3 12
3 13
3 14
3 15
3 16
3 17
3 18
3 19
3 20
3 21
3 22
3 23
3 24
3 25
3 26
3 27
3 28
3 29
3 30
3 31
4 0
4 1
4 2
4 3
4 4
4 5
4 6
4 7
4 8
4 9
4 10
4 11
4 12
4 13
4 14
4 15
4 16
4 17
4 18
4 19
4 20
4 21
4 22
4 23
4 24
4 25
4 26
4 27
4 28
4 29
4 30
4 31
5 0
5 1
5 2
5 3
5 4
5 5
5 6
5 7
5 8
5 9
5 10
5 11
5 12
5 13
5 14
5 15
5 16
5 17
5 18
5 19
5 20
5 21
5 22
5 23
5 24
5 25
5 26
5 27
5 28
5 29
5 30
5 31
6 0
6 1
6 2
6 3
6 4
6 5
6 6
6 7
6 8
6 9
6 10
6 11
6 12
6 13
6 14
6 15
6 16
6 17
6 18
6 19
6 20
6 21


## STEP 6: IDENTIFY RELEVANT CHANNELS 

In [7]:
# Find which channels were recorded
numSam = np.zeros((32,32))
numSam = np.count_nonzero(dataAll,axis=2)
numChan = np.count_nonzero(numSam)

# Map and Identify recorded channels
findCoors = np.nonzero(numSam)
chMap = np.array(findCoors)
chId = np.zeros((numChan))
startIdx = np.zeros((numChan))
for k in range(0,numChan):
    chId[k] = map2idx(chMap[0,k],chMap[1,k])
    startIdx[k] = (dataAll[chMap[0,k],chMap[1,k],:]!=0).argmax(axis=0)

## STEP 7: CUSTOMIZE PLOT SETTINGS

In [8]:
# Spike ID
timeWin = 20 #timeInterval you wish to view and find spikes in
spikeDev = 3.291 # 3.291 is 1/1000 False Positives
cutoff = -30

# Plotting
plotPerFig = 36 # How many plots per fig?
numRows = 6 # How many rows in your plot?
manSety = True # Do you want to manually set the ylim of the plots?
ymi = -20 #If yes, ylim bottom
yma = 20 #If yes, ylim top

## STEP 8: IDENTIFY POTENTIAL SPIKES

In [9]:
# Intialize variables for Gaussian Mixture
mean_hat = np.zeros((32,32,2))
std_hat = np.zeros((32,32,2))
w_hat = np.zeros((32,32,2))
noiseIdx = np.zeros((32,32))
spikeThres = np.zeros((32,32))
spans = np.zeros(np.shape(dataFilt))
spanIdx = [None]*len(chId)
cnt = 0

# Loop through each channel
for k in range(0,32):
    for m in range(0,32):
        # Find when recordings happened for specific channel
        samInd = np.asarray(np.where(abs(dataFilt[k,m,:])>0.00001))[0,:]
        # If channel was not recorded, go back to top
        if np.shape(samInd) == (0,):
            continue
        # Specify data only at time of recording for specific channel
        y = dataFilt[k,m,samInd[0]:samInd[-1]]
        gmSam = np.reshape(y,(len(y),1))
        # Split data into two normal distributions (in case spikes form a separate distribution, still works with very few spikes because second distribution will be made of outlier noise events)
        gm = GaussianMixture(n_components=2).fit(gmSam)
        mean_hat[k,m,:] = gm.means_.flatten()
        w_hat[k,m,:] = gm.weights_.flatten()
        std_hat[k,m,:] = np.sqrt(gm.covariances_).flatten()
        
        # Which of the distributions is the noise (centered at 0)
        noiseIdx[k,m] = np.argmin(np.abs(mean_hat[k,m,:]))
        # Find negative spike threshold
        spikeThres[k,m] = mean_hat[k,m,int(noiseIdx[k,m])]-spikeDev*std_hat[k,m,int(noiseIdx[k,m])]
        # Identify all spikes
        spanIdx[cnt] = np.where(dataFilt[k,m,:]<spikeThres[k,m])
        spanIdx[cnt] = np.asarray(spanIdx[cnt][0])     
        
        # Remove multiple samples indicated for a single spike
        for x in reversed(range(0,len(spanIdx[cnt]))):
            if spanIdx[cnt][x] - spanIdx[cnt][x-1] == 1:
                spanIdx[cnt] = np.delete(spanIdx[cnt],x)
        
        cnt += 1
    
# Report Channels with Most Spikes & Their Figure Numbers
spikeCnt = np.zeros((400))
for i in range(0,len(spanIdx)):
    spikeCnt[i] = len(spanIdx[i])
spikeCntIdx = np.argsort(spikeCnt)
spikeCntIdx = spikeCntIdx[::-1]
print(chId[spikeCntIdx])



[262. 600. 263. 260. 261. 462. 473. 273. 632. 557. 492. 643. 324. 313.
 327. 256. 296. 641. 514. 407. 328. 365. 541. 373. 561. 478. 486. 611.
 383. 306. 286. 491. 513. 589. 521. 582. 469. 276. 310. 515. 460. 608.
 526. 545. 640. 354. 446. 385. 346. 610. 390. 636. 574. 573. 506. 512.
 449. 411. 381. 380. 290. 562. 532. 297. 359. 556. 406. 493. 454. 269.
 426. 620. 432. 476. 467. 593. 453. 618. 468. 459. 622. 377. 441. 421.
 443. 536. 540. 361. 542. 371. 550. 353. 465. 317. 363. 330. 322. 305.
 504. 502. 499. 433. 345. 487. 461. 364. 270. 639. 281. 588. 280. 266.
 615. 577. 553. 288. 601. 564. 575. 307. 355. 613. 646. 497. 389. 633.
 583. 626. 348. 279. 599. 597. 630. 278. 265. 321. 474. 394. 344. 482.
 315. 418. 396. 302. 452. 325. 531. 559. 431. 413. 285. 340. 368. 299.
 627. 437. 332. 524. 339. 625. 337. 357. 520. 617. 423. 517. 516. 566.
 422. 379. 417. 403. 604. 430. 458. 444. 326. 405. 624. 342. 654. 434.
 607. 435. 410. 408. 268. 448. 409. 442. 602. 445. 304. 463. 552. 560.
 530. 

## STEP 9: PLOTTING ALL THOSE TRACES

In [10]:
# Make the correct number of figures
numFigs = int(numChan/plotPerFig)
numCols = int(plotPerFig/numRows)
if numCols*numRows < plotPerFig:
    numCols += 1
butTrack = {}

# Prep for Plotting
plt.close('all')
elecStep = 0
plt.ioff()
    
gk = 0
# Generate figure and define style
fig = plt.figure(gk,figsize=(18,12),facecolor='white',constrained_layout=False)
plt.style.use('fivethirtyeight')

# Figure Title
title = 'Figure ' + str(gk) + ':'
startElec = str(int(chId[gk*plotPerFig]))

# Find All Potential Spikes for All Datasets & Record Timepoints

# Generate subplots
for m in range(1,plotPerFig+1):
    plt.subplot(numRows,numCols,m)
    
    # Plot time & amplitude information
    plt.plot(times,dataFilt[chMap[0,gk*plotPerFig+m-1],chMap[1,gk*plotPerFig+m-1],:],linewidth=1,color='k')
    plt.xlim([times[int(startIdx[gk*plotPerFig+m-1])],times[int(startIdx[gk*plotPerFig+m-1])]+int(timeWin)])
    if manSety == True:
        plt.ylim([ymi,yma])
    ax = plt.gca()
    ax.ticklabel_format(useOffset=False)
    ax.relim()
    ax.autoscale_view()
    text_elec = 'Ch # ' + str(int(chId[gk*plotPerFig+m-1]))
    ax.annotate(text_elec,(0.6,0.9),xycoords='axes fraction',va='center',fontsize = 11,color='blue')
    
    # Check for spikes & Plot Spans for All:
    spanStarts = spanIdx[gk*plotPerFig+m-1]
    if np.shape(spanStarts) != (0,):
        for i in range(0,len(spanStarts)):
            if spanStarts[i] <= 5:
                plt.axvspan(times[0], times[spanStarts[i]+15], ec = 'white', color='yellow',alpha=0.75)
            else:
                plt.axvspan(times[spanStarts[i]-5], times[spanStarts[i]+15], ec = 'white', color='yellow',alpha=0.75)

    # Update so we plot all electrodes
    elecStep += 1
    if gk*plotPerFig+m-1 >= numChan:
        break

# Adjust the subplots to take the entire screen        
plt.subplots_adjust(left=0.03,right=0.978,bottom=0.05,top=0.95,wspace=0.25,hspace=0.2)

# Figure Title 
stopElec = str(int(chId[gk*plotPerFig+m-1]))
figTitle = title + ' Ch ' + startElec + ' to Ch ' + stopElec
axs = fig.get_axes()
ax = axs[-1]
ax.set_xlabel('Time (ms)')
ax.set_ylabel('ADC')
ax.yaxis.set_label_position("right")
ax = axs[0]
ax.set_title(figTitle,fontsize=16,ha='center',va='bottom',fontweight='bold')

# Add Buttons
# Reset
axButn1 = plt.axes(([0.7, 0.955, 0.05, 0.04]))
btn1 = Button(axButn1, label="Reset",color='pink',hovercolor='red')
butTrack[0,gk] = btn1
butTrack[0,gk].on_clicked(callback_next_btn1)
# Next
axButn2 = plt.axes(([0.9, 0.955, 0.05, 0.04]))
btn2 = Button(axButn2, label="Next",color='lightblue',hovercolor='blue')
butTrack[1,gk] = btn2
butTrack[1,gk].on_clicked(callback_next_btn2)
# Back
axButn3 = plt.axes(([0.85, 0.955, 0.05, 0.04]))
btn3 = Button(axButn3, label="Back",color='lightblue',hovercolor='blue')
butTrack[2,gk] = btn3
butTrack[2,gk].on_clicked(callback_next_btn3)
# Button for Time Window
axButn4 = plt.axes(([0.5, 0.955, 0.1, 0.04]))
btn4 = Button(axButn4, label="Time Window", color = 'lightgreen',hovercolor='green')
butTrack[3,gk] = btn4
butTrack[3,gk].on_clicked(callback_next_btn4)
# Button for Next Figure
nextFig = 0
axButn5 = plt.axes(([0.25,0.955,0.05,0.04]))
btn5 = Button(axButn5, label="Next Fig",color='gold',hovercolor='goldenrod')
butTrack[4,gk] = btn5
butTrack[4,gk].on_clicked(callback_next_btn5)
# Button for Scale Adjust
axButn6 = plt.axes(([0.45, 0.955, 0.05, 0.04]))
btn6 = Button(axButn6, label="Y Scale", color = 'lightgreen',hovercolor='green')
butTrack[5,gk] = btn6 
butTrack[5,gk].on_clicked(callback_next_btn6)

# Flush the plot out
plt.show()

Next Fig Pressed
Next button pressed
Next button pressed
Next button pressed
Next button pressed
Next button pressed
Next button pressed
Next button pressed
Next button pressed
Next button pressed
Next button pressed
Next button pressed
Next button pressed
Next button pressed
Next button pressed
Next button pressed
Next button pressed
Next button pressed
Next button pressed
Next button pressed
Back button pressed
Next Fig Pressed
Next button pressed
Next button pressed
Next button pressed
Next button pressed
Next button pressed
Next button pressed
Next button pressed
Next button pressed
Next button pressed
Next button pressed
Next button pressed
Next button pressed
Next button pressed
Next button pressed
Next button pressed
Next button pressed
Next button pressed
Next button pressed
Next button pressed


## OPTIONAL: Frequency Investigation FFT of Raw or Filtered Data

In [None]:
plt.close('all')    
from scipy.fft import fft, fftfreq

fig = plt.figure(0,figsize=(18,12),facecolor='white',constrained_layout=False)
plt.style.use('fivethirtyeight')
start = 0

for k in range(0,10):
    while start == 0:
        kstart = k
        start +=1
    plt.subplot(5,2,k+(kstart*-1+1))
    samInd = np.asarray(np.where(abs(dataFilt[chMap[0,k],chMap[1,k],:])>0.00001))[0,:]
    y = dataFilt[chMap[0,k],chMap[1,k],samInd[0]:samInd[-1]]
#     samInd = np.asarray(np.where(abs(dataAll[chMap[0,k],chMap[1,k],:])>0.00001))[0,:]
#     y = dataAll[chMap[0,k],chMap[1,k],samInd[0]:samInd[-1]]
    
    Nsam = len(y)
    dt = 1/(20e3)
    x = np.linspace(0.0,Nsam*dt,Nsam,endpoint=False)
    
    yf = fft(y)
    xf = fftfreq(Nsam,dt)[:Nsam//2]
    
    plt.plot(xf,2.0/Nsam*np.abs(yf[0:Nsam//2]),linewidth=1,color='r')
    plt.grid()
    
    ax = plt.gca()
    text_elec = 'Ch # ' + str(int(map2idx(chMap[0,k],chMap[1,k])))
    ax.annotate(text_elec,(0.6,0.9),xycoords='axes fraction',va='center',fontsize = 11,color='blue')


    plt.xlabel('Frequency (Hz)')
    plt.xlim([0,6000])
    plt.ylim([0,1])
    
plt.suptitle('After Filter with Litke Spikes')  
plt.show()
        

## OPTIONAL: Amplitude Investigation

In [59]:
# Plot histogram of amplitudes from the timepoints where recording took place
    
import scipy.stats as stats
plt.close(1)    
from sklearn.mixture import GaussianMixture

fig = plt.figure(1,figsize=(18,12),facecolor='white',constrained_layout=False)
plt.style.use('fivethirtyeight')
start = 0

for k in range(10,20):
    # Call the right subplot
    while start == 0:
        kstart = k
        start +=1
    plt.subplot(5,2,k+(kstart*-1+1))
    
    # Filtered Data: Find the timepoints
    samInd = np.asarray(np.where(abs(dataFilt[chMap[0,k],chMap[1,k],:])>0.00001))[0,:]  
    y = dataFilt[chMap[0,k],chMap[1,k],samInd[0]:samInd[-1]]
    n, bins, patches = plt.hist(x=y, bins='auto', density=True, alpha=0.7,rwidth=0.85)

    gmSam = np.reshape(y,(len(y),1))
    gm = GaussianMixture(n_components=2).fit(gmSam)
    means_hat = gm.means_.flatten()
    weights_hat = gm.weights_.flatten()
    std_hat = np.sqrt(gm.covariances_).flatten()
    
    ystd = np.std(y)
    plt.grid(axis='y',alpha=0.75)
    plt.xlabel('Amplitude')
    plt.ylabel('Frequency')
    plt.xlim([-20,20])
    plt.ylim([0,0.2])
    
    xmin, xmax = plt.xlim()
    x = np.linspace(xmin,xmax,1000)
    mu1_h, sd1_h = means_hat[0], std_hat[0]
    p1 = stats.norm.pdf(x,mu1_h,sd1_h)
    plt.plot(x, p1, 'k', linewidth=2)
    mu2_h, sd2_h = means_hat[1], std_hat[1]
    p2 = stats.norm.pdf(x,mu2_h,sd2_h)
    plt.plot(x, p2, 'r', linewidth=2)
    
    # Find which components has mean closer to 0
    noiseIdx = np.argmin(np.abs(means_hat))
    # How many std? 1/1000 FP means 99.9% of values fall in noise distribution. Z-score is 3.291
    print(3.291*std_hat[noiseIdx])
    print(means_hat[noiseIdx]-3.291*std_hat[noiseIdx])
    
    
    ax = plt.gca()
    text_elec = 'Ch # ' + str(int(map2idx(chMap[0,k],chMap[1,k])))
    ax.annotate(text_elec,(0.6,0.9),xycoords='axes fraction',va='center',fontsize = 11,color='blue')
    
plt.suptitle('After Filter with Litke Spikes')  
plt.show()

8.539353190310349
-8.506555675174807
7.683100822831957
-7.667936325146803
7.910522736809725
-7.902722795615734
6.409887877694772
-6.394264080150428
6.438242258990651
-6.426424339136473
5.9197326920868445
-5.914400598786331
5.86089079152754
-5.851386098150465
6.542342821516092
-6.523837402289213
6.717159427237503
-6.712662596451062
6.044899855404028
-6.033247361998068


## Deprecated: Further Frequency Investigation

In [None]:
# Further Frequency Investigation comparing raw data and litke data 
# yf1 = raw data yf2 = litke spikes
# Note: I took the difference of FFTs but since litke and raw have different number of points, its approximated. Could make exact by making 0's in the raw data where the litke spikes are
# yf2 = []
# xf2 = []

# for k in range(20,30):

#     samInd = np.asarray(np.where(abs(dataAll2[chMap[0,k],chMap[1,k],:])>0.00001))[0,:]
#     y = dataAll2[chMap[0,k],chMap[1,k],samInd[0]:samInd[-1]]
    
#     Nsam = len(y)
#     dt = 1/(20e3)
#     x = np.linspace(0.0,Nsam*dt,Nsam,endpoint=False)
    
#     yf = fft(y)
#     yf2.append(yf)
#     xf = fftfreq(Nsam,dt)[:Nsam//2]
#     xf2.append(xf)
    
# fig = plt.figure(3,figsize=(18,12),facecolor='white',constrained_layout=False)
# plt.style.use('fivethirtyeight')
# start = 0

# for k in range(0,10):
#     while start == 0:
#         kstart = k
#         start +=1
#     plt.subplot(5,2,k+1)
    
#     lenDif = np.asarray(np.shape(yf2[k]))-np.asarray(np.shape(yf1[k]))
#     print(lenDif)
#     yf1_len = np.append(yf1[k],np.zeros((lenDif)))
    
#     yf = abs(yf2[k][:]-yf1_len[:])
#     Nsam = len(yf)

#     plt.plot(xf2[k][:],2.0/Nsam*np.abs(yf[0:Nsam//2]),linewidth=1,color='r')
             
             
#     plt.grid()
    
#     ax = plt.gca()
#     text_elec = 'Ch # ' + str(int(map2idx(chMap[0,k+20],chMap[1,k+20])))
#     ax.annotate(text_elec,(0.6,0.9),xycoords='axes fraction',va='center',fontsize = 11,color='blue')


#     plt.xlabel('Frequency (Hz)')
#     plt.xlim([0,10000])
#     plt.ylim([0,1])
    
# plt.suptitle('Approximate Difference of FFTs')  
# plt.show()