This creates epoched files for evoked EEG and units responses for multiple subjects. Designed for use with main experimental subjects.

In [1]:
from datetime import date
from glob import glob
import json
import math
import os
from pathlib import Path
import pickle
import sys
import time

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import signal

In [2]:
sys.path.append(r'C:\Users\lesliec\code')

In [3]:
from tbd_eeg.tbd_eeg.data_analysis.eegutils import EEGexp
from tbd_eeg.tbd_eeg.data_analysis.Utilities.utilities import (
    get_stim_events, get_evoked_traces, find_nearest_ind, qualitycheck_trials)
from allensdk.core.mouse_connectivity_cache import MouseConnectivityCache

In [4]:
%matplotlib widget

#### Functions

In [5]:
## Developed in NPX_find_bursts_testing.ipynb, this version is faster and only returns start times and spike counts ##
## For thalamus ##
def find_bursts_THunit(spike_times):
    
    preISIs = np.diff(spike_times)[:-1]
    postISIs = np.diff(spike_times)[1:]
    ## Find starts ##
    bs_inds = np.nonzero((preISIs > 0.1) * (postISIs < 0.004))[0]
    if len(bs_inds) == 0:
        return np.array([]), np.array([])
    
    burst_starts = bs_inds + 1 # +1 corrects for the actual spike ind
    ## Loop through burst starts to find spikes that belong to the burst
    burst_counts = []
    for st_ind in bs_inds:
        spkind = st_ind+1
        bcount = 1
        while (spkind < len(preISIs)) and (preISIs[spkind] < 0.004):
            spkind += 1
            bcount += 1
        burst_counts.append(bcount)
    
    return spike_times[burst_starts], np.array(burst_counts)

In [6]:
## Finds bursts in non-thalamic units ##
def find_bursts_otherunit(spike_times):
    ISI_threshold = 0.01 # ISI less than or equal to 15 ms
    spike_count_thresh = 3 # at least this number of spikes to be considered burst

    preISIs = np.insert(np.diff(spike_times), 0, 1.0)
    burst_starts = []
    burst_counts = []
    spkind = 0
    while spkind < len(spike_times):
        tempevent = [spike_times[spkind]]
        spkind += 1
        while (spkind < len(spike_times)) and (preISIs[spkind] < ISI_threshold):
            tempevent.append(spike_times[spkind])
            spkind += 1
        if len(tempevent) >= spike_count_thresh:
            burst_starts.append(tempevent[0])
            burst_counts.append(len(tempevent))
        del tempevent
    
    return np.array(burst_starts), np.array(burst_counts)

In [7]:
def find_closest_region(sunit_info, struct_tree, annot):
    ## Finds a grey matter region above/below an unknown region ##
    Vind = sunit_info.CCF_DV
    vent_sip = struct_tree.get_structures_by_id([annot[sunit_info.CCF_AP, Vind, sunit_info.CCF_ML]])[0]['structure_id_path']
    while not struct_tree.structure_descends_from(vent_sip[-1], 8):
        Vind += 1
        vent_sip = struct_tree.get_structures_by_id([annot[sunit_info.CCF_AP, Vind, sunit_info.CCF_ML]])[0]['structure_id_path']

    Dind = sunit_info.CCF_DV
    dors_sip = struct_tree.get_structures_by_id([annot[sunit_info.CCF_AP, Dind, sunit_info.CCF_ML]])[0]['structure_id_path']
    while not struct_tree.structure_descends_from(dors_sip[-1], 8):
        Dind -= 1
        dors_sip = struct_tree.get_structures_by_id([annot[sunit_info.CCF_AP, Dind, sunit_info.CCF_ML]])[0]['structure_id_path']

    if (Vind - sunit_info.CCF_DV) <= (sunit_info.CCF_DV - Dind):
        return struct_tree.get_structures_by_id([vent_sip[-1]])[0]['acronym']
    elif (Vind - sunit_info.CCF_DV) > (sunit_info.CCF_DV - Dind):
        return struct_tree.get_structures_by_id([dors_sip[-1]])[0]['acronym']

