# Initialization

In [9]:
import sys
sys.path.append("../src")
import os
import numpy as np
import pickle
import json
import pandas as pd
from grabbit import Layout
from mne import read_epochs, grand_average, write_evokeds, read_evokeds
from mne import pick_types, combine_evoked, set_log_level, grand_average
from mne.time_frequency import tfr_morlet, read_tfrs, write_tfrs
from mne.viz import plot_compare_evokeds
from utils import CH_NAMES, select_subjects, drop_bad_trials
from eeg_sensor_analysis import baseline_normalize, power_heatmap, add_events
import matplotlib.pyplot as plt
import seaborn as sns
from ipywidgets import interact, fixed

sns.set(style='whitegrid', font_scale=2)
colors = ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3']
set_log_level('critical')

# load subjects to process
layout = Layout('../data', '../data/grabbit_config.json')
subjects = select_subjects(layout, 'eeg', exclude='eeg')

# Analysis Parameters 
conditions = ['incongruent', 'congruent']
epoch_types = ['stimulus', 'response']
epoch_times = [(-.5, 1.75), (-1, 1)]
baseline = (-.5, -.1)
# match Cohen, Donner 2013 
frequencies = np.logspace(np.log10(2), np.log10(60), num=30) 
n_cycles = np.logspace(np.log10(3), np.log10(10), num=30) 

# load behavior
behavior = pd.read_csv('../data/derivatives/behavior/group_data.tsv', 
                       na_values='n/a', sep='\t') 
behavior = behavior[behavior.modality == 'eeg']

# make eeg_sensor derivative directory structure
pipeline_root = '../data/derivatives/eeg_sensor'
if not os.path.exists(pipeline_root):
    os.makedirs(pipeline_root)
for subject in subjects + ['group']:
    if not os.path.exists('%s/%s' % (pipeline_root, subject)):
        os.makedirs('%s/%s' % (pipeline_root, subject))
    if not os.path.exists('%s/%s/evoked' % (pipeline_root, subject)):
        os.makedirs('%s/%s/evoked' % (pipeline_root, subject))
    if not os.path.exists('%s/%s/tfr' % (pipeline_root, subject)):
        os.makedirs('%s/%s/tfr' % (pipeline_root, subject))

# ERPs

## Make Evoked Data

In [7]:
for epo_type, epo_times in zip(epoch_types, epoch_times):
    print(epo_type)
    
    group = {'incongruent': [], 'congruent': []}
    
    for subject in subjects:
        if subject == 'group':
            continue
            
        print(subject)
        
        # load subject epochs & behavior
        epo_file = layout.get(subject=subject, 
                              derivative='eeg_preprocessing', 
                              extensions='%s_cleaned-epo.fif' % epo_type)[0]
        epochs = read_epochs(epo_file.filename, verbose=False)
        sub_behavior = behavior[behavior.participant_id == subject]
        
        # crop filter period
        epochs.crop(epo_times[0], epo_times[1])
        
        # drop bad trials from epochs and behavior
        sub_behavior, epochs = drop_bad_trials(subject, sub_behavior,
                                               epochs, layout, epo_type)
        
        # add event labels
        epochs = add_events(epochs, sub_behavior)
        
        # interpolate bads
        bads = epochs.info['bads']
        epochs.interpolate_bads(reset_bads=True)
        
        # extract evoked and standard error
        evos = [epochs[c].average() for c in conditions]
        evos_std = [epochs[c].standard_error() for c in conditions]
        
        # save evoked and standard error
        f = '%s/%s/evoked/%s_%s-ave.fif' % (pipeline_root, subject,
                                                   subject, epo_type)
        write_evokeds(f, evos)
        f = '%s/%s/evoked/%s_%s_stderr-ave.fif' % (pipeline_root, subject,
                                                       subject, epo_type)
        write_evokeds(f, evos_std)
        
        # accumulate group data
        for i, c in enumerate(conditions):
            group[c].append(evos[i])
        
    # accumulate group data
    evos = []
    evos_std = []
    for i, c in enumerate(conditions):
        evos.append(grand_average(group[c]))
        
        # compute group standard error
        tmp = np.array([e.data for e in group[c]])
        tmp = np.std(tmp, axis=0) / np.sqrt(tmp.shape[0])
        
        # place group standard error in evoked object
        std_err = evos[i].copy()
        std_err.data = tmp.squeeze()
        evos_std.append(std_err)
        
    # save group evoked and standard error
    f = '%s/group/evoked/group_%s-ave.fif' % (pipeline_root, epo_type)
    write_evokeds(f, evos)
    f = '%s/group/evoked/group_%s_stderr-ave.fif' % (pipeline_root, 
                                                     epo_type)
    write_evokeds(f, evos_std)


print('Done!')

