In [1]:
import numpy as np
import os.path as op
import os
import sys

cwd = os.getcwd()
sys.path.append(cwd)
print(sys.path)

from pprint import pformat
import argparse
# EEG utilities
import mne
from mne.preprocessing import ICA, create_eog_epochs
from pyprep.prep_pipeline import PrepPipeline
from autoreject import get_rejection_threshold, validation_curve
# BIDS utilities
from mne_bids import BIDSPath, read_raw_bids
from util.io.bids import DataSink

# constants
BIDS_ROOT = '../data/bids'
DERIV_ROOT = op.join(BIDS_ROOT, 'derivatives')
ERP_PASSBAND = (0.1, 40)
TASK = 'pitch'
TMIN = -0.3
TMAX = 0.3

sub = '3'
run = '1'

'''
Parameters
----------
sub : str
    Subject ID as in BIDS dataset
'''
print('----------------- load data ------------------')
bids_path = BIDSPath(
    root = BIDS_ROOT,
    subject = sub,
    task = TASK,
    run = run,
    datatype = 'eeg'
    )
print(bids_path)
raw = read_raw_bids(bids_path, verbose = False)
events, event_ids = mne.events_from_annotations(raw)

print('----------------- Downsample ------------------') 
raw, events = raw.resample(1250, events = events) # resample to 1250 Hz, ideally we would downsample after epoching, but to make the job less resource intensive we're downsampling first

print('----------------- run PREP pipeline ------------------') # notch, exclude bad chans, and re-reference
raw.load_data()
np.random.seed(int(sub))
lf = raw.info['line_freq']
prep_params = {
    "ref_chs": "eeg",
    "reref_chs": "eeg",
    "line_freqs": np.arange(lf, ERP_PASSBAND[1], lf)
}
prep = PrepPipeline(
    raw,
    prep_params,
    raw.get_montage(),
    ransac = False,
    random_state = int(sub)
    )
prep.fit()

print('----------------- Extract data from PREP ------------------')
prep_eeg = prep.raw_eeg # get EEG channels from PREP
prep_non_eeg = prep.raw_non_eeg # get non-EEG channels from PREP
raw_data = np.concatenate((prep_eeg.get_data(), prep_non_eeg.get_data())) # combine data from the two

# Create info object for post-PREP data
print('Create info object for post-PREP data')
new_ch_names = prep_eeg.info['ch_names'] + prep_non_eeg.info['ch_names']
raw = raw.reorder_channels(new_ch_names) # modify the channel names on the original raw data
raw_info = raw.info # use the modified info from the original raw data object
 
# Combine post-prep data and new info
print('Create new raw object')
raw = mne.io.RawArray(raw_data, raw_info) # replace original raw object

print('----------------- Filter ------------------') 
raw = raw.filter(*ERP_PASSBAND)

print('----------------- re-reference eye electrodes to become bipolar EOG ------------------')
def reref(dat):
    dat[0,:] = (dat[1,:] - dat[0,:])
    return dat
raw = raw.apply_function(
    reref,
    picks = ['leog', 'Fp2'],
    channel_wise = False
)
raw = raw.apply_function(
    reref,
    picks = ['reog', 'Fp1'],
    channel_wise = False
)
raw = raw.set_channel_types({'leog': 'eog', 'reog': 'eog'})

## now prepare non-epoched data for ERP analysis
# identify bad ICs on weakly highpassed data
print('----------------- Epoch data for ERP analysis ------------------')
epochs = mne.Epochs(
    raw,
    events, # same events as FFR epochs
    tmin = TMIN,
    tmax = TMAX, # only prestim
    event_id = event_ids,
    baseline = None,
    preload = True
)

print('----------------- Run ICA ------------------')
ica = ICA(n_components = 15, random_state = 0)
ica.fit(epochs, picks = ['eeg', 'eog'])

print('----------------- Apply ICA ------------------')
eog_indices, eog_scores = ica.find_bads_eog(epochs, threshold = 1.96)
ica.exclude = eog_indices
ica.apply(epochs) # transforms in place

if ica.exclude: # if we found any bad components
    fig_ica_removed = ica.plot_components(ica.exclude)