In [8]:
def get_region_from_children(test_id, parent_id, struct_tree):
    try:
        child_ind = np.nonzero([
            struct_tree.structure_descends_from(test_id, x) for x in struct_tree.child_ids([parent_id])[0]
        ])[0][0]
        return struct_tree.get_structures_by_id([struct_tree.child_ids([parent_id])[0][child_ind]])[0]['acronym']
    except:
        return struct_tree.get_structures_by_id([parent_id])[0]['acronym']

In [9]:
def get_parent_region(region_acronym, struct_tree):
    # areas_of_interest = {
    #     'SM-TH': ['AV', 'CL', 'MD', 'PO', 'PF', 'VAL', 'VPL', 'VPM', 'VM'],
    # }
    reg_id = struct_tree.get_structures_by_acronym([region_acronym])[0]['id']
    if struct_tree.structure_descends_from(reg_id, 567):
        if struct_tree.structure_descends_from(reg_id, 315):
            return get_region_from_children(reg_id, 315, struct_tree)
        elif struct_tree.structure_descends_from(reg_id, 698):
            return 'OLF'
        elif struct_tree.structure_descends_from(reg_id, 1089):
            return get_region_from_children(reg_id, 1089, struct_tree)
        elif struct_tree.structure_descends_from(reg_id, 703):
            return get_region_from_children(reg_id, 703, struct_tree)
        elif struct_tree.structure_descends_from(reg_id, 477):
            return 'STR'
        elif struct_tree.structure_descends_from(reg_id, 803):
            return 'PAL'
        else:
            return 'unassigned'
    elif struct_tree.structure_descends_from(reg_id, 343):
        if struct_tree.structure_descends_from(reg_id, 1129):
            return 'TH'
            # if region_acronym == 'RT':
            #     return 'RT-TH'
            # elif region_acronym in areas_of_interest['SM-TH']:
            #     return 'SM-TH'
            # else:
            #     return 'other-TH'
        elif struct_tree.structure_descends_from(reg_id, 1097):
            return 'HY'
        else:
            return get_region_from_children(reg_id, 343, struct_tree)
    else:
        return 'unassigned'

In [10]:
def add_parent_region_to_df(unit_info_df, struct_tree, annot):
    ## First, make sure all names in region column correspond to a CCF region (removes nan values) ##
    adj_regions = unit_info_df['region'].values.copy()
    for indi, rowi in unit_info_df.iterrows():
        try:
            str_info = struct_tree.get_structures_by_acronym([rowi.region])[0]
        except KeyError:
            if rowi.depth <= 0: # unit was placed above brain
                new_region_id = annot[rowi.CCF_AP, np.nonzero(annot[rowi.CCF_AP, :, rowi.CCF_ML])[0][0], rowi.CCF_ML]
                adj_regions[indi] = struct_tree.get_structures_by_id([new_region_id])[0]['acronym']
            else:
                Lind = rowi.CCF_ML
                while annot[rowi.CCF_AP, rowi.CCF_DV, Lind] == 0:
                    Lind -= 1
                new_region_id = struct_tree.get_structures_by_id(
                    [annot[rowi.CCF_AP, rowi.CCF_DV, Lind]])[0]['structure_id_path'][-1]
                adj_regions[indi] = struct_tree.get_structures_by_id([new_region_id])[0]['acronym']
    unit_info_df['adj_region'] = adj_regions
    
    ## Second, re-assign any non-grey matter areas to the closest region ##
    adj_regions = unit_info_df['adj_region'].values.copy()
    for indi, rowi in unit_info_df.iterrows():
        reg_id = struct_tree.get_structures_by_acronym([rowi.adj_region])[0]['id']
        if not struct_tree.structure_descends_from(reg_id, 8):
            adj_regions[indi] = find_closest_region(rowi, struct_tree, annot)
    unit_info_df['adj_region'] = adj_regions
    
    ## Finally, assign a parent region to each adjusted CCF region ##
    parent_regions = unit_info_df['adj_region'].values.copy()
    for indi, rowi in unit_info_df.iterrows():
        parent_regions[indi] = get_parent_region(rowi.adj_region, struct_tree)
    unit_info_df['parent_region'] = parent_regions
    
    return unit_info_df.drop('adj_region', axis=1)

