In [1]:
%load_ext autoreload
%autoreload 2

from os import path
import json
import gc

import numpy as np
import pandas as pd
import scipy as sp

import matplotlib.pyplot as plt
from matplotlib import cm
import matplotlib as mpl
import seaborn as sns
from tqdm.auto import tqdm
# import pandarallel
from IPython.utils.capture import capture_output
with capture_output():
    tqdm.pandas()
#     pandarallel.pandarallel.initialize(progress_bar=True)
from mne.time_frequency import psd_array_multitaper

from tbd_eeg.data_analysis.eegutils import EEGexp
from tbd_eeg.data_analysis.Utilities.utilities import get_stim_events, find_nearest_ind

from ipympl.backend_nbagg import Canvas
Canvas.header_visible.default_value = False
%matplotlib widget

In [2]:
boi_defs = {
    'delta' : (0.8, 1.5),
    'theta' : (3, 6),
    'alpha' : (8, 13),
    'beta' : (15, 30),
    'lower_gamma' : (30, 60),
    'upper_gamma' : (60, 100)
}
boi_colors = {
    _band : cm.brg(i/6) for i, _band in enumerate(boi_defs)
}

hipp_channels = [4, 5, 7, 8, 21, 22, 24, 25]

In [3]:
def bin_spikes(spikes, bin_size_ms=1, t_start=0, t_end=1e5):
    N = int((t_end-t_start)*1000/bin_size_ms)
    spikes = spikes[(spikes>t_start)&(spikes<t_end)]
    _binned_spikes = pd.Series(
        index=(N * (spikes - t_start) / (t_end - t_start)).astype(int), data=1
    )
    counts = _binned_spikes.reset_index().groupby('index').size()
    times = np.linspace(t_start, t_end, N, endpoint=False)
    binned_spikes = np.zeros(N)
    binned_spikes[counts.index] = counts
    return pd.Series(binned_spikes, index=times, dtype=bool)

def load_spikes(rec_folder, probe, repeat=False):
    exp = EEGexp(rec_folder, preprocess=False, make_stim_csv=False)
    stim_log = pd.read_csv(exp.stimulus_log_file)
    stim_log['sweep'] = stim_log.sweep.map(
        {
            i : f'{i}_{k}' for i, k in enumerate(
                gmetadata[
                    gmetadata.exp_name==expt
                ]['brain states'].values[0].split('/')
            )
        }
    )
    stim_log.rename_axis(index='stim_id', inplace=True)
    
    fname = (
        f'../tiny-blue-dot/zap-n-zip/sg/spikes_stim_aligned/{exp.mouse}_'
        f'{path.basename(path.dirname(exp.experiment_folder))}_{probe}.pkl')
    if path.exists(fname) and not repeat:
        spikes = pd.read_pickle(fname)
    else:
        print('Loading and preprocessing spiking data.')
        # read raw spike time data
        _spike_times = np.load(
            exp.ephys_params[probe]['spike_times'], mmap_mode='r'
        )
        _spike_clusters = np.load(
            exp.ephys_params[probe]['spike_clusters'], mmap_mode='r'
        )
        cluster_metrics = pd.read_csv(
            exp.ephys_params[probe]['cluster_metrics'], index_col=1
        ).drop('Unnamed: 0', axis=1, errors='ignore')
        cluster_groups = pd.read_csv(
            exp.ephys_params[probe]['cluster_group'],
            sep='\t', index_col=0
        )

        # rearrange into spike times for each cluster
        spike_df = pd.DataFrame(
            index=_spike_clusters, data=_spike_times, columns=['time']
        )
        spike_times = spike_df.groupby(level=0).apply(lambda g: g.values[:, 0])

        # keep only good clusters and drop 'noise'
        cluster_metrics = cluster_metrics[cluster_groups.group.isin(['good'])]
        cluster_metrics = cluster_metrics[
            (cluster_metrics.isi_viol<0.5)&(cluster_metrics.amplitude_cutoff<0.1)
        ]
        spike_times = spike_times.loc[cluster_metrics.index]

        # bin spikes into 1ms bins
        spikes = {}
        t_start = spike_times.apply(lambda x: x.min()).min().round(3)
        t_end = spike_times.apply(lambda x: x.max()).max().round(3)
        for u, t in tqdm(spike_times.items(), total=len(spike_times)):
            spikes[u] = bin_spikes(t, t_start=t_start, t_end=t_end)
        spikes = pd.concat(spikes, axis=1, names='units').rename_axis('time')

        # keep only those spikes that are in a window of interest around stimulus times
        idx = stim_log.reset_index().set_index('onset').rename_axis('time')
        idx.index = idx.index - 1
        idx = idx.reindex(
            spikes.index, method='ffill', limit=1000*4
        ).reset_index().dropna()
        spikes = spikes.loc[idx.time]
        
        def _reset_index_time(df):
            df['time'] = (df.time - df.time.iloc[0] - 1).round(3)
            return df
        idx = idx.groupby('stim_id').apply(_reset_index_time).drop(
            ['offset', 'duration'], axis=1
        )
        
        # set index and columns with useful information
        spikes.index = pd.MultiIndex.from_frame(idx)
        spikes = spikes.sort_index()
        spike_waveforms = np.load(exp.ephys_params[probe]['waveforms'])
        width = {}
        for u in cluster_metrics.index:
            wav = spike_waveforms[u, cluster_metrics.loc[u, 'peak_channel'], :]
            wav_duration = np.abs(np.argmin(wav)-np.argmax(wav))
            width[u] = wav_duration / exp.ephys_params[probe]['ap_sample_rate']
        cluster_metrics = cluster_metrics.join(
            pd.Series(width, name='waveform_width')
        )
        _cm_names = ['FS_RS', 'peak_channel']

        try:
            with open(exp.ephys_params[probe]['probe_info'], 'r') as f:
                areas = json.load(f)['area_ch']
            areas = pd.Series(areas, name='area')
            layers = areas.str.extract('(\d.*)')[0].fillna('').rename('layer')
            areas = areas.str.rstrip('12/3456ab').fillna('')
            cluster_metrics['area'] = cluster_metrics.peak_channel.map(
                lambda x: areas.loc[x]
            )
            cluster_metrics['layer'] = cluster_metrics.peak_channel.map(
                lambda x: layers.loc[x]
            )
            _cm_names += ['area', 'layer']
        except:
            print('Area labels not found.')
        
        cluster_metrics['FS_RS'] = cluster_metrics.waveform_width.map(
            lambda x: 'FS' if x<0.0004 else 'RS'
        )
        spikes.columns = pd.MultiIndex.from_frame(
            cluster_metrics[_cm_names].reset_index()
            .rename({'cluster_id':'unit'}, axis=1)
        )
        # remove stimulation artifact
        spikes.loc[-0.001, 0.002] = False
        
        spikes.to_pickle(fname)
    return spikes

