In this notebook, we will explore some data analyses using the hilbert transform of the EEG signal.

In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import cm, patches
import matplotlib.gridspec as gridspec
from tqdm.auto import tqdm
from scipy import signal
import umap

from tbd_eeg.data_analysis.eegutils import *
from tbd_eeg.data_analysis.Utilities import utilities as utils
from tbd_eeg.data_analysis.Utilities import filters

from ipympl.backend_nbagg import Canvas
Canvas.header_visible.default_value = False
%matplotlib widget

In [2]:
epoch_cms = {
    'pre' : cm.Reds,
    'iso_high' : cm.PuOr,
    'iso_low' : cm.PuOr_r,
    'early_recovery': cm.Blues,
    'late_recovery' : cm.Greens
}

In [3]:
data_folder = "/allen/programs/braintv/workgroups/nc-ophys/Leslie/eeg_pilot/mouse507190/pilot1_2020-02-28_10-33-11/recording1/"

# set the sample_rate for all data analysis
sample_rate = 2500

# load experiment metadata and eeg data
exp = EEGexp(data_folder)
eegdata = exp.load_eegdata(frequency=sample_rate, return_type='pd')

# locate valid channels (some channels can be disconnected and we want to ignore them in the analysis)
print('Identifying valid channels...')
median_amplitude = eegdata[:sample_rate*300].apply(
    utils.median_amplitude, raw=True, axis=0, distance=sample_rate
)
valid_channels = median_amplitude.index[median_amplitude < 2000].values
print('The following channels seem to be correctly connected and report valid data:')
print(list(valid_channels))

# load other data (running, iso etc)
print('Loading other data...')
running_speed = exp.load_running(return_type='pd')
iso = exp.load_analog_iso(return_type='pd')

# automatically annotate anesthesia epochs
iso_first_on = (iso>4).idxmax()
print('iso on at', iso_first_on)
iso_first_mid = ((iso[iso.index>iso_first_on]>1)&(iso[iso.index>iso_first_on]<4)).idxmax()
print('iso reduced at', iso_first_mid)
iso_first_off = (iso>1)[::-1].idxmax()
print('iso off at', iso_first_off)

# annotate artifacts with power in high frequencies
print('Annotating artifacts...')
hf_annots = pd.Series(
    eegdata[valid_channels].apply(
        find_hf_annotations, axis=0,
        sample_rate=sample_rate, fmin=300, pmin=0.25
    ).mean(axis=1),
    name='artifact'
)
recovery_first_jump = (hf_annots>4)[hf_annots.index>iso_first_off].idxmax()

epochs = pd.Series(
    index = [0, iso_first_on-0.001, iso_first_on+0.001, iso_first_mid-0.001,
             iso_first_mid+0.001, iso_first_off-0.001, iso_first_off+0.001,
             recovery_first_jump-0.001, recovery_first_jump+0.001, eegdata.index[-1]],
    data=['pre', 'pre', 'iso_high', 'iso_high', 'iso_low', 'iso_low',
          'early_recovery', 'early_recovery', 'late_recovery', 'late_recovery'],
    dtype=pd.CategoricalDtype(
        categories=['pre', 'iso_high', 'iso_low', 'early_recovery', 'late_recovery'],
        ordered=True
    )
)