# Load subjects.csv file

In [11]:
multisub_file = r"C:\Users\lesliec\OneDrive - Allen Institute\Shared Documents - Lab 328\Projects\CL-CM stim\subject_metadata.csv"
subject_df = pd.read_csv(multisub_file)
subject_df.head()

Unnamed: 0,mouse,sex,DOB,strain,exp_folder,histology,EEG,stim_tip_distance
0,mouse771424,F,,C57BL/6J,THstim_d1_2024-11-14_11-28-16,True,False,300
1,mouse771424,F,,C57BL/6J,THstim_d2_2024-11-15_10-51-49,True,False,300
2,mouse771425,F,,C57BL/6J,THstim_d1_2024-11-21_10-59-24,True,False,300
3,mouse771425,F,,C57BL/6J,THstim_d2_2024-11-22_10-49-58,True,False,300
4,mouse771426,M,,C57BL/6J,THstim_d1_2024-12-19_12-19-39,True,False,100


# Process data

In [44]:
overwrite_existing_files = True
data_dir = Path(r"P:\\")

## Process the running signal

In [12]:
for rowi, exprow in subject_df.iterrows():
    print('{}: {}'.format(exprow.mouse, exprow.exp_folder))
    data_path = os.path.join(data_dir, exprow.mouse, exprow.exp_folder, 'experiment1', 'recording1')
    exp = EEGexp(data_path, preprocess=False, make_stim_csv=False)

    ## Set file names ##
    running_file = os.path.join(exp.data_folder, 'running_signal.npy')
    raw_running_file = os.path.join(exp.data_folder, 'raw_running_signal.npy')
    running_ts_file = os.path.join(exp.data_folder, 'running_timestamps_master_clock.npy')

    if not os.path.exists(running_ts_file) or overwrite_existing_files:
        run_signal, raw_signal, run_timestamps = exp.load_running()
        np.save(running_file, run_signal, allow_pickle=False)
        np.save(raw_running_file, raw_signal, allow_pickle=False)
        np.save(running_ts_file, run_timestamps, allow_pickle=False)
        print(' Saved running signals.\n')
    else:
        print(' Running signals already exist?\n')

mouse771424: THstim_d1_2024-11-14_11-28-16
Experiment type: electrical stimulation
 Saved running signals.

mouse771424: THstim_d2_2024-11-15_10-51-49
Experiment type: electrical stimulation
 Saved running signals.

mouse771425: THstim_d1_2024-11-21_10-59-24
Experiment type: electrical stimulation
 Saved running signals.

mouse771425: THstim_d2_2024-11-22_10-49-58
Experiment type: electrical stimulation
 Saved running signals.

mouse771426: THstim_d1_2024-12-19_12-19-39
Experiment type: electrical stimulation
 Saved running signals.

mouse771426: THstim_d2_2024-12-20_09-28-35
Experiment type: electrical stimulation
 Saved running signals.

mouse771427: THstim_d1_2025-01-22_10-36-51
Experiment type: electrical stimulation
mouse771427: THstim_d2_2025-01-23_10-25-10
Experiment type: electrical stimulation
mouse771427: THstim_d3_2025-01-24_10-40-15
Experiment type: electrical and sensory stimulation


## Process units

In [45]:
mcc = MouseConnectivityCache(resolution=25)
str_tree = mcc.get_structure_tree()
annot, annot_info = mcc.get_annotation_volume()

spike_count_threshold = 20

