In [1]:
from datetime import date
from glob import glob
import json
import math
import os
import sys
import time
import pickle

import gspread
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import signal

In [2]:
sys.path.append(r'C:\Users\lesliec\code')

In [3]:
from tbd_eeg.tbd_eeg.data_analysis.eegutils import EEGexp
from tbd_eeg.tbd_eeg.data_analysis.Utilities.utilities import (
    get_stim_events, get_evoked_traces, find_nearest_ind, qualitycheck_trials)
from allensdk.core.mouse_connectivity_cache import MouseConnectivityCache

In [4]:
%matplotlib notebook

# Load CCF, 25 um resolution

In [5]:
mcc = MouseConnectivityCache(resolution=25)
str_tree = mcc.get_structure_tree()
annot, annot_info = mcc.get_annotation_volume()

2025-05-19 10:37:29,967 allensdk.api.api.retrieve_file_over_http INFO     Downloading URL: http://download.alleninstitute.org/informatics-archive/current-release/mouse_ccf/annotation/ccf_2017/annotation_25.nrrd


# Functions

In [6]:
## Developed in NPX_find_bursts_testing.ipynb, this version is faster and only returns start times and spike counts ##
## For thalamus ##
def find_bursts_THunit(spike_times):
    
    preISIs = np.diff(spike_times)[:-1]
    postISIs = np.diff(spike_times)[1:]
    ## Find starts ##
    bs_inds = np.nonzero((preISIs > 0.1) * (postISIs < 0.005))[0]
    if len(bs_inds) == 0:
        return np.array([]), np.array([])
    
    burst_starts = bs_inds + 1 # +1 corrects for the actual spike ind
    ## Loop through burst starts to find spikes that belong to the burst
    burst_counts = []
    for st_ind in bs_inds:
        spkind = st_ind+1
        bcount = 1
        while (spkind < len(preISIs)) and (preISIs[spkind] < 0.004):
            spkind += 1
            bcount += 1
        burst_counts.append(bcount)
    
    return spike_times[burst_starts], np.array(burst_counts)

In [7]:
## Finds bursts in non-thalamic units ##
def find_bursts_otherunit(spike_times):
    ISI_threshold = 0.005 # ISI less than or equal to 5 ms
    spike_count_thresh = 3 # at least this number of spikes to be considered burst

    preISIs = np.insert(np.diff(spike_times), 0, 1.0)
    burst_starts = []
    burst_counts = []
    spkind = 0
    while spkind < len(spike_times):
        tempevent = [spike_times[spkind]]
        spkind += 1
        while (spkind < len(spike_times)) and (preISIs[spkind] < ISI_threshold):
            tempevent.append(spike_times[spkind])
            spkind += 1
        if len(tempevent) >= spike_count_thresh:
            burst_starts.append(tempevent[0])
            burst_counts.append(len(tempevent))
        del tempevent
    
    return np.array(burst_starts), np.array(burst_counts)

In [8]:
def find_closest_region(sunit_info, struct_tree, annot):
    ## Finds a grey matter region above/below an unknown region ##
    Vind = sunit_info.CCF_DV
    vent_sip = struct_tree.get_structures_by_id([annot[sunit_info.CCF_AP, Vind, sunit_info.CCF_ML]])[0]['structure_id_path']
    while not struct_tree.structure_descends_from(vent_sip[-1], 8):
        Vind += 1
        vent_sip = struct_tree.get_structures_by_id([annot[sunit_info.CCF_AP, Vind, sunit_info.CCF_ML]])[0]['structure_id_path']

    Dind = sunit_info.CCF_DV
    dors_sip = struct_tree.get_structures_by_id([annot[sunit_info.CCF_AP, Dind, sunit_info.CCF_ML]])[0]['structure_id_path']
    while not struct_tree.structure_descends_from(dors_sip[-1], 8):
        Dind -= 1
        dors_sip = struct_tree.get_structures_by_id([annot[sunit_info.CCF_AP, Dind, sunit_info.CCF_ML]])[0]['structure_id_path']

    if (Vind - sunit_info.CCF_DV) <= (sunit_info.CCF_DV - Dind):
        return struct_tree.get_structures_by_id([vent_sip[-1]])[0]['acronym']
    elif (Vind - sunit_info.CCF_DV) > (sunit_info.CCF_DV - Dind):
        return struct_tree.get_structures_by_id([dors_sip[-1]])[0]['acronym']

