This branch is the software release for the 2019 paper: https://www.nature.com/articles/s41598-019-47795-0

See LICENSE.txt

Copyright 2019 Massachusetts Institute of Technology

In [None]:
%reset -f
import os
from glob import glob
import scipy
import scipy.io
import numpy as np
import sklearn
import sklearn.metrics
import sklearn.preprocessing
import sklearn.linear_model
import matplotlib.pyplot as plt
%matplotlib inline
import getpass
import sklearn.datasets
import sklearn.ensemble
import sklearn.model_selection
import datetime
import h5py
import re
import hashlib
import sys
from importlib import reload

## Use sklearn for least squares with internal cross validation to reconstruct audio envelopes (b) from eeg (A)
#  according to Aw = b.
#
# For each experiment part, create an A matrix, then concatenate all the parts
# a1 a2 a3 a4 ... aN, audio
# e1 e2 e3 e4 ... eN, eeg
# Assume 26 sample (250 ms window)
# Use [e1 .. e26] to predict a1
# Therefore, drop the last (26 -1) samples from the audio vector to create the y target vector
#
## About
#  Greg Ciccarelli
#  September 3, 2018

In [None]:
import nipype.pipeline.engine as pe
import nipype.interfaces.utility as niu
import nipype.interfaces.io as nio

In [None]:
user = getpass.getuser()