In [46]:
for eind, exprow in subject_df.iterrows():
    print('{}: {}'.format(exprow.mouse, exprow.exp_folder))
    data_path = os.path.join(data_dir, exprow.mouse, exprow.exp_folder, 'experiment1', 'recording1')
    exp = EEGexp(data_path, preprocess=False, make_stim_csv=False)

    probe_list = [x.replace('_sorted', '') for x in exp.experiment_data if 'probe' in x]
    if len(probe_list) == 0:
        print(' This experiment has no probe data, not making spike times files.\n')
        continue

    ## Set file names ##
    unit_info_file = os.path.join(exp.data_folder, 'all_units_info.csv')
    unit_allspiketimes_file = os.path.join(exp.data_folder, 'units_allspikes.pkl')

    if overwrite_existing_files:
        pass # will overwrite all subjects' files
    else:
        if os.path.exists(unit_info_file):
            print(' {} already exists, skipping analysis.\n'.format(unit_info_file))
            continue

    ## Get probe info ##
    print(' Getting probe info...')
    start = time.time()
    all_units_info = []
    unit_allspiketimes = {}
    for probe_name in probe_list:
        print('  {}'.format(probe_name))
        ## Load probe_info.json ##
        with open(exp.ephys_params[probe_name]['probe_info']) as data_file:
            data = json.load(data_file)
        npx_allch = np.array(data['channel'])
        surface_ch = int(data['surface_channel'])
        allch_z = np.array(data['vertical_pos'])
        # ref_mask = np.array(data['mask'])
        # npx_chs = np.array([x for x in npx_allch if ref_mask[x] and x <= surface_ch])
        probe_ch_depths = allch_z[surface_ch] - allch_z
        
        ## Load the unit info ##
        cluster_group = pd.read_csv(exp.ephys_params[probe_name]['cluster_group'], sep='\t')
        cluster_metrics = pd.read_csv(exp.ephys_params[probe_name]['cluster_metrics'])
        spike_clusters = np.load(exp.ephys_params[probe_name]['spike_clusters'])
        spike_times = np.load(exp.ephys_params[probe_name]['spike_times'])
        
        if not np.array_equal(cluster_group['cluster_id'].values.astype('int'), np.unique(spike_clusters)):
            print('   IDs from cluster_group.tsv DO NOT match spike_clusters.npy. This may mean there are unsorted units, check in phy.')
            continue
        if np.array_equal(cluster_group['cluster_id'].values.astype('int'), cluster_metrics['cluster_id'].values.astype('int')):
            unit_metrics = pd.merge(cluster_group.rename(columns={'group':'label'}), cluster_metrics, on='cluster_id')
        else:
            print('   IDs from cluster_group DO NOT match cluster_metrics.')
            continue
        
        ## Select only "good" units ##
        unit_metrics['spike_count'] = [np.sum(spike_clusters == x) for x in unit_metrics.cluster_id.values]
        good_units = unit_metrics[(unit_metrics['label'] == 'good') & (unit_metrics['spike_count'] > spike_count_threshold)]
        tempcoords = np.array([[int(y) for y in x.replace('[','').replace(']','').replace(' ','').split(',')] for x in good_units.ccf_coord.values])
        
        ## Now reorganize metrics to save ##
        probe_units = pd.DataFrame([probe_name[-1] + str(x) for x in good_units.cluster_id.values], columns=['unit_id'])
        probe_units['probe'] = [probe_name] * len(good_units)
        probe_units['peak_ch'] = good_units['peak_channel'].values
        probe_units['depth'] = [probe_ch_depths[pkch] for pkch in good_units.peak_channel.values]
        probe_units['spike_duration'] = good_units['duration'].values
        probe_units['region'] = good_units['area'].values
        probe_units['CCF_AP'], probe_units['CCF_DV'], probe_units['CCF_ML'] = tempcoords[:,0], tempcoords[:,1], tempcoords[:,2]
        probe_units['firing_rate'] = good_units['firing_rate'].values
        probe_units['presence_ratio'] = good_units['presence_ratio'].values
        probe_units['isi_viol'] = good_units['isi_viol'].values
        probe_units['amplitude_cutoff'] = good_units['amplitude_cutoff'].values
        probe_units['spike_count'] = good_units['spike_count'].values
        
        ## Add parent region column ##
        probe_units = add_parent_region_to_df(probe_units, str_tree, annot)
        all_units_info.append(probe_units)

        for uniti, (uind, unitrow) in zip(good_units.cluster_id.values, probe_units.iterrows()):
            spikesi = np.squeeze(spike_times[spike_clusters == uniti])
            if unitrow.parent_region == 'TH':
                burstsi, burst_counts = find_bursts_THunit(spikesi)
            else:
                burstsi, burst_counts = find_bursts_otherunit(spikesi)
            unit_allspiketimes[unitrow.unit_id] = {
                'spikes': spikesi,
                'bursts': burstsi,
                'burst_counts': burst_counts,
            }

    ## Now combine all probe units dfs ##
    all_units_info_df = pd.concat(all_units_info, ignore_index=True)
    all_units_info_df.to_csv(unit_info_file, index=False)
    pickle.dump(unit_allspiketimes, open(unit_allspiketimes_file, 'wb'))

    end = time.time()
    print('   Time to get unit spike times and save: {:.2f} min\n'.format((end-start)/60))
    ## After each subject, delete common variables ##
    del probe_units, all_units_info, all_units_info_df, unit_allspiketimes

