In [5]:
# imports
import mne
import os
import os.path as op
from mne.preprocessing import ICA
from mne.channels import make_standard_montage
from autoreject import AutoReject, compute_thresholds
import numpy as np
import gdown
import logging
import matplotlib.pyplot as plt


### Load the data

In [None]:

folderpath = r'../data/pilot_data/'
filename = 'sub-01_ses-001_raw.edf'
filepath = op.join(folderpath, filename)

if '.edf' in filename:
    # Load the raw data
    raw = mne.io.read_raw_edf(filepath, preload=True, stim_channel='Trigger')

    ## Do some preparation steps
    # Get all channel names
    channel_names = raw.info['ch_names']
    print("Channel Names:", channel_names)
    
    # Define non-EEG channels to exclude
    non_eeg_channels = ['EEG X1:ECG-Pz', 'EEG X2:-Pz', 'EEG X3:-Pz', 'CM', 'EEG A1-Pz', 'EEG A2-Pz']
    # Check which non-EEG channels are present
    existing_non_eeg = [ch for ch in non_eeg_channels if ch in channel_names]
    print("Non-EEG Channels to Exclude:", existing_non_eeg)
    # Drop non-EEG channels
    raw.drop_channels(existing_non_eeg)
    print("Channels after exclusion:", raw.info['ch_names'])

    # Identify EEG channels
    eeg_channels = [ch for ch in raw.info['ch_names'] if 'EEG' in ch]
    print("EEG Channels Before Renaming:", eeg_channels)
    # Create a mapping by removing 'EEG ' prefix and '-Pz' suffix
    rename_mapping = {ch: ch.replace('EEG ', '').replace('-Pz', '') for ch in eeg_channels}
    print("Rename Mapping:", rename_mapping)
    # Rename channels
    raw.rename_channels(rename_mapping)
    print("EEG Channels After Renaming:", [ch for ch in raw.info['ch_names'] if ch in rename_mapping.values()])

    # Set channel types
    for ch in raw.info['ch_names']:
        if ch in rename_mapping.values():
            raw.set_channel_types({ch: 'eeg'})
        elif 'Trigger' in ch:
            raw.set_channel_types({ch: 'stim'})
        else:
            raw.set_channel_types({ch: 'misc'})  # For any other channels, if any remain
    print("Channel Types:", raw.get_channel_types())
    # Apply standard 10-20 montage
    montage = make_standard_montage('standard_1020')
    raw.set_montage(montage)
    print(raw.info['dig'])


elif '.fif' in filename:
    # do something else
    print(filepath)
    

### Clean the data

In [None]:
## Filter
# Apply band-pass filter from 0.5 Hz to 100 Hz
raw_filtered = raw.copy().filter(l_freq=1, h_freq=100, fir_design='firwin')
# Apply notch filter at 60 Hz
raw_filtered.notch_filter(freqs=[60,120], fir_design='firwin')

## Dummy segment for autoreject
# Create fake events
events_for_autoreject = mne.make_fixed_length_events(raw, duration=1)
# Segment
epochs_for_autoreject = mne.Epochs(raw, events_for_autoreject, tmin=0, tmax=1, baseline=None, detrend=0, preload=True)

## Autoreject
# Initialize autoreject
ar = AutoReject()
# Fit autoreject
ar.fit(epochs_for_autoreject)

# Obtain list of bad epochs
ar_log = ar.get_reject_log(epochs_for_autoreject)
print(ar_log)



In [None]:
## Run ICA on clean data
ica = ICA(n_components=0.95, 
            random_state=0).fit(epochs_for_autoreject[~ar_log.bad_epochs], decim=3)
ica.plot_sources(raw_filtered, show=True)
ica.plot_components(title='ICA Components', show=True);

In [None]:
## Identify bad components
bad_components_indices = [1,2,3]  # Replace with actual bad components. Here, I identified the first component (IC000) as bad.
ica.exclude = bad_components_indices
print("Bad components:", bad_components_indices)

## Reconstruct data without bad components
raw_corrected = raw_filtered.copy()
ica.apply(raw_corrected)


## Rerun autoreject
# Initialize AutoReject with continuous mode
#ar = AutoReject(thresh_func='bayesian_optimization', n_jobs=-1)

# Fit and transform the raw data
#raw_clean = ar.fit_transform(raw_corrected)

### Inspect the data

In [None]:
raw.plot_psd(average=False, show=True)

In [None]:
raw_filtered.plot_psd(average=False, show=True)

In [None]:
raw_corrected.plot_psd(average=False, show=True)

In [None]:
raw_corrected.save(op.join(folderpath, filename.replace('.edf', '_preprocessed.fif')), overwrite=True)

