In [None]:
import re, datetime, operator, logging, sys
from collections import namedtuple
import os

import argparse
import glob
import math
import ntpath

import shutil
import urllib

from datetime import datetime
import warnings
warnings.filterwarnings("ignore")

import pandas as pd
import numpy as np
from mne.io import concatenate_raws, read_raw_edf
import xml.etree.ElementTree as ET



EVENT_CHANNEL = 'EDF Annotations'

class EDFEndOfData(Exception): pass

def tal(tal_str):
    '''Return a list with (onset, duration, annotation) tuples for an EDF+ TAL
  stream.
  '''
    exp = '(?P<onset>[+\-]\d+(?:\.\d*)?)' + '(?:\x15(?P<duration>\d+(?:\.\d*)?))?' + '(\x14(?P<annotation>[^\x00]*))?' + '(?:\x14\x00)'

    def annotation_to_list(annotation):
        return str(annotation.encode('utf-8')).split('\x14') if annotation else []

    def parse(dic):
        return (
      float(dic['onset']),
      float(dic['duration']) if dic['duration'] else 0.,
      annotation_to_list(dic['annotation']))

    return [parse(m.groupdict()) for m in re.finditer(exp, tal_str)]


def edf_header(f):
    h = {}
    assert f.tell() == 0  # check file position
    assert f.read(8) == '0       '

    # recording info)
    h['local_subject_id'] = f.read(80).strip()
    h['local_recording_id'] = f.read(80).strip()

    # parse timestamp
    (day, month, year) = [int(x) for x in re.findall('(\d+)', f.read(8))]
    (hour, minute, sec)= [int(x) for x in re.findall('(\d+)', f.read(8))]
    h['date_time'] = str(datetime.datetime(year + 2000, month, day,
    hour, minute, sec))

    # misc
    header_nbytes = int(f.read(8))
    subtype = f.read(44)[:5]
    h['EDF+'] = subtype in ['EDF+C', 'EDF+D']
    h['contiguous'] = subtype != 'EDF+D'
    h['n_records'] = int(f.read(8))
    h['record_length'] = float(f.read(8))  # in seconds
    nchannels = h['n_channels'] = int(f.read(4))

    # read channel info
    channels = range(h['n_channels'])
    h['label'] = [f.read(16).strip() for n in channels]
    h['transducer_type'] = [f.read(80).strip() for n in channels]
    h['units'] = [f.read(8).strip() for n in channels]
    h['physical_min'] = np.asarray([float(f.read(8)) for n in channels])
    h['physical_max'] = np.asarray([float(f.read(8)) for n in channels])
    h['digital_min'] = np.asarray([float(f.read(8)) for n in channels])
    h['digital_max'] = np.asarray([float(f.read(8)) for n in channels])
    h['prefiltering'] = [f.read(80).strip() for n in channels]
    h['n_samples_per_record'] = [int(f.read(8)) for n in channels]
    f.read(32 * nchannels)  # reserved

    #assert f.tell() == header_nbytes
    return h


class BaseEDFReader:
    def __init__(self, file):
        self.file = file


    def read_header(self):
        self.header = h = edf_header(self.file)

        # calculate ranges for rescaling
        self.dig_min = h['digital_min']
        self.phys_min = h['physical_min']
        phys_range = h['physical_max'] - h['physical_min']
        dig_range = h['digital_max'] - h['digital_min']
        assert np.all(phys_range > 0)
        assert np.all(dig_range > 0)
        self.gain = phys_range / dig_range


    def read_raw_record(self):
        '''Read a record with data_2013 and return a list containing arrays with raw
        bytes.
        '''
        result = []
        for nsamp in self.header['n_samples_per_record']:
            samples = self.file.read(nsamp * 2)
            if len(samples) != nsamp * 2:
                raise EDFEndOfData
            result.append(samples)
        return result


    def convert_record(self, raw_record):
        '''Convert a raw record to a (time, signals, events) tuple based on
        information in the header.
        '''
        h = self.header
        dig_min, phys_min, gain = self.dig_min, self.phys_min, self.gain
        time = float('nan')
        signals = []
        events = []
        for (i, samples) in enumerate(raw_record):
            if h['label'][i] == EVENT_CHANNEL:
                ann = tal(samples)
                time = ann[0][0]
                events.extend(ann[1:])
                # print(i, samples)
                # exit()
            else:
                # 2-byte little-endian integers
                dig = np.fromstring(samples, '<i2').astype(np.float32)
                phys = (dig - dig_min[i]) * gain[i] + phys_min[i]
                signals.append(phys)

        return time, signals, events


    def read_record(self):
        return self.convert_record(self.read_raw_record())


    def records(self):
        '''
        Record generator.
        '''
        try:
            while True:
                yield self.read_record()
        except EDFEndOfData:
            pass


