In [1]:
from datetime import date
from glob import glob
import json
import math
import os
import pickle
import sys
import time

import gspread
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from scipy import integrate, interpolate, stats
import pingouin as pg

In [2]:
sys.path.append(r'C:\Users\lesliec\code')

In [3]:
from tbd_eeg.tbd_eeg.data_analysis.eegutils import EEGexp

In [4]:
%matplotlib notebook

In [5]:
sigalpha = 0.05

In [6]:
areas_of_interest = {
    'MO': ['MOp1', 'MOp2/3', 'MOp5', 'MOp6a', 'MOp6b', 'MOs1', 'MOs2/3', 'MOs5', 'MOs6a', 'MOs6b'],
    'SM-TH': ['AV', 'CL', 'MD', 'PO', 'PF', 'VAL', 'VPL', 'VPM', 'VM'],
    'RT-TH': ['RT'],
}
areacolors = {'MO': 'blueviolet', 'SM-TH': 'limegreen', 'RT-TH': 'salmon'}

#### Functions

In [7]:
def get_stim_event_inds(stim_table, stim_type, stim_param, sweep, trials='all'):
    if trials == 'resting':
        return stim_table[
            (stim_table['stim_type'] == stim_type) &
            (stim_table['parameter'] == stim_param) &
            (stim_table['sweep'] == sweep) &
            (stim_table['good'] == True) &
            (stim_table['resting_trial'] == True)
        ].index.values
    elif trials == 'running':
        return stim_table[
            (stim_table['stim_type'] == stim_type) &
            (stim_table['parameter'] == stim_param) &
            (stim_table['sweep'] == sweep) &
            (stim_table['good'] == True) &
            (stim_table['resting_trial'] == False)
        ].index.values
    else:
        return stim_table[
            (stim_table['stim_type'] == stim_type) &
            (stim_table['parameter'] == stim_param) &
            (stim_table['sweep'] == sweep) &
            (stim_table['good'] == True)
        ].index.values

In [8]:
def get_zFR(trig_FR, FRtime):
    
    trig_Z = np.zeros_like(trig_FR) * np.nan # try NaNs, it works
    baseline_avg = np.mean(trig_FR[FRtime < 0, :], axis=0)
    baseline_std = np.std(trig_FR[FRtime < 0, :], axis=0)
    nonzero_inds = np.nonzero(baseline_std)[0]
    trig_Z[:, nonzero_inds] = (trig_FR[:, nonzero_inds] - baseline_avg[None, nonzero_inds]) / baseline_std[None, nonzero_inds]
    
    return trig_Z

In [9]:
def p_stars(test_pval):
    if test_pval < 0.001:
        return '***'
    elif test_pval < 0.01:
        return '**'
    elif test_pval < 0.05:
        return '*'
    else:
        return 'n.s.'

#### Load metadata for experiments

In [10]:
_gc = gspread.service_account() # need a key file to access the account
_sh = _gc.open('Templeton-log_exp') # open the spreadsheet
_df = pd.DataFrame(_sh.sheet1.get()) # load the first worksheet
metadata = _df.T.set_index(0).T # put it in a nicely formatted dataframe

In [11]:
plotsdir = r'C:\Users\lesliec\OneDrive - Allen Institute\data\plots\psilocybin_exp'

### Load subjects

In [12]:
multisub_file = r"C:\Users\lesliec\OneDrive - Allen Institute\data\brain_states_subjects.csv"
subject_df = pd.read_csv(multisub_file, converters={'mouse': str}).astype({'analyze': bool})

In [13]:
subject_df.head()

Unnamed: 0,exp_type,mouse,experiment,sweep_states,stim_depth,bad_chs,analyze,data_loc,CCF_res,notes
0,psilocybin,657903,pilot_aw_psi_2023-01-13_12-18-22,"awake,psilocybin",deep,none,False,F:\psi_exp\mouse657903\pilot_aw_psi_2023-01-13...,25,"big lesion in CTX, stim closer to ORB anyway, ..."
1,psilocybin,666193,pilot_aw_psi_2023-02-16_10-55-48,psilocybin,deep,all,True,F:\psi_exp\mouse666193\pilot_aw_psi_2023-02-16...,25,no EEG
2,psilocybin,666194,pilot_aw_psi_2023-02-23_10-40-34,psilocybin,deep,none,True,F:\psi_exp\mouse666194\pilot_aw_psi_2023-02-23...,25,
3,psilocybin,666196,pilot_aw_psi_2023-03-16_10-21-29,"awake,psilocybin,psilocybin,psilocybin,psilocy...",deep,all,True,F:\psi_exp\mouse666196\pilot_aw_psi_2023-03-16...,25,EEG bad?
4,psilocybin,669118,pilot_aw_psi_2023-03-24_09-55-33,"awake,psilocybin,psilocybin,psilocybin,psilocy...",deep,none,True,F:\psi_exp\mouse669118\pilot_aw_psi_2023-03-24...,25,"only has probe B, F"


## Get multi-subject metrics

In [14]:
call_saline_awake = True

skip_states = ['recovery']
psilocybin_window = 30 * 60 # min to include as "psilocybin" -> seconds; I tried 60 min on 5/10
substates = {'resting': True, 'running': False}
trial_threshold = 30 # must have more than this number of trials to be included in analysis
trial_max = 125 # limit some states that have many trials
unit_threshold = 25 # must have at least this number of units to be included in analysis, 5/10 was 5
time_bin = 0.0025 # size of time bins (s) for firing rate

sig_test_window = [0.075, 0.3]
sigalpha = 0.05 # significance threshold for unit activation
burst_window = [0.075, 0.3]
early_window = [0.002, 0.025]
late_window = [0.075, 0.3]

In [15]:
all_subjects_states_info = []
all_subjects_data = {}
for indi, subrow in subject_df.iterrows():
    if not subrow.analyze:
        print('Skipping {} - {}, experiment excluded from analysis.\n'.format(subrow.mouse, subrow.exp_type))
        continue
#     elif subrow.bad_chs == 'all':
#         print('Skipping {} - {}, all EEG chs are bad.\n'.format(subrow.mouse, subrow.exp_type))
#         continue
    print('{}: {}'.format(subrow.mouse, subrow.experiment))
    if subrow.mouse in all_subjects_data.keys():
        all_subjects_data[subrow.mouse][subrow.exp_type] = {}
    else:
        all_subjects_data[subrow.mouse] = {}
        all_subjects_data[subrow.mouse][subrow.exp_type] = {}
    
    ## Load EEGexp and stim_log ##
    exp = EEGexp(subrow.data_loc, preprocess=False, make_stim_csv=False)
    stim_log = pd.read_csv(exp.stimulus_log_file).astype({'parameter': str})

    ### Get all states in experiment ###
    all_sweeps = np.unique(stim_log['sweep'].values)

    ## Get state label for each sweep ##
    sweep_state_list = []
    for char in subrow.sweep_states.split(','):
        sweep_state_list.append(char)
    if len(sweep_state_list) == 1:
        sweep_state_list = sweep_state_list * len(all_sweeps)
    stim_log['state'] = stim_log.apply(lambda x: sweep_state_list[x.sweep], axis=1)
    if call_saline_awake:
        stim_log['state'] = ['awake' if x == 'saline' else x for x in stim_log['state'].values]
    ## Get depth label for each sweep ##
    sweep_depth_list = []
    for char in subrow.stim_depth.split(','):
        sweep_depth_list.append(char)
    if len(sweep_depth_list) == 1:
        sweep_depth_list = sweep_depth_list * len(all_sweeps)
    stim_log['stim_depth'] = stim_log.apply(lambda x: sweep_depth_list[x.sweep], axis=1)
    ## Get list of states and the middle current ##
    states = np.unique(stim_log['state'].values)
    currents = np.unique(stim_log[stim_log['stim_type'] == 'biphasic']['parameter'].values).astype(int)
    if len(currents) > 1:
        ch_curr = str(currents[1])
    else:
        ch_curr = str(currents[0])

    ## Load unit info ##
    fn_units_info = os.path.join(exp.data_folder, 'evoked_data', 'all_units_info.csv')
    if os.path.exists(fn_units_info):
        unit_info = pd.read_csv(fn_units_info)
        with open(os.path.join(exp.data_folder, 'evoked_data', 'units_event_spikes.pkl'), 'rb') as unit_file:
            all_unit_event_spikes = pickle.load(unit_file)
    else:
        print(' {} not found. Not analyzing this subject.\n'.format(fn_units_info))
        continue

    ## Make time bins for event spikes ##
    bins = np.arange(all_unit_event_spikes['event_window'][0], all_unit_event_spikes['event_window'][1] + time_bin, time_bin)
    timex = bins[:-1] + time_bin/2
    
    ## Separate unit info by ROI ##
    ROI_unit_info = {}
    for region in np.unique(unit_info['parent_region'].values):
        udf = unit_info[unit_info['parent_region'] == region]
        if len(udf) >= unit_threshold:
            ROI_unit_info[region] = udf.sort_values(by='depth').reset_index(drop=True)
