ECoG data pipeline for functional cortical mapping by Jay Jeschke and Daniel Maksumov.

In [None]:
import matplotlib.pyplot as plt
from matplotlib import colors as mcolors
%matplotlib
import numpy as np
import mne
from mne.io import read_raw_edf
import os
import os.path as op
from visbrain.gui import Brain
from visbrain.objects import SourceObj, BrainObj
from visbrain.io import download_file, path_to_visbrain_data
plt.ion()
plt.style.use('seaborn-white')
import sys
from scipy import signal
import math
import h5py
import warnings
import scipy.stats as stats
import scipy.signal as sig
from ecogMethods import *
import multiprocessing
from joblib import delayed, Parallel
import gc

import hdf5storage  ##  uncomment if you are importing similarly formatted data from MATLAB
options= hdf5storage.Options(store_python_metadata=True)
np.set_printoptions( precision=5)


SJ = 'NY723' # "Subject ID"

# the_misc contains channels we can ignore when plotting
# check raw.ch_names to edit if needed
the_misc = ['C175','C176','EKG1','EKG2','ECG1','ECG2','EKGL','EKGR',
            'ECGL','ECGR','LEKG','REKG','LECG','RECG','DC1','DC2','DC3',
            'DC4','DC5','DC6','DC7','DC8','DC9','DC10','DC11','DC12',
            "TRIG","OSAT",'PR','Pleth','STI 014']


# Directory of your EDF file
home_path = os.getcwd() # choose path for resulting data folders


edf_path = op.join(home_path,"NY723_FM")# example of path where EDF lives 
edf_file = SJ + '_FunctionalMapping_512.EDF'# name of EDF file
create_dir= False

# What tasks we're running
tasks = ['PicN','VisRead','AudRep','AudN','SenComp'] 
clrs= ['xkcd:blue','xkcd:cyan','xkcd:green', 
     'xkcd:magenta','xkcd:red',' xkcd:yellow',
     'xkcd:black','xkcd:teal', 'xkcd:palegreen',
      'xkcd:coral']
clrs = clrs[:len(tasks)] #plotting colors corresponding to the above tasks



# Bad channel detection presets, if automatic_bads=True, bad channels are detected automatically
# if manual_bads is also true, clickable channels will be presented with detected bads already selected
automatic_bads=True
manual_bads=True


#If analysis has already been done in MATLAB, set from_mat to True and skip down to the last 2 cells for plotting
from_mat=False


matDir='' #MATLAB directory where Functional mappiing data lives


# Specify DC channel the trigger and mic 
trigger_name= 'DC1' 
mic_name = 'DC5'


In [None]:
# Put EDF data into raw object
raw = read_raw_edf(op.join(edf_path, edf_file), preload=True,misc=the_misc)

In [None]:
# Sampling frequency or rate
srate = int(raw.info["sfreq"])

misc= [val for val in raw.ch_names if val.startswith('DC') or val.startswith('ECG') 
                       or val.startswith('EKG') or val.startswith('SG') ]
the_misc= the_misc +misc
# elec_label=[i for i in labels if i not in the_misc]

# Creates list of electrode names
# labels = raw.ch_names[0 : final_index]
labels = [i for i in raw.ch_names if i not in the_misc]

# Creates list of electrodes
elecs = [i for i in range(len(labels))]


In [None]:
# Enter bad electrodes manually, or leave bad_elecs empty and select from the plot
# To select/unselect bad channels in the plot, click the label of the channel on the left of the plot
# we suggest you do manually select bad channels along and not rely soley on the suggested bad channels
bad_elecs = [] 