def load_edf(edffile):
    '''Load an EDF+ file.
  Very basic reader for EDF and EDF+ files. While BaseEDFReader does support
  exotic features like non-homogeneous sample rates and loading only parts of
  the stream, load_edf expects a single fixed sample rate for all channels and
  tries to load the whole file.
  Parameters
  ----------
  edffile : file-like object or string
  Returns
  -------
  Named tuple with the fields:
    X : NumPy array with shape p by n.
      Raw recording of n samples in p dimensions.
    sample_rate : float
      The sample rate of the recording. Note that mixed sample-rates are not
      supported.
    sens_lab : list of length p with strings
      The labels of the sensors used to record X.
    time : NumPy array with length n
      The time offset in the recording for each sample.
    annotations : a list with tuples
      EDF+ annotations are stored in (start, duration, description) tuples.
      start : float
        Indicates the start of the event in seconds.
      duration : float
        Indicates the duration of the event in seconds.
      description : list with strings
        Contains (multiple?) descriptions of the annotation event.
  '''
    if isinstance(edffile, basestring):
        with open(edffile, 'rb') as f:
            return load_edf(f)  # convert filename to file

    reader = BaseEDFReader(edffile)
    reader.read_header()

    h = reader.header
    log.debug('EDF header: %s' % h)

      # get sample rate info
    nsamp = np.unique(
        [n for (l, n) in zip(h['label'], h['n_samples_per_record'])
        if l != EVENT_CHANNEL])
    assert nsamp.size == 1, 'Multiple sample rates not supported!'
    sample_rate = float(nsamp[0]) / h['record_length']

    rectime, X, annotations = zip(*reader.records())
    X = np.hstack(X)
    annotations = reduce(operator.add, annotations)
    chan_lab = [lab for lab in reader.header['label'] if lab != EVENT_CHANNEL]

      # create timestamps
    if reader.header['contiguous']:
        time = np.arange(X.shape[1]) / sample_rate
    else:
        reclen = reader.header['record_length']
        within_rec_time = np.linspace(0, reclen, nsamp, endpoint=False)
        time = np.hstack([t + within_rec_time for t in rectime])

    tup = namedtuple('EDF', 'X sample_rate chan_lab time annotations')
    return tup(X, sample_rate, chan_lab, time, annotations)


EPOCH_SEC_SIZE = 30

# data on GNODE 25 DATE: 06-12-21 (ALL 329 files of SHHS1)


data_dir = '/scratch/SLEEP_data/shhs/polysomnography/edfs/shhs1'
ann_dir = '/scratch/SLEEP_data/shhs/polysomnography/annotations-events-profusion/shhs1'
output_dir = '/scratch/SLEEP_data/shhs/output'
select_ch = 'EEG C4-A1'  #EEG (sec)	C3	A2  #EEG	C4	A1

csv_path = '/scratch/SLEEP_data/selected_shhs1_files.txt'
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

#ids = pd.read_csv("selected_shhs1_files.txt", header=None, names='a')
ids = pd.read_csv(csv_path, header=None)
ids = ids[0].values.tolist()

edf_fnames = [os.path.join(data_dir, i + ".edf") for i in ids]
ann_fnames = [os.path.join(ann_dir,  i + "-profusion.xml") for i in ids]

