In [16]:
import numpy as np
import os
import pandas as pd
import h5py
from scipy.io import loadmat
import socket
import mne

hostname = socket.gethostname()
p = {}
if hostname == 'syndrome' or hostname == 'zod.psych.nyu.edu' or hostname == 'zod':
    p['datc'] = '/d/DATC/datc/MD_TMS_EEG'
else:
    p['datc'] = '/Users/mrugankdake/Documents/Clayspace/EEG_TMS/datc/MD_TMS_EEG'
p['data'] = p['datc'] + '/data'
p['analysis'] = p['datc'] + '/analysis'
p['EEGfiles'] = p['datc'] + '/EEGfiles'
p['meta'] = p['analysis'] + '/meta_analysis'
p['df_fname'] = os.path.join(p['meta'], 'calib_filtered.csv')
p['master_evoked'] = os.path.join(p['EEGfiles'], 'masterTFR_evoked.mat')
p['master_induced'] = os.path.join(p['EEGfiles'], 'masterTFR_induced.mat')
p['training_data'] = os.path.join(p['EEGfiles'], 'training_data.npy')

# Load up summary meta-data
summary_df = pd.read_csv(os.path.join(p['analysis'] + '/EEG_TMS_meta_Summary.csv'))
All_metadata = {row['Subject ID']: row for _, row in summary_df.iterrows()}

# Load up behavioral data
df_behav = pd.read_csv(p['df_fname'])
df = df_behav[['subjID', 'day', 'gender', 'handedness',
       'hemistimulated', 'PT', 'StimIntensity',
       'rnum', 'tnum', 'istms', 'ispro', 'instimVF', 'TarX', 'TarY', 'isaccX',
       'isaccY', 'fsaccX', 'fsaccY', 'isacc_rt', 'fsacc_rt',
       'isacc_peakvel', 'fsacc_peakvel', 'trial_type',
       'TMS_condition', 'ierr', 'ferr',
       'igain', 'fgain', 'eccentricity', 'polang', 'ipea', 'fpea', 'iang',
       'fang', 'itheta', 'iradial', 'itangential', 'ftheta', 'fradial',
       'ftangential', 'TMS_time', 'ierr_threshold', 'ferr_threshold']]

In [2]:
sub_list = [1]
day_list = [1, 2, 3]
conditions = ['pin', 'pout', 'ain', 'aout']

data_dict = {cond: {ss: {dd: [] for dd in day_list} for ss in sub_list} for cond in conditions}
trl_dict = {cond: {ss: {dd: [] for dd in day_list} for ss in sub_list} for cond in conditions}

subject_day_info = []
freq_band = (8, 12)
time_band = (-1, 4.5)
ch_count = None
time_points = None
tr_count = 0

tfr_type = 'evoked'

if os.path.exists(p['training_data']):
    with np.load(p['output_filename']) as data:
        data_matrix = data['data_matrix']
        chan_list = data['chan_list']
        trl_matrix = data['trl_matrix']
        time_list = data['time_list']
        freq_list = data['freq_list']