Identifying valid channels...
The following channels seem to be correctly connected and report valid data:
[0, 1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
Loading other data...


  return eval(self.dfile['analog_meta'].value)


iso on at 644.112
iso reduced at 1190.048
iso off at 1867.552
Annotating artifacts...


In [4]:
# validate data and artifact annotation
f, ax = plt.subplots(1, 1, figsize=(12, 2), tight_layout=True)
eegdata[[13, 4]][::50].plot(ax=ax, alpha=0.7)
ax2 = ax.twinx()
hf_annots.plot(ax=ax2, c='r', lw=0.5)
ax2.set_ylim(-np.sum(ax2.get_ylim())/2, ax2.get_ylim()[1]);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [5]:
_egroups = {
    'left_front' : [11, 12, 13, 14],
    'right_front' : [18, 17, 16, 15],
    'left_front_middle' : [9, 10],
    'right_front_middle' : [20, 19],
    'left_back_middle' : [3, 4, 6, 7],
    'right_back_middle' : [26, 25, 23, 22],
    'left_back_middle_center' : [5, 8],
    'right_back_middle_center' : [24, 21],
    'left_back' : [1, 2],
    'right_back' : [28, 27],
    'left_bottom' : [0],
    'right_bottom' : [29],
}

exp.ch_coordinates['z'] = 0
exp.ch_coordinates['group'] = ''
exp.ch_coordinates['gid'] = 0
exp.ch_coordinates['wgid'] = 0
for i, (g, idx) in enumerate(_egroups.items()):
    exp.ch_coordinates.loc[idx, 'group'] = g
    exp.ch_coordinates.loc[idx, 'gid'] = i
    exp.ch_coordinates.loc[sorted(idx), 'wgid'] = idx
exp.ch_coordinates = exp.ch_coordinates.sort_values(['gid', 'wgid'])
exp.ch_coordinates['order'] = 0
_left = exp.ch_coordinates.index[exp.ch_coordinates.group.str.contains('left')]
exp.ch_coordinates.loc[_left, 'order'] = range(len(_left))
_right = exp.ch_coordinates.index[exp.ch_coordinates.group.str.contains('right')]
exp.ch_coordinates.loc[_right, 'order'] = len(_left)+np.arange(len(_right))[::-1]
exp.ch_coordinates.sort_index(inplace=True)
exp.ch_coordinates.drop('wgid', inplace=True, axis=1)

In [6]:
# define a function to quickly plot the electrode map with or without borders
def plot_electrode_map(ax, highlight=None, labels=True, cmap=cm.Paired, s=50):
    colors = np.array(exp.ch_coordinates.gid.map(lambda x: cmap(x/11, 0.9)))
    if highlight in set(exp.ch_coordinates.group):
        colors = np.array(exp.ch_coordinates.apply(lambda row: cmap(row.gid/12, 0.9) if row.group==highlight else cm.Greys(0.5,0.5), axis=1))
    exp.ch_coordinates.plot(
        kind='scatter', x='ML', y='AP', marker='o', ax=ax, legend=False, c=colors, s=s
    )
    if labels:
        for i in exp.ch_coordinates.index:
            ax.annotate(i, exp.ch_coordinates.loc[i, ['AP', 'ML']][::-1]+[0, 0.2], xycoords='data', ha='center')
    ax.set_xlim(-5, 5)
    ax.set_ylim(-5, 3)
    if labels:
        ax.set_title('Electrode map')
    if not labels:
        ax.set_xlabel('')
        ax.set_ylabel('')
    ax.set_xticks([])
    ax.set_yticks([]);

# function to show electrode groups along an axis instead of electrode numbers
def draw_groups(ax, cmap=cm.Paired):
    ax.set_xlim(-0.5, 29.5)
    ax.set_ylim(29.5, -0.5)
    ax.set_xticks([])
    ax.set_yticks([])
    nt = 0
    def add_patch(df):
        nonlocal nt, ax
        ax.add_patch(patches.Rectangle(
            (nt+0.05-0.5, 30), len(df)-0.05, 1, clip_on=False, color=cmap(df.gid.iloc[0]/12, 0.9), label=df.group.iloc[0]
        ))
        ax.add_patch(patches.Rectangle(
            (-2, nt+0.05-0.5), 1, len(df), clip_on=False, color=cmap(df.gid.iloc[0]/12, 0.9)
        ))
        nt += len(df)
    exp.ch_coordinates.sort_values('order').groupby('gid', sort=False).apply(add_patch)

In [7]:
# window data and get valid windows
thresh = 4
invalid_times = (hf_annots>thresh)

def drop_artifacts(df, winsize, sample_rate):
    if (len(df) < winsize*sample_rate) | (df['artifact'].sum() > 0):
        return pd.Series([np.nan]*(int(winsize*sample_rate)))
    return pd.Series(df.index[:int(winsize*sample_rate)])

winsize_n_samples = 4096*4
winsize_s = winsize_n_samples / sample_rate
windows = pd.Series(index=eegdata.index, data=(eegdata.index/winsize_s).astype(int), name='window')
aligned_windows = pd.concat([windows, invalid_times.reindex(windows.index, method='nearest')], axis=1)
valid_windows = aligned_windows.groupby('window').apply(drop_artifacts, winsize=winsize_s, sample_rate=sample_rate).dropna()
win_center = valid_windows.apply(lambda x: x.mean(), axis=1)
win_epoch = pd.Series(data=epochs.reindex(win_center, method='nearest').values, index=win_center.index)
valid_windows['epoch'] = win_epoch

valid_wins_by_epoch = valid_windows.set_index('epoch', append=True).swaplevel()
print('Number of workable windows (size {0:.2f} s) in each epoch:'.format(winsize_s))
[print(idx, len(valid_wins_by_epoch.loc[idx])) for idx in list(valid_wins_by_epoch.index.levels[0])];

Number of workable windows (size 6.55 s) in each epoch:
pre 29
iso_high 63
iso_low 102
early_recovery 39
late_recovery 123


# Hilbert transform the data

# PSD

In [8]:
for ep in ['iso_low', 'pre', 'late_recovery']:#valid_wins_by_epoch.index.levels[0]:
    _data = np.array(
        [eegdata.loc[valid_wins_by_epoch.loc[ep].loc[x]].values for x in valid_wins_by_epoch.loc[ep].index]
    )
    _data = _data.reshape(-1, _data.shape[-1])
    f, (ax, axe) = plt.subplots(
        1, 2, figsize=(8, 3), constrained_layout=True,
        gridspec_kw={'width_ratios':[2.5, 1]}
    )
    for c in eegdata.columns:
        ax.psd(
            _data[:, c], detrend='linear', Fs=sample_rate, NFFT=winsize_n_samples,
            label=c, c=cm.Paired(exp.ch_coordinates.loc[c, 'gid']/12, 0.5)
        )
    ax.set_ylim(10, 50)
    ax.set_xlim(0.5, 100)
    ax.set_xscale('log')
    ax.set_yscale('log')
    plot_electrode_map(axe)
    ax.set_title(ep);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
# bands of interest (roughly based on the above spectrograms)
boi = [(1, 10), (10, 17), (17, 25), (25, 50), (50, 100)]

In [10]:
# filter data by bands
filtered_data = {}
for band in tqdm(boi):
    filtered_data[band] = eegdata.apply(
        filters.butter_bandpass_filter,
        lowcut=band[0], highcut=band[1],
        sampling_frequency=sample_rate, filter_order=3
    )

HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))




