****Sample Notebook to get started with the EEG dataset****

In [2]:
!python3 -m pip install -r requirements.txt



In [3]:
# Importing libraries
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import mne
import os
from tqdm import tqdm
import configparser

In [4]:
# Make the dataset class
class BhutanDataSet():
    """
    Defines the BhutanDataSet class. This class is used to load the Bhutan EEG dataset and preprocess it for
    later use in the BENDR feature representation generation. It is loaded as a raw .edf file and the preprocessed 
    according to Makoto's pipeline.
    """
    def __init__(self, path, subject, session, event, event_id, tmin, tmax, baseline, filter):
        """
        Initializes the BhutanDataSet class.

        Args:
            path (str): The path to the raw .edf file.
            subject (int): The subject number.
            session (int): The session number.
            event (str): The event to be extracted.
            event_id (int): The event id.
            tmin (float): The minimum time to be extracted.
            tmax (float): The maximum time to be extracted.
            baseline (tuple): The baseline to be used.
            filter (tuple): The filter to be used.

        """
        self.path = path
        self.subject = subject
        self.session = session
        self.event = event
        self.event_id = event_id
        self.tmin = tmin
        self.tmax = tmax
        self.baseline = baseline
        self.filter = filter

        print("Loading the Bhutan EEG dataset for subject {} and session {}...".format(self.subject, self.session))

        self.raw, self.epochs, self.X, self.y = self.call()

    def call(self):
        """
        Calls the BhutanDataSet class.

        Returns:
            raw (mne.io.edf.edf.RawEDF): The raw .edf file.
            epochs (mne.epochs): The preprocessed epochs.
            X (np.array): The data.
            y (np.array): The labels.

        """
        raw = self.load_data()
        # self.visualise_data(raw)
        epochs = self.preprocess_data(raw)
        X, y = self.extract_data(epochs)
        return raw, epochs, X, y

    def load_data(self):
        """
        Loads the raw .edf file.

        Returns:
            raw (mne.io.edf.edf.RawEDF): The raw .edf file.

        """
        raw = mne.io.read_raw_edf(self.path, infer_types=True, verbose=True)
        # Print the number of channels and the channel names
        print("The number of channels is: ", len(raw.info['ch_names']))
        print("The channel names are: ", raw.info['ch_names'])
        raw.drop_channels(['R1', 'R2', 'TIP', 'GROUND', 'REF']) # BAD PRACTICE - REMOVE
        return raw
    
    def visualise_data(self, raw):
        """
        Visualises the raw .edf file.

        Args:
            raw (mne.io.edf.edf.RawEDF): The raw .edf file.

        """
        raw.plot()
    
    def preprocess_data(self, raw):
        """
        Preprocesses the raw .edf file.

        Args:
            raw (mne.io.edf.edf.RawEDF): The raw .edf file.

        Returns:
            epochs (mne.epochs): The preprocessed epochs.

        """
        # Set the montage
        montage = mne.channels.make_standard_montage('standard_1020')
        # for montage in mne.channels.get_builtin_montages():
        #     try:
        #         raw.set_montage(montage, match_case=False)
        #         print('Set montage to', montage)
        #         break
        #     except Exception as e:
        #         print(e, 'for', montage)
        raw.set_montage(montage, match_case=False)
        # Load data into memory
        raw.load_data()
        # Filter the data
        raw.filter(self.filter[0], self.filter[1], fir_design='firwin')
        # Extract events
        events, event_id = mne.events_from_annotations(raw)
        # Extract epochs
        # epochs = mne.Epochs(raw, events, event_id, tmin=self.tmin, tmax=self.tmax, baseline=self.baseline, preload=True)
        epochs = mne.Epochs(raw, events, baseline=self.baseline, preload=True)
        return epochs
    
    def extract_data(self, epochs):
        """
        Extracts the data from the epochs.

        Args:
            epochs (mne.epochs): The preprocessed epochs.

        Returns:
            X (np.array): The data.
            y (np.array): The labels.

        """
        X = epochs.get_data()
        y = epochs.events[:, -1]
        return X, y
    
