#### Deep Learning on Dutch Dyslexia Program dataset

Inspired by: https://github.com/epodium/EEG_age_prediction

(- Load subject ages from excel file)
- Load raw EEG data from .cnt files
- Apply filters (bandpass)
- Detect potential bad channels and replace them by interpolation
- Detect potential bad epochs and remove them
- Save processed data (and metadata)

#### Imports

In [1]:
import os              
import sys
import glob
import numpy as np  
import pandas as pd 
import mne
from IPython.display import clear_output

# PATH_MAIN = os.path.join(main_path, 'researchdrive', 'ePodium (Projectfolder)')
PATH_MAIN = os.path.join('D:', 'EEG Data', 'DDP Surfdrive')
PATH_METADATA = os.path.join(PATH_MAIN, 'metadata')
PATH_PROCESSED = os.path.join(PATH_MAIN, 'processed')

In [None]:
# from eegyolk initialization_functions

def load_dataset(folder_dataset, file_extension = '.bdf', preload=True, max_files_preloaded = 5):
    '''
    This function is for datasets under 5 files. Otherwise use generator_load_dataset
    Reads and returns the files that store the EEG data,
    along with a list of the filenames and paths of these bdf files. 
    Takes as input the top folder location of the dataset.
    '''
    pattern = os.path.join(folder_dataset, '**/*' + file_extension)
    eeg_filepaths = glob.glob(pattern, recursive=True)
    eeg_dataset = []
    eeg_filenames = []
    eeg_filenames_failed_to_load = []

    files_loaded = 0
    files_failed_to_load = 0
    for path in eeg_filepaths:
        filename = os.path.split(path)[1].replace(file_extension, '')

        if(file_extension == '.bdf'):
            raw = mne.io.read_raw_bdf(path,preload=preload)

        if(file_extension == '.cnt'): # .cnt files do not always load.
            try:
                raw = mne.io.read_raw_cnt(path,preload=preload)
            except:
                eeg_filenames_failed_to_load.append(filename)
                files_failed_to_load += 1
                print(f"File {filename} could not be loaded.") 
                continue

        eeg_dataset.append(raw)
        eeg_filenames.append(filename)
        files_loaded += 1
        print(files_loaded, "EEG files loaded")
        if preload and files_loaded >= max_files_preloaded : break

        clear_output(wait=True)
    print(len(eeg_dataset), "EEG files loaded")
    if(files_failed_to_load>0): print(files_failed_to_load, "EEG files failed to load")

    return eeg_dataset, eeg_filenames

#### Load Dataset

In [3]:
# folder_ddp_dataset = os.path.join(data_path, "DDP Dataset") # folder in Surf research drive
main_folder_ddp_dataset = os.path.join("D:", "EEG Data", "DDP Surfdrive") # local folder

ddp_age_folders = ['5mnd mmn', '11mnd mmn', '17mnd mmn', '23mnd mmn',
                    '29mnd mmn', '35mnd mmn', '41mnd mmn', '47mnd mmn']

for i, ddp_age_group in enumerate(ddp_age_folders):
    age_group_folder_location = os.path.join(main_folder_ddp_dataset, ddp_age_folders[i])
    epod_raw, epod_filenames = load_dataset(age_group_folder_location, file_extension = '.cnt', preload=False)
    epod_raw_preload, epod_filenames_preload = load_dataset(age_group_folder_location, file_extension = '.cnt')


177 EEG files loaded


#### 

- From https://github.com/epodium/EEG_age_prediction/blob/main/Notebooks/Deep%20learning%20EEG_dataset%20preprocessing_DL.ipynb

