ECoG data pipeline for functional cortical mapping by Jay Jeschke and Daniel Maksumov.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import mne
from mne.io import read_raw_edf
import os
import os.path as op
plt.ion()
plt.style.use('seaborn-white')
import sys
from scipy import signal
import math
import h5py
import warnings
import scipy.stats as stats
import scipy.signal as sig
from ecogMethods import *
import hdf5storage

np.set_printoptions( precision=5)


SJ = 'NY723' # "Subject ID"

# the_misc contains channels we can ignore when plotting
# check raw.ch_names to edit if needed
the_misc = ['C175','C176','EKG1','EKG2','ECG1','ECG2','EKGL','EKGR',
            'ECGL','ECGR','LEKG','REKG','LECG','RECG','DC1','DC2','DC3',
            'DC4','DC5','DC6','DC7','DC8','DC9','DC10','DC11','DC12',
            "TRIG","OSAT",'PR','Pleth','STI 014']


# Directory of your EDF file
home_path = os.getcwd() # choose path for resulting data folders


edf_path = op.join(home_path)# example of path where EDF lives 
edf_file = SJ + '_FunctionalMapping_512.EDF'# name of EDF file
create_dir= False

# What tasks we're running
tasks = ['PicN','VisRead','AudRep','AudN','SenComp'] 
clrs = 'bcgmrykw'[:len(tasks)] #plotting colors corresponding to the above tasks



# Bad channel detection presets, if automatic_bads=True, bad channels are detected automatically
# if manual_bads is also true, clickable channels will be presented with detected bads already selected
automatic_bads=True
manual_bads=True


#If analysis has already been done in MATLAB, set from_mat to True and skip down to the last 2 cells for plotting
from_mat=False


matDir='' #MATLAB directory where Functional mappiing data lives


# Specify DC channel the trigger and mic 
trigger_name= 'DC1' 
mic_name = 'DC2'




In [None]:
# Put EDF data into raw object
raw = read_raw_edf(op.join(edf_path, edf_file), preload=True,misc=the_misc)

In [None]:
# Sampling frequency or rate
srate = int(raw.info["sfreq"])
try:
    final_index = next(x for x,val in enumerate(raw.ch_names) 
                       if val.startswith('DC1') or val.startswith('ECG') 
                       or val.startswith('EKG')) #pick the first channel you'd like to exclude
except:
    print(raw.ch_names)
    input('Pick label of first channel to exclude after ECoG channels:  ')


# Creates list of electrode names
labels = raw.ch_names[0 : final_index]

# Creates list of electrodes
elecs = [i for i in range(final_index)]

In [None]:
# Enter bad electrodes manually, or leave bad_elecs empty and select from the plot
# To select/unselect bad channels in the plot, click the label of the channel on the left of the plot
# we suggest you do manually select bad channels along and not rely soley on the suggested bad channels
bad_elecs = [] 



if not bad_elecs:
    if automatic_bads:
        data,times = raw.copy()[:len(elecs),0:30*srate]## gets data from first 30 seconds
        # bad channels are suggested based on deviation from average channel amplitude and 
        # spectral power between 55-65 Hz
        # if too many good channels are being rejected increase the threshold
        suggested_bads= detect_bads(data, low_bound=55, up_bound=65,thresh=.9)
        raw.info['bads']=[raw.ch_names[i] for i in suggested_bads]
        
        
        if manual_bads:
            %matplotlib qt5
            raw.plot(block = True, duration=5, scalings = dict(eeg=10e-5,misc=10e-2))
            bad_elecs=[raw.ch_names.index(i) for i in raw.info['bads']]
        else:
            bad_elecs=[raw.ch_names.index(i) for i in raw.info['bads']]
    else:
        %matplotlib qt5
        raw.plot(block = True, duration=5, scalings = dict(eeg=10e-5,misc=10e-2))
        bad_elecs=[raw.ch_names.index(i) for i in raw.info['bads']]
print('Bad electrodes: ',bad_elecs)

In [None]:
#Creating global variables for each task
for task in tasks:
    create_subj_globals(SJ, task, srate, srate, elecs, bad_elecs, TANK=[], root_path=home_path) 

In [None]:
trigger_data, trigger_times = raw.copy().pick_channels([trigger_name])[:,:]
mic_channel = raw.copy().pick_channels([mic_name])[:,:][0]

clips_indices = extract_blocks(data=trigger_data, times=trigger_times, subj=SJ, tasks=tasks, srate=srate, 
                                      blockMin=180 ,eventMin=.8,gap=30, 
                                      trigger_len=1.5, thresh='')

#plotting will show each task/block in a unique color 
#for visual confirmation of start/stop indices
# If plotting triggers, make true:

plot_trigger = True

if plot_trigger:
    %matplotlib qt5
    
    for i,ind in enumerate(clips_indices):
        start= clips_indices[i][0]
        end = clips_indices[i][1]
        t, = plt.gca().plot(np.arange(start,end),trigger_data[0][start:end],lw=2,c=clrs[i]) 
        
#for optimizing the block marker locater: gap is the minimum time gap between blocks, 
#thresh is the threshold that a trigger must exceed above the zero-meaned baseline to be counted

In [None]:
#if you prefer, you can manually enter block indices in the following format:
# clips_indices= [[block 1 start index,stop index], [block 2 start index,stop index]...]
#remember to multiply clock time by sample rate
# clips_indices[0][0] = 180000



