In [None]:
import mne
from mne.preprocessing import ICA
import pycsd.pycsd as pycsd
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
from scipy import stats,optimize
import pickle
from utils import *
%matplotlib inline
%matplotlib qt

montage_path='/usr/local/lib/python3.7/site-packages/mne/channels/data/montages'
data_path = path_data

In [None]:
def LoadEEGData(vpn, montage_path):
    # This function loads raw EEG, EOG and EMG data and references it.
    raw = mne.io.read_raw_edf('{}subj{}.bdf'.format(path_eeg_subj, vpn), preload=True)
    # Fix channel names:
    for name in raw.ch_names[:-1]:
        raw.rename_channels({name: name[2:]})
    raw.drop_channels(raw.ch_names[64 + 10:len(raw.ch_names) - 1])
    
    # Fix data type:
    montage = mne.channels.read_montage(kind='biosemi64', path=montage_path)
    raw.info['description']='BioSemi/64'
    raw.set_montage(montage)
    
    raw.set_channel_types(mapping={'EMG1a': 'emg','EMG1b': 'emg','FT10': 'emg', 'PO10': 'emg', 
                                   'HeRe': 'eog', 'HeLi': 'eog', 'VeUp': 'eog', 'VeDo': 'eog',
                                   'STI 014': 'resp'})
        
    mne.set_bipolar_reference(raw,
                              anode=['EMG1b','PO10', 'HeRe', 'VeUp'],
                              cathode=['EMG1a','FT10', 'HeLi', 'VeDo'],
                              ch_name=['EMGleft','EMGright','EOGx','EOGy'],
                              copy=False)
    
    raw.set_eeg_reference(ref_channels=['PO9', 'FT9'])
    raw.drop_channels(['PO9', 'FT9'])
    
    reorderedChannelNames=['Fpz','Fp1','Fp2',
                  'AFz','AF3','AF4','AF7','AF8',
                  'Fz','F1','F2','F3','F4','F5','F6','F7','F8',
                  'FCz','FC1','FC2','FC3','FC4','FC5','FC6','FT7','FT8',
                  'Cz','C1','C2','C3','C4','C5','C6','T7','T8',
                  'CPz','CP1','CP2','CP3','CP4','CP5','CP6','TP7','TP8',
                  'Pz','P1','P2','P3','P4','P5','P6','P7','P8','P9','P10',
                  'POz','PO3','PO4','PO7','PO8',
                   'Oz','O1','O2',
                  'Iz',
                  'EOGx', 'EOGy','EMGleft','EMGright','STI 014']

    raw.reorder_channels(reorderedChannelNames)
    
    
    return raw



def FilterEEGData(vpn,raw):
    # This function filters raw data and saves filtered data
    raw.filter(l_freq=1/7, h_freq = 100)
    raw.notch_filter([50], filter_length='auto',phase='zero', fir_design='firwin')
    raw.save('{}subj{}_filtered_raw.fif'.format(path_eeg_subj, vpn))
    return raw



def FindTriggers(raw):
    # This functions finds trigger onsets
    events = mne.find_events(raw, stim_channel='STI 014', uint_cast=True, consecutive=True, min_duration=0.001)
    event_id = {'EncStart':10 ,'EncMaskInit':12, 'EncEmpty':14, 'EncStim':16, 'EncMaskEnd':18,
               '2backStart':20, '2backNumber':22, '2backResp':24, '2backBlank':26,
               'DecStart':28, 'DecBlockStart':30, 'DecMaskInit':32, 'DecEmpty': 34, 'DecStim':36,'DecMaskEnd':40,
               'RecStart':42, 'RecMaskInit':44, 'RecEmpty':46, 'RecStim': 48,
               'PauseStart':52, 'CalibStart':54,
               'FixBreak':60, 'FixWaitFail':70, 'DecF':62, 'DecJ':64, 'RecF':66, 'RecJ':68}
    return events



def GetPeriodOnsets(onsetTrigger,offsetTrigger,subjectDetails, events):
    # This function outputs the onsets and offsets of experimental periods (e.g. encoding, decision)
    onsets  = events[events[:,2]==onsetTrigger,0]
    offsets = events[events[:,2]==offsetTrigger,0]
    
    #chop off onsets/offsets of practice runs:
    if len(onsets)==11:
        onsets=np.delete(onsets,0)
        print('NOTIFICATION: practice run onsets deleted')
    if len(offsets)==11:
        offsets=np.delete(offsets,0)
        print('NOTIFICATION: practice run offsets deleted')
    
    # delete invalid trials if applicable
    if len(subjectDetails['invalidRuns'])>0:
        onsets,offsets=np.delete(onsets,subjectDetails['invalidRuns']-1),np.delete(offsets,subjectDetails['invalidRuns']-1)
        print('NOTIFICATION: invalid run onsets/offsets were deleted')
        
    if len(onsets) == len(subjectDetails['validRuns']):
        pass
    else:
        print('PROBLEM: number of encoding onsets runs does not match number of valid runs. Check whats going on!')
    
    return onsets, offsets



