In [1]:
%load_ext autoreload
%autoreload 2

from os import path
import json

import numpy as np
import pandas as pd
import scipy as sp

import matplotlib.pyplot as plt
from matplotlib import cm
import matplotlib as mpl
import seaborn as sns
from tqdm.auto import tqdm
# import pandarallel
from IPython.utils.capture import capture_output
with capture_output():
    tqdm.pandas()
#     pandarallel.pandarallel.initialize(progress_bar=True)
from mne.time_frequency import psd_array_multitaper

from tbd_eeg.data_analysis.eegutils import EEGexp
from tbd_eeg.data_analysis.Utilities.utilities import get_stim_events, find_nearest_ind

from ipympl.backend_nbagg import Canvas
Canvas.header_visible.default_value = False
%matplotlib widget

In [2]:
boi_defs = {
    'delta' : (0.8, 1.5),
    'theta' : (3, 6),
    'alpha' : (8, 13),
    'beta' : (15, 30),
    'lower_gamma' : (30, 60),
    'upper_gamma' : (60, 100)
}
boi_colors = {
    _band : cm.brg(i/6) for i, _band in enumerate(boi_defs)
}

hipp_channels = [4, 5, 7, 8, 21, 22, 24, 25]

In [3]:
def bin_spikes(spikes, bin_size_ms=1, t_start=0, t_end=1e5):
    N = int((t_end-t_start)*1000/bin_size_ms)
    spikes = spikes[(spikes>t_start)&(spikes<t_end)]
    _binned_spikes = pd.Series(
        index=(N * (spikes - t_start) / (t_end - t_start)).astype(int), data=1
    )
    counts = _binned_spikes.reset_index().groupby('index').size()
    times = np.linspace(t_start, t_end, N, endpoint=False)
    binned_spikes = np.zeros(N)
    binned_spikes[counts.index] = counts
    return pd.Series(binned_spikes, index=times, dtype=bool)