#             print('{}: {:d} units'.format(region, len(udf)))
            
    ## Get event_inds ##
    state_event_inds = {}
    for statei in states:
        if statei in skip_states:
            continue
        if statei == 'psilocybin':
            exp_meta = metadata[(
                (metadata['mouse_name'].str.contains(subrow.mouse)) &
                (metadata['exp_name'].str.contains(subrow.experiment))
            )].squeeze()
            stim_log['onset_from_inj2'] = stim_log['onset'] - float(exp_meta['Second injection time'])
            for substi, trialtest in substates.items():
                eventinds = stim_log[
                    (stim_log['stim_type'] == 'biphasic') &
                    (stim_log['parameter'] == ch_curr) &
                    (stim_log['stim_depth'] == 'deep') &
                    (stim_log['good'] == True) &
                    (stim_log['resting_trial'] == trialtest) &
                    (stim_log['state'] == statei) &
                    (stim_log['onset_from_inj2'] < psilocybin_window)
                ].index.values
                state_event_inds[statei + '_' + substi] = eventinds
        elif statei == 'awake':
            for substi, trialtest in substates.items():
                eventinds = stim_log[
                    (stim_log['stim_type'] == 'biphasic') &
                    (stim_log['parameter'] == ch_curr) &
                    (stim_log['stim_depth'] == 'deep') &
                    (stim_log['good'] == True) &
                    (stim_log['resting_trial'] == trialtest) &
                    (stim_log['state'] == statei)
                ].index.values
                state_event_inds[statei + '_' + substi] = eventinds
        else:
            eventinds = stim_log[
                (stim_log['stim_type'] == 'biphasic') &
                (stim_log['parameter'] == ch_curr) &
                (stim_log['stim_depth'] == 'deep') &
                (stim_log['good'] == True) &
                (stim_log['resting_trial'] == True) &
                (stim_log['state'] == statei)
            ].index.values
            state_event_inds[statei] = eventinds
              
    all_subjects_data[subrow.mouse][subrow.exp_type] = {
        'unit_metrics': {},
        'unit_zscores': {},
        'pop_fr': {},
        'trial_counts': {},
        'region_counts': {},
        'spike_latencies': {}
    }
    ## Get metrics for each state/region #
    for statei, event_inds in state_event_inds.items():
        if len(event_inds) < trial_threshold:
            print(' Only {:d} trials for {} state, not analyzing.'.format(len(event_inds), statei))
            continue
        elif len(event_inds) > trial_max:
            print(' {} has {:d} trials, downsampling trials to {:d}.'.format(statei, len(event_inds), trial_max))
            event_inds = np.random.choice(event_inds, size=trial_max, replace=False)
        all_subjects_data[subrow.mouse][subrow.exp_type]['trial_counts'][statei] = len(event_inds)
        all_subjects_data[subrow.mouse][subrow.exp_type]['unit_metrics'][statei] = {}
        all_subjects_data[subrow.mouse][subrow.exp_type]['unit_zscores'][statei] = {}
        all_subjects_data[subrow.mouse][subrow.exp_type]['pop_fr'][statei] = {}
        all_subjects_data[subrow.mouse][subrow.exp_type]['spike_latencies'][statei] = {}
        regcount = 0
        for regi, regdf in ROI_unit_info.items():
            regcount += 1
            unit_firing_rates = np.zeros((len(timex), len(regdf)), dtype=float) * np.nan
            sig_evoked_units = np.zeros(len(regdf), dtype=int)
            burst_trials = np.zeros(len(regdf), dtype=float)
            mean_burst_counts = np.zeros(len(regdf), dtype=float)
            early_spike_times = np.zeros((len(event_inds), len(regdf)), dtype=float) * np.nan
            late_spike_times = np.zeros((len(event_inds), len(regdf)), dtype=float) * np.nan

            ## Get unit event spike times ##
            reg_unit_metrics = []
            for ii, unitrow in regdf.iterrows():
                unit_event_spikes = [all_unit_event_spikes['event_spikes'][unitrow.unit_id][ei] for ei in event_inds]
                unit_event_bursts = [all_unit_event_spikes['event_bursts'][unitrow.unit_id]['times'][ei] for ei in event_inds]
                ## Get firing rates ##    
                unit_event_counts, edges = np.histogram(np.concatenate(unit_event_spikes), bins)
                unit_firing_rates[:, ii] = unit_event_counts / (time_bin * len(event_inds))

                prespikes = np.zeros(len(event_inds), dtype=int)
                postspikes = np.zeros(len(event_inds), dtype=int)
                trial_counts = np.zeros(len(event_inds), dtype=int)
                burst_counts = np.zeros(len(event_inds), dtype=int)
                for jj, uspikesi in enumerate(unit_event_spikes):
                    ## Count spikes ##
                    prespikes[jj] = np.sum((uspikesi >= -sig_test_window[1]) & (uspikesi <= -sig_test_window[0]))
                    postspikes[jj] = np.sum((uspikesi >= sig_test_window[0]) & (uspikesi <= sig_test_window[1]))

                    ## Count bursts ##
                    windowbursts = np.nonzero(
                        (unit_event_bursts[jj] >= burst_window[0]) & (unit_event_bursts[jj] <= burst_window[1]))[0]
                    if len(windowbursts) > 0:
                        trial_counts[jj] = 1
                    burst_counts[jj] = len(np.nonzero(unit_event_bursts[jj] >= burst_window[0])[0])

                    ## Find first spikes ##
                    earlyspikes = np.nonzero((uspikesi >= early_window[0]) & (uspikesi <= early_window[1]))[0]
                    if len(earlyspikes) > 0:
                        early_spike_times[jj, ii] = uspikesi[earlyspikes[0]]
                    latespikes = np.nonzero((uspikesi >= late_window[0]) & (uspikesi <= late_window[1]))[0]
                    if len(latespikes) > 0:
                        late_spike_times[jj, ii] = uspikesi[latespikes[0]]

                wstat, pval = stats.wilcoxon(x=postspikes, y=prespikes, zero_method='zsplit')
                spcount = np.mean(postspikes) - np.mean(prespikes)

                reg_unit_metrics.append([
                    unitrow.unit_id, pval, spcount, np.nanmedian(early_spike_times[:,ii]) * 1E3,
                    np.nanmedian(late_spike_times[:,ii]) * 1E3, np.mean(trial_counts), np.mean(burst_counts),
                    np.mean(unit_firing_rates[timex < 0, ii])
                ])
                sig_evoked_units[ii] = pval < sigalpha
                burst_trials[ii] = np.mean(trial_counts) # fraction of trials with burst
                mean_burst_counts[ii] = np.mean(burst_counts) # avg number of evoked bursts
            unit_metrics_df = regdf.merge(
                pd.DataFrame(reg_unit_metrics, columns=[
                    'unit_id', 'p_value', 'mean_spike_diff', 'early_latency', 'late_latency',
                    'burst_prob', 'burst_count', 'baselineFR'
                ]), on='unit_id', how='left')
            all_subjects_data[subrow.mouse][subrow.exp_type]['unit_metrics'][statei][regi] = unit_metrics_df
            all_subjects_data[subrow.mouse][subrow.exp_type]['unit_zscores'][statei][regi] = [
                timex, get_zFR(unit_firing_rates, timex)]
            all_subjects_data[subrow.mouse][subrow.exp_type]['pop_fr'][statei][regi] = [
                timex, np.mean(unit_firing_rates, axis=1)]
            all_subjects_data[subrow.mouse][subrow.exp_type]['spike_latencies'][statei][regi] = [
                early_spike_times * 1E3, late_spike_times * 1E3]
            
            ## Get first and rebound spike latencies across the population for each trial ##
            latencies = []
            for latmat, wink, bink in zip([early_spike_times, late_spike_times], [early_window, late_window], [50, 100]):
                pop_lat = []
                for tspikes in latmat:
                    tc, tb = np.histogram(tspikes, bins=bink, range=(wink[0], wink[1]))
                    if np.sum(tc) > 0:
                        pop_lat.append(tb[np.argmax(tc)])
                latencies.append(np.median(pop_lat))
            
            all_subjects_states_info.append([
                subrow.mouse, subrow.exp_type, statei, len(event_inds), regi, len(regdf),
                np.mean(unit_firing_rates[timex < 0, :]), np.mean(sig_evoked_units), np.median(burst_trials),
                np.mean(mean_burst_counts), latencies[0] * 1E3, latencies[1] * 1E3
            ])
            all_subjects_data[subrow.mouse][subrow.exp_type]['region_counts'][statei] = regcount
    print('')
all_subs_unit_stats_df = pd.DataFrame(all_subjects_states_info, columns=[
    'mouse', 'exp_type', 'state', 'trial_count', 'region', 'num_units',
    'baselineFR', 'fraction_sig', 'burst_prob', 'burst_count', 'early_latency', 'late_latency'
])

Skipping 657903 - psilocybin, experiment excluded from analysis.

666193: pilot_aw_psi_2023-02-16_10-55-48
This data does not contain an EEG recording.
Experiment type: electrical stimulation


  overwrite_input=overwrite_input)



666194: pilot_aw_psi_2023-02-23_10-40-34
Experiment type: electrical stimulation
 psilocybin_resting has 187 trials, downsampling trials to 125.

666196: pilot_aw_psi_2023-03-16_10-21-29
Experiment type: electrical and sensory stimulation
 Only 10 trials for awake_resting state, not analyzing.
 psilocybin_running has 271 trials, downsampling trials to 125.

669118: pilot_aw_psi_2023-03-24_09-55-33
Experiment type: electrical stimulation
 Only 3 trials for awake_resting state, not analyzing.
 Only 17 trials for psilocybin_resting state, not analyzing.
 psilocybin_running has 305 trials, downsampling trials to 125.

669117: pilot_aw_psi_2023-03-30_11-37-07
Experiment type: electrical stimulation
 Only 16 trials for awake_running state, not analyzing.
 psilocybin_resting has 252 trials, downsampling trials to 125.

Skipping 673449 - psilocybin, experiment excluded from analysis.

673449: aw_psi_d2_2023-04-20_10-05-31
Experiment type: electrical stimulation
 F:\psi_exp\mouse673449\aw_psi_

In [16]:
all_subs_unit_stats_df['subject'] = all_subs_unit_stats_df.apply(lambda row: row['mouse'] + '_' + row['exp_type'], axis=1)

In [17]:
all_subs_unit_stats_df.head()

Unnamed: 0,mouse,exp_type,state,trial_count,region,num_units,baselineFR,fraction_sig,burst_prob,burst_count,early_latency,late_latency,subject
0,666193,psilocybin,psilocybin_resting,125,HIP,153,5.443791,0.026144,0.0,0.087111,7.06,86.25,666193_psilocybin
1,666193,psilocybin,psilocybin_resting,125,ILA,258,1.351163,0.01938,0.0,0.005085,6.6,106.5,666193_psilocybin
2,666193,psilocybin,psilocybin_resting,125,MO,158,2.571646,0.329114,0.0,0.026076,3.38,185.25,666193_psilocybin
3,666193,psilocybin,psilocybin_resting,125,OLF,75,3.40384,0.013333,0.0,0.032107,6.6,97.5,666193_psilocybin
4,666193,psilocybin,psilocybin_resting,125,PL,119,3.801714,0.176471,0.0,0.002891,3.38,93.0,666193_psilocybin


In [21]:
tempdf = all_subs_unit_stats_df[
    (all_subs_unit_stats_df['state'] == 'psilocybin_resting') &
    (all_subs_unit_stats_df['region'] == 'RT-TH')
]
tempdf

Unnamed: 0,mouse,exp_type,state,trial_count,region,num_units,baselineFR,fraction_sig,burst_prob,burst_count,early_latency,late_latency,subject
57,666196,psilocybin,psilocybin_resting,58,RT-TH,35,16.642611,0.514286,0.051724,0.360591,6.83,153.75,666196_psilocybin


## Make multi-sub figure

In [18]:
all_states = {
    'awake_resting': 'o', 'awake_running': 'X',
    'psilocybin_resting': 's', 'psilocybin_running': 'P',
    'isoflurane': '^', 'urethane': 'v',
}
stlabels = ['AW\nrest', 'AW\nrun', 'PSI\nrest', 'PSI\nrun', 'ISO', 'UR']

In [19]:
metric_dict = {
    'baselineFR': ['Baseline firing rate', '(Hz)'],
    'fraction_sig': ['Significant response', 'Fraction of pop.'],
    'early_latency': ['First spike latency', 'Time from stim (ms)'],
    'late_latency': ['Rebound spike latency', 'Time from stim (ms)'],
    'burst_prob': ['Burst probability', 'Fraction of trials'],
    'burst_count': ['Burst count', 'Mean evoked\nburst count'],
}

In [20]:
plot_regions = ['MO', 'SM-TH', 'RT-TH']

## Figure 1, all areas separate ##
fig, axs = plt.subplots(nrows=len(plot_regions), ncols=len(metric_dict), figsize=(13, 6), sharex=True, sharey='col')
fig.set_tight_layout({'rect': [0.01, 0, 1, 0.98]})

for ii, regi in enumerate(plot_regions):
    for kk, (statei, stmarker) in enumerate(all_states.items()):
        tempdf = all_subs_unit_stats_df[
            (all_subs_unit_stats_df['region'] == regi) & (all_subs_unit_stats_df['state'] == statei)
        ]
        xs = np.random.normal(kk, 0.1, len(tempdf))
        for jj, (met, metinfo) in enumerate(metric_dict.items()):
            axs[ii,jj].boxplot(tempdf[met].values, positions=[kk], widths=[0.8], showfliers=False)
            axs[ii,jj].scatter(xs, tempdf[met].values, c='k', marker=stmarker, alpha=0.3)
            if kk == 0:
                axs[ii,jj].set_ylabel(metinfo[1])
                if ii == 0:
                    axs[ii,jj].set_title(metinfo[0])

axs[0,0].set_xticks(list(range(len(stlabels))))
axs[0,0].set_xticklabels(stlabels)
fig.text(0.005, 0.805, 'MO', va='center', rotation='vertical', fontsize=14)
fig.text(0.005, 0.5, 'MO-TH', va='center', rotation='vertical', fontsize=14)
fig.text(0.005, 0.21, 'RT', va='center', rotation='vertical', fontsize=14)

## Save ##
figname = 'allSUBS_unitmetrics_allstates_regsep.png'
# fig.savefig(os.path.join(plotsdir, figname), transparent=False, dpi=300)

<IPython.core.display.Javascript object>

In [22]:
## Fig 2, early latency ##
locs = np.arange((len(areas_of_interest) + 1) * len(stlabels)).reshape((len(stlabels), (len(areas_of_interest) + 1)))
fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True)