def get_mean_firing(rec_folder, probe, repeat_spikes=False):
    exp = EEGexp(rec_folder, preprocess=False, make_stim_csv=False)
    fname = (
        f'../tiny-blue-dot/zap-n-zip/sg/spikes_stim_aligned/fr_{exp.mouse}_'
        f'{path.basename(path.dirname(exp.experiment_folder))}_{probe}.pkl')
    if path.exists(fname):
        return pd.DataFrame(pd.read_pickle(fname), dtype='int16')
    spikes = load_spikes(rec_folder, probe, repeat=repeat_spikes)
    fr = pd.DataFrame((
        spikes.groupby(['stim_type', 'parameter', 'sweep', 'time']).mean()*1000
    ).round(), dtype='int16')
    fr.to_pickle(fname)
    return pd.DataFrame(fr, dtype='int16')

# Select an experiment to analyze

In [4]:
# accessing the Google sheet with experiment metadata in python
# setting up the permissions:
# 1. install gspread (pip install gspread / conda install gspread)
# 2. copy the service_account.json file to '~/.config/gspread/service_account.json'
# 3. run the following:
import gspread
_gc = gspread.service_account() # need a key file to access the account (step 2)
_sh = _gc.open('Zap_Zip-log_exp') # open the spreadsheet
_df = pd.DataFrame(_sh.sheet1.get()) # load the first worksheet
gmetadata = _df.T.set_index(0).T # put it in a nicely formatted dataframe

In [5]:
gmetadata[gmetadata['brain states'].str.contains('ur', case=False)][0:]

Unnamed: 0,mouse_name,exp_name,brain states,stimulation,visual_stim,audio_stim,ISI (sec),stimulus duration (msec),Current (uA),Cortical Area stimulation,N trials per stimulus,EEG bad_channels,Npx,Units Sorted (X),Brain slices (X),Pupil tracking pre-processing,Brain areas assignment,"CCF coordinates stim electrode (surface,tip)","CCF area stim electrode (surface,tip)",Notes
24,mouse547868,estim_vis_2020-12-04_11-24-26,awake/UR/UR,electrical/sensory,white,,[3.5 4.5],0.2/250,20/40/60,M2,120/100,,A,X,X,,,,,during the 2nd UR we delivered only electrical...
45,mouse582386,urethane_2021-07-15_11-36-58,urethane/urethane/urethane/urethane/urethane,electrical,,,[3.5 4.5],0.2,60,M2,120,,"F,B",X,X,,X,not visible,not visible,EEG filter at 0.1Hz. stim electrode with 300 u...
51,mouse586469,urethane_2021-10-08_12-21-28,urethane/urethane/urethane,electrical,,,[3.5 4.5],0.2,60,SS-cortex,120,,"B,C",X,Tissuecyte,,,,,EEG filter at 0.1Hz. Urethane IP. Signals look...
53,mouse590478,urethane_2021-10-29_11-51-38,urethane/urethane,electrical,,,[3.5 4.5],0.2,70/80/100,SS-cortex,120,,B,X,Tissuecyte,,,,,EEG filter at 0.1Hz. Signals look OK. Pupil vi...
55,mouse590480,urethane_vis_2021-11-12_11-14-43,urethane/urethane/urethane,electrical/sensory,white,,[3.5 4.5],0.2/250,30/50/70,SS-cortex,120,8101322.0,"F,B",X,Tissuecyte,,,,,EEG filter at 0.1Hz. Visual stimuli delivered ...
57,mouse590481,urethane_vis_2021-11-19_10-59-04,urethane/urethane/urethane/urethane,electrical/sensory,white,,[3.5 4.5],0.2/250,30/50/70,SS-cortex,120,158101314.0,"F,B,C",,Tissuecyte,,,,,EEG filter at 0.1Hz. IP injection. Visual stim...


In [6]:
mouse = 'mouse582386'
# for mouse582386
# last block of URE is burst suppression (very high URE);
# surgical level are blocks 3, 4a
# fat mouse
# block 1 and 2 should have oscillations based on pupil
# blocks 3 and 4 should be slower
# block 5 is burst suppression
# rereference with saline channels

expt = gmetadata[gmetadata.mouse_name==mouse].exp_name
expt = expt[expt.str.contains('ur', case=False)].values[0]

bad_channels = gmetadata[(gmetadata.mouse_name==mouse)&(gmetadata.exp_name==expt)]['EEG bad_channels'].iloc[0]
if len(bad_channels)>0:
    bad_channels = [int(x) for x in bad_channels.split(',')]
else:
    bad_channels = []

# Load LFP, EEG and spiking data

In [7]:
rec_folder = (
    f'../tiny-blue-dot/zap-n-zip/EEG_exp/'
    f'{mouse}/{expt}/experiment1/recording1/'
)
exp = EEGexp(rec_folder, preprocess=False, make_stim_csv=False)

# load LFP data
probes = [k for k in exp.ephys_params.keys() if 'probe' in k]
sample_rate_lfp = exp.ephys_params[probes[0]]['lfp_sample_rate']