In [9]:
def get_region_from_children(test_id, parent_id, struct_tree):
    try:
        child_ind = np.nonzero([
            struct_tree.structure_descends_from(test_id, x) for x in struct_tree.child_ids([parent_id])[0]
        ])[0][0]
        return struct_tree.get_structures_by_id([struct_tree.child_ids([parent_id])[0][child_ind]])[0]['acronym']
    except:
        return struct_tree.get_structures_by_id([parent_id])[0]['acronym']

In [10]:
def get_parent_region(region_acronym, struct_tree):
    # areas_of_interest = {
    #     'SM-TH': ['AV', 'CL', 'MD', 'PO', 'PF', 'VAL', 'VPL', 'VPM', 'VM'],
    # }
    
    reg_id = struct_tree.get_structures_by_acronym([region_acronym])[0]['id']
    if struct_tree.structure_descends_from(reg_id, 567):
        if struct_tree.structure_descends_from(reg_id, 315):
            return get_region_from_children(reg_id, 315, struct_tree)
        elif struct_tree.structure_descends_from(reg_id, 698):
            return 'OLF'
        elif struct_tree.structure_descends_from(reg_id, 1089):
            return get_region_from_children(reg_id, 1089, struct_tree)
        elif struct_tree.structure_descends_from(reg_id, 703):
            return get_region_from_children(reg_id, 703, struct_tree)
        elif struct_tree.structure_descends_from(reg_id, 477):
            return 'STR'
        elif struct_tree.structure_descends_from(reg_id, 803):
            return 'PAL'
        else:
            return 'unassigned'
    elif struct_tree.structure_descends_from(reg_id, 343):
        if struct_tree.structure_descends_from(reg_id, 1129):
            return 'TH'
            # if region_acronym == 'RT':
            #     return 'RT-TH'
            # elif region_acronym in areas_of_interest['SM-TH']:
            #     return 'SM-TH'
            # else:
            #     return 'other-TH'
        elif struct_tree.structure_descends_from(reg_id, 1097):
            return 'HY'
        else:
            return get_region_from_children(reg_id, 343, struct_tree)
    else:
        return 'unassigned'

In [11]:
def add_parent_region_to_df(unit_info_df, struct_tree, annot):
    ## First, make sure all names in region column correspond to a CCF region (removes nan values) ##
    adj_regions = unit_info_df['region'].values.copy()
    for indi, rowi in unit_info_df.iterrows():
        try:
            str_info = struct_tree.get_structures_by_acronym([rowi.region])[0]
        except KeyError:
            if rowi.depth <= 0: # unit was placed above brain
                new_region_id = annot[rowi.CCF_AP, np.nonzero(annot[rowi.CCF_AP, :, rowi.CCF_ML])[0][0], rowi.CCF_ML]
                adj_regions[indi] = struct_tree.get_structures_by_id([new_region_id])[0]['acronym']
            else:
                Lind = rowi.CCF_ML
                while annot[rowi.CCF_AP, rowi.CCF_DV, Lind] == 0:
                    Lind -= 1
                new_region_id = struct_tree.get_structures_by_id(
                    [annot[rowi.CCF_AP, rowi.CCF_DV, Lind]])[0]['structure_id_path'][-1]
                adj_regions[indi] = struct_tree.get_structures_by_id([new_region_id])[0]['acronym']
    unit_info_df['adj_region'] = adj_regions
    
    ## Second, re-assign any non-grey matter areas to the closest region ##
    adj_regions = unit_info_df['adj_region'].values.copy()
    for indi, rowi in unit_info_df.iterrows():
        reg_id = struct_tree.get_structures_by_acronym([rowi.adj_region])[0]['id']
        if not struct_tree.structure_descends_from(reg_id, 8):
            adj_regions[indi] = find_closest_region(rowi, struct_tree, annot)
    unit_info_df['adj_region'] = adj_regions
    
    ## Finally, assign a parent region to each adjusted CCF region ##
    parent_regions = unit_info_df['adj_region'].values.copy()
    for indi, rowi in unit_info_df.iterrows():
        parent_regions[indi] = get_parent_region(rowi.adj_region, struct_tree)
    unit_info_df['parent_region'] = parent_regions
    
    return unit_info_df.drop('adj_region', axis=1)

