In [1]:
import os
import sys
import json
import time
import pickle
import gspread
import numpy as np
import pandas as pd
from datetime import date
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from matplotlib.gridspec import GridSpec
from matplotlib.colorbar import Colorbar
from matplotlib.patches import Patch
from scipy import integrate, signal, stats, fftpack

In [2]:
sys.path.append(r'C:\Users\lesliec\code')

In [3]:
from tbd_eeg.tbd_eeg.data_analysis.eegutils import EEGexp
from tbd_eeg.tbd_eeg.data_analysis.Utilities.utilities import (
    get_stim_events,
    get_evoked_traces,
    get_evoked_firing_rates,
    find_nearest_ind
)
from allensdk.brain_observatory.ecephys.lfp_subsampling.subsampling import remove_lfp_offset
from allensdk.core.mouse_connectivity_cache import MouseConnectivityCache
from PCIst.PCIst.pci_st import calc_PCIst, dimensionality_reduction

In [4]:
%matplotlib notebook

In [5]:
plt.rcParams.update({'font.size': 12})

Load CCF for identifying cortical areas

In [6]:
mcc = MouseConnectivityCache(resolution=10)
str_tree = mcc.get_structure_tree()

Load Zap_Zip-log_exp to get metadata for experiments

In [7]:
_gc = gspread.service_account() # need a key file to access the account
_sh = _gc.open('Zap_Zip-log_exp') # open the spreadsheet
_df = pd.DataFrame(_sh.sheet1.get()) # load the first worksheet
zzmetadata = _df.T.set_index(0).T # put it in a nicely formatted dataframe

Define areas of interest to plot population activity

In [8]:
areas_of_interest = {
    'MO': [
        'MOp1', 'MOp2/3', 'MOp5', 'MOp6a', 'MOp6b',
        'MOs1', 'MOs2/3', 'MOs5', 'MOs6a', 'MOs6b'
    ],
    'ACA': [
        'ACAd1', 'ACAd2/3', 'ACAd5', 'ACAd6a', 'ACAd6b',
        'ACAv1', 'ACAv2/3', 'ACAv5', 'ACAv6a', 'ACAv6b'
    ],
    'SS': [
        'SSp-bfd1', 'SSp-bfd2/3', 'SSp-bfd4', 'SSp-bfd5', 'SSp-bfd6a', 'SSp-bfd6b',
        'SSp-ll1', 'SSp-ll2/3', 'SSp-ll4', 'SSp-ll5', 'SSp-ll6a', 'SSp-ll6b',
        'SSp-tr1', 'SSp-tr2/3', 'SSp-tr4', 'SSp-tr5', 'SSp-tr6a', 'SSp-tr6b'
    ],
    'VIS': [
        'VISp1', 'VISp2/3', 'VISp4', 'VISp5', 'VISp6a', 'VISp6b',
        'VISam1', 'VISam2/3', 'VISam4', 'VISam5', 'VISam6a', 'VISam6b',
        'VISpm1', 'VISpm2/3', 'VISpm4', 'VISpm5', 'VISpm6a', 'VISpm6b',
        'VISrl1', 'VISrl2/3', 'VISrl4', 'VISrl5', 'VISrl6a', 'VISrl6b',
    ],
    'MO-TH': [
        'AV', 'CL', 'MD', 'PO', 'RT', 'VAL', 'VPL', 'VPM', 'VM'
    ],
}

In [9]:
area_colors = {
    'MO': (31/255, 157/255, 90/255), # [31, 157, 90] 'blue'
    'ACA': (64/255, 166/255, 102/255), # [64, 166, 102] 'deepskyblue'
    'SS': (24/255, 128/255, 100/255), # [24, 128, 100] 'blueviolet'
    'VIS': (8/255, 133/255, 140/255), # [8, 133, 140] 'green'
    'MO-TH': (255/255, 112/255, 128/255), # [255, 112, 128] 'steelblue'
#     'VIS-TH': 'olivedrab'
}

