# NCH Dataset Preprocessor

## To Run:
- Replace PSG_DIR with the directory where the NCH edfs and tsvs are (they should be together in the same folder).
- Replace OUT_DIR with an existing directory where the processed files should be saved.


In [1]:
!pip install mne numpy pandas



In [2]:
import glob
import os
from datetime import datetime
from itertools import compress

import numpy as np
import pandas as pd
import mne
from mne import make_fixed_length_events
import xml.etree.ElementTree as ET

In [3]:
PSG_DIR = "./data/nch/"
OUT_DIR = './data/nch/preprocessed'

THRESHOLD = 3
NUM_WORKER = 8
FREQ = 128.0
EPOCH_LENGTH = 30.0
SN = 3984 

channels = [
    "EOG LOC-M2",  # 0
    "EOG ROC-M1",  # 1
    "EEG C3-M2",  # 2
    "EEG C4-M1",  # 3
    "ECG EKG2-EKG",  # 4

    "RESP PTAF",  # 5
    "RESP AIRFLOW",  # 6
    "RESP THORACIC",  # 7
    "RESP ABDOMINAL",  # 8
    "SPO2",  # 9
    "CAPNO",  # 10
]

APNEA_EVENT_DICT = {
    "Obstructive Apnea": 2,
    "Central Apnea": 2,
    "Mixed Apnea": 2,
    "apnea": 2,
    "obstructive apnea": 2,
    "central apnea": 2,
    "apnea": 2,
    "Apnea": 2,
}

HYPOPNEA_EVENT_DICT = {
    "Obstructive Hypopnea": 1,
    "Hypopnea": 1,
    "hypopnea": 1,
    "Mixed Hypopnea": 1,
    "Central Hypopnea": 1,
}

POS_EVENT_DICT = {
    "Obstructive Hypopnea": 1,
    "Hypopnea": 1,
    "hypopnea": 1,
    "Mixed Hypopnea": 1,
    "Central Hypopnea": 1,

    "Obstructive Apnea": 2,
    "Central Apnea": 2,
    "Mixed Apnea": 2,
    "apnea": 2,
    "obstructive apnea": 2,
    "central apnea": 2,
    "Apnea": 2,
}

NEG_EVENT_DICT = {
    'Sleep stage N1': 0,
    'Sleep stage N2': 0,
    'Sleep stage N3': 0,
    'Sleep stage R': 0,
}

WAKE_DICT = {
    "Sleep stage W": 10
}

In [7]:
def identity(df):
    # just returns whatever DataFrame df is passed to it, essentially acting as a placeholder function.
    return df


def apnea2bad(df):
    df = df.replace(r'.*pnea.*', 'badevent', regex=True)
    print("bad replaced!")
    return df


def wake2bad(df):
    return df.replace("Sleep stage W", 'badevent')


def change_duration(df, label_dict=POS_EVENT_DICT, duration=EPOCH_LENGTH):
    for key in label_dict:
        df.loc[df.description == key, 'duration'] = duration
    print("change duration!")
    return df

def load_study_chat(edf_path, annotation_path, annotation_func, preload=False, exclude=[], verbose='CRITICAL'):
    raw = mne.io.read_raw_edf(input_fname=edf_path, exclude=exclude, preload=preload, verbose=verbose)

    df = annotation_func(pd.read_csv(annotation_path, sep='\t'))
    annotations = mne.Annotations(df.onset, df.duration, df.description)  # ,orig_time=new_datetime)

    raw.set_annotations(annotations)

    raw.rename_channels({name: name.upper() for name in raw.info['ch_names']})

    return raw

