In this notebook, we will import and apply some preprocessing to the EEG data, and store it for quick access in the future.  
Preprocessing would involve:
1. Filtering out sub-1000 Hz noise
1. Downsampling
1. Annotating time intervals that show motion artefacts

Where should this go in the pipeline? Should the preprocessing be part of the lfp class or EEGexp class?

Downsample then filter, or filter then downsample?  
Downsampling first leads to high frequency noise staying back in the result due to aliasing. So, first filter, then downsample.

In [1]:
%load_ext autoreload
%autoreload 2

import os
from glob import glob

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import cm
from scipy import signal

from tbd_eeg.data_analysis.eegutils import EEGexp
from tbd_eeg.data_analysis.Utilities import filters

%matplotlib widget

In [2]:
plot_on_the_go = False
ch = 3

# Raw data

In [3]:
data_folder = r"/allen/programs/braintv/workgroups/nc-ophys/Leslie/eeg_pilot/mouse505550/pilot1_2020-03-02_10-08-51/recording1/"
exp = EEGexp(data_folder)
eegdata = exp.memmap_EEGdata()
timestamps = np.load(exp.eegtimestamps_file)
eegdata = pd.DataFrame(data=eegdata, index=timestamps)

if plot_on_the_go:
    f, ax = plt.subplots(1, 1, figsize=(12, 2))
    eegdata[ch].plot(ax=ax)

Loading /allen/programs/braintv/workgroups/nc-ophys/Leslie/eeg_pilot/mouse505550/pilot1_2020-03-02_10-08-51/recording1/continuous/Rhythm_FPGA-111.0/continuous.dat


In [4]:
# # get back to working on reshaping data without loading into memory
# datafiles = sorted(glob(exp.data_folder + '/**/*.dat', recursive=True))[0]
# data = np.memmap(datafiles)
# data.reshape(int(data.size/exp.num_chs), exp.num_chs)[:, exp.intanNNmap]

# Low-pass filter

In [5]:
eegdata_lp = eegdata.apply(
    lambda x: filters.butter_filter(
        np.expand_dims(x, 1),
        sampling_frequency=exp.sample_rate,
        cutoff_frequency=1000,
        filter_order=2,
        ftype='low'
    )[:, 0], raw=True, axis=0
)

if plot_on_the_go:
    f, ax = plt.subplots(1, 1, figsize=(12, 2))
    eegdata_lp[ch].plot(ax=ax)

# Downsample to 2000Hz

In [6]:
eegdata_lp_ds = eegdata_lp[::int(exp.sample_rate/2000)]
del eegdata_lp
del eegdata # comment out if running with limited memory

if plot_on_the_go:
    f, ax = plt.subplots(1, 1, figsize=(12, 2))
    eegdata_lp_ds[ch].plot(ax=ax)

# Annotate artefacts due to motion

## Identify unconnected channels
They are 30 and 31, plus others in some cases

In [7]:
def median_amplitude(data):
    peaks, _ = signal.find_peaks(data.abs().values[:2000*300], distance=2000*0.01)
    return np.median(data.abs().values[peaks][:2000*300])

In [8]:
median_amplitude = eegdata_lp_ds.apply(median_amplitude, axis=0)
active_channels = median_amplitude.index[median_amplitude < 1500]
median_amplitude

0       719.624825
1       480.487103
2       437.554127
3       703.361702
4       632.450952
5       574.245522
6       755.077336
7       689.285770
8       758.496281
9       904.261988
10      949.545130
11      941.256338
12      903.608083
13      896.314522
14      869.803090
15      873.733912
16      879.680995
17      895.212696
18      911.931058
19      877.456712
20      876.732936
21      843.788109
22      867.560646
23      852.173806
24      655.483230
25      789.975831
26      859.701420
27      570.533649
28      643.806023
29      911.620407
30    32684.204003
31    29694.638679
dtype: float64

