In [1]:
from datetime import date
from glob import glob
import json
import math
import os
import pickle
import sys
import time

import gspread
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from scipy import integrate, interpolate, stats

In [2]:
sys.path.append(r'C:\Users\lesliec\code')

In [3]:
from tbd_eeg.tbd_eeg.data_analysis.eegutils import EEGexp

In [4]:
%matplotlib notebook

#### Functions

In [5]:
def get_stim_event_inds(stim_table, stim_type, stim_param, sweep, trials='all'):
    if trials == 'resting':
        return stim_table[
            (stim_table['stim_type'] == stim_type) &
            (stim_table['parameter'] == stim_param) &
            (stim_table['sweep'] == sweep) &
            (stim_table['good'] == True) &
            (stim_table['resting_trial'] == True)
        ].index.values
    elif trials == 'running':
        return stim_table[
            (stim_table['stim_type'] == stim_type) &
            (stim_table['parameter'] == stim_param) &
            (stim_table['sweep'] == sweep) &
            (stim_table['good'] == True) &
            (stim_table['resting_trial'] == False)
        ].index.values
    else:
        return stim_table[
            (stim_table['stim_type'] == stim_type) &
            (stim_table['parameter'] == stim_param) &
            (stim_table['sweep'] == sweep) &
            (stim_table['good'] == True)
        ].index.values

In [6]:
def get_zFR(trig_FR, FRtime):
    
    trig_Z = np.zeros_like(trig_FR) * np.nan # try NaNs, it works
    baseline_avg = np.mean(trig_FR[FRtime < 0, :], axis=0)
    baseline_std = np.std(trig_FR[FRtime < 0, :], axis=0)
    nonzero_inds = np.nonzero(baseline_std)[0]
    trig_Z[:, nonzero_inds] = (trig_FR[:, nonzero_inds] - baseline_avg[None, nonzero_inds]) / baseline_std[None, nonzero_inds]
    
    return trig_Z

#### Load Zap_Zip-log_exp to get metadata for experiments

In [7]:
_gc = gspread.service_account() # need a key file to access the account
_sh = _gc.open('Templeton-log_exp') # open the spreadsheet
_df = pd.DataFrame(_sh.sheet1.get()) # load the first worksheet
metadata = _df.T.set_index(0).T # put it in a nicely formatted dataframe

In [8]:
plotsdir = r'C:\Users\lesliec\OneDrive - Allen Institute\data\plots\psilocybin_exp'

### Load subjects

In [9]:
multisub_file = r"C:\Users\lesliec\OneDrive - Allen Institute\data\brain_states_subjects.csv"
subject_df = pd.read_csv(multisub_file, converters={'mouse': str}).astype({'analyze': bool})

In [10]:
subject_df[:10]

