In this notebook, we develop code to extract the phase of different rhythms at all times.  
See http://www.scholarpedia.org/article/Hilbert_transform_for_brain_waves for details on Hilbert Transform.  
  
The second section is based on the autoregressive causal phase estimation algorithm phastimate, from here: https://github.com/bnplab/phastimate

In [1]:
%load_ext autoreload
%autoreload 2
import warnings

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import cm, patches
import matplotlib.gridspec as gridspec
import scipy
from scipy import signal
from tqdm.auto import tqdm
with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    tqdm.pandas()
from spectrum import aryule

from tbd_eeg.data_analysis.eegutils import *
from tbd_eeg.data_analysis.Utilities import utilities as utils
from tbd_eeg.data_analysis.Utilities import filters

from plot_electrodes import *

from ipympl.backend_nbagg import Canvas
Canvas.header_visible.default_value = False
%matplotlib widget

In [2]:
epoch_cms = {
    'pre' : cm.Reds,
    'iso_high' : cm.PuOr,
    'iso_low' : cm.PuOr_r,
    'early_recovery': cm.Blues,
    'late_recovery' : cm.Greens
}

bands = {
    'delta' : (1, 4),
    'theta' : (5, 8),
    'alpha' : (9, 13),
    'beta' : (15, 35),
}

In [3]:
data_folder = "../tiny-blue-dot/zap-n-zip/EEG_exp/mouse521886/estim1_2020-07-16_13-37-02/experiment1/recording1/"

# set the sample_rate for all data analysis
sample_rate = 200

# load experiment metadata and eeg data
exp = EEGexp(data_folder)
eegdata = exp.load_eegdata(frequency=sample_rate, return_type='pd')

# locate valid channels (some channels can be disconnected and we want to ignore them in the analysis)
print('Identifying valid channels...')
median_amplitude = eegdata[:sample_rate*300].apply(
    utils.median_amplitude, raw=True, axis=0, distance=sample_rate
)
valid_channels = median_amplitude.index[median_amplitude < 2000].values
print('The following channels seem to be correctly connected and report valid data:')
print(list(valid_channels))

# load other data (running, iso etc)
print('Loading other data...')
running_speed = exp.load_running(return_type='pd')
iso = exp.load_analog_iso(return_type='pd')

# automatically annotate anesthesia epochs
iso_first_on = (iso>4).idxmax()
print('iso on at', iso_first_on)
iso_first_mid = ((iso[iso.index>iso_first_on]>1)&(iso[iso.index>iso_first_on]<4)).idxmax()
print('iso reduced at', iso_first_mid)
iso_first_off = (iso>1)[::-1].idxmax()
print('iso off at', iso_first_off)

# annotate artifacts with power in high frequencies
print('Annotating artifacts...')
hf_annots = pd.Series(
    eegdata[valid_channels].apply(
        find_hf_annotations, axis=0,
        sample_rate=sample_rate, fmin=300, pmin=0.25
    ).mean(axis=1),
    name='artifact'
)
recovery_first_jump = (hf_annots>4)[hf_annots.index>iso_first_off].idxmax()

epochs = pd.Series(
    index = [0, iso_first_on-0.001, iso_first_on+0.001, iso_first_mid-0.001,
             iso_first_mid+0.001, iso_first_off-0.001, iso_first_off+0.001,
             recovery_first_jump-0.001, recovery_first_jump+0.001, eegdata.index[-1]],
    data=['pre', 'pre', 'iso_high', 'iso_high', 'iso_low', 'iso_low',
          'early_recovery', 'early_recovery', 'late_recovery', 'late_recovery'],
    dtype=pd.CategoricalDtype(
        categories=['pre', 'iso_high', 'iso_low', 'early_recovery', 'late_recovery'],
        ordered=True
    )
)

The settings.xml file was not found.


Was the recording done on NP4? [y/n]  y


Experiment type: electrical stimulation.
SomnoSuite log file not found.
Identifying valid channels...
The following channels seem to be correctly connected and report valid data:
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
Loading other data...
iso on at 1439.6
iso reduced at 1658.38
iso off at 2965.83
Annotating artifacts...


In [4]:
# reindex eegdata to include all stimulus information in its index
T_PRESTIM = 2.5 # length of window before stimulus onset (~4-this is the response time)
if exp.experiment_type == 'electrical stimulation':
    stimuli = pd.read_csv(exp.stimulus_log_file)
else:
    stimuli = None
