# -----------------
# BLAST-EEG
## Prepocessing and epoching

Author: Marine Thieux

Reviewer/Assistant engineer: Lucie Martinet

# -----------------


In [3]:
import os
import numpy as np
import mne
from datetime import  timezone, timedelta
import neo
import logging
import pandas as pd
from tkinter import filedialog # to select the files and directory paths

# to generate windows and explore the results easily
import matplotlib
matplotlib.use("Qt5Agg")
from PyQt5.QtWidgets  import QMessageBox  # in the pyqt4 tutorials

print(neo.__version__) # run with 0.13.1 
print(mne.__version__) # run with 1.7.0


0.13.1
1.7.0


### Parameters to personalise to read and write data

In [6]:
# -----------------
# Parameters
# -----------------
# Entry file
# Where to read data
file_path = filedialog.askopenfilename(title = "Select the file to process.")

# Results directory, including the preproc data
# Where to save the preproc results data

base_directory_path = filedialog.askdirectory(title = "Select the directory to store the results.")

# Print results for confirmation
print(f"Selected file path: {file_path}")
print(f"Selected base directory path: {base_directory_path}")

Chans    = ['Fp1', 'Fp2', 'F7', 'F3', 'Fz', 'F4', 'F8', 'T3', 'C3', 'Cz', 'C4', 'T4', 'T5', 'P3', 
            'Pz', 'P4', 'T6', 'O1', 'O2', 'EMG1+', 'ECG1+']


Selected file path: 
Selected base directory path: 


### Import datas from .TRC (with neo)

In [None]:
def create_annotations(eeg_micromed) :
    # Triggers collect
    events_time_l = []
    description_l = []
    duration_l = []

    for e in eeg_micromed._events :
        ev_time = np.asarray(e.times.magnitude)
        events_time_l.append(ev_time- float(eeg_micromed._analogsignals[0].t_start)) #* sFreq) # To get sample number from time
        description_l.append( np.asarray(e.labels))
        duration_l.append(np.asarray(np.full(shape = len(ev_time), fill_value = 0.0)))

    events_time = np.concatenate(events_time_l)
    description = np.concatenate(description_l)
    duration = np.concatenate(duration_l)
    onset = events_time
    annotations = mne.Annotations(onset = onset,
                                  duration = duration, 
                                  description = description,
                                  orig_time = eeg_micromed.rec_datetime.replace(tzinfo = timezone.utc) + timedelta(seconds = float(eeg_micromed.analogsignals[0]._t_start)))
    return annotations



In [None]:
def convert_micromed_mneObject(filename): 

    # NEO import of raw data
    eeg_micromed = neo.MicromedIO(filename = filename).read_segment()
    eeg_micromed.annotate(material = "micromed")  
    # Get informations and data to convert it to MNE needs
    # Channels names and types
    chan = list()
    for cc in range(np.size(eeg_micromed.analogsignals[0].array_annotations["channel_names"])):
        chan.append(eeg_micromed.analogsignals[0].array_annotations["channel_names"][cc])

    # Sample frequency
    sFreq = eeg_micromed.analogsignals[0].sampling_rate 

    # Data
    convert_fac = 1. * eeg_micromed.analogsignals[0].units
    convert_fac.units = 'V' # to volts
    data = np.asarray(eeg_micromed.analogsignals[0].segment._analogsignals[0]).transpose()
 
    data = data * convert_fac.magnitude  

    raw_info = mne.create_info(chan, sFreq, ch_types="eeg")   
    raw_info.set_meas_date(eeg_micromed.rec_datetime.replace(tzinfo = timezone.utc) + timedelta(seconds = float(eeg_micromed.analogsignals[0]._t_start)) )
 
    raw = mne.io.RawArray(data, raw_info, verbose = 1)

    annotations = create_annotations(eeg_micromed)
    
    raw.set_annotations(annotations)

    return raw

In [None]:
raw = convert_micromed_mneObject(file_path)

### Channels / montage selection

In [None]:
# Get the list of channels that actually exist in the raw data
existing_channels = [chan for chan in Chans if chan in raw.ch_names]

