In [1]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm, trange
import mne
import os
import re
import torch
import pickle

from utils import *

In [2]:
task_dict = {
    'Resting state': {'annotation': 'T0', 'runs': ['03', '04', '05', '06','07', '08', '09', '10', '11', '12', '13', '14']},
    'Left fist, performed': {'annotation': 'T1', 'runs': ['03', '07' '11']},
    'Left fist, imagined': {'annotation': 'T1', 'runs': ['04', '08', '12']},
    'Both fists, performed': {'annotation': 'T1', 'runs': ['05', '09', '13']},
    'Both fists, imagined': {'annotation': 'T1', 'runs': ['06', '10', '14']},
    'Right fist, performed': {'annotation': 'T2', 'runs': ['03', '07', '11']},
    'Right fist, imagined': {'annotation': 'T2', 'runs': ['04', '08', '12']},
    'Both feet, performed': {'annotation': 'T2', 'runs': ['05', '09', '13']},
    'Both feet, imagined': {'annotation': 'T2', 'runs': ['06', '10', '14']}
}


label_dict = {
    'T0': {'03': 'Resting state', '04': 'Resting state', '05': 'Resting state', '06': 'Resting state', '07': 'Resting state', '08': 'Resting state', '09': 'Resting state', '10': 'Resting state', '11': 'Resting state', '12': 'Resting state', '13': 'Resting state', '14': 'Resting state'},
    'T1': {'03': 'Left fist, performed', '04': 'Left fist, imagined', '05': 'Both fists, performed', '06': 'Both fists, imagined', '07': 'Left fist, performed', '08': 'Left fist, imagined', '09': 'Both fists, performed', '10': 'Both fists, imagined', '11': 'Left fist, performed', '12': 'Left fist, imagined', '13': 'Both fists, performed', '14': 'Both fists, imagined'},
    'T2': {'03': 'Right fist, performed', '04': 'Right fist, imagined', '05': 'Both feet, performed', '06': 'Both feet, imagined', '07': 'Right fist, performed', '08': 'Right fist, imagined', '09': 'Both feet, performed', '10': 'Both feet, imagined', '11': 'Right fist, performed', '12': 'Right fist, imagined', '13': 'Both feet, performed', '14': 'Both feet, imagined'}
}

In [7]:
def get_raw(edf_file_path):
    # Read the EDF file and load the data into a Raw object
    raw = mne.io.read_raw_edf(edf_file_path, verbose=False, preload=True)

    # Standardize channel names according to the EEGBCI dataset
    mne.datasets.eegbci.standardize(raw)

    # Create a standard 10-20 montage
    montage = mne.channels.make_standard_montage('standard_1020')

    # Rename channels, removing dots and capitalizing names, with some specific adjustments
    new_names = {
        ch_name: ch_name.rstrip('.').upper().replace('Z', 'z').replace('FP', 'Fp')
        for ch_name in raw.ch_names
    }
    raw.rename_channels(new_names)

    # Set average reference for EEG channels and apply the montage
    raw.set_eeg_reference(ref_channels='average', projection=True, verbose=False)
    raw.set_montage(montage)

    # Apply projection to the data
    raw.apply_proj(verbose=False)

    # Rename P7 and P8 channels to T5 and T6, if present
    if 'P7' in raw.ch_names:
        raw.rename_channels({'P7': 'T5'})
    if 'P8' in raw.ch_names:
        raw.rename_channels({'P8': 'T6'})

    # Remove dots from channel names and capitalize them
    new_names = {ch_name: ch_name.rstrip('.').upper() for ch_name in raw.ch_names}
    raw.rename_channels(new_names)
    
    EEG_20_div = [
                'FP1', 'FP2',
        'F7', 'F3', 'FZ', 'F4', 'F8',
        'T7', 'C3', 'CZ', 'C4', 'T8',
        'T5', 'P3', 'PZ', 'P4', 'T6',
                 'O1', 'O2'
    ]


    # Select only EEG channels
    raw.pick_channels(EEG_20_div)

    # Reorder channels
    raw.reorder_channels(EEG_20_div)

    # Resample the data to 256 Hz
    raw.resample(256)

    # Apply bandpass filter between 0.1 and 100 Hz
    raw.filter(0.1, 100, verbose=False)

    # Apply notch filter at 60 Hz
    raw.notch_filter(60, verbose=False)

    return raw