stimuli['id'] = range(len(stimuli))
stimuli['window_onset'] = stimuli.onset.map(lambda x: x-T_PRESTIM)
_stimuli = stimuli.set_index('window_onset').reindex(eegdata.index, method='ffill')
_stimuli.fillna(dict(stim_type=-1, amplitude=0, onset=0, sweep=-1, id=-1), inplace=True)
_onset = _stimuli[_stimuli.id==0].iloc[0]
_stimuli.fillna(_onset, inplace=True)
metastim = pd.MultiIndex.from_arrays(
    [_stimuli.index.rename('time'),
    _stimuli.amplitude,
    _stimuli.duration,
    _stimuli.onset,
    _stimuli.offset,
    _stimuli.sweep,
    _stimuli.id],
)
eegdata.index = metastim
eegdata = eegdata.loc[eegdata.index.dropna()]
sample_rate = 1/np.diff(eegdata.index.get_level_values('time')).mean()
eegdata

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31
time,amplitude,duration,onset,offset,sweep,id,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1
12.932482,0.0,400.0,0.00000,134.81291,-1.0,-1.0,12,7,6,11,17,16,11,22,32,70,76,88,92,89,62,48,50,59,33,31,14,21,18,12,19,5,5,7,1,5,-61,-25
12.937282,0.0,400.0,0.00000,134.81291,-1.0,-1.0,-2,2,4,-1,4,9,-1,9,28,21,32,35,45,30,29,37,34,43,33,25,17,21,19,12,14,10,8,12,12,8,21,66
12.942082,0.0,400.0,0.00000,134.81291,-1.0,-1.0,12,8,7,16,17,14,16,25,35,43,55,63,67,59,54,59,56,63,57,40,35,28,32,24,21,20,18,13,16,18,18,87
12.946882,0.0,400.0,0.00000,134.81291,-1.0,-1.0,19,11,10,22,23,20,23,30,36,7,31,52,57,49,48,52,50,57,54,35,31,26,26,21,21,15,12,11,9,11,-3,23
12.951682,0.0,400.0,0.00000,134.81291,-1.0,-1.0,2,6,11,6,10,19,8,17,29,-1,10,16,23,18,20,27,23,31,31,27,26,22,24,23,26,20,19,20,20,21,31,56
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4238.544615,50.0,400.0,4185.48398,4185.48459,2.0,899.0,-21,-6,0,-17,-18,-3,-14,0,18,14,23,20,9,19,24,16,17,-11,24,29,22,24,25,20,19,13,14,8,10,10,-12,13
4238.549415,50.0,400.0,4185.48398,4185.48459,2.0,899.0,-17,-15,-13,-18,-27,-24,-24,-23,-16,-26,-12,15,-1,11,6,-1,-5,-26,-11,0,-4,-6,1,1,-5,1,9,-9,-2,14,-15,35
4238.554215,50.0,400.0,4185.48398,4185.48459,2.0,899.0,-11,-22,-28,-9,-22,-35,-7,-18,-28,10,0,18,4,12,12,1,2,-15,-3,-12,-17,-22,-15,-21,-25,-24,-17,-30,-31,-16,-71,-32
4238.559015,50.0,400.0,4185.48398,4185.48459,2.0,899.0,3,-11,-16,0,-12,-30,-1,-13,-34,-13,-21,-20,-26,-28,-20,-29,-23,-29,-20,-28,-28,-32,-28,-28,-30,-31,-26,-27,-29,-21,-47,-43


In [5]:
f, ax = plt.subplots(1, 1, figsize=(2.3, 2.3), constrained_layout=True)
plot_electrode_map(ax, labels=False, numbers=True, box=False, s=50)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

---

# Characterizing effects of a hard boundary and predictability

In [6]:
# look at a particular channel, and prestimulus data to evaluate baseline
tid = -1
ch = 7
data = eegdata.xs(tid, level='id')[ch]
time = data.index.get_level_values('time')
data.index = time
display(data)

time
12.932482     22
12.937282      9
12.942082     25
12.946882     30
12.951682     17
              ..
132.289457    14
132.294257    27
132.299057    -6
132.303857   -28
132.308657   -11
Name: 7, Length: 24871, dtype: int16

In [78]:
def _filter_data(data, band='delta', order=3):
    return pd.Series(filters.butter_bandpass_filter(
        data.values[:, np.newaxis], bands[band][0], bands[band][1],
        sample_rate, order
    )[:, 0], index = data.index)

def _linear_predict(data_filtered, pred_len_s=0.1, buf_s=2):
    hilbert = pd.Series(signal.hilbert(data_filtered), index=data_filtered.index, name='hilbert')
    ampl = hilbert.abs().rename('amplitude') # amplitude of signal
    angl = hilbert.apply(np.angle).rename('angle') # phase of signal
    uangl = pd.Series(np.unwrap(angl), index=angl.index, name='unwrapped_angle')
    len_fit = int(pred_len_s*sample_rate)
    buf_n = int(buf_s*sample_rate)
    m, c = np.polyfit(uangl.index[buf_n:-len_fit], uangl.values[buf_n:-len_fit], 1)
    uangl_pred = pd.Series(
        m*uangl.index + c,
#         np.r_[uangl.values[:-len_fit], m*uangl.index[-len_fit:] + c],
        index=uangl.index, name='predicted_unwrapped_angle'
    )
#     print(uangl.values[buf_n:-len_fit].mean())
#     uangl_pred = uangl_pred - uangl_pred.values[buf_n:-len_fit].mean() + uangl.values[buf_n:-len_fit].mean()
#     print(uangl_pred.values[buf_n:-len_fit].mean())
    angl_pred = ((uangl_pred + np.pi) % (2 * np.pi) - np.pi).rename('predicted_angle')
    return pd.concat([hilbert, ampl, angl, uangl, uangl_pred, angl_pred], axis=1)

In [79]:
data_filtered = _filter_data(data.loc[90:95])
_ret = _linear_predict(data_filtered, 1)
f, (ax, axu) = plt.subplots(2, 1, figsize=(7, 4), tight_layout=True)
ax.axvline(92, c='k')
ax.axvline(94, c='k')
_ret.angle.plot(ax=ax)
_ret.predicted_angle.plot(ax=ax)
_ret.unwrapped_angle.plot(ax=axu)
_ret.predicted_unwrapped_angle.plot(ax=axu);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [80]:
# for a chunk of time, compare actual phase (boundary far away) and phase if filtered only up to a closer boundary
# Then also use the larger chunk to filter and predict phase in region where edge effect of hilbert is not strong
def _txd_data(data, edge_s=2, pred_len_s=3, buf_s=3, band='delta', order=3):
    edge_n = int(edge_s*sample_rate)
    pred_len_n = int(pred_len_s*sample_rate)
    # filter entire timeseries
    filtered_full = _filter_data(data, band, order).rename('filtered')
    filtered_part = _filter_data(data.loc[:data.index[-1]-edge_s], band, order).rename('filtered')
    # hilbert transform filtered timeseries
    hilbert_full = _linear_predict(filtered_full, pred_len_s, buf_s)
    hilbert_part = _linear_predict(filtered_part, pred_len_s, 0)
    return pd.concat(
        [filtered_full, filtered_part, hilbert_full, hilbert_part],
        axis=1, keys=['full', 'part', 'full', 'part']
    ).sort_index(axis=1)