def load_spikes(rec_folder, probe, repeat=False):
    exp = EEGexp(rec_folder, preprocess=False, make_stim_csv=False)
    stim_log = pd.read_csv(exp.stimulus_log_file)
    stim_log['sweep'] = stim_log.sweep.map(
        {
            i : f'{i}_{k}' for i, k in enumerate(
                gmetadata[
                    gmetadata.exp_name==expt
                ]['brain states'].values[0].split('/')
            )
        }
    )
    stim_log.rename_axis(index='stim_id', inplace=True)
    
    fname = (
        f'../tiny-blue-dot/zap-n-zip/sg/spikes_stim_aligned/{exp.mouse}_'
        f'{path.basename(path.dirname(exp.experiment_folder))}_{probe}.pkl')
    if path.exists(fname) and not repeat:
        spikes = pd.read_pickle(fname)
    else:
        print('Loading and preprocessing spiking data.')
        # read raw spike time data
        _spike_times = np.load(
            exp.ephys_params[probe]['spike_times'], mmap_mode='r'
        )
        _spike_clusters = np.load(
            exp.ephys_params[probe]['spike_clusters'], mmap_mode='r'
        )
        cluster_metrics = pd.read_csv(
            exp.ephys_params[probe]['cluster_metrics'], index_col=1
        ).drop('Unnamed: 0', axis=1, errors='ignore')
        cluster_groups = pd.read_csv(
            exp.ephys_params[probe]['cluster_group'],
            sep='\t', index_col=0
        )

        # rearrange into spike times for each cluster
        spike_df = pd.DataFrame(
            index=_spike_clusters, data=_spike_times, columns=['time']
        )
        spike_times = spike_df.groupby(level=0).apply(lambda g: g.values[:, 0])

        # keep only good clusters and drop 'noise'
        cluster_metrics = cluster_metrics[cluster_groups.group.isin(['good'])]
        cluster_metrics = cluster_metrics[
            (cluster_metrics.isi_viol<0.5)&(cluster_metrics.amplitude_cutoff<0.1)
        ]
        spike_times = spike_times.loc[cluster_metrics.index]

        # bin spikes into 1ms bins
        spikes = {}
        t_start = spike_times.apply(lambda x: x.min()).min().round(3)
        t_end = spike_times.apply(lambda x: x.max()).max().round(3)
        for u, t in tqdm(spike_times.items(), total=len(spike_times)):
            spikes[u] = bin_spikes(t, t_start=t_start, t_end=t_end)
        spikes = pd.concat(spikes, axis=1, names='units').rename_axis('time')

        # keep only those spikes that are in a window of interest around stimulus times
        idx = stim_log.reset_index().set_index('onset').rename_axis('time')
        idx.index = idx.index - 1
        idx = idx.reindex(
            spikes.index, method='ffill', limit=1000*4
        ).reset_index().dropna()
        spikes = spikes.loc[idx.time]
        
        def _reset_index_time(df):
            df['time'] = (df.time - df.time.iloc[0] - 1).round(3)
            return df
        idx = idx.groupby('stim_id').apply(_reset_index_time).drop(
            ['offset', 'duration'], axis=1
        )
        
        # set index and columns with useful information
        spikes.index = pd.MultiIndex.from_frame(idx)
        spikes = spikes.sort_index()
        spike_waveforms = np.load(exp.ephys_params[probe]['waveforms'])
        width = {}
        for u in cluster_metrics.index:
            wav = spike_waveforms[u, cluster_metrics.loc[u, 'peak_channel'], :]
            wav_duration = np.abs(np.argmin(wav)-np.argmax(wav))
            width[u] = wav_duration / exp.ephys_params[probe]['ap_sample_rate']
        cluster_metrics = cluster_metrics.join(
            pd.Series(width, name='waveform_width')
        )
        _cm_names = ['FS_RS', 'peak_channel']

        try:
            with open(exp.ephys_params[probe]['probe_info'], 'r') as f:
                areas = json.load(f)['area_ch']
            areas = pd.Series(areas, name='area')
            layers = areas.str.extract('(\d.*)')[0].fillna('').rename('layer')
            areas = areas.str.rstrip('12/3456ab').fillna('')
            cluster_metrics['area'] = cluster_metrics.peak_channel.map(
                lambda x: areas.loc[x]
            )
            cluster_metrics['layer'] = cluster_metrics.peak_channel.map(
                lambda x: layers.loc[x]
            )
            _cm_names += ['area', 'layer']
        except:
            print('Area labels not found.')
        
        cluster_metrics['FS_RS'] = cluster_metrics.waveform_width.map(
            lambda x: 'FS' if x<0.0004 else 'RS'
        )
        spikes.columns = pd.MultiIndex.from_frame(
            cluster_metrics[_cm_names].reset_index()
            .rename({'cluster_id':'unit'}, axis=1)
        )
        # remove stimulation artifact
        spikes.loc[-0.001, 0.002] = False
        
        spikes.to_pickle(fname)
    return spikes

def get_mean_firing(rec_folder, probe, repeat_spikes=False):
    exp = EEGexp(rec_folder, preprocess=False, make_stim_csv=False)
    fname = (
        f'../tiny-blue-dot/zap-n-zip/sg/spikes_stim_aligned/fr_{exp.mouse}_'
        f'{path.basename(path.dirname(exp.experiment_folder))}_{probe}.pkl')
    if path.exists(fname):
        return pd.DataFrame(pd.read_pickle(fname), dtype='int16')
    spikes = load_spikes(rec_folder, probe, repeat=repeat_spikes)
    fr = pd.DataFrame((
        spikes.groupby(['stim_type', 'parameter', 'sweep', 'time']).mean()*1000
    ).round(), dtype='int16')
    fr.to_pickle(fname)
    return pd.DataFrame(fr, dtype='int16')

# Select an experiment to analyze

