In [None]:
# %matplotlib inline


# Download, Load, Preprocess and Save the TUH and NMT EEG Corpus

In this project, we investigate scaling law for transfer learning in normal/abnormal classification on the TUH and NMT EEG Corpus.


In [None]:
# Loading Libraries
from collections import Counter
import numpy as np
# import torch
import mne
import braindecode
from torch.utils.data import DataLoader

from braindecode.preprocessing import (
    preprocess, Preprocessor, create_fixed_length_windows, scale as multiply)
mne.set_log_level('ERROR')  # avoid messages everytime a window is extracted

from braindecode.datasets.tuh import TUHAbnormal
from nmt import NMT

We first download the datasets
uncomment the below cell

In [None]:
# # download datasets
# ## TUH
# !rsync -auxvL --delete nedc-eeg@www.isip.piconepress.com:data/eeg/tuh_eeg_abnormal/ ~/scratch/tuab/ #you need the password see https://isip.piconepress.com/projects/tuh_eeg/html/downloads.shtml
# ## NMT
# !wget https://chatbotmart.com/datasets/nmt_scalp_eeg_dataset.zip
# !unzip nmt_scalp_eeg_dataset.zip

We start by loading TUH and NMT datasets.

In [None]:
TUH_PATH = '~/scratch/medical/eeg/tuab/tuab_org/'
TUH_PATH_pp = '~/scratch/medical/eeg/tuab/tuab_pp3'

NMT_PATH = '~/scratch/medical/eeg/NMT/nmt_scalp_eeg_dataset/'
NMT_PATH_pp = '~/scratch/medical/eeg/NMT/nmt_pp3'

N_JOBS = 4  # specify the number of jobs for loading and windowing

Selecting dataset

In [None]:
dataset = 'tuh'

if dataset == 'tuh':
    tuh_ds = TUHAbnormal(
    TUH_PATH, 
    target_name=('pathological', 'age', 'gender'),
    # recording_ids=range(100),#or None to load the whole dataset,
    preload=False,
    n_jobs=N_JOBS
    )
    tuh_ds.description
    selected_ds = tuh_ds
    PATH_pp = TUH_PATH_pp
    
elif dataset == 'nmt':
    nmt_ds = NMT(
    NMT_PATH, 
    target_name=('pathological', 'age', 'gender'),
    # recording_ids=range(100,200),#or None to load the whole dataset,
    preload=False,
    n_jobs=N_JOBS
    )
    nmt_ds.description
    selected_ds = nmt_ds
    PATH_pp = NMT_PATH_pp

In [None]:
selected_ds.description

Iterating through the dataset gives x as ndarray(n_channels x 1) as well as
the target as [age of the subject, gender of the subject]. Let's look at the last example
as it has more interesting age/gender labels (compare to the last row of the dataframe above).



In [None]:
x, y = selected_ds[-1]
print('x:', x.shape)
print('y:', y)

In [None]:
raw = selected_ds.datasets[0].raw


In [None]:
raw.plot_psd()

Next, we will perform some preprocessing steps. First, we will do some
selection of available recordings based on the duration. We will select those
recordings, that have at least five minutes duration. Data is not loaded here.
Then we will do some basic preprocessings for both datasets. 

In [None]:
def select_by_duration(ds, tmin=0, tmax=None):
    if tmax is None:
        tmax = np.inf
    # determine length of the recordings and select based on tmin and tmax
    split_ids = []
    for d_i, d in enumerate(ds.datasets):
        duration = d.raw.n_times / d.raw.info['sfreq']
        if tmin <= duration <= tmax:
            split_ids.append(d_i)
    splits = ds.split(split_ids)
    split = splits['0']
    return split

In [None]:
tmin = 5 * 60
tmax = None

print(len(selected_ds))
selected_ds = select_by_duration(selected_ds, tmin, tmax)
print(len(selected_ds))

In [None]:
# TUH spesific channel PP
short_ch_names = sorted([
    'A1', 'A2',
    'FP1', 'FP2', 'F3', 'F4', 'C3', 'C4', 'P3', 'P4', 'O1', 'O2',
    'F7', 'F8', 'T3', 'T4', 'T5', 'T6', 'FZ', 'CZ', 'PZ'])
ar_ch_names = sorted([
    'EEG A1-REF', 'EEG A2-REF',
    'EEG FP1-REF', 'EEG FP2-REF', 'EEG F3-REF', 'EEG F4-REF', 'EEG C3-REF',
    'EEG C4-REF', 'EEG P3-REF', 'EEG P4-REF', 'EEG O1-REF', 'EEG O2-REF',
    'EEG F7-REF', 'EEG F8-REF', 'EEG T3-REF', 'EEG T4-REF', 'EEG T5-REF',
    'EEG T6-REF', 'EEG FZ-REF', 'EEG CZ-REF', 'EEG PZ-REF'])
