In [1]:
from glob import glob
import os
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('QtAgg')
from mne.preprocessing import ICA
import mne
import numpy as np
import pandas as pd
import pyprep
import pyxdf
from important_files.utils import *
from scipy.signal import welch
import warnings
import json
warnings.filterwarnings("ignore")


## HELPERS ##
def annotate_blinks(
    raw: mne.io.Raw, ch_name: list[str] = ["E25", "E8"]
) -> mne.Annotations:
    """Annotate the blinks in the EEG signal.

    Args:
        raw (mne.io.Raw): The raw EEG data in mne format.
        ch_name (list[str]): The channels to use for the EOG. Default is
                            ["Fp1", "Fp2"]. I would suggest to use the
                            channels that are the most frontal (just above
                            the eyes). In the case of an EGI system the
                            channels would be "E25" and "E8".

    Returns:
        mne.Annotations: The annotations object containing the blink events.
    """
    eog_epochs = mne.preprocessing.create_eog_epochs(raw, ch_name=ch_name)
    blink_annotations = mne.annotations_from_events(
        eog_epochs.events,
        raw.info["sfreq"],
        event_desc={eog_epochs.events[0, 2]: "blink"},
    )
    return blink_annotations

def annotate_muscle(raw: mne.io.Raw) -> mne.Annotations:
    muscle_annotations, _ = mne.preprocessing.annotate_muscle_zscore(
        raw, 
        threshold=3, # this needs to be calibrated for the entire dataset
        ch_type='eeg', 
        min_length_good=0.1, 
        filter_freq=(95, 120), 
        )

    return muscle_annotations


In [2]:


xdf_filename = '/Users/bryan.gonzalez/CUNY_subs/sub-P5197505/sub-P5197505_ses-S001_task-CUNY_run-001_mobi.xdf'

In [3]:
df = get_event_data(event='RestingState', 
                        df=import_eeg_data(xdf_filename),
                        stim_df=import_stim_data(xdf_filename))

ch_names = [f"E{i+1}" for i in range(df.shape[1] - 1)]
info = mne.create_info(ch_names, 
                    sfreq=1/df.lsl_time_stamp.diff().mean(), 
                    ch_types='eeg')

TS = df.lsl_time_stamp.values
df.drop(columns=['lsl_time_stamp'], inplace=True)

raw = mne.io.RawArray(df.T * 1e-6, info=info) # multiplying by 1e-6 converts to volts

# Create a Cz reference
value = np.zeros((1, raw.n_times))
info = mne.create_info(["Cz"], raw.info['sfreq'], ch_types='eeg')
cz = mne.io.RawArray(value, info)
raw.add_channels([cz], force_update_info=True)

# Apply a montage
montage = mne.channels.make_standard_montage('GSN-HydroCel-129')
raw.set_montage(montage, on_missing='ignore')

#raw.crop(tmin=0, tmax=5)

prep_params = {
        "ref_chs": "eeg",
        "reref_chs": "eeg",
        "line_freqs": np.arange(60, raw.info["sfreq"] / 2, 60),
    }
# these params set up the robust reference  - i.e. median of all channels and interpolate bad channels
prep = pyprep.PrepPipeline(raw, montage=montage, channel_wise=True, prep_params=prep_params)
prep_output = prep.fit()
raw_cleaned = prep_output.raw_eeg

Creating RawArray with float64 data, n_channels=128, n_times=299989
    Range : 0 ... 299988 =      0.000 ...   299.998 secs
Ready.
Creating RawArray with float64 data, n_channels=1, n_times=299989
    Range : 0 ... 299988 =      0.000 ...   299.998 secs