In [81]:
edge_s = 3
buf_s = 3.5
pred_len = 5.5
band = 'delta'
order = 3
t_lim = (85, 95)
_comparison_data = _txd_data(data.loc[t_lim[0]:t_lim[1]], edge_s, pred_len, buf_s, band, order)
_comparison_data.columns = _comparison_data.columns.map(lambda x: f'{x[0]}_{x[1]}')
display(_comparison_data.head())

Unnamed: 0_level_0,full_amplitude,full_angle,full_filtered,full_hilbert,full_predicted_angle,full_predicted_unwrapped_angle,full_unwrapped_angle,part_amplitude,part_angle,part_filtered,part_hilbert,part_predicted_angle,part_predicted_unwrapped_angle,part_unwrapped_angle
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
85.000051,8.407235,-2.326949,-5.768438,-5.768438-6.116104j,-0.285792,12.280579,-2.326949,7.505419,2.447397,-5.768432,-5.768432+4.801720j,2.887265,2.887265,2.447397
85.004851,8.22326,-2.459666,-6.384209,-6.384209-5.183038j,-0.206849,12.359521,-2.459666,6.527314,2.931803,-6.384202,-6.384202+1.359337j,2.969547,2.969547,2.931803
85.009651,8.655827,-2.498588,-6.927239,-6.927239-5.190057j,-0.127907,12.438464,-2.498588,7.0136,2.984496,-6.927232,-6.927232+1.097283j,3.051828,3.051828,2.984496
85.014451,9.08747,-2.519433,-7.384675,-7.384675-5.296102j,-0.048965,12.517406,-2.519433,7.401539,-3.074062,-7.384668,-7.384668-0.499453j,3.13411,3.13411,3.209124
85.019251,9.638822,-2.503907,-7.744559,-7.744559-5.738353j,0.029978,12.596348,-2.503907,7.818747,-3.003721,-7.744552,-7.744552-1.074574j,-3.066794,3.216392,3.279465


In [82]:
# how does actual phase get affected due to the edge?
_times = _comparison_data.loc[t_lim[1]-edge_s-3:t_lim[1]-edge_s].index
f, ax = plt.subplots(1, 1, figsize=(8, 2), tight_layout=True)
_comparison_data.full_angle.loc[_times].plot(ax=ax, label='actual phase')
_comparison_data.part_angle.loc[_times].plot(ax=ax, label='edge-affected phase')
ax2 = ax.twinx()
pd.Series(
    np.unwrap(
        _comparison_data.full_angle.loc[_times]-_comparison_data.part_angle.loc[_times]
    ), index=_times
).plot(ax=ax2, label='difference', c='k', lw=0.5)
ax.legend(fontsize=8, loc=2)
ax.set_ylabel('angle')
ax2.set_ylabel('difference');

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

  ddmod = mod(dd + pi, 2*pi) - pi


In [83]:
# how quickly does predicted phase deviate from actual?
_times = _comparison_data.loc[t_lim[0]+buf_s:].index
f, ax = plt.subplots(1, 1, figsize=(8, 2), tight_layout=True)
_comparison_data.full_angle.loc[_times].plot(ax=ax, label='actual')
_comparison_data.full_predicted_angle.loc[_times].plot(ax=ax, label='predicted')
ax.axvline(t_lim[1]-pred_len, c='k')
ax.axvline(t_lim[0]+buf_s, c='k')
ax2 = ax.twinx()
pd.Series(
    np.unwrap(
        _comparison_data.full_angle.loc[_times]-_comparison_data.full_predicted_angle.loc[_times]
    ), index=_times
).plot(ax=ax2, label='difference', c='k', lw=0.5)
ax.legend(fontsize=8, loc=2)
ax.set_ylabel('angle')
ax2.set_ylabel('difference');

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [97]:
# repeat the above for multiple random times
def _run_comparison(t0, edge_s, buf_s, pred_len, band, order, t_tot=10):
    t_lim = (t0, t0+t_tot)
    _comparison_data = _txd_data(data.loc[t_lim[0]:t_lim[1]], edge_s, pred_len, buf_s, band, order)
    _comparison_data.columns = _comparison_data.columns.map(lambda x: f'{x[0]}_{x[1]}')
    
    _times = _comparison_data.loc[t_lim[1]-edge_s-3:t_lim[1]-edge_s].index
    edge_effect = pd.Series(
#         np.unwrap(
            _comparison_data.full_angle.loc[_times]-_comparison_data.part_angle.loc[_times],
#         ),
    index=_times
    )
    
    _times = _comparison_data.loc[t_lim[0]+buf_s:].index
    prediction_effect = pd.Series(
        np.unwrap(
            _comparison_data.full_angle.loc[_times]-_comparison_data.full_predicted_angle.loc[_times],
        ),
        index=_times
    )
    prediction_effect = prediction_effect - prediction_effect.values[-int(pred_len*sample_rate)]
    
    return edge_effect, prediction_effect

In [184]:
band = 'alpha'
order = 3
edge_s = 3
buf_s = 3.75
pred_len = 6
t_tot = 10
n_samples = 1000
t0s = 20 + np.random.random(n_samples)*100
edge_effects, pred_effects = {}, {}
for t0 in tqdm(t0s):
    _e, _p = _run_comparison(t0, edge_s, buf_s, pred_len, band, order, t_tot)
    _e.index = _e.index-_e.index[-1]
    _p.index = _p.index - (t0+t_tot-pred_len)
    edge_effects[t0] = _e.reindex(np.linspace(-3, 0, 601), method='nearest')
    pred_effects[t0] = _p.reindex(np.linspace(buf_s-(t_tot-pred_len), pred_len, 1000), method='nearest')#np.linspace(0, pred_len, 1000), method='nearest')#