fn_mean_lfp = (
    f'../tiny-blue-dot/zap-n-zip/sg/lfp_preprocessed/{exp.mouse}_'
    f'{path.basename(path.dirname(exp.experiment_folder))}_by_area.pkl'
)
fn_lfp = (
    f'../tiny-blue-dot/zap-n-zip/sg/lfp_preprocessed/{exp.mouse}_'
    f'{path.basename(path.dirname(exp.experiment_folder))}_all_channels.pkl'
)
if 1:#not (path.exists(fn_mean_lfp) and path.exists(fn_lfp)):
    lfp, lfp_by_area, timestamps = {}, {}, {}
    _idx = None
    for probe in tqdm(probes, desc='loading LFPs'):
        lfp[probe] = np.memmap(
            exp.ephys_params[probe]['lfp_continuous'],
            dtype='int16', mode='r'
        )
        lfp[probe] = pd.DataFrame(np.reshape(lfp[probe], (
            int(lfp[probe].size/exp.ephys_params[probe]['num_chs']),
            exp.ephys_params[probe]['num_chs']
        )), index=np.load(
            exp.ephys_params[probe]['lfp_timestamps']
        ))

        # add area and ayers
        with open(exp.ephys_params[probe]['probe_info'], 'r') as f:
            data = json.load(f)
        areas = pd.Series(data['area_ch'], name='area')
        layers = areas.str.extract('(\d.*)')[0].fillna('').rename('layer')
        areas = areas.str.rstrip('12/3456ab').fillna('')

        cols = lfp[probe].columns.rename('channel').to_frame()
        cols['area'] = areas
        cols['layer'] = layers
        cols.loc[data['surface_channel']:, 'area'] = 'saline'
        cols.loc[data['air_channel']:, 'area'] = 'air'
        cols['offset'] = data['offset']

        lfp[probe].columns = pd.MultiIndex.from_frame(
            cols.drop('offset', axis=1)
        )

        # rereference to saline and compute mean over areas
        _lfp = lfp[probe][::5]
        if _idx is None:
            _idx = np.arange(
                _lfp.index[0]-_lfp.index[0]%0.002,
                _lfp.index[-1], 5/sample_rate_lfp
            )
        _lfp = _lfp.reindex(_idx, method='nearest')
        _top_saline_channels = _lfp.columns.get_level_values('channel')[
            _lfp.columns.get_level_values('area').isin(['saline'])
        ][-6:]

        _lfp = (_lfp.T - _lfp[_top_saline_channels].mean(1)).T
        lfp_by_area[probe] = (
            _lfp.groupby('area', axis=1)
            .mean().ffill().bfill().astype('int16')
        )
        lfp[probe] = _lfp.iloc[:, ::4].ffill().bfill().astype('int16')
    gc.collect()
    lfp_by_area = pd.concat(
        lfp_by_area, axis=1,
        names=['probe']+lfp_by_area[probe].columns.names
    ).ffill().bfill().astype('int16')
    lfp_by_area.to_pickle(fn_mean_lfp)
    lfp = pd.concat(
        lfp, axis=1, names=['probe']+lfp[probe].columns.names
    ).ffill().bfill().astype('int16')
    lfp.to_pickle(fn_lfp)
else:
    lfp_by_area = pd.read_pickle(fn_mean_lfp)
    lfp = pd.read_pickle(fn_lfp)
sample_rate_lfp = sample_rate_lfp / 5

# load EEG data
sample_rate_eeg = exp.ephys_params['EEG']['sample_rate']
eegdata = exp.load_eegdata(return_type='pd').drop(bad_channels, axis=1)
# add common average and hippocampal average signals
eegdata['common_avg'] = eegdata.mean(1)
eegdata['hipp_avg'] = eegdata[
    [x for x in hipp_channels if x not in bad_channels]
].mean(1)

# loading the stimulus table
stim_log = pd.read_csv(exp.stimulus_log_file)
stim_log['sweep'] = stim_log.sweep.map(
    {
        i : f'{i}_{k}' for i, k in enumerate(
            gmetadata[
                gmetadata.exp_name==expt
            ]['brain states'].values[0].split('/')
        )
    }
)
stim_log.rename_axis(index='stim_id', inplace=True)

# load spikes
fn_spikes = (
    f'../tiny-blue-dot/zap-n-zip/sg/spikes_stim_aligned/{exp.mouse}_'
    f'{path.basename(path.dirname(exp.experiment_folder))}_all.pkl'
)
if path.exists(fn_spikes):
    spikes = pd.read_pickle(fn_spikes)
else:
    spikes = {}
    for probe in probes:
        spikes[probe] = load_spikes(rec_folder, probe, repeat=False)
    spikes = pd.concat(spikes, axis=1, names=['probe']+spikes[probe].columns.names)
    spikes = spikes.swaplevel('time', 'sweep')
    spikes.sort_index(inplace=True)
    spikes.to_pickle(fn_spikes)

Experiment type: electrical stimulation
SomnoSuite log file not found.