for jj, (regi, rcolor) in enumerate(areacolors.items()):
    for ii, (statei, stmarker) in enumerate(all_states.items()):
        tempdf = all_subs_unit_stats_df[
            (all_subs_unit_stats_df['region'] == regi) & (all_subs_unit_stats_df['state'] == statei)
        ]
        xs = np.zeros(len(tempdf)) + locs[ii,jj] # np.random.normal(locs[ii,jj], 0.1, len(tempdf))
        ax.scatter(xs, tempdf['early_latency'].values, c='k', marker=stmarker, alpha=0.3)
        ax.boxplot(
            tempdf['early_latency'].values, positions=[locs[ii,jj]], widths=[0.8], showfliers=False,
            medianprops={'color': rcolor, 'linewidth': 2}, boxprops={'color': rcolor},
        )
ax.set_title('First spike latency across subjects')
ax.set_ylabel('Latency from stim (ms)')
ax.set_xticks(np.mean(locs[:,:-1], axis=1))
ax.set_xticklabels(stlabels)

## Add subject lines ##
for ii, statei in enumerate(all_states.keys()):
    stdf = all_subs_unit_stats_df[all_subs_unit_stats_df['state'] == statei]
    for subi in np.unique(stdf['subject'].values):
        subdf = stdf[stdf['subject'] == subi]
        tempvals = np.zeros(len(areacolors), dtype=float)
        for kk, regi in enumerate(areacolors.keys()):
            if regi in subdf['region'].values:
                tempvals[kk] = subdf[subdf['region'] == regi].squeeze().early_latency
            else:
                tempvals[kk] = np.nan
        for kk in range(len(areacolors) - 1):
            ax.plot([locs[ii,kk], locs[ii,kk+1]], tempvals[kk:kk+2], color='k', linewidth=1.0, alpha=0.5)
        
## Save ##
figname = 'allSUBS_early_latency_sublines.png'
# fig.savefig(os.path.join(plotsdir, figname), transparent=False, dpi=300)

<IPython.core.display.Javascript object>

In [23]:
## Fig 3, late latency ##
locs = np.arange((len(areas_of_interest) + 1) * len(stlabels)).reshape((len(stlabels), (len(areas_of_interest) + 1)))
fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True)

for jj, (regi, rcolor) in enumerate(areacolors.items()):
    for ii, (statei, stmarker) in enumerate(all_states.items()):
        tempdf = all_subs_unit_stats_df[
            (all_subs_unit_stats_df['region'] == regi) & (all_subs_unit_stats_df['state'] == statei)
        ]
        xs = np.zeros(len(tempdf)) + locs[ii,jj] # np.random.normal(locs[ii,jj], 0.1, len(tempdf))
        ax.scatter(xs, tempdf['late_latency'].values, c='k', marker=stmarker, alpha=0.3)
        ax.boxplot(
            tempdf['late_latency'].values, positions=[locs[ii,jj]], widths=[0.8], showfliers=False,
            medianprops={'color': rcolor, 'linewidth': 2}, boxprops={'color': rcolor},
        )
ax.set_title('Rebound spike latency across subjects')
ax.set_ylabel('Latency from stim (ms)')
ax.set_xticks(np.mean(locs[:,:-1], axis=1))
ax.set_xticklabels(stlabels)

## Add subject lines ##
for ii, statei in enumerate(all_states.keys()):
    stdf = all_subs_unit_stats_df[all_subs_unit_stats_df['state'] == statei]
    for subi in np.unique(stdf['subject'].values):
        subdf = stdf[stdf['subject'] == subi]
        tempvals = np.zeros(len(areacolors), dtype=float)
        for kk, regi in enumerate(areacolors.keys()):
            if regi in subdf['region'].values:
                tempvals[kk] = subdf[subdf['region'] == regi].squeeze().late_latency
            else:
                tempvals[kk] = np.nan
        for kk in range(len(areacolors) - 1):
            ax.plot([locs[ii,kk], locs[ii,kk+1]], tempvals[kk:kk+2], color='k', linewidth=1.0, alpha=0.5)

## Save ##
figname = 'allSUBS_rebound_latency_sublines.png'
# fig.savefig(os.path.join(plotsdir, figname), transparent=False, dpi=300)

<IPython.core.display.Javascript object>

### Plot spike latency

In [24]:
plot_areas = {'MO': 'blueviolet', 'SM-TH': 'limegreen'}
plot_metric = 'late_latency'
plot_title = 'Rebound spike latency across subjects'

regdfs = []
for regi in plot_areas.keys():
    regdfs.append(all_subs_unit_stats_df[(all_subs_unit_stats_df['region'] == regi)])
regionsdf = pd.concat(regdfs)
regionsdf.head()

Unnamed: 0,mouse,exp_type,state,trial_count,region,num_units,baselineFR,fraction_sig,burst_prob,burst_count,early_latency,late_latency,subject
2,666193,psilocybin,psilocybin_resting,125,MO,158,2.571646,0.329114,0.0,0.026076,3.38,185.25,666193_psilocybin
12,666193,psilocybin,psilocybin_running,115,MO,158,4.144496,0.06962,0.0,0.019483,3.38,111.0,666193_psilocybin
41,666196,psilocybin,awake_running,110,MO,28,2.856656,0.321429,0.0,0.000974,3.84,158.25,666196_psilocybin
53,666196,psilocybin,psilocybin_resting,58,MO,28,2.649938,0.214286,0.0,0.001232,4.76,176.25,666196_psilocybin
65,666196,psilocybin,psilocybin_running,125,MO,28,2.226429,0.392857,0.0,0.001429,3.84,169.5,666196_psilocybin


In [25]:
for regi in plot_areas.keys():
    print(regi)
    for statei in all_states.keys():
        statevals = regionsdf[(regionsdf['region'] == regi) & (regionsdf['state'] == statei)][plot_metric].values
        if len(statevals) > 2:
            swstat, swp = stats.shapiro(statevals)
            if swp < sigalpha:
                normtag = 'NOT normal'
            else:
                normtag = 'normal'
        else:
            normtag = 'cannot run Shapiro-Wilk'
        print(' {}: N={:d}; {}; mean={:.2f}, SEM={:.2f}'.format(
            statei, len(statevals), normtag, np.mean(statevals), np.std(statevals)/np.sqrt(len(statevals))
        ))
    print('')

MO
 awake_resting: N=15; normal; mean=147.00, SEM=7.45
 awake_running: N=8; normal; mean=113.81, SEM=9.73
 psilocybin_resting: N=3; normal; mean=152.25, SEM=23.37
 psilocybin_running: N=3; normal; mean=125.25, SEM=18.44
 isoflurane: N=9; normal; mean=174.12, SEM=13.22
 urethane: N=9; normal; mean=159.25, SEM=13.32

SM-TH
 awake_resting: N=14; normal; mean=134.22, SEM=8.73
 awake_running: N=9; NOT normal; mean=108.00, SEM=9.82
 psilocybin_resting: N=4; normal; mean=127.59, SEM=15.84
 psilocybin_running: N=5; NOT normal; mean=93.00, SEM=12.09
 isoflurane: N=8; normal; mean=163.17, SEM=14.08
 urethane: N=9; NOT normal; mean=117.62, SEM=14.85



Apply a two-way ANOVA

In [26]:
metricANOVA = pg.anova(data=regionsdf, dv=plot_metric, between=['region', 'state'], detailed=True)
for indi, anovarow in metricANOVA.iterrows():
    if anovarow['p-unc'] < sigalpha:
        print('There is a significant effect of {} on {}.'.format(anovarow['Source'], plot_metric))
    else:
        print('There is NOT a significant effect of {} on {}.'.format(anovarow['Source'], plot_metric))

There is a significant effect of region on late_latency.
There is a significant effect of state on late_latency.
There is NOT a significant effect of region * state on late_latency.
There is NOT a significant effect of Residual on late_latency.


If there is a significant interaction -> run all the comparisons (the order of the "between" factors matters!).

In [23]:
metricposthoc = pg.pairwise_tests(data=regionsdf, dv=plot_metric, between=['region', 'state'], padjust='fdr_bh')
metricposthoc

Unnamed: 0,Contrast,region,A,B,Paired,Parametric,T,dof,alternative,p-unc,p-corr,p-adjust,BF10,hedges
0,region,-,MO,SM-TH,False,True,3.370913,86.43462,two-sided,0.001122,,,27.627,0.67792
1,state,-,awake_resting,awake_running,False,True,2.821748,31.610488,two-sided,0.008187,0.017544,fdr_bh,6.367,0.864333
2,state,-,awake_resting,isoflurane,False,True,-3.820116,34.671737,two-sided,0.000529,0.002645,fdr_bh,64.407,-1.135709
3,state,-,awake_resting,psilocybin_resting,False,True,-0.437614,10.507752,two-sided,0.670525,0.773682,fdr_bh,0.408,-0.161996
4,state,-,awake_resting,psilocybin_running,False,True,2.102788,9.599017,two-sided,0.062927,0.078659,fdr_bh,1.78,0.93662
5,state,-,awake_resting,urethane,False,True,-3.24252,33.060141,two-sided,0.002707,0.006767,fdr_bh,16.054,-0.982943
6,state,-,awake_running,isoflurane,False,True,-5.831553,32.0,two-sided,2e-06,2.7e-05,fdr_bh,7297.356,-1.952957
7,state,-,awake_running,psilocybin_resting,False,True,-2.587059,14.361401,two-sided,0.021175,0.039703,fdr_bh,3.501,-1.009967
8,state,-,awake_running,psilocybin_running,False,True,0.204982,12.081219,two-sided,0.840999,0.840999,fdr_bh,0.392,0.089918
9,state,-,awake_running,urethane,False,True,-5.25172,32.986099,two-sided,9e-06,6.6e-05,fdr_bh,1766.813,-1.733498


In [24]:
sigdf = metricposthoc[metricposthoc['p-corr'] < sigalpha]
sigdf

Unnamed: 0,Contrast,region,A,B,Paired,Parametric,T,dof,alternative,p-unc,p-corr,p-adjust,BF10,hedges
1,state,-,awake_resting,awake_running,False,True,2.821748,31.610488,two-sided,0.008187,0.017544,fdr_bh,6.367,0.864333
2,state,-,awake_resting,isoflurane,False,True,-3.820116,34.671737,two-sided,0.000529,0.002645,fdr_bh,64.407,-1.135709
5,state,-,awake_resting,urethane,False,True,-3.24252,33.060141,two-sided,0.002707,0.006767,fdr_bh,16.054,-0.982943
6,state,-,awake_running,isoflurane,False,True,-5.831553,32.0,two-sided,2e-06,2.7e-05,fdr_bh,7297.356,-1.952957
7,state,-,awake_running,psilocybin_resting,False,True,-2.587059,14.361401,two-sided,0.021175,0.039703,fdr_bh,3.501,-1.009967
9,state,-,awake_running,urethane,False,True,-5.25172,32.986099,two-sided,9e-06,6.6e-05,fdr_bh,1766.813,-1.733498
10,state,-,isoflurane,psilocybin_resting,False,True,2.47753,12.806535,two-sided,0.027982,0.046637,fdr_bh,2.962,1.013797
11,state,-,isoflurane,psilocybin_running,False,True,4.292557,11.062013,two-sided,0.001256,0.004709,fdr_bh,85.162,1.961446
15,state,-,psilocybin_running,urethane,False,True,-3.987488,12.217905,two-sided,0.001741,0.005223,fdr_bh,48.258,-1.716281
24,region * state,MO,awake_running,urethane,False,True,-3.291412,13.59269,two-sided,0.005542,0.020781,fdr_bh,8.689,-1.477056


If there is not a significant interaction, but there is a significant main effect, you can run post hoc tests on the factor with an effect.

In [28]:
metricposthoc = pg.pairwise_tukey(data=regionsdf, dv=plot_metric, between='state')
metricposthoc