# Load subjects.csv file

In [12]:
multisub_file = r"C:\Users\lesliec\OneDrive - Allen Institute\analysis\GAT1-KO_analyses\GAT1_control_NPephys_subjects.csv"
all_sessions_df = pd.read_csv(multisub_file, converters={'mouse': str})

In [13]:
all_sessions_df.head()

Unnamed: 0,genotype,mouse,experiment,sweep_states,bad_chs,CCFres,NPX_analysis,EEG_analysis,data_loc,notes
0,GAT1-KO,645606,EEGNPXspont_estim_2022-12-20_12-26-39,awake,none,25,True,True,E:\GAT1_EEG_pilot\mouse645606\EEGNPXspont_esti...,
1,GAT1-KO,644565,EEGNPXspont_estim_2022-12-22_10-36-08,awake,none,25,True,True,E:\GAT1_EEG_pilot\mouse644565\EEGNPXspont_esti...,
2,GAT1-KO,672785,EEGNPXspont_estim_2023-07-05_12-39-59,awake,all,25,True,False,E:\GAT1_EEG_pilot\mouse672785\EEGNPXspont_esti...,EEG has low amplitude signals and all chs look...
3,GAT1-KO,672789,EEGNPXspont_estim_2023-07-13_13-28-01,awake,none,25,True,True,E:\GAT1_EEG_pilot\mouse672789\EEGNPXspont_esti...,
4,wildtype,654181,estim_vis_2022-11-22_09-42-58,"awake,isoflurane",781113,25,True,True,F:\psi_exp\mouse654181\estim_vis_2022-11-22_09...,good control mouse with RT units


In [14]:
type(all_sessions_df.iloc[0].NPX_analysis)

numpy.bool_

### Check trial quality for EEG signals

## Set parameters

In [16]:
overwrite_existing_files = True

event_window = [-2.0, 2.0]

apply_mask = True
apply_hpass = True
apply_lpass = True

spike_count_threshold = 20

## Process running signal and EEG