%matplotlib
if not bad_elecs:
    if automatic_bads:
        data,times = raw.copy()[:len(elecs),0:30*srate]## gets data from first 30 seconds
        # bad channels are suggested based on deviation from average channel amplitude and 
        # spectral power between 55-65 Hz
        # if too many good channels are being rejected increase the threshold
        suggested_bads= detect_bads(data, low_bound=55, up_bound=65,thresh=.9)
        raw.info['bads']=[raw.ch_names[i] for i in suggested_bads]
        
        
        if manual_bads:
            raw.plot(block = True, duration=5, scalings = dict(eeg=10e-5,misc=10e-3))
            bad_elecs=[raw.ch_names.index(i) for i in raw.info['bads']]
        else:
            bad_elecs=[raw.ch_names.index(i) for i in raw.info['bads']]
    else:
        raw.plot(block = True, duration=5, scalings = dict(eeg=10e-5,misc=10e-2))
        bad_elecs=[raw.ch_names.index(i) for i in raw.info['bads']]
    plt.pause(.0001)
print('Bad electrodes: ',np.sort(bad_elecs))

In [None]:
trigger_data, trigger_times = raw.copy().pick_channels([trigger_name])[:,:]
mic_channel = raw.copy().pick_channels([mic_name])[:,:][0]

clips_indices = extract_blocks(data=trigger_data, times=trigger_times, subj=SJ, tasks=tasks, srate=srate, 
                                      blockMin=180 ,gap=20, 
                                      trigger_len=1.5, thresh='')

#plotting will show each task/block in a unique color 
#for visual confirmation of start/stop indices
# If plotting triggers, make true:

plot_trigger = True

if plot_trigger:
    %matplotlib

    for i,ind in enumerate(clips_indices):
        start= clips_indices[i][0]
        end = clips_indices[i][1]
        t, = plt.gca().plot(np.arange(start,end),trigger_data[0][start:end],lw=2,c=clrs[i])
        plt.pause(.00001)
    plt.show()

#for optimizing the block marker locater: gap is the minimum time gap between blocks, 
#thresh is the threshold that a trigger must exceed above the zero-meaned baseline to be counted

In [None]:
#creating "events" objects for each task

trig_data, trig_times = raw.copy().pick_channels([trigger_name])[:,:]


for cnt, task in enumerate(tasks):
    x = get_subj_globals(SJ,task, root_path=home_path, from_mat=from_mat, matDir=matDir)
    time_i = clips_indices[cnt][0]
    time_f = clips_indices[cnt][1]
    data,times = raw.copy().pick_channels([trigger_name])[:,time_i:time_f]
    

    exec(task+"_events= extract_task_events(data,times,task, x.subj, srate = srate,start=time_i, eventMin=srate,thresh='', practiceTrials=3)")
    warnings.warn(f'Saving and overwriting data for:  {x.subj} {task} \n')
    # you can optimize the task trigger finder by adjusting the eventMin (event minimum length) 
    #and the trigger threshold
   
    TE = op.join(x.ANdir, 'events.h5')
    with h5py.File(TE, 'w') as hf:
        grp = hf.create_group('Events')
        exec('size=np.asarray(np.shape('+task+'_events.onset))')
        exec('data=np.asarray('+task+'_events.onset)')
        onsets = grp.create_dataset('onset', size, dtype=int, data=data)

        exec('asciiList = [n.encode("ascii", "ignore") for n in '+task+'_events.event]')
        exec('data=np.asarray(asciiList)')
        events = grp.create_dataset('event', size, dtype='S100', data=data)
        exec('data=np.asarray('+task+'_events.badevent)')
        badevents = grp.create_dataset('badevent', size, dtype=int, data=data)
####
###
###
## include a way to manually input bad trials

In [None]:
#Creating Plot of all good electrodes for all tasks
%matplotlib 
transparency = .2
clrs_r = clrs[::-1]
g = get_subj_globals(SJ,tasks[0], root_path=home_path, from_mat=from_mat, matDir=matDir)
es = [i for i in g.elecs if i not in g.bad_elecs]

line_plot=True


# Which electrodes to plot
all_elecs = True
if all_elecs:
    e_init = 0
    e_final = es[-1]
