### This script loads subjects with probes, parses each unit's spike times (saves as a pkl file), and collects each unit's event-related spike times (for all events) and saves that as a pkl file, too.

In [1]:
import os
import sys
import json
import time
import gspread
import pickle
import numpy as np
import pandas as pd
from scipy import signal, fftpack, stats, ndimage
import statsmodels.stats.multitest as multitest

In [2]:
sys.path.append(r'C:\Users\lesliec\code')

In [3]:
from tbd_eeg.tbd_eeg.data_analysis.eegutils import EEGexp
from tbd_eeg.tbd_eeg.data_analysis.Utilities.utilities import (
    get_stim_events,
    get_evoked_traces,
    get_evoked_firing_rates,
    find_nearest_ind
)
from allensdk.brain_observatory.ecephys.lfp_subsampling.subsampling import remove_lfp_offset
from allensdk.core.mouse_connectivity_cache import MouseConnectivityCache

In [4]:
%matplotlib notebook

Load CCF for identifying cortical areas

In [5]:
mcc = MouseConnectivityCache(resolution=10)
str_tree = mcc.get_structure_tree()

Load Zap_Zip-log_exp to get metadata for experiments

In [6]:
_gc = gspread.service_account() # need a key file to access the account
_sh = _gc.open('Zap_Zip-log_exp') # open the spreadsheet
_df = pd.DataFrame(_sh.sheet1.get()) # load the first worksheet
zzmetadata = _df.T.set_index(0).T # put it in a nicely formatted dataframe

Define areas of interest to plot population activity

In [7]:
areas_of_interest = {
    'MO': [
        'MOp1', 'MOp2/3', 'MOp5', 'MOp6a', 'MOp6b',
        'MOs1', 'MOs2/3', 'MOs5', 'MOs6a', 'MOs6b'
    ],
    'ACA': [
        'ACAd1', 'ACAd2/3', 'ACAd5', 'ACAd6a', 'ACAd6b',
        'ACAv1', 'ACAv2/3', 'ACAv5', 'ACAv6a', 'ACAv6b'
    ],
    'SS': [
        'SSp-bfd1', 'SSp-bfd2/3', 'SSp-bfd4', 'SSp-bfd5', 'SSp-bfd6a', 'SSp-bfd6b',
        'SSp-ll1', 'SSp-ll2/3', 'SSp-ll4', 'SSp-ll5', 'SSp-ll6a', 'SSp-ll6b',
        'SSp-tr1', 'SSp-tr2/3', 'SSp-tr4', 'SSp-tr5', 'SSp-tr6a', 'SSp-tr6b'
    ],
    'VIS': [
        'VISp1', 'VISp2/3', 'VISp4', 'VISp5', 'VISp6a', 'VISp6b',
        'VISam1', 'VISam2/3', 'VISam4', 'VISam5', 'VISam6a', 'VISam6b',
        'VISpm1', 'VISpm2/3', 'VISpm4', 'VISpm5', 'VISpm6a', 'VISpm6b',
        'VISrl1', 'VISrl2/3', 'VISrl4', 'VISrl5', 'VISrl6a', 'VISrl6b',
    ],
    'MO-TH': [
        'AV', 'CL', 'MD', 'PO', 'RT', 'VAL', 'VPL', 'VPM', 'VM' # this is the default
#         'AV', 'CL', 'MD', 'PO', 'PF', 'VAL', 'VPL', 'VPM', 'VM' # consider removing RT for future
    ],
}

area_colors = {
    'MO': 'blue',
    'ACA': 'deepskyblue',
    'SS': 'blueviolet',
    'VIS': 'green',
    'MO-TH': 'steelblue',
}

#### Functions

In [8]:
## Developed in NPX_find_bursts_testing.ipynb, this version is faster and only returns start times and spike counts ##
def find_bursts(unit_ids, all_spikes_dict):
    burst_info = {}
    for uid in unit_ids:
        alluspikes = all_spikes_dict[uid]
        preISIs = np.diff(alluspikes)[:-1]
        postISIs = np.diff(alluspikes)[1:]
        ## Find starts ##
        bs_inds = np.nonzero((preISIs > 0.1) * (postISIs < 0.004))[0]
        if len(bs_inds) == 0:
            continue
        burst_starts = bs_inds + 1 # +1 corrects for the actual spike ind
        ## Loop through burst starts to find spikes that belong to the burst
        burst_counts = []
        for st_ind in bs_inds:
            spkind = st_ind+1
            bcount = 1
            while (spkind < len(preISIs)) and (preISIs[spkind] < 0.004):
                spkind += 1
                bcount += 1
            burst_counts.append(bcount)
        ## Store burst info ##
        burst_info[uid] = {}
        burst_info[uid]['start_times'] = alluspikes[burst_starts]
        burst_info[uid]['burst_spike_counts'] = np.array(burst_counts)
    return burst_info