HBox(children=(FloatProgress(value=0.0, description='loading LFPs', max=2.0, style=ProgressStyle(description_w…




# Window eegdata and extract band powers

In [8]:
def window_ts_dataframe(df, sample_rate, pre_win_size=2.5):
    idx = stim_log.reset_index().set_index('onset').rename_axis('time')
    idx.index = idx.index - pre_win_size
    idx = idx.reindex(
        df.index.rename('time'),
        method='ffill', limit=int(sample_rate*4)
    ).reset_index().dropna()
    df = df.loc[idx.time].copy()

    def _reset_index_time(_df):
        _df['time'] = (_df.time - _df.time.iloc[0] - pre_win_size).round(4)
        return _df
    idx = idx.groupby('stim_id').apply(_reset_index_time).drop(
        ['offset', 'duration'], axis=1
    )

    # set index and columns with useful information
    df.index = pd.MultiIndex.from_frame(idx)
    return df.sort_index().loc[:-0.1].swaplevel('time', 'sweep').sort_index()

In [10]:
eegpre = window_ts_dataframe(eegdata, sample_rate_eeg)
lfppre = window_ts_dataframe(lfp, sample_rate_lfp)
lfpmeanpre = window_ts_dataframe(lfp_by_area, sample_rate_lfp)

In [29]:
for channel in ['common_avg', 'hipp_avg']:
    f, (ax_sig, ax) = plt.subplots(
        2, 1, figsize=(8, 4), sharex=True, tight_layout=True,
        gridspec_kw=dict(height_ratios=[1, 4])
    )
    _f, _t, _Sxx = sp.signal.spectrogram(
        eegpre[channel], fs=sample_rate_eeg, nperseg=6001, noverlap=0
    )
    _Sxx = pd.DataFrame(_Sxx, index=_f).rolling(
        10, center=True, win_type='gaussian'
    ).mean(std=6).ffill().bfill()
    _Sxx = (_Sxx.T*_f).T
#     _Sxx = (_Sxx.T/_Sxx.sum(1)).T
    _Sxx = np.log10(_Sxx)
    
#     _Sxx[_Sxx>500*np.median(Sxx)] = np.nan
#     _Sxx = pd.DataFrame(_Sxx, index=_f).ffill()#
    ax.imshow(
        _Sxx[:100], aspect='auto', extent=[
            0, eegpre.index.levels[1].shape[0], _f[0], 100
        ],
        origin='lower', interpolation='none'#, cmap=cm.copper
    )
    ax.set_ylim(0, 100)
    ax.set_xlabel('stim id')
    ax.set_ylabel('frequency (Hz)')
    
    ax_sig.plot(
        np.linspace(
            0, eegpre.index.levels[1].shape[0], len(eegpre)
        ), eegpre[channel].values
    )
    ax_sig.set_title(channel)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [28]:
for channel, df in [('CA', lfpmeanpre['probeB']), ('DG-mo', lfpmeanpre['probeB'])]:
    f, (ax_sig, ax) = plt.subplots(
        2, 1, figsize=(8, 4), sharex=True, tight_layout=True,
        gridspec_kw=dict(height_ratios=[1, 4])
    )
    _f, _t, _Sxx = sp.signal.spectrogram(
        df[channel], fs=sample_rate_lfp, noverlap=0,
        nperseg=int(df.groupby('stim_id').size().median())
    )
    _Sxx = pd.DataFrame(_Sxx, index=_f).rolling(
        10, center=True, win_type='gaussian'
    ).mean(std=6).ffill().bfill()
    _Sxx = (_Sxx.T*_f).T
#     _Sxx = (_Sxx.T/_Sxx.sum(1)).T
    _Sxx = np.log10(_Sxx)
    
#     _Sxx[_Sxx>500*np.median(Sxx)] = np.nan
#     _Sxx = pd.DataFrame(_Sxx, index=_f).ffill()#
    ax.imshow(
        _Sxx[:100], aspect='auto', extent=[
            0, df.index.levels[1].shape[0], _f[0], 100
        ],
        origin='lower', interpolation='none'#, cmap=cm.copper
    )
    ax.set_ylim(0, 100)
    ax.set_xlabel('stim id')
    ax.set_ylabel('frequency (Hz)')
    
    ax_sig.plot(
        np.linspace(
            0, df.index.levels[1].shape[0], len(df)
        ), df[channel].values
    )
    ax_sig.set_title(channel)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## Extract band powers for each window

In [11]:
def get_band_powers(df):
    with capture_output():
        _psd, _freqs = psd_array_multitaper(
            df.T, sample_rate_eeg, 0, 100, adaptive=True,
            low_bias=False, normalization='full', n_jobs=6, verbose=0
        )
    _psd = pd.DataFrame(_psd.T, index=_freqs, columns=df.columns)
    
    _integrated_power = {}
    for _band, _idx in boi_defs.items():
        _integrated_power[f'{_band}'] = _psd.loc[_idx[0]:_idx[1]].apply(
            lambda col: sp.integrate.simps(col.values, col.index)
        )
    _integrated_power = pd.concat(_integrated_power, axis=1)
    _integrated_power['total'] = _psd.apply(
        lambda col: sp.integrate.simps(col.values, col.index)
    )
    _integrated_power = _integrated_power.T
    return _integrated_power.rename_axis('band')

In [12]:
_band_power_example = get_band_powers(
    df.swaplevel().sort_index()[:-0.005].swaplevel().sort_index()
)
f, ax = plt.subplots(figsize=(4, 3), tight_layout=True)
_band_power_example.drop('total').plot(ax=ax, cmap=cm.winter)
ax.set_yscale('log', nonpositive='mask')
ax.set_xlabel('band')
ax.set_ylabel('power')
ax.set_xticklabels([
    s.get_text().replace('_', '\n') for s in ax.get_xticklabels()
])
ax.get_legend().set_visible(False)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [13]:
band_powers = (
    eegpre.swaplevel('sweep', 'time').sort_index()[:-0.005]
    .swaplevel('sweep', 'time').sort_index()
    .groupby(level=[0, 1, 2, 3]).progress_apply(get_band_powers)
)
band_powers_normalized = (
    band_powers.T / band_powers.xs('total', level='band').T
).T
band_powers

HBox(children=(FloatProgress(value=0.0, max=1320.0), HTML(value='')))




Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,0,1,2,3,4,5,6,7,9,11,12,14,15,16,17,18,19,20,21,23,24,25,26,27,28,29,common_avg,hipp_avg
sweep,stim_id,stim_type,parameter,band,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1
0_urethane,0.0,biphasic,30,delta,682.340484,309.171335,267.599297,1704.966184,1680.195590,680.739846,1989.831959,1855.591978,2269.164002,2946.204641,2971.491359,2936.855222,3402.987913,3488.350615,3025.528248,3239.876258,2638.084657,2156.193904,1381.887434,1452.003659,833.942551,1105.694831,1285.276527,171.029450,235.037759,671.654111,1470.596352,1175.287980
0_urethane,0.0,biphasic,30,theta,105.020284,249.200538,495.428912,329.173257,600.732684,650.691065,528.198924,678.902954,979.785653,2032.573363,2133.721798,1791.094951,2411.259353,2708.505159,2329.385411,2481.680306,1475.756304,1145.462080,958.718247,399.758353,877.228903,489.612560,365.890844,416.145680,256.385436,111.507197,721.962694,630.514584
0_urethane,0.0,biphasic,30,alpha,85.488391,92.549053,130.978533,238.084931,275.837812,191.559656,346.216141,346.242936,429.313263,365.135661,285.647507,287.785820,282.721448,292.891630,298.942901,444.581664,460.358124,615.668040,272.948787,257.277370,230.435277,208.708898,214.357417,138.372111,119.544831,73.654938,182.502799,220.735340
0_urethane,0.0,biphasic,30,beta,71.617851,76.210112,84.555923,240.499161,271.869532,156.693553,338.981510,331.067961,432.298755,461.426148,383.309522,294.308567,325.601275,394.488560,398.547388,565.152892,494.279495,555.681799,249.515774,240.669512,206.847032,231.722611,222.391173,94.817882,131.178466,157.474119,157.125543,170.483562
0_urethane,0.0,biphasic,30,lower_gamma,66.850160,37.833797,34.924869,236.197362,141.344918,63.164495,341.990837,206.071387,344.045374,287.250420,163.971037,112.590874,156.753412,211.041734,194.680363,359.238979,365.623664,631.268822,143.782908,277.097074,126.378418,203.990382,235.858910,50.862250,81.487059,126.687701,80.502746,87.937672
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2_urethane,1199.0,biphasic,70,alpha,7.712237,5.535935,4.722593,14.558271,13.726542,8.294159,16.789392,16.611141,16.892040,11.204561,9.461690,8.957961,7.078797,7.933394,8.897264,8.843037,10.872814,11.654896,10.883342,12.117659,9.457890,9.805517,11.338768,4.405575,4.918955,5.694297,7.877950,10.170193
2_urethane,1199.0,biphasic,70,beta,8.310022,4.885038,5.877639,15.219009,12.543838,7.577812,13.389755,11.737743,10.955361,8.879941,7.290882,6.681587,8.709047,8.646542,8.033354,10.038217,10.216328,11.503264,9.967003,16.852974,9.533951,15.633393,16.854346,5.439571,7.631909,15.727387,6.340927,9.519719
2_urethane,1199.0,biphasic,70,lower_gamma,9.675642,6.084538,6.463249,17.839127,17.408330,8.354885,12.093148,13.393306,10.351327,11.663050,9.639155,12.088107,13.006222,13.430225,10.212087,13.122550,10.180056,11.808334,10.892836,18.725175,10.647627,18.657138,19.258339,7.568619,10.446113,19.436775,6.121847,9.880834
2_urethane,1199.0,biphasic,70,upper_gamma,4.127422,2.300270,1.968971,5.303869,8.018018,3.304278,4.486275,5.547777,4.619790,3.828436,3.058788,6.112576,2.918587,3.266037,2.797084,3.581620,3.424603,4.484611,2.838788,5.132318,2.840058,4.520449,5.044008,2.542893,3.264479,6.246657,1.701282,2.094925


## Characterize distribution of band powers across trials

In [14]:
electrode_id = 'hipp_avg' # choose an electrode to plot metrics
bois = ['delta', 'theta', 'beta', 'upper_gamma']
stim_type = 'biphasic'

dfs = {
    boi : (
        band_powers[electrode_id]
        .xs(stim_type, level='stim_type')
        .xs(boi, level='band')
        .unstack(['sweep', 'parameter'])
    ) for boi in bois
}

sweeps = dfs[bois[0]].columns.remove_unused_levels().levels[0]

f = plt.figure(
    figsize=(3*len(dfs), 2.2*len(sweeps)), tight_layout=True
)
f.suptitle(f'electrode {electrode_id}')
gs = plt.GridSpec(
    len(sweeps), len(dfs), figure=f
)
axes = []

for i, sw in enumerate(sweeps):
    axes.append([])
    for j, boi in enumerate(dfs.keys()):
        if i>0:
            ax = f.add_subplot(gs[i, j], sharex=axes[0][j])
        else:
            ax = f.add_subplot(gs[i, j])
        axes[i].append(ax)
        sns.histplot(
            dfs[boi][sw], common_bins=True, common_norm=True,
            ax=ax, element='step'
        )
        for t in ax.get_legend().texts:
            t.set_fontsize(8)
        ax.get_legend().get_title().set_fontsize(9)
        if i==len(sweeps)-1:
            ax.set_xlabel(f'{boi} power')
        else:
            ax.tick_params(labelbottom=False)
        if j!=0:
            ax.set_ylabel('')
        else:
            ax.set_ylabel(f'{sw}\ncount')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [15]:
# same as above figure, but with normalized band powers

dfs = {
    boi : (
        band_powers_normalized[electrode_id]
        .xs(stim_type, level='stim_type')
        .xs(boi, level='band')
        .unstack(['sweep', 'parameter'])
    ) for boi in bois
}

sweeps = dfs[bois[0]].columns.remove_unused_levels().levels[0]

f = plt.figure(
    figsize=(3*len(dfs), 2.2*len(sweeps)), tight_layout=True
)
f.suptitle(f'electrode {electrode_id}')
gs = plt.GridSpec(
    len(sweeps), len(dfs), figure=f
)
axes = []

for i, sw in enumerate(sweeps):
    axes.append([])
    for j, boi in enumerate(dfs.keys()):
        if i>0:
            ax = f.add_subplot(gs[i, j], sharex=axes[0][j])
        else:
            ax = f.add_subplot(gs[i, j])
        axes[i].append(ax)
        sns.histplot(
            dfs[boi][sw], common_bins=True, common_norm=True,
            ax=ax, element='step'
        )
        for t in ax.get_legend().texts:
            t.set_fontsize(8)
        ax.get_legend().get_title().set_fontsize(9)
        if i==len(sweeps)-1:
            ax.set_xlabel(f'{boi} power')
        else:
            ax.tick_params(labelbottom=False)
        if j!=0:
            ax.set_ylabel('')
        else:
            ax.set_ylabel(f'{sw}\ncount')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [16]:
x_col = 'common_avg'
y_col = 'hipp_avg'
f, ax = plt.subplots(figsize=(4, 3), tight_layout=True)
for i, boi in enumerate(bois):
    band_powers.xs(boi, level='band').plot.scatter(
        x=x_col, y=y_col, ax=ax, label=boi,
        edgecolor=f'C{i}', color='white', alpha=0.2
    )
ax.set_xlabel(f'channel {x_col}')
ax.set_ylabel(f'channel {y_col}')
ax.legend(fontsize=9)
ax.set_xscale('log')
ax.set_yscale('log')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [17]:
# same as above figure but with normalized band powers
f, ax = plt.subplots(figsize=(4, 3), tight_layout=True)
for i, boi in enumerate(bois):
    band_powers_normalized.xs(boi, level='band').plot.scatter(
        x=x_col, y=y_col, ax=ax, label=boi,
        edgecolor=f'C{i}', color='white', alpha=0.2
    )
ax.set_xlabel(f'channel {x_col}')
ax.set_ylabel(f'channel {y_col}')
ax.legend(fontsize=9)
ax.set_xscale('log')
ax.set_yscale('log')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [18]:
f, axes = plt.subplots(
    2, len(bois)//2, figsize=(2*len(bois)/2, len(bois)),
    sharey=True, sharex=True
)
axes = axes.flatten()

for i, (ax, boi) in enumerate(zip(axes, bois)):
    im = ax.imshow(
        band_powers.xs(boi, level='band').corr(),
        aspect='auto', vmin=-1, vmax=1, cmap=cm.RdBu
    )
    ax.set_title(boi, fontsize=9)
    if i>=len(bois)/2:
        ax.set_xlabel('channel', fontsize=9)
    if i%2==0:
        ax.set_ylabel('channel', fontsize=9)
    ax.tick_params(labelsize=7)

f.subplots_adjust(right=0.8)
cax = f.add_axes([0.85, 0.15, 0.05, 0.7])
f.colorbar(im, cax=cax)
cax.tick_params(labelsize=7);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [19]:
# same as above figure, but with normalized band powers
f, axes = plt.subplots(
    2, len(bois)//2, figsize=(2*len(bois)/2, len(bois)),
    sharey=True, sharex=True
)
axes = axes.flatten()

for i, (ax, boi) in enumerate(zip(axes, bois)):
    im = ax.imshow(
        band_powers_normalized.xs(boi, level='band').corr(),
        aspect='auto', vmin=-1, vmax=1, cmap=cm.RdBu
    )
    ax.set_title(boi, fontsize=9)
    if i>=len(bois)/2:
        ax.set_xlabel('channel', fontsize=9)
    if i%2==0:
        ax.set_ylabel('channel', fontsize=9)
    ax.tick_params(labelsize=7)

f.subplots_adjust(right=0.8)
cax = f.add_axes([0.85, 0.15, 0.05, 0.7])
f.colorbar(im, cax=cax)
cax.tick_params(labelsize=7);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Characterize responses for different band powers

In [20]:
# set up parameters used for the entire section
boi = 'dot'
electrode_id = 'common_avg'
parameter = '50'

powers = band_powers[electrode_id].unstack('band')
powers['dot'] = np.log(powers['theta'] / powers['delta'])
powers_discrete = powers.apply(pd.qcut, q=4, labels=range(4))
powers_normalized = band_powers_normalized[electrode_id].unstack('band')
powers_normalized['dot'] = np.log(powers['theta'] / powers['delta'])
powers_discrete_normalized = powers_normalized.drop('total', axis=1).apply(
    pd.qcut, q=4, labels=range(4)
)
display(powers)
display(powers_discrete)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,band,delta,theta,alpha,beta,lower_gamma,upper_gamma,total,dot
sweep,stim_id,stim_type,parameter,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
0_urethane,0.0,biphasic,30,1470.596352,721.962694,182.502799,157.125543,80.502746,16.005655,12860.676990,-0.711450
0_urethane,1.0,biphasic,30,1900.746031,1294.625665,226.535842,159.732198,104.533968,14.978256,16660.113679,-0.384025
0_urethane,2.0,biphasic,50,1692.181727,642.253963,207.313954,93.771877,78.511579,11.189472,12952.591838,-0.968790
0_urethane,3.0,biphasic,70,703.903899,513.547132,276.911495,92.471973,47.094202,11.068056,5474.095110,-0.315300
0_urethane,4.0,biphasic,70,661.265054,335.855201,98.178927,53.759791,56.441089,9.729902,4347.142710,-0.677475
...,...,...,...,...,...,...,...,...,...,...,...
2_urethane,1195.0,biphasic,70,4.334161,9.599387,5.932091,7.607473,7.075072,1.669324,66.309922,0.795171
2_urethane,1196.0,biphasic,70,3.003181,14.429327,5.936746,7.037119,5.671071,2.062488,73.835960,1.569591
2_urethane,1197.0,biphasic,70,2.257633,11.953749,7.548113,8.733232,6.196987,2.013443,63.115225,1.666728
2_urethane,1198.0,biphasic,30,2.996596,6.690662,7.543362,7.875894,6.903180,1.604043,60.682744,0.803236


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,band,delta,theta,alpha,beta,lower_gamma,upper_gamma,total,dot
sweep,stim_id,stim_type,parameter,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
0_urethane,0.0,biphasic,30,3,2,3,3,3,3,3,1
0_urethane,1.0,biphasic,30,3,3,3,3,3,3,3,1
0_urethane,2.0,biphasic,50,3,2,3,2,3,2,3,0
0_urethane,3.0,biphasic,70,1,2,3,2,1,2,1,1
0_urethane,4.0,biphasic,70,1,1,2,1,2,2,1,1
...,...,...,...,...,...,...,...,...,...,...,...
2_urethane,1195.0,biphasic,70,0,0,0,0,0,0,0,3
2_urethane,1196.0,biphasic,70,0,0,0,0,0,0,0,3
2_urethane,1197.0,biphasic,70,0,0,0,0,0,0,0,3
2_urethane,1198.0,biphasic,30,0,0,0,0,0,0,0,3


In [89]:
# _dot = np.log(powers['theta'] / powers['delta'])
f, (ax, ax2) = plt.subplots(
    1, 2, figsize=(8, 2.4), tight_layout=True,
    gridspec_kw=dict(width_ratios=[1, 4])
)
sns.histplot(powers['dot'], ax=ax, bins=80)
ax.set_xlabel('log (4 Hz / 1 Hz)')
powers['dot'].droplevel(
    ['sweep', 'stim_type', 'parameter']
).rolling(10).mean().sort_index().plot(ax=ax2)
ax2.set_ylabel('log (4 Hz / 1 Hz)')
f.suptitle(electrode_id, fontsize=12);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [None]:
plot the spectrogram to see changes in power

## Mean responses of EEG channels wrt state

In [22]:
post_win_size = 2
idx = stim_log.reset_index().set_index('onset').rename_axis('time')
idx.index = idx.index - 4 + post_win_size
idx = idx.reindex(
    eegdata.index.rename('time'),
    method='ffill', limit=sample_rate_eeg*4
).reset_index().dropna()
eegpost = eegdata.loc[idx.time].copy()

def _reset_index_time(df):
    df['time'] = (df.time-df.time.iloc[0]-4+post_win_size).round(4)
    return df
idx = idx.groupby('stim_id').apply(_reset_index_time).drop(
    ['offset', 'duration'], axis=1
)

# set index and columns with useful information
eegpost.index = pd.MultiIndex.from_frame(idx)
eegpost = eegpost.swaplevel('time', 'sweep').sort_index()

# example part of the dataframe
eegpost.loc[('0_urethane', 1.0, 'biphasic')]

Unnamed: 0_level_0,Unnamed: 1_level_0,0,1,2,3,4,5,6,7,9,11,12,14,15,16,17,18,19,20,21,23,24,25,26,27,28,29,common_avg,hipp_avg
parameter,time,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1
30,-2.0000,-14.234999,17.744999,31.589999,1.560000,8.580000,30.614999,21.449999,45.044998,67.274998,89.894997,7.605000,-0.195000,-2.340000,-0.585000,-10.725000,-3.120000,-5.655000,0.585000,-1.365000,-1.755000,25.154999,-0.975000,-6.045000,1.560000,-23.789999,-33.539999,9.397500,17.842499
30,-1.9996,-15.404999,23.594999,36.464999,-3.705000,9.555000,33.929999,15.404999,41.339998,72.149997,92.039997,15.794999,7.800000,-7.020000,0.975000,0.195000,2.340000,-7.020000,-3.120000,-7.605000,-0.975000,24.569999,-4.095000,-10.140000,4.680000,-21.644999,-47.969998,9.697500,16.282499
30,-1.9992,-4.680000,16.964999,32.369999,-1.170000,5.265000,33.929999,23.789999,45.434998,68.444997,100.424996,23.009999,14.819999,-23.399999,-18.719999,-21.449999,-17.354999,-17.159999,-24.374999,-19.109999,-14.234999,14.234999,-14.624999,-18.914999,-6.240000,-33.149999,-51.284998,3.570000,10.855000
30,-1.9988,-7.020000,25.154999,40.559999,-8.970000,15.794999,39.584999,27.689999,36.659999,60.839998,88.724997,7.410000,-8.775000,-10.920000,-2.535000,-5.070000,-8.580000,-18.134999,-13.844999,-5.460000,-7.020000,22.229999,-11.505000,-16.379999,7.605000,-26.129999,-44.654998,6.817500,16.217499
30,-1.9984,-13.260000,17.354999,33.344999,4.680000,9.555000,31.004999,15.404999,38.024999,54.599998,88.529997,13.065000,-2.535000,-20.084999,-11.505000,-14.624999,-17.159999,-21.059999,-27.884999,-17.159999,-26.129999,12.480000,-16.184999,-30.419999,-16.574999,-26.909999,-44.849998,0.450000,9.620000
30,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
30,1.9592,-54.599998,-38.999999,-39.194999,-91.649997,-93.209997,-75.854997,-104.714996,-81.314997,-106.469996,-128.504995,-132.014995,-114.464996,-133.379995,-132.794995,-132.794995,-132.794995,-127.919995,-99.254996,-75.269997,-78.389997,-77.804997,-77.219997,-67.274998,-53.819998,-56.159998,-54.794998,-90.794997,-80.112497
30,1.9596,-56.549998,-46.019998,-44.264998,-93.014997,-85.019997,-73.709997,-103.349996,-91.454997,-113.489996,-124.799995,-129.284995,-116.609996,-122.069996,-125.774995,-127.919995,-126.359995,-123.434995,-105.299996,-78.194997,-78.974997,-77.804997,-72.929997,-67.079998,-42.509998,-53.234998,-44.654998,-89.377497,-79.852497
30,1.9600,-58.109998,-41.924998,-47.384998,-93.599997,-82.484997,-73.319997,-99.449996,-78.194997,-99.254996,-124.994995,-117.974996,-110.174996,-120.119996,-120.119996,-131.819995,-136.304995,-116.999996,-101.204996,-74.099997,-79.169997,-80.729997,-77.024997,-77.024997,-57.134998,-60.644998,-54.989998,-89.009997,-77.642497
30,1.9604,-54.989998,-33.734999,-46.994998,-83.069997,-82.484997,-73.514997,-99.254996,-70.979997,-85.409997,-120.899996,-118.169996,-113.879996,-128.894995,-124.604995,-127.724995,-133.769995,-127.919995,-105.299996,-85.994997,-76.634997,-80.144997,-79.949997,-75.269997,-59.279998,-57.719998,-54.794998,-88.514997,-78.844997


In [33]:
_eegpost = eegpost.unstack('time')
_eegpost.index = pd.MultiIndex.from_frame(
    _eegpost.index.to_frame().join(powers_discrete)
)

In [79]:
eeg_mean_responses = _eegpost.groupby(['sweep', 'parameter', boi]).mean()
eeg_mean_responses = eeg_mean_responses.groupby(level=0, axis=1).apply(
    lambda df: df.rolling(10, center=True, win_type='gaussian', axis=1).mean(std=3)
)
sweeps = eeg_mean_responses.index.levels[0]
params = eeg_mean_responses.index.levels[1]

In [84]:
electrode_id = 12
f, axes = plt.subplots(
    len(sweeps), len(params), figsize=(len(params)*3, len(sweeps)*2),
    tight_layout=True, sharex=True
)
for axe, sw in zip(axes, sweeps):
    for ax, pa in zip(axe, params):
        eeg_mean_responses.loc[(sw, pa), electrode_id].T.dropna(how='all').plot(
            ax=ax, alpha=0.6, legend=False
        )
        if ax==axe[0]:
            ax.set_ylabel(sw)
        if axe[0]==axes[0, 0]:
            ax.set_title(pa)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

30 0
50 1
70 2
white 3
30 0
50 1
70 2
white 3
30 0
50 1
70 2
white 3


### Magnitude of different EEG response coponents vs band power

## Spiking responses wrt state

### # spikes vs state

In [22]:
_spikes = spikes.xs(parameter, level='parameter')

# compute PSTH for each neuron separately using all trials
_mfr = _spikes.groupby(['time']).mean()

# z-score the mean firing rates
_mfr = (_mfr - _mfr[:0].mean()) / _mfr[:0].std()

_mfr = _mfr.rolling(
    window=10, center=True, win_type='gaussian'
).mean(std=4)

In [23]:
f, ax = plt.subplots(figsize=(7, 2), tight_layout=True)
_mfr.T.reset_index(drop=True).T.plot(
    ax=ax, legend=False, c='C0', alpha=0.3
)
ax.set_xlim(-0.2, 0.5);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [24]:
n_spikes_bs = _spikes.swaplevel('sweep', 'time').sort_index()[-1.005:-0.005].groupby('stim_id').mean()

n_spikes_early = _spikes.swaplevel('sweep', 'time').sort_index()[0:0.2].groupby('stim_id').mean() - n_spikes_bs
n_spikes_late = _spikes.swaplevel('sweep', 'time').sort_index()[0.3:0.5].groupby('stim_id').mean() - n_spikes_bs

_powers = powers.xs(parameter, level='parameter').swaplevel('sweep', 'stim_id').sort_index().reset_index()

In [25]:
boi = 'delta'
f, ax = plt.subplots(figsize=(3, 2.4), tight_layout=True)
sns.histplot(n_spikes_late.corrwith(_powers[boi]), ax=ax, color='C1', alpha=0.2, element='step')
ax.axvline(n_spikes_late.corrwith(_powers[boi]).mean(), c='C1', label='late')
sns.histplot(n_spikes_early.corrwith(_powers[boi]), ax=ax, color='C0', alpha=0.2, element='step')
ax.axvline(n_spikes_early.corrwith(_powers[boi]).mean(), c='C0', label='early')
ax.legend(loc=0, fontsize=7)
ax.set_xlabel(f'corr ( # spikes, {boi} power )')
ax.set_ylabel('# neurons');

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

---

In [26]:
q_spikes = spikes.copy()
q_spikes.index = pd.MultiIndex.from_frame(
    q_spikes.index.to_frame()
    .reset_index('time', drop=True)
    .join(powers_discrete)
)
q_psths = q_spikes.groupby(['sweep', 'parameter', 'time', boi]).mean()

In [27]:
_q_psth = q_psths.loc[('0_urethane', '50')]

In [28]:
f, ax = plt.subplots(figsize=(6, 3), tight_layout=True)
_ = _q_psth.mean(1).unstack(boi).rolling(
    window=10, center=True, win_type='gaussian'
).mean(std=4).apply(
    lambda c: c.plot(ax=ax, c=cm.RdBu((c.name+1)/5, 0.5))
)
ax.set_xlim(-0.2, 1.2);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [29]:
_q_psth.xs(0, level=boi).T

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,time,-1.000,-0.999,-0.998,-0.997,-0.996,-0.995,-0.994,-0.993,-0.992,-0.991,-0.990,-0.989,-0.988,-0.987,-0.986,-0.985,-0.984,-0.983,-0.982,-0.981,-0.980,-0.979,-0.978,-0.977,-0.976,...,2.976,2.977,2.978,2.979,2.980,2.981,2.982,2.983,2.984,2.985,2.986,2.987,2.988,2.989,2.990,2.991,2.992,2.993,2.994,2.995,2.996,2.997,2.998,2.999,3.000
probe,unit,FS_RS,peak_channel,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1,Unnamed: 50_level_1,Unnamed: 51_level_1,Unnamed: 52_level_1,Unnamed: 53_level_1,Unnamed: 54_level_1
probeF,0.0,RS,0,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,
probeF,1.0,RS,1,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,
probeF,3.0,RS,1,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,
probeF,4.0,RS,1,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,
probeF,6.0,RS,2,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,
probeF,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
probeF,969.0,RS,183,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,
probeF,970.0,RS,222,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,
probeF,972.0,RS,256,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,
probeF,973.0,RS,290,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,


In [30]:
f, axes = plt.subplots(2, 1, figsize=(6, 6), tight_layout=True, sharex=True, sharey=True)
axes[0].imshow(_q_psth.xs(0, level=boi).T, aspect='auto', cmap=cm.bwr, vmin=-0.05, vmax=0.05, extent=[-1, 3, 0, len(_q_psth.xs(0, level=boi).T)])
axes[1].imshow(_q_psth.xs(3, level=boi).T, aspect='auto', cmap=cm.bwr, vmin=-0.05, vmax=0.05, extent=[-1, 3, 0, len(_q_psth.xs(0, level=boi).T)])
axes[0].set_xlim(-0.2, 1.2)
for ax in axes:
    ax.axvline(0, c='k')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

### inhibition time vs state