edge_effects = pd.concat(edge_effects, axis=1)*180/np.pi
pred_effects = pd.concat(pred_effects, axis=1)*180/np.pi

HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




In [185]:
f, ax = plt.subplots(1, 1, figsize=(6, 2), tight_layout=True)
# edge_effects.mean(1).plot(yerr=edge_effects.std(1), alpha=0.6, ax=ax)
# ax.axhline(20, c='k', lw=0.5)
# ax.axhline(-20, c='k', lw=0.5)
# ax.set_ylabel('phase deviation\n(mean +- SD)')
edge_effects.abs().mean(1).plot(alpha=0.6, ax=ax)
ax.set_xlabel('time from right edge (s)')
ax.set_ylabel('mean absolute\ndeviation')
ax.annotate(f'{band} band filtering + Hilbert', (0.01, 0.95), xycoords='axes fraction', va='top');

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [186]:
f, ax = plt.subplots(1, 1, figsize=(6, 2), tight_layout=True)
pred_effects.abs().loc[:1].mean(1).plot(
#     yerr=pred_effects.loc[:1].std(1),#/np.sqrt(n_samples),
    alpha=0.6, ax=ax
)
ax.set_xlim(-0.1, 0.25)
ax.set_ylim(-10, 70)
ax.axvline(0, c='k')
ax.set_xlabel('time from prediction onset (s)')
# ax.axhline(20, c='k', lw=0.5)
# ax.axhline(-20, c='k', lw=0.5)
# ax.set_ylabel('phase deviation\n(mean +- SD)')
ax.set_ylabel('mean absolute\ndeviation')
ax.annotate(f'{band} linear phase extrapolation (using {t_tot-(pred_len+buf_s)}s of data)', (0.01, 0.05), xycoords='axes fraction', va='bottom');

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [188]:
def _plot_effects(axes, band, order=3, edge_s=3, buf_s=3.75, pred_len=6, t_tot=10, n_samples=1000):
    t0s = 20 + np.random.random(n_samples)*100
    edge_effects, pred_effects = {}, {}
    for t0 in tqdm(t0s, desc=band):
        _e, _p = _run_comparison(t0, edge_s, buf_s, pred_len, band, order, t_tot)
        _e.index = _e.index-_e.index[-1]
        _p.index = _p.index - (t0+t_tot-pred_len)
        edge_effects[t0] = _e.reindex(np.linspace(-3, 0, 601), method='nearest')
        pred_effects[t0] = _p.reindex(np.linspace(buf_s-(t_tot-pred_len), pred_len, 1000), method='nearest')#np.linspace(0, pred_len, 1000), method='nearest')#
    edge_effects = pd.concat(edge_effects, axis=1)*180/np.pi
    pred_effects = pd.concat(pred_effects, axis=1)*180/np.pi
    
    edge_effects.abs().mean(1).plot(alpha=0.6, ax=axes[0], label=band)
    pred_effects.abs().loc[:1].mean(1).plot(alpha=0.6, ax=axes[1], label=band)
    return

f, axes = plt.subplots(1, 2, figsize=(8, 4), tight_layout=True)
for band in bands:
    _plot_effects(axes, band)
axes[0].set_xlabel('time from right edge (s)')
axes[0].set_ylabel('mean absolute\ndeviation')
axes[0].legend(loc=2)

axes[1].set_xlim(-0.1, 0.25)
axes[1].set_ylim(-10, 70)
axes[1].axvline(0, c='k')
axes[1].set_xlabel('time from prediction onset (s)')
axes[1].set_ylabel('mean absolute\ndeviation');

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




Text(0, 0.5, 'mean absolute\ndeviation')

## Phase autocorrelation

In [190]:
data

time
12.932482     22
12.937282      9
12.942082     25
12.946882     30
12.951682     17
              ..
132.289457    14
132.294257    27
132.299057    -6
132.303857   -28
132.308657   -11
Name: 7, Length: 24871, dtype: int16

In [204]:
f, ax = plt.subplots(1, 1, figsize=(4, 3), tight_layout=True)
for band in bands:
    _data_filtered = _filter_data(data, band)
    _hilbert = _linear_predict(_data_filtered)
    _x = np.arange(100)
    _y = []
    for __x in _x:
        _y.append(_hilbert.angle.autocorr(__x))
    _y = pd.Series(_y[::-1], index=-_x[::-1]/sample_rate)
    _y.plot(ax=ax, label=band)
ax.legend(loc=2)
ax.set_xlabel('lag (s)')
ax.set_ylabel('phase autocorrelation');

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Phase detection - let's look at theta

In [6]:
def _reset_phase(df, i=0):
    return df - df.iloc[i]

def _wrap(col):
    return (col + np.pi) % (2 * np.pi) - np.pi

In [None]:
# frequency band to work with
# fband = (15, 35) # beta
# fband = (9, 13) # alpha
fband = (5, 8) # theta
# fband = (1, 4) # delta

In [None]:
eegdfltd = eegdata[valid_channels].progress_apply(
    filters.butter_bandpass_filter,
    lowcut=fband[0], highcut=fband[1],
    sampling_frequency=sample_rate, filter_order=3
)
ch1 = 10
ch2 = 15

In [11]:
# # visualize a couple of channels
# ch1 = 0
# ch2 = 31
# _win = (819, 1821)
# f, ax = plt.subplots(1, 1, figsize=(9, 2), constrained_layout=True)
# eegdata.loc[_win[0]:_win[1], ch1].plot(ax=ax, alpha=0.3, label='_')
# eegdata.loc[_win[0]:_win[1], ch2].plot(ax=ax, alpha=0.3, label='_')
# # ax.set_prop_cycle(None)
# ax2 = ax.twinx()
# eegdfltd.loc[_win[0]:_win[1], ch1].plot(ax=ax2, alpha=0.9)
# eegdfltd.loc[_win[0]:_win[1], ch2].plot(ax=ax2, alpha=0.9)
# ax2.legend();