Unnamed: 0,A,B,mean(A),mean(B),diff,se,T,p-tukey,hedges
0,awake_resting,awake_running,140.831897,110.735294,30.096602,11.548393,2.606129,0.106346,0.950288
1,awake_resting,isoflurane,140.831897,168.970588,-28.138692,11.548393,-2.436589,0.154914,-0.775411
2,awake_resting,psilocybin_resting,140.831897,138.160714,2.671182,15.920963,0.167778,0.999981,0.077524
3,awake_resting,psilocybin_running,140.831897,105.09375,35.738147,15.098123,2.367059,0.179098,1.071275
4,awake_resting,urethane,140.831897,138.4375,2.394397,11.344354,0.211065,0.99994,0.060279
5,awake_running,isoflurane,110.735294,168.970588,-58.235294,12.967511,-4.490861,0.000296,-1.581529
6,awake_running,psilocybin_resting,110.735294,138.160714,-27.42542,16.978458,-1.615307,0.590731,-0.802031
7,awake_running,psilocybin_running,110.735294,105.09375,5.641544,16.209389,0.348042,0.999306,0.17374
8,awake_running,urethane,110.735294,138.4375,-27.702206,12.786139,-2.166581,0.263707,-0.668941
9,isoflurane,psilocybin_resting,168.970588,138.160714,30.809874,16.978458,1.814645,0.46171,0.722259


In [29]:
locs = np.arange((len(plot_areas) + 1) * len(all_states)).reshape((len(all_states), (len(plot_areas) + 1)))
fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True)
stargs = {'xycoords': 'data', 'fontsize': 12, 'fontweight': 'bold', 'ha': 'left', 'va': 'center'}

for jj, (regi, rcolor) in enumerate(plot_areas.items()):
    for ii, (statei, stmarker) in enumerate(all_states.items()):
        tempdf = all_subs_unit_stats_df[
            (all_subs_unit_stats_df['region'] == regi) & (all_subs_unit_stats_df['state'] == statei)
        ]
        xs = np.zeros(len(tempdf)) + locs[ii,jj] # np.random.normal(locs[ii,jj], 0.1, len(tempdf))
        ax.scatter(xs, tempdf[plot_metric].values, c='k', marker=stmarker, alpha=0.3)
        ax.boxplot(
            tempdf[plot_metric].values, positions=[locs[ii,jj]], widths=[0.8], showfliers=False,
            medianprops={'color': rcolor, 'linewidth': 2}, boxprops={'color': rcolor},
        )
ax.set_title(plot_title)
ax.set_ylabel('Latency from stim (ms)')
ax.set_xticks(np.mean(locs[:,:-1], axis=1))
ax.set_xticklabels(stlabels)

## Add subject lines ##
for ii, statei in enumerate(all_states.keys()):
    stdf = all_subs_unit_stats_df[all_subs_unit_stats_df['state'] == statei]
    for subi in np.unique(stdf['subject'].values):
        subdf = stdf[stdf['subject'] == subi]
        tempvals = np.zeros(len(plot_areas), dtype=float)
        for kk, regi in enumerate(plot_areas.keys()):
            if regi in subdf['region'].values:
                tempvals[kk] = subdf[subdf['region'] == regi].squeeze()[plot_metric]
            else:
                tempvals[kk] = np.nan
        for kk in range(len(plot_areas) - 1):
            ax.plot([locs[ii,kk], locs[ii,kk+1]], tempvals[kk:kk+2], color='k', linewidth=1.0, alpha=0.5)

## Add stats ## for late_latency
# ax.plot([3,15], [105,105], color=plot_areas['MO'], linewidth=2)
# ax.annotate(p_stars(metricposthoc.iloc[24]['p-corr']), xy=(3,105), **stargs)

# ax.plot([1,4], [220,220], color=plot_areas['SM-TH'], linewidth=2)
# ax.annotate(p_stars(metricposthoc.iloc[31]['p-corr']), xy=(1,220), **stargs)
# ax.plot([1,13], [230,230], color=plot_areas['SM-TH'], linewidth=2)
# ax.annotate(p_stars(metricposthoc.iloc[32]['p-corr']), xy=(1,230), **stargs)
# ax.plot([1,10], [225,225], color=plot_areas['SM-TH'], linewidth=2)
# ax.annotate(p_stars(metricposthoc.iloc[34]['p-corr']), xy=(1,225), **stargs)

# ax.plot([4,13], [240,240], color=plot_areas['SM-TH'], linewidth=2)
# ax.annotate(p_stars(metricposthoc.iloc[36]['p-corr']), xy=(4,240), **stargs)
# ax.plot([4,7], [235,235], color=plot_areas['SM-TH'], linewidth=2)
# ax.annotate(p_stars(metricposthoc.iloc[37]['p-corr']), xy=(4,235), **stargs)
# ax.plot([4,16], [245,245], color=plot_areas['SM-TH'], linewidth=2)
# ax.annotate(p_stars(metricposthoc.iloc[39]['p-corr']), xy=(4,245), **stargs)

# ax.plot([7,10], [250,250], color=plot_areas['SM-TH'], linewidth=2)
# ax.annotate(p_stars(metricposthoc.iloc[43]['p-corr']), xy=(7,250), **stargs)
# ax.plot([7,13], [255,255], color=plot_areas['SM-TH'], linewidth=2)
# ax.annotate(p_stars(metricposthoc.iloc[40]['p-corr']), xy=(7,255), **stargs)

# ax.plot([10,13], [260,260], color=plot_areas['SM-TH'], linewidth=2)
# ax.annotate(p_stars(metricposthoc.iloc[41]['p-corr']), xy=(10,260), **stargs)
# ax.plot([10,16], [265,265], color=plot_areas['SM-TH'], linewidth=2)
# ax.annotate(p_stars(metricposthoc.iloc[45]['p-corr']), xy=(10,265), **stargs)
        
## Save ##
figname = 'allSUBS_{}_sublines_MO_TH_nostats.png'.format(plot_metric)
# fig.savefig(os.path.join(plotsdir, figname), transparent=False, dpi=300)

<IPython.core.display.Javascript object>

Plotting latency in a new way

In [31]:
plot_areas = {'MO': 'blueviolet', 'SM-TH': 'limegreen'}
plot_metric = 'late_latency'
plot_title = 'rebound spike latency'
ystart = 220
delta = 6

fig, axs = plt.subplots(ncols=3, nrows=1, figsize=(13, 4), constrained_layout=True, sharex=True)
stargs = {'xycoords': 'data', 'fontsize': 12, 'fontweight': 'bold', 'ha': 'left', 'va': 'center'}
axs[0].get_shared_y_axes().join(axs[0], axs[1])

for jj, (regi, rcolor) in enumerate(plot_areas.items()):
    regdf = all_subs_unit_stats_df[all_subs_unit_stats_df['region'] == regi]
    for ii, (statei, stmarker) in enumerate(all_states.items()):
        tempdf = regdf[regdf['state'] == statei]
        xs = np.random.normal(ii, 0.1, len(tempdf)) # np.zeros(len(tempdf)) + ii
        axs[jj].scatter(xs, tempdf[plot_metric].values, c='k', marker=stmarker, alpha=0.3)
        axs[jj].boxplot(
            tempdf[plot_metric].values, positions=[ii], widths=[0.8], showfliers=False,
            medianprops={'color': rcolor, 'linewidth': 2}, boxprops={'color': rcolor},
        )
    axs[jj].set_title(regi + ' ' + plot_title)
    
    ## Stats ##
    metricANOVA = pg.anova(data=regdf, dv=plot_metric, between='state', detailed=True)
    print('{}: ANOVA between states p-val = {:.3E}'.format(regi, metricANOVA['p-unc'].values[0]))
    if metricANOVA['p-unc'].values[0] < sigalpha:
        metricposthoc = pg.pairwise_tests(data=regdf, dv=plot_metric, between='state', padjust='fdr_bh')
        sigdf = metricposthoc[metricposthoc['p-corr'] < sigalpha].reset_index(drop=True)
        print(' Significant effect of state on {}. Posthoc tests find {:d} sig diffs.'.format(plot_metric, len(sigdf)))
#         for indi, sigrow in sigdf.iterrows():
#             xA = np.nonzero(np.array(list(all_states.keys())) == sigrow.A)[0][0]
#             xB = np.nonzero(np.array(list(all_states.keys())) == sigrow.B)[0][0]
#             axs[jj].plot([xA, xB], [ystart + indi * delta, ystart + indi * delta], color='k', linewidth=2)
#             axs[jj].annotate(p_stars(sigrow['p-corr']), xy=(np.min([xA, xB]), ystart + indi * delta), **stargs)
    print('')


## Difference between areas ##
axs[2].axhline(0, color='k', linestyle='dashed', alpha=0.25)
all_states_subvals = []
for ii, (statei, stmarker) in enumerate(all_states.items()):
    stdf = all_subs_unit_stats_df[all_subs_unit_stats_df['state'] == statei]
    subvals = np.zeros(len(np.unique(stdf['subject'].values)), dtype=float) * np.nan
    for jj, subi in enumerate(np.unique(stdf['subject'].values)):
        subdf = stdf[stdf['subject'] == subi]
        tempvals = np.zeros(len(plot_areas), dtype=float) * np.nan
        for kk, regi in enumerate(plot_areas.keys()):
            if regi in subdf['region'].values:
                tempvals[kk] = subdf[subdf['region'] == regi].squeeze()[plot_metric]
        subvals[jj] = tempvals[0] - tempvals[1]
        all_states_subvals.append([statei, subi, subvals[jj]])
    xs = np.random.normal(ii, 0.1, len(subvals))
    axs[2].scatter(xs, subvals, c='k', marker=stmarker, alpha=0.3)
    axs[2].boxplot(subvals[~np.isnan(subvals)], positions=[ii], widths=[0.8], showfliers=False)
lat_diff_df = pd.DataFrame(all_states_subvals, columns=['state', 'subject', 'latency_diff'])
metricANOVA = pg.anova(data=lat_diff_df, dv='latency_diff', between='state', detailed=True)
print('{}: ANOVA between states p-val = {:.3E}'.format('latency_diff', metricANOVA['p-unc'].values[0]))
if metricANOVA['p-unc'].values[0] < sigalpha:
    metricposthoc = pg.pairwise_tests(data=lat_diff_df, dv='latency_diff', between='state', padjust='fdr_bh')
    sigdf = metricposthoc[metricposthoc['p-corr'] < sigalpha].reset_index(drop=True)
    print(' Significant effect of state on {}. Posthoc tests find {:d} sig diffs.'.format(plot_metric, len(sigdf)))
#     yst = 60
#     deltay = 5
#     for indi, sigrow in sigdf.iterrows():
#         xA = np.nonzero(np.array(list(all_states.keys())) == sigrow.A)[0][0]
#         xB = np.nonzero(np.array(list(all_states.keys())) == sigrow.B)[0][0]
#         axs[2].plot([xA, xB], [yst + indi * deltay, yst + indi * deltay], color='k', linewidth=2)
#         axs[2].annotate(p_stars(sigrow['p-corr']), xy=(np.min([xA, xB]), yst + indi * deltay), **stargs)
    
    
axs[0].set_ylabel('Latency from stim (ms)')
axs[0].set_xticks(range(len(all_states)))
axs[0].set_xticklabels(stlabels)
axs[2].set_title('Rebound spike onset difference (MO - SM-TH)')
axs[2].set_ylabel('Difference (ms)')
        
## Save ##
figname = 'allSUBS_{}_separated.png'.format(plot_metric)
# fig.savefig(os.path.join(plotsdir, figname), transparent=False, dpi=300)

<IPython.core.display.Javascript object>

MO: ANOVA between states p-val = 3.273E-02
 Significant effect of state on late_latency. Posthoc tests find 0 sig diffs.

SM-TH: ANOVA between states p-val = 2.199E-02
 Significant effect of state on late_latency. Posthoc tests find 0 sig diffs.