def normalize(x, minimum=-0.00125, maximum=0.00125):
    return (x - minimum) / (maximum - minimum) * 2 - 1

In [8]:
DATA_PATH_RAW = '/home/williamtheodor/Documents/DL for EEG Classification/data/eegmmidb (raw)/files/'
DATA_PATH_LABELLED = '/home/williamtheodor/Documents/DL for EEG Classification/data/eegmmidb (labelled)/'


# create folder for each label
for label in task_dict.keys():
    if not os.path.exists(DATA_PATH_LABELLED + label):
        os.makedirs(DATA_PATH_LABELLED + label)

##

In [9]:
NUMBER_SAMPLES = 1024
NUMBER_CHANNELS = 20

patients_to_exclude = ['S088', 'S089', 'S090', 'S092', 'S104', 'S106']

for patient_folder in tqdm(os.listdir(DATA_PATH_RAW)):
    if patient_folder not in patients_to_exclude:
        for file in os.listdir(DATA_PATH_RAW + patient_folder + '/'):
            if file.endswith('.edf') and file[5:7] not in ['01', '02']:
                file_path = DATA_PATH_RAW + patient_folder + '/' + file

                raw = get_raw(file_path)
                annotations = get_annotations(file_path)

                annotation_dict = get_window_dict(raw, annotations)

                for annotation in annotation_dict.keys():
                    for raw in annotation_dict[annotation]:                  

                        #try: 
                        x = np.zeros((1, NUMBER_CHANNELS, NUMBER_SAMPLES))
                        x[:,:19,:] = raw.copy().get_data()[:,:NUMBER_SAMPLES].reshape(1,NUMBER_CHANNELS-1,NUMBER_SAMPLES)

                        # normalize the data
                        x = normalize(x)

                        x[:,19,:] = np.ones((1, NUMBER_SAMPLES)) * -1  
                        x = torch.from_numpy(x).float()

                        label = label_dict[annotation][file[5:7]]
                        
                        picklePath = DATA_PATH_LABELLED + label + '/' + file[0:7] + '_' + annotation + '_' + label + '.pkl'
                        with open(picklePath, 'wb') as handle:
                            pickle.dump(x, handle, protocol=pickle.HIGHEST_PROTOCOL)
                        #except:
                         #   print('Error with file: ' + file, raw.copy().get_data().shape)
                    

  raw = mne.io.read_raw_edf(edf_file_path, verbose=False, preload=True)
  raw = mne.io.read_raw_edf(edf_file_path, verbose=False, preload=True)
  raw = mne.io.read_raw_edf(edf_file_path, verbose=False, preload=True)
  raw = mne.io.read_raw_edf(edf_file_path, verbose=False, preload=True)
  raw = mne.io.read_raw_edf(edf_file_path, verbose=False, preload=True)
  raw = mne.io.read_raw_edf(edf_file_path, verbose=False, preload=True)
  raw = mne.io.read_raw_edf(edf_file_path, verbose=False, preload=True)
  raw = mne.io.read_raw_edf(edf_file_path, verbose=False, preload=True)
  raw = mne.io.read_raw_edf(edf_file_path, verbose=False, preload=True)
  raw = mne.io.read_raw_edf(edf_file_path, verbose=False, preload=True)
  raw = mne.io.read_raw_edf(edf_file_path, verbose=False, preload=True)
  raw = mne.io.read_raw_edf(edf_file_path, verbose=False, preload=True)
100%|██████████| 109/109 [05:45<00:00,  3.17s/it]