edf_fnames.sort()
ann_fnames.sort()

edf_fnames = np.asarray(edf_fnames)
ann_fnames = np.asarray(ann_fnames)

#yahase 
for file_id in range(len(edf_fnames)):
    if os.path.exists(os.path.join(output_dir, edf_fnames[file_id].split('/')[-1])[:-4]+".npz"):
        continue
    print(edf_fnames[file_id])
    select_ch = 'EEG C4-A1'
    raw = read_raw_edf(edf_fnames[file_id], preload=True, stim_channel=None, verbose=None)
    sampling_rate = raw.info['sfreq']
    ch_type = select_ch.split(" ")[0]    # selecting EEG out of 'EEG C4-A1'
    select_ch = sorted([s for s in raw.info["ch_names"] if ch_type in s]) # this has 2 vals [EEG,EEG(sec)] and selecting 0th index
    print(select_ch)
    raw_ch_df = raw.to_data_frame(scalings=sampling_rate)[select_ch]
    print(raw_ch_df.shape)
    #raw_ch_df = raw_ch_df.to_frame()
    raw_ch_df.set_index(np.arange(len(raw_ch_df)))
  
    labels = []
    # Read annotation and its header
    t = ET.parse(ann_fnames[file_id])
    r = t.getroot()
    faulty_File = 0
    for i in range(len(r[4])):
        lbl = int(r[4][i].text)
        if lbl == 4:  # make stages N3, N4 same as N3
            labels.append(3)
        elif lbl == 5:  # Assign label 4 for REM stage
            labels.append(4)
        else:
            labels.append(lbl)
        if lbl > 5:  # some files may contain labels > 5 BUT not the selected ones.
            faulty_File = 1

    if faulty_File == 1:
        print( "============================== Faulty file ==================")
        continue

    labels = np.asarray(labels)

    # Remove movement and unknown stages if any
    raw_ch = raw_ch_df.values
    print(raw_ch.shape)

    # Verify that we can split into 30-s epochs
    if len(raw_ch) % (EPOCH_SEC_SIZE * sampling_rate) != 0:
        raise Exception("Something wrong")
    n_epochs = len(raw_ch) / (EPOCH_SEC_SIZE * sampling_rate)

    # Get epochs and their corresponding labels
    x = np.asarray(np.split(raw_ch, n_epochs)).astype(np.float32)
    y = labels.astype(np.int32)

    print(x.shape)
    print(y.shape)
    assert len(x) == len(y)

    # Select on sleep periods
    w_edge_mins = 30
    nw_idx = np.where(y != 0)[0]
    start_idx = nw_idx[0] - (w_edge_mins * 2)
    end_idx = nw_idx[-1] + (w_edge_mins * 2)
    if start_idx < 0: start_idx = 0
    if end_idx >= len(y): end_idx = len(y) - 1
    select_idx = np.arange(start_idx, end_idx + 1)
    print("Data before selection: {}, {}".format(x.shape, y.shape))
    x = x[select_idx]
    y = y[select_idx]
    print("Data after selection: {}, {}".format(x.shape, y.shape))

    # Saving as numpy files
    filename = os.path.basename(edf_fnames[file_id]).replace(".edf",  ".npz")
    save_dict = {
        "x": x,
        "y": y,
        "fs": sampling_rate
    }
    np.savez(os.path.join(output_dir, filename), **save_dict)
    print(" ---------- Done this file ---------")



In [64]:
SELECTED_SUBJECTS_PATH = './preprocess/shhs/selected_shhs1.txt'


window_size = 30
sfreq = 100
window_size_samples = window_size*sfreq
subject_ids = pd.read_csv(SELECTED_SUBJECTS_PATH, header=None)
subject_ids = subject_ids[0].values.tolist()
SHHS_PATH = '/scratch/shhs/edfs/shhs1'
SHHS_EVENTS_PATH = '/scratch/shhs/annotations-events-profusion'
SELECTED_SUBJECTS_PATH = './preprocess/shhs/selected_shhs1.txt'
SHHS_SAVE_PATH = os.path.join(os.path.split(os.path.split(SHHS_PATH)[0])[0], 'subjects_data')