latency_diff: ANOVA between states p-val = 2.865E-01


### Plot baseline firing rates

In [50]:
plot_region = 'SM-TH'
plot_metric = 'baselineFR'

regiondf = all_subs_unit_stats_df[all_subs_unit_stats_df['region'] == plot_region]

In [51]:
for statei in all_states.keys():
    statevals = regiondf[regiondf['state'] == statei][plot_metric].values
    print(statei)
    if len(statevals) > 2:
        swstat, swp = stats.shapiro(statevals)
        if swp < sigalpha:
            normtag = 'NOT normal'
        else:
            normtag = 'normal'
    else:
        normtag = 'cannot run Shapiro-Wilk'
    print(' N={:d}; {}; mean={:.1f}, SEM={:.1f}\n'.format(
        len(statevals), normtag, np.mean(statevals), np.std(statevals)/np.sqrt(len(statevals))
    ))

awake_resting
 N=14; NOT normal; mean=7.7, SEM=1.0

awake_running
 N=9; normal; mean=12.7, SEM=1.3

psilocybin_resting
 N=4; normal; mean=7.4, SEM=0.6

psilocybin_running
 N=5; normal; mean=12.0, SEM=1.6

isoflurane
 N=8; normal; mean=0.4, SEM=0.1

urethane
 N=9; normal; mean=2.3, SEM=0.2



Apply a one-way ANOVA

In [52]:
metricANOVA = pg.anova(data=regiondf, dv=plot_metric, between='state', detailed=True)
print('ANOVA between states p-val = {:.3E}'.format(metricANOVA['p-unc'].values[0]))
if metricANOVA['p-unc'].values[0] < sigalpha:
    print('There is a significant effect of state on {}, now perform posthoc tests.'.format(plot_metric))
    metricposthoc = pg.pairwise_tests(data=regiondf, dv=plot_metric, between='state', padjust='fdr_bh')

ANOVA between states p-val = 1.654E-10
There is a significant effect of state on baselineFR, now perform posthoc tests.


In [53]:
sigdf = metricposthoc[metricposthoc['p-corr'] < sigalpha]
sigdf

Unnamed: 0,Contrast,A,B,Paired,Parametric,T,dof,alternative,p-unc,p-corr,p-adjust,BF10,hedges
0,state,awake_resting,awake_running,False,True,-2.960593,16.912992,two-sided,0.008796131,0.011995,fdr_bh,6.377,-1.225633
1,state,awake_resting,isoflurane,False,True,6.897097,13.359198,two-sided,9.437873e-06,4.7e-05,fdr_bh,10100.0,2.210057
4,state,awake_resting,urethane,False,True,5.057333,14.173155,two-sided,0.0001684796,0.000505,fdr_bh,343.704,1.678221
5,state,awake_running,isoflurane,False,True,9.162728,8.136435,two-sided,1.453399e-05,5.5e-05,fdr_bh,53810.0,3.973277
6,state,awake_running,psilocybin_resting,False,True,3.545174,10.761029,two-sided,0.004746029,0.00791,fdr_bh,8.732,1.41656
8,state,awake_running,urethane,False,True,7.705106,16.0,two-sided,8.995985e-07,1.3e-05,fdr_bh,11460.0,3.459259
9,state,isoflurane,psilocybin_resting,False,True,-10.077365,3.19866,two-sided,0.001573429,0.003934,fdr_bh,5955.798,-8.036434
10,state,isoflurane,psilocybin_running,False,True,-6.410907,4.036892,two-sided,0.002947451,0.005526,fdr_bh,313.129,-4.404608
11,state,isoflurane,urethane,False,True,-7.246742,12.288946,two-sided,8.918342e-06,4.7e-05,fdr_bh,4064.02,-3.223812
13,state,psilocybin_resting,urethane,False,True,7.140701,3.669661,two-sided,0.002803806,0.005526,fdr_bh,677.603,5.211647


Plot it

In [57]:
fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True)

for ii, (statei, stmarker) in enumerate(all_states.items()):
    tempdf = regiondf[regiondf['state'] == statei]
    xs = np.random.normal(ii, 0.1, len(tempdf)) # np.zeros(len(tempdf)) + locs[ii,jj]
    ax.scatter(xs, tempdf[plot_metric].values, c='k', marker=stmarker, alpha=0.3)
    ax.boxplot(
        tempdf[plot_metric].values, positions=[ii], widths=[0.8], showfliers=False,
#         medianprops={'color': rcolor, 'linewidth': 2}, boxprops={'color': rcolor},
    )
ax.set_title('Baseline firing rate: {}'.format(plot_region))
ax.set_ylabel('Firing rate (Hz)')
ax.set_xticklabels(stlabels)

## Add stats ##
for (indi, sigrow), yval in zip(sigdf.iterrows(), [21, 22, 23, 25, 24, 26, 27, 29, 31, 28, 30]):
    xA = np.nonzero(np.array(list(all_states.keys())) == sigrow.A)[0][0]
    xB = np.nonzero(np.array(list(all_states.keys())) == sigrow.B)[0][0]
    ax.plot([xA, xB], [yval, yval], color='k', linewidth=2)
    ax.annotate(
        p_stars(sigrow['p-corr']), xy=(np.min([xA, xB]), yval), xycoords='data', # np.min([xA, xB]); (xA + xB)/2
        fontsize=12, fontweight='bold', ha='left', va='center'
    )
    
## Save ##
figname = '{}_{}_all_subs_states.png'.format(plot_region, plot_metric)
# fig.savefig(os.path.join(plotsdir, figname), transparent=False, dpi=300)

<IPython.core.display.Javascript object>

### Plot rebound burst probability

In [86]:
plot_region = 'SM-TH'
plot_metric = 'burst_count'

regiondf = all_subs_unit_stats_df[all_subs_unit_stats_df['region'] == plot_region]

In [87]:
for statei in all_states.keys():
    statevals = regiondf[regiondf['state'] == statei][plot_metric].values
    print(statei)
    if len(statevals) > 2:
        swstat, swp = stats.shapiro(statevals)
        if swp < sigalpha:
            normtag = 'NOT normal'
        else:
            normtag = 'normal'
    else:
        normtag = 'cannot run Shapiro-Wilk'
    print(' N={:d}; {}; mean={:.3f}, SEM={:.3f}\n'.format(
        len(statevals), normtag, np.mean(statevals), np.std(statevals)/np.sqrt(len(statevals))
    ))

awake_resting
 N=14; normal; mean=0.413, SEM=0.058

awake_running
 N=9; normal; mean=0.120, SEM=0.031

psilocybin_resting
 N=4; normal; mean=0.337, SEM=0.086

psilocybin_running
 N=5; normal; mean=0.138, SEM=0.048

isoflurane
 N=8; NOT normal; mean=0.085, SEM=0.043

urethane
 N=9; normal; mean=0.305, SEM=0.037



Apply a Kruskal-Wallis H-test (non-parametric one-way ANOVA due to non-normal distributions)

In [88]:
dist_normal = True

In [89]:
if dist_normal:
    metricANOVA = pg.anova(data=regiondf, dv=plot_metric, between='state', detailed=True)
    print('ANOVA between states p-val = {:.3E}'.format(metricANOVA['p-unc'].values[0]))
    if metricANOVA['p-unc'].values[0] < sigalpha:
        print('There is a significant effect of state on {}, now perform posthoc tests.'.format(plot_metric))
        metricposthoc = pg.pairwise_tests(data=regiondf, dv=plot_metric, between='state', parametric=True, padjust='fdr_bh')
else:
    metricANOVA = pg.kruskal(data=regiondf, dv=plot_metric, between='state', detailed=True)
    print('ANOVA between states p-val = {:.3E}'.format(metricANOVA['p-unc'].values[0]))
    if metricANOVA['p-unc'].values[0] < sigalpha:
        print('There is a significant effect of state on {}, now perform posthoc tests.'.format(plot_metric))
        metricposthoc = pg.pairwise_tests(data=regiondf, dv=plot_metric, between='state', parametric=False, padjust='fdr_bh')

ANOVA between states p-val = 1.334E-04
There is a significant effect of state on burst_count, now perform posthoc tests.


In [90]:
metricposthoc

Unnamed: 0,Contrast,A,B,Paired,Parametric,T,dof,alternative,p-unc,p-corr,p-adjust,BF10,hedges
0,state,awake_resting,awake_running,False,True,4.285634,19.18155,two-sided,0.000392,0.002938,fdr_bh,73.983,1.512935
1,state,awake_resting,isoflurane,False,True,4.344828,19.963803,two-sided,0.000315,0.002938,fdr_bh,75.562,1.606791
2,state,awake_resting,psilocybin_resting,False,True,0.657842,5.432614,two-sided,0.537495,0.628238,fdr_bh,0.53,0.330253
3,state,awake_resting,psilocybin_running,False,True,3.407795,13.547084,two-sided,0.004429,0.013286,fdr_bh,11.075,1.284435
4,state,awake_resting,urethane,False,True,1.506898,20.395513,two-sided,0.14717,0.220755,fdr_bh,0.851,0.544738
5,state,awake_running,isoflurane,False,True,0.62233,13.017436,two-sided,0.544473,0.628238,fdr_bh,0.479,0.291913
6,state,awake_running,psilocybin_resting,False,True,-2.078311,3.681629,two-sided,0.112233,0.210436,fdr_bh,1.581,-1.512958
7,state,awake_running,psilocybin_running,False,True,-0.279741,7.017679,two-sided,0.787742,0.787742,fdr_bh,0.469,-0.155555
8,state,awake_running,urethane,False,True,-3.631701,16.0,two-sided,0.002244,0.009456,fdr_bh,15.824,-1.630476
9,state,isoflurane,psilocybin_resting,False,True,-2.308582,4.345539,two-sided,0.076923,0.164836,fdr_bh,1.969,-1.514451


In [91]:
sigdf = metricposthoc[metricposthoc['p-corr'] < sigalpha]
sigdf

Unnamed: 0,Contrast,A,B,Paired,Parametric,T,dof,alternative,p-unc,p-corr,p-adjust,BF10,hedges
0,state,awake_resting,awake_running,False,True,4.285634,19.18155,two-sided,0.000392,0.002938,fdr_bh,73.983,1.512935
1,state,awake_resting,isoflurane,False,True,4.344828,19.963803,two-sided,0.000315,0.002938,fdr_bh,75.562,1.606791
3,state,awake_resting,psilocybin_running,False,True,3.407795,13.547084,two-sided,0.004429,0.013286,fdr_bh,11.075,1.284435
8,state,awake_running,urethane,False,True,-3.631701,16.0,two-sided,0.002244,0.009456,fdr_bh,15.824,-1.630476
11,state,isoflurane,urethane,False,True,-3.657089,14.244856,two-sided,0.002522,0.009456,fdr_bh,15.274,-1.697849


Plot it

In [94]:
fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True)

for ii, (statei, stmarker) in enumerate(all_states.items()):
    tempdf = regiondf[regiondf['state'] == statei]
    xs = np.random.normal(ii, 0.1, len(tempdf)) # np.zeros(len(tempdf)) + locs[ii,jj]
    ax.scatter(xs, tempdf[plot_metric].values, c='k', marker=stmarker, alpha=0.3)
    ax.boxplot(
        tempdf[plot_metric].values, positions=[ii], widths=[0.8], showfliers=False,
#         medianprops={'color': rcolor, 'linewidth': 2}, boxprops={'color': rcolor},
    )
ax.set_title('Rebound burst count: {}'.format(plot_region))
ax.set_ylabel('Mean evoked burst count')
ax.set_xticklabels(stlabels)