In [10]:
# state_colors = {
#     'awake': (120/255, 156/255, 74/255),
#     'anesthetized': (130/255, 122/255, 163/255),
#     'recovery': (93/255, 167/255, 229/255)
# }
state_colors = {
    'resting': 'royalblue',
    'running': 'seagreen',
    'anesthetized': 'indianred',
}

#### Functions

In [11]:
def get_stim_event_inds(stim_table, stim_type, stim_param, sweep, trials='resting'):
    if trials == 'resting':
        return stim_table[
            (stim_table['stim_type'] == stim_type) &
            (stim_table['parameter'] == stim_param) &
            (stim_table['sweep'] == sweep) &
            (stim_table['good'] == True) &
            (stim_table['resting_trial'] == True)
        ].index.values
    elif trials == 'running':
        return stim_table[
            (stim_table['stim_type'] == stim_type) &
            (stim_table['parameter'] == stim_param) &
            (stim_table['sweep'] == sweep) &
            (stim_table['good'] == True) &
            (stim_table['resting_trial'] == False)
        ].index.values
    else:
        return stim_table[
            (stim_table['stim_type'] == stim_type) &
            (stim_table['parameter'] == stim_param) &
            (stim_table['sweep'] == sweep) &
            (stim_table['good'] == True)
        ].index.values

In [12]:
def get_stim_event_times(stim_table, stim_type, stim_param, sweep, trials='resting'):
    if trials == 'resting':
        return stim_table[
            (stim_table['stim_type'] == stim_type) &
            (stim_table['parameter'] == stim_param) &
            (stim_table['sweep'] == sweep) &
            (stim_table['good'] == True) &
            (stim_table['resting_trial'] == True)
        ].onset.values
    elif trials == 'running':
        return stim_table[
            (stim_table['stim_type'] == stim_type) &
            (stim_table['parameter'] == stim_param) &
            (stim_table['sweep'] == sweep) &
            (stim_table['good'] == True) &
            (stim_table['resting_trial'] == False)
        ].onset.values
    else:
        return stim_table[
            (stim_table['stim_type'] == stim_type) &
            (stim_table['parameter'] == stim_param) &
            (stim_table['sweep'] == sweep) &
            (stim_table['good'] == True)
        ].onset.values

In [13]:
def get_zscore_fr(trig_FR, time_bins):
    
    trig_Z = np.zeros_like(trig_FR) * np.nan # try NaNs, it works
#     trig_Z = np.zeros_like(trig_FR) # try zeros
    baseline_bins = time_bins < 0
    baseline_avg = np.mean(trig_FR[:,baseline_bins[:-1]], axis=1)
    baseline_std = np.std(trig_FR[:,baseline_bins[:-1]], axis=1)
    nonzero_inds = np.nonzero(baseline_std)[0]
    trig_Z[nonzero_inds,:] = (trig_FR[nonzero_inds,:] - baseline_avg[nonzero_inds,None]) / baseline_std[nonzero_inds,None]
    
    return trig_Z, nonzero_inds

## Load subject

In [14]:
recfolder = r'F:\psi_exp\mouse669117\pilot_aw_2023-03-29_11-09-15\experiment1\recording1'
# exp = EEGexp(recfolder, preprocess=True, make_stim_csv=True)
exp = EEGexp(recfolder, preprocess=False, make_stim_csv=False)

Experiment type: electrical stimulation


In [15]:
plotsdir = r'C:\Users\lesliec\OneDrive - Allen Institute\data\plots\manuscript_figs\burst_analyses'

## Test with one subject

In [16]:
evoked_data_folder = os.path.join(exp.data_folder, 'evoked_data')