Unnamed: 0,exp_type,mouse,experiment,sweep_states,stim_depth,bad_chs,analyze,data_loc,CCF_res,notes
0,psilocybin,657903,pilot_aw_psi_2023-01-13_12-18-22,"awake,psilocybin",deep,none,False,F:\psi_exp\mouse657903\pilot_aw_psi_2023-01-13...,25,"big lesion in CTX, stim closer to ORB anyway, ..."
1,psilocybin,666193,pilot_aw_psi_2023-02-16_10-55-48,psilocybin,deep,all,True,F:\psi_exp\mouse666193\pilot_aw_psi_2023-02-16...,25,no EEG
2,psilocybin,666194,pilot_aw_psi_2023-02-23_10-40-34,psilocybin,deep,none,True,F:\psi_exp\mouse666194\pilot_aw_psi_2023-02-23...,25,
3,psilocybin,666196,pilot_aw_psi_2023-03-16_10-21-29,"awake,psilocybin,psilocybin,psilocybin,psilocy...",deep,all,True,F:\psi_exp\mouse666196\pilot_aw_psi_2023-03-16...,25,EEG bad?
4,psilocybin,669118,pilot_aw_psi_2023-03-24_09-55-33,"awake,psilocybin,psilocybin,psilocybin,psilocy...",deep,none,True,F:\psi_exp\mouse669118\pilot_aw_psi_2023-03-24...,25,"only has probe B, F"
5,psilocybin,669117,pilot_aw_psi_2023-03-30_11-37-07,"awake,psilocybin,psilocybin,psilocybin,psilocy...",deep,691314,True,F:\psi_exp\mouse669117\pilot_aw_psi_2023-03-30...,25,
6,psilocybin,673449,aw_psi_2023-04-19_11-23-26,"awake,psilocybin,psilocybin",deep,0,False,F:\psi_exp\mouse673449\aw_psi_2023-04-19_11-23...,25,running a lot this day
7,psilocybin,673449,aw_psi_d2_2023-04-20_10-05-31,"awake,psilocybin,psilocybin",deep,1314,True,F:\psi_exp\mouse673449\aw_psi_d2_2023-04-20_10...,25,
8,psilocybin,676726,aw_psi_2023-05-03_11-08-22,"awake,psilocybin,psilocybin",deep,29,True,F:\psi_exp\mouse676726\aw_psi_2023-05-03_11-08...,25,
9,psilocybin,676727,aw_psi_2023-05-10_09-49-12,"awake,psilocybin,psilocybin",deep,29,True,F:\psi_exp\mouse676727\aw_psi_2023-05-10_09-49...,25,


## Testing metrics on single subjects with plots

In [12]:
skip_states = ['recovery']
psilocybin_window = 90 * 60 # min to include as "psilocybin" -> seconds
substates = {'resting': True, 'running': False}
trial_threshold = 30 # must have more than this number of trials to be included in analysis
trial_max = 125 # limit some states that have many trials
unit_threshold = 5 # must have st least this number of units to be included in analysis
time_bin = 0.0025 # size of time bins (s) for firing rate

sig_test_window = [0.075, 0.3]
sigalpha = 0.05 # significance threshold for unit activation
burst_window = [0.075, 0.3]
early_window = [0.002, 0.025]
late_window = [0.075, 0.3]

In [13]:
subrow = subject_df.iloc[9]
print(subrow.mouse)

if not subrow.analyze:
    print('Skipping {} - {} for now, missing data.\n'.format(subrow.mouse, subrow.exp_type))

551399


Add state and stim depth info to stim_log and get stim current for analysis

In [14]:
## Load EEGexp and stim_log ##
exp = EEGexp(subrow.data_loc, preprocess=False, make_stim_csv=False)
stim_log = pd.read_csv(exp.stimulus_log_file).astype({'parameter': str})

### Get all states in experiment ###
all_sweeps = np.unique(stim_log['sweep'].values)

## Get state label for each sweep ##
sweep_state_list = []
for char in subrow.sweep_states.split(','):
    sweep_state_list.append(char)
if len(sweep_state_list) == 1:
    sweep_state_list = sweep_state_list * len(all_sweeps)
stim_log['state'] = stim_log.apply(lambda x: sweep_state_list[x.sweep], axis=1)
## Get depth label for each sweep ##
sweep_depth_list = []
for char in subrow.stim_depth.split(','):
    sweep_depth_list.append(char)
if len(sweep_depth_list) == 1:
    sweep_depth_list = sweep_depth_list * len(all_sweeps)
stim_log['stim_depth'] = stim_log.apply(lambda x: sweep_depth_list[x.sweep], axis=1)
## Get list of states and the middle current ##
states = np.unique(stim_log['state'].values)
currents = np.unique(stim_log[stim_log['stim_type'] == 'biphasic']['parameter'].values).astype(int)
if len(currents) > 1:
    ch_curr = str(currents[1])
else:
    ch_curr = str(currents[0])
    
print(states)

Experiment type: electrical stimulation
['awake' 'isoflurane' 'recovery']


All-units info now includes parent region