In [12]:
# # visualize all filtered channels
# _win = (1200, 2200)
# f, ax = plt.subplots(1, 1, figsize=(9, 2), constrained_layout=True)
# ax.set_prop_cycle(c=np.array(exp.ch_coordinates.gid.map(lambda x: cm.Paired(x/11, 0.9))))
# eegdfltd.loc[_win[0]:_win[1]].plot(ax=ax, legend=None)
# ax.set_title('All channels, filter band ({0:d}, {1:d})'.format(*fband));

In [13]:
# Hilbert transform
_hil = signal.hilbert(eegdfltd, axis=0)
hilbert = pd.DataFrame(_hil, columns=eegdfltd.columns, index=eegdfltd.index)
ampl = hilbert.abs() # amplitude of signal
angl = hilbert.apply(np.angle) # phase of signal
uangl = angl.apply(np.unwrap) # unwrapped phase of signal

In [14]:
# example result of Hilbert transform
_win = (1200, 1210)
f, (ax, axp) = plt.subplots(2, 1, figsize=(11, 5), sharex=True, constrained_layout=True)
eegdfltd.loc[_win[0]:_win[1], ch1].plot(ax=ax, alpha=0.3, label='_')
eegdfltd.loc[_win[0]:_win[1], ch2].plot(ax=ax, alpha=0.3, label='_')
ax.set_prop_cycle(None)
ampl.loc[_win[0]:_win[1], ch1].plot(ax=ax, alpha=0.6)
ampl.loc[_win[0]:_win[1], ch2].plot(ax=ax, alpha=0.6)
ampl.loc[_win[0]:_win[1]].median(axis=1).plot(ax=ax, alpha=0.6, c='k', label='mean amplitude')
angl.loc[_win[0]:_win[1], [ch1]].plot(ax=axp, alpha=0.6)
angl.loc[_win[0]:_win[1], [ch2]].plot(ax=axp, alpha=0.6)
angl.loc[_win[0]:_win[1]].median(axis=1).plot(ax=axp, alpha=0.3, c='k', label='mean phase')
ax.legend(loc=1)
axp.legend(loc=1)
ax.set_ylabel('signal and amplitude')
axp.set_ylabel('phase')
axp.set_xlabel('time (s)');

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## Overall, phases drift quickly
But this does not say much because discontinuities are going to add up quickly over time. We should look for phase synchrony in short windows instead.

In [18]:
# phase drift between channels over time
_win = (0, 120)
_sig = uangl.join(stimuli.set_index('onset').reindex(uangl.index, method='nearest').id).groupby('id').apply(_reset_phase, i=-2)[valid_channels]
# _sig = _phase_reset(uangl.subtract(uangl.mean(axis=1), axis=0))
f, (ax, axe) = plt.subplots(1, 2, figsize=(7, 3), constrained_layout=True)
ax.set_prop_cycle(c=np.array(exp.ch_coordinates.gid.map(lambda x: cm.Paired(x/11, 0.9))))
# angl.subtract(angl.mean(axis=1), axis=0).loc[_win[0]:_win[1]].plot(ax=ax, alpha=0.6, legend=False) #.loc[_win[0]:_win[1]]
_reset_phase(_sig.subtract(_sig.mean(axis=1), axis=0).loc[_win[0]:_win[1]]).plot(ax=ax, alpha=0.6, legend=False)
ax.set_xlabel('time (s)')
ax.set_ylabel('phase offset from mean (rad)\n(cumulative)');
plot_electrode_map(axe, labels=False, s=50)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## Phase synchrony in short windows

In [19]:
# window the unwrapped angle
ps_winsize_s = 1
ps_wins = np.array(uangl.index / ps_winsize_s, dtype=int)
uangl['win'] = ps_wins

# compute relative phases within each window
local_rel_phases = uangl.groupby('win').apply(lambda df: _wrap(_reset_phase(df.drop('win', axis=1)))).join(uangl.win)

In [20]:
# example evolution of phases within the window
f, axes = plt.subplots(1, 3, figsize=(11, 3), tight_layout=True, sharey=True)
for ax, w in zip(axes, [16, 13, 14]):
    _phases = local_rel_phases[local_rel_phases.win==w][valid_channels].apply(np.unwrap)
#     _phases.plot(ax=ax, legend=False, alpha=0.6)
    _phases.subtract(_phases.mean(axis=1), axis=0).plot(ax=ax, legend=False, alpha=0.6)
axes[1].set_xlabel('time (s)')
axes[0].set_ylabel('phase (relative to mean)');

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [21]:
# we can look at the dispersion at the last time point as a proxy for how synchronized or not the waves were
def _compute_phase_dispersion(rel_phases):
    return rel_phases[valid_channels].apply(np.unwrap).iloc[-1].std()
phase_dispersion = local_rel_phases.groupby('win').apply(_compute_phase_dispersion)

In [22]:
# how does phase dispersion look over time?
f, ax = plt.subplots(1, 1, figsize=(8, 2), tight_layout=True)
phase_dispersion.plot(ax=ax)
ax.set_xlabel('time (s)')
ax.set_ylabel('dispersion')
ax.vlines(np.array([iso_first_on, iso_first_mid, iso_first_off])/ps_winsize_s, 0, 7, colors='r')
running_speed.plot(ax=ax.twinx(), c='k', lw=0.5);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

---

# Causal phase prediction

In [6]:
band = 'theta'

In [7]:
# look at a particular channel, and prestimulus data to evaluate baseline
tid = -1
ch = 7
data = eegdata.xs(tid, level='id')[ch]
time = data.index.get_level_values('time')
display(data)
t_onset = 120.0#data.index.get_level_values('onset').unique()[0]
f, ax = plt.subplots(1, 1, figsize=(8, 2), tight_layout=True)
ax.plot(data.index.get_level_values('time'), data, label=f'ch{ch}')
ax.axvline(t_onset, c='k')
ax.set_xlabel('time (s)')
ax.set_ylabel('voltage');