## Can we use running data to identify the artefacts?
Not explored much. Looks unlikely.

In [9]:
sync_data = exp.load_sync_dataset()
linear_velocity, runtime = exp.load_running(sync_data)
running_speed = pd.DataFrame(index=runtime, data=linear_velocity, columns=['speed'])
running_speed = (running_speed - running_speed.mean())*200/running_speed.std()
# running_speed is rescaled so as to look good on the plots

  return eval(self.dfile['analog_meta'].value)


In [10]:
# look at 580:582 as an example of artefact
if plot_on_the_go:
    f, ax = plt.subplots(1, 1, figsize=(12, 2), sharex=True, tight_layout=True)
    running_speed.plot(ax=ax, c=cm.Greys(0.5, 0.5), label='Running speed')
    # eegdata_ds[30].plot(ax=ax2, c=cm.Blues(0.6, 0.4))
    eegdata_lp_ds[2].plot(ax=ax, c=cm.Reds(0.6, 0.4), label='2')
    eegdata_lp_ds[8].plot(ax=ax, c=cm.Greens(0.6, 0.4), label='8')
    ax.legend(loc=1, ncol=2);

## Identifying artefacts using the spectrogram

In [11]:
def normalize_spec(Sxx):
    Sxx_min = Sxx.min(axis=1)
    Sxx_max = Sxx.max(axis=1)
    return (Sxx - Sxx_min[:, np.newaxis]) / (Sxx_max[:, np.newaxis] - Sxx_min[:, np.newaxis])

In [12]:
ch = 5
_data = eegdata_lp_ds[ch]
f, t, Sxx = signal.spectrogram(_data, 2000, nperseg=512)
Sxx_norm = normalize_spec(Sxx)
frinds = f < 5000
frhinds = f > 200
t = t + _data.index[0]
fig, (ax, ax2) = plt.subplots(2, 1, figsize=(12, 4), sharex=True)
# plot spectrogram
spgm = ax.pcolormesh(t, f[frinds], Sxx_norm[frinds,:], cmap='YlOrRd', vmin=0, vmax=1)
ax.set_ylabel('Frequency (Hz)')
ax.set_xlabel('Time (s)')
_data.plot(ax=ax2, c=cm.Reds(0.6, 0.4))
running_speed.plot(ax=ax2, c=cm.Greys(0.5, 0.5))
hfp = (Sxx_norm[frhinds, :]>0.2).sum(axis=0)
ax.twinx().plot(t, hfp, c='k');
ax2.set_ylim(-5000, 5000);
# ax2.set_xlim(560, 620);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [13]:
try:
    iso_level, isotime = exp.load_analog_iso(sync_data)
    iso = pd.DataFrame(index=isotime, data=iso_level)
except:
    pass

  return eval(self.dfile['analog_meta'].value)


## Filter spectrogram of all channels to find artefacts

In [61]:
def find_hf_annotations(data_ch, sample_rate=2000, nperseg=512, fmin=200, power=0.2):
    f, t, Sxx = signal.spectrogram(data_ch, sample_rate, nperseg=nperseg)
    Sxx_norm = normalize_spec(Sxx)
    frhinds = f > fmin
    t = t + data_ch.index[0]
    return pd.Series(data=(Sxx_norm[frhinds, :]>power).sum(axis=0), index=t)

In [77]:
def get_windows(annots, coalesce=0, min_length=0):
    '''
    annots: series with binary annotations
    coalesce: merge windows with separation less than this value
    '''
    wins = []
    th_old = 0
    new_win = True
    for i, tl in enumerate(annots.index):
        if th_old > tl:
            continue
        if annots[tl]:
            if tl <= th_old + coalesce:
                new_win = False
            else:
                new_win = True
            for th in annots.index[annots.index>tl]:
                if annots[th]:
                    pass
                else:
                    if new_win:
                        wins.append((annots.index[i-1], th))
                    else:
                        wins[-1] = (wins[-1][0], th)
                    th_old = th
                    break
    return [w for w in wins if w[1]-w[0] > min_length]

