In this notebook, I apply UMAP to EEG data to see if different mouse states are captured well without any other analysis.

In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import cm, patches
from scipy import signal
import umap

from tbd_eeg.data_analysis.eegutils import *
from tbd_eeg.data_analysis.Utilities import utilities as utils
import differentiation

%matplotlib widget

In [None]:
epoch_cms = {
    'pre' : cm.Reds,
    'iso_high' : cm.PuOr,
    'iso_low' : cm.PuOr_r,
    'early_recovery': cm.Blues,
    'late_recovery' : cm.Greens
}

In [None]:
data_folder = "/allen/programs/braintv/workgroups/nc-ophys/Leslie/eeg_pilot/mouse496220/pilot2_2020.01.16/recording1/"

# set the sample_rate for all data analysis
sample_rate = 2500

# load experiment metadata and eeg data
exp = EEGexp(data_folder)
eegdata = exp.load_eegdata(frequency=sample_rate, return_type='pd')

# load other data (running, iso etc)
print('Loading other data...')
running_speed = exp.load_running(return_type='pd')
try:
    iso = exp.load_analog_iso(return_type='pd')
except:
    iso = None

# locate valid channels (some channels can be disconnected and we want to ignore them in the analysis)
print('Identifying valid channels...')
median_amplitude = eegdata[:sample_rate*300].apply(
    utils.median_amplitude, raw=True, axis=0, distance=sample_rate
)
valid_channels = median_amplitude.index[median_amplitude < 2000].values
print('The following channels seem to be correctly connected and report valid data:')
print(list(valid_channels))

# annotate artifacts with power in high frequencies
print('Annotating artifacts...')
hf_annots = pd.Series(
    eegdata[valid_channels].apply(
        find_hf_annotations, axis=0,
        sample_rate=sample_rate, fmin=300, pmin=0.25
    ).mean(axis=1),
    name='artifact'
)

# automatically annotate anesthesia epochs
iso_first_on = (iso>4).idxmax()
print('iso on at', iso_first_on)
iso_first_mid = ((iso[iso.index>iso_first_on]>1)&(iso[iso.index>iso_first_on]<4)).idxmax()
print('iso reduced at', iso_first_mid)
iso_first_off = (iso>1)[::-1].idxmax()
print('iso off at', iso_first_off)
recovery_first_jump = (hf_annots>4)[hf_annots.index>iso_first_off].idxmax()

epochs = pd.Series(
    index = [0, iso_first_on-0.001, iso_first_on+0.001, iso_first_mid-0.001,
             iso_first_mid+0.001, iso_first_off-0.001, iso_first_off+0.001,
             recovery_first_jump-0.001, recovery_first_jump+0.001, eegdata.index[-1]],
    data=['pre', 'pre', 'iso_high', 'iso_high', 'iso_low', 'iso_low',
          'early_recovery', 'early_recovery', 'late_recovery', 'late_recovery'],
    dtype=pd.CategoricalDtype(
        categories=['pre', 'iso_high', 'iso_low', 'early_recovery', 'late_recovery'],
        ordered=True
    )
)

# UMAP Analysis

## Generate valid windows
Valid meaning windows without artifacts.  
Turns out, we can just generate windows very easily, and validate them on the fly using a validity Series, `invalid_times`.

In [None]:
thresh = 4
invalid_times = (hf_annots>thresh)
invalid_times[invalid_times]

In [5]:
# this is just to quantify invalid windows, but is not used in later analysis
valid_windows = {}
for epoch in epochs.groupby(epochs):
    invalid = get_windows((hf_annots>thresh)&(hf_annots.index<epoch[1].index[1])&(hf_annots.index>epoch[1].index[0]), coalesce=4, min_length=0.2)
    print(epoch[0], '# invalid windows:', len(invalid))
    valid = [(invalid[i][1], invalid[i+1][0]) for i in range(len(invalid)-1)]
    valid_windows[epoch[0]] = valid

pre # invalid windows: 17
iso_high # invalid windows: 11
iso_low # invalid windows: 0
early_recovery # invalid windows: 0
late_recovery # invalid windows: 77


## Create state vectors

In [6]:
def mean_lfp(aligned_df, sample_rate, winsize):
    """
    Simply returns the mean absolute value of signal
    Instead, it could first find the envelope and return mean amplitude of that
    """
    if ((aligned_df.artifact).sum() > 0) | (len(aligned_df) < winsize * sample_rate):
        # window overlaps with artifact or window too short, so return nan
        return aligned_df[valid_channels].mean()*np.nan
    else:
        return aligned_df[valid_channels].abs().mean()
    return

def spectral_state(aligned_df, sample_rate, winsize):
    """
    Returns the spectral state for a block of time
    """
    if ((aligned_df.artifact).sum() > 0) | (len(aligned_df) < winsize * sample_rate):
        # window overlaps with artifact | window too short, so return nan
        aligned_df = pd.DataFrame(data=np.zeros((int(sample_rate*winsize), len(valid_channels))))
        return pd.Series(differentiation.spectral_states(
            sample_rate=sample_rate,
            window_length=winsize,
            data=aligned_df.values[:int(winsize*sample_rate)].T
        )[-1])*np.nan
    else:
        spec = differentiation.spectral_states(
            sample_rate=sample_rate,
            window_length=winsize,
            data=aligned_df[valid_channels].values[:int(winsize*sample_rate)].T
        )[-1]
        return pd.Series(spec)

