Comparison between pre and post stimulus connectivity across brian regions

In [1]:
%load_ext autoreload
%autoreload 2
import warnings

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import cm, patches
import matplotlib.gridspec as gridspec
from scipy import signal
from tqdm.auto import tqdm
with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    tqdm.pandas()

from tbd_eeg.data_analysis.eegutils import *
from tbd_eeg.data_analysis.Utilities import utilities as utils
from tbd_eeg.data_analysis.Utilities import filters

from plot_electrodes import *

from ipympl.backend_nbagg import Canvas
Canvas.header_visible.default_value = False

from simclus import simclus

%matplotlib widget

In [2]:
epoch_cms = {
    'pre' : cm.Reds,
    'iso_high' : cm.PuOr,
    'iso_low' : cm.PuOr_r,
    'early_recovery': cm.Blues,
    'late_recovery' : cm.Greens
}

In [3]:
data_folder = "../tiny-blue-dot/zap-n-zip/EEG_exp/mouse521886/estim1_2020-07-16_13-37-02/experiment1/recording1/"

# set the sample_rate for all data analysis
sample_rate = 200

# load experiment metadata and eeg data
exp = EEGexp(data_folder)
eegdata = exp.load_eegdata(frequency=sample_rate, return_type='pd')

# locate valid channels (some channels can be disconnected and we want to ignore them in the analysis)
print('Identifying valid channels...')
median_amplitude = eegdata[:sample_rate*300].apply(
    utils.median_amplitude, raw=True, axis=0, distance=sample_rate
)
valid_channels = median_amplitude.index[median_amplitude < 2000].values
print('The following channels seem to be correctly connected and report valid data:')
print(list(valid_channels))

# load other data (running, iso etc)
print('Loading other data...')
running_speed = exp.load_running(return_type='pd')
iso = exp.load_analog_iso(return_type='pd')

# automatically annotate anesthesia epochs
iso_first_on = (iso>4).idxmax()
print('iso on at', iso_first_on)
iso_first_mid = ((iso[iso.index>iso_first_on]>1)&(iso[iso.index>iso_first_on]<4)).idxmax()
print('iso reduced at', iso_first_mid)
iso_first_off = (iso>1)[::-1].idxmax()
print('iso off at', iso_first_off)

# annotate artifacts with power in high frequencies
print('Annotating artifacts...')
hf_annots = pd.Series(
    eegdata[valid_channels].apply(
        find_hf_annotations, axis=0,
        sample_rate=sample_rate, fmin=300, pmin=0.25
    ).mean(axis=1),
    name='artifact'
)
recovery_first_jump = (hf_annots>4)[hf_annots.index>iso_first_off].idxmax()

epochs = pd.Series(
    index = [0, iso_first_on-0.001, iso_first_on+0.001, iso_first_mid-0.001,
             iso_first_mid+0.001, iso_first_off-0.001, iso_first_off+0.001,
             recovery_first_jump-0.001, recovery_first_jump+0.001, eegdata.index[-1]],
    data=['pre', 'pre', 'iso_high', 'iso_high', 'iso_low', 'iso_low',
          'early_recovery', 'early_recovery', 'late_recovery', 'late_recovery'],
    dtype=pd.CategoricalDtype(
        categories=['pre', 'iso_high', 'iso_low', 'early_recovery', 'late_recovery'],
        ordered=True
    )
)

The settings.xml file was not found.


Was the recording done on NP4? [y/n]  y


Experiment type: electrical stimulation.
SomnoSuite log file not found.
Identifying valid channels...
The following channels seem to be correctly connected and report valid data:
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
Loading other data...
iso on at 1439.6
iso reduced at 1658.38
iso off at 2965.83
Annotating artifacts...


In [4]:
# load stimuli
if exp.experiment_type == 'electrical stimulation':
    stimuli = pd.read_csv(exp.stimulus_log_file)
else:
    stimuli = None
stimuli

Unnamed: 0,stim_type,amplitude,duration,onset,offset,sweep
0,biphasic,50,400.0,134.81231,134.81291,0
1,biphasic,20,400.0,138.40947,138.41007,0
2,biphasic,50,400.0,142.72804,142.72864,0
3,biphasic,100,400.0,147.04601,147.04661,0
4,biphasic,100,400.0,151.26887,151.26948,0
...,...,...,...,...,...,...
895,biphasic,20,400.0,4168.44474,4168.44535,2
896,biphasic,100,400.0,4172.79589,4172.79649,2
897,biphasic,100,400.0,4176.85687,4176.85748,2
898,biphasic,100,400.0,4181.28689,4181.28750,2