## Load unit info ##
fn_units_info = os.path.join(evoked_data_folder, 'all_units_info.csv')
if os.path.exists(fn_units_info):
    unit_info = pd.read_csv(fn_units_info)
    with open(os.path.join(evoked_data_folder, 'units_allspikes.pkl'), 'rb') as unit_file:
        all_unit_spikes = pickle.load(unit_file)
else:
    print('  {} not found. Not analyzing this subject\n.'.format(fn_units_stats))

In [17]:
regdf = unit_info[unit_info['parent_region'] == 'MO']
regdf.head()

Unnamed: 0,unit_id,probe,peak_ch,depth,spike_duration,region,CCF_AP,CCF_DV,CCF_ML,parent_region
1368,F901,probeF,264,1060,0.755444,MOs6a,129,104,190,MO
1369,F900,probeF,264,1060,0.631826,MOs6a,129,104,190,MO
1370,F532,probeF,264,1060,0.59062,MOs6a,129,104,190,MO
1371,F535,probeF,265,1060,0.247236,MOs6a,129,103,190,MO
1372,F534,probeF,265,1060,0.714238,MOs6a,129,103,190,MO


In [18]:
type(all_unit_spikes['F900'])

dict

In [23]:
len(all_unit_spikes['F535']['spikes'])

282068

### Method 1: loop through all spikes, test pre- and post- ISI

Test with one spike train

In [31]:
s1 = time.time()
alluspikes = all_unit_spikes['C24']['spikes']

burst_starts_m1 = []
burst_all_m1 = []
IN_BURST = False
for spi in range(1, len(alluspikes)-1):
    preISI = alluspikes[spi] - alluspikes[spi-1]
    postISI = alluspikes[spi+1] - alluspikes[spi]
    if (preISI > 0.1) & (postISI < 0.004):
        burst_starts_m1.append(spi)
        burst_all_m1.append(spi)
        IN_BURST = True
    elif IN_BURST & (preISI < 0.004):
        burst_all_m1.append(spi)
    else:
        IN_BURST = False
# print(len(burst_starts_m1))
e1 = time.time()
print((e1-s1))

0.08395862579345703


In [42]:
print(len(burst_starts_m1))

170


### Method 2: find starts using arrays, loop through starts

Test with one spike train

In [48]:
s2 = time.time()
## Burst-finding method 2: using arrays and a loop ##
alluspikes = all_unit_spikes['C28']['spikes']

preISIs = np.diff(alluspikes)[:-1]
postISIs = np.diff(alluspikes)[1:]

bs_inds = np.nonzero((preISIs > 0.1) * (postISIs < 0.004))[0]
burst_starts_m2 = bs_inds + 1 # +1 corrects for the actual spike ind

## Loop through burst starts to find spikes that belong to the burst
allbinds = bs_inds.tolist()
for st_ind in bs_inds:
    spkind = st_ind+1
    while preISIs[spkind] < 0.004:
        allbinds.append(spkind)
        spkind += 1
burst_all_m2 = np.sort(allbinds) + 1 # +1 corrects for the actual spike ind
e2 = time.time()
print((e2-s2))

0.001997709274291992


### Method 3: find bursts using ISI threshold only (for bursting in ctx, too)

Test with one spike train

In [34]:
ISI_threshold = 0.015 # ISI less than or equal to 15 ms
spike_count_thresh = 3 # at least this number of spikes to be considered burst

s3 = time.time()
## Burst-finding method 3: using an ISI threshold (pre spike ISI) only ##
alluspikes = all_unit_spikes['F535']['spikes']
preISIs = np.insert(np.diff(alluspikes), 0, 1.0)

burst_list = []
spkind = 0
while spkind < len(alluspikes):
    tempevent = [alluspikes[spkind]]
    spkind += 1
    while (spkind < len(alluspikes)) and (preISIs[spkind] < ISI_threshold):
        tempevent.append(alluspikes[spkind])
        spkind += 1
    if len(tempevent) >= spike_count_thresh:
        burst_list.append(tempevent)
    del tempevent
    