le_ch_names = sorted([
    'EEG A1-LE', 'EEG A2-LE',
    'EEG FP1-LE', 'EEG FP2-LE', 'EEG F3-LE', 'EEG F4-LE', 'EEG C3-LE',
    'EEG C4-LE', 'EEG P3-LE', 'EEG P4-LE', 'EEG O1-LE', 'EEG O2-LE',
    'EEG F7-LE', 'EEG F8-LE', 'EEG T3-LE', 'EEG T4-LE', 'EEG T5-LE',
    'EEG T6-LE', 'EEG FZ-LE', 'EEG CZ-LE', 'EEG PZ-LE'])
assert len(short_ch_names) == len(ar_ch_names) == len(le_ch_names)
ar_ch_mapping = {ch_name: short_ch_name for ch_name, short_ch_name in zip(
    ar_ch_names, short_ch_names)}
le_ch_mapping = {ch_name: short_ch_name for ch_name, short_ch_name in zip(
    le_ch_names, short_ch_names)}
ch_mapping = {'ar': ar_ch_mapping, 'le': le_ch_mapping}


def select_by_channels(ds, ch_mapping):
    split_ids = []
    for i, d in enumerate(ds.datasets):
        ref = 'ar' if d.raw.ch_names[0].endswith('-REF') else 'le'
        # these are the channels we are looking for
        seta = set(ch_mapping[ref].keys())
        # these are the channels of the recoding
        setb = set(d.raw.ch_names)
        # if recording contains all channels we are looking for, include it
        if seta.issubset(setb):
            split_ids.append(i)
    return ds.split(split_ids)['0']

if dataset=='tuh':
    selected_ds = select_by_channels(selected_ds, ch_mapping)

In [None]:
import asrpy

def custom_crop(raw, tmin=0.0, tmax=None, include_tmax=True):
    # crop recordings to tmin – tmax. can be incomplete if recording
    # has lower duration than tmax
    # by default mne fails if tmax is bigger than duration
    tmax = min((raw.n_times - 1) / raw.info['sfreq'], tmax)
    raw.crop(tmin=tmin, tmax=tmax, include_tmax=include_tmax)

def custom_rename_channels(raw, mapping):
    # rename channels which are dependent on referencing:
    # le: EEG 01-LE, ar: EEG 01-REF
    # mne fails if the mapping contains channels as keys that are not present
    # in the raw
    if 'EEG' in raw.ch_names[0]: #just for tuh
        reference = raw.ch_names[0].split('-')[-1].lower()
        assert reference in ['le', 'ref'], 'unexpected referencing'
        reference = 'le' if reference == 'le' else 'ar'
        raw.rename_channels(mapping[reference])
def custom_reset_date(raw):
    # resolve this error: info["meas_date"] seconds must be between "(-2147483648, 0)" and "(2147483647, 0)"
    print(raw.info["meas_date"])
    raw.anonymize()

def apply_asr(raw):
    try:
        # filter the data between 1 and 75 Hz
        raw.load_data()
        raw.filter(l_freq=1., h_freq=None, fir_design='firwin',
            skip_by_annotation='edge')
        #run asr 
        asr = asrpy.ASR(sfreq=raw.info["sfreq"], cutoff=5)
        asr.fit(raw.copy())
        raw = asr.transform(raw.copy())
    except:
        print('Could not apply the ASR')
 
# one recording assumed to be equal to one subject (TUAB has multiple sessions per subject)
def normalize_one_recording_channel_wise(clean_eeg_data):
    
    # raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=True)
    ch_names = short_ch_names #raw.ch_names

    # do your data cleaning/preprocessing pipeline here
    # clean_epochs or clean_raw = ...

    # for sake of illustration
    # clean_eeg_data = raw.get_data()
    print(clean_eeg_data.shape)

    # compute stats only on clean segments
    means = []
    stds = []
    for i in range(len(ch_names)):
        means.append(np.mean(clean_eeg_data[i, :]))
        stds.append(np.std(clean_eeg_data[i, :]))
    
    ## apply clip to 2 stds
    #and
    # apply z-score normalization to clean_data
    normalized_clean_eeg_data = []
    for i in range(len(ch_names)):
        #   clip to 2 stds
        clean_eeg_data[i, :] = np.clip(clean_eeg_data[i, :],a_min=means[i]-2*stds[i],a_max=means[i]+2*stds[i])
        #   z-score normalization
        # clean_eeg_data[i, :] = (clean_eeg_data[i, :] - means[i]) / stds[i] # zscoring
        clean_eeg_data[i, :] = (clean_eeg_data[i, :] - np.mean(clean_eeg_data[i, :])) / np.std(clean_eeg_data[i, :]) # zscoring

        ## scaling to be in [0 1] range
        # clean_eeg_data[i, :] = (clean_eeg_data[i, :] - min(clean_eeg_data[i, :])) / (max(clean_eeg_data[i, :]) - min(clean_eeg_data[i, :]))
    #     normalized_clean_eeg_data.append(normalized_channel)
    # normalized_clean_eeg_data = np.array(normalized_clean_eeg_data); print(normalized_clean_eeg_data.shape)

    return clean_eeg_data #normalized_clean_eeg_data