In [4]:
# accessing the Google sheet with experiment metadata in python
# setting up the permissions:
# 1. install gspread (pip install gspread / conda install gspread)
# 2. copy the service_account.json file to '~/.config/gspread/service_account.json'
# 3. run the following:
import gspread
_gc = gspread.service_account() # need a key file to access the account (step 2)
_sh = _gc.open('Zap_Zip-log_exp') # open the spreadsheet
_df = pd.DataFrame(_sh.sheet1.get()) # load the first worksheet
gmetadata = _df.T.set_index(0).T # put it in a nicely formatted dataframe

In [5]:
gmetadata[gmetadata['brain states'].str.contains('ur', case=False)][3:]

Unnamed: 0,mouse_name,exp_name,brain states,stimulation,visual_stim,audio_stim,ISI (sec),stimulus duration (msec),Current (uA),Cortical Area stimulation,N trials per stimulus,EEG bad_channels,Npx,Units Sorted (X),Brain slices (X),Pupil tracking pre-processing,Brain areas assignment,"CCF coordinates stim electrode (surface,tip)","CCF area stim electrode (surface,tip)",Notes
53,mouse590478,urethane_2021-10-29_11-51-38,urethane/urethane,electrical,,,[3.5 4.5],0.2,70/80/100,SS-cortex,120,,B,X,Tissuecyte,,,,,EEG filter at 0.1Hz. Signals look OK. Pupil vi...
55,mouse590480,urethane_vis_2021-11-12_11-14-43,urethane/urethane/urethane,electrical/sensory,white,,[3.5 4.5],0.2/250,30/50/70,SS-cortex,120,8101322.0,"F,B",X,Tissuecyte,,,,,EEG filter at 0.1Hz. Visual stimuli delivered ...
57,mouse590481,urethane_vis_2021-11-19_10-59-04,urethane/urethane/urethane/urethane,electrical/sensory,white,,[3.5 4.5],0.2/250,30/50/70,SS-cortex,120,,"F,B,C",,Tissuecyte,,,,,EEG filter at 0.1Hz. IP injection. Visual stim...


In [6]:
mouse = 'mouse590480'
expt = gmetadata[gmetadata.mouse_name==mouse].exp_name
expt = expt[expt.str.contains('ur', case=False)].values[0]

bad_channels = gmetadata[(gmetadata.mouse_name==mouse)&(gmetadata.exp_name==expt)]['EEG bad_channels'].iloc[0]
bad_channels = [int(x) for x in bad_channels.split(',')]

# Load LFP, EEG and spiking data

In [7]:
rec_folder = (
    f'../tiny-blue-dot/zap-n-zip/EEG_exp/'
    f'{mouse}/{expt}/experiment1/recording1/'
)
exp = EEGexp(rec_folder, preprocess=False, make_stim_csv=False)

# load LFP data
probes = [k for k in exp.ephys_params.keys() if 'probe' in k]
sample_rate_lfp = exp.ephys_params[probes[0]]['lfp_sample_rate']

lfp, timestamps = {}, {}
for probe in probes:
    lfp[probe] = np.memmap(
        exp.ephys_params[probe]['lfp_continuous'],
        dtype='int16', mode='r'
    )
    lfp[probe] = pd.DataFrame(np.reshape(lfp[probe], (
        int(lfp[probe].size/exp.ephys_params[probe]['num_chs']),
        exp.ephys_params[probe]['num_chs']
    )), index=np.load(
        exp.ephys_params[probe]['lfp_timestamps']
    ))

# load EEG data
sample_rate_eeg = exp.ephys_params['EEG']['sample_rate']
eegdata = exp.load_eegdata(return_type='pd').drop(bad_channels, axis=1)
# add common average and hippocampal average signals
eegdata['common_avg'] = eegdata.mean(1)
eegdata['hipp_avg'] = eegdata[
    [x for x in hipp_channels if x not in bad_channels]
].mean(1)

# loading the stimulus table
stim_log = pd.read_csv(exp.stimulus_log_file)
stim_log['sweep'] = stim_log.sweep.map(
    {
        i : f'{i}_{k}' for i, k in enumerate(
            gmetadata[
                gmetadata.exp_name==expt
            ]['brain states'].values[0].split('/')
        )
    }
)
stim_log.rename_axis(index='stim_id', inplace=True)