# now we no longer need EOG channels
epochs = epochs.drop_channels('leog')
epochs = epochs.drop_channels('reog')

print('----------------- Baseline correct ------------------')
epochs = epochs.apply_baseline((TMIN, 0.))

print('----------------- Reject bad trials ------------------')
thres = get_rejection_threshold(epochs)
print(thres)
epochs.drop_bad(reject = thres)

print('----------------- Save ------------------')
sink = DataSink(DERIV_ROOT, 'erp')
erp_fpath = sink.get_path(
    subject = sub,
    task = TASK,
    run = run,
    desc = 'forERP',
    suffix = 'epo',
    extension = 'fif.gz'
)
print(f'Saving epochs for ERP analysis to: {erp_fpath}')
epochs.save(erp_fpath, overwrite = True)

print('----------------- generate a report ------------------')
report = mne.Report(verbose = True)
report.parse_folder(op.dirname(erp_fpath), pattern = '*epo.fif.gz', render_bem = False)
if ica.exclude:
    fig_ica_removed = ica.plot_components(ica.exclude, show = False)
    report.add_figure(
        fig_ica_removed,
        title = 'Removed ICA Components',
        section = 'ICA'
    )
bads = prep.noisy_channels_original
html_lines = []
for line in pformat(bads).splitlines():
    html_lines.append('<br/>%s' % line)
html = '\n'.join(html_lines)
report.add_html(html, title = 'Interpolated Channels', section = 'channels')
report.add_html(epochs.info._repr_html_(), title = 'Epochs Info (FFR)', section = 'info')
report.add_html(epochs.info._repr_html_(), title = 'Epochs Info (ERP)', section = 'info')
report.save(op.join(sink.deriv_root, 'sub-%s.html'%sub), overwrite = True)

['/project/hcn1/.conda/envs/mne/lib/python311.zip', '/project/hcn1/.conda/envs/mne/lib/python3.11', '/project/hcn1/.conda/envs/mne/lib/python3.11/lib-dynload', '', '/project/hcn1/.conda/envs/mne/lib/python3.11/site-packages', '/project2/hcn1/pitch_tracking_attention/analysis']
----------------- load data ------------------
../data/bids/sub-3/eeg/sub-3_task-pitch_run-1_eeg.vhdr
Used Annotations descriptions: ['11', '12', '13', '21', '22', '23', '31', '32', '33']


  raw = read_raw_bids(bids_path, verbose = False)
  raw = read_raw_bids(bids_path, verbose = False)
['Aux1']
Consider setting the channel types to be of EEG/sEEG/ECoG/DBS/fNIRS using inst.set_channel_types before calling inst.set_montage, or omit these channels when creating your montage.
  raw = read_raw_bids(bids_path, verbose = False)


----------------- Downsample ------------------
----------------- run PREP pipeline ------------------
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Setting up high-pass filter at 1 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal highpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Filter length: 3961 samples (3.301 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    3.5s


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Setting up high-pass filter at 1 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal highpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Filter length: 3961 samples (3.301 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    4.1s


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).


2024-03-02 15:16:31,606 - pyprep.reference - INFO - Bad channels: {'bad_by_nan': [], 'bad_by_flat': ['Cz'], 'bad_by_deviation': [], 'bad_by_hf_noise': [], 'bad_by_correlation': [], 'bad_by_SNR': [], 'bad_by_dropout': [], 'bad_by_ransac': [], 'bad_all': ['Cz']}


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).


2024-03-02 15:19:57,254 - pyprep.reference - INFO - Bad channels: {'bad_by_nan': [], 'bad_by_flat': ['Cz'], 'bad_by_deviation': [], 'bad_by_hf_noise': ['PO8', 'O2'], 'bad_by_correlation': [], 'bad_by_SNR': [], 'bad_by_dropout': [], 'bad_by_ransac': [], 'bad_all': ['PO8', 'O2', 'Cz']}


Setting channel interpolation method to {'eeg': 'spline'}.
Interpolating bad channels.
    Automatic origin fit: head of radius nan mm
Computing interpolation matrix from 61 sensor positions


ValueError: array must not contain infs or NaNs