stimulus
sub-hc001
sub-hc002
sub-hc003
sub-hc004
sub-hc005
sub-hc006
sub-hc007
sub-hc008
sub-hc009
sub-hc010
sub-hc011
sub-hc012
sub-hc014
sub-hc015
sub-hc016
sub-hc017
sub-hc019
sub-hc020
sub-hc021
sub-hc022
sub-hc023
sub-hc024
sub-hc025
sub-hc026
sub-hc028
sub-hc029
sub-hc030
sub-hc031
sub-hc032
sub-hc033
sub-hc034
sub-hc035
sub-hc036
sub-hc037
sub-hc042
sub-hc044
sub-hc045
sub-pp001
sub-pp002
sub-pp003
sub-pp004
sub-pp005
sub-pp006
sub-pp007
sub-pp008
sub-pp009
sub-pp010
sub-pp011
sub-pp012
sub-pp013
sub-pp014
sub-pp015
sub-pp016
response
sub-hc001
sub-hc002
sub-hc003
sub-hc004
sub-hc005
sub-hc006
sub-hc007
sub-hc008
sub-hc009
sub-hc010
sub-hc011
sub-hc012
sub-hc014
sub-hc015
sub-hc016
sub-hc017
sub-hc019
sub-hc020
sub-hc021
sub-hc022
sub-hc023
sub-hc024
sub-hc025
sub-hc026
sub-hc028
sub-hc029
sub-hc030
sub-hc031
sub-hc032
sub-hc033
sub-hc034
sub-hc035
sub-hc036
sub-hc037
sub-hc042
sub-hc044
sub-hc045
sub-pp001
sub-pp002
sub-pp003
sub-pp004
sub-pp005
sub-pp006
sub-pp007
sub-pp008
su

## Visualize Evoked Responses 

### Plot ERP Waveforms

In [12]:
def plot_erps(subject, ch, behavior):
    
    plt.close('all')
    fig, axs = plt.subplots(1, 2, figsize=(24, 6), sharey=True)
    
    # remove bad trials from behavior
    exclusions = ['fast_rt', 'no_response', 'error', 'post_error']
    behavior = behavior.loc[np.where(np.sum(behavior[exclusions], 
                                            axis=1) == 0)[0], :]
    if subject != 'group':
        behavior = behavior.loc[behavior.participant_id == subject, :]
    
    for i, epo_type in enumerate(epoch_types):
        ax = axs[i]
        
        # load evoked with standard error 
        evo_file = layout.get(subject=subject, 
                              derivative='eeg_sensor', 
                              extensions='%s-ave.fif' % epo_type)[0]
        evos = read_evokeds(evo_file.filename, verbose=False)
        evo_file = layout.get(subject=subject, 
                              derivative='eeg_sensor', 
                              extensions='%s_stderr-ave.fif' % epo_type)[0]
        evos_std = read_evokeds(evo_file.filename, verbose=False)
        
        
        for j, c in enumerate(conditions):
            
            evo = evos[j]
            evo_std = evos_std[j]
            
            # select out chosen channel
            evo.pick_channels([ch])
            evo_std.pick_channels([ch])
            
            # extract the data and standard error
            times = evo.times 
            data = evo.data.squeeze() * 1e6
            std_err = evo_std.data.squeeze() * 1e6
            
            # plot waveforms with standard error shading
            ax.plot(times, data, color=colors[j])
            ax.fill_between(times, data - std_err, data + std_err,
                            alpha=0.5, color=colors[j])
            
        # histogram rts on bottom of stimulus-locked plots
        for j, c in enumerate(conditions):
            
            if epo_type == 'stimulus':
                bottom=ax.get_ylim()[0]
                rt = behavior[behavior.trial_type == c].response_time
                ax.hist(rt, color=colors[j], alpha=0.2, 
                        normed=True, bottom=bottom)
            
        # set time axis ticks
        if epo_type == 'stimulus':
            ax.set_xticks(np.arange(-.5, 1.8, .25))
            ax.set_xlim((-.5, 1.75))
            ax.set_ylabel('$\mu V$')
        else:
            ax.set_xticks(np.arange(-1, 1.1, .25))
        
        # plot flourishes
        ax.set_title('%s-locked' % epo_type)
        ax.axvline(0, color='k')
        ax.axhline(0, color='k')
        ax.set_xlabel('Time (s)')
        ax.legend(conditions, loc='best')
    
    plt.suptitle('%s %s ERPs' % (subject, ch), y=1.05)
    sns.despine()
    plt.show();

interact(plot_erps, subject=subjects, ch=CH_NAMES, 
         behavior=fixed(behavior));

### Plot Topomaps 