In [15]:
## Load unit info ##
fn_units_info = os.path.join(exp.data_folder, 'evoked_data', 'all_units_info.csv')
if os.path.exists(fn_units_info):
    unit_info = pd.read_csv(fn_units_info)
    with open(os.path.join(exp.data_folder, 'evoked_data', 'units_event_spikes.pkl'), 'rb') as unit_file:
        all_unit_event_spikes = pickle.load(unit_file)
else:
    print('{} not found. Not analyzing this subject.'.format(fn_units_stats))
#     continue

## Make time bins for event spikes ##
bins = np.arange(all_unit_event_spikes['event_window'][0], all_unit_event_spikes['event_window'][1] + time_bin, time_bin)
timex = bins[:-1] + time_bin/2

unit_info.head()

Unnamed: 0,unit_id,probe,peak_ch,depth,spike_duration,region,CCF_AP,CCF_DV,CCF_ML,parent_region
0,B0,probeB,0,3460,0.508208,VAL,624,424,420,SM-TH
1,B2,probeB,0,3460,0.467002,VAL,624,424,420,SM-TH
2,B4,probeB,1,3460,0.535678,VAL,624,424,420,SM-TH
3,B5,probeB,1,3460,0.357119,VAL,624,424,420,SM-TH
4,B6,probeB,1,3460,0.425796,VAL,624,424,420,SM-TH


In [16]:
np.unique(unit_info['parent_region'].values)

array(['ACA', 'ILA', 'MO', 'PL', 'SM-TH', 'SS', 'other-TH'], dtype=object)

In [17]:
ROI_unit_info = {}
for region in np.unique(unit_info['parent_region'].values):
    udf = unit_info[unit_info['parent_region'] == region]
    if len(udf) >= unit_threshold:
        ROI_unit_info[region] = udf.sort_values(by='depth').reset_index(drop=True)
        print('{}: {:d} units'.format(region, len(udf)))

ACA: 19 units
ILA: 15 units
MO: 99 units
PL: 205 units
SM-TH: 82 units
SS: 120 units
other-TH: 53 units


Get eventinds

In [18]:
state_event_inds = {}
for statei in states:
    if statei in skip_states:
        print('Skipping {}.'.format(statei))
        continue
    if statei in ['psilocybin', 'saline']:
        exp_meta = metadata[(
            (metadata['mouse_name'].str.contains(subrow.mouse)) &
            (metadata['exp_name'].str.contains(subrow.experiment))
        )].squeeze()
        stim_log['onset_from_inj2'] = stim_log['onset'] - float(exp_meta['Second injection time'])
        for substi, trialtest in substates.items():
            eventinds = stim_log[
                (stim_log['stim_type'] == 'biphasic') &
                (stim_log['parameter'] == ch_curr) &
                (stim_log['stim_depth'] == 'deep') &
                (stim_log['good'] == True) &
                (stim_log['resting_trial'] == trialtest) &
                (stim_log['state'] == statei) &
                (stim_log['onset_from_inj2'] < psilocybin_window)
            ].index.values
            state_event_inds[statei + '_' + substi] = eventinds
    elif statei == 'awake':
        for substi, trialtest in substates.items():
            eventinds = stim_log[
                (stim_log['stim_type'] == 'biphasic') &
                (stim_log['parameter'] == ch_curr) &
                (stim_log['stim_depth'] == 'deep') &
                (stim_log['good'] == True) &
                (stim_log['resting_trial'] == trialtest) &
                (stim_log['state'] == statei)
            ].index.values
            state_event_inds[statei + '_' + substi] = eventinds
    else:
        eventinds = stim_log[
            (stim_log['stim_type'] == 'biphasic') &
            (stim_log['parameter'] == ch_curr) &
            (stim_log['stim_depth'] == 'deep') &
            (stim_log['good'] == True) &
            (stim_log['resting_trial'] == True) &
            (stim_log['state'] == statei)
        ].index.values
        state_event_inds[statei] = eventinds
print(state_event_inds.keys())

Skipping recovery.
dict_keys(['awake_resting', 'awake_running', 'isoflurane'])