In [37]:
_hil = signal.hilbert(filtered_data[(50, 100)][:winsize_n_samples][0])

In [38]:
f, (ax, axa) = plt.subplots(2, 1, figsize=(8, 4), sharex=True)
ax.plot(filtered_data[(50, 100)][:winsize_n_samples][0].values)
ax.plot(np.abs(_hil))
ax.plot(-np.abs(_hil))
axa.plot(np.unwrap(np.angle(_hil)));

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [58]:
# apply hilbert trasnform on each band
hilbert_data = {}
for band, data in filtered_data.items():
    

{(1,
  10):                       0           1          2           3           4  \
 29.833911    179.754529  134.554246 -12.203937  310.194006  288.631675   
 29.834311    180.349244  132.483192 -17.746028  306.743403  283.862425   
 29.834711    180.897240  130.367923 -23.283493  303.216344  279.010283   
 29.835111    181.398070  128.208870 -28.815122  299.613160  274.076101   
 29.835511    181.851314  126.006487 -34.339717  295.934218  269.060775   
 ...                 ...         ...        ...         ...         ...   
 3642.282076   -0.137466   -0.422038  -0.623872   -0.412206   -0.694910   
 3642.282476   -0.119817   -0.362815  -0.534596   -0.357591   -0.599403   
 3642.282876   -0.103715   -0.309758  -0.454780   -0.308075   -0.513480   
 3642.283276   -0.089105   -0.262480  -0.383841   -0.263425   -0.436593   
 3642.283676   -0.075927   -0.220600  -0.321197   -0.223401   -0.368192   
 
                       5           6           7           8           9  \
 29.833911 