# Identify Task Periods

In [19]:

folderpath = r'../data/pilot_data/'
filename = 'sub-01_ses-001_raw_preprocessed.fif'
filepath = op.join(folderpath, filename)
raw = mne.io.read_raw_fif(filepath, preload=True)
events = mne.find_events(raw, stim_channel='Trigger', min_duration=0.001, consecutive=False)
print("Events:", events)

# Initialize variables to track task periods
task_periods = {
    'Rest_GoNoGo': {'start': None, 'end': None},
    'GoNoGo': {'start': None, 'end': None},
    'LandoitC': {'start': None, 'end': None},
    'MentalImagery': {'start': None, 'end': None}
}
# Function to convert time (minutes) to samples
def minutes_to_samples(minutes, sfreq):
    return int(minutes * 60 * sfreq)

# Find all triggers and their sample indices
trigger_dict = {}
for event in events:
    sample, _, trigger = event
    if trigger not in trigger_dict:
        trigger_dict[trigger] = []
    trigger_dict[trigger].append(sample)
    
    
# Define Rest_GoNoGo: first Rest start (6) to Rest end (7)
if 6 in trigger_dict and 7 in trigger_dict:
    rest_start = trigger_dict[6][0]
    rest_end = trigger_dict[7][0]
    task_periods['Rest_GoNoGo']['start'] = rest_start
    task_periods['Rest_GoNoGo']['end'] = rest_end
    logging.info(f"Rest_GoNoGo: Start={rest_start}, End={rest_end}")
else:
    logging.error("Rest (GoNoGo) triggers 6 and/or 7 not found.")
    raise ValueError("Missing Rest (GoNoGo) triggers.")

# Define GoNoGo Task: first GoNoGo start (8) to GoNoGo end (9)
if 8 in trigger_dict and 9 in trigger_dict:
    gonogo_start = trigger_dict[7][0]
    gonogo_end = trigger_dict[9][0]
    task_periods['GoNoGo']['start'] = gonogo_start
    task_periods['GoNoGo']['end'] = gonogo_end
    logging.info(f"GoNoGo: Start={gonogo_start}, End={gonogo_end}")
else:
    logging.error("GoNoGo triggers 8 and/or 9 not found.")
    raise ValueError("Missing GoNoGo triggers.")

# 6.3 Define Mental Imagery Task: second GoNoGo start (8) to second GoNoGo end (9)
# Assuming triggers 8 and 9 occur twice: first for GoNoGo, second for Mental Imagery
if len(trigger_dict.get(8, [])) >= 2 and len(trigger_dict.get(9, [])) >= 2:
    mental_imagery_start = trigger_dict[8][-1]
    mental_imagery_end = trigger_dict[9][-1]
    task_periods['MentalImagery']['start'] = mental_imagery_start
    task_periods['MentalImagery']['end'] = mental_imagery_end
    logging.info(f"MentalImagery: Start={mental_imagery_start}, End={mental_imagery_end}")
else:
    logging.error("Mental Imagery triggers 8 and/or 9 not found.")
    raise ValueError("Missing Mental Imagery triggers.")

sfreq = raw.info['sfreq']
# Calculate Landoit-C Task start and end
if 'Rest_GoNoGo' in task_periods and task_periods['Rest_GoNoGo']['end']:
    # Start: 1 minute after GoNoGo end (9)
    start = task_periods['GoNoGo']['end'] + minutes_to_samples(14, sfreq)
        # Find all occurrences of trigger 6
    rest_starts = trigger_dict.get(6, [])
    if len(rest_starts) >= 2:
        # Second Rest start corresponds to Mental Imagery task
        mental_imagery_rest_start = rest_starts[-1]
        end = mental_imagery_rest_start - minutes_to_samples(1, sfreq)
        task_periods['LandoitC']['start'] = start
        task_periods['LandoitC']['end'] = end
        logging.info(f"LandoitC: Start={start}, End={end}")
    else:
        logging.error("Second Rest start trigger (6) for Mental Imagery task not found.")
        raise ValueError("Missing Mental Imagery Rest start trigger.")
else:
    logging.error("Cannot define LandoitC Task due to missing Rest_GoNoGo task periods.")
    raise ValueError("Missing Rest_GoNoGo task periods.")




# Extract Rest period as a Raw object
rest_period = task_periods['Rest_GoNoGo']
rest_start_time = rest_period['start'] / sfreq  # Convert samples to seconds
rest_end_time = rest_period['end'] / sfreq
rest_raw = raw.copy().crop(tmin=rest_start_time, tmax=rest_end_time)