mouse771424: THstim_d1_2024-11-14_11-28-16
Experiment type: electrical stimulation
 Getting probe info...
  probeA
  probeD
  probeE
  probeF
   Time to get unit spike times and save: 0.67 min

mouse771424: THstim_d2_2024-11-15_10-51-49
Experiment type: electrical stimulation
 Getting probe info...
  probeA
  probeD
  probeE
  probeF
   Time to get unit spike times and save: 0.44 min

mouse771425: THstim_d1_2024-11-21_10-59-24
Experiment type: electrical stimulation
 Getting probe info...
  probeD
  probeE
  probeF
   Time to get unit spike times and save: 0.70 min

mouse771425: THstim_d2_2024-11-22_10-49-58
Experiment type: electrical stimulation
 Getting probe info...
  probeA
  probeD
  probeE
  probeF
   Time to get unit spike times and save: 0.99 min

mouse771426: THstim_d1_2024-12-19_12-19-39
Experiment type: electrical stimulation
 Getting probe info...
  probeA
  probeD
  probeE
  probeF
   Time to get unit spike times and save: 0.30 min

mouse771426: THstim_d2_2024-12-20_09-28

### Testing on a single subject

In [15]:
exprow = subject_df.iloc[8]
print('{}: {}'.format(exprow.mouse, exprow.exp_folder))
data_path = os.path.join(data_dir, exprow.mouse, exprow.exp_folder, 'experiment1', 'recording1')
exp = EEGexp(data_path, preprocess=False, make_stim_csv=False)

mouse771427: THstim_d3_2025-01-24_10-40-15
Experiment type: electrical and sensory stimulation


In [16]:
probe_list = [x.replace('_sorted', '') for x in exp.experiment_data if 'probe' in x]
print(probe_list)

['probeD', 'probeE', 'probeF']


Loop through probes

In [43]:
probe_name = 'probeD'
spike_count_threshold = 20

## Load probe_info.json ##
with open(exp.ephys_params[probe_name]['probe_info']) as data_file:
    data = json.load(data_file)
npx_allch = np.array(data['channel'])
surface_ch = int(data['surface_channel'])
allch_z = np.array(data['vertical_pos'])
ref_mask = np.array(data['mask'])
npx_chs = np.array([x for x in npx_allch if ref_mask[x] and x <= surface_ch])
probe_ch_depths = allch_z[surface_ch] - allch_z

## Load the unit info ##
cluster_group = pd.read_csv(exp.ephys_params[probe_name]['cluster_group'], sep='\t')
cluster_metrics = pd.read_csv(exp.ephys_params[probe_name]['cluster_metrics'])
spike_clusters = np.load(exp.ephys_params[probe_name]['spike_clusters'])
spike_times = np.load(exp.ephys_params[probe_name]['spike_times'])