In [18]:
for rowi, exprow in all_sessions_df.iterrows():
    print('{}: {}'.format(exprow.mouse, exprow.experiment))
    exp = EEGexp(exprow.data_loc, preprocess=False, make_stim_csv=False)

    ## Set file names ##
    running_file = os.path.join(exp.data_folder, 'running_signal.npy')
    raw_running_file = os.path.join(exp.data_folder, 'raw_running_signal.npy')
    running_ts_file = os.path.join(exp.data_folder, 'running_timestamps_master_clock.npy')
    evoked_folder = os.path.join(exp.data_folder, 'evoked_data')
    if not os.path.exists(evoked_folder):
        os.mkdir(evoked_folder)
    event_running_file = os.path.join(evoked_folder, 'event_running_speed.npy')
    event_running_ts_file = os.path.join(evoked_folder, 'event_running_times.npy')
    event_EEGtraces_file = os.path.join(evoked_folder, 'event_EEGtraces.npy')
    event_EEGtraces_ts_file = os.path.join(evoked_folder, 'event_EEGtraces_times.npy')

    ## Load stim log ##
    stim_log = pd.read_csv(exp.stimulus_log_file)
    all_event_times = stim_log['onset'].values
        
    ## Load running signal and get mean event speed ##
    if os.path.exists(running_file):
        run_signal = np.load(running_file)
        run_timestamps = np.load(running_ts_file)
    else:
        print('  Loading running from sync and saving...')
        run_signal, raw_speed, run_timestamps = exp.load_running()
        np.save(running_file, run_signal, allow_pickle=False)
        np.save(raw_running_file, raw_speed, allow_pickle=False)
        np.save(running_ts_file, run_timestamps, allow_pickle=False)
    if not os.path.exists(event_running_file) or overwrite_existing_files:
        print('  Getting event-related running...')
        rinds = np.arange(-int(-event_window[0] * 100), int(event_window[1] * 100))
        event_inds = np.array([find_nearest_ind(run_timestamps, x) for x in all_event_times])
        event_run_speed = run_signal[np.repeat([rinds], len(event_inds), axis=0).T + event_inds]
        event_run_times = rinds / 100
        ## Save ##
        np.save(event_running_file, event_run_speed, allow_pickle=False)
        np.save(event_running_ts_file, event_run_times, allow_pickle=False)
        ## Add speed to stim_log ##
        evinds = np.nonzero((event_run_times >= -0.5) & (event_run_times < 0.5))[0]
        mean_speed = np.mean(event_run_speed[evinds, :], axis=0)
        stim_log['mean_speed'] = mean_speed
        stim_log['resting_trial'] = stim_log['mean_speed'] == 0
        stim_log.to_csv(exp.stimulus_log_file, index=False)
        
    if np.any([True for x in exp.experiment_data if 'recording' in x]):
        # badchstr = exprow['bad_chs'].replace(' ','')
        # if (not os.path.exists(event_EEGtraces_file) or overwrite_existing_files) and (badchstr != 'all'):
        if (not os.path.exists(event_EEGtraces_file) or overwrite_existing_files):
            ## Load EEG data and preprocess ##
            print('  Loading EEG data...')
            datai, tsi = exp.load_eegdata()
            eeg_chs = np.arange(0, datai.shape[1])

            ## Mask estim artifact ##
            if apply_mask:
                mask_samples = int(0.002 * exp.ephys_params['EEG']['sample_rate'])
                for etime in stim_log.loc[stim_log['stim_type'] == 'biphasic', 'onset'].to_numpy():
                    val = find_nearest_ind(tsi, etime) - 2
                    datai[val:val+mask_samples, :] = datai[val:val-mask_samples:-1, :]

            ## Apply high-pass filter ##
            if apply_hpass:
                hpb, hpa = signal.butter(3, 0.1/(exp.ephys_params['EEG']['sample_rate']/2), btype='highpass')
                datai = signal.filtfilt(hpb, hpa, datai, axis=0)

            ## Get evoked traces ##
            print('  Getting EEG traces...')
            event_traces, event_ts = get_evoked_traces(
                datai, tsi, all_event_times, -event_window[0], event_window[1], exp.ephys_params['EEG']['sample_rate'])

            ## Apply lowpass filter ##
            if apply_lpass:
                lpb, lpa = signal.butter(3, 100/(exp.ephys_params['EEG']['sample_rate']/2), btype='low')
                event_traces = signal.filtfilt(lpb, lpa, event_traces, axis=0)

            ## Save ##
            print('   ...saving {}.'.format(event_EEGtraces_file))
            np.save(event_EEGtraces_file, event_traces, allow_pickle=False)
            np.save(event_EEGtraces_ts_file, event_ts, allow_pickle=False)
        else:
            print('  Not creating EEG traces file, it already exists or all EEG chs are bad.')
    else:
        print('  No EEG in this recording.')

    print('')

645606: EEGNPXspont_estim_2022-12-20_12-26-39
Experiment type: electrical stimulation
  Getting event-related running...
  Loading EEG data...
  Getting EEG traces...
   ...saving E:\GAT1_EEG_pilot\mouse645606\EEGNPXspont_estim_2022-12-20_12-26-39\experiment1\recording1\evoked_data\event_EEGtraces.npy.

644565: EEGNPXspont_estim_2022-12-22_10-36-08
Experiment type: electrical stimulation
  Getting event-related running...
  Loading EEG data...
  Getting EEG traces...
   ...saving E:\GAT1_EEG_pilot\mouse644565\EEGNPXspont_estim_2022-12-22_10-36-08\experiment1\recording1\evoked_data\event_EEGtraces.npy.

672785: EEGNPXspont_estim_2023-07-05_12-39-59
Experiment type: electrical and sensory stimulation
  Getting event-related running...
  Loading EEG data...
  Getting EEG traces...
   ...saving E:\GAT1_EEG_pilot\mouse672785\EEGNPXspont_estim_2023-07-05_12-39-59\experiment1\recording1\evoked_data\event_EEGtraces.npy.