def GetPeriodData(vpn, periodLabel, raw, subjectDetails, events, samplingRate=1024, timeAddConstant=7):
    # This function aggreagates data from an experimental period (e.g. encoding, decision) and saves and outputs the aggregate.
    
    if periodLabel == 'Encoding':
        startTrigger,stopTrigger = 10,20
    elif periodLabel == 'Decision':
         startTrigger,stopTrigger = 28,42
    
    periodOnsets, periodOffsets = GetPeriodOnsets(startTrigger,stopTrigger,subjectDetails, events)
    
    numRuns=len(periodOnsets)
    periodOnsets = periodOnsets/samplingRate - timeAddConstant
    periodOffsets = periodOffsets/samplingRate + timeAddConstant
    for i in np.arange(numRuns):
        dataRun = raw.copy().crop(tmin=periodOnsets[i], tmax=periodOffsets[i])
        print(dataRun)
        try: 
            dataEEG.append(raws=dataRun, preload=None)
        except:
            dataEEG = dataRun

    dataEEG.save('{}subj{}_{}_filtered_raw.fif'.format(path_eeg_subj, vpn, periodLabel))
    
    # save events
    eventsNew=np.transpose(np.array([periodOffsets*samplingRate, np.zeros(numRuns), np.ones(numRuns)],dtype='int64'))
    eventsUpdated = np.vstack([events, eventsNew])
    eventsUpdated = eventsUpdated[eventsUpdated[:,0].argsort()]
    mne.write_events('{}subj{}_{}-eve.fif'.format(path_eeg_subj, subjectID, 'Decision'), eventsUpdated)
    return dataEEG



def GetSubjectDetails(subjID):
    # This function loads subject details.
    pickle_in = open('{}subj{}/Termin2/behavioral/subject{}_details.pickle'.format(data_path, subjectID, subjectID),"rb")
    subjectDetails = pickle.load(pickle_in)
    
    return subjectDetails



def CheckEpochs(dataEEG,events,eventID,tMin=-1,tMax=2,baseLine=(-.5,0)):
    # This function does epoch the data and plots epochs for inspection.
    epochs = mne.Epochs(dataEEG, events, event_id=[eventID], tmin=tMin, tmax=tMax,
                        proj=True, picks=None, baseline=baseLine,
                        preload=True, reject=None)  
    epochs.pick_types(meg=False, eeg=True, stim=True, eog=True, emg=True)
    #epochs.drop(subjectDetails['preprocessingDecision_excludedTrials'], reason='USER', verbose=None)
    
    %matplotlib qt
    scaleParams = dict(mag=1e-12, grad=4e-11, eeg=150e-6, eog=25e-5, ecg=5e-4,
                       emg=1e-3, ref_meg=1e-12, misc=1e-3, stim=1, resp=1, chpi=1e-4,
                       whitened=10.)
    epochs.plot(n_channels = 68, title=None,scalings = scaleParams)
    return epochs



def GetFixationBreakTrials(dataBehavEncoding,dataBehavDecision):
    # This function outputs fixation break trial number (index starts with 1) and time of fix break in ms.
    trialsFixationBreak = dataBehavEncoding.loc[(dataBehavEncoding['fixationBreakTarget'] < 3) 
                                            & (dataBehavDecision['mode'] == 'memory')]
    trialsFixationBreak.index=trialsFixationBreak.index+1
    timeFixationBreak = trialsFixationBreak['timeFixationBreak']-trialsFixationBreak['timeStimulusOnset']
    return timeFixationBreak



def FitIca(data,numComponents=20):
    method = 'fastica'
    decim = 3  # we need sufficient statistics, not all time points -> saves time
    random_state = 1337
    ica = ICA(n_components=numComponents, method=method, random_state=random_state)
    reject = dict(mag=5e-12, grad=4000e-13)
    ica.fit(data, decim=decim, reject=reject)
    return ica



def FindEOGArtifacts(ica,data):
    # This function finds blinks using EOG data and plots correlation with independent components
    %matplotlib inline
    n_max_eog = 1
    eog_inds, scores = ica.find_bads_eog(data) 
    ica.plot_scores(scores, exclude=eog_inds, labels='eog') 
    eog_average = mne.preprocessing.create_eog_epochs(raw, reject = None ,picks=None).average()
    ica.plot_sources(eog_average, exclude=eog_inds)  # look at source time course

    

def SaveSubjectDetails(subjDet):
    # This function saves the subjectDetails pickle.
    pickle_out = open('{}subj{}/Termin2/behavioral/subject{}_details.pickle'.format(data_path, subjectID, subjectID),"wb")
    pickle.dump(subjectDetails, pickle_out)    
    pickle_out.close()

    
    