subject_ids = pd.read_csv(SELECTED_SUBJECTS_PATH, header=None)
subject_ids
import xml.etree.ElementTree as ET


<Element 'SleepStage' at 0x000002442C6B87C0>

In [87]:
['SaO2',
 'H.R.',
 'EEG(sec)',
 'ECG',
 'EMG',
 'EOG(L)',
 'EOG(R)',
 'EEG',
 'SOUND',
 'AIRFLOW',
 'THOR RES',
 'ABDO RES',
 'POSITION',
 'LIGHT',
 'NEW AIR',
 'OX stat']

In [22]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm

import mne
import xml.etree.ElementTree as ET
from braindecode.preprocessing.preprocess import preprocess, Preprocessor, zscore
from braindecode.preprocessing.windowers import create_windows_from_events
from braindecode.datasets import BaseConcatDataset, BaseDataset

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)


SHHS_PATH = '/scratch/shhs/edfs/shhs1'
SHHS_EVENTS_PATH = '/scratch/shhs/annotations-events-profusion'
SELECTED_SUBJECTS_PATH = './preprocess/shhs/selected_shhs1.txt'
SHHS_SAVE_PATH = os.path.join(os.path.split(os.path.split(SHHS_PATH)[0])[0], 'subjects_data')

if not os.path.exists(SHHS_SAVE_PATH):
    os.makedirs(SHHS_SAVE_PATH, exist_ok=True)


window_size = 30
sfreq = 100
window_size_samples = window_size*sfreq
subject_ids = pd.read_csv(SELECTED_SUBJECTS_PATH, header=None)
subject_ids = subject_ids[0].values.tolist()
raw_paths = [os.path.join(SHHS_PATH, f'{f}.edf') for f in subject_ids]
ann_paths = [os.path.join(SHHS_PATH, f'{f}-profusion.xml') for f in subject_ids]

label_mapping = {  
    "Sleep stage W": 0,
    "Sleep stage N1": 1,
    "Sleep stage N2": 2,
    "Sleep stage N3": 3,
    "Sleep stage R": 4,
}
channel_mapping = {
    'eeg': ['EEG', 'EEG(sec)'],
    'ecg': ['ECG'],
    'eog': ['EOG(L)', 'EOG(R)'],
    'emg': ['EMG'],
    'emog': ['EOG(L)', 'EOG(R)', 'EMG'],
    'sound': ['SOUND'],
}

class SHHSSleepStaging(BaseDataset):
    
    def __init__(
        self,
        raw_path=None,
        ann_path=None,
        channels=None
        preload=False,
        crop_wake_mins=30,
        crop=None,
    ):
        if (raw_path is None) or (ann_path is None):
            raise Exception("Please provide paths for raw and annotations file!")
            
        self._fetch_data(raw_path, ann_path)

        raw, desc = self._load_raw(
            raw_path,
            ann_path,
            preload=preload,
            crop_wake_mins=crop_wake_mins,
            crop=crop
        )
        super().__init__(raw, desc)

    @staticmethod    
    def read_annotations(ann_fname):
        labels = []
        t = ET.parse(ann_fname)
        r = t.getroot()

        for i in range(len(r[4])):
            lbl = int(r[4][i].text)
            if lbl == 0:
                labels.append("Sleep stage W")
            elif lbl == 1:
                labels.append("Sleep stage N1")
            elif lbl == 2:
                labels.append("Sleep stage N2")
            elif (lbl == 3) or (lbl == 4):
                labels.append("Sleep stage N3")
            elif lbl == 5:
                labels.append("Sleep stage R")
            else:
                print( "============================== Faulty file =============================")

        labels = np.asarray(labels)
        onsets = [window_size*i for i in range(len(labels))]
        onsets = np.asarray(onsets)
        durations = np.repeat(window_size, len(labels))
        annots = mne.Annotations(onsets, durations, labels)
        return annots

    def _load_raw(
        self,
        raw_fname,
        ann_fname,
        preload,
        crop_wake_mins,
        crop,
    ):
        raw = mne.io.read_raw_edf(raw_fname, preload=preload)
        annots = self.read_annotations(ann_fname)
        raw.set_annotations(annots, emit_warning=False)
        raw.resample(sfreq, npad="auto")

        if crop_wake_mins > 0:
            # Find first and last sleep stages
            mask = [x[-1] in ["1", "2", "3", "R"] for x in annots.description]
            sleep_event_inds = np.where(mask)[0]

            # Crop raw
            tmin = annots[int(sleep_event_inds[0])]["onset"] - crop_wake_mins * 60
            tmax = annots[int(sleep_event_inds[-1])]["onset"] + crop_wake_mins * 60
            raw.crop(tmin=max(tmin, raw.times[0]), tmax=min(tmax, raw.times[-1]))

        if crop is not None:
            raw.crop(*crop)

        raw_basename = os.path.basename(raw_fname)
        subj_nb = int(raw_basename[2:5])
        desc = pd.Series({"subject_id": subj_nb,}, name="")
        return raw, desc

    