In [19]:
all_data = {
    'unit_metrics': {}, 'unit_zscores': {}, 'pop_fr': {}, 'trial_counts': {}, 'spike_latencies': {}
}
## Get metrics for each state/region #
for statei, event_inds in state_event_inds.items():
    if len(event_inds) < trial_threshold:
        print(' Only {:d} trials for {} state, not analyzing.'.format(len(event_inds), statei))
        continue
    elif len(event_inds) > trial_max:
        print(' {} has {:d} trials, downsampling trials to {:d}.'.format(statei, len(event_inds), trial_max))
        event_inds = np.random.choice(event_inds, size=trial_max, replace=False)
    all_data['trial_counts'][statei] = len(event_inds)
    all_data['unit_metrics'][statei] = {}
    all_data['unit_zscores'][statei] = {}
    all_data['pop_fr'][statei] = {}
    all_data['spike_latencies'][statei] = {}
    for regi, regdf in ROI_unit_info.items():
        unit_firing_rates = np.zeros((len(timex), len(regdf)), dtype=float) * np.nan
        sig_evoked_units = np.zeros(len(regdf), dtype=int)
        burst_trials = np.zeros(len(regdf), dtype=float)
        mean_burst_counts = np.zeros(len(regdf), dtype=float)
        early_spike_times = np.zeros((len(event_inds), len(regdf)), dtype=float) * np.nan
        late_spike_times = np.zeros((len(event_inds), len(regdf)), dtype=float) * np.nan

        ## Get unit event spike times ##
        reg_unit_metrics = []
        for ii, unitrow in regdf.iterrows():
            unit_event_spikes = [all_unit_event_spikes['event_spikes'][unitrow.unit_id][ei] for ei in event_inds]
            unit_event_bursts = [all_unit_event_spikes['event_bursts'][unitrow.unit_id]['times'][ei] for ei in event_inds]
            ## Get firing rates ##    
            unit_event_counts, edges = np.histogram(np.concatenate(unit_event_spikes), bins)
            unit_firing_rates[:, ii] = unit_event_counts / (time_bin * len(event_inds))

            prespikes = np.zeros(len(event_inds), dtype=int)
            postspikes = np.zeros(len(event_inds), dtype=int)
            trial_counts = np.zeros(len(event_inds), dtype=int)
            burst_counts = np.zeros(len(event_inds), dtype=int)
            for jj, uspikesi in enumerate(unit_event_spikes):
                ## Count spikes ##
                prespikes[jj] = np.sum((uspikesi >= -sig_test_window[1]) & (uspikesi <= -sig_test_window[0]))
                postspikes[jj] = np.sum((uspikesi >= sig_test_window[0]) & (uspikesi <= sig_test_window[1]))

                ## Count bursts ##
                windowbursts = np.nonzero(
                    (unit_event_bursts[jj] >= burst_window[0]) & (unit_event_bursts[jj] <= burst_window[1]))[0]
                if len(windowbursts) > 0:
                    trial_counts[jj] = 1
                burst_counts[jj] = len(np.nonzero(unit_event_bursts[jj] >= burst_window[0])[0])

                ## Find first spikes ##
                earlyspikes = np.nonzero((uspikesi >= early_window[0]) & (uspikesi <= early_window[1]))[0]
                if len(earlyspikes) > 0:
                    early_spike_times[jj, ii] = uspikesi[earlyspikes[0]]
                latespikes = np.nonzero((uspikesi >= late_window[0]) & (uspikesi <= late_window[1]))[0]
                if len(latespikes) > 0:
                    late_spike_times[jj, ii] = uspikesi[latespikes[0]]

            wstat, pval = stats.wilcoxon(x=postspikes, y=prespikes, zero_method='zsplit')
            spcount = np.mean(postspikes) - np.mean(prespikes)

            reg_unit_metrics.append([
                unitrow.unit_id, pval, spcount, np.nanmedian(early_spike_times[:,ii]), np.nanmedian(late_spike_times[:,ii]),
                np.mean(trial_counts), np.mean(burst_counts), np.mean(unit_firing_rates[timex < 0, ii])
            ])
            sig_evoked_units[ii] = pval < sigalpha
            burst_trials[ii] = np.mean(trial_counts) # fraction of trials with burst
            mean_burst_counts[ii] = np.mean(burst_counts) # avg number of evoked bursts
        unit_metrics_df = regdf.merge(
            pd.DataFrame(reg_unit_metrics, columns=[
                'unit_id', 'p_value', 'mean_spike_diff', 'early_latency', 'late_latency',
                'burst_prob', 'burst_count', 'baselineFR'
            ]), on='unit_id', how='left')
        all_data['unit_metrics'][statei][regi] = unit_metrics_df
        all_data['unit_zscores'][statei][regi] = [timex, get_zFR(unit_firing_rates, timex)]
        all_data['pop_fr'][statei][regi] = [timex, np.mean(unit_firing_rates, axis=1)]
        all_data['spike_latencies'][statei][regi] = [early_spike_times, late_spike_times]