def FinishPreprocessing(data, ica, subjectID, periodLabel):
    # This function finishes preprocessing. It deletes trials, applies to-be-excluded components and saves data.
    subjectDetails = GetSubjectDetails(subjectID)
    
    #interpolate channels
    data.info['bads']=subjectDetails['preprocessing%s_interpolatedChannels'%periodLabel]
    data.interpolate_bads(reset_bads=True, mode='accurate', origin=(0.0, 0.0, 0.04), verbose=None)
    
    #data.drop(subjectDetails['preprocessingDecision_excludedTrials'], reason='USER', verbose=None)
    #SPACE FOR TRIALS TO BE DELETED, using subjectDetails['ppEnc_excludedTrials']. DONT FORGET TO DELETE TRIALS ALSO FROM BEHAVDATAFILE
    exclCompNum = subjectDetails['preprocessing%s_excludedComponentNumbers'%periodLabel]
    
    if len(exclCompNum)>0:
        ica.apply(data, include=None, exclude=exclCompNum, n_pca_components=None, start=None, stop=None)
    data.save('{}subj{}_{}_preprocessed_raw.fif'.format(path_eeg_subj, subjectID, periodLabel))
    return data

# Load relevant data
1. specify subject data
2. load data and filter it

In [None]:
subjIdx=39

subjectID=subjIDs[subjIdx]
subjectDetails = GetSubjectDetails(subjectID)
path_eeg_subj = '{}/subj{}/Termin2/eeg/'.format(data_path, subjectID)
path_behav_subj = '{}/subj{}/Termin2/behavioral/'.format(data_path, subjectID)

dataBehavEncoding =pd.read_csv('%sdataEncoding'%path_behav_subj,sep=',',header=0)
dataBehavDecision =pd.read_csv('%sdataDecision'%path_behav_subj,sep=',',header=0)

In [None]:
raw = LoadEEGData(subjectID, montage_path)
raw = FilterEEGData(subjectID,raw)

# Encoding
1. get period data
2. check epochs
3. perform ICA
4. finalize preprocessing

In [None]:
raw=mne.io.Raw('{}subj{}_filtered_raw.fif'.format(path_eeg_subj, subjectID), 
               allow_maxshield=False, preload=True, verbose=None)
events  = FindTriggers(raw)

dataEEGEnc=GetPeriodData(subjectID, 'Encoding', raw, subjectDetails, events)   
#dataEEGEnc=mne.io.Raw('{}subj{}_{}_filtered_raw.fif'.format(path_eeg_subj, subjectID, 'Encoding'), allow_maxshield=False, preload=True, verbose=None)

In [None]:
eventsEnc  = FindTriggers(dataEEGEnc)
epochsEnc = CheckEpochs(dataEEGEnc,eventsEnc,16)

In [None]:
numComponents=20
ica = FitIca(dataEEGEnc,numComponents)

%matplotlib qt
ica.plot_components()

FindEOGArtifacts(ica,dataEEGEnc)

%matplotlib qt
ica.plot_properties(epochsEnc, picks=np.arange(20))
ica.plot_components()

In [None]:
subjectDetails = GetSubjectDetails(subjectID)

subjectDetails['preprocessingEncoding_excludedTrials'] = np.array([])-1
subjectDetails['preprocessingEncoding_excludedComponentNumbers'] = []
subjectDetails['preprocessingEncoding_interpolatedChannels'] = []

SaveSubjectDetails(subjectDetails)

dataEEGEnc_pp = FinishPreprocessing(dataEEGEnc, ica, subjectID, 'Encoding')

# Decision
1. get period data
2. check epochs
3. perform ICA
4. finalize preprocessing

In [None]:
subjectID=subjIDs[subjIdx]
events  = FindTriggers(raw)
dataEEGDec=GetPeriodData(subjectID, 'Decision', raw, subjectDetails, events)  

In [None]:
eventsDec  = FindTriggers(dataEEGDec)
epochsDec = CheckEpochs(dataEEGDec,eventsDec,36,tMax=2)

In [None]:
numComponents=20
ica = FitIca(dataEEGDec,numComponents)

FindEOGArtifacts(ica,dataEEGDec)

ica.plot_components()
%matplotlib qt
ica.plot_properties(epochsDec, picks=np.arange(numComponents))

In [None]:
subjectDetails = GetSubjectDetails(subjectID)

subjectDetails['preprocessingDecision_excludedTrials'] = np.array([])-1
subjectDetails['preprocessingDecision_excludedComponentNumbers'] = []
subjectDetails['preprocessingDecision_interpolatedChannels'] = []

SaveSubjectDetails(subjectDetails)

dataEEGDec_pp = FinishPreprocessing(dataEEGDec, ica, subjectID, 'Decision')