In [None]:
def read_cnt_file(file,
                  label_group,
                  event_idx = [2, 3, 4, 5, 12, 13, 14, 15],
                  channel_set = "30",
                  tmin = -0.2,
                  tmax = 0.8,
                  lpass = 0.5, 
                  hpass = 40, 
                  threshold = 5, 
                  max_bad_fraction = 0.2,
                  max_bad_channels = 2):
    """ Function to read cnt file. Run bandpass filter. 
    Then detect and correct/remove bad channels and bad epochs.
    Store resulting epochs as arrays.
    
    Args:
    --------
    file: str
        Name of file to import.
    label_group: int
        Unique ID of specific group (must be >0).
    channel_set: str
        Select among pre-defined channel sets. Here: "30" or "62"
    """
    
    if channel_set == "30":
        channel_set = ['O2', 'O1', 'OZ', 'PZ', 'P4', 'CP4', 'P8', 'C4', 'TP8', 'T8', 'P7', 
                       'P3', 'CP3', 'CPZ', 'CZ', 'FC4', 'FT8', 'TP7', 'C3', 'FCZ', 'FZ', 
                       'F4', 'F8', 'T7', 'FT7', 'FC3', 'F3', 'FP2', 'F7', 'FP1']
    elif channel_set == "62":
        channel_set = ['O2', 'O1', 'OZ', 'PZ', 'P4', 'CP4', 'P8', 'C4', 'TP8', 'T8', 'P7', 
                       'P3', 'CP3', 'CPZ', 'CZ', 'FC4', 'FT8', 'TP7', 'C3', 'FCZ', 'FZ', 
                       'F4', 'F8', 'T7', 'FT7', 'FC3', 'F3', 'FP2', 'F7', 'FP1', 'AFZ', 'PO3', 
                       'P1', 'POZ', 'P2', 'PO4', 'CP2', 'P6', 'M1', 'CP6', 'C6', 'PO8', 'PO7', 
                       'P5', 'CP5', 'CP1', 'C1', 'C2', 'FC2', 'FC6', 'C5', 'FC1', 'F2', 'F6', 
                       'FC5', 'F1', 'AF4', 'AF8', 'F5', 'AF7', 'AF3', 'FPZ']
    else:
        print("Predefined channel set given by 'channel_set' not known...")
        
    
    # Initialize array
    signal_collection = np.zeros((0,len(channel_set),501))
    label_collection = [] #np.zeros((0))
    channel_names_collection = []
    
    # Import file
    try:
        data_raw = mne.io.read_raw_cnt(file, eog='auto', preload=True, verbose=False)
    except ValueError:
        print("ValueError")
        print("Could not load file:", file)
        return None, None, None
    
    # Band-pass filter (between 0.5 and 40 Hz. was 0.5 to 30Hz in Stober 2016)
    data_raw.filter(0.5, 40, fir_design='firwin')

    # Get events from annotations in the data
    events_from_annot, event_dict = mne.events_from_annotations(data_raw)
    
    # Set baseline:
    baseline = (None, 0)  # means from the first instant to t = 0

    # Select channels to exclude (if any)
    channels_exclude = [x for x in data_raw.ch_names if x not in channel_set]
    channels_exclude = [x for x in channels_exclude if x not in ['HEOG', 'VEOG']]
    
    for event_id in event_idx:
        if str(event_id) in event_dict:
            # Pick EEG channels
            picks = mne.pick_types(data_raw.info, meg=False, eeg=True, stim=False, eog=False,
                               #exclude=data_exclude)#'bads'])
                                   include=channel_set, exclude=channels_exclude)#'bads'])

            epochs = mne.Epochs(data_raw, events=events_from_annot, event_id=event_dict,
                                tmin=tmin, tmax=tmax, proj=True, picks=picks,
                                baseline=baseline, preload=True, event_repeated='merge', verbose=False)

            # Detect potential bad channels and epochs
            bad_channels, bad_epochs = helper_functions.select_bad_epochs(epochs,
                                                                          event_id,
                                                                          threshold = threshold,
                                                                          max_bad_fraction = max_bad_fraction)

            # Interpolate bad channels
            # ------------------------------------------------------------------
            if len(bad_channels) > 0:
                if len(bad_channels) > max_bad_channels:
                    print(20*'--')
                    print("Found too many bad channels (" + str(len(bad_channels)) + ")")
                    return None, None, None
                else:
                    montage = mne.channels.make_standard_montage('standard_1020')
                    montage.ch_names = [ch_name.upper() for ch_name in montage.ch_names]
                    data_raw.set_montage(montage)
                    
                    # MARK: Think about using all channels before removing (62 -> 30), to enable for better interpolation
                    
                    # Mark bad channels:
                    data_raw.info['bads'] = bad_channels
                    # Pick EEG channels:
                    picks = mne.pick_types(data_raw.info, meg=False, eeg=True, stim=False, eog=False,
                                       #exclude=data_exclude)#'bads'])
                                       include=channel_set, exclude=channels_exclude)#'bads'])
                    epochs = mne.Epochs(data_raw, events=events_from_annot, event_id=event_dict,
                                        tmin=tmin, tmax=tmax, proj=True, picks=picks,
                                        baseline=baseline, preload=True, verbose=False)
                    
                    # Interpolate bad channels using functionality of 'mne'
                    epochs.interpolate_bads()
                    

            # Get signals as array and add to total collection
            channel_names_collection.append(epochs.ch_names)
            signals_cleaned = epochs[str(event_id)].drop(bad_epochs).get_data()
            signal_collection = np.concatenate((signal_collection, signals_cleaned), axis=0)
            label_collection += [event_id + label_group] * signals_cleaned.shape[0]

    return signal_collection, label_collection, channel_names_collection