#         all_subjects_states_info.append([
#             subrow.mouse, subrow.exp_type, statei, len(event_inds), regi, len(regdf),
#             np.mean(unit_firing_rates[timex < 0, :]), np.mean(sig_evoked_units), np.median(burst_trials),
#             np.mean(mean_burst_counts), np.nanmedian(early_spike_latency), np.nanmedian(late_spike_latency)
#         ])

  overwrite_input=overwrite_input)


 Only 12 trials for awake_running state, not analyzing.


### Plotting firing rate scatter for 2 states

In [20]:
for regi, df in ROI_unit_info.items():
    print('{}: {:d} units'.format(regi, len(df)))

ACA: 19 units
ILA: 15 units
MO: 99 units
PL: 205 units
SM-TH: 82 units
SS: 120 units
other-TH: 53 units


In [35]:
compare_states = ['awake_running', 'saline_running']
# plot_areas = ['MO', 'PL', 'ILA', 'ACA', 'RSP', 'SS', 'VIS', 'OLF', 'SM-TH', 'RT-TH', 'other-TH', 'HIP'] # 669117
# plot_areas = ['MO', 'PL', 'ORB', 'FRP', 'SS', 'OLF', 'SM-TH', 'other-TH', 'HIP'] # 669118
plot_areas = ['MO', 'ORB', 'FRP', 'STR', 'SS', 'VIS', 'OLF', 'HIP', 'SM-TH', 'RT-TH', 'other-TH'] # 666196
acolors = plt.cm.hsv(np.linspace(0, 1, len(plot_areas)))

Scatter plot of baseline firing rates in "compare_states"

In [36]:
fig, axs = plt.subplots(3, 4, figsize=(12,9), constrained_layout=True)

for ax, regi in zip(axs.flatten(), plot_areas):
    scatter_vals = []
    for statei in compare_states:
        scatter_vals.append(all_data['unit_metrics'][statei][regi]['baselineFR'].values)
    min_val = np.min(scatter_vals)
    max_val = np.max(scatter_vals)
    ax.plot([min_val, max_val], [min_val, max_val], color='r', linestyle='dashed', alpha=0.3)
    ax.scatter(scatter_vals[0], scatter_vals[1], color='k', alpha=0.6)
    ax.set_title('{} (n={:d})'.format(regi, len(scatter_vals[0])))

axs[0,0].set_xlabel('{}\n(baseline firing rate, Hz)'.format(compare_states[0]))
axs[0,0].set_ylabel('{}\n(baseline firing rate, Hz)'.format(compare_states[1]))

## Save ##
figname = '{}_{}_baselineFR_allregions.png'.format(subrow.mouse, subrow.exp_type)
fig.savefig(os.path.join(plotsdir, 'ind_sub_evoked_units', figname), transparent=False, dpi=300)

<IPython.core.display.Javascript object>

### Summary across all areas/states