# If there are any matching channels, pick them
if existing_channels:
    raw.pick(existing_channels)
else:
    print("None of the channels in the 'Chans' list are present in the raw data.")
chtype = {'EMG1+': 'emg', 'ECG1+': 'ecg'}
raw.set_channel_types(chtype)

raw.set_montage("standard_1020") # Apply a template montage 
raw.set_eeg_reference('average') # Set the reference

# delete bad channels manually 
scalings = dict(eeg = 50e-6)
raw.plot(duration = 30, scalings = scalings)

### Filter the data

In [None]:
# High pass and low pass filtering
raw.filter(l_freq = 1., h_freq = 70, h_trans_bandwidth = 'auto', l_trans_bandwidth = 'auto', filter_length = 'auto', phase = 'zero')

In [None]:
# Ask if we apply the notch filter depending on the existance of 50Hz pic in the psd
raw.compute_psd().plot()
qm = QMessageBox()

qm.setText("Do you want to apply a notch filter based on 50Hz?")
qm.setStandardButtons(QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No)
ret = 0
ret = qm.exec()

print("Type ret : ", type(ret), ret)
if ret == QMessageBox.StandardButton.Yes: 
    raw.notch_filter(np.arange(50, int(raw.info['sfreq']/2)-1, 50), filter_length = 'auto', phase = 'zero') # notch
    print("Notch filter applied")


In [None]:
# Downsampling at 256Hz
orig_raw = raw.copy()
raw = raw.resample(sfreq = 256)
raw.plot(duration = 30, scalings = scalings)

### Export preprocessed raw to .fif (optional)

In [None]:
# Optionnal step
qm = QMessageBox()

qm.setText("Do you want to save your work in a .fif file ?")
qm.setStandardButtons(QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No)
save_bool = 0
save_bool = qm.exec()
if save_bool == "y" :
    # Extract the EPI number and file name from the input file path
    parts = file_path.split('/')
    epi_number = parts[-2]  
    filename = parts[-1] 

    # Construct the new file path
    new_filename = filename.replace('.TRC', '.fif')
    new_file_path = os.path.join(base_directory_path, epi_number, new_filename).replace("\\", "/") # replace for those who work on windows

    # Ensure the target directory exists
    os.makedirs(os.path.dirname(new_file_path), exist_ok=True)

    # Save the preprocessed raw data
    raw.save(new_file_path, overwrite=True)

    print(f"Preprocessed data saved to: {new_file_path}")

### Epochs 

In [None]:
# Define the function to assign event IDs based on the first letter of the description
def assign_event_id(description):
    first_letter = description[0]
    if first_letter == "b":
        return 77
    elif first_letter == "a" or first_letter == "c" or first_letter == "e" :
        return 66
    elif first_letter == "d" or first_letter == "f" :
        return 77
    elif description in ["101", "102", "103", "104", "105"]:
        return 100
    else:
        return None  # Ignore other annotations

In [None]:
# Ensure correct pairing of start and end events
def make_paired_events (all_ied_events) :
    paired_events = []
    current_start = None

    for event_time, event_type in all_ied_events:
        if event_type == 'start':
            # Set current start when a "start" event is encountered
            if current_start is None:
                current_start = event_time
            else:
                # If we encounter another "start" without an "end", we ignore the previous one
                print(f"Warning: Unmatched start event at {current_start} ignored.")
                current_start = event_time  # Reset with the new start event
        elif event_type == 'end':
            if current_start is not None:
                # Pair this "end" with the current "start"
                if event_time > current_start:
                    paired_events.append((current_start, event_time))
                    current_start = None  # Reset the current start after a successful pair
                else:
                    # If the "end" comes before the "start", ignore it (shouldn't happen with sorted events)
                    print(f"Warning: End event at {event_time} comes before the start event.")
            else:
                # Ignore "end" events that have no matching "start"
                print(f"Warning: Unmatched end event at {event_time} ignored.")
    return paired_events