def __get_epochs(windows_subject):
    epochs_data = []
    for epoch in windows_subject.windows:
        epochs_data.append(epoch)
    epochs_data = np.stack(epochs_data, axis=0) # Shape of (num_epochs, num_channels, num_sample_points)
    return epochs_data

def __get_channels(raw, ann):
    channels_data = dict()
    for ch in channel_mapping.items():
        shhs_subject = SHHSSleepStaging(raw_path=raw, ann_path=ann, channels=channel_mapping[ch], preload=True)
        shhs_windows_subject = create_windows_from_events(
                                shhs_subject,
                                window_size_samples=window_size_samples,
                                window_stride_samples=window_size_samples,
                                preload=True,
                                mapping=label_mapping,
                            )
        preprocess(shhs_windows_subject, [Preprocessor(zscore)])
        channels_data[ch] = __get_epochs(shhs_windows_subject)
    channels_data['y'] = shhs_windows_subject.y
    channels_data['subject_id'] = shhs_windows_subject.description['subject_id']
    channels_data['epoch_length'] = len(shhs_windows_subject)
    return channels_data, shhs_windows_subject.description['subject_id']


for raw, ann in tqdm(zip(raw_paths, ann_paths), desc="SHHS dataset preprocessing ..."):

    channels_data, sub_id = __get_channels(raw, ann)    
    subjects_save_path = os.path.join(SHHS_SAVE_PATH, f"{sub_id:03d}.npz")
    np.savez(subjects_save_path, **channels_data)
    

In [None]:

SHHS_PATH = '/scratch/shhs/edfs/shhs1'
SHHS_EVENTS_PATH = '/scratch/shhs/annotations-events-profusion'
SELECTED_SUBJECTS_PATH = './preprocess/shhs/selected_shhs1.txt'


window_size = 30
sfreq = 100
window_size_samples = window_size*sfreq
subject_ids = pd.read_csv(SELECTED_SUBJECTS_PATH, header=None)
subject_ids = subject_ids[0].values.tolist()


mapping = {  
    "Sleep stage W": 0,
    "Sleep stage N1": 1,
    "Sleep stage N2": 2,
    "Sleep stage N3": 3,
    "Sleep stage R": 4,
}