# load spikes
fn_spikes = (
    f'../tiny-blue-dot/zap-n-zip/sg/spikes_stim_aligned/{exp.mouse}_'
    f'{path.basename(path.dirname(exp.experiment_folder))}_all.pkl'
)
if path.exists(fn_spikes):
    spikes = pd.read_pickle(fn_spikes)
else:
    spikes = {}
    for probe in probes:
        spikes[probe] = load_spikes(rec_folder, probe, repeat=False)
    spikes = pd.concat(spikes, axis=1, names=['probe', 'unit', 'FS_RS', 'peak_channel'])
    spikes = spikes.swaplevel('time', 'sweep')
    spikes.sort_index(inplace=True)
    spikes.to_pickle(fn_spikes)

Experiment type: electrical and sensory stimulation
SomnoSuite log file not found.


# Window eegdata and extract band powers

In [8]:
pre_win_size = 2.5
idx = stim_log.reset_index().set_index('onset').rename_axis('time')
idx.index = idx.index - pre_win_size
idx = idx.reindex(
    eegdata.index.rename('time'),
    method='ffill', limit=sample_rate_eeg*4
).reset_index().dropna()
eegpre = eegdata.loc[idx.time].copy()

def _reset_index_time(df):
    df['time'] = (df.time - df.time.iloc[0] - pre_win_size).round(4)
    return df
idx = idx.groupby('stim_id').apply(_reset_index_time).drop(
    ['offset', 'duration'], axis=1
)

# set index and columns with useful information
eegpre.index = pd.MultiIndex.from_frame(idx)
eegpre = eegpre.swaplevel('time', 'sweep').sort_index()

In [9]:
df = eegpre.loc[('0_urethane', 1.0, 'biphasic')]
df

Unnamed: 0_level_0,Unnamed: 1_level_0,0,1,2,3,4,5,6,7,9,11,12,14,15,16,17,18,19,20,21,23,24,25,26,27,28,29,common_avg,hipp_avg
parameter,time,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1
30,-2.5000,-9.360000,-42.509998,-41.144998,-55.379998,-74.879997,-70.784997,-80.924997,-101.789996,-109.784996,-113.879996,-114.659996,-131.429995,-91.649997,-40.754999,-81.509997,11.505000,-45.824998,13.065000,-60.059998,-14.819999,-63.179998,-35.294999,-25.154999,-25.154999,-12.870000,9.165000,-54.194998,-67.664998
30,-2.4996,-4.485000,-30.419999,-40.169999,-49.334998,-69.029997,-67.664998,-71.564997,-98.474996,-103.154996,-103.934996,-105.494996,-116.219996,-89.504997,-23.984999,-77.024997,17.744999,-39.389999,19.304999,-53.429998,-17.939999,-56.939998,-28.859999,-16.769999,-20.669999,-4.095000,19.694999,-47.377498,-62.399998
30,-2.4992,-3.705000,-38.414999,-52.649998,-40.364999,-77.219997,-72.929997,-82.289997,-98.474996,-104.519996,-107.444996,-100.424996,-113.099996,-94.574997,-36.074999,-81.119997,5.265000,-44.459998,16.769999,-61.619998,-21.254999,-63.179998,-38.219999,-22.619999,-15.404999,-14.429999,8.580000,-52.072498,-68.607497
30,-2.4988,-6.825000,-41.339998,-53.819998,-62.594998,-82.874997,-76.829997,-90.869997,-98.279996,-95.549996,-102.179996,-103.544996,-116.024996,-82.094997,-23.594999,-68.834997,10.530000,-36.269999,15.209999,-53.624998,-25.349999,-69.614997,-45.434998,-25.349999,-28.079999,-14.234999,13.455000,-52.462498,-71.109997
30,-2.4984,-6.825000,-31.589999,-41.924998,-57.524998,-76.829997,-74.294997,-88.334997,-94.184997,-102.569996,-84.044997,-97.109996,-104.519996,-78.779997,-22.034999,-72.734997,13.065000,-45.824998,6.825000,-61.814998,-19.109999,-69.419997,-40.169999,-26.129999,-44.264998,-20.864999,-0.780000,-51.607498,-69.452497
30,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
30,1.4592,-44.069998,-38.219999,-18.524999,-69.809997,-54.989998,-38.024999,-67.079998,-53.039998,-45.044998,-35.879999,-26.714999,-30.029999,-38.219999,-42.899998,-48.944998,-55.379998,-57.524998,-56.744998,-53.429998,-84.044997,-48.554998,-74.294997,-76.439997,-29.639999,-41.924998,-65.129998,-49.792498,-53.722498
30,1.4596,-42.704998,-46.604998,-23.984999,-57.914998,-70.589997,-45.824998,-61.424998,-63.179998,-54.404998,-42.899998,-33.149999,-42.314998,-32.759999,-31.979999,-48.749998,-48.554998,-45.239998,-55.574998,-46.019998,-72.344997,-40.169999,-70.199997,-75.854997,-21.449999,-37.244999,-51.869998,-48.577498,-55.997498
30,1.4600,-40.559999,-33.344999,-15.209999,-64.154998,-63.374998,-32.954999,-49.724998,-65.324998,-49.724998,-47.189998,-33.149999,-33.344999,-32.759999,-40.949998,-42.509998,-47.969998,-46.799998,-51.674998,-41.339998,-71.174997,-37.244999,-61.814998,-70.004997,-23.594999,-31.979999,-60.449998,-45.704998,-50.342498
30,1.4604,-30.224999,-27.494999,-8.970000,-62.984998,-64.934998,-24.959999,-46.994998,-50.504998,-38.024999,-36.269999,-26.714999,-29.639999,-23.204999,-24.764999,-39.974999,-46.604998,-53.819998,-62.009998,-42.119998,-71.369997,-36.074999,-62.789998,-67.859998,-12.090000,-21.449999,-43.679998,-40.597499,-46.897498


