# EEG preprocessing 
### Effect of prediction and attention on C1 component
Written by Maximilien Van Migem
and Created on 22/11/2023 

In [None]:
# Import some libraries
import os
import numpy as np
import mne
import pandas as pd

%matplotlib qt 

# Subject selection and loading data
sub = 26
# reference = 'average'
reference = ['M1','M2'] # Mastoids

data_directory = 'C:/Users/mvmigem/Documents/data/project_1/'
rejected_channels_path = data_directory + 'rejected_channels.npy'
rejected_ica_path = data_directory + 'mastoid_rejected_ica.npy'
# rejected_ica_path = data_directory + 'average_rejected_ica.npy'
rejected_ica = np.load(rejected_ica_path, allow_pickle=True)[..., np.newaxis][0]
current_file_path = data_directory + f'raw_data/sub_{sub}/eeg/main_{sub}.bdf'
current_behav_path = data_directory + f'raw_data/sub_{sub}/behav/predatt_participant_{sub}.csv'
behav_data = pd.read_csv(current_behav_path)

raw = mne.io.read_raw_bdf(current_file_path, preload = True)

# we rename the channels to their respective external electrodes
fix_chans = {'EXG1':'eye_above','EXG2':'eye_below',
             'EXG3':'eye_left','EXG4':'eye_right',
             'EXG5':'M1','EXG6':'M2'}
raw.rename_channels(fix_chans)

# we still have two exg channels which weren't actually recorded though (EXG7
# and EXG8) these are empty, so we'll drop them
raw.drop_channels(['EXG7', 'EXG8'])
print(raw.info['ch_names'])

# we'll also reset the channel types, so MNE knows what is 'brain' data
raw.set_channel_types({'M1':'eeg', 'M2':'eeg',
                       'eye_above':'eog', 'eye_below':'eog',
                       'eye_left':'eog', 'eye_right': 'eog'})

print(raw.info)

# # rereference to mastoids
# raw.set_eeg_reference(ref_channels = ['M1','M2'])

# rereference
raw.set_eeg_reference(ref_channels = reference)

raw.drop_channels(['M1','M2'])

## I'll drop the mastoids because I'm not using them
# raw.drop_channels(['mastoid_left', 'mastoid_right'])
# Set montage
montage = mne.channels.make_standard_montage('biosemi64')

# This dict is to rename the channel names to fit the montage
mon_chnames = montage.ch_names
raw_chnames = raw.info['ch_names']
rename_channels = dict(zip(raw_chnames[:64],mon_chnames))
raw.rename_channels(rename_channels)
# Set montage
raw.set_montage(montage)
# take another look
# raw.plot(n_channels = 64)

# we can annotate bad spans of raw data in the interactive plot by pressing the
# 'a' key, useful if we know there is a period we aren't interested in such
# as a mis-start of the task etc.
raw_annot = raw.copy()

# we can also identify eog events algorithmically via "find_eog_events" this
# produces a list of 'events' around each blink (hopefully). This applies a 
# filter and then identifies peaks in the eog to find likely blinks. We can 
# adjust the threshold, via thresh. but default should be okay for now.
eog_events = mne.preprocessing.find_eog_events(raw_annot)
# we'll say that the blinks start a tiny bit earlier than 
onsets = eog_events[:, 0] / raw_annot.info["sfreq"] - 0.25
# we'll assume they're all half a second long
dur = [0.5] * len(eog_events)
descriptions = ["bad blink"] * len(eog_events)
blink_annot = mne.Annotations(onsets,
                              dur,
                              descriptions,
                              orig_time = raw_annot.info["meas_date"])
raw_annot.set_annotations(blink_annot)

# let's take a look at what has been detected

# Downsampling variables (logic -> https://mne.tools/stable/auto_tutorials/preprocessing/30_filtering_resampling.html#best-practices)
current_sfreq = raw.info['sfreq']
desired_sfreq = 256  # Hz
decim = np.round(current_sfreq / desired_sfreq).astype(int)
obtained_sfreq = current_sfreq / decim
lowpass_freq = obtained_sfreq / 3.


raw_filtered = raw_annot.copy().notch_filter(freqs = 50, fir_design = 'firwin', verbose=None, )
raw_filtered = raw_filtered.copy().filter(l_freq=0.1, h_freq=lowpass_freq)


# Plotting for potential channel rejection
raw_filtered.plot_psd()
raw_filtered.plot(n_channels=64,block=True)

# Load the rejected channel
rejected_channels = np.load(rejected_channels_path, allow_pickle=True)[..., np.newaxis][0]
rejected_channels[f'subject_{sub}'] = raw_filtered.info['bads']  
np.save(rejected_channels_path, rejected_channels)
# Interpolate rejected channels
interp_filt_raw = raw_filtered.copy().interpolate_bads(reset_bads = False)