e3 = time.time()
print((e3-s3))

0.26186084747314453


In [29]:
print(len(alluspikes))
print(len(preISIs))

282068
282068


In [30]:
preISIs[:10]

array([1.        , 0.04756656, 0.05923321, 0.06933318, 0.05273322,
       0.02559994, 0.02406661, 0.00813332, 0.01476663, 0.01316664])

In [35]:
print(len(burst_list))

15289


In [36]:
burst_list[0:10]

[[74.95432721311225,
  74.96246052893498,
  74.97722716380986,
  74.99039380212945,
  74.9996604488455,
  75.01126042387136],
 [82.89291012175966,
  82.90121010389024,
  82.90451009678551,
  82.91241007977726,
  82.92614338354352,
  82.93774335856938],
 [82.95987664425091, 82.97444327955637, 82.98460992433479],
 [95.5190496050103, 95.52868291760362, 95.53768289822713],
 [95.55534952685846, 95.56944949650197, 95.584049465069],
 [104.16029766752314, 104.17176430950273, 104.18009762489487],
 [108.52722159909831,
  108.53905490695516,
  108.55155488004337,
  108.56538818359434,
  108.57262150135472],
 [108.62792138229696,
  108.63618803116597,
  108.64765467314555,
  108.66198797562004],
 [110.35001767471398,
  110.36385097826494,
  110.37531762024452,
  110.38895092422607,
  110.4035508927931,
  110.41725086329778],
 [112.20958033783978, 112.22211364418956, 112.23238028875267]]

Compare results from both methods

In [22]:
print(burst_starts_m1[0:15])
print(burst_starts_m2[0:15])

[582, 944, 959, 976, 995, 997, 1006, 1023, 1063, 1065, 1082, 1157, 1234, 1252, 2085]
[ 582  944  959  976  995  997 1006 1023 1063 1065 1082 1157 1234 1252
 2085]


In [23]:
print(len(burst_all_m1))
print(len(burst_all_m2))

3015
3015


In [24]:
print(burst_all_m1[0:20])
print(burst_all_m2[0:20])

[582, 583, 944, 945, 959, 960, 976, 977, 995, 996, 997, 998, 999, 1000, 1006, 1007, 1008, 1023, 1024, 1063]
[ 582  583  944  945  959  960  976  977  995  996  997  998  999 1000
 1006 1007 1008 1023 1024 1063]


### Method RT bursts: find bursts with ISI < 4 ms (4-8 spikes)

Test with one spike train

In [54]:
s2 = time.time()
## Burst-finding method 2: using arrays and a loop ##
alluspikes = all_unit_spikes['C28']['spikes']

preISIs = np.diff(alluspikes)[:-1]
postISIs = np.diff(alluspikes)[1:]

bs_inds = np.nonzero((preISIs > 0.004) * (postISIs < 0.004))[0]
RT_burst_starts = bs_inds + 1 # +1 corrects for the actual spike ind

all_bursts = []
allbinds = bs_inds.tolist()
for st_ind in bs_inds:
    spkind = st_ind+1
    while preISIs[spkind] < 0.004:
        allbinds.append(spkind)
        spkind += 1


IN_BURST = True
for st_ind in bs_inds:
    spcount = 1
    spkind = st_ind+1
    while (preISIs[spkind] < 0.004) and (postISIs[spkind] < 0.004):
        spcount = 1
        spkind += 1


# ## Loop through burst starts to find spikes that belong to the burst
# allbinds = bs_inds.tolist()
# for st_ind in bs_inds:
#     spkind = st_ind+1
#     while preISIs[spkind] < 0.004:
#         allbinds.append(spkind)
#         spkind += 1
# burst_all_m2 = np.sort(allbinds) + 1 # +1 corrects for the actual spike ind
# e2 = time.time()
# print((e2-s2))

In [56]:
np.diff(alluspikes[0:10])

