In [10]:
import mne
import os
import glob
import numpy as np
import pandas as pd
from scipy.io import savemat
import matplotlib.pyplot as plt
from utils import *

In [11]:
#======================================================================================
#                        INITIALIZE DIRECTORIES
#======================================================================================
root_dir = "/Users/cindyzhang/Documents/M2/Audiomotor_Piano/AM-EEG/data_raw" #where the raw bdf files are
output_base = '/Users/cindyzhang/Documents/M2/Audiomotor_Piano/AM-EEG/data_preprocessed' #where all the preprocessed .mat files and other info go

plot = False
FS_ORIG = 2048  # Hz
#subjects_to_process = ['01', '02', '03', '04', '05', '06', '07', '08', '09', '10', 
                     #  '11', '12', '13', '14', '15', '16', '17', '18', '19', '20']  

subjects_to_process = ['00']


In [12]:
#======================================================================================
#                       PREPROCESSING PARAMETERS
#======================================================================================
# Notch filtering
notch_applied = False
freq_notch = 50

# Bandpass filtering
bpf_applied = True
freq_low   = 0.01
freq_high  = 8
bandpass = str(freq_low) + '-' + str(freq_high)
ftype = 'butter'
order = 3

# Spherical interpolation
int_applied = False
interpolation = 'spline'

# Rereferencing using average of mastoids electrodes
reref_applied = True
reref_type = 'Mastoids'  #Mastoids #Average

# Downsampling
down_applied = True
downfreq = 64
if not down_applied:
    downfreq = FS_ORIG


In [13]:
#======================================================================================
#                       LOOP THROUGH SUBJECTS
#======================================================================================

files = glob.glob(os.path.join(root_dir, '**', '*.bdf'), recursive=True)

for idx, file in enumerate(files):
    
    print("Currently processing ", file)

    df_pre = pd.DataFrame()

    subject_ID = file.split('.')[0][-2:]
    if subject_ID not in subjects_to_process:
        continue
    
    
    output_dir = os.path.join(output_base, str(idx))
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    #======================================================================================
    #                        READ EEG FILES
    #======================================================================================
        
    raw = mne.io.read_raw_bdf(file, eog=None, misc=None, stim_channel='Status', 
                              infer_types=False, preload=False, verbose=None)
   
    # Check metadata
    n_time_samps = raw.n_times
    time_secs = raw.times
    ch_names = raw.ch_names
    n_chan = len(ch_names) 
    print('the (cropped) sample data object has {} time samples and {} channels.'
        ''.format(n_time_samps, n_chan))
    print('The last time sample is at {} seconds.'.format(time_secs[-1]))
    print('The first few channel names are {}.'.format(', '.join(ch_names[:3])))
    print('bad channels:', raw.info['bads'])  # chs marked "bad" during acquisition
    print(raw.info['sfreq'], 'Hz')            # sampling frequency
    print(raw.info['description'], '\n')      # miscellaneous acquisition info
    print(raw.info)

    if plot:
        raw.plot(start=100, duration=10)


    #======================================================================================
    #                       FIND TRIGGERS
    #======================================================================================
    #loading the raw events
    #the stim channel is called 'Status' 
    events = mne.find_events(raw, stim_channel='Status', shortest_event=1) #raises exception if shortest event is default 2...?
    

#...TO DO: import bad triggers noted in experiment from corresponding csv 
    bad_triggers = []
    #events.pop(bad_triggers indices)

    #cleaned triggers according to channel
    #trial start and end are indicated in channel 5, filtered for 10+ min duration
    events_2, events_3, events_4, events_5, trial_starts = sort_events(events)


   
    #======================================================================================
    #                        CROPPING FILES TO THE TRIAL
    #======================================================================================
    print("trial starts: ", trial_starts)
    

    listen_pre_start = trial_starts[0][0]/FS_ORIG
    listen_pre_end = listen_pre_start+665

    motor_pre_start = trial_starts[1][0]/FS_ORIG
    motor_pre_end = motor_pre_start+600

    error_pre_start = trial_starts[2][0]/FS_ORIG
    error_pre_end = error_pre_start+600