class MMIDBDataSet():
    """
    Defines the MMIDBDataSet class. This class is used to load the MMIDB EEG dataset and preprocess it for
    later use in the BENDR feature representation generation. It is loaded as a raw .edf file and the preprocessed 
    according to Makoto's pipeline.

    """
    def __init__(self, path, subject, session, event, event_id, tmin, tmax, baseline, filter):
        """
        Initializes the MMIDBDataSet class.

        Args:
            path (str): The path to the raw .edf file.
            subject (int): The subject number.
            session (int): The session number.
            event (str): The event to be extracted.
            event_id (int): The event id.
            tmin (float): The minimum time to be extracted.
            tmax (float): The maximum time to be extracted.
            baseline (tuple): The baseline to be used.
            filter (tuple): The filter to be used.

        """
        self.path = path
        self.subject = subject
        self.session = session
        self.event = event
        self.event_id = event_id
        self.tmin = tmin
        self.tmax = tmax
        self.baseline = baseline
        self.filter = filter

        print("Loading the MMIDB EEG dataset for subject {} and session {}...".format(self.subject, self.session))

        self.raw, self.epochs, self.X, self.y = self.call()

    def call(self):
        """
        Calls the MMIDBDataSet class.

        Returns:
            raw (mne.io.edf.edf.RawEDF): The raw .edf file.
            epochs (mne.epochs): The preprocessed epochs.
            X (np.array): The data.
            y (np.array): The labels.

        """
        raw = self.load_data()
        # self.visualise_data(raw)
        epochs = self.preprocess_data(raw)
        X, y = self.extract_data(epochs)
        return raw, epochs, X, y

    def load_data(self):
        """
        Loads the raw .edf file.

        Returns:
            raw (mne.io.edf.edf.RawEDF): The raw .edf file.

        """
        raw = mne.io.read_raw_edf(self.path, infer_types=True, verbose=True)
        # Print the number of channels and the channel names
        print("The number of channels is: ", len(raw.info['ch_names']))
        print("The channel names are: ", raw.info['ch_names'])

        # Rename channelse, remove all .'s and spaces
        raw.rename_channels(lambda x: x.strip('.').replace(' ', ''))
        return raw
    
    def visualise_data(self, raw):
        """
        Visualises the raw .edf file.

        Args:
            raw (mne.io.edf.edf.RawEDF): The raw .edf file.

        """
        raw.plot()

    def preprocess_data(self, raw):
        """
        Preprocesses the raw .edf file.

        Args:
            raw (mne.io.edf.edf.RawEDF): The raw .edf file.

        Returns:
            epochs (mne.epochs): The preprocessed epochs.

        """
        # Set the montage
        montage = mne.channels.make_standard_montage('standard_1020')
        raw.set_montage(montage, match_case=False)
        # Load data into memory
        raw.load_data()
        # Filter the data
        raw.filter(self.filter[0], self.filter[1], fir_design='firwin')
        # Extract events
        events, event_id = mne.events_from_annotations(raw)
        # Extract epochs
        epochs = mne.Epochs(raw, events, event_id, tmin=self.tmin, tmax=self.tmax, baseline=self.baseline, preload=True)
        return epochs
    
    def extract_data(self, epochs):
        """
        Extracts the data from the epochs.

        Args:
            epochs (mne.epochs): The preprocessed epochs.

        Returns:
            X (np.array): The data.
            y (np.array): The labels.

        """
        X = epochs.get_data()
        y = epochs.events[:, -1]
        return X, y
    

    
def load_config(path='config.ini'):
    """
    Loads the configuration file.

    Returns:
        config (configparser.ConfigParser): The configuration file.

    """
    config = configparser.ConfigParser()
    config.read(path)
    return config

In [5]:
# Load the configuration file
config = load_config()