time        amplitude  duration  onset  offset     sweep
12.932482   0.0        400.0     0.0    134.81291  -1.0     22
12.937282   0.0        400.0     0.0    134.81291  -1.0      9
12.942082   0.0        400.0     0.0    134.81291  -1.0     25
12.946882   0.0        400.0     0.0    134.81291  -1.0     30
12.951682   0.0        400.0     0.0    134.81291  -1.0     17
                                                            ..
132.289457  0.0        400.0     0.0    134.81291  -1.0     14
132.294257  0.0        400.0     0.0    134.81291  -1.0     27
132.299057  0.0        400.0     0.0    134.81291  -1.0     -6
132.303857  0.0        400.0     0.0    134.81291  -1.0    -28
132.308657  0.0        400.0     0.0    134.81291  -1.0    -11
Name: 7, Length: 24871, dtype: int16

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## Step 1: Determine peak frequency and SNR

In [8]:
prestim_data = data.loc(axis=0)[pd.IndexSlice[:t_onset]] # prestimulus data

# obtain a spectrogram
_frequencies, _power = signal.welch(prestim_data, fs=sample_rate)
# find peaks
_idx_peaks = signal.find_peaks(_power)[0]
# find peak in band of interest
_idx_pois = np.array([
    x for x in _idx_peaks if _frequencies[x]<bands[band][1] and _frequencies[x]>bands[band][0]
])
_idx_peak = _power[_idx_pois].argmax()
_peak_power = _power[_idx_peak]
peak_frequency = _frequencies[_idx_pois][_idx_peak]
# TODO: Evaluate SNR and test for goodness of signal before proceeding
# The peak seems to be very strong, so I am skipping this for now

# plotting
f, ax = plt.subplots(1, 1, figsize=(4, 3), tight_layout=True)
ax.plot(_frequencies, _power)
ax.set_xlabel('frequency (Hz)')
ax.set_ylabel('PSD')
ax.set_xlim(-1, 50)
ax.scatter(
    _frequencies[_idx_peaks], _power[_idx_peaks], facecolors='none',
    edgecolors=['r' if x==_idx_pois[_idx_peak] else 'b' for x in _idx_peaks]
)
for _name, _band in bands.items():
    ax.axvspan(*_band, alpha=0.15, facecolor='r' if _name==band else 'b')
    ax.annotate(_name, (_band[0], 0), rotation=90)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## Step 2: Determine phase and amplitude using entire time

In [9]:
_data = signal.detrend(data)
# generate a set of filters, filter the signal using these, then take Hilbert transform, look at mean and variance of estimated phase

## Step 3: Optimize parameters

In [10]:
# to be implemented

## Step 4: Causal phase estimation

In [11]:
# how does filter order affect the filtered signal?
_data = signal.detrend(data)
_frequencies, _power = signal.welch(_data, fs=sample_rate)

f, (ax, ax2) = plt.subplots(
    1, 2, figsize=(12, 2), tight_layout=True,
    gridspec_kw=dict(width_ratios=[4, 1])
)
ax.plot(
    data.index.get_level_values('time'), _data,
    c='k', lw=0.6, alpha=0.5, label='raw'
)
ax2.plot(_frequencies, _power, c='k', alpha=0.6)
for filter_order in range(1, 8):
    _data_filtered = filters.butter_bandpass_filter(
        _data, bands[band][0], bands[band][1],
        sample_rate, filter_order
    )
    if np.isnan(_data_filtered).sum()>0:
        continue
    _frequencies, _power = signal.welch(_data_filtered, fs=sample_rate)
    ax.plot(
        data.index.get_level_values('time'), _data_filtered,
        alpha=0.5, label=filter_order
    )
    ax2.plot(_frequencies, _power, alpha=0.6)
ax.legend(ncol=2, loc=1, fontsize=8)
ax2.set_xlim(bands[band][0]-1, bands[band][1]+1)
ax2.set_ylim(1, 100)
ax2.set_yscale('log')
ax.set_xlabel('time (s)')
ax.set_ylabel('signal')
ax2.set_xlabel('frequency (Hz)')
ax2.set_ylabel('PSD');

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [12]:
# filter the data
FILTER_ORDER = 2
def _bandpass_filter(data, filt=None):
    time = data.index
    data = signal.detrend(data)
    # for now we use the standard butterfilter that Irene has implemented
    filter_order = FILTER_ORDER
    data_filtered = filters.butter_bandpass_filter(
        data, bands[band][0], bands[band][1],
        sample_rate, filter_order
    )
    return pd.Series(data_filtered, index=time)

_data_filtered = _bandpass_filter(data.droplevel(list(range(1, 6))))

In [13]:
# fit AR parameters
def _fit_aryule(data_filtered, arlen_ms=50):
    yworder = int(arlen_ms*sample_rate/1000)
    rho, P, k = aryule(data_filtered, yworder)
    rho = -1*rho[::-1]
    return rho, P, k

_rho, _P, _K = _fit_aryule(_data_filtered, arlen_ms=1000)
f, ax = plt.subplots(1, 1, figsize=(4, 2), tight_layout=True)
ax.plot(-np.arange(len(_rho))/sample_rate*1000, _rho, 'o')
ax.set_xlabel('Regression time (ms)')
ax.set_ylabel('AR Coefficient');

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [14]:
# predict timeseries
def _predict(data_filtered, rho, prediction_len_ms=200, arlen_ms=50):
    time = data_filtered.index
    dt = np.diff(time).mean()
    yworder = int(arlen_ms*sample_rate/1000)
    n_predict = int(prediction_len_ms/1000*sample_rate)+1
    data_with_predictions = np.r_[data_filtered, np.zeros(n_predict)]
    for i in range(len(data_filtered), len(data_with_predictions)):
        data_with_predictions[i] = (data_with_predictions[i-yworder:i]*rho).sum()
    data_with_predictions = pd.Series(
        data_with_predictions,
        index=time.append(pd.Index(
            np.ones(n_predict).cumsum()*dt+time[-1]
#             np.linspace(
#                 time[-1]+1/sample_rate,
#                 time[-1]+prediction_len_ms/1000,
#                 n_predict, endpoint=True
#             )
        ))
    )
    return data_with_predictions