In [13]:
def plot_topomap(subject, epo_type, time, col_limit):
    plt.close('all')
    evo_file = layout.get(subject=subject, 
                          derivative='eeg_sensor', 
                          extensions='%s-ave.fif' % epo_type)[0]
    evokeds = read_evokeds(evo_file.filename, verbose=False)
    
    f, axs = plt.subplots(1, 4, figsize=(24, 6)) 
    
    for i, evo in enumerate(evokeds):
        evo.plot_topomap(times=time, axes=axs[i], colorbar=True, 
                         show=False, vmin=-col_limit, vmax=col_limit)
        axs[i].set_xlabel(conditions[i])
        
    diff = combine_evoked(evokeds, weights=[1, -1])
    diff.plot_topomap(times=time, axes=axs[2], colorbar=True, 
                      show=False, vmin=-col_limit, vmax=col_limit)
    axs[2].set_xlabel('i-c')
    
    plt.show();

interact(plot_topomap, subject=['group'] + subjects, 
         epo_type=['response', 'stimulus'],
         time=np.arange(-1, 1.75, .01), col_limit=np.arange(.5, 5, .5),
         diff_col_limit=np.arange(.5, 5, .5));


# TFR Power

## Make TFR Power

### Compute Raw TFR Power

In [15]:
for epo_type, epo_times in zip(epoch_types, epoch_times):
    print(epo_type)
    
    for subject in subjects:
        if subject == 'group':
            continue
        print(subject)

        # load subject behavior and epochs
        sub_behavior = behavior[behavior.participant_id == subject]
        epo_file = layout.get(subject=subject, 
                              derivative='eeg_preprocessing', 
                              extensions='%s_cleaned-epo.fif' % epo_type)[0]
        epochs = read_epochs(epo_file.filename, verbose=False)

        # interpolate the bad channels
        epochs.interpolate_bads(reset_bads=True)

        # drop bad trials from epochs and behavior
        sub_behavior, epochs = drop_bad_trials(subject, sub_behavior, 
                                               epochs, layout, epo_type)

        # add event labels
        epochs = add_events(epochs, sub_behavior)

        tfrs = []
        for event in conditions:
            cond_epochs = epochs[event]
            
            # remove evoked response
            cond_epochs.subtract_evoked()
            
            # compute tfr
            power = tfr_morlet(cond_epochs, freqs=frequencies, 
                               n_cycles=n_cycles, decim=5, 
                               return_itc=False, average=False, n_jobs=5)
            
            # crop out filter buffer
            power.crop(epo_times[0], epo_times[1])
            tfrs.append(power)

        f = '../data/derivatives/eeg_sensor/%s/tfr/%s_%s_raw-tfr.h5'
        write_tfrs(f % (subject, subject, epo_type), tfrs, overwrite=True)

print('Done!')

stimulus
sub-hc001
sub-hc002
sub-hc003
sub-hc004
sub-hc005
sub-hc006
sub-hc007
sub-hc008
sub-hc009
sub-hc010
sub-hc011
sub-hc012
sub-hc014
sub-hc015
sub-hc016
sub-hc017
sub-hc019
sub-hc020
sub-hc021
sub-hc022
sub-hc023
sub-hc024
sub-hc025
sub-hc026
sub-hc028
sub-hc029
sub-hc030
sub-hc031
sub-hc032
sub-hc033
sub-hc034
sub-hc035
sub-hc036
sub-hc037
sub-hc042
sub-hc044
sub-hc045
sub-pp001
sub-pp002
sub-pp003
sub-pp004
sub-pp005
sub-pp006
sub-pp007
sub-pp008
sub-pp009
sub-pp010
sub-pp011
sub-pp012
sub-pp013
sub-pp014
sub-pp015
sub-pp016
response
sub-hc001
sub-hc002
sub-hc003
sub-hc004
sub-hc005
sub-hc006
sub-hc007
sub-hc008
sub-hc009
sub-hc010
sub-hc011
sub-hc012
sub-hc014
sub-hc015
sub-hc016
sub-hc017
sub-hc019
sub-hc020
sub-hc021
sub-hc022
sub-hc023
sub-hc024
sub-hc025
sub-hc026
sub-hc028
sub-hc029
sub-hc030
sub-hc031
sub-hc032
sub-hc033
sub-hc034
sub-hc035
sub-hc036
sub-hc037
sub-hc042
sub-hc044
sub-hc045
sub-pp001
sub-pp002
sub-pp003
sub-pp004
sub-pp005
sub-pp006
sub-pp007
sub-pp008
su

### Baseline Normalize Power

In [None]:
# cache stimulus baseline periods 
baselines = {}