#.....TO DO
    #add stuff for post 
    #approximate crop for training?

    eeg_listen_pre = raw.copy().crop(tmin = listen_pre_start, tmax = listen_pre_end)
    eeg_motor_pre = raw.copy().crop(tmin = motor_pre_start, tmax = motor_pre_end)
    eeg_error_pre = raw.copy().crop(tmin = error_pre_start, tmax = error_pre_end)

    
    #eegs_to_process = [eeg_listen_pre]
    eegs_to_process = []
    eeg_names = ['eeg_listen_pre']

    #get the eeg data
    
    for i, eeg in enumerate(eegs_to_process):

    #======================================================================================
    #                       FILTERING
    #======================================================================================
        
        ## -------------
        ## Select channels
        ## -------------
        eeg_channels = ch_names[0:66]
        eeg = eeg.pick_channels(eeg_channels)
        if plot:
            eeg.plot(start=100, duration=10, n_channels=len(raw.ch_names))

         
         #...BP filtering
         
         #...interpolation
            
        
        #... re-referencing
         
         
        ## -------------
        ## Downsampling
        ## -------------
        df_pre['down_applied'] = [down_applied]
        df_pre['downfreq'] = [downfreq]
        if down_applied:
            eeg = eeg.resample(sfreq=downfreq)
            print(eeg.info)
            if plot:
                eeg.plot()
  
        
    #======================================================================================
    #                       SAVING CROPPED FILES
    #======================================================================================
        
        name = eeg_names[i]
        eeg_tosave = eeg.get_data()

        savemat(os.path.join(output_dir,   f'{name}_{subject_ID}.mat'), {'trial_data': eeg_tosave[0:64, :], 
                                                              'trial_mastoids': eeg_tosave[64:, :]})
    

         ## -------------
        ## Save preprocessing stages
        ## -------------
    df_pre.to_csv(os.path.join(output_dir, f"preprocess_record_{subject_ID}"), index=False)


Currently processing  /Users/cindyzhang/Documents/M2/Audiomotor_Piano/AM-EEG/data_raw/sub_00.bdf
Extracting EDF parameters from /Users/cindyzhang/Documents/M2/Audiomotor_Piano/AM-EEG/data_raw/sub_00.bdf...
BDF file detected
Setting channel info structure...
Creating raw.info structure...
the (cropped) sample data object has 4270080 time samples and 80 channels.
The last time sample is at 2084.99951171875 seconds.
The first few channel names are Fp1, AF7, AF3.
bad channels: []
2048.0 Hz
None 

<Info | 8 non-empty values
 bads: []
 ch_names: Fp1, AF7, AF3, F1, F3, F5, F7, FT7, FC5, FC3, FC1, C1, C3, C5, ...
 chs: 79 EEG, 1 Stimulus
 custom_ref_applied: False
 highpass: 0.0 Hz
 lowpass: 417.0 Hz
 meas_date: 2024-03-15 17:38:06 UTC
 nchan: 80
 projs: []
 sfreq: 2048.0 Hz
 subject_info: 1 item (dict)
>
Trigger channel has a non-zero initial value of 130816 (consider using initial_event=True to detect this event)
21987 events found
Event IDs: [65282 65284 65286 65288 65296]
trial starts:  [[

EXPLORATION BELOW, NOT FOR THE ACTUAL PIPELINE
Note the different thresholds used (keystroke triggers are longer than mode switch or trial triggers)

In [14]:
trial_starts

array([[  46413,   65280,   65296],
       [1555179,   65280,   65296],
       [2989776,   65280,   65296]])

In [6]:
"""raw_events = mne.find_events(raw, shortest_event=1)
raw.plot(events = raw_events)"""

Trigger channel has a non-zero initial value of 130816 (consider using initial_event=True to detect this event)
21987 events found
Event IDs: [65282 65284 65286 65288 65296]


In [15]:
listen_events = clean_triggers(mne.find_events(eeg_listen_pre), threshold=1000)
eeg_listen_pre.plot(events = listen_events)

Trigger channel has a non-zero initial value of 65296 (consider using initial_event=True to detect this event)
19417 events found
Event IDs: [65282 65296]
Using pyopengl with version 3.1.6


<mne_qt_browser._pg_figure.MNEQtBrowser at 0x12ff200d0>

Channels marked as bad:
none


In [16]:
motor_events = clean_triggers(mne.find_events(eeg_motor_pre), threshold=1500)
motor_events.shape
eeg_motor_pre.plot(events = motor_events)

Trigger channel has a non-zero initial value of 65296 (consider using initial_event=True to detect this event)
695 events found
Event IDs: [65282 65296]
Using pyopengl with version 3.1.6


<mne_qt_browser._pg_figure.MNEQtBrowser at 0x285c4f9a0>

Channels marked as bad:
none


In [17]:
error_events = clean_triggers(mne.find_events(eeg_error_pre, shortest_event=1), threshold=1500)
eeg_error_pre.plot(events = error_events)

Trigger channel has a non-zero initial value of 65296 (consider using initial_event=True to detect this event)
1795 events found
Event IDs: [65282 65284 65286 65288 65296]
Using pyopengl with version 3.1.6


<mne_qt_browser._pg_figure.MNEQtBrowser at 0x2868065f0>

Channels marked as bad:
none