In [None]:
def big_node(dct_params):
    """Main processing function:  performs audio reconstruction from eeg.
    
    Description
    -----------
    Unpack parameters
    Preproc data and create reconstruction algorithm
    Train algorithm and predict on test data
    Save out all parameters and predictions
    
    """    
    import os
    from glob import glob
    import scipy
    import scipy.io
    import scipy.signal
    import numpy as np
    import sklearn
    import sklearn.metrics
    import sklearn.preprocessing
    import sklearn.linear_model
    import sklearn.datasets
    import sklearn.ensemble
    import sklearn.model_selection
    import datetime
    import h5py
    import re
    import hashlib
    import sys
    from importlib import reload
    
    #----------------------------------------------------
    #  Unpack Parameters, preproc, create reconstruction alg.
    #----------------------------------------------------    
    file_path_name_eeg = dct_params['file_path_name_eeg']
    file_path_name_audio = dct_params['file_path_name_audio']
    train = dct_params['train']
    test = dct_params['test']    
    alpha_array = dct_params['alpha_array'] 
    zscore = dct_params['zscore'] 
    collect = dct_params['collect'] 
    file_path_name_util = dct_params['file_path_name_util']
    idx_ch = dct_params['idx_ch']
    num_context = dct_params['num_context']
    save_flag = dct_params['save_flag']
    file_path_bciaud_shared = dct_params['file_path_bciaud_shared']
    timestamp_time  = dct_params['timestamp_time']     
    file_path_gridhome = dct_params['file_path_gridhome']
    model_lsq = dct_params['model_lsq']
    windsor_flag = dct_params['windsor_flag']    
    aur_flag = dct_params['aur_flag']
    aur_thresh = dct_params['aur_thresh']
    loss_type = dct_params['loss_type']
    cv_type = dct_params['cv_type']
    
    sys.path.append(os.path.split(file_path_name_util)[0])
    module = __import__(os.path.split(file_path_name_util)[1])
    reload(module)
    load_data = getattr(module, 'load_data')
    cat_part = getattr(module, 'cat_part')
    make_conv = getattr(module, 'make_conv') 
    
    audio_ht, eeg_ht, audio_unatt_ht = load_data(file_path_name_audio, file_path_name_eeg)
    
    ## On the fly preprocessing
    if windsor_flag:
        prctile = np.nanpercentile(eeg_ht[train], [0.1, 1, 5, 95, 99, 99.9])
        #eeg_ht = eeg.copy() #crucial otherwise eeg is overwritten
        eeg_ht[eeg_ht < prctile[0]] = prctile[0]
        eeg_ht[eeg_ht > prctile[-1]] = prctile[-1]


    X_ht, y_ht, z_ht, groups_ht = cat_part(eeg_ht[train], audio_ht[train], audio_unatt_ht[train], 
                                idx_ch=idx_ch, num_context=num_context)
    print(groups_ht) 
        
    Xi, yi, zi = make_conv(eeg_ht[test], audio_ht[test], audio_unatt_ht[test],
                           idx_ch=idx_ch, num_context=num_context)

    if cv_type == 'sample':
        cv_gen = 3
        print('-- sample --')
    elif cv_type == 'group3fold':
        cv_gen = sklearn.model_selection.GroupKFold(n_splits=3).split(X_ht, y_ht, groups=groups_ht)
        print('-- part --')
    
    if loss_type == 'corr':
        def score_func(y, y_pred):    
            return scipy.stats.pearsonr(np.ravel(y), np.ravel(y_pred))[0]

        scoring = sklearn.metrics.make_scorer(score_func, greater_is_better=True, needs_proba=False, needs_threshold=False)
        print('-- corr --')
        
    elif loss_type == 'mse':
        scoring = None
        print('-- mse --')    
    
    if model_lsq == 'RidgeCV':
        clf = sklearn.linear_model.RidgeCV(alphas=alpha_array, fit_intercept=True, 
                                           normalize=False, scoring=scoring, cv=cv_gen, 
                                           gcv_mode=None, store_cv_values=False)
    elif model_lsq == 'LassoCV':
        #clf = sklearn.linear_model.LassoCV(eps=0.001, n_alphas=100, alphas=None, 
        #                                   fit_intercept=True, normalize=False, precompute='auto', 
        #                                   max_iter=10000, tol=0.0001, copy_X=True, cv=3, 
        #                                   verbose=False, n_jobs=None, positive=False, 
        #                                   random_state=None, selection='cyclic')
        clf = sklearn.linear_model.LassoCV(eps=0.001, alphas=alpha_array,
                                           fit_intercept=True, normalize=False, precompute='auto', 
                                           max_iter=1000, tol=0.0001, copy_X=True, cv=3, 
                                           verbose=False, n_jobs=None, positive=False, 
                                           random_state=None, selection='cyclic') 
        
    elif model_lsq == 'RANSAC':
        clf = sklearn.linear_model.RANSACRegressor(base_estimator=None, min_samples=None, 
                                                   residual_threshold=None, is_data_valid=None, 
                                                   is_model_valid=None, max_trials=100, max_skips=np.inf, 
                                                   stop_n_inliers=np.inf, stop_score=np.inf, stop_probability=0.99, 
                                                   loss='absolute_loss', random_state=None)


    if zscore == 'RobustScaler':
        pp_X = sklearn.preprocessing.RobustScaler()
        pp_Y = sklearn.preprocessing.RobustScaler()
        pp_Z = sklearn.preprocessing.RobustScaler()
    elif zscore == 'StandardScaler':
        pp_X = sklearn.preprocessing.StandardScaler() 
        pp_Y = sklearn.preprocessing.StandardScaler() 
        pp_Z = sklearn.preprocessing.StandardScaler()
    trX_ht  = pp_X.fit_transform(X_ht)
    trY_ht = pp_Y.fit_transform(y_ht)
    trZ_ht = pp_Z.fit_transform(z_ht)

    #----------------------------------------------------
    #  Fit and predict audio and EEG
    #----------------------------------------------------     
    t_start = datetime.datetime.now()
    print(trY_ht.shape)
    clf.fit(trX_ht, trY_ht)
    t_end = datetime.datetime.now()
    print('- lsq time -')
    print(t_end - t_start)    

    y_hat = clf.predict(pp_X.transform(Xi))

    #----------------------------------------------------
    #  Evaluate and save out
    #----------------------------------------------------     
    rho_att = scipy.stats.pearsonr(np.ravel(y_hat), np.ravel(pp_Y.transform(yi)))[0]
    rho_unatt = scipy.stats.pearsonr(np.ravel(y_hat), np.ravel(pp_Z.transform(zi)))[0]

    ####row = np.array([rho_att, rho_unatt])

    ###stats[test] = row
    #print(clf.alpha_)    
    #------------------------------------------------------------------------
    example_te_y = np.ravel(pp_Y.transform(yi))
    example_te_yhat = np.ravel(y_hat)
    example_te_unatt = np.ravel(pp_Z.transform(zi))

    #------------------------------------------------------------------------
    if collect == 'cocoha':
        session_str = '0000'
    elif collect == 'Columbia':
        session_str = '0000'
    else:
        session_str = re.search('BCIHearing_Subj_\d+_([\d_]+)', file_path_name_eeg).group(1)

    dct_all = {
                         'envTestAtt': example_te_y[None, :], #output api compatible
                         'envHatAtt': example_te_yhat[None, :], #output api compatible
                         'envTestUna': example_te_unatt[None, :], #output api compatible
                         #'yValAtt': example_val_y,
                         #'yValHat': example_val_yhat,
                         #'yValUna': example_val_z_unatt,
                         #'yTrainAtt': example_tr_y,
                         #'yTrainHat': example_tr_yhat,
                         #'yTrainUna': example_tr_unatt,
                         'yTestAtt': example_te_y,
                         'yTestHat': example_te_yhat,
                         'yTestUna': example_te_unatt,              
                         'subjID': re.search('Subj_(\d+)_', file_path_name_audio).group(1),#output api compatible
                         'alpha': clf.alpha_,
                         'rho_att': rho_att,
                         'rho_unatt': rho_unatt,
                         'file_path_name_net': 'LLlsq_grid_sklearn.ipynb',
                          'pp_X_center': pp_X.center_,
                          'pp_X_scale': pp_X.scale_,
                          'pp_Y_center': pp_Y.center_,
                          'pp_Y_scale': pp_Y.scale_,
                          'pp_Z_center': pp_Z.center_,
                          'pp_Z_scale': pp_Z.scale_,        
                         'clf_coef_': clf.coef_,
            }
    dct_all = {**dct_all, **dct_params}

    if save_flag:

        timestamp = '%s_%s' % (timestamp_time, 
                                     hashlib.md5((('').join(file_path_name_audio+file_path_name_eeg)).encode('utf')).hexdigest())

        file_path_save = XXX_file_path_save_with_timestamp

        # Create the file_path_save here to avoid race conditions in the workflow
        if not os.path.exists(file_path_save):
            os.makedirs(file_path_save)

        hashstr = ''
        for key, val in dct_all.items():
            if type(val) is str:
                hashstr = hashstr + key + val
            elif type(val) in [float, int]:
                hashstr = hashstr + key + str(val)
            elif type(val) in [list]:
                if type(val[0]) is str:
                    hashstr = hashstr + key + ','.join(val)
                elif type(val[0]) in [float, int]:
                    hashstr = hashstr + key + ','.join([str(i) for i in val])
        hexstamp = hashlib.md5(hashstr.encode('utf')).hexdigest()

        now_str = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
        file_path_name_checkpoint = os.path.join(file_path_save, 
                                                 'checkpoint_eeg2env_%s_%s.pt' 
                                                 % (hexstamp, now_str))

        print(file_path_name_checkpoint)
        # Replace all None elements of dict with NaN before saving to avoid save fail.
        for key, val in dct_all.items():
            if val is None:
                dct_all[key] = np.nan
        scipy.io.savemat(os.path.join(file_path_save, 
                                      'checkpoint_eeg2env_%s_%s.mat' 
                                       % (hexstamp, now_str)), 
                                      dct_all)