In [36]:
Zlim = 5
plwin = [-0.2, 0.6]
acolors = plt.cm.hsv(np.linspace(0, 1, len(ROI_unit_info)+2))

fig = plt.figure(figsize=(len(all_data['unit_metrics'])*2.4, 5))
gs = fig.add_gridspec(ncols=1, nrows=2, height_ratios=[1, 2.5], left=0.08, right=0.97, top=0.9, bottom=0.1, hspace=0.1)
popaxs = gs[0].subgridspec(ncols=len(all_data['unit_metrics']), nrows=1, wspace=0.14).subplots(sharex=True, sharey=True)
Zaxs = gs[1].subgridspec(ncols=len(all_data['unit_metrics']), nrows=len(ROI_unit_info), hspace=0.08, wspace=0.14).subplots(sharex=True)

for jj, (statei, statedict) in enumerate(all_data['unit_zscores'].items()):
    popaxs[jj].axvline(0, color='k', alpha=0.25)
    for ii, (regi, datai) in enumerate(statedict.items()):
        popaxs[jj].plot(
            all_data['pop_fr'][statei][regi][0], all_data['pop_fr'][statei][regi][1],
            color=acolors[ii], linewidth=1.2, alpha=0.8, label=regi
        )
        imunit = Zaxs[ii,jj].imshow(
            datai[1].T, cmap='bwr', interpolation='none', aspect='auto', origin='upper', vmin=-Zlim, vmax=Zlim,
            extent=[datai[0][0], datai[0][-1], 0, datai[1].shape[0]],
        )
        Zaxs[ii,jj].axvline(0, color='k', alpha=0.25)
        Zaxs[ii,jj].set_ylabel('{} (n={:d})'.format(regi, datai[1].shape[1]))
        Zaxs[ii,jj].set_yticks([])

    Zaxs[ii,jj].set_xlim(plwin)
    Zaxs[ii,jj].set_xlabel('Time from stim onset (s)')

    popaxs[jj].set_xlim(plwin)
    popaxs[jj].set_xticklabels([])
    popaxs[jj].set_title('{} ({:d} trials)'.format(statei, len(state_event_inds[statei])))
popaxs[0].set_ylabel('Pop. FR (Hz)')
popaxs[0].legend()

<IPython.core.display.Javascript object>

<matplotlib.legend.Legend at 0x28fdabea208>

In [40]:
locs = np.arange((len(ROI_unit_info) + 1) * len(all_data['unit_metrics'])).reshape((len(all_data['unit_metrics']), (len(ROI_unit_info) + 1)))
burst_region = 'SM-TH'

fig = plt.figure(figsize=(12, 8))
gs = fig.add_gridspec(ncols=2, nrows=1, left=0.15, right=0.95, top=0.9, bottom=0.08, wspace=0.25)
axs = gs[0].subgridspec(ncols=1, nrows=3, hspace=0.15).subplots(sharex=True)
bxs = gs[1].subgridspec(ncols=1, nrows=2, hspace=0.15).subplots(sharex=True)