# Select event dict for condition
if behav_data['start_position'].isin([0, 2]).any():
    event_id = {
    'start_trial':99, 'pos1/seq':11, 'pos1/seq3':13, 
    'pos2/seq2':22, 'pos2/seq4':24,
    'pos3/seq1':31, 'pos3/seq3':33,
    'pos4/seq2':42, 'pos4/seq4':44,
    }
elif behav_data['start_position'].isin([1, 3]).any():
    # Event dict
    event_id = {
        'start_trial':99, 'pos1/seq2':12, 'pos1/seq4':14, 
        'pos2/seq1':21, 'pos2/seq3':23,
        'pos3/seq2':32, 'pos3/seq4':34,
        'pos4/seq1':41, 'pos4/seq3':43,
    }
events = mne.find_events(interp_filt_raw)

# we can visualise the paradigm (timecourse of the events), to confirm nothing
# weird has happened
# fig = mne.viz.plot_events(events, 
#                           sfreq = interp_filt_raw.info['sfreq'],
#                           event_id = event_id)


# Define your threshold in seconds
threshold_ms = 1000
sfreq = interp_filt_raw.info['sfreq']  # Sampling frequency of your data
threshold_samples = int(threshold_ms / 1000 * sfreq)

# Calculate differences between consecutive events
event_times = events[:, 0]  # Extract the sample index (first column) of each event
time_diffs = np.diff(event_times)

# Identify where time differences exceed the threshold
long_gaps = time_diffs > threshold_samples

# Find the periods where the distance exceeds the threshold
indices_exceeding_threshold = np.where(long_gaps)[0]

# Create a list to hold the annotations
annotations = []

for idx in indices_exceeding_threshold:
    start_sample = events[idx, 0] + (0.52*sfreq) # Start of the period + 500ms for preceding trial
    end_sample = events[idx + 1, 0] - 1 # End of the period (-1 to avoid removing trial trigger)
    start_time = (start_sample / sfreq)  # Convert sample index to time in seconds and add padding for epoch
    duration = (end_sample - start_sample) / sfreq  # Duration in seconds

    # Create an annotation
    annotation = mne.Annotations(onset=start_time,
                                 duration=duration,
                                 description=f'bad_calibration_gap')
    
    # Append annotation to the list
    annotations.append(annotation)

# Convert the list of annotations to a single mne.Annotations object
if annotations:
    combined_annotations = annotations[0]
    for annotation in annotations[1:]:
        combined_annotations += annotation
    
    # Add the annotations to the raw object
    interp_filt_raw.set_annotations(combined_annotations)
else:
    print("No gaps exceeding the threshold were found.")


# ICA
ica = mne.preprocessing.ICA(n_components = 0.99)
ica.fit(interp_filt_raw,decim=2, verbose='error', reject_by_annotation=True)
ica.plot_components()

interp_filt_raw.plot(events=events,n_channels=64,)


In [None]:
# Save the rejected ica's
exclude_ica = [4,14,19]

rejected_ica[f'subject_{sub}'] = exclude_ica
np.save(rejected_ica_path, rejected_ica)

# Exclude ica
ica.exclude=exclude_ica
ica.apply(interp_filt_raw)

# Sav the raw data
interp_filt_raw.save(f"C:/Users/mvmigem/Documents/data/project_1/preprocessed/mastoid_raw/main_clean_mastoidref_{sub:02}-raw.fif", overwrite=True)

# Epoch data around stim onset
epochs_stimlock = mne.Epochs(interp_filt_raw, events, event_id = event_id,
    tmin = -0.5, tmax = 0.5, proj = False, baseline = (None,0), decim=decim, #from previous cell
    detrend = None, verbose = True, reject_by_annotation= False, preload = True)

# Save processed epoch file
# epochs_stimlock.save(f"C:/Users/mvmigem/Documents/data/project_1/preprocessed/average_ref/unpaired/main_clean_averageref_{sub:02}-epo.fif", overwrite=True)
epochs_stimlock.save(f"C:/Users/mvmigem/Documents/data/project_1/preprocessed/mastoid_ref/unpaired/main_clean_mastoidref_{sub:02}-epo.fif", overwrite=True)

In [None]:
# # Epoch data around stim onset
# epochs_stimlock_for_ica = mne.Epochs(interp_filt_raw, events, event_id = event_id,
#     tmin = -0.1, tmax = 0.45, proj = False, baseline = (None,0), decim=decim, #from previous cell
#     detrend = None, verbose = True, reject_by_annotation= True, preload = True)
# ICA
ica = mne.preprocessing.ICA(n_components = 0.99)
ica.fit(interp_filt_raw,decim=2, verbose='error', reject_by_annotation=True)
ica.plot_components()

# interp_filt_raw.plot(events=events,n_channels=64,)

In [None]:
ica.plot_components()

In [None]:
rejected_channels[f'subject_{sub}'] = interp_filt_raw.info['bads']
np.save(rejected_channels_path, rejected_channels)
interp_filt_raw = interp_filt_raw.copy().interpolate_bads(reset_bads = True)

In [None]:
rejected_ica[f'subject_{sub}'] = exclude_ica
np.save(rejected_ica_path, rejected_ica)