## Extract band powers for each window

In [10]:
def get_band_powers(df):
    with capture_output():
        _psd, _freqs = psd_array_multitaper(
            df.T, sample_rate_eeg, 0, 100, adaptive=True,
            low_bias=False, normalization='full', n_jobs=6, verbose=0
        )
    _psd = pd.DataFrame(_psd.T, index=_freqs, columns=df.columns)
    
    _integrated_power = {}
    for _band, _idx in boi_defs.items():
        _integrated_power[f'{_band}'] = _psd.loc[_idx[0]:_idx[1]].apply(
            lambda col: sp.integrate.simps(col.values, col.index)
        )
    _integrated_power = pd.concat(_integrated_power, axis=1)
    _integrated_power['total'] = _psd.apply(
        lambda col: sp.integrate.simps(col.values, col.index)
    )
    _integrated_power = _integrated_power.T
    return _integrated_power.rename_axis('band')

In [11]:
_band_power_example = get_band_powers(
    df.swaplevel().sort_index()[:-0.005].swaplevel().sort_index()
)
f, ax = plt.subplots(figsize=(4, 3), tight_layout=True)
_band_power_example.drop('total').plot(ax=ax, cmap=cm.winter)
ax.set_yscale('log', nonpositive='mask')
ax.set_xlabel('band')
ax.set_ylabel('power')
ax.set_xticklabels([
    s.get_text().replace('_', '\n') for s in ax.get_xticklabels()
])
ax.get_legend().set_visible(False)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [None]:
band_powers = (
    eegpre.swaplevel('sweep', 'time').sort_index()[:-0.005]
    .swaplevel('sweep', 'time').sort_index()
    .groupby(level=[0, 1, 2, 3]).progress_apply(get_band_powers)
)
band_powers_normalized = (
    band_powers.T / band_powers.xs('total', level='band').T
).T
band_powers

## Characterize distribution of band powers across trials