In [None]:
#saving data specific to each task
for cnt, task in enumerate(tasks):
    x = get_subj_globals(SJ,task, root_path=home_path, from_mat=from_mat, matDir=matDir)
    time_i = clips_indices[cnt][0]
    time_f = clips_indices[cnt][1]
    trigger = trigger_data[:, time_i:time_f]    
    mic = mic_channel[:, time_i:time_f]
    gdat = raw[:,:][0][x.elecs, time_i:time_f]
    gdat = gdat*10**6 #scaling from microvolts to volts
    labels = [np.string_(ch) for ch in labels]    #need to prompt user for event minimum duration
    labels = np.asarray(labels)
    
    save_h5(op.join(x.DTdir,'gdat.h5'), "gdat", gdat)
    save_h5(op.join(x.DTdir,'labels.h5'), "labels", labels)
    save_h5(op.join(x.DTdir,'trigger.h5'), "trigger", trigger)
    save_h5(op.join(x.DTdir,'mic.h5'), "mic", mic)

In [None]:
#creating "events" objects for each task

trig_data, trig_times = raw.copy().pick_channels([trigger_name])[:,:]


for cnt, task in enumerate(tasks):
    x = get_subj_globals(SJ,task, root_path=home_path, from_mat=from_mat, matDir=matDir)
    time_i = clips_indices[cnt][0]
    time_f = clips_indices[cnt][1]
    data,times = raw.copy().pick_channels([trigger_name])[:,time_i:time_f]
    
#     data=trig_data[time_i:time_f]
# #     data=data.T
#     print(np.shape(data))
#     data = data-np.mean(data)#mean center
#     data = data/abs(max(data.T))#normalize
#     data= data.clip(min=0)
#     plt.figure()
#     plt.plot()
    exec(task+'_events= extract_task_events(data,times,task, x.subj, srate = srate,start=time_i, eventMin=srate)')
#     exec(task+'_events= extract_task_events(trig_data,trig_times,task, x.subj, srate = srate, start=time_i, stop=time_f, eventMin=.5)')
    warnings.warn(f'Saving and overwriting data for:  {x.subj} {task} \n')
    # you can optimize the task trigger finder by adjusting the eventMin (event minimum length) 
    #and the trigger threshold
   
    TE = op.join(x.ANdir, 'events.h5')
    with h5py.File(TE, 'w') as hf:
        grp = hf.create_group('Events')
        exec('size=np.asarray(np.shape('+task+'_events.onset))')
        exec('data=np.asarray('+task+'_events.onset)')
        onsets = grp.create_dataset('onset', size, dtype=int, data=data)

        exec('asciiList = [n.encode("ascii", "ignore") for n in '+task+'_events.event]')
        exec('data=np.asarray(asciiList)')
        events = grp.create_dataset('event', size, dtype='S100', data=data)
        exec('data=np.asarray('+task+'_events.badevent)')
        badevents = grp.create_dataset('badevent', size, dtype=int, data=data)
    

In [None]:
#Common average reference applied to data each task
for i, task in enumerate(tasks):
    g = get_subj_globals(SJ,task, root_path=home_path, from_mat=from_mat, matDir=matDir)
    warnings.warn(f'Saving and overwriting data for:  {x.subj} {task} \n')
    create_CAR(SJ, task, x.bad_elecs, home_path) # saves car data for each task
    

In [None]:
#Creating Plot of all good electrodes for all tasks

%matplotlib qt5
transparency = .2
clrs_r = clrs[::-1]

g = get_subj_globals(SJ,tasks[0], root_path=home_path, from_mat=from_mat, matDir=matDir)
es = [i for i in g.elecs if i not in g.bad_elecs]

# Which electrodes to plot
all_elecs = True
if all_elecs:
    e_init = 1
    e_final = len(es)
else:
    e_init = 132
    e_final = 135 ## exclusive

# Grid dimensions
length= 20 # number of columns in grid plot
height = math.ceil(len(es)/length)

## leave option to leave out axis markers
for i, task in enumerate(tasks[::-1]):
    
    if from_mat==True:
        print(from_mat)
        labels= hdf5storage.loadmat(op.join(g.DTdir,"Labels.mat"))
        labels=[i[0] for i in labels['Labels'][0]]

    else:
        labels = load_h5(op.join(g.DTdir,"labels.h5"), "labels")
        labels = [ch.decode() for ch in labels] #turns label array into list of strings
    
        car_data = load_h5(op.join(g.DTdir,"car_data.h5"), "car_data") ## takes about 100 seconds to load
    
    params = Params('','','','','','','','','')
    params.scale = 0.8
    params.en = 4000 if clrs[i] in clrs[:2] else 2000 # Shows first 4 seconds of the last 2 tasks

    for j in range(e_init, e_final):       
    
        
        t = plot_single(g.subj,task,es[j-1],params, root_path = home_path, f1=70,f2=150,raw=0,gdat=car_data,db=0, matDir=False)
        t = my_conv(t,100)
        m = np.mean(t, axis = 0)

        v = np.arange(params.st, params.en)
        plt.subplot(height,length,j-(e_init-1))
        plt.title(labels[es[j-1]],pad = -10)

        if params.shade_plot:
            sem = stats.sem(t, axis = 0)
            plt.fill_between(v, m-sem, m+sem, alpha = transparency, color = clrs_r[i])

        plt.plot(v, m, color = clrs_r[i])
        plt.ylim(-10,175)