In [9]:
## Developed in NPX_find_bursts_testing.ipynb, this version is faster and only returns start times and spike counts ##
def find_bursts_indunit(spike_times):
    
    preISIs = np.diff(spike_times)[:-1]
    postISIs = np.diff(spike_times)[1:]
    ## Find starts ##
    bs_inds = np.nonzero((preISIs > 0.1) * (postISIs < 0.004))[0]
    if len(bs_inds) == 0:
        return np.array([]), np.array([])
    
    burst_starts = bs_inds + 1 # +1 corrects for the actual spike ind
    ## Loop through burst starts to find spikes that belong to the burst
    burst_counts = []
    for st_ind in bs_inds:
        spkind = st_ind+1
        bcount = 1
        while (spkind < len(preISIs)) and (preISIs[spkind] < 0.004):
            spkind += 1
            bcount += 1
        burst_counts.append(bcount)
    
    return spike_times[burst_starts], np.array(burst_counts)

### Load subjects from file

In [10]:
with open(r'C:\Users\lesliec\OneDrive - Allen Institute\data\all_EEG_subjects.json') as subjects_file:
    multi_sub_dict = json.load(subjects_file)

### Load EEG_exp and gather units' spike times

In [13]:
overwrite_existing_files = False

event_window = [-2.0, 2.0] # find event-related spikes within this window (s)

