In [31]:
%matplotlib inline
import matplotlib.pyplot as plt

import seaborn as sns
sns.set_style("white")

import mne
from autoreject import LocalAutoRejectCV
mne.set_log_level('ERROR')

import pandas as pd
import scipy.stats as stats
from os import listdir
import numpy as np

loc_files="~/conincon_data/files/"
template = "{name}-{inst}.fif" # template to load files

df = pd.read_csv('~/conincon_data/lookUp.csv') # load csv with sentence info
event_ids = {"con/hc":201, "con/lc":200, "inc/lc":210, "inc/hc":211}

In [32]:
bads = ["VLA21", # not a native speaker of German 
        "GMA05","RAC22","PBT16","HHA01","BSN17"] # rejected participants

In [33]:
names = list({fname.split("-")[0]  # the names of all datasets
              for fname in listdir(loc_files) 
              if "raw" in fname and fname.split("-")[0] not in bads # excludes rejected participants
             })

In [48]:
def fix_events(events, srate=100):
    
    
    '''Look for critical word onsets for sentence items
     calculates word onset relative to trigger event and 
     creates a new events file that indexes word onsets
    events: events file from the raw data and with triggers and time points'''
    
    for cond, trg in event_ids.items():
        inds = np.where(events[:, 2] == trg)[0]  # find all events belonging to this condition
        congruency, cloze = cond.split("/")
        for ind in inds:
            item = events[ind - 1, 2]  # check item trigger (precedes the condition trigger)
            s = "label == '{}_{}_{}'".format(congruency, cloze, item - 1)
            t = int((df.query(s)["onset"].values[0])*srate)  # add word onset delay
            
            events[ind, 0] += t+(0.255*srate) # 255 ms delay between trigger and audio onset which is taken into account
    return events
    
def get_epoch(name, times=dict(tmin=-.3, tmax=1.3)):
    
    '''Creates the epoched data for participant from -.3 to 1.3 relative to onset
    name: participant
    times: epoch time window tmin for start and tmax for end
    '''
    
    params = dict(name=name, inst='raw') # parmeters to insert into template
    
    with mne.io.read_raw_fif(loc_files+template.format(**params)) as raw:
        events = mne.find_events(raw, min_duration=0, shortest_event=0)
        events = fix_events(events) 
        
        raw.load_data()
        raw.filter(.3, 20, n_jobs=12, phase="zero", filter_length='auto',
                   l_trans_bandwidth='auto', h_trans_bandwidth='auto')
        picks = mne.pick_types(raw.info, eeg=True, stim=False)
        
        return mne.Epochs(raw, events, event_ids, preload=True,
                          baseline=(None, 0),
                          picks=picks, tmin=-.3,tmax=1.3)


In [49]:
# load previously calculted autoreject and repair
import pickle
with open(loc_files+'ar.pckl', 'rb') as f: 
        autorejs = pickle.load(f)

In [50]:
# load ICAs
icas=dict()
for name in names:
    params = dict(name=name, inst='ica')
    icas[name] = mne.preprocessing.ica.read_ica(loc_files+'{name}-{inst}.fif'.format(**params))

In [51]:
def clean_epochs(epochs, name, on_ica=False):
    '''Function to apply the autorejections and ICAs to the epoched data
    here, we apply both autoreject and the ICA. 
'''
    
    ica = icas[name]
    # apply previously calculated ICA to exclude eye ICs
    epochs = ica.apply(epochs.load_data(), exclude=ica.labels_["eog"])
    
    # reject trials based on previously fitted thresholds and interpolate bad channels
    epochs = autorejs[name].transform(epochs)
    if on_ica:
        epochs = ica.get_sources(epochs).drop_channels(
            ['ICA0{:02d}'.format(ii) for ii in ica.labels_["eog"]])
        #epochs.drop_channels(["SO1", "SO2", "FP1", "FP2"])
        mapping = {name:'eeg' for name in epochs.ch_names}
        epochs.set_channel_types(mapping);

    return epochs

In [56]:
def saving_epochs(name, on_ica=False):
    ''' function to load epoch, apply ICA correction and autorejection & repair
     saves epoch to specified folder'''
    
    epochs = get_epoch(name)
    epochs = clean_epochs(epochs, name, on_ica=on_ica)
    params = dict(name=name, OIs=("-on_ica" if on_ica else ""), inst="epo")
    epochs.save(loc_files+"{name}{OIs}-{inst}.fif".format(**params))

In [57]:
%%capture
for name in names:
    [saving_epochs(name, on_ica=item) for item in [True, False]]