if not np.array_equal(cluster_group['cluster_id'].values.astype('int'), np.unique(spike_clusters)):
    print('IDs from cluster_group.tsv DO NOT match spike_clusters.npy. This may mean there are unsorted units, check in phy.')
if np.array_equal(cluster_group['cluster_id'].values.astype('int'), cluster_metrics['cluster_id'].values.astype('int')):
    unit_metrics = pd.merge(cluster_group.rename(columns={'group':'label'}), cluster_metrics, on='cluster_id')
else:
    print('IDs from cluster_group DO NOT match cluster_metrics.')

## Select only "good" units ##
unit_metrics['spike_count'] = [np.sum(spike_clusters == x) for x in unit_metrics.cluster_id.values]
good_units = unit_metrics[(unit_metrics['label'] == 'good') & (unit_metrics['spike_count'] > spike_count_threshold)]
tempcoords = np.array([[int(y) for y in x.replace('[','').replace(']','').replace(' ','').split(',')] for x in good_units.ccf_coord.values])

## Now reorganize metrics to save ##
probe_units = pd.DataFrame([probe_name[-1] + str(x) for x in good_units.cluster_id.values], columns=['unit_id'])
probe_units['probe'] = [probe_name] * len(good_units)
probe_units['peak_ch'] = good_units['peak_channel'].values
probe_units['depth'] = [probe_ch_depths[pkch] for pkch in good_units.peak_channel.values]
probe_units['spike_duration'] = good_units['duration'].values
probe_units['region'] = good_units['area'].values
probe_units['CCF_AP'], probe_units['CCF_DV'], probe_units['CCF_ML'] = tempcoords[:,0], tempcoords[:,1], tempcoords[:,2]
probe_units['firing_rate'] = good_units['firing_rate'].values
probe_units['presence_ratio'] = good_units['presence_ratio'].values
probe_units['isi_viol'] = good_units['isi_viol'].values
probe_units['amplitude_cutoff'] = good_units['amplitude_cutoff'].values
probe_units['spike_count'] = good_units['spike_count'].values

In [42]:
## Add parent region column ##
probe_units = add_parent_region_to_df(probe_units, str_tree, annot)
probe_units.head()

Unnamed: 0,unit_id,probe,peak_ch,depth,spike_duration,region,CCF_AP,CCF_DV,CCF_ML,firing_rate,presence_ratio,isi_viol,amplitude_cutoff,spike_count,parent_region
0,D0,probeD,9,3620,0.274707,MRN,329,160,199,1.869871,0.37,0.021218,0.027682,9447.0,MB
1,D3,probeD,1,3700,0.247236,MRN,328,163,201,0.007917,0.26,0.0,0.4581,,MB
2,D7,probeD,1,3700,0.315913,MRN,328,163,201,2.930395,0.73,0.172786,0.315145,,MB
3,D10,probeD,1,3700,0.61809,MRN,328,163,201,0.008313,0.27,0.0,0.5,40.0,MB
4,D12,probeD,5,3660,0.329648,MRN,329,161,200,6.395605,0.98,0.041715,0.157468,,MB


In [19]:
print(np.unique(probe_units.parent_region.values))

['MB' 'RHP' 'VIS']


In [20]:
probe_units.spike_count.values