In [14]:
for mouse_num, rec_folder in multi_sub_dict.items():
    exp = EEGexp(rec_folder, preprocess=False, make_stim_csv=False)
    exp_tag = exp.experiment_folder[exp.experiment_folder.find('mouse')+12:exp.experiment_folder.find(str(exp.date.year))-1]
    print('{}: {}'.format(mouse_num, exp_tag))
    
    probe_list = [x.replace('_sorted', '') for x in exp.experiment_data if 'probe' in x]
    if len(probe_list) == 0:
        print(' This experiment has no probe data, not making spike times files.\n')
        continue
    
    ## Set file names ##
    evoked_folder = os.path.join(exp.data_folder, 'evoked_data')
    if not os.path.exists(evoked_folder):
        os.mkdir(evoked_folder)
    unit_info_file = os.path.join(evoked_folder, 'all_units_info.csv')
    unit_allspiketimes_file = os.path.join(evoked_folder, 'units_allspikes.pkl')
    unit_eventspikes_file = os.path.join(evoked_folder, 'units_event_spikes.pkl')
    if overwrite_existing_files:
        pass # will overwrite all subjects' files
    else:
        if os.path.exists(unit_info_file):
            print('  {} already exists, skipping analysis.\n'.format(unit_info_file))
            continue
    
    ## Load stim log ##
    stim_log = pd.read_csv(exp.stimulus_log_file)
    all_event_times = stim_log['onset'].values
    
    ## Get probe info ##
    print('  Getting probe info...')
    probe_locs = np.ones((len(probe_list)), dtype=bool)
    probe_data = {}
    for pbi, probei in enumerate(probe_list):
        ## Load probe_info.json ##
        with open(exp.ephys_params[probei]['probe_info']) as data_file:
            data = json.load(data_file)
        if 'area_ch' in data.keys():
            probe_data[probei] = {}
            npx_allch = np.array(data['channel'])
            surface_ch = int(data['surface_channel'])
            allch_z = np.array(data['vertical_pos'])
            ref_mask = np.array(data['mask'])
            npx_chs = np.array([x for x in npx_allch if ref_mask[x] and x <= surface_ch])
            probe_data[probei]['ch_depths'] = allch_z[surface_ch] - allch_z

            ## Select units and get peak chs ##
            select_units, peak_chs, unit_metrics = exp.get_probe_units(probei)
            unit_metrics['cell_type'] = unit_metrics['duration'].apply(lambda x: 'FS' if x <= 0.4 else 'RS')
            ## Sort units ##
            probe_data[probei]['units'] = select_units[np.squeeze(np.argsort(peak_chs))]
            probe_data[probei]['chs'] = peak_chs[np.squeeze(np.argsort(peak_chs))]
            probe_data[probei]['cell_type'] = unit_metrics['cell_type'].values[np.squeeze(np.argsort(peak_chs))]
            probe_data[probei]['areas'] = unit_metrics.area.values[np.squeeze(np.argsort(peak_chs))]
            
            ## Load spike times and cluster ids ##
            probe_data[probei]['spike_times'] = np.load(exp.ephys_params[probei]['spike_times'])
            probe_data[probei]['spike_clusters'] = np.load(exp.ephys_params[probei]['spike_clusters'])
        else:
            print('  {} does not have area assignments, not processing.'.format(probei))
            probe_locs[pbi] = False

    if probe_locs.any():
        pass
    else:
        print('  NO area assignments for any probes, not analyzing.\n')
        continue
    
    ## Get unit info, spikes and event-spikes, then save files ##
    print('  Getting spike times...')
    start = time.time()
    all_units_info = []
    unit_allspiketimes = {}
    unit_eventspikestimes = {'event_window': event_window, 'event_spikes': {}, 'event_bursts': {}}
    for probei, pdata in probe_data.items():
        for unitind, uniti in enumerate(pdata['units']):
            unit_name = probei[-1] + str(uniti)
            unit_allspiketimes[unit_name] = {}
            ## Gather unit info ##
            unit_info = [unit_name] ## get unit_id ##
            unit_region = pdata['areas'][unitind]
            unit_info.extend([pdata['cell_type'][unitind], pdata['ch_depths'][pdata['chs'][unitind]], unit_region])
            parent_region = [key for key in list(areas_of_interest.keys()) if unit_region in areas_of_interest[key]]
            if len(parent_region) == 1:
                unit_info.append(parent_region[0])
            else:
                unit_info.append('notROI')
            all_units_info.append(unit_info)
            
            ## Get all and event spike times ##
            spikesi = np.squeeze(pdata['spike_times'][pdata['spike_clusters'] == uniti])
            unit_allspiketimes[unit_name]['spikes'] = spikesi
            burstsi, burst_counts = find_bursts_indunit(spikesi)
            unit_allspiketimes[unit_name]['bursts'] = burstsi
            unit_allspiketimes[unit_name]['burst_counts'] = burst_counts
            event_raster = []
            burst_raster = []
            burst_count_raster = []
            for eventi in all_event_times:
                spikeinds = np.nonzero((spikesi >= eventi + event_window[0]) & (spikesi <= eventi + event_window[1]))[0]
                event_raster.append(spikesi[spikeinds] - eventi)
                burstinds = np.nonzero((burstsi >= eventi + event_window[0]) & (burstsi <= eventi + event_window[1]))[0]
                burst_raster.append(burstsi[burstinds] - eventi)
                burst_count_raster.append(burst_counts[burstinds])
            unit_eventspikestimes['event_spikes'][unit_name] = event_raster
            unit_eventspikestimes['event_bursts'][unit_name] = {'times': burst_raster, 'counts': burst_count_raster}
            
    ## Save the data files to mouse's recordingX\evoked_data folder ##
    all_units_info_df = pd.DataFrame(
        all_units_info, columns=['unit_id', 'cell_type', 'depth', 'region', 'parent'])
    all_units_info_df.to_csv(unit_info_file, index=False)
    pickle.dump(unit_allspiketimes, open(unit_allspiketimes_file, 'wb'))
    pickle.dump(unit_eventspikestimes, open(unit_eventspikes_file, 'wb'))
    
    end = time.time()
    print('  Time to get unit spike times and save: {:.2f} min\n'.format((end-start)/60))
    ## After each subject, delete common variables ##
    del stim_log, probe_data, all_event_times, all_units_info, all_units_info_df, unit_allspiketimes, unit_eventspikestimes

Experiment type: electrical stimulation
SomnoSuite log file not found.
521885: estim1
 This experiment has no probe data, not making spike times files.

Experiment type: electrical stimulation
SomnoSuite log file not found.
521886: estim1
 This experiment has no probe data, not making spike times files.

Experiment type: electrical stimulation
SomnoSuite log file not found.
521887: estim1
 This experiment has no probe data, not making spike times files.

Experiment type: electrical stimulation
SomnoSuite log file not found.
543393: estim1
  Getting probe info...
  probeB does not have area assignments, not processing.
  NO area assignments for any probes, not analyzing.

Experiment type: electrical stimulation
SomnoSuite log file not found.
543394: estim1
  Getting probe info...
  probeB does not have area assignments, not processing.
  NO area assignments for any probes, not analyzing.

Experiment type: electrical stimulation
SomnoSuite log file not found.
543395: estim1
  Getting pro