## Add stats ##
for (indi, sigrow), yval in zip(sigdf.iterrows(), [0.75, 0.85, 0.8, 0.9, 0.95]):
    xA = np.nonzero(np.array(list(all_states.keys())) == sigrow.A)[0][0]
    xB = np.nonzero(np.array(list(all_states.keys())) == sigrow.B)[0][0]
    ax.plot([xA, xB], [yval, yval], color='k', linewidth=2)
    ax.annotate(
        p_stars(sigrow['p-corr']), xy=(np.min([xA, xB]), yval), xycoords='data', # np.min([xA, xB]); (xA + xB)/2
        fontsize=12, fontweight='bold', ha='left', va='center'
    )
    
## Save ##
figname = '{}_{}_all_subs_states.png'.format(plot_region, plot_metric)
# fig.savefig(os.path.join(plotsdir, figname), transparent=False, dpi=300)

<IPython.core.display.Javascript object>

#### Test plotting each subject

666193: 5 states/3 regions; 669118: 3/2; 655955: 1/3; 657903: 4/1

In [15]:
all_states = {
    'awake_resting': 'o', 'awake_running': 'X',
    'psilocybin_resting': 's', 'psilocybin_running': 'P',
    'isoflurane': '^', 'urethane': 'v',
}

In [28]:
subi = '666196'
for exp_type, sdict in all_subjects_data[subi].items():
    print(exp_type)
    print(sdict['unit_metrics'].keys())
    print('')

psilocybin
dict_keys(['awake_running', 'psilocybin_resting', 'psilocybin_running'])

saline
dict_keys(['awake_resting', 'awake_running'])



In [32]:
exp_states = {
    'psilocybin': ['awake_running', 'psilocybin_resting', 'psilocybin_running'],
    'saline': ['awake_resting']
}

subject_dict = {
    'unit_metrics': {},
    'unit_zscores': {},
    'pop_fr': {},
    'trial_counts': {},
    'region_counts': {}
}
for exp_type, state_list in exp_states.items():
    for statei in state_list:
        for keyi in subject_dict.keys():
            subject_dict[keyi][statei] = all_subjects_data[subi][exp_type][keyi][statei]

In [40]:
Zlim = 5
plwin = [-0.2, 0.6]
burst_region = 'SM-TH'

num_states = len(subject_dict['unit_zscores'])
state_reg_counts = []
for statei, regdict in subject_dict['unit_zscores'].items():
    state_reg_counts.append(np.sum([1 for x in areacolors.keys() if x in subject_dict['unit_zscores']['awake_running'].keys()]))
num_regions = np.max(state_reg_counts)
if (num_states == 1) & (num_regions == 1):
    print('Not making figures for {}.'.format(subi))
pop_firing_rates = subject_dict['pop_fr']

### Unit z-scores figure ###
fig = plt.figure(figsize=(13, 6))
gs = fig.add_gridspec(ncols=1, nrows=2, height_ratios=[1, 2.5], left=0.04, right=0.98, top=0.94, bottom=0.08, hspace=0.08)
popgs = gs[0].subgridspec(ncols=5, nrows=1, wspace=0.14)
Zgs = gs[1].subgridspec(ncols=5, nrows=3, hspace=0.08, wspace=0.16)

jj = 0
for statei in all_states.keys():
    if statei not in subject_dict['unit_zscores'].keys():
        continue
    if jj == 0:
        popax = fig.add_subplot(popgs[jj])
    else:
        popax = fig.add_subplot(popgs[jj], sharey=popax)
    popax.axvline(0, color='k', alpha=0.25)
    
    ii = 0
    for regi, rcolor in areacolors.items():
        if regi not in subject_dict['unit_zscores'][statei].keys():
            continue
        popax.plot(
            pop_firing_rates[statei][regi][0], pop_firing_rates[statei][regi][1],
            color=rcolor, linewidth=1.2, alpha=0.8, label=regi
        )
        datai = subject_dict['unit_zscores'][statei][regi]
        Zax = fig.add_subplot(Zgs[ii,jj])
        Zax.imshow(
            datai[1].T, cmap='bwr', interpolation='none', aspect='auto', origin='upper', vmin=-Zlim, vmax=Zlim,
            extent=[datai[0][0], datai[0][-1], 0, datai[1].shape[0]],
        )
        Zax.axvline(0, color='k', alpha=0.25)
        Zax.set_ylabel('{} (n={:d})'.format(regi, datai[1].shape[1]))
        Zax.set_yticks([])
        Zax.set_xlim(plwin)
        if ii != subject_dict['region_counts'][statei]-1:
            Zax.set_xticklabels([])
        else:
            Zax.set_xlabel('Time from stim onset (s)')
        ii += 1
    if jj == 0:
        popax.set_ylabel('Pop. FR (Hz)')
        popax.legend()
    popax.set_xlim(plwin)
    popax.set_xticklabels([])
    popax.set_title('{} ({:d} trials)'.format(statei, subject_dict['trial_counts'][statei]))
    jj += 1
        
## Title ##
fig.text(0.12, 0.98, subi, rotation='horizontal', va='center', ha='center', fontsize=12)
## Save ##
figname = '{}_evoked_unit_zscores_allstates.png'.format(subi)
fig.savefig(os.path.join(plotsdir, 'ind_sub_evoked_units', figname), transparent=False, dpi=300)



### Metrics figure ###
locs = np.arange((num_regions + 1) * num_states).reshape((num_states, (num_regions + 1)))
metfig = plt.figure(figsize=(10, 6))
gs = metfig.add_gridspec(ncols=2, nrows=1, left=0.08, right=0.95, top=0.93, bottom=0.12, wspace=0.25)
axs = gs[0].subgridspec(ncols=1, nrows=3, hspace=0.15).subplots(sharex=True)
bxs = gs[1].subgridspec(ncols=1, nrows=2, hspace=0.15).subplots(sharex=True)

ii = 0
stlabels2 = []
for statei in all_states.keys():
    if statei not in subject_dict['unit_metrics'].keys():
        continue
    jj = 0
    for regi, rcolor in areacolors.items():
        if regi not in subject_dict['unit_metrics'][statei].keys():
            continue
        reg_metrics = subject_dict['unit_metrics'][statei][regi]
        total_units = len(reg_metrics)
        xs = np.random.normal(locs[ii,jj], 0.08, total_units)
        ## Fraction of total units that are significantly activated (excited+inhibited) ##
        sig_units = np.sum(reg_metrics['p_value'].values < sigalpha)
        axs[0].bar(locs[ii,jj], sig_units/total_units, color=rcolor)

        ## Early spike latency ##
        early_lats = reg_metrics['early_latency'].values * 1E3
        axs[1].scatter(xs, early_lats, color='k', marker='o', alpha=0.4)
        axs[1].boxplot(
            early_lats[~np.isnan(early_lats)], positions=[locs[ii,jj]], widths=[0.8], showfliers=False, 
            medianprops={'color': rcolor, 'linewidth': 2}, boxprops={'color': rcolor},
        )

        ## Late spike latency ##
        late_lats = reg_metrics['late_latency'].values * 1E3
        axs[2].scatter(xs, late_lats, color='k', marker='o', alpha=0.4)
        axs[2].boxplot(
            late_lats[~np.isnan(late_lats)], positions=[locs[ii,jj]], widths=[0.8], showfliers=False, 
            medianprops={'color': rcolor, 'linewidth': 2}, boxprops={'color': rcolor},
        )

        if regi == burst_region:
            ## Burst probability ##
            bxs[0].scatter(xs, reg_metrics['burst_prob'].values, color='k', marker='o', alpha=0.4)
            bxs[0].boxplot(
                reg_metrics['burst_prob'].values, positions=[locs[ii,jj]], widths=[0.8], showfliers=False, 
                medianprops={'color': rcolor, 'linewidth': 2}, boxprops={'color': rcolor},
            )

            ## Burst count ##
            bxs[1].scatter(xs, reg_metrics['burst_count'].values, color='k', marker='o', alpha=0.4)
            bxs[1].boxplot(
                reg_metrics['burst_count'].values, positions=[locs[ii,jj]], widths=[0.8], showfliers=False, 
                medianprops={'color': rcolor, 'linewidth': 2}, boxprops={'color': rcolor},
            )
        jj += 1
    if ii == 0:
        regleg = []
        for regi, rcolor in areacolors.items():
            reg_metrics = subject_dict['unit_metrics'][statei][regi]
            regleg.append(Patch(facecolor=rcolor, label='{} (n={:d})'.format(regi, len(reg_metrics))))
        axs[0].legend(handles=regleg)
    templabel = statei.split('_')
    if len(templabel) > 1:
        stlabels2.append(templabel[0] + '\n' + templabel[1])
    else:
        stlabels2.append(templabel[0])
    ii += 1

axs[0].set_ylabel('Fraction of total units\nw/ significant response')
axs[1].set_ylabel('First spike latency\nfrom stim onset (ms)')
axs[2].set_ylabel('Rebound spike latency\nfrom stim onset (ms)')
axs[2].set_xticks(np.mean(locs[:,:-1], axis=1))
axs[2].set_xticklabels(stlabels2)

bxs[0].set_ylabel('Burst probability ({})'.format(burst_region))
bxs[1].set_ylabel('Mean burst count ({})'.format(burst_region))
bxs[1].set_xticks(locs[:,1])
bxs[1].set_xticklabels(stlabels2)

## Title ##
metfig.text(0.5, 0.98, subi, rotation='horizontal', va='center', ha='center', fontsize=12)
## Save ##
mfigname = '{}_unit_metrics_allstates.png'.format(subi)
metfig.savefig(os.path.join(plotsdir, 'ind_sub_evoked_units', mfigname), transparent=False, dpi=300)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

### Make single subject example plots

This may not work correctly due to some subjects having duplicate states across recordings, see above for solution.

## Testing metrics on single subjects with plots

In [15]:
skip_states = ['recovery']
psilocybin_window = 90 * 60 # min to include as "psilocybin" -> seconds
substates = {'resting': True, 'running': False}
trial_threshold = 18 # must have more than this number of trials to be included in analysis
trial_max = 125 # limit some states that have many trials
unit_threshold = 5 # must have st least this number of units to be included in analysis
time_bin = 0.0025 # size of time bins (s) for firing rate

sig_test_window = [0.0, 0.3]
sigalpha = 0.05 # significance threshold for unit activation
burst_window = [0.075, 0.3]
early_window = [0.002, 0.025]
late_window = [0.075, 0.3]

In [16]:
subrow = subject_df.iloc[4]
print(subrow.mouse)

if not subrow.analyze:
    print('Skipping {} - {} for now, missing data.\n'.format(subrow.mouse, subrow.exp_type))

669118


Add state and stim depth info to stim_log and get stim current for analysis

In [17]:
## Load EEGexp and stim_log ##
exp = EEGexp(subrow.data_loc, preprocess=False, make_stim_csv=False)
stim_log = pd.read_csv(exp.stimulus_log_file).astype({'parameter': str})

### Get all states in experiment ###
all_sweeps = np.unique(stim_log['sweep'].values)

## Get state label for each sweep ##
sweep_state_list = []
for char in subrow.sweep_states.split(','):
    sweep_state_list.append(char)
if len(sweep_state_list) == 1:
    sweep_state_list = sweep_state_list * len(all_sweeps)
stim_log['state'] = stim_log.apply(lambda x: sweep_state_list[x.sweep], axis=1)
## Get depth label for each sweep ##
sweep_depth_list = []
for char in subrow.stim_depth.split(','):
    sweep_depth_list.append(char)
if len(sweep_depth_list) == 1:
    sweep_depth_list = sweep_depth_list * len(all_sweeps)
stim_log['stim_depth'] = stim_log.apply(lambda x: sweep_depth_list[x.sweep], axis=1)
## Get list of states and the middle current ##
states = np.unique(stim_log['state'].values)
currents = np.unique(stim_log[stim_log['stim_type'] == 'biphasic']['parameter'].values).astype(int)
if len(currents) > 1:
    ch_curr = str(currents[1])
else:
    ch_curr = str(currents[0])
    
print(states)

Experiment type: electrical stimulation
['awake' 'psilocybin']


Add area of interest tag to all-units info