array([0.01469997, 0.11766645, 0.01083331, 0.00283333, 0.00366666,
       0.01173331, 0.01046665, 0.01056665, 0.00446666])

In [55]:
print(RT_burst_starts)

[    3   295   403   470   480   503   622   647   686   711   733   765
   777   821   849   874   916   920   948   976   989  1013  1173  1227
  1263  1314  1319  1389  1506  1562  1570  1677  1827  1974  1977  2001
  2009  2011  2051  2100  2102  2104  2106  2126  2141  2210  2269  2308
  2318  2376  2459  2474  2477  2521  2532  2534  2617  2702  2760  2807
  2881  2912  2928  2957  2962  3005  3036  3087  3193  3195  3337  3358
  3408  3452  3480  3570  3585  3658  3660  3671  3681  3697  3710  3764
  3802  3836  3879  4082  4174  4176  4189  4267  4271  4379  4401  4462
  4502  4530  4546  4607  4622  4625  4733  4747  4863  4866  4907  4939
  4953  5015  5032  5052  5065  5089  5107  5126  5156  5174  5181  5212
  5239  5250  5305  5361  5374  5382  5394  5408  5425  5430  5447  5631
  5682  5692  5699  5715  5737  5746  5834  5842  5982  5997  6130  6132
  6157  6257  6305  6333  6385  6400  6418  6420  6423  6427  6485  6535
  6566  6579  6589  6596  6608  6622  6685  6702  6

### Compare timing of both methods for multiple neurons

Method 1: loop through all spikes

In [25]:
s1 = time.time()

reg_units = regdf['unit_id'].values
for uid in reg_units:
    alluspikes = all_unit_spikes[uid]

    burst_starts_m1 = []
    burst_all_m1 = []
    IN_BURST = False
    for spi in range(1, len(alluspikes)-1):
        preISI = alluspikes[spi] - alluspikes[spi-1]
        postISI = alluspikes[spi+1] - alluspikes[spi]
        if (preISI > 0.1) & (postISI < 0.004):
            burst_starts_m1.append(spi)
            burst_all_m1.append(spi)
            IN_BURST = True
        elif IN_BURST & (preISI < 0.004):
            burst_all_m1.append(spi)
        else:
            IN_BURST = False
    
e1 = time.time()
print('Method 1, all units, search time: {:.2f} s'.format(e1-s1))

Method 1, all units, search time: 14.13 s


Method 2: use arrays first

In [26]:
s2 = time.time()

reg_units = regdf['unit_id'].values
for uid in reg_units:
    alluspikes = all_unit_spikes[uid]
    preISIs = np.diff(alluspikes)[:-1]
    postISIs = np.diff(alluspikes)[1:]

    bs_inds = np.nonzero((preISIs > 0.1) * (postISIs < 0.004))[0]
    burst_starts_m2 = bs_inds + 1 # +1 corrects for the actual spike ind

    ## Loop through burst starts to find spikes that belong to the burst
    allbinds = bs_inds.tolist()
    for st_ind in bs_inds:
        spkind = st_ind+1
        while (spkind < len(preISIs)) and (preISIs[spkind] < 0.004):
            allbinds.append(spkind)
            spkind += 1
    burst_all_m2 = np.sort(allbinds) + 1 # +1 corrects for the actual spike ind

e2 = time.time()
print('Method 2, all units, search time: {:.2f} s'.format(e2-s2))

Method 2, all units, search time: 0.24 s


Compare results from both methods

In [27]:
print(len(burst_starts_m1))
print(len(burst_starts_m2))

470
470


In [28]:
print(burst_starts_m1[0:15])
print(burst_starts_m2[0:15])

[2075, 5226, 10217, 12848, 14794, 15327, 15486, 16158, 17323, 17327, 17365, 17937, 18156, 18605, 20098]
[ 2075  5226 10217 12848 14794 15327 15486 16158 17323 17327 17365 17937
 18156 18605 20098]


In [29]:
print(len(burst_all_m1))
print(len(burst_all_m2))

1048
1048


### Info to save?

This saves the times of the first spike of each burst ['start_times'], the times of all spikes in a burst ['all_times'], the number of spikes belonging to each burst ['burst_spike_counts'], and a list of the spike times that belong to each burst ['burst_spike_times'].
<br>**But probably the most relevant are the ['start_times'] and ['burst_spike_counts'].**

In [43]:
start = time.time()

reg_units = regdf['unit_id'].values
burst_info = {}

for uid in reg_units:
    alluspikes = all_unit_spikes[uid]
    preISIs = np.diff(alluspikes)[:-1]
    postISIs = np.diff(alluspikes)[1:]

    bs_inds = np.nonzero((preISIs > 0.1) * (postISIs < 0.004))[0]
    if len(bs_inds) == 0:
        continue
    burst_starts_m2 = bs_inds + 1 # +1 corrects for the actual spike ind

    ## Loop through burst starts to find spikes that belong to the burst
    allbinds = bs_inds.tolist()
    burst_counts = []
    all_burst_times = []
    for st_ind in bs_inds:
        spkind = st_ind+1
        ind_burst_times = [alluspikes[spkind]]
        bcount = 1
        while (spkind < len(preISIs)) and (preISIs[spkind] < 0.004):
            allbinds.append(spkind)
            spkind += 1
            ind_burst_times.append(alluspikes[spkind])
            bcount += 1
        burst_counts.append(bcount)
        all_burst_times.append(np.array(ind_burst_times))
    burst_all_m2 = np.sort(allbinds) + 1 # +1 corrects for the actual spike ind
    
    ## store burst info ##
    burst_info[uid] = {}
    burst_info[uid]['start_times'] = alluspikes[burst_starts_m2]
    burst_info[uid]['all_times'] = alluspikes[burst_all_m2]
    burst_info[uid]['burst_spike_counts'] = np.array(burst_counts)
    burst_info[uid]['burst_spike_times'] = all_burst_times
    
end = time.time()
print('Method 2, all units, search time: {:.2f} s'.format(end-start))

Method 2, all units, search time: 0.51 s


In [44]:
len(burst_info['B0']['burst_spike_counts'])

1343

In [45]:
len(burst_info['B0']['burst_spike_times'])

1343

In [51]:
test_ind = 5
print(burst_info['B0']['start_times'][test_ind])
print(burst_info['B0']['burst_spike_counts'][test_ind])
print(burst_info['B0']['burst_spike_times'][test_ind])

144.16321360599198
4
[144.16321361 144.16538027 144.16738026 144.17134692]


### Make it into a function

In [None]:
## Developed in NPX_find_bursts_testing.ipynb, this version is faster and only returns start times and spike counts ##
def find_bursts(unit_ids, all_spikes_dict):
    burst_info = {}
    for uid in unit_ids:
        alluspikes = all_spikes_dict[uid]
        preISIs = np.diff(alluspikes)[:-1]
        postISIs = np.diff(alluspikes)[1:]
        ## Find starts ##
        bs_inds = np.nonzero((preISIs > 0.1) * (postISIs < 0.004))[0]
        if len(bs_inds) == 0:
            continue
        burst_starts = bs_inds + 1 # +1 corrects for the actual spike ind
        ## Loop through burst starts to find spikes that belong to the burst
        burst_counts = []
        for st_ind in bs_inds:
            spkind = st_ind+1
            bcount = 1
            while (spkind < len(preISIs)) and (preISIs[spkind] < 0.004):
                spkind += 1
                bcount += 1
            burst_counts.append(bcount)
        ## Store burst info ##
        burst_info[uid] = {}
        burst_info[uid]['start_times'] = alluspikes[burst_starts]
        burst_info[uid]['burst_spike_counts'] = np.array(burst_counts)
    return burst_info