In [7]:
f, ax = plt.subplots(1, 1, figsize=(2, 2), constrained_layout=False)
plot_electrode_map(ax, labels=True, s=50)
ax.axis('off');

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

---

# Extract windows around stimuli

In [8]:
f, ax = plt.subplots(1, 1, figsize=(2.5, 2), tight_layout=True)
(stimuli.onset - stimuli.offset.shift(1)).hist(ax=ax, grid=False, bins=np.linspace(3, 5, 41))
ax.set_xlabel('inter stimulus interval', fontsize=10);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
winsize_s = 1.5 # response dies down within 1 sec typically, so this is a decent pre and post window size

# in addition to responses, also assign epochs to windows
win_epochs = epochs.reindex(stimuli.onset, method='nearest').rename(('epoch', ''))
win_epochs.index = stimuli.index

def get_windows(col):
    index = pd.Index(np.linspace(-winsize_s, winsize_s, int(winsize_s*2*sample_rate+1)))
    responses = {}
    for stim_i in stimuli.index:
        responses[stim_i] = col.loc[stimuli.onset[stim_i] - winsize_s:stimuli.onset[stim_i] + winsize_s]
        responses[stim_i].index = responses[stim_i].index - stimuli.onset[stim_i]
        responses[stim_i] = responses[stim_i].reindex(index, method='nearest')
    return pd.concat(responses)

responses = eegdata.progress_apply(get_windows).unstack()

HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))




# Mean responses across electrodes
* why do the response directions of the frontal electrodes flip with increasing stimulus strength?
* The response amplitude (and shape) of frontal electrodes (especially left) is very mildly affected by anesthesia state – perhaps they are not useful in quantifying brain state?
* Amplitude is largest for pre-anesthesia, smallest for anesthetized state
* Post-stimulus minima is nearly (completely) absent in anesthetized state

In [10]:
def plot_mean_response(col, ax, cmap=cm.Paired, alpha=0.8):
    try:
        c = cmap(exp.ch_coordinates.gid[col.name]/11, alpha)
        cb = cmap(exp.ch_coordinates.gid[col.name]/11, 0.05)
    except:
        return
    col = col.unstack()
    mn, sd = col.mean(), col.std()
    mn.plot(ax=ax, c=c, legend=False)
    ax.fill_between(mn.index.astype(float), (mn-sd).values, (mn+sd).values, color=cb)
    return

def plot_mean_response_df(df, ax=None, title=None):
    if ax is None:
        f, (ax, axm) = plt.subplots(1, 2, figsize=(8, 2.5), gridspec_kw=dict(width_ratios=[3, 1]), tight_layout=True)
        plot_electrode_map(axm, labels=False)
        axm.axis('off')
    if title is None:
        title = f'amplitude: {df.name}'
    df.stack().apply(plot_mean_response, args=(ax,))
    ax.set_title(title)
    ax.set_xlim(xmin=-0.25)
    return ax

f, axes = plt.subplots(3, 3, figsize=(9, 5), tight_layout=True, sharex=True, sharey=False)
_responses = responses.join(
    pd.concat([
        stimuli.amplitude.rename(('amplitude', '')),
        win_epochs
    ], axis=1)
)

for axe, amp in zip(axes, set(_responses.amplitude)):
    ax.get_shared_y_axes().join(*axe)
    for ax, ep in zip(axe, set(_responses.epoch)):
        if ep in ['iso_high', 'early_recovery']:
            continue
        title = ''
        if ax in axes[0]:
            title = f'{ep}'
        if ax in axes[:, 0]:
            ax.set_ylabel(f'amplitude {amp}')
        plot_mean_response_df(
            _responses[(_responses.amplitude==amp)&(_responses.epoch==ep)].drop(['epoch', 'amplitude'], axis=1)[[10, 11]],
            ax=ax, title=title
        )

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

The rowNum attribute was deprecated in Matplotlib 3.2 and will be removed two minor releases later. Use ax.get_subplotspec().rowspan.start instead.
  layout[ax.rowNum, ax.colNum] = ax.get_visible()
The colNum attribute was deprecated in Matplotlib 3.2 and will be removed two minor releases later. Use ax.get_subplotspec().colspan.start instead.
  layout[ax.rowNum, ax.colNum] = ax.get_visible()
The rowNum attribute was deprecated in Matplotlib 3.2 and will be removed two minor releases later. Use ax.get_subplotspec().rowspan.start instead.
  if not layout[ax.rowNum + 1, ax.colNum]:
The colNum attribute was deprecated in Matplotlib 3.2 and will be removed two minor releases later. Use ax.get_subplotspec().colspan.start instead.
  if not layout[ax.rowNum + 1, ax.colNum]:
The rowNum attribute was deprecated in Matplotlib 3.2 and will be removed two minor releases later. Use ax.get_subplotspec().rowspan.start instead.
  layout[ax.rowNum, ax.colNum] = ax.get_visible()
The colNum attribute was

# How consistent are individual responses?
We will use similarity clustering to evaluate this.

## Test similarity metric
Compare cosine and correlation to determine which one to use

In [27]:
def plot_features_similarity(features, similarity):
    f, axes = plt.subplots(
        1, 3, figsize=(9, 3), sharey=True, tight_layout=True,
        gridspec_kw=dict(width_ratios=[2, 1, 0.05])
    )
    simclus.plot_features_similarity(
        features, similarity, axes[:2],
        extent=[features.columns[0], features.columns[-1], 0, len(features)]
    )
    if 'clusters' in similarity.columns:
        im = axes[2].imshow(win_epochs.cat.codes[
            similarity.sort_values('clusters').index
        ].values[:, np.newaxis], aspect='auto', cmap=cm.Greys)
    else:
        im = axes[2].imshow(win_epochs.cat.codes[
            similarity.index
        ].values[:, np.newaxis], aspect='auto', cmap=cm.Greys)
    axes[2].axis('off')
    return axes

In [29]:
# the response timeseries for a single electrode acts as the set of features
etd = 20
amp = 100
features = responses[etd].loc[stimuli.index[stimuli.amplitude==amp], 0.01:0.3]
# features = (features-features.mean())#/features.std()

# compute similarity
similarity = simclus.get_normalized_similarity(features, algo='euclidean', zscore=True)

# cluster by similarity
similarity = simclus.cluster_by_similarity(similarity, entropy_merge=2, **dict(n_components_range=(1, 6), cv_types=['diag']))

axes = plot_features_similarity(features, similarity.drop('clusters', axis=1))
axes[0].set_ylabel('original order')
axes[0].set_title(f'LFP electrode {etd}')
axes = plot_features_similarity(features, similarity)
axes[0].set_ylabel('cluster groups')
axes[0].set_title(f'LFP electrode {etd}');

# clusters changed from 3 to 3.


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

### Heterogenity of clusters within epochs
On average in how many clusters is `pre`, `iso_low` etc. distributed?

In [None]:
def apply_clustering(etd, amp=100, time=(0, 0.3), algo='cosine', entropy_merge=2, cluster_kw=dict(n_components_range=(1, 6), cv_types=['diag'])):
    features = responses[etd].loc[stimuli.index[stimuli.amplitude==amp], time[0]:time[1]]
    features = (features-features.mean())#/features.std()
    similarity = simclus.get_normalized_similarity(features, algo=algo)
    similarity = simclus.cluster_by_similarity(similarity, entropy_merge=entropy_merge, **cluster_kw)
    return similarity

**Apply similarity clustering to all electrodes**

In [None]:
all_similarity = {}
for amp in tqdm([20, 50, 100]):
    all_similarity[amp] = {c : apply_clustering(c) for c in responses.columns.levels[0]}
    all_similarity[amp] = pd.concat(all_similarity[amp], axis=1)

**keep clusters and compute heterogeniety**

In [None]:
all_clusters = {c:v.swaplevel(axis=1)['clusters'].join(win_epochs.rename('epoch')) for c, v in all_similarity.items()}

In [None]:
f, ax = plt.subplots(1, 1, figsize=(3, 2.5), tight_layout=True)
all_clusters[100][[0, 'epoch']].groupby('epoch').filter(
    lambda df: False if df.name in ['iso_high', 'early_recovery'] else True
).apply(
    lambda df: display(df)#ax.hist(df.drop('epoch', axis=1).values, bins=[-0.5, 0.5, 1.5, 2.5], alpha=0.3)
)

In [None]:
colors = {'pre':'b', 'iso_low':'r', 'late_recovery':'g'}
offset = {'pre':-0.1, 'iso_low':0., 'late_recovery':0.1}
def _plot_het(df, ax):
    if df.name in ['pre', 'iso_low', 'late_recovery']:
        df.drop('epoch', axis=1).astype(int).apply(
            lambda row: ax.plot(row+offset[df.name], row.index, lw=0, marker='o', c=colors[df.name], alpha=0.01, label=df.name),
            axis=1
        )
    return

f, axes = plt.subplots(1, 3, figsize=(7.5, 7), tight_layout=True, sharey=True)
for ax, amp in zip(axes, [20, 50, 100]):
    all_clusters[amp].groupby('epoch').apply(_plot_het, ax)
    ax.set_title(f'amplitude {amp}')
axes[0].set_ylabel('electrode')
axes[0].invert_yaxis()
axes[1].set_xlabel('cluster');