672789: EEGNPXspont_estim_2023-07-13_13-28-01
Experiment type: electrical

## Process units

In [22]:
for rowi, exprow in all_sessions_df.iterrows():
    print('{}: {}'.format(exprow.mouse, exprow.experiment))
    if not exprow.NPX_analysis:
        print('   Not analyzing probes on this session (NPX_analysis=False).\n')
        continue
    exp = EEGexp(exprow.data_loc, preprocess=False, make_stim_csv=False)
 
    probe_list = [x.replace('_sorted', '') for x in exp.experiment_data if 'probe' in x]
    if len(probe_list) == 0:
        print('  This experiment has no probe data, not making spike times files.\n')
        continue
    
    ## Set file names ##
    evoked_folder = os.path.join(exp.data_folder, 'evoked_data')
    if not os.path.exists(evoked_folder):
        os.mkdir(evoked_folder)
    unit_info_file = os.path.join(evoked_folder, 'all_units_info.csv')
    unit_allspiketimes_file = os.path.join(evoked_folder, 'units_allspikes.pkl')
    unit_eventspikes_file = os.path.join(evoked_folder, 'units_event_spikes.pkl')
    if overwrite_existing_files:
        pass # will overwrite all subjects' files
    else:
        if os.path.exists(unit_info_file):
            print('  {} already exists, skipping analysis.\n'.format(unit_info_file))
            continue

    ## Load stim log ##
    stim_log = pd.read_csv(exp.stimulus_log_file)
    all_event_times = stim_log['onset'].values

    ## Get probe info ##
    print(' Getting probe info...')
    start = time.time()
    all_units_info = []
    unit_allspiketimes = {}
    unit_eventspikestimes = {'event_window': event_window, 'event_spikes': {}, 'event_bursts': {}}
    for probe_name in probe_list:
        print('  {}'.format(probe_name))
        ## Load probe_info.json ##
        with open(exp.ephys_params[probe_name]['probe_info']) as data_file:
            data = json.load(data_file)
        npx_allch = np.array(data['channel'])
        surface_ch = int(data['surface_channel'])
        allch_z = np.array(data['vertical_pos'])
        # ref_mask = np.array(data['mask'])
        # npx_chs = np.array([x for x in npx_allch if ref_mask[x] and x <= surface_ch])
        probe_ch_depths = allch_z[surface_ch] - allch_z
        
        ## Load the unit info ##
        cluster_group = pd.read_csv(exp.ephys_params[probe_name]['cluster_group'], sep='\t')
        cluster_metrics = pd.read_csv(exp.ephys_params[probe_name]['cluster_metrics'])
        spike_clusters = np.load(exp.ephys_params[probe_name]['spike_clusters'])
        spike_times = np.load(exp.ephys_params[probe_name]['spike_times'])
        
        if not np.array_equal(cluster_group['cluster_id'].values.astype('int'), np.unique(spike_clusters)):
            print('   IDs from cluster_group.tsv DO NOT match spike_clusters.npy. This may mean there are unsorted units, check in phy.')
            continue
        if np.array_equal(cluster_group['cluster_id'].values.astype('int'), cluster_metrics['cluster_id'].values.astype('int')):
            unit_metrics = pd.merge(cluster_group.rename(columns={'group':'label'}), cluster_metrics, on='cluster_id')
        else:
            print('   IDs from cluster_group DO NOT match cluster_metrics.')
            continue
        
        ## Select only "good" units ##
        unit_metrics['spike_count'] = [np.sum(spike_clusters == x) for x in unit_metrics.cluster_id.values]
        good_units = unit_metrics[(unit_metrics['label'] == 'good') & (unit_metrics['spike_count'] > spike_count_threshold)]
        tempcoords = np.array([[int(y) for y in x.replace('[','').replace(']','').replace(' ','').split(',')] for x in good_units.ccf_coord.values])
        
        ## Now reorganize metrics to save ##
        probe_units = pd.DataFrame([probe_name[-1] + str(x) for x in good_units.cluster_id.values], columns=['unit_id'])
        probe_units['probe'] = [probe_name] * len(good_units)
        probe_units['peak_ch'] = good_units['peak_channel'].values
        probe_units['depth'] = [probe_ch_depths[pkch] for pkch in good_units.peak_channel.values]
        probe_units['spike_duration'] = good_units['duration'].values
        probe_units['region'] = good_units['area'].values
        probe_units['CCF_AP'], probe_units['CCF_DV'], probe_units['CCF_ML'] = tempcoords[:,0], tempcoords[:,1], tempcoords[:,2]
        probe_units['firing_rate'] = good_units['firing_rate'].values
        probe_units['presence_ratio'] = good_units['presence_ratio'].values
        probe_units['isi_viol'] = good_units['isi_viol'].values
        probe_units['amplitude_cutoff'] = good_units['amplitude_cutoff'].values
        probe_units['spike_count'] = good_units['spike_count'].values
        
        ## Add parent region column ##
        probe_units = add_parent_region_to_df(probe_units, str_tree, annot)
        all_units_info.append(probe_units)

        for uniti, (uind, unitrow) in zip(good_units.cluster_id.values, probe_units.iterrows()):
            spikesi = np.squeeze(spike_times[spike_clusters == uniti])
            if unitrow.parent_region == 'TH':
                burstsi, burst_counts = find_bursts_THunit(spikesi)
            else:
                burstsi, burst_counts = find_bursts_otherunit(spikesi)
            unit_allspiketimes[unitrow.unit_id] = {
                'spikes': spikesi,
                'bursts': burstsi,
                'burst_counts': burst_counts,
            }

            event_raster = []
            burst_raster = []
            burst_count_raster = []
            for eventi in all_event_times:
                spikeinds = np.nonzero((spikesi >= eventi + event_window[0]) & (spikesi <= eventi + event_window[1]))[0]
                event_raster.append(spikesi[spikeinds] - eventi)
                burstinds = np.nonzero((burstsi >= eventi + event_window[0]) & (burstsi <= eventi + event_window[1]))[0]
                burst_raster.append(burstsi[burstinds] - eventi)
                burst_count_raster.append(burst_counts[burstinds])
            unit_eventspikestimes['event_spikes'][unitrow.unit_id] = event_raster
            unit_eventspikestimes['event_bursts'][unitrow.unit_id] = {'times': burst_raster, 'counts': burst_count_raster}

    ## Now combine all probe units dfs ##
    all_units_info_df = pd.concat(all_units_info, ignore_index=True)
    all_units_info_df.to_csv(unit_info_file, index=False)
    pickle.dump(unit_allspiketimes, open(unit_allspiketimes_file, 'wb'))
    pickle.dump(unit_eventspikestimes, open(unit_eventspikes_file, 'wb'))

    end = time.time()
    print('   Time to get unit spike times and save: {:.2f} min\n'.format((end-start)/60))
    ## After each subject, delete common variables ##
    del probe_units, all_units_info, all_units_info_df, unit_allspiketimes, stim_log, all_event_times, unit_eventspikestimes

645606: EEGNPXspont_estim_2022-12-20_12-26-39
Experiment type: electrical stimulation
 Getting probe info...
  probeB
  probeC
  probeF
   Time to get unit spike times and save: 1.53 min

644565: EEGNPXspont_estim_2022-12-22_10-36-08
Experiment type: electrical stimulation
 Getting probe info...
  probeB
  probeC
  probeF
   Time to get unit spike times and save: 1.44 min

672785: EEGNPXspont_estim_2023-07-05_12-39-59
Experiment type: electrical and sensory stimulation
 Getting probe info...
  probeB
  probeC
  probeD
  probeF
   Time to get unit spike times and save: 5.80 min

672789: EEGNPXspont_estim_2023-07-13_13-28-01
Experiment type: electrical and sensory stimulation
 Getting probe info...
  probeB
  probeC
  probeD
  probeF
   Time to get unit spike times and save: 4.28 min

654181: estim_vis_2022-11-22_09-42-58
Experiment type: electrical and sensory stimulation
 Getting probe info...
  probeB
  probeC
  probeF
   Time to get unit spike times and save: 2.94 min

669118: pilot_

In [21]:
print(unitrow.unit_id)

B0