tmin = 1 * 60
tmax = 6 * 60
sfreq = 100

preprocessors = [
    Preprocessor(custom_crop, tmin=tmin, tmax=tmax, include_tmax=False,
                 apply_on_array=False),
            
    # Preprocessor('set_eeg_reference', ref_channels='average', ch_type='eeg'),
    Preprocessor(custom_rename_channels, mapping=ch_mapping,
                 apply_on_array=False),
    Preprocessor('pick_channels', ch_names=short_ch_names, ordered=True),
    Preprocessor(multiply, factor=1*1e6 if dataset=='nmt' else 1e6, apply_on_array=True),
    Preprocessor(custom_reset_date,apply_on_array=False),
    Preprocessor(np.clip, a_min=-800, a_max=800, apply_on_array=True),
    Preprocessor('resample', sfreq=sfreq),
    Preprocessor(apply_asr,apply_on_array=False),
    Preprocessor('set_eeg_reference', ref_channels='average', ch_type='eeg'),
    Preprocessor(normalize_one_recording_channel_wise, apply_on_array=True),

]

In [None]:
## creating ds_pp folder
import os
import shutil

def create_pp_folder(dir_name):
    # Check if directory already exists
    if os.path.exists(dir_name):
        # Remove directory if it already exists
        shutil.rmtree(dir_name)

    # Create new directory
    os.mkdir(dir_name)

create_pp_folder(PATH_pp)

In [None]:
selected_preproc = preprocess(
    concat_ds=selected_ds,
    preprocessors=preprocessors,
    n_jobs=N_JOBS,
    save_dir=PATH_pp,
    overwrite=True,
)

Our datasets are preprocessed and saved to the given directories. Now we cann move to the training notebook.

Visuliziation

In [None]:
print(f'dataset: {dataset} preprocessing was successful')

In [None]:
raw = selected_preproc.datasets[90].raw

In [None]:
selected_preproc.datasets[90].description

In [None]:
(raw.get_data().max())

In [None]:
(raw.get_data().min())

In [None]:
raw.plot(scalings='auto')

In [None]:
raw.plot_psd(fmax=50)

In [None]:
raw2 = selected_preproc.datasets[5].raw
print(raw2.get_data().max())
print(raw2.get_data().min())
raw2.plot(scalings=False)

In [None]:
raw3 = selected_preproc.datasets[92].raw
print(raw3.get_data().max())
print(raw3.get_data().min())
raw3.plot(scalings=False)

In [None]:
from braindecode.datautil.serialization import  load_concat_dataset
from braindecode.datasets import BaseConcatDataset

ds_all = load_concat_dataset(TUH_PATH_pp, preload=False,
                            # target_name=['pathological','age','gender'] ,#)
                            target_name=['gender'] ,#)
                            ids_to_load=range(800)
                            )
ds_all2 = load_concat_dataset(NMT_PATH_pp, preload=False,
                            # target_name=['pathological','age','gender'] ,#)
                            target_name=['gender'] ,#)
                            ids_to_load=range(800)
                            )

ds_all = BaseConcatDataset([ds_all, ds_all2])

In [None]:
x, y = ds_all[-50]
print('x:', x.shape)
print('y:', y)

In [None]:
df = ds_all.description
df = df.assign(gender_bool=df['gender'].map(lambda x: 0 if x == 'M' else 1 if x == 'F' else 10))


In [None]:
# Define a function that converts 'M' and 'F' to True and False
def convert_gender(gender):
  if gender == 'M':
    return True
  elif gender == 'F':
    return False
  else:
    print("Invalid gender: " + gender)
    return None

# Apply the function to the column and assign it to a new column
df = df.assign(gender_bool=df['gender'].map(convert_gender))

# Display the first five rows of the dataframe
df.head()


In [None]:
df.value_counts('gender')