In [None]:
def create_paired_events (paired_events, raw):
    if paired_events:
        start_ied_events, end_ied_events = zip(*paired_events)
        start_ied_events = list(start_ied_events)
        end_ied_events = list(end_ied_events)

        # Calculate durations
        ied_durations = [end - start for start, end in zip(start_ied_events, end_ied_events)]

        # Create new annotations that span from IED_start to IED_end
        ied_annotation = mne.Annotations(
            onset = start_ied_events,  # Start times
            duration = ied_durations,  # Durations calculated from start and end times
            description = ["IED"] * len(start_ied_events), # Description for the events
            orig_time = raw.info["meas_date"]
        )

        # Plot events with the event dict
        fig = mne.viz.plot_events(all_events, raw.info['sfreq'], event_id = event_dict, on_missing = 'ignore')

        # Print the new IED annotations for verification
        print(ied_annotation)

        # Set the new annotations while keeping the existing ones
        raw.set_annotations(raw.annotations + ied_annotation)

        # Drop bad channels as needed
        raw.drop_channels(raw.info['bads'])
    else:
        print("No valid IED annotations were created.")
    return raw

In [None]:
# Read triggers as events / annotations 
print(len(raw.annotations))
mne.count_annotations(raw.annotations)


In [None]:
# Apply the function to create the custom mapping
custom_mapping = {annot['description']: assign_event_id(annot['description']) for annot in raw.annotations}

# Remove entries with None values (ignored annotations)
custom_mapping = {k: v for k, v in custom_mapping.items() if v is not None}

In [None]:
# Extract events from annotations using the custom mapping
all_events, all_event_id = mne.events_from_annotations(raw, event_id = custom_mapping) 

In [None]:
# Define event IDs for "4 Letters" and IED
stim_id = {"4 Letters": 100,}
ied_id = { "start": 66, "end": 77,}

event_dict = stim_id.copy()
event_dict.update(ied_id)

In [None]:
# Extract start and end times based on event IDs
start_ied_events = all_events[all_events[:, 2] == ied_id["start"]][:, 0] / raw.info['sfreq']
end_ied_events = all_events[all_events[:, 2] == ied_id["end"]][:, 0] / raw.info['sfreq']

# Combine start and end events into one list for ordered processing
all_ied_events = sorted(
    [(event, 'start') for event in start_ied_events] + [(event, 'end') for event in end_ied_events]
)

# Ensure correct pairing of start and end events
paired_events = make_paired_events(all_ied_events)

# Create annotations based on the paired events
raw = create_paired_events (paired_events, raw)

### Mark and drop bad segments

In [None]:
# Open with "a" then "enter" and drag 
scalings = dict(eeg = 50e-6)
raw.plot(duration = 20, scalings = scalings)

In [None]:
# Function to check if IED is within the epoch
def is_ied_within_epoch(epoch_start, epoch_end, ied_annotations):
    for annot in ied_annotations:
        ied_start = annot['onset']
        ied_end = ied_start + annot['duration']
        if (ied_start < epoch_end) and (ied_end > epoch_start):
            return True
    return False

In [None]:
# Custom mapping of annotations to event IDs
custom_mapping = {
    "101": 100, 
    "102": 100, 
    "103": 100, 
    "104": 100, 
    "105": 100, 
    "IED": 88
}

In [None]:
# Extract events from annotations
all_events, all_event_id = mne.events_from_annotations(raw, event_id=custom_mapping)

In [None]:
# Define event IDs for "4 Letters" and IED
stim_id = {"4 Letters": 100}
ied_id = {"IED": 88}

# Define epochs parameters
tmin = -1  # start of the epoch (before the "4 Letters" event)
tmax = 1   # end of the epoch (after the "4 Letters" event)

# Create epochs based on "4 Letters" events
epochs = mne.Epochs(raw, events = all_events, event_id = stim_id, tmin = tmin, tmax = tmax, baseline = (-0.2,0), preload = True, reject_by_annotation = ['bad'])