_twin = (127, 129)
_data_with_predictions = _predict(_data_filtered.loc[_twin[0]:_twin[1]], _rho, prediction_len_ms=1000, arlen_ms=1000)
f, ax = plt.subplots(1, 1, figsize=(10, 2), tight_layout=True)
_data_filtered.loc[_twin[0]:].plot(ax=ax, label='data')
_data_with_predictions.plot(ax=ax, label='prediction')
ax.axvline(_twin[1], c='k', lw=0.5)
ax.set_xlim(_twin[0], _data_with_predictions.index[-1])
ax.legend(loc=2);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [15]:
# full function to predict using yule-walker
def arywpredict(data, filt=None, redge_ms=500, arlen_ms=1000, prediction_len_ms=500):
    redge = data.index[-1] - redge_ms/1000
    data_filtered = _bandpass_filter(data, filt=filt)
    rho, P, K = _fit_aryule(
        data_filtered.loc[:redge], arlen_ms=arlen_ms
    )
    data_with_predictions = _predict(
        data_filtered.loc[:redge], rho,
        prediction_len_ms=prediction_len_ms, arlen_ms=arlen_ms
    )
    return data_with_predictions, data_filtered

In [16]:
def hilbert_tx(series):
    _hil = signal.hilbert(series, axis=0)
    _hil = pd.Series(_hil, index=series.index, name='hilert')
    ampl = _hil.abs().rename('amplitude') # amplitude of signal
    angl = _hil.apply(np.angle).rename('angle') # phase of signal
    uangl = pd.Series(np.unwrap(angl), index=angl.index, name='uangl') # unwrapped phase of signal
    return pd.concat([_hil, ampl, angl, uangl], axis=1)

tid = -1
ch = 7
data = eegdata.xs(tid, level='id')[ch]
time = data.index.get_level_values('time')
t_onset = 120.0#data.index.get_level_values('onset').unique()[0]

_tpred = 16
arlen_ms = 2000
predlen_ms = 1000
redge_ms = 1000

_tmin = _tpred-arlen_ms*1.5/1000
_tmax = _tpred+predlen_ms*1.5/1000
_data_with_predictions, _data_filtered = arywpredict(
    data.droplevel(list(range(1, 6))).loc[_tpred-arlen_ms*5/1000:_tpred+redge_ms/1000],
    redge_ms=redge_ms, filt=None, arlen_ms=arlen_ms, prediction_len_ms=predlen_ms
)
_hilbert_real = hilbert_tx(_data_filtered).loc[_tmin:]
_hilbert_pred = hilbert_tx(_data_with_predictions).loc[_tmin:]
f, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 3), tight_layout=True, sharex=True)
_data_filtered[_tmin:].plot(ax=ax1, label='_filtered data', alpha=0.6)
_data_with_predictions[_tmin:].plot(ax=ax1, label='_prediction', alpha=0.6)
ax1.set_prop_cycle(None)
_hilbert_real.amplitude[_tmin:].plot(ax=ax1, label='actual')
_hilbert_pred.amplitude[_tmin:].plot(ax=ax1, label='predicted')
_hilbert_real.angle[_tmin:].plot(ax=ax2, marker=None, label='actual')
_hilbert_pred.angle[_tmin:].plot(ax=ax2, marker=None, label='predicted')
ax1.axvline(_tpred, c='k')
ax2.axvline(_tpred, c='k')
ax1.legend(loc=2, fontsize=9)
ax2.legend(loc=2, fontsize=9)
ax2.set_xlabel('time (s)');

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

### Characterize accuracy of prediction

In [17]:
def _get_phase_amplitude_accuracy(T, ch=7, arlen_ms=1000, predlen_ms=1000, redge_ms=1000):
    tid = -1
    data = eegdata.xs(tid, level='id')[ch]
    time = data.index.get_level_values('time')
    _tmin = T-arlen_ms*1.5/1000
    _tmax = T+predlen_ms*1.5/1000
    _data_with_predictions, _data_filtered = arywpredict(
        data.droplevel(list(range(1, 6))).loc[T-arlen_ms*5/1000:T+redge_ms/1000],
        redge_ms=redge_ms, filt=None, arlen_ms=arlen_ms, prediction_len_ms=predlen_ms
    )
    _hilbert_real = hilbert_tx(_data_filtered)
    _hilbert_pred = hilbert_tx(_data_with_predictions)
    _hilbert_real['uangl'] = _hilbert_real.uangl - _hilbert_real.uangl.reindex([T], method='nearest').values[0]
    _hilbert_pred['uangl'] = _hilbert_pred.uangl - _hilbert_pred.uangl.reindex([T], method='nearest').values[0]
    m, c = np.polyfit(_hilbert_pred.uangl.loc[:T].index, _hilbert_pred.uangl.loc[:T].values, 1)
    _hilbert_real['linear'] = _hilbert_real.uangl
    _hilbert_pred['linear'] = _hilbert_pred.uangl
    _hilbert_pred.loc[T:, 'linear'] = m*_hilbert_pred.uangl.loc[T:].index + c
    _hilbert_diff = _hilbert_pred[T:] - _hilbert_real.reindex(
        _hilbert_pred.index, method='nearest'
    ).loc[T:]
    _hilbert_diff['amplitude'] = _hilbert_diff.amplitude / _hilbert_real.amplitude.mean()
    _sig_diff = _data_with_predictions[T:] - _data_filtered.reindex(
        _data_with_predictions.index, method='nearest'
    )[T:]
    _diff = pd.concat([_sig_diff.rename('signal'), _hilbert_diff.amplitude, _hilbert_diff.uangl, _hilbert_diff.linear], axis=1)
    _diff.index = np.array(1000*(_diff.index - T), dtype=int)/1000
    return _diff.loc[-1:]