# Load the Bhutan EEG dataset
bhutan = BhutanDataSet('/Users/benjaminfazal/Downloads/Bhutan Data/test_file.edf', subject=int(config['Bhutan']['subject']), 
                        session=int(config['Bhutan']['session']), event=config['Bhutan']['event'], 
                        event_id=int(config['Bhutan']['event_id']), tmin=int(config['Bhutan']['tmin']), 
                        tmax=int(config['Bhutan']['tmax']), baseline=None, 
                        filter=(float(config['Bhutan']['filter1']), float(config['Bhutan']['filter2'])))

# investigate x and y
print(bhutan.X.shape)

Loading the Bhutan EEG dataset for subject 1 and session 1...
Extracting EDF parameters from /Users/benjaminfazal/Downloads/Bhutan Data/test_file.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
The number of channels is:  32
The channel names are:  ['FP1', 'FP2', 'F3', 'F4', 'C3', 'C4', 'P3', 'P4', 'O1', 'O2', 'F7', 'F8', 'T7', 'T8', 'TP7', 'TP8', 'P7', 'P8', 'F9', 'F10', 'T9', 'T10', 'P9', 'P10', 'Fz', 'Cz', 'Pz', 'R1', 'R2', 'TIP', 'GROUND', 'REF']
Reading 0 ... 686591  =      0.000 ...  2681.996 secs...
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 40 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 40.00 

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s


Used Annotations descriptions: ['Eye blinking', 'Eye movement left-right', 'Eyes closed', 'Hyperventilation', 'Jaw clenching']
Not setting metadata
5 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 5 events and 180 original time points ...
0 bad epochs dropped
(5, 27, 180)


[Parallel(n_jobs=1)]: Done  27 out of  27 | elapsed:    0.2s finished
  X = epochs.get_data()


In [6]:
# Load the configuration file
config = load_config()

# Load the MMIDB EEG dataset
mmidb = MMIDBDataSet('/Users/benjaminfazal/Downloads/MMIDB/S013R01.edf', subject=int(config['MMIDB']['subject']), 
                        session=int(config['MMIDB']['session']), event=config['MMIDB']['event'], 
                        event_id=int(config['MMIDB']['event_id']), tmin=int(config['MMIDB']['tmin']), 
                        tmax=int(config['MMIDB']['tmax']), baseline=None, 
                        filter=(float(config['MMIDB']['filter1']), float(config['MMIDB']['filter2'])))

# investigate x and y
print(mmidb.X.shape)

Loading the MMIDB EEG dataset for subject 1 and session 1...
Extracting EDF parameters from /Users/benjaminfazal/Downloads/MMIDB/S013R01.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
The number of channels is:  64
The channel names are:  ['Fc5.', 'Fc3.', 'Fc1.', 'Fcz.', 'Fc2.', 'Fc4.', 'Fc6.', 'C5..', 'C3..', 'C1..', 'Cz..', 'C2..', 'C4..', 'C6..', 'Cp5.', 'Cp3.', 'Cp1.', 'Cpz.', 'Cp2.', 'Cp4.', 'Cp6.', 'Fp1.', 'Fpz.', 'Fp2.', 'Af7.', 'Af3.', 'Afz.', 'Af4.', 'Af8.', 'F7..', 'F5..', 'F3..', 'F1..', 'Fz..', 'F2..', 'F4..', 'F6..', 'F8..', 'Ft7.', 'Ft8.', 'T7..', 'T8..', 'T9..', 'T10.', 'Tp7.', 'Tp8.', 'P7..', 'P5..', 'P3..', 'P1..', 'Pz..', 'P2..', 'P4..', 'P6..', 'P8..', 'Po7.', 'Po3.', 'Poz.', 'Po4.', 'Po8.', 'O1..', 'Oz..', 'O2..', 'Iz..']
Reading 0 ... 9759  =      0.000 ...    60.994 secs...
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 40 Hz

FIR filter parameters
---------------------
Designing a one-p

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done  64 out of  64 | elapsed:    0.0s finished
  X = epochs.get_data()