else:
    for cond_idx, cond in enumerate(conditions):
        for ss in sub_list:
            for dd in day_list:
                # Loader files for TFR, trl ids and flagged channel and trials
                this_fname = os.path.join(p['EEGfiles'], f'sub{ss:02}', f'day{dd:02}', f'sub{ss:02}_day{dd:02}_TFR_'+tfr_type+'.mat')
                trl_idx_fname = os.path.join(p['EEGfiles'], f'sub{ss:02}', f'day{dd:02}', f'sub{ss:02}_day{dd:02}_trl_idx.mat')
                flag_data_fname = os.path.join(p['EEGfiles'], f'sub{ss:02}', f'day{dd:02}', f'sub{ss:02}_day{dd:02}_flagdata.mat')

                flag_data = loadmat(flag_data_fname)['trls_to_remove'][0]
                trl_idx = loadmat(trl_idx_fname)['trl_idx'][0][0][cond_idx]
                trl_idx = np.asarray(trl_idx).T[0]
                trl_idx = [trl for trl in trl_idx if trl not in flag_data]
                # print(flag_data)
                # print(trl_idx)
                trl_dict[cond][ss][dd] = trl_idx
                with h5py.File(this_fname, 'r') as f:
                    # Load up power-spectrum
                    powspctrm = np.array(f['POW'][cond]['powspctrm'])
                    # Load up channel labels
                    ch_refs = f['POW'][cond]['label'][0]
                    ch_labels = []
                    for ref in ch_refs:
                        label_data = f[ref]
                        label = ''.join(chr(c[0]) for c in label_data)
                        ch_labels.append(label)
                    # Load up time and frequency
                    time = np.array(f['POW'][cond]['time'])
                    freqs = np.array(f['POW'][cond]['freq'])
                    
                    # Create order of channel labels first time running this
                    if tr_count == 0:
                        chan_list = ch_labels
                    
                    # Reorder data for channel indices are different from the one in first dataset
                    channel_indices = [chan_list.index(ch) for ch in ch_labels]
                    powspctrm = powspctrm[:, :, channel_indices, :]

                    # Slice along the time of interest -1 to 4.5 seconds
                    time = np.array(f['POW'][cond]['time'])
                    time_band_indices = np.where((time >= time_band[0]) & (time <= time_band[1]))[0]
                    # Average over the alpha band
                    freqs = np.array(f['POW'][cond]['freq'])
                    freq_band_indices = np.where((freqs >= freq_band[0]) & (freqs <= freq_band[1]))[0]

                    # Create a holder for time and freq
                    if tr_count == 0:
                        time_list = time[time_band_indices]
                        freq_list = freqs[freq_band_indices]

                    powspctrm = powspctrm[time_band_indices, :, :, :]
                    powspctrm_avg = np.mean(powspctrm[:, freq_band_indices, :, :], axis=1)

                    # Reorder X_avg in the shape (trials, channels, time)
                    powspctrm_avg = np.transpose(powspctrm_avg, (2, 1, 0))

                    data_dict[cond][ss][dd] = powspctrm_avg

                    if ch_count is None:
                        ch_count = powspctrm_avg.shape[1]
                    if time_points is None:
                        time_points = powspctrm_avg.shape[2]
                    tr_count += powspctrm_avg.shape[0]

    # Save data and trl info in a giant array
    data_matrix = np.zeros((len(conditions), len(sub_list), len(day_list), tr_count, ch_count, time_points))
    trl_matrix = np.zeros((len(conditions), len(sub_list), len(day_list), tr_count))
    current_trial_index = 0

    # Concatenate trial info and epoched data into an array
    for cond_idx, cond in enumerate(conditions):
        for ss_idx, ss in enumerate(sub_list):
            for dd_idx, dd in enumerate(day_list):
                data = data_dict[cond][ss][dd]
                num_trials = data.shape[0]
                data_matrix[cond_idx, ss_idx, dd_idx, current_trial_index:current_trial_index+num_trials, :, :] = data
                trl_mat = trl_dict[cond][ss][dd]
                trl_matrix[cond_idx, ss_idx, dd_idx, current_trial_index:current_trial_index+num_trials] = trl_mat
                current_trial_index += num_trials
    np.savez_compressed(p['training_data'], data_matrix=data_matrix, chan_list=chan_list, 
                        trl_matrix=trl_matrix, time_list=time_list, freq_list=freq_list)

In [18]:
df.head(10)

Unnamed: 0,subjID,day,gender,handedness,hemistimulated,PT,StimIntensity,rnum,tnum,istms,...,fang,itheta,iradial,itangential,ftheta,fradial,ftangential,TMS_time,ierr_threshold,ferr_threshold
0,1,1,M,Right,Left,40.0,44,1,1,1,...,-1.108366,9.776067,1.173321,0.202163,-105.188438,-0.068791,-0.253396,middle,6.0,5.858632
1,1,1,M,Right,Left,40.0,44,1,2,1,...,-3.242797,-124.583117,-0.566011,-0.820997,-124.583117,-0.566011,-0.820997,middle,6.0,5.858632
2,1,1,M,Right,Left,40.0,44,1,3,1,...,0.834452,12.652869,0.933824,0.209639,12.652869,0.933824,0.209639,middle,6.0,5.858632
3,1,1,M,Right,Left,40.0,44,1,4,1,...,-0.896185,-26.814327,0.4639,-0.234479,-26.814327,0.4639,-0.234479,middle,6.0,5.858632
4,1,1,M,Right,Left,40.0,44,1,5,1,...,1.87387,172.002941,-2.743788,0.385471,172.002941,-2.743788,0.385471,middle,6.0,5.858632
5,1,1,M,Right,Left,40.0,44,1,7,1,...,-3.550709,-68.316834,0.353586,-0.889282,-68.316834,0.353586,-0.889282,middle,6.0,5.858632
6,1,1,M,Right,Left,40.0,44,1,8,1,...,5.190948,97.792269,-0.184098,1.3453,97.792269,-0.184098,1.3453,middle,6.0,5.858632
7,1,1,M,Right,Left,40.0,44,1,9,1,...,0.513874,169.886892,-0.713735,0.127304,169.886892,-0.713735,0.127304,middle,6.0,5.858632
8,1,1,M,Right,Left,40.0,44,1,10,1,...,-0.492228,-144.740081,-0.158479,-0.112043,-144.740081,-0.158479,-0.112043,middle,6.0,5.858632
9,1,1,M,Right,Left,40.0,44,1,11,1,...,0.113584,171.598527,-0.195169,0.028825,171.598527,-0.195169,0.028825,middle,6.0,5.858632