In [None]:
electrode_id = 'hipp_avg' # choose an electrode to plot metrics
bois = ['delta', 'theta', 'beta', 'upper_gamma']
stim_type = 'biphasic'

dfs = {
    boi : (
        band_powers[electrode_id]
        .xs(stim_type, level='stim_type')
        .xs(boi, level='band')
        .unstack(['sweep', 'parameter'])
    ) for boi in bois
}

sweeps = dfs[bois[0]].columns.remove_unused_levels().levels[0]

f = plt.figure(
    figsize=(3*len(dfs), 2.2*len(sweeps)), tight_layout=True
)
f.suptitle(f'electrode {electrode_id}')
gs = plt.GridSpec(
    len(sweeps), len(dfs), figure=f
)
axes = []

for i, sw in enumerate(sweeps):
    axes.append([])
    for j, boi in enumerate(dfs.keys()):
        if i>0:
            ax = f.add_subplot(gs[i, j], sharex=axes[0][j])
        else:
            ax = f.add_subplot(gs[i, j])
        axes[i].append(ax)
        sns.histplot(
            dfs[boi][sw], common_bins=True, common_norm=True,
            ax=ax, element='step'
        )
        for t in ax.get_legend().texts:
            t.set_fontsize(8)
        ax.get_legend().get_title().set_fontsize(9)
        if i==len(sweeps)-1:
            ax.set_xlabel(f'{boi} power')
        else:
            ax.tick_params(labelbottom=False)
        if j!=0:
            ax.set_ylabel('')
        else:
            ax.set_ylabel(f'{sw}\ncount')

In [None]:
# same as above figure, but with normalized band powers

dfs = {
    boi : (
        band_powers_normalized[electrode_id]
        .xs(stim_type, level='stim_type')
        .xs(boi, level='band')
        .unstack(['sweep', 'parameter'])
    ) for boi in bois
}

sweeps = dfs[bois[0]].columns.remove_unused_levels().levels[0]

f = plt.figure(
    figsize=(3*len(dfs), 2.2*len(sweeps)), tight_layout=True
)
f.suptitle(f'electrode {electrode_id}')
gs = plt.GridSpec(
    len(sweeps), len(dfs), figure=f
)
axes = []

for i, sw in enumerate(sweeps):
    axes.append([])
    for j, boi in enumerate(dfs.keys()):
        if i>0:
            ax = f.add_subplot(gs[i, j], sharex=axes[0][j])
        else:
            ax = f.add_subplot(gs[i, j])
        axes[i].append(ax)
        sns.histplot(
            dfs[boi][sw], common_bins=True, common_norm=True,
            ax=ax, element='step'
        )
        for t in ax.get_legend().texts:
            t.set_fontsize(8)
        ax.get_legend().get_title().set_fontsize(9)
        if i==len(sweeps)-1:
            ax.set_xlabel(f'{boi} power')
        else:
            ax.tick_params(labelbottom=False)
        if j!=0:
            ax.set_ylabel('')
        else:
            ax.set_ylabel(f'{sw}\ncount')

In [None]:
x_col = 'common_avg'
y_col = 'hipp_avg'
f, ax = plt.subplots(figsize=(4, 3), tight_layout=True)
for i, boi in enumerate(bois):
    band_powers.xs(boi, level='band').plot.scatter(
        x=x_col, y=y_col, ax=ax, label=boi,
        edgecolor=f'C{i}', color='white', alpha=0.2
    )
ax.set_xlabel(f'channel {x_col}')
ax.set_ylabel(f'channel {y_col}')
ax.legend(fontsize=9)
ax.set_xscale('log')
ax.set_yscale('log')

In [None]:
# same as above figure but with normalized band powers
f, ax = plt.subplots(figsize=(4, 3), tight_layout=True)
for i, boi in enumerate(bois):
    band_powers_normalized.xs(boi, level='band').plot.scatter(
        x=x_col, y=y_col, ax=ax, label=boi,
        edgecolor=f'C{i}', color='white', alpha=0.2
    )