In [16]:
# test the function on sample data
sample_ts = pd.Series(index=range(20), data=[0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0.2, 0, 0, 0])
get_windows(sample_ts, coalesce=0)

[(3, 5), (6, 8), (8, 10), (13, 17)]

In [74]:
hf_annots = eegdata_lp_ds[active_channels].apply(find_hf_annotations, axis=0, fmin=300, power=0.25).mean(axis=1)

In [72]:
f, ax = plt.subplots(1, 1, figsize=(4, 3))
ax.hist(hf_annots, bins=100, cumulative=True, histtype='step');

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [85]:
f, ax = plt.subplots(1, 1, figsize=(12, 3))
# hf_annots.plot(ax=ax.twinx(), c='k')
# hf_annots > X sets a threshold for detecting artifacts
for win in get_windows(hf_annots>4, coalesce=4, min_length=0.2):
    ax.axvspan(win[0], win[1], color=cm.Reds(0.3, 0.3))
eegdata_lp_ds[ch].plot(ax=ax, c=cm.Blues(0.6, 0.3))
running_speed.plot(ax=ax, c=cm.Greys(0.6, 0.3), legend=False)
try:
    iso.plot(ax=ax.twinx(), c=cm.Reds(0.8, 0.8), lw=0.5)
except:
    pass
ax.set_ylim(-5000, 5000)
ax.set_title('mouse {:s}'.format(exp.mouse), fontsize=8);

  """Entry point for launching an IPython kernel.


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Artifact statistics

In [86]:
# automatically annotate anesthesia epochs
iso_first_on = (iso>4).idxmax()[0]
iso_first_mid = ((iso[iso.index>iso_first_on]>1)&(iso[iso.index>iso_first_on]<4)).idxmax()[0]
iso_first_off = (iso>1)[::-1].idxmax()[0]
recovery_first_jump = (hf_annots>5)[hf_annots.index>iso_first_off].idxmax()

epochs = {
    'pre' : (0, iso_first_on),
    'iso_high' : (iso_first_on, iso_first_mid),
    'iso_low' : (iso_first_mid, iso_first_off),
    'early_recovery' : (iso_first_off, recovery_first_jump),
    'late_recovery' : (recovery_first_jump, eegdata_lp_ds.index[-1]),
}

In [88]:
thresh = 5
f, ax = plt.subplots(1, 1, figsize=(4, 3), tight_layout=True)
valid_windows = {}
for epoch, epoch_win in epochs.items():
    invalid = get_windows((hf_annots>thresh)&(hf_annots.index<epoch_win[1])&(hf_annots.index>epoch_win[0]), coalesce=4, min_length=0.2)
    print(epoch, '# invalid windows:', len(invalid))
    valid = [(invalid[i][1], invalid[i+1][0]) for i in range(len(invalid)-1)]
    valid_windows[epoch] = valid
    valid_lengths = np.array([win[1]-win[0] for win in valid])
    if len(valid_lengths) == 0:
        valid_lengths = np.array(epoch_win[1]-epoch_win[0])
    ax.plot(np.arange(1, 100), [(valid_lengths[valid_lengths>x]/x).sum() for x in np.arange(1, 100)], label=epoch)
ax.axhline(1, c='k', lw=0.5)
ax.axhline(10, c='k', lw=0.5)
ax.axhline(50, c='k', lw=0.5)
ax.set_xlabel('Window size (s)')
ax.set_ylabel('# Valid windows')
ax.legend(loc=1, fontsize=8)
ax.set_title(exp.mouse, fontsize=10)
ax.set_yscale('log', nonposy='mask');

  


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

pre # invalid windows: 13
iso_high # invalid windows: 12
iso_low # invalid windows: 0
early_recovery # invalid windows: 0
late_recovery # invalid windows: 87