In [18]:
## Load unit info ##
fn_units_info = os.path.join(exp.data_folder, 'evoked_data', 'all_units_info.csv')
if os.path.exists(fn_units_info):
    unit_info = pd.read_csv(fn_units_info)
    with open(os.path.join(exp.data_folder, 'evoked_data', 'units_event_spikes.pkl'), 'rb') as unit_file:
        all_unit_event_spikes = pickle.load(unit_file)
else:
    print('{} not found. Not analyzing this subject.'.format(fn_units_stats))
#     continue

## Make time bins for event spikes ##
bins = np.arange(all_unit_event_spikes['event_window'][0], all_unit_event_spikes['event_window'][1] + time_bin, time_bin)
timex = bins[:-1] + time_bin/2

### Add parent region to df ###
parent_list = []
for regi in unit_info['region'].values:
    parent_region = [key for key in list(areas_of_interest.keys()) if regi in areas_of_interest[key]]
    if len(parent_region) == 1:
        parent_list.append(parent_region[0])
    else:
        parent_list.append('notROI')
unit_info['parent'] = parent_list

unit_info.head()

Unnamed: 0,unit_id,probe,peak_ch,depth,spike_duration,region,CCF_AP,CCF_DV,CCF_ML,parent
0,B2,probeB,0,3700,0.467002,VAL,260,179,179,MO-TH
1,B3,probeB,1,3700,0.549414,VAL,260,179,179,MO-TH
2,B6,probeB,1,3700,0.425796,VAL,260,179,179,MO-TH
3,B9,probeB,2,3680,0.480737,VAL,260,178,179,MO-TH
4,B10,probeB,4,3660,0.480737,VAL,260,177,179,MO-TH


In [19]:
ROI_unit_info = {}
for region in areas_of_interest.keys():
    udf = unit_info[unit_info['parent'] == region]
    if len(udf) >= unit_threshold:
        ROI_unit_info[region] = udf.sort_values(by='depth').reset_index(drop=True)
        print('{}: {:d} units'.format(region, len(udf)))

MO: 7 units
MO-TH: 62 units


Get eventinds

In [20]:
state_event_inds = {}
for statei in states:
    if statei in skip_states:
        print('Skipping {}.'.format(statei))
        continue
    if statei == 'psilocybin':
        exp_meta = metadata[(
            (metadata['mouse_name'].str.contains(subrow.mouse)) &
            (metadata['exp_name'].str.contains(subrow.experiment))
        )].squeeze()
        stim_log['onset_from_inj2'] = stim_log['onset'] - float(exp_meta['Second injection time'])
        for substi, trialtest in substates.items():
            eventinds = stim_log[
                (stim_log['stim_type'] == 'biphasic') &
                (stim_log['parameter'] == ch_curr) &
                (stim_log['stim_depth'] == 'deep') &
                (stim_log['good'] == True) &
                (stim_log['resting_trial'] == trialtest) &
                (stim_log['state'] == statei) &
                (stim_log['onset_from_inj2'] < psilocybin_window)
            ].index.values
            state_event_inds[statei + '_' + substi] = eventinds
    elif statei == 'awake':
        for substi, trialtest in substates.items():
            eventinds = stim_log[
                (stim_log['stim_type'] == 'biphasic') &
                (stim_log['parameter'] == ch_curr) &
                (stim_log['stim_depth'] == 'deep') &
                (stim_log['good'] == True) &
                (stim_log['resting_trial'] == trialtest) &
                (stim_log['state'] == statei)
            ].index.values
            state_event_inds[statei + '_' + substi] = eventinds
    else:
        eventinds = stim_log[
            (stim_log['stim_type'] == 'biphasic') &
            (stim_log['parameter'] == ch_curr) &
            (stim_log['stim_depth'] == 'deep') &
            (stim_log['good'] == True) &
            (stim_log['resting_trial'] == True) &
            (stim_log['state'] == statei)
        ].index.values
        state_event_inds[statei] = eventinds
print(state_event_inds.keys())

dict_keys(['awake_resting', 'awake_running', 'psilocybin_resting', 'psilocybin_running'])


In [26]:
unit_metrics = {}
unit_zscores = {}
pop_firing_rates = {}
for statei, event_inds in state_event_inds.items():
    if len(event_inds) < trial_threshold:
        print('Only {:d} trials for {} state, not analyzing.'.format(len(event_inds), statei))
        continue
    elif len(event_inds) > trial_max:
        print('{} has {:d} trials, downsampling trials to {:d}.'.format(statei, len(event_inds), trial_max))
        event_inds = np.random.choice(event_inds, size=trial_max, replace=False)
    unit_metrics[statei] = {}
    unit_zscores[statei] = {}
    pop_firing_rates[statei] = {}
    for regi, regdf in ROI_unit_info.items():

        unit_firing_rates = np.zeros((len(timex), len(regdf)), dtype=float) * np.nan
    #     sig_evoked_units = np.zeros(len(regdf), dtype=int)
    #     burst_trials = np.zeros(len(regdf), dtype=float)
    #     early_spike_latency = np.zeros(len(regdf), dtype=float)
    #     late_spike_latency = np.zeros(len(regdf), dtype=float)

        ## Get unit event spike times ##
        reg_unit_metrics = []
        for ii, unitrow in regdf.iterrows():
            unit_event_spikes = [all_unit_event_spikes['event_spikes'][unitrow.unit_id][ei] for ei in event_inds]
            unit_event_bursts = [all_unit_event_spikes['event_bursts'][unitrow.unit_id]['times'][ei] for ei in event_inds]
            ## Get firing rates ##    
            unit_event_counts, edges = np.histogram(np.concatenate(unit_event_spikes), bins)
            unit_firing_rates[:, ii] = unit_event_counts / (time_bin * len(event_inds))


            prespikes = np.zeros(len(event_inds), dtype=int)
            postspikes = np.zeros(len(event_inds), dtype=int)
            trial_counts = np.zeros(len(event_inds), dtype=int)
            burst_counts = np.zeros(len(event_inds), dtype=int)
            earlyfirstspikes = []
            latefirstspikes = []
            for jj, uspikesi in enumerate(unit_event_spikes):
                ## Count spikes ##
                prespikes[jj] = np.sum((uspikesi >= -sig_test_window[1]) & (uspikesi <= -sig_test_window[0]))
                postspikes[jj] = np.sum((uspikesi >= sig_test_window[0]) & (uspikesi <= sig_test_window[1]))

                ## Count bursts ##
                windowbursts = np.nonzero(
                    (unit_event_bursts[jj] >= burst_window[0]) & (unit_event_bursts[jj] <= burst_window[1]))[0]
                if len(windowbursts) > 0:
                    trial_counts[jj] = 1
                burst_counts[jj] = len(np.nonzero(unit_event_bursts[jj] >= burst_window[0])[0])

                ## Find first spikes ##
                earlyspikes = np.nonzero((uspikesi >= early_window[0]) & (uspikesi <= early_window[1]))[0]
                if len(earlyspikes) > 0:
                    earlyfirstspikes.append(uspikesi[earlyspikes[0]])
                latespikes = np.nonzero((uspikesi >= late_window[0]) & (uspikesi <= late_window[1]))[0]
                if len(latespikes) > 0:
                    latefirstspikes.append(uspikesi[latespikes[0]])

            wstat, pval = stats.wilcoxon(x=postspikes, y=prespikes, zero_method='zsplit')

            reg_unit_metrics.append([
                unitrow.unit_id, pval, np.median(earlyfirstspikes), np.median(latefirstspikes),
                np.mean(trial_counts), np.mean(burst_counts)
            ])
    #         sig_evoked_units[ii] = pval < sigalpha
    #         burst_trials[ii] = np.mean(trial_counts) # fraction of trials with burst
    #         early_spike_latency[ii] = np.median(earlyfirstspikes)
    #         late_spike_latency[ii] = np.median(latefirstspikes)
        unit_metrics_df = regdf.merge(
            pd.DataFrame(reg_unit_metrics,
                         columns=['unit_id', 'p_value', 'early_latency', 'late_latency', 'burst_prob', 'burst_count']),
            on='unit_id', how='left'
        )
        unit_metrics[statei][regi] = unit_metrics_df
        unit_zscores[statei][regi] = [timex, get_zFR(unit_firing_rates, timex)]
        pop_firing_rates[statei][regi] = [timex, np.mean(unit_firing_rates, axis=1)]

Only 3 trials for awake_resting state, not analyzing.
psilocybin_running has 997 trials, downsampling trials to 125.


##### Make some plots

In [27]:
Zlim = 5
plwin = [-0.2, 0.6]

fig = plt.figure(figsize=(len(unit_zscores)*2.4, 5))
gs = fig.add_gridspec(ncols=1, nrows=2, height_ratios=[1, 2.5], left=0.08, right=0.97, top=0.9, bottom=0.1, hspace=0.1)
popaxs = gs[0].subgridspec(ncols=len(unit_zscores), nrows=1, wspace=0.14).subplots(sharex=True, sharey=True)
Zaxs = gs[1].subgridspec(ncols=len(unit_zscores), nrows=len(ROI_unit_info), hspace=0.08, wspace=0.14).subplots(sharex=True)

for jj, (statei, statedict) in enumerate(unit_zscores.items()):
    popaxs[jj].axvline(0, color='k', alpha=0.25)
    for ii, (regi, datai) in enumerate(statedict.items()):
        popaxs[jj].plot(
            pop_firing_rates[statei][regi][0], pop_firing_rates[statei][regi][1],
            color=areacolors[regi], linewidth=1.2, alpha=0.8, label=regi
        )
        imunit = Zaxs[ii,jj].imshow(
            datai[1].T, cmap='bwr', interpolation='none', aspect='auto', origin='upper', vmin=-Zlim, vmax=Zlim,
            extent=[datai[0][0], datai[0][-1], 0, datai[1].shape[0]],
        )
        Zaxs[ii,jj].axvline(0, color='k', alpha=0.25)
        Zaxs[ii,jj].set_ylabel('{} (n={:d})'.format(regi, datai[1].shape[1]))
        Zaxs[ii,jj].set_yticks([])

    Zaxs[ii,jj].set_xlim(plwin)
    Zaxs[ii,jj].set_xlabel('Time from stim onset (s)')

    popaxs[jj].set_xlim(plwin)
    popaxs[jj].set_xticklabels([])
    popaxs[jj].set_title('{} ({:d} trials)'.format(statei, len(state_event_inds[statei])))
popaxs[0].set_ylabel('Pop. FR (Hz)')
popaxs[0].legend()

<IPython.core.display.Javascript object>

<matplotlib.legend.Legend at 0x226bb366588>

In [28]:
locs = np.arange((len(ROI_unit_info) + 1) * len(unit_metrics)).reshape((len(unit_metrics), (len(ROI_unit_info) + 1)))
burst_region = 'MO-TH'

fig = plt.figure(figsize=(12, 8))
gs = fig.add_gridspec(ncols=2, nrows=1, left=0.15, right=0.95, top=0.9, bottom=0.08, wspace=0.25)
axs = gs[0].subgridspec(ncols=1, nrows=3, hspace=0.15).subplots(sharex=True)
bxs = gs[1].subgridspec(ncols=1, nrows=2, hspace=0.15).subplots(sharex=True)