# Extract go/ no go period
gonogo_period = task_periods['GoNoGo']
gonogo_start_time = gonogo_period['start'] / sfreq  # Convert samples to seconds
gonogo_end_time = gonogo_period['end'] / sfreq
gonogo_raw = raw.copy().crop(tmin=gonogo_start_time, tmax=gonogo_end_time)

# Extract Landoit-C period
landoitc_period = task_periods['LandoitC']
landoitc_start_time = landoitc_period['start'] / sfreq  # Convert samples to seconds
landoitc_end_time = landoitc_period['end'] / sfreq
landoitc_raw = raw.copy().crop(tmin=landoitc_start_time, tmax=landoitc_end_time)

# Extract mental imagery period
mentalimagery_period = task_periods['MentalImagery']
mentalimagery_start_time = mentalimagery_period['start'] / sfreq  # Convert samples to seconds
mentalimagery_end_time = mentalimagery_period['end'] / sfreq
mentalimagery_raw = raw.copy().crop(tmin=mentalimagery_start_time, tmax=mentalimagery_end_time)

# Save Rest period as Raw FIF file
pilot_data_folder = r'../data/pilot_data/'
rest_save_path = op.join(pilot_data_folder, 'Rest.fif')
rest_raw.save(rest_save_path, overwrite=True)
logging.info(f"Saved Rest period as raw FIF file: {rest_save_path}")

# Save GonoGo period as Raw FIF file
pilot_data_folder = r'../data/pilot_data/'
gonogo_save_path = op.join(pilot_data_folder, 'GonoGo.fif')
gonogo_raw.save(gonogo_save_path, overwrite=True)
logging.info(f"Saved Rest period as raw FIF file: {gonogo_save_path}")

# Save Landoit-C period as Raw FIF file
pilot_data_folder = r'../data/pilot_data/'
landoitc_save_path = op.join(pilot_data_folder, 'LandoitC.fif')
landoitc_raw.save(landoitc_save_path, overwrite=True)
logging.info(f"Saved Rest period as raw FIF file: {landoitc_save_path}")

# Save Mental Imagery period as Raw FIF file
pilot_data_folder = r'../data/pilot_data/'
mentalimagery_save_path = op.join(pilot_data_folder, 'MentalImagery.fif')
mentalimagery_raw.save(mentalimagery_save_path, overwrite=True)
logging.info(f"Saved Rest period as raw FIF file: {mentalimagery_save_path}")



Opening raw data file ../data/pilot_data/sub-01_ses-001_raw_preprocessed.fif...


    Range : 0 ... 1135799 =      0.000 ...  3785.997 secs
Ready.
Reading 0 ... 1135799  =      0.000 ...  3785.997 secs...


  raw = mne.io.read_raw_fif(filepath, preload=True)


1509 events found on stim channel Trigger
Event IDs: [  1   2   3   4   6   7   8   9  10  11  13  14  15  16  17  18  21  22
  23  24  25  26  27  28  30  31  32  34  37  38  39  40  41  44  45  46
  47  48  50  51  52  54  64 128]
Events: [[  45028       0       6]
 [  81028       0       7]
 [  81636       0       2]
 ...
 [1109650       0       3]
 [1118947       0       2]
 [1128239       0       9]]
Overwriting existing file.
Writing d:\Yann\scrs\neurotheque_pilots\..\data\pilot_data\Rest.fif
Closing d:\Yann\scrs\neurotheque_pilots\..\data\pilot_data\Rest.fif


  rest_raw.save(rest_save_path, overwrite=True)


[done]
Overwriting existing file.
Writing d:\Yann\scrs\neurotheque_pilots\..\data\pilot_data\GonoGo.fif


  gonogo_raw.save(gonogo_save_path, overwrite=True)


Closing d:\Yann\scrs\neurotheque_pilots\..\data\pilot_data\GonoGo.fif
[done]
Overwriting existing file.
Writing d:\Yann\scrs\neurotheque_pilots\..\data\pilot_data\LandoitC.fif
Closing d:\Yann\scrs\neurotheque_pilots\..\data\pilot_data\LandoitC.fif


  landoitc_raw.save(landoitc_save_path, overwrite=True)


[done]
Overwriting existing file.
Writing d:\Yann\scrs\neurotheque_pilots\..\data\pilot_data\MentalImagery.fif
Closing d:\Yann\scrs\neurotheque_pilots\..\data\pilot_data\MentalImagery.fif


  mentalimagery_raw.save(mentalimagery_save_path, overwrite=True)


[done]


In [None]:
landoitc_raw.plot()

<mne_qt_browser._pg_figure.MNEQtBrowser at 0x20d2a30d510>

Channels marked as bad:
none