def preprocess(path, annotation_modifier, out_dir):
    print(path)
    is_apnea_available, is_hypopnea_available = True, True
    raw = load_study_chat(path[0], path[1], annotation_modifier, verbose=True)

    ### Channel Check ###
    if not all([name in raw.ch_names for name in channels]):
        print([name in raw.ch_names for name in channels])
        print("study " + os.path.basename(path[0]) + " skipped since insufficient channels")
        return 0
    
    try:
        apnea_events, event_ids = mne.events_from_annotations(raw, event_id=POS_EVENT_DICT, chunk_duration=1.0,
                                                              verbose=None)
    except ValueError as e:
        print(str(e))
        print("No Chunk found!")
        return 0
    
    ########################################   CHECK CRITERIA FOR SS   #################################################
    print(str(datetime.now().time().strftime("%H:%M:%S")) + ' --- Processing %s' % os.path.basename(path[0]))


    try:
        apnea_events, event_ids = mne.events_from_annotations(raw, event_id=APNEA_EVENT_DICT, chunk_duration=1.0,
                                                              verbose=None)
    except ValueError:
        is_apnea_available = False

    try:
        hypopnea_events, event_ids = mne.events_from_annotations(raw, event_id=HYPOPNEA_EVENT_DICT, chunk_duration=1.0,
                                                                 verbose=None)
    except ValueError:
        is_hypopnea_available = False

    wake_events, event_ids = mne.events_from_annotations(raw, event_id=WAKE_DICT, chunk_duration=1.0, verbose=None)
    
    ####################################################################################################################
    
    sfreq = raw.info['sfreq']
    tmax = EPOCH_LENGTH - 1. / sfreq

    raw = raw.pick_channels(channels, ordered=True)
    fixed_events = make_fixed_length_events(raw, id=0, duration=EPOCH_LENGTH, overlap=0.)

    try:
        epochs = mne.Epochs(raw, fixed_events, event_id=[0], tmin=0, tmax=tmax, baseline=None, preload=True, proj=False, verbose=None)
        epochs.load_data()
        
    except AssertionError:
        return 0
        
    if sfreq != FREQ:
        epochs = epochs.resample(FREQ, npad='auto', n_jobs=8, verbose=None)
    data = epochs.get_data()
    
    ####################################################################################################################
    if is_apnea_available:
        apnea_events_set = set((apnea_events[:, 0] / sfreq).astype(int))
    if is_hypopnea_available:
        hypopnea_events_set = set((hypopnea_events[:, 0] / sfreq).astype(int))
    wake_events_set = set((wake_events[:, 0] / sfreq).astype(int))

    starts = (epochs.events[:, 0] / sfreq).astype(int)

    labels_apnea = []
    labels_hypopnea = []
    labels_not_awake = []
    total_apnea_event_second = 0
    total_hypopnea_event_second = 0

    for seq in range(data.shape[0]):
        epoch_set = set(range(starts[seq], starts[seq] + int(EPOCH_LENGTH)))
        if is_apnea_available:
            apnea_seconds = len(apnea_events_set.intersection(epoch_set))
            total_apnea_event_second += apnea_seconds
            labels_apnea.append(apnea_seconds)
        else:
            labels_apnea.append(0)

        if is_hypopnea_available:
            hypopnea_seconds = len(hypopnea_events_set.intersection(epoch_set))
            total_hypopnea_event_second += hypopnea_seconds
            labels_hypopnea.append(hypopnea_seconds)
        else:
            labels_hypopnea.append(0)

        labels_not_awake.append(len(wake_events_set.intersection(epoch_set)) == 0)
    ####################################################################################################################
    data = data[labels_not_awake, :, :]
    labels_apnea = list(compress(labels_apnea, labels_not_awake))
    labels_hypopnea = list(compress(labels_hypopnea, labels_not_awake))
    ####################################################################################################################

    new_data = np.zeros_like(data)
    for i in range(data.shape[0]):
        new_data[i, 0, :] = data[i, 0, :]  # EOG LOC-M2
        new_data[i, 1, :] = data[i, 1, :]  # EOG ROC-M1
        new_data[i, 2, :] = data[i, 2, :]  # EEG C3-M2
        new_data[i, 3, :] = data[i, 3, :]  # EEG C4-M1
        new_data[i, 4, :] = data[i, 4, :]  # ECG EKG2-EKG
        new_data[i, 5, :] = data[i, 5, :]  # RESP PTAF
        new_data[i, 6, :] = data[i, 6, :]  # RESP AIRFLOW
        new_data[i, 7, :] = data[i, 7, :]  # RESP THORACIC
        new_data[i, 8, :] = data[i, 8, :]  # RESP ABDOMINAL
        new_data[i, 9, :] = data[i, 9, :]  # SPO2
        new_data[i, 10, :] = data[i, 10,:] # CAPNO
    data = new_data[:, :11, :]
    ####################################################################################################################

    np.savez_compressed(
        out_dir + '\\' + os.path.basename(path[0]) + "_" + str(total_apnea_event_second) + "_" + str(
            total_hypopnea_event_second),
        data=data, labels_apnea=labels_apnea, labels_hypopnea=labels_hypopnea)

    return data.shape[0]

mne.set_log_file('log.txt', overwrite=False)
edf_files = glob.glob(PSG_DIR + "*.edf")
for edf_file in edf_files:
    # For each EDF file, construct the corresponding annotation file path by replacing the extension with -nsrr.tsv.
    annot_file = edf_file.replace(".edf", ".tsv")
    preprocess((edf_file, annot_file), identity, OUT_DIR)

('./data/nch\\10000_17728.edf', './data/nch\\10000_17728.tsv')
23:39:12 --- Processing 10000_17728.edf


[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  10 tasks      | elapsed:    3.5s
[Parallel(n_jobs=8)]: Done 272 tasks      | elapsed:    3.7s
[Parallel(n_jobs=8)]: Done 10512 tasks      | elapsed:    8.1s
[Parallel(n_jobs=8)]: Done 13046 out of 13046 | elapsed:    8.9s finished
  data = epochs.get_data()