In [None]:
def epoch_ied(epochs) :
    # Initialize lists to store epochs with and without "IED"
    epochs_with_ied = []
    epochs_without_ied = []

    # Initialize lists to store metadata for each epoch
    metadata_with_ied = []
    metadata_without_ied = []


    # Iterate through epochs to check for "IED" annotations
    for epoch_idx, epoch in enumerate(epochs):
        # Get the start and end times of the current epoch in seconds
        epoch_start = epochs[epoch_idx].events[0, 0] / raw.info['sfreq'] + tmin
        epoch_end = epoch_start + (tmax - tmin)
        
        # Check if any part of the IED event overlaps with the epoch time range
        ied_annotations_within_epoch = [annot for annot in raw.annotations if annot['description'] == 'IED']
        
        if is_ied_within_epoch(epoch_start, epoch_end, ied_annotations_within_epoch):
            #epochs_with_ied.append(epoch_idx)
            metadata_with_ied.append({'description': 'With IED'})
        else:
            #epochs_without_ied.append(epoch_idx)
            metadata_without_ied.append({'description': 'Without IED'})
    return metadata_with_ied, metadata_without_ied

metadata_with_ied, metadata_without_ied = epoch_ied(epochs)

In [None]:
# Create metadata DataFrame
metadata = pd.DataFrame(metadata_with_ied + metadata_without_ied)

# Add metadata to epochs
epochs.metadata = metadata

# Select epochs with and without IED
epochs_with_ied = epochs[metadata['description'] == 'With IED']
epochs_without_ied = epochs[metadata['description'] == 'Without IED']

In [None]:
# Total number of epochs before dropping (initial number of epochs created)
total_epochs_created = len(epochs.drop_log)

# Number of epochs dropped due to bad annotations
bad_dropped_epochs = [log for log in epochs.drop_log if 'BAD_' in log]
bad_dropped_count = len(bad_dropped_epochs)

# Number of epochs kept (total created - bad dropped)
epochs_kept_count = total_epochs_created - bad_dropped_count

# Calculate the proportion of dropped epochs relative to the epochs kept
proportion_dropped_bad = bad_dropped_count / total_epochs_created

# Output the results
print(f"Total epochs created: {total_epochs_created}")
print(f"Number of epochs dropped due to bad annotations: {bad_dropped_count}")
print(f"Proportion of epochs dropped due to bad annotations: {proportion_dropped_bad:.2%}")

In [None]:
# Here reject bad epochs manually
epochs.plot(n_epochs = 10, events = all_events, event_id = event_dict, scalings = scalings)

### Export epoched data

In [None]:
def save_dataframe_as_csv(dataframe, base_path, base_filename, suffix):
    try:
        # Construct the full file path
        csv_filename = f"{base_filename}{suffix}.csv"
        file_path = os.path.join(base_path, csv_filename)
        
        # Ensure the target directory exists
        os.makedirs(base_path, exist_ok=True)
        
        # Save the dataframe to CSV
        dataframe.to_data_frame().to_csv(file_path)
        logging.info(f"Dataframe saved to: {file_path}")
    except Exception as e:
        logging.error(f"Failed to save dataframe {suffix}: {e}")

In [None]:
# Set up logging
logging.basicConfig(level = logging.INFO)

# Extract the EPI number and file name from the input file path
parts = file_path.split('/')
epi_number = parts[-2]  
filename = parts[-1]

# Construct the base file path
base_filename = filename.replace('.fif', '')  # Removing the extension for flexibility

#############PATH FILE#################
base_path = os.path.join(base_directory_path, epi_number)

# Save dataframe to CSV
save_dataframe_as_csv(epochs, base_path, base_filename, '_epochs')
save_dataframe_as_csv(epochs_with_ied, base_path, base_filename, '_epochs_with_IED')
save_dataframe_as_csv(epochs_without_ied, base_path, base_filename, '_epochs_without_IED')

# Save the epochs to a .fif file
epochs.save(f'{base_path}/{base_filename}_epochs.fif', overwrite = True)
epochs_with_ied.save(f'{base_path}/{base_filename}_epochs_with_IED.fif', overwrite = True)
epochs_without_ied.save(f'{base_path}/{base_filename}_epochs_without_IED.fif', overwrite = True)