f, axes = plt.subplots(1, 3, figsize=(9, 2), tight_layout=True)
_diff = _get_phase_amplitude_accuracy(100)
for c, ax in zip(_diff.columns, axes):
    _diff[c].plot(ax=ax)
    ax.set_ylabel(f'delta {c}')
axes[1].set_xlabel('time from starting prediction (s)');

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [18]:
T_preds = 20 + np.random.randint(0, 1000, 500)/1000*100 # select a random prediction time in the initial no-stimulation period
_diffs = {}
for T in tqdm(sorted(T_preds)):
    _diffs[T] = _get_phase_amplitude_accuracy(T)
_diffs = pd.concat(_diffs, axis=1).ffill().bfill()
_diffs = _diffs.swaplevel(axis=1)
_diffs['uangl'] = _diffs['uangl'] / 3.14*180
_diffs['linear'] = _diffs['linear'] / 3.14*180

HBox(children=(FloatProgress(value=0.0, max=500.0), HTML(value='')))




In [19]:
f, axes = plt.subplots(1, 3, figsize=(9, 2), tight_layout=True, sharex=True)
for c, ax in zip(_diffs.columns.levels[0], axes):
    _diffs[c].abs().median(axis=1).plot(
        ax=ax,# yerr=_diffs[c].std(axis=1)/np.sqrt(500)
    )
    ax.set_ylabel(f'delta {c}')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [20]:
f, ax = plt.subplots(1, 1, figsize=(4, 3), tight_layout=True)
ax.plot((_diffs.linear-_diffs.uangl).abs().median(axis=1), label='AR prediction-actual')
ax.plot(_diffs.linear.abs().median(axis=1), label='linear prediction-actual')
ax.set_xlabel('time')
ax.set_ylabel('angle')
ax.legend(loc=2, fontsize=9);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## Test phase detection algorithm on sample data

In [21]:
stimuli[stimuli.amplitude==100].head(10)

Unnamed: 0,stim_type,amplitude,duration,onset,offset,sweep,id,window_onset
3,biphasic,100,400.0,147.04601,147.04661,0,3,144.54601
4,biphasic,100,400.0,151.26887,151.26948,0,4,148.76887
5,biphasic,100,400.0,154.91916,154.91976,0,5,152.41916
9,biphasic,100,400.0,171.72105,171.72165,0,9,169.22105
10,biphasic,100,400.0,176.02181,176.02242,0,10,173.52181
15,biphasic,100,400.0,195.45212,195.45272,0,15,192.95212
17,biphasic,100,400.0,203.01729,203.01789,0,17,200.51729
24,biphasic,100,400.0,231.21529,231.21589,0,24,228.71529
28,biphasic,100,400.0,246.58409,246.58469,0,28,244.08409
30,biphasic,100,400.0,254.73616,254.73676,0,30,252.23616


In [22]:
# look at a particular channel and stimulus
tid = 5
ch = 4
delay_ms = 200
data = eegdata.xs(tid, level='id')[ch]
time = data.index.get_level_values('time')
t_onset = data.index.get_level_values('onset').unique()[0]

_tpred = t_onset - delay_ms/1000
arlen_ms = 1000
predlen_ms = 1000
redge_ms = 1000
_tmin = _tpred-1
_tmax = _tpred+predlen_ms*1.5/1000

f, (ax, ax1, ax2) = plt.subplots(3, 1, figsize=(8, 4), tight_layout=True, sharex=True)
data.droplevel(list(range(1, 6))).loc[_tmin:_tpred+redge_ms/1000].plot(
    ax=ax, label=f'ch{ch}'
)
ax.axvline(t_onset, c='k')
ax.set_xlabel('time (s)')
ax.set_ylabel('voltage')

_data_with_predictions, _data_filtered = arywpredict(
    data.droplevel(list(range(1, 6))).loc[_tpred-arlen_ms*5/1000:_tpred+redge_ms/1000],
    redge_ms=redge_ms, filt=None, arlen_ms=arlen_ms, prediction_len_ms=predlen_ms
)
_hilbert_real = hilbert_tx(_data_filtered).loc[_tmin:]
_hilbert_data = hilbert_tx(_data_with_predictions).loc[_tmin:]
_data_filtered[_tmin:].plot(ax=ax1, label='_filtered data', alpha=0.6)
_data_with_predictions[_tmin:].plot(ax=ax1, label='_prediction', alpha=0.6)
ax1.set_prop_cycle(None)
_hilbert_real.amplitude[_tmin:].plot(ax=ax1, label='actual')
_hilbert_data.amplitude[_tmin:].plot(ax=ax1, label='predicted')
_hilbert_real.angle[_tmin:].plot(ax=ax2, label='actual')
_hilbert_data.angle[_tmin:].plot(ax=ax2, label='predicted')
ax1.axvline(_tpred, c='k', lw=0.5)
ax2.axvline(_tpred, c='k', lw=0.5)
ax1.axvline(t_onset, c='k')
ax2.axvline(t_onset, c='k')
ax1.legend(loc=2, fontsize=9)
ax2.legend(loc=2, fontsize=9)
ax2.set_xlabel('time (s)')
ax1.set_ylabel('voltage\n(band passed)')
ax2.set_ylabel('phase (rad)')
ax.set_title(f'ch {ch} stimulus {tid} ({stimuli.amplitude[tid]} mA) - {band} band');

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

1. On test data, show error in amplitude vs time
1. Error in phase vs time
1. Is AR based prediction different from simple linear extrapolation?