for ii, (statei, statemets) in enumerate(all_data['unit_metrics'].items()):
    for jj, (regi, reg_metrics) in enumerate(statemets.items()):
        total_units = len(reg_metrics)
        xs = np.random.normal(locs[ii,jj], 0.08, total_units)
        ## Fraction of total units that are significantly activated (excited+inhibited) ##
        sig_units = np.sum(reg_metrics['p_value'].values < sigalpha)
        axs[0].bar(locs[ii,jj], sig_units/total_units, color=acolors[jj])
        
        ## Early spike latency ##
        early_lats = reg_metrics['early_latency'].values * 1E3
        axs[1].scatter(xs, early_lats, color='k', marker='o', alpha=0.4)
        axs[1].boxplot(
            early_lats[~np.isnan(early_lats)], positions=[locs[ii,jj]], widths=[0.8], showfliers=False, 
            medianprops={'color': acolors[jj], 'linewidth': 2}, boxprops={'color': acolors[jj]},
        )
        
        ## Late spike latency ##
        late_lats = reg_metrics['late_latency'].values * 1E3
        axs[2].scatter(xs, late_lats, color='k', marker='o', alpha=0.4)
        axs[2].boxplot(
            late_lats[~np.isnan(late_lats)], positions=[locs[ii,jj]], widths=[0.8], showfliers=False, 
            medianprops={'color': acolors[jj], 'linewidth': 2}, boxprops={'color': acolors[jj]},
        )
        
        if regi == burst_region:
            xs2 = np.random.normal(ii, 0.08, total_units)
            ## Burst probability ##
            bxs[0].scatter(xs2, reg_metrics['burst_prob'].values, color='k', marker='o', alpha=0.4)
            bxs[0].boxplot(
                reg_metrics['burst_prob'].values, positions=[ii], widths=[0.8], showfliers=False, 
                medianprops={'color': acolors[jj], 'linewidth': 2}, boxprops={'color': acolors[jj]},
            )
            
            ## Burst count ##
            bxs[1].scatter(xs2, reg_metrics['burst_count'].values, color='k', marker='o', alpha=0.4)
            bxs[1].boxplot(
                reg_metrics['burst_count'].values, positions=[ii], widths=[0.8], showfliers=False, 
                medianprops={'color': acolors[jj], 'linewidth': 2}, boxprops={'color': acolors[jj]},
            )
                 
axs[0].set_ylabel('Fraction of total units\nw/ significant response')
axs[1].set_ylabel('First spike latency\nfrom stim onset (ms)')
axs[2].set_ylabel('Rebound spike latency\nfrom stim onset (ms)')
axs[2].set_xticks(np.mean(locs[:,:-1], axis=1))
axs[2].set_xticklabels(all_data['unit_metrics'].keys())

regleg = []
for jj, (regi, reg_metrics) in enumerate(statemets.items()):
    regleg.append(Patch(facecolor=acolors[jj], label='{} (n={:d})'.format(regi, len(reg_metrics))))
axs[0].legend(handles=regleg)

bxs[0].set_ylabel('Burst probability ({})'.format(burst_region))
bxs[1].set_ylabel('Mean burst count ({})'.format(burst_region))
bxs[1].set_xticks(range(len(all_data['unit_metrics'])))
bxs[1].set_xticklabels(all_data['unit_metrics'].keys())

<IPython.core.display.Javascript object>

[Text(0, 0, 'awake_resting'), Text(1, 0, 'isoflurane')]

### Examining spike latencies

In [50]:
locs = np.arange((len(ROI_unit_info) + 1) * len(all_data['spike_latencies'])).reshape((len(all_data['spike_latencies']), (len(ROI_unit_info) + 1)))

fig, axs = plt.subplots(ncols=2, nrows=1, figsize=(9,3.5), constrained_layout=True, sharex=True)

for ii, (statei, statelats) in enumerate(all_data['spike_latencies'].items()):
    for jj, (regi, reg_lats) in enumerate(statelats.items()):
        for (kk, latmat), wink, bink in zip(enumerate(reg_lats), [early_window, late_window], [50, 100]):
            pop_lat = []
            for tspikes in latmat:
                tc, tb = np.histogram(tspikes, bins=bink, range=(wink[0], wink[1]))
                if np.sum(tc) > 0:
                    pop_lat.append(tb[np.argmax(tc)])
            
            xs = np.random.normal(locs[ii,jj], 0.08, len(pop_lat))
            axs[kk].scatter(xs, pop_lat, color='k', marker='o', alpha=0.4)
            axs[kk].boxplot(
                pop_lat, positions=[locs[ii,jj]], widths=[0.8], showfliers=False, 
                medianprops={'color': acolors[jj], 'linewidth': 2}, boxprops={'color': acolors[jj]},
            )          

axs[0].set_ylabel('First spike latency\nfrom stim onset (ms)')
axs[1].set_ylabel('Rebound spike latency\nfrom stim onset (ms)')
axs[0].set_xticks(np.mean(locs[:,:-1], axis=1))
axs[0].set_xticklabels(all_data['spike_latencies'].keys())