# Create configuration parameters

In [None]:
# Select the LSQ model

model_lsq = 'RidgeCV'
#model_lsq = 'LassoCV'
#model_lsq = 'RANSAC'

In [None]:
# Get data

timestamp_time = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
# :9 same AM or PM run

modality = 'neuroscan'
#modality = 'dsi'
collect = 'LL_HowTo_0DegreesSeparation'

if modality == 'dsi':
    idx_ch = np.asarray([0, 1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 13, 14, 15, 16, 18, 19]) # drop M1 8 and M2 17
    #idx_ch = np.arange(20) #keep mastoid
elif modality == 'neuroscan':
    idx_ch = np.arange(64)
    # Subsample the wet eeg to match the dry eeg lead configuration
    #idx_ch = [45, 25,  7,  9, 11, 29, 49, 27,  0,  2, 23, 43, 60, 62,  5, 13, 51, 31] # full drop M1 and M2  (32, 42 zero based indexing)
    
subj_folder_list = XXX_list_of_subject_folders 
file_path_name_audio_list = XXX_list_of_audio_files
file_path_name_eeg_list = XXX_list_of_eeg_files

print(file_path_name_audio_list)
print(file_path_name_eeg_list)

In [None]:
# Establish preproc and regularization params
save_flag = True
num_context = 26
num_context = 51

alpha_array = np.logspace(1, 10, 10) 

zscore = 'RobustScaler'
windsor_flag = False

    
loss_type = 'mse'
loss_type = 'corr'

cv_type = 'group3fold'
#cv_type = 'sample'

In [None]:
alpha_array

In [None]:
# Get helper file
file_path_name_util = XXX_path_to_lsq_grid

sys.path.append(os.path.split(file_path_name_util)[0])
module = __import__(os.path.split(file_path_name_util)[1])
reload(module)
load_data = getattr(module, 'load_data')
cat_part = getattr(module, 'cat_part')
make_conv = getattr(module, 'make_conv') 