for epo_type, epo_times in zip(epoch_types, epoch_times):
    print(epo_type)
    
    group = {'incongruent': [], 'congruent': []}
    for subject in subjects:
        if subject == 'group':
            continue
        print(subject)
        
        f = '../data/derivatives/eeg_sensor/%s/tfr/%s_%s_raw-tfr.h5'
        tfrs = read_tfrs(f % (subject, subject, epo_type))
        
        if epo_type == 'stimulus':
            baselines[subject] = {'incongruent': (-.5, -.1), 
                                  'congruent': (-.5, -.1)}
        
        norm_tfrs = []
        for i, c in enumerate(conditions):
            
            tfr, baseline = baseline_normalize(tfrs[i], baselines[subject][c])
            norm_tfrs.append(tfr)
            group[c].append(tfr)
            baselines[subject][c] = baseline
            
            
        f = '../data/derivatives/eeg_sensor/%s/tfr/%s_%s_norm-tfr.h5'
        write_tfrs(f % (subject, subject, epo_type), norm_tfrs, 
                   overwrite=True)
        
    group_tfrs = [grand_average(group[c]) for c in conditions]
    f = '../data/derivatives/eeg_sensor/group/tfr/group_%s_norm-tfr.h5'
    write_tfrs(f % (epo_type), group_tfrs, overwrite=True)
        
print('Done!')

stimulus
sub-hc001
sub-hc002
sub-hc003
sub-hc004
sub-hc005
sub-hc006
sub-hc007


## Visualize TFR Power

### TFR Power Heatmaps

In [2]:
def plot_tfr_heatmap(subject, ch, lim, behavior):
    
    exclusions = ['fast_rt', 'no_response', 'error', 'post_error']
    sub_behavior = behavior.loc[np.where(np.sum(behavior[exclusions], 
                                            axis=1) == 0)[0], :]
    if subject != 'group':
        sub_behavior = sub_behavior.loc[sub_behavior.participant_id == subject, :]
    sns.set(style='white', font_scale=2)
    plt.close('all')

    fig, axs = plt.subplots(2, 3, figsize=(24, 16))
    
    
    for i, epo_type in enumerate(epoch_types):
        
        f = '../data/derivatives/eeg_sensor/%s/tfr/%s_%s_norm-tfr.h5'
        tfrs = read_tfrs(f % (subject, subject, epo_type))

        powers = []
        for j, c in enumerate(conditions):
            
            power = tfrs[j]
            power.pick_channels([ch])
            ax = axs[i, j]
            rts = [sub_behavior[sub_behavior.trial_type == c].response_time]
            rt_colors = [colors[j]] 
            if epo_type == 'stimulus':
                ax = power_heatmap(power, ax, lim, rts, rt_colors)
            else:
                ax = power_heatmap(power, ax, lim)
            ax.set_title('%s-locked' % epo_type)
                

        ax = axs[i, 2]
        power = tfrs[0] - tfrs[1] 
        power.pick_channels([ch])
        rts = [sub_behavior[sub_behavior.trial_type == c].response_time for c in conditions]
        rt_colors = [colors[0], colors[1]]
        if epo_type == 'stimulus':
            ax = power_heatmap(power, ax, lim, rts, rt_colors)
        else:
            ax = power_heatmap(power, ax, lim)

    plt.tight_layout()
    plt.subplots_adjust(top=.92)
    plt.suptitle('%s %s TFR Heatmaps' % (subject, ch), fontsize=24)
    plt.show()

interact(plot_tfr_heatmap, subject=['group'] + subjects, ch=CH_NAMES, 
         lim=(.5, 4, .5), behavior=fixed(behavior));

### TFR Power Topomaps

In [3]:
def plot_topomap(subject, epo_type, time, fmin, fmax, col_limit):
    plt.close('all')
    f = '../data/derivatives/eeg_sensor/%s/tfr/%s_%s_norm-tfr.h5'
    tfrs = read_tfrs(f % (subject, subject, epo_type))

    f, axs = plt.subplots(1, 3, figsize=(24, 6)) 
    
    for i, tfr in enumerate(tfrs):
        tfr.plot_topomap(tmin=time, tmax=time + .005,
                         fmin=fmin, fmax=fmax, axes=axs[i], colorbar=True, 
                         show=False, vmin=-col_limit, vmax=col_limit)
        axs[i].set_xlabel(config['condition'][i])
        
    diff = tfrs[0] - tfrs[1]
    diff.plot_topomap(tmin=time, tmax=time + .005,
                      fmin=fmin, fmax=fmax, axes=axs[2], colorbar=True, 
                      show=False, vmin=-col_limit, vmax=col_limit)
    axs[2].set_xlabel('i-c')
    plt.suptitle('%.2f s' % time)
    
    plt.show()

interact(plot_topomap, subject=['group'] + subjects, 
         epo_type=['response', 'stimulus'],
         time=np.arange(-1, 1.75, .01), 
         fmin=frequencies, fmax=frequencies,
         col_limit=np.arange(.5, 5, .5));