for ii, (statei, statemets) in enumerate(unit_metrics.items()):
    for jj, (regi, reg_metrics) in enumerate(statemets.items()):
        total_units = len(reg_metrics)
        xs = np.random.normal(locs[ii,jj], 0.08, total_units)
        ## Fraction of total units that are significantly activated (excited+inhibited) ##
        sig_units = np.sum(reg_metrics['p_value'].values < sigalpha)
        axs[0].bar(locs[ii,jj], sig_units/total_units, color=areacolors[regi])
        
        ## Early spike latency ##
        early_lats = reg_metrics['early_latency'].values * 1E3
        axs[1].scatter(xs, early_lats, color='k', marker='o', alpha=0.4)
        axs[1].boxplot(
            early_lats[~np.isnan(early_lats)], positions=[locs[ii,jj]], widths=[0.8], showfliers=False, 
            medianprops={'color': areacolors[regi], 'linewidth': 2}, boxprops={'color': areacolors[regi]},
        )
        
        ## Late spike latency ##
        late_lats = reg_metrics['late_latency'].values * 1E3
        axs[2].scatter(xs, late_lats, color='k', marker='o', alpha=0.4)
        axs[2].boxplot(
            late_lats[~np.isnan(late_lats)], positions=[locs[ii,jj]], widths=[0.8], showfliers=False, 
            medianprops={'color': areacolors[regi], 'linewidth': 2}, boxprops={'color': areacolors[regi]},
        )
        
        if regi == burst_region:
            ## Burst probability ##
            bxs[0].scatter(xs, reg_metrics['burst_prob'].values, color='k', marker='o', alpha=0.4)
            bxs[0].boxplot(
                reg_metrics['burst_prob'].values, positions=[locs[ii,jj]], widths=[0.8], showfliers=False, 
                medianprops={'color': areacolors[regi], 'linewidth': 2}, boxprops={'color': areacolors[regi]},
            )
            
            ## Burst count ##
            bxs[1].scatter(xs, reg_metrics['burst_count'].values, color='k', marker='o', alpha=0.4)
            bxs[1].boxplot(
                reg_metrics['burst_count'].values, positions=[locs[ii,jj]], widths=[0.8], showfliers=False, 
                medianprops={'color': areacolors[regi], 'linewidth': 2}, boxprops={'color': areacolors[regi]},
            )
            
        
axs[0].set_ylabel('Fraction of total units\nw/ significant response')
axs[1].set_ylabel('First spike latency\nfrom stim onset (ms)')
axs[2].set_ylabel('Rebound spike latency\nfrom stim onset (ms)')
axs[2].set_xticks(np.mean(locs[:,:-1], axis=1))
axs[2].set_xticklabels(unit_metrics.keys())

regleg = []
for regi, reg_metrics in statemets.items():
    regleg.append(Patch(facecolor=areacolors[regi], label='{} (n={:d})'.format(regi, len(reg_metrics))))
axs[0].legend(handles=regleg)

bxs[0].set_ylabel('Burst probability ({})'.format(burst_region))
bxs[1].set_ylabel('Mean burst count ({})'.format(burst_region))
bxs[1].set_xticks(locs[:,1])
bxs[1].set_xticklabels(unit_metrics.keys())

<IPython.core.display.Javascript object>

[Text(1, 0, 'awake_running'),
 Text(4, 0, 'psilocybin_resting'),
 Text(7, 0, 'psilocybin_running')]

### Get states info for all subjects

In [60]:
psilocybin_window = 90 * 60 # min to include as "psilocybin" -> seconds
substates = {'resting': True, 'running': False}

subjects_states_info = []
for indi, row in subject_df.iterrows():
    if not row.analyze:
        print('Skipping {} - {} for now, missing data.\n'.format(row.mouse, row.exp_type))
        continue
    print('{}: {}'.format(row.mouse, row.experiment))
    
    ## Load EEGexp and stim_log ##
    exp = EEGexp(row.data_loc, preprocess=False, make_stim_csv=False)
    stim_log = pd.read_csv(exp.stimulus_log_file).astype({'parameter': str})
    
    ## Load unit info ##
    fn_units_info = os.path.join(exp.data_folder, 'evoked_data', 'all_units_info.csv')
    if os.path.exists(fn_units_info):
        unit_info = pd.read_csv(fn_units_info)
    else:
        print('{} not found. Not analyzing this subject.'.format(fn_units_stats))
        continue
        
    ### Add parent region to df and get unit count ###
    parent_list = []
    for regi in unit_info['region'].values:
        parent_region = [key for key in list(areas_of_interest.keys()) if regi in areas_of_interest[key]]
        if len(parent_region) == 1:
            parent_list.append(parent_region[0])
        else:
            parent_list.append('notROI')
    unit_info['parent'] = parent_list
    
    unit_counts = [len(unit_info)]
    for areai in areas_of_interest.keys():
        area_units = unit_info[unit_info['parent'] == areai].unit_id.values
        unit_counts.append(len(area_units))
        
    ### Get trial counts for all states ###
    all_sweeps = np.unique(stim_log['sweep'].values)
    
    ## Get state label for each sweep ##
    sweep_state_list = []
    for char in row.sweep_states.split(','):
        sweep_state_list.append(char)
    if len(sweep_state_list) == 1:
        sweep_state_list = sweep_state_list * len(all_sweeps)
    stim_log['state'] = stim_log.apply(lambda x: sweep_state_list[x.sweep], axis=1)
    ## Get depth label for each sweep ##
    sweep_depth_list = []
    for char in row.stim_depth.split(','):
        sweep_depth_list.append(char)
    if len(sweep_depth_list) == 1:
        sweep_depth_list = sweep_depth_list * len(all_sweeps)
    stim_log['stim_depth'] = stim_log.apply(lambda x: sweep_depth_list[x.sweep], axis=1)
    ## Get list of states and the middle current ##
    states = np.unique(stim_log['state'].values)
    currents = np.unique(stim_log[stim_log['stim_type'] == 'biphasic']['parameter'].values).astype(int)
    if len(currents) > 1:
        ch_curr = str(currents[1])
    else:
        ch_curr = str(currents[0])
        
    ## Get trials counts ##
    for statei in states:
        if statei == 'psilocybin':
            exp_meta = metadata[(
                (metadata['mouse_name'].str.contains(row.mouse)) &
                (metadata['exp_name'].str.contains(row.experiment))
            )].squeeze()
            stim_log['onset_from_inj2'] = stim_log['onset'] - float(exp_meta['Second injection time'])
            for substi, trialtest in substates.items():
                eventinds = stim_log[
                    (stim_log['stim_type'] == 'biphasic') &
                    (stim_log['parameter'] == ch_curr) &
                    (stim_log['stim_depth'] == 'deep') &
                    (stim_log['good'] == True) &
                    (stim_log['resting_trial'] == trialtest) &
                    (stim_log['state'] == statei) &
                    (stim_log['onset_from_inj2'] < psilocybin_window)
                ].index.values
                templist = [row.mouse, row.experiment, statei + '_' + substi, len(eventinds)]
                subjects_states_info.append(templist + unit_counts)
        elif statei == 'awake':
            for substi, trialtest in substates.items():
                eventinds = stim_log[
                    (stim_log['stim_type'] == 'biphasic') &
                    (stim_log['parameter'] == ch_curr) &
                    (stim_log['stim_depth'] == 'deep') &
                    (stim_log['good'] == True) &
                    (stim_log['resting_trial'] == trialtest) &
                    (stim_log['state'] == statei)
                ].index.values
                templist = [row.mouse, row.experiment, statei + '_' + substi, len(eventinds)]
                subjects_states_info.append(templist + unit_counts)
        else:
            eventinds = stim_log[
                (stim_log['stim_type'] == 'biphasic') &
                (stim_log['parameter'] == ch_curr) &
                (stim_log['stim_depth'] == 'deep') &
                (stim_log['good'] == True) &
                (stim_log['resting_trial'] == True) &
                (stim_log['state'] == statei)
            ].index.values
            templist = [row.mouse, row.experiment, statei, len(eventinds)]
            subjects_states_info.append(templist + unit_counts)
    print('')
    
subjects_states_df = pd.DataFrame(subjects_states_info, columns=[
    'mouse', 'experiment', 'state', 'trial_count', 'total_units', 'MO_units', 'MO-TH_units', 'RT_units'
])

657903: pilot_aw_psi_2023-01-13_12-18-22
Experiment type: electrical stimulation

666193: pilot_aw_psi_2023-02-16_10-55-48
This data does not contain an EEG recording.
Experiment type: electrical stimulation

Skipping 666194 - psilocybin for now, missing data.

Skipping 666196 - psilocybin for now, missing data.

669118: pilot_aw_psi_2023-03-24_09-55-33
Experiment type: electrical stimulation

Skipping 669117 - psilocybin for now, missing data.

582386: urethane_2021-07-15_11-36-58
Experiment type: electrical stimulation

638703: urethane_estim_2022-10-14_12-25-20
Experiment type: electrical stimulation

654181: urethane_vis_2022-11-23_08-30-16
Experiment type: electrical and sensory stimulation

654182: urethane_vis_2022-12-02_11-02-25
Experiment type: electrical and sensory stimulation

631037: urethane_2022-12-07_10-34-51
Experiment type: electrical stimulation

655955: urethane_2022-12-14_10-38-00
Experiment type: electrical stimulation

655956: urethane_2022-12-16_10-45-18
Experim

In [61]:
subjects_states_df.head()

Unnamed: 0,mouse,experiment,state,trial_count,total_units,MO_units,MO-TH_units,RT_units
0,657903,pilot_aw_psi_2023-01-13_12-18-22,awake_resting,100,981,0,124,0
1,657903,pilot_aw_psi_2023-01-13_12-18-22,awake_running,20,981,0,124,0
2,657903,pilot_aw_psi_2023-01-13_12-18-22,psilocybin_resting,19,981,0,124,0
3,657903,pilot_aw_psi_2023-01-13_12-18-22,psilocybin_running,101,981,0,124,0
4,666193,pilot_aw_psi_2023-02-16_10-55-48,psilocybin_resting,514,1347,147,211,0


In [62]:
trial_lim = 18

for whstate in np.unique(subjects_states_df['state'].values):
    subdf = subjects_states_df[subjects_states_df['state'] == whstate]
    subcount = np.sum(subdf['trial_count'].values >= trial_lim)
    print('{}: {:d} subjects'.format(whstate, subcount))

awake_resting: 12 subjects
awake_running: 10 subjects
isoflurane: 10 subjects
psilocybin_resting: 3 subjects
psilocybin_running: 3 subjects
recovery: 7 subjects
urethane: 8 subjects


In [22]:
subdf = subjects_states_df[subjects_states_df['mouse'] == '669118']
subdf.head()

Unnamed: 0,mouse,experiment,state,trial_count,total_units,MO_units,MO-TH_units,RT_units
6,669118,pilot_aw_psi_2023-03-24_09-55-33,awake_resting,3,668,7,62,0
7,669118,pilot_aw_psi_2023-03-24_09-55-33,awake_running,117,668,7,62,0
8,669118,pilot_aw_psi_2023-03-24_09-55-33,psilocybin_resting,19,668,7,62,0
9,669118,pilot_aw_psi_2023-03-24_09-55-33,psilocybin_running,997,668,7,62,0


In [63]:
tempdf = subjects_states_df[subjects_states_df['RT_units'] > 0]
tempdf

Unnamed: 0,mouse,experiment,state,trial_count,total_units,MO_units,MO-TH_units,RT_units
15,655955,urethane_2022-12-14_10-38-00,urethane,221,479,43,124,38
16,655956,urethane_2022-12-16_10-45-18,urethane,240,787,159,104,26
18,654181,estim_vis_2022-11-22_09-42-58,awake_resting,34,772,119,146,27
19,654181,estim_vis_2022-11-22_09-42-58,awake_running,85,772,119,146,27
20,654181,estim_vis_2022-11-22_09-42-58,isoflurane,119,772,119,146,27
55,666193,pilot_aw_2023-02-15_11-44-11,awake_resting,1072,1648,119,176,82
56,666193,pilot_aw_2023-02-15_11-44-11,awake_running,123,1648,119,176,82


In [65]:
subjects_states_df.groupby('mouse')['state'].count()

mouse
551397    4
551399    4
569062    4
569064    4
569068    4
569069    4
569073    4
571619    3
582386    1
631037    1
635397    3
638703    1
654181    4
654182    1
655955    1
655956    1
657903    4
666193    5
669118    4
Name: state, dtype: int64