# Create cross val folds

In [None]:
eval_list = []
for file_path_name_audio, file_path_name_eeg in zip(file_path_name_audio_list, file_path_name_eeg_list):
    print(file_path_name_audio)
    audio, eeg, audio_unatt = load_data(file_path_name_audio, file_path_name_eeg)

    # exhaustive
    full_set = audio.shape[0]
    #full_set = 10 # debug, 4

    for test in [33]: #range(full_set): 
        train = sorted(list(set(range(full_set)) - set([test])))
        train = (np.asarray(train)[np.random.permutation(np.size(train))[:]]).tolist() # Only use N parts for training
        eval_list.append([train, test, file_path_name_audio, file_path_name_eeg])

In [None]:
idx_b = 0
dct_params = dict()
dct_params['train'] = eval_list[idx_b][0]
dct_params['test'] = eval_list[idx_b][1]
dct_params['file_path_name_audio'] = eval_list[idx_b][2]
dct_params['file_path_name_eeg'] = eval_list[idx_b][3]
dct_params['alpha_array']  = alpha_array 
dct_params['zscore'] = zscore 
dct_params['collect'] = collect 
dct_params['file_path_name_util'] = file_path_name_util 
dct_params['idx_ch'] = idx_ch 
dct_params['num_context'] = num_context 
dct_params['save_flag'] = save_flag
dct_params['file_path_bciaud_shared'] = file_path_bciaud_shared
dct_params['timestamp_time'] = timestamp_time
dct_params['file_path_gridhome'] = file_path_gridhome
dct_params['model_lsq'] = model_lsq
dct_params['windsor_flag'] = windsor_flag
dct_params['loss_type'] = loss_type

# Debug
#big_node(dct_params)

In [None]:
n_splits = len(eval_list)
n_splits = 1

In [None]:
wf = pe.Workflow(name="wf")
for idx_b in range(n_splits): 
    #if np.mod(idx_b, 1) == 0:
    #    print(idx_b)  
        
    timestamp = '%s_%s' % (timestamp_time, 
                                 hashlib.md5((('').join(eval_list[idx_b][2]+eval_list[idx_b][3])).encode('utf')).hexdigest())

    file_path_save = XXX_file_path_save_with_timestamp
    
    # Create the file_path_save here to avoid race conditions in the workflow
    if not os.path.exists(file_path_save):
        os.makedirs(file_path_save)
        
    node_big = pe.Node(niu.Function(input_names=['dct_params'],
                                    output_names=['outputs'],
                                    function=big_node),
                                    name='big_node_%03d' % idx_b)
   
    dct_params = dict()
    dct_params['train'] = eval_list[idx_b][0]
    dct_params['test'] = eval_list[idx_b][1]
    dct_params['file_path_name_audio'] = eval_list[idx_b][2]
    dct_params['file_path_name_eeg'] = eval_list[idx_b][3]
    dct_params['alpha_array']  = alpha_array 
    dct_params['zscore'] = zscore 
    dct_params['collect'] = collect 
    dct_params['file_path_name_util'] = file_path_name_util 
    dct_params['idx_ch'] = idx_ch 
    dct_params['num_context'] = num_context 
    dct_params['save_flag'] = save_flag  
    dct_params['file_path_bciaud_shared'] = file_path_bciaud_shared     
    dct_params['timestamp_time'] = timestamp_time 
    dct_params['file_path_gridhome'] = file_path_gridhome
    dct_params['model_lsq'] = model_lsq
    dct_params['windsor_flag'] = windsor_flag
    dct_params['loss_type'] = loss_type
    dct_params['cv_type'] = cv_type    

    node_big.inputs.dct_params = dct_params    
    wf.add_nodes([node_big])

# Run the code

In [None]:
wf.config['exeuction']['crashdump_dir'] = XXX_path_to_crashdumpdir
wf.base_dir = XXX_path_to_base_dir

wf.config['execution']['parameterize_dirs'] = False
wf.config['execution']['poll_sleep_duration'] = 10
wf.config['execution']['job_finished_timeout'] = 30

In [None]:
print(timestamp_time)

In [None]:
run_local_flag = True
#un_local_flag = False

In [None]:
if run_local_flag:
    eg = wf.run() 
else: 
    eg = wf.run('SLURM', plugin_args={'sbatch_args': '--constraint=xeon-e5 --mem=15G'})  

In [None]:
print('done')