array([  9447,     40,      6,  14805,     42,  32312, 114746, 116360,
        14924,  42674,  17657,    142,     31, 218195,  34801,     12,
       136522, 140276, 102046,   6785, 101833,  31264,  21339,  55547,
           28, 104193,  23132,     16,  15051,  63218,   4861,     25,
       197405,     32,  73147,  47288,      6, 130581,  10987,  28972,
          265,    221,  27654, 131489,     51,  95909,     33,     50,
            6,     50,  82921,  10385,  12010,   9651,  69487,   9064,
        39303,      2,  12663,  34226,  54585,   1633,  47629,  70484,
          501,  84338,  71143, 161154,   4169,  28293,  74298,  23758,
           13,   7821,  18234,  34639,   3900,      3,   1700,   1718,
        89287,  64997,  16507,  12520, 136903,  38629,    326,   3523,
         4010,   7453,  41013,  33792,  10475,  32306,  32399,  53756,
        21126, 108555,     17,     35,     33, 114043,  12785,  23336,
        19825,    115,  14871,  30052,   2449,  16460,   4840,  26515,
      

Now loop through units

In [38]:
unitrow = probe_units.iloc[238]
uniti = good_units.cluster_id.values[238]
print(unitrow.unit_id)
print(uniti)

D523
523


In [27]:
if unitrow.parent_region == 'TH':
    burstsi, burst_counts = find_bursts_THunit(spikesi)
else:
    burstsi, burst_counts = find_bursts_otherunit(spikesi)

ValueError: diff requires input that is at least one dimensional

In [24]:
print(len(burstsi))

0


# Old stuff

### Check trial quality for EEG signals

### Set parameters

### Process running signal and EEG

### Process units

In [15]:
for mouse, explist in subjects.items():
    for exptype, dataloc in explist.items():
        print('{}: {}'.format(mouse, exptype))
        exp = EEGexp(dataloc, preprocess=False, make_stim_csv=False)
        
        probe_list = [x.replace('_sorted', '') for x in exp.experiment_data if 'probe' in x]
        if len(probe_list) == 0:
            print(' This experiment has no probe data, not making spike times files.\n')
            continue
        
        ## Set file names ##
        evoked_folder = os.path.join(exp.data_folder, 'evoked_data')
        if not os.path.exists(evoked_folder):
            os.mkdir(evoked_folder)
        unit_info_file = os.path.join(evoked_folder, 'all_units_info.csv')
        unit_allspiketimes_file = os.path.join(evoked_folder, 'units_allspikes.pkl')
        unit_eventspikes_file = os.path.join(evoked_folder, 'units_event_spikes.pkl')
        if overwrite_existing_files:
            pass # will overwrite all subjects' files
        else:
            if os.path.exists(unit_info_file):
                print('  {} already exists, skipping analysis.\n'.format(unit_info_file))
                continue

        ## Load stim log ##
        stim_log = pd.read_csv(exp.stimulus_log_file)
        all_event_times = stim_log['onset'].values
        
        ## Get probe info ##
        print('  Getting probe info...')
        probe_data = {}
        for pbi, probei in enumerate(probe_list):
            probe_data[probei] = {}
            ## Load probe_info.json ##
            with open(exp.ephys_params[probei]['probe_info']) as data_file:
                data = json.load(data_file)
            npx_allch = np.array(data['channel'])
            surface_ch = int(data['surface_channel'])
            allch_z = np.array(data['vertical_pos'])
            ref_mask = np.array(data['mask'])
            npx_chs = np.array([x for x in npx_allch if ref_mask[x] and x <= surface_ch])
            probe_data[probei]['ch_depths'] = allch_z[surface_ch] - allch_z
            
            ## Select units and get peak chs ##
            select_units, peak_chs, unit_metrics = exp.get_probe_units(probei)
            ## Sort units ##
            probe_data[probei]['units'] = select_units[np.squeeze(np.argsort(peak_chs))]
            probe_data[probei]['chs'] = peak_chs[np.squeeze(np.argsort(peak_chs))]
            probe_data[probei]['duration'] = unit_metrics.duration.values[np.squeeze(np.argsort(peak_chs))]
            
            ## Load spike times and cluster ids ##
            probe_data[probei]['spike_times'] = np.load(exp.ephys_params[probei]['spike_times'])
            probe_data[probei]['spike_clusters'] = np.load(exp.ephys_params[probei]['spike_clusters'])
            
            if 'area_ch' in data.keys():
                probe_data[probei]['areas'] = unit_metrics.area.values[np.squeeze(np.argsort(peak_chs))]
                probe_data[probei]['CCF_coords'] = unit_metrics.ccf_coord.values[np.squeeze(np.argsort(peak_chs))]
                
        ## Get unit info, spikes and event-spikes, then save files ##
        print('  Getting spike times...')
        start = time.time()
        all_units_info = []
        unit_allspiketimes = {}
        unit_eventspikestimes = {'event_window': event_window, 'event_spikes': {}, 'event_bursts': {}}
        for probei, pdata in probe_data.items():
            for unitind, uniti in enumerate(pdata['units']):
                unit_name = probei[-1] + str(uniti)
                spikesi = np.squeeze(pdata['spike_times'][pdata['spike_clusters'] == uniti])
                if spikesi.size < 50:
                    continue
                unit_allspiketimes[unit_name] = {}
                
                ## Gather unit info ##
                if 'areas' in pdata.keys():
                    unit_region = pdata['areas'][unitind]
                    unit_coords = [
                        int(x) for x in pdata['CCF_coords'][unitind].replace('[','').replace(']','').replace(' ','').split(',')
                    ]
                else:
                    unit_region = 'none'
                    unit_coords = [-1, -1, -1]
                all_units_info.append([
                    unit_name, probei, pdata['chs'][unitind], pdata['ch_depths'][pdata['chs'][unitind]],
                    pdata['duration'][unitind], unit_region, unit_coords[0], unit_coords[1], unit_coords[2]
                ])

                ## Get all and event spike times ##
                unit_allspiketimes[unit_name]['spikes'] = spikesi
                burstsi, burst_counts = find_bursts_indunit(spikesi)
                unit_allspiketimes[unit_name]['bursts'] = burstsi
                unit_allspiketimes[unit_name]['burst_counts'] = burst_counts
                event_raster = []
                burst_raster = []
                burst_count_raster = []
                for eventi in all_event_times:
                    spikeinds = np.nonzero((spikesi >= eventi + event_window[0]) & (spikesi <= eventi + event_window[1]))[0]
                    event_raster.append(spikesi[spikeinds] - eventi)
                    burstinds = np.nonzero((burstsi >= eventi + event_window[0]) & (burstsi <= eventi + event_window[1]))[0]
                    burst_raster.append(burstsi[burstinds] - eventi)
                    burst_count_raster.append(burst_counts[burstinds])
                unit_eventspikestimes['event_spikes'][unit_name] = event_raster
                unit_eventspikestimes['event_bursts'][unit_name] = {'times': burst_raster, 'counts': burst_count_raster}

        ## Save the data files to mouse's recordingX\evoked_data folder ##
        all_units_info_df = pd.DataFrame(
            all_units_info,
            columns=['unit_id', 'probe', 'peak_ch', 'depth', 'spike_duration', 'region', 'CCF_AP', 'CCF_DV', 'CCF_ML']
        )
        
        ## Add parent region column ##
        if len(np.unique(all_units_info_df['region'].values.astype(str))) > 1:
            print('  Adding parent region...')
#             sub_CCF_res = subject_df[subject_df['mouse'] == mouse]['CCF_res'].iloc[0]
            sub_CCF_res = 25
            mcc = MouseConnectivityCache(resolution=sub_CCF_res)
            str_tree = mcc.get_structure_tree()
            annot, annot_info = mcc.get_annotation_volume()
            all_units_info_df = add_parent_region_to_df(all_units_info_df, str_tree, annot)
        
        all_units_info_df.to_csv(unit_info_file, index=False)
        pickle.dump(unit_allspiketimes, open(unit_allspiketimes_file, 'wb'))
        pickle.dump(unit_eventspikestimes, open(unit_eventspikes_file, 'wb'))

        end = time.time()
        print('  Time to get unit spike times and save: {:.2f} min\n'.format((end-start)/60))
        ## After each subject, delete common variables ##
        del stim_log, probe_data, all_event_times, all_units_info, all_units_info_df, unit_allspiketimes, unit_eventspikestimes

655955: saline
Experiment type: electrical stimulation
  Getting probe info...
  Getting spike times...
  Adding parent region...
  Time to get unit spike times and save: 0.99 min