def spectral_differentiation(aligned_df, sample_rate, winsize, state_length):
    if ((aligned_df.artifact).sum() > 0) | (len(aligned_df) < winsize * sample_rate):
        return pd.Series([np.nan]*int((winsize/state_length)*(winsize/state_length-1)/2))
    return pd.Series(
        differentiation.spectral_differentiation(
            aligned_df[valid_channels].values[:int(winsize*sample_rate)].T,
            sample_rate=sample_rate, window_length=state_length
        )
    )

In [8]:
WINSIZE_s = 9
func = spectral_state
# how many valid windows do we have of that length?
windows = pd.Series(index=eegdata.index, data=(eegdata.index/WINSIZE_s).astype(int), name='window')
aligned_windows = pd.concat([eegdata[valid_channels], windows, invalid_times.reindex(windows.index, method='nearest')], axis=1)

In [11]:
states = aligned_windows.groupby('window').apply(func, sample_rate=sample_rate, winsize=WINSIZE_s).dropna()
state_times = windows.groupby(windows).apply(lambda x: np.mean(x.index))[states.index]
states.shape

(245, 326279)

## UMAP on temporal states

In [12]:
aligned_running_windows = pd.concat([running_speed.reindex(windows.index, method='nearest'), windows, invalid_times.reindex(windows.index, method='nearest')], axis=1)
mean_speed_by_win = aligned_running_windows.groupby('window').apply(lambda x: x.running_speed.mean())
v_max = mean_speed_by_win.max()
v_min = mean_speed_by_win.min()

In [13]:
spec_df = aligned_windows.groupby('window').apply(
    spectral_differentiation, sample_rate=sample_rate, winsize=WINSIZE_s, state_length=WINSIZE_s/30
).dropna().median(axis=1)
d_max = spec_df.max()
d_min = spec_df.min()

In [14]:
reducer = umap.UMAP(
    n_neighbors=10,
    min_dist=0.01,
)
reducer.fit(states.values)
embedding = reducer.transform(states.values)

In [17]:
label = 'differentiation'
# label = 'running'
# label = 'epoch'

f, ax = plt.subplots(1, 1, figsize=(6, 4), tight_layout=True)
if label == 'running':
    c=mean_speed_by_win.loc[states.index].map(lambda x: cm.Accent((x-v_min)/(v_max-v_min), 0.4)) # color by running speed
    plt.colorbar(cm.ScalarMappable(norm=mpl.colors.Normalize(vmin=v_min, vmax=v_max), cmap=cm.Accent), label='Velocity')
if label == 'differentiation':
    c=spec_df.loc[states.index].map(lambda x: cm.Reds((x-d_min)/(d_max-d_min), 0.8)) # color by differentiation
    plt.colorbar(cm.ScalarMappable(norm=mpl.colors.Normalize(vmin=d_min, vmax=d_max), cmap=cm.Reds), label='Differentiation')
if label == 'epoch':
    c=[x(0.7, 0.8) for x in epochs.reindex(state_times, method='nearest').apply(lambda x: epoch_cms[x])] # color by epoch
    label_patches = [patches.Patch(color=epoch_cms[ep](0.7, 0.8), label=ep) for ep in epochs.dtype.categories]
    ax.legend(handles=label_patches, loc=(1.02, 0))
ax.set_xlabel('umap 1')
ax.set_ylabel('umap 2')
ax.set_title('{0:.2f} s window, {1:s}'.format(WINSIZE_s, func.__name__), fontsize=10)

sc = ax.scatter(embedding[:, 0], embedding[:, 1], c=c);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## UMAP on channels

In [None]:
states['epoch'] = epochs.reindex(state_times, method='nearest').values

In [17]:
def plot_umap(df, axes):
    e = df.epoch.iloc[0]
    ax = axes[list(epochs.cat.categories).index(e)]
    df = df[valid_channels]
    reducer = umap.UMAP(
        n_neighbors=6,
        min_dist=0.1,
    )
    reducer.fit(df.values.T)
    embedding = reducer.transform(df.values.T)

    ax.scatter(
        embedding[:, 0], embedding[:, 1],
        c=[cm.Accent(i/len(valid_channels), 0.9) for i in range(len(valid_channels))]
    )
    label_patches = [patches.Patch(color=cm.Accent(i/len(valid_channels), 0.9), label=i) for i in range(len(valid_channels))]
#     ax.legend(handles=label_patches, loc=(1.02, 0))
    ax.set_xlabel('umap 1')
    ax.set_ylabel('umap 2')
    ax.set_title('{0:.2f} s window, {1:s}\n{2:s}'.format(WINSIZE_s, func.__name__, e), fontsize=10)
    return (reducer, ax)

In [31]:
f, axes = plt.subplots(1, len(epochs.cat.categories), figsize=(3*len(epochs.cat.categories), 3), tight_layout=True)
ret = states.groupby('epoch').apply(plot_umap, axes=axes)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [19]:
import gc
gc.collect()

31172