regleg = []
for jj, (regi, reg_lats) in enumerate(statelats.items()):
    regleg.append(Patch(facecolor=acolors[jj], label='{} (n={:d})'.format(regi, reg_lats[0].shape[1])))
axs[0].legend(handles=regleg)

<IPython.core.display.Javascript object>

<matplotlib.legend.Legend at 0x28f8540ec08>

#### Plot one state

In [22]:
statei = 'awake_resting'

print(len(all_data['pop_fr'][statei]))

7


In [24]:
all_data['pop_fr']['awake_resting'].keys()

dict_keys(['ACA', 'ILA', 'MO', 'PL', 'SM-TH', 'SS', 'other-TH'])

In [25]:
plot_areas = ['MO', 'PL', 'ILA', 'ACA', 'SS', 'SM-TH', 'other-TH']
print(len(plot_areas))
acolors = plt.cm.hsv(np.linspace(0, 1, len(plot_areas)))

7


In [26]:
plot_window = [-0.2, 0.6]
fig, ax = plt.subplots(figsize=(8,4), constrained_layout=True)

ax.axvline(0, color='k', alpha=0.25)
for ii, regi in enumerate(plot_areas):
    timex = all_data['pop_fr'][statei][regi][0]
    traces = all_data['pop_fr'][statei][regi][1]
    num_units = len(all_data['unit_metrics'][statei][regi])
    ax.plot(timex, traces, color=acolors[ii], linewidth=1.2, alpha=0.8, label='{} (n={:d})'.format(regi, num_units))
    
ax.set_title('{} ({:d} trials)'.format(statei, all_data['trial_counts'][statei]))
ax.set_xlim(plot_window)
ax.set_xlabel('Time from stim onset (s)')
ax.set_ylabel('Population firing rate (Hz)')
ax.legend(loc='upper right')

<IPython.core.display.Javascript object>

<matplotlib.legend.Legend at 0x28fcd807d08>

In [36]:
all_data['unit_metrics'][statei]['MO'].head()

Unnamed: 0,unit_id,probe,peak_ch,depth,spike_duration,region,CCF_AP,CCF_DV,CCF_ML,parent_region,p_value,mean_spike_diff,early_latency,late_latency,burst_prob,burst_count,baselineFR
0,F546,probeF,339,280,0.79665,MOs2/3,123,73,177,MO,0.518348,-0.032,0.00296,0.237906,0.0,0.0,0.48
1,F528,probeF,318,480,0.61809,MOs5,125,81,180,MO,0.611167,-0.048,0.003366,0.234963,0.0,0.0,0.364
2,F527,probeF,317,500,0.755444,MOs5,125,81,181,MO,0.000144,-0.24,0.017943,0.264899,0.0,0.0,1.608
3,F520,probeF,314,520,0.260972,MOs5,125,82,181,MO,0.000979,-0.384,0.017178,0.220105,0.0,0.0,4.164
4,F506,probeF,300,660,0.700503,MOs5,126,88,184,MO,1.0,0.0,,,0.0,0.0,0.0


In [37]:
## Find significantly activated neurons ##
region = 'MO'

sig_units = all_data['unit_metrics'][statei][region][
    (all_data['unit_metrics'][statei][region]['p_value'] < sigalpha) &
    (all_data['unit_metrics'][statei][region]['mean_spike_diff'] > 0) # looking for more spikes post than pre
]

print('{}: {:d}/{:d} units are sig. activated in rebound window'.format(
    region, len(sig_units), len(all_data['unit_metrics'][statei][region])
))

MO: 0/48 units are sig. activated in rebound window


In [38]:
sig_units.head()

Unnamed: 0,unit_id,probe,peak_ch,depth,spike_duration,region,CCF_AP,CCF_DV,CCF_ML,parent_region,p_value,mean_spike_diff,early_latency,late_latency,burst_prob,burst_count,baselineFR