ax.set_xlabel(f'channel {x_col}')
ax.set_ylabel(f'channel {y_col}')
ax.legend(fontsize=9)
ax.set_xscale('log')
ax.set_yscale('log')

In [None]:
f, axes = plt.subplots(
    2, len(bois)//2, figsize=(2*len(bois)/2, len(bois)),
    sharey=True, sharex=True
)
axes = axes.flatten()

for i, (ax, boi) in enumerate(zip(axes, bois)):
    im = ax.imshow(
        band_powers.xs(boi, level='band').corr(),
        aspect='auto', vmin=-1, vmax=1, cmap=cm.RdBu
    )
    ax.set_title(boi, fontsize=9)
    if i>=len(bois)/2:
        ax.set_xlabel('channel', fontsize=9)
    if i%2==0:
        ax.set_ylabel('channel', fontsize=9)
    ax.tick_params(labelsize=7)

f.subplots_adjust(right=0.8)
cax = f.add_axes([0.85, 0.15, 0.05, 0.7])
f.colorbar(im, cax=cax)
cax.tick_params(labelsize=7);

In [None]:
# same as above figure, but with normalized band powers
f, axes = plt.subplots(
    2, len(bois)//2, figsize=(2*len(bois)/2, len(bois)),
    sharey=True, sharex=True
)
axes = axes.flatten()

for i, (ax, boi) in enumerate(zip(axes, bois)):
    im = ax.imshow(
        band_powers_normalized.xs(boi, level='band').corr(),
        aspect='auto', vmin=-1, vmax=1, cmap=cm.RdBu
    )
    ax.set_title(boi, fontsize=9)
    if i>=len(bois)/2:
        ax.set_xlabel('channel', fontsize=9)
    if i%2==0:
        ax.set_ylabel('channel', fontsize=9)
    ax.tick_params(labelsize=7)

f.subplots_adjust(right=0.8)
cax = f.add_axes([0.85, 0.15, 0.05, 0.7])
f.colorbar(im, cax=cax)
cax.tick_params(labelsize=7);

# Characterize responses for different band powers

In [None]:
# set up parameters used for the entire section
boi = 'dot'
electrode_id = 'common_avg'
parameter = '50'

powers = band_powers[electrode_id].unstack('band')
powers['dot'] = np.log(powers['theta'] / powers['delta'])
powers_discrete = powers.apply(pd.qcut, q=4, labels=range(4))
powers_normalized = band_powers_normalized[electrode_id].unstack('band')
powers_normalized['dot'] = np.log(powers['theta'] / powers['delta'])
powers_discrete_normalized = powers_normalized.drop('total', axis=1).apply(
    pd.qcut, q=4, labels=range(4)
)
display(powers)
display(powers_discrete)

In [None]:
# _dot = np.log(powers['theta'] / powers['delta'])
f, (ax, ax2) = plt.subplots(
    1, 2, figsize=(8, 2.4), tight_layout=True,
    gridspec_kw=dict(width_ratios=[1, 4])
)
sns.histplot(powers['dot'], ax=ax, bins=80)
ax.set_xlabel('log (4 Hz / 1 Hz)')
powers['dot'].droplevel(
    ['sweep', 'stim_type', 'parameter']
).rolling(1).mean().sort_index().plot(ax=ax2)
ax2.set_ylabel('log (4 Hz / 1 Hz)')
f.suptitle(electrode_id, fontsize=12);

## Mean responses of EEG channels wrt state

In [None]:
post_win_size = 3
idx = stim_log.reset_index().set_index('onset').rename_axis('time')
idx.index = idx.index - 4 + post_win_size
idx = idx.reindex(
    eegdata.index.rename('time'),
    method='ffill', limit=sample_rate_eeg*4
).reset_index().dropna()
eegpost = eegdata.loc[idx.time].copy()

def _reset_index_time(df):
    df['time'] = (df.time-df.time.iloc[0]-4+post_win_size).round(4)
    return df
idx = idx.groupby('stim_id').apply(_reset_index_time).drop(
    ['offset', 'duration'], axis=1
)