class SHHSSleepStaging(BaseConcatDataset):
    
    def __init__(
        self,
        shhs_path=None,
        subject_ids=None,
        preload=False,
        crop_wake_mins=0,
        crop=None,
    ):
        if subject_ids is None:
            subject_ids = range(1, 155)
        if shhs_path is None:
            raise Exception("Please provide path")
        
        self.raw_files, self.edf_files = [], []       
        self._fetch_data(subject_ids, shhs_path)

        all_base_ds = list()
        for raw_fname, ann_fname in zip(self.raw_files, self.edf_files):
            raw, desc = self._load_raw(
                raw_fname,
                ann_fname,
                preload=preload,
                crop_wake_mins=crop_wake_mins,
                crop=crop
            )
            base_ds = BaseDataset(raw, desc)
            all_base_ds.append(base_ds)
        super().__init__(all_base_ds)
    
    def _fetch_data(
        self,
        subject_ids,
        shhs_path,
    ):
        shhs_files = os.listdir(shhs_path)
        for subject in subject_ids:
            current_file = f"SN{subject:03d}.edf"
            if  current_file in shhs_files:
                self.raw_files.append(os.path.join(shhs_path, current_file))
                self.edf_files.append(os.path.join(shhs_path, f"SN{subject:03d}_sleepscoring.edf"))        
    
    @staticmethod
    def _load_raw(
        raw_fname,
        ann_fname,
        preload,
        crop_wake_mins,
        crop,
    ):
        raw = mne.io.read_raw_edf(raw_fname, preload=preload)
        annots = mne.read_annotations(ann_fname)
        raw.set_annotations(annots, emit_warning=False)
        raw.resample(sfreq, npad="auto")

        if crop_wake_mins > 0:
            # Find first and last sleep stages
            mask = [x[-1] in ["1", "2", "3", "R"] for x in annots.description]
            sleep_event_inds = np.where(mask)[0]

            # Crop raw
            tmin = annots[int(sleep_event_inds[0])]["onset"] - crop_wake_mins * 60
            tmax = annots[int(sleep_event_inds[-1])]["onset"] + crop_wake_mins * 60
            raw.crop(tmin=max(tmin, raw.times[0]), tmax=min(tmax, raw.times[-1]))

        if crop is not None:
            raw.crop(*crop)

        raw_basename = os.path.basename(raw_fname)
        subj_nb = int(raw_basename[2:5])
        desc = pd.Series({"subject_id": subj_nb,}, name="")
        return raw, desc

    
def __get_epochs(windows_subject):
    epochs_data = []
    for epoch in windows_subject.windows:
        epochs_data.append(epoch)
    epochs_data = np.stack(epochs_data, axis=0) # Shape of (num_epochs, num_channels, num_sample_points)
    return epochs_data


shhs_dataset = SHHSSleepStaging(shhs_path=SHHS_PATH)
shhs_windows_dataset = create_windows_from_events(
                        shhs_dataset,
                        window_size_samples=window_size_samples,
                        window_stride_samples=window_size_samples,
                        preload=False,
                        mapping=mapping,
                    )
preprocess(shhs_windows_dataset, [Preprocessor(zscore)])

SHHS_SAVE_PATH = os.path.join(os.path.split(SHHS_PATH)[0], 'subjects_data')
if not os.path.exists(SHHS_SAVE_PATH):
    os.makedirs(SHHS_SAVE_PATH, exist_ok=True)

for windows_subject in tqdm(shhs_windows_dataset.datasets, desc="SHHS dataset preprocessing ..."):
    shhs_subject_data = __get_epochs(windows_subject)

    subjects_save_path = os.path.join(SHHS_SAVE_PATH, f"{windows_subject.description['subject_id']:03d}.npz")
    np.savez(subjects_save_path, 
             eeg=shhs_subject_data[:, :4], 
             emg=shhs_subject_data[:, 4:5],
             eog=shhs_subject_data[:, 5:7],
             emog=shhs_subject_data[:, 4:7],
             ecg=shhs_subject_data[:, 7:],
             y=windows_subject.y,
             subject_id=windows_subject.description['subject_id'],
             epoch_length=len(windows_subject),
            )
    

In [88]:
raw.ch_names

['SaO2',
 'H.R.',
 'EEG(sec)',
 'ECG',
 'EMG',
 'EOG(L)',
 'EOG(R)',
 'EEG',
 'SOUND',
 'AIRFLOW',
 'THOR RES',
 'ABDO RES',
 'POSITION',
 'LIGHT',
 'NEW AIR',
 'OX stat']