Ready.
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Setting up high-pass filter at 1 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal highpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Filter length: 3301 samples (3.301 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.2s
[Parallel(n_jobs=1)]: Done  71 tasks      | elapsed:    0.6s
[Parallel(n_jobs=1)]: Done 129 out of 129 | elapsed:    1.1s finished


Setting up high-pass filter at 1 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal highpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Filter length: 3301 samples (3.301 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.1s
[Parallel(n_jobs=1)]: Done  71 tasks      | elapsed:    0.5s
[Parallel(n_jobs=1)]: Done 129 out of 129 | elapsed:    0.9s finished


Removed notch frequencies (Hz):
     60.00 : 7482 windows
    120.00 : 7482 windows
    180.00 : 7482 windows
    239.00 : 7482 windows
    240.00 : 7482 windows
    241.00 : 7482 windows
    299.00 : 7482 windows
    300.00 : 7482 windows
    301.00 : 7482 windows
    359.00 : 7482 windows
    360.00 : 7482 windows
    361.00 : 7482 windows
    419.00 : 7482 windows
    420.00 : 7482 windows
    421.00 : 7482 windows
    479.00 : 7482 windows
    480.00 : 7482 windows
    481.00 : 7482 windows
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Setting up high-pass filter at 1 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal highpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Filter length: 3301 samples (3.301 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.2s
[Parallel(n_jobs=1)]: Done  71 tasks      | elapsed:    0.6s
[Parallel(n_jobs=1)]: Done 129 out of 129 | elapsed:    1.0s finished


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Executing RANSAC
This may take a while, so be patient...
Finding optimal chunk size : 25
Total # of chunks: 5
Current chunk:
1
2
3
4
5

RANSAC done!


2025-06-17 15:42:21,017 - pyprep.reference - INFO - Bad channels: {'bad_by_nan': [], 'bad_by_flat': ['E127', 'Cz'], 'bad_by_deviation': ['E92'], 'bad_by_hf_noise': ['E55', 'E69', 'E73', 'E80'], 'bad_by_correlation': ['E92', 'E95'], 'bad_by_SNR': [], 'bad_by_dropout': [], 'bad_by_ransac': [np.str_('E55'), np.str_('E80'), np.str_('E126')], 'bad_all': [np.str_('E126'), 'E95', 'E80', 'E127', 'E69', 'E73', 'E92', 'E55', 'Cz']}


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Executing RANSAC
This may take a while, so be patient...
Finding optimal chunk size : 32
Total # of chunks: 4
Current chunk:
1
2
3
Finding optimal chunk size : 25
Total # of chunks: 5
Current chunk:
1
2
3
4
5

RANSAC done!


2025-06-17 15:44:15,785 - pyprep.reference - INFO - Bad channels: {'bad_by_nan': [], 'bad_by_flat': ['Cz', 'E127'], 'bad_by_deviation': ['E92'], 'bad_by_hf_noise': ['E73', 'E69'], 'bad_by_correlation': ['E95', 'E92'], 'bad_by_SNR': [], 'bad_by_dropout': [], 'bad_by_ransac': [np.str_('E126'), np.str_('E128')], 'bad_all': [np.str_('E126'), 'E95', np.str_('E128'), 'E127', 'E69', 'E73', 'E92', 'Cz']}


Setting channel interpolation method to {'eeg': 'spline'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 94.0 mm
Computing interpolation matrix from 121 sensor positions
Interpolating 8 sensors


2025-06-17 15:44:16,442 - pyprep.reference - INFO - Iterations: 1


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Executing RANSAC
This may take a while, so be patient...
Finding optimal chunk size : 42
Total # of chunks: 3
Current chunk:
1
2
3

RANSAC done!


2025-06-17 15:46:44,575 - pyprep.reference - INFO - Bad channels: {'bad_by_nan': [], 'bad_by_flat': ['Cz', 'E127'], 'bad_by_deviation': ['E92'], 'bad_by_hf_noise': ['E73', 'E69'], 'bad_by_correlation': ['E95', 'E92'], 'bad_by_SNR': [], 'bad_by_dropout': [], 'bad_by_ransac': [np.str_('E126'), np.str_('E128')], 'bad_all': [np.str_('E126'), 'E95', np.str_('E128'), 'E127', 'E69', 'E73', 'E92', 'Cz']}


Setting channel interpolation method to {'eeg': 'spline'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 94.0 mm
Computing interpolation matrix from 121 sensor positions
Interpolating 8 sensors


2025-06-17 15:46:45,157 - pyprep.reference - INFO - Iterations: 2


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Executing RANSAC
This may take a while, so be patient...
Finding optimal chunk size : 42
Total # of chunks: 3
Current chunk:
1
2
3

RANSAC done!


2025-06-17 15:48:06,195 - pyprep.reference - INFO - Bad channels: {'bad_by_nan': [], 'bad_by_flat': ['Cz', 'E127'], 'bad_by_deviation': ['E92'], 'bad_by_hf_noise': ['E73', 'E69'], 'bad_by_correlation': ['E95', 'E92'], 'bad_by_SNR': [], 'bad_by_dropout': [], 'bad_by_ransac': [np.str_('E126'), np.str_('E128')], 'bad_all': [np.str_('E126'), 'E95', np.str_('E128'), 'E127', 'E69', 'E73', 'E92', 'Cz']}
2025-06-17 15:48:06,196 - pyprep.reference - INFO - Robust reference done


Setting channel interpolation method to {'eeg': 'spline'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 94.0 mm
Computing interpolation matrix from 121 sensor positions
Interpolating 8 sensors
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Setting up high-pass filter at 1 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal highpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Filter length: 3301 samples (3.301 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.1s
[Parallel(n_jobs=1)]: Done  71 tasks      | elapsed:    0.5s
[Parallel(n_jobs=1)]: Done 129 out of 129 | elapsed:    0.9s finished


Executing RANSAC
This may take a while, so be patient...
Finding optimal chunk size : 64
Total # of chunks: 2
Current chunk:
1
2

RANSAC done!
Found 3 uniquely bad channels:

0 by NaN: []

0 by flat: []

1 by deviation: ['E92']

0 by HF noise: []

2 by correlation: ['E92', 'E95']

0 by SNR: []

0 by dropout: []

1 by RANSAC: [np.str_('E127')]

Setting channel interpolation method to {'eeg': 'spline'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 94.0 mm
Computing interpolation matrix from 125 sensor positions
Interpolating 4 sensors
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Setting up high-pass filter at 1 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal highpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50

[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.1s
[Parallel(n_jobs=1)]: Done  71 tasks      | elapsed:    0.5s
[Parallel(n_jobs=1)]: Done 129 out of 129 | elapsed:    0.8s finished


Executing RANSAC
This may take a while, so be patient...
Finding optimal chunk size : 65
Total # of chunks: 2
Current chunk:
1
2

RANSAC done!


In [None]:
subject = xdf_filename.split('sub-')[-1].split('_')[0]

if len(glob('/'.join(xdf_filename.split('/')[:-1]) +'/*.fif')) < 3:
    df = get_event_data(event='RestingState', 
                        df=import_eeg_data(xdf_filename),
                        stim_df=import_stim_data(xdf_filename))

    ch_names = [f"E{i+1}" for i in range(df.shape[1] - 1)]
    info = mne.create_info(ch_names, 
                        sfreq=1/df.lsl_time_stamp.diff().mean(), 
                        ch_types='eeg')
    
    TS = df.lsl_time_stamp.values
    df.drop(columns=['lsl_time_stamp'], inplace=True)

    raw = mne.io.RawArray(df.T * 1e-6, info=info) # multiplying by 1e-6 converts to volts

    # Create a Cz reference
    value = np.zeros((1, raw.n_times))
    info = mne.create_info(["Cz"], raw.info['sfreq'], ch_types='eeg')
    cz = mne.io.RawArray(value, info)
    raw.add_channels([cz], force_update_info=True)

    # Apply a montage
    montage = mne.channels.make_standard_montage('GSN-HydroCel-129')
    raw.set_montage(montage, on_missing='ignore')

    #raw.crop(tmin=0, tmax=5)

    prep_params = {
            "ref_chs": "eeg",
            "reref_chs": "eeg",
            "line_freqs": np.arange(60, raw.info["sfreq"] / 2, 60),
        }
    # these params set up the robust reference  - i.e. median of all channels and interpolate bad channels
    prep = pyprep.PrepPipeline(raw, montage=montage, channel_wise=True, prep_params=prep_params)
    prep_output = prep.fit()
    raw_cleaned = prep_output.raw_eeg


    # check if cleaned file already exists
    save_path = '/'.join(xdf_filename.split('/')[:-1]) + f'/sub-{subject}_ses-S001_task-CUNY_run-001_eeg_clean.fif'
    raw_cleaned.save(save_path, overwrite=True)
    vars = {}
    print(f"Bad channels before robust reference: {prep.noisy_channels_original['bad_all']}")
    vars['bad_channels_before'] = prep.noisy_channels_original['bad_all']
    print(f"Interpolated channels: {prep.interpolated_channels}")
    vars['interpolated_channels'] = prep.interpolated_channels
    print(f"Bad channels after interpolation: {prep.still_noisy_channels}")
    vars['bad_channels_after'] = prep.still_noisy_channels
    # safe the vars dictionary to a file
    with open('/'.join(xdf_filename.split('/')[:-1]) + f'/sub-{subject}_ses-S001_task-CUNY_run-001_eeg_clean_vars.json', 'w') as f:
        json.dump(vars, f, indent=4)
    
'''else:
    raw_cleaned = mne.io.read_raw_fif('/'.join(xdf_filename.split('/')[:-1]) + f'/sub-{subject}_ses-S001_task-CUNY_run-001_eeg_clean.fif', preload=True)
    # read the vars dictionary from the file
    with open('/'.join(xdf_filename.split('/')[:-1]) + f'/sub-{subject}_ses-S001_task-CUNY_run-001_eeg_clean_vars.json', 'r') as f:
        vars = json.load(f)
    print(f"Bad channels before robust reference: {vars['bad_channels_before']}")
    print(f"Interpolated channels: {vars['interpolated_channels']}")
    print(f"Bad channels after interpolation: {vars['bad_channels_after']}")
    '''


In [None]:
TS

In [None]:



blink_annotations = annotate_blinks(raw_cleaned, ch_name=["E25", "E8"])

muscle_annotations = annotate_muscle(raw_cleaned)

all_annotations = blink_annotations + muscle_annotations + raw_cleaned.annotations
raw_cleaned.set_annotations(all_annotations)

# Create a binary array
binary_mask = np.zeros(len(raw_cleaned.times), dtype=int)

# Iterate over annotations
for annot in raw_cleaned.annotations:
    onset_sample = int(annot['onset'] * raw_cleaned.info['sfreq'])
    duration_sample = int(annot['duration'] * raw_cleaned.info['sfreq'])
    binary_mask[onset_sample:onset_sample + duration_sample] = 1

percent_good = 1 - np.sum(binary_mask) / len(binary_mask)
print(f'Percent Good Data: {percent_good * 100:.2f}%')
vars['percent_good'] = percent_good * 100

In [None]:
TS

# Artifact remove with ICA

In [None]:
# set notch filter
raw_cleaned.notch_filter(60)
# set bandpass filter
raw_cleaned.filter(l_freq=1.0, h_freq=50.0) # only keeping frequencies between 1-50 Hz
# play around with this number to get components 
# that seem to represent the actual brain activations well
num_components = .95 
ica = ICA(n_components=num_components, method='picard')
ica.fit(raw_cleaned)

In [None]:
# plot the components and wait for user input to select components
ica.plot_components( title='ICA Components')
#print("Select components to exclude (e.g., 0, 1, 2) and press Enter:")
#exclude = input().split(',')
#exclude = [int(i.strip()) for i in exclude if i.strip().isdigit()]

In [None]:
ica.plot_sources(raw_cleaned)

In [None]:
ica.plot_properties(raw_cleaned, picks=[0,4,13]) # This exact component number probably won't work if you recompute ICA


In [None]:
ica.plot_overlay(raw_cleaned, exclude=[0,1, 4,11, 13]) # see what the data would look like if we removed the component



In [None]:
ica.exclude = [0,4,11, 13] # these are the components that we want to exclude
ica.apply(raw_cleaned)

In [None]:
raw_cleaned.annotations.delete([i for i, desc in enumerate(raw_cleaned.annotations.description) if desc == 'blink' or desc == 'BAD_muscle'])
fig = raw_cleaned.plot(show_scrollbars=False,
                        show_scalebars=False,events=None, start=0, 
                        duration=200,n_channels=50, scalings=.35e-4, color='k', title='EEG Data after ICA')

fig.savefig(f'./report_images/{subject}_cleaned_eeg.png', dpi=300, bbox_inches='tight')

In [None]:
fig = raw_cleaned.plot_psd(fmax=50, average=False, show=True)
fig.savefig(f'./report_images/{subject}_cleaned_eeg_psd.png', dpi=300, bbox_inches='tight')

In [None]:
raw_cleaned.plot()

In [None]:
save_path = '/'.join(xdf_filename.split('/')[:-1]) + f'/sub-{subject}_ses-S001_task-CUNY_run-001_eeg_clean.fif'
raw_cleaned.save(save_path, overwrite=True)