else:
    e_init = 102
    e_final = 108 ## exclusive
    
pElecs= [i for i in es if i in np.arange(e_init,e_final+1)] #good electrodes for plotting

# Grid dimensions
length= 20 # number of columns in grid plot
height = math.ceil(len(pElecs)/length)

## leave option to leave out axis markers

with warnings.catch_warnings():
    warnings.simplefilter("ignore") ## this is here to supress unecessary warning
    if line_plot:
        fig,axs = plt.subplots(nrows=height,ncols=length, sharex=True, sharey=True)
    
    for i, task in enumerate(tasks[::-1]):
        
        
        params = Params('','','','','','','','') #creating params class object with default parameters
        params.en = 4000 if clrs_r[i] in clrs_r[:2] else 2000 # amount of time in miliseconds plotted, in this case
        #400 for the first 2 tasks and 2000 for the remaining tasks
        
        if from_mat==True:
            print(from_mat)
            labels= hdf5storage.loadmat(op.join(g.DTdir,"Labels.mat"))
            labels=[i[0] for i in labels['Labels'][0]]

        else:
            labels = load_h5(op.join(g.DTdir,"labels.h5"), "labels")
            labels = [ch.decode() for ch in labels] #turns label array into list of strings
            
        for ind,j in enumerate(pElecs):       

            #analysis steps
            t = plot_single(g.subj,task,j,params, root_path = home_path, f1=70,f2=150,gdat='',raw=0, matDir=False)
            t = my_conv(t,100)
            m = np.mean(t, axis = 0)
      
        
            if line_plot:
                #plotting 1 task at 1 electrode at a time
                v = np.arange(params.st, params.en)
                col=ind-(int(np.floor(ind/length)))*length

                ax1=plt.subplot(height,length,ind+1)
                ax1.text(0,141,labels[j], fontsize=6)

    #             standard error mean plotting
                if params.shade_plot:
                    sem = stats.sem(t, axis = 0)
                    ax1.fill_between(v, m-sem, m+sem, alpha = transparency, color = clrs_r[i])

                ax1.plot(v, m, color = clrs_r[i], scaley=False)

                plt.ylim(-10,175)

                if col!=0:
                    plt.setp(ax1.get_yticklabels(), visible=False)
                else:
                    plt.setp(ax1.get_yticklabels(), fontsize=8)

                if (ind+20)<len(pElecs):
                    plt.setp(ax1.get_xticklabels(), visible=False)
                else:
                    plt.setp(ax1.get_xticklabels(), fontsize=8)
                fig.canvas.update()
                fig.canvas.flush_events()
   
    
    if line_plot:
        plt.subplots_adjust(left=.03,right=.93,
                        top= .95,bottom=.05, hspace=.2,
                        wspace=.13)
        l1 = plt.plot(1, color=clrs_r[0][5:])[0]
        l2 = plt.plot(1,  color=clrs_r[1][5:])[0]
        l3 = plt.plot(1,  color=clrs_r[2][5:])[0]
        l4 = plt.plot(1, color=clrs_r[3][5:])[0]
        l4 = plt.plot(1, color=clrs_r[4][5:])[0]

        fig.legend([l1, l2, l3, l4],     # The line objects
                   labels=tasks[::-1],   # The labels for each line
                   loc="center right",   # Position of legend
                   borderaxespad=0.1,    # Small spacing around legend box
                   title="Tasks",  # Title for the legend
                   prop={'size': 6}
                   )
        empties=len(np.ndarray.flatten(axs))-len(pElecs)
        if len(pElecs)<=length:
            for i in np.arange(empties):
                fig.delaxes(axs[-(i+1)])
        else:
              for i in np.arange(empties):
                fig.delaxes(axs[height-1][-(i+1)])
    
if line_plot:
    mng = plt.get_current_fig_manager()
    mng.window.showMaximized() #maximize the figure
    plt.show()