# set index and columns with useful information
eegpost.index = pd.MultiIndex.from_frame(idx)
eegpost = eegpost.swaplevel('time', 'sweep').sort_index()

# example part of the dataframe
eegpost.loc[('0_urethane', 1.0, 'biphasic')]

### Magnitude of different EEG response coponents vs band power

## Spiking responses wrt state

### # spikes vs state

In [None]:
_spikes = spikes.xs(parameter, level='parameter')

# compute PSTH for each neuron separately using all trials
_mfr = _spikes.groupby(['time']).mean()

# z-score the mean firing rates
_mfr = (_mfr - _mfr[:0].mean()) / _mfr[:0].std()

_mfr = _mfr.rolling(
    window=10, center=True, win_type='gaussian'
).mean(std=4)

In [None]:
f, ax = plt.subplots(figsize=(7, 2), tight_layout=True)
_mfr.T.reset_index(drop=True).T.plot(
    ax=ax, legend=False, c='C0', alpha=0.3
)
ax.set_xlim(-0.2, 0.5);

In [None]:
n_spikes_bs = _spikes.swaplevel('sweep', 'time').sort_index()[-1.005:-0.005].groupby('stim_id').mean()

n_spikes_early = _spikes.swaplevel('sweep', 'time').sort_index()[0:0.2].groupby('stim_id').mean() - n_spikes_bs
n_spikes_late = _spikes.swaplevel('sweep', 'time').sort_index()[0.3:0.5].groupby('stim_id').mean() - n_spikes_bs

_powers = powers.xs(parameter, level='parameter').swaplevel('sweep', 'stim_id').sort_index().reset_index()

In [None]:
boi = 'delta'
f, ax = plt.subplots(figsize=(3, 2.4), tight_layout=True)
sns.histplot(n_spikes_late.corrwith(_powers[boi]), ax=ax, color='C1', alpha=0.2, element='step')
ax.axvline(n_spikes_late.corrwith(_powers[boi]).mean(), c='C1', label='late')
sns.histplot(n_spikes_early.corrwith(_powers[boi]), ax=ax, color='C0', alpha=0.2, element='step')
ax.axvline(n_spikes_early.corrwith(_powers[boi]).mean(), c='C0', label='early')
ax.legend(loc=0, fontsize=7)
ax.set_xlabel(f'corr ( # spikes, {boi} power )')
ax.set_ylabel('# neurons');

---

In [None]:
q_spikes = spikes.copy()
q_spikes.index = pd.MultiIndex.from_frame(
    q_spikes.index.to_frame()
    .reset_index('time', drop=True)
    .join(powers_discrete)
)
q_psths = q_spikes.groupby(['sweep', 'parameter', 'time', boi]).mean()

In [None]:
_q_psth = q_psths.loc[('0_urethane', '50')]

In [None]:
f, ax = plt.subplots(figsize=(6, 3), tight_layout=True)
_ = _q_psth.mean(1).unstack(boi).rolling(
    window=10, center=True, win_type='gaussian'
).mean(std=4).apply(
    lambda c: c.plot(ax=ax, c=cm.RdBu((c.name+1)/5, 0.5))
)
ax.set_xlim(-0.2, 1.2);

In [None]:
_q_psth.xs(0, level=boi).T

In [None]:
f, axes = plt.subplots(2, 1, figsize=(6, 6), tight_layout=True, sharex=True, sharey=True)
axes[0].imshow(_q_psth.xs(0, level=boi).T, aspect='auto', cmap=cm.bwr, vmin=-0.05, vmax=0.05, extent=[-1, 3, 0, len(_q_psth.xs(0, level=boi).T)])
axes[1].imshow(_q_psth.xs(3, level=boi).T, aspect='auto', cmap=cm.bwr, vmin=-0.05, vmax=0.05, extent=[-1, 3, 0, len(_q_psth.xs(0, level=boi).T)])
axes[0].set_xlim(-0.2, 1.2)
for ax in axes:
    ax.axvline(0, c='k')

### inhibition time vs state