In [None]:
pip install mne

In [None]:
import os
import glob
import numpy as np
import ntpath
import argparse
import shutil
import math
from datetime import datetime
from mne.io import read_raw_edf
from collections import namedtuple
import re, datetime, operator, logging, sys

In [None]:
EVENT_CHANNEL = 'EDF Annotations'
log = logging.getLogger(__name__)

class EDFEndOfData(Exception): pass


def tal(tal_str):
    '''Return a list with (onset, duration, annotation) tuples for an EDF+ TAL
    stream.
    '''
    exp = '(?P<onset>[+\-]\d+(?:\.\d*)?)' + \
        '(?:\x15(?P<duration>\d+(?:\.\d*)?))?' + \
        '(\x14(?P<annotation>[^\x00]*))?' + \
        '(?:\x14\x00)'

    def parse(dic):
        return (
        float(dic['onset']),
        float(dic['duration']) if dic['duration'] else 0.,
        str(dic['annotation'].encode('utf-8')).split('\x14') if dic['annotation'] else [])

    return [parse(m.groupdict()) for m in re.finditer(exp, tal_str)]


class EDFHeaderParser:
    def parse_edf_header(f):
        h = {}
        assert f.tell() == 0  # check file position
        assert f.read(8) == '0       '

        # recording info)
        h['local_subject_id'] = f.read(80).strip()
        h['local_recording_id'] = f.read(80).strip()

        # parse timestamp
        (day, month, year) = [int(x) for x in re.findall('(\d+)', f.read(8))]
        (hour, minute, sec)= [int(x) for x in re.findall('(\d+)', f.read(8))]
        h['date_time'] = str(datetime.datetime(year + 2000, month, day,
                                            hour, minute, sec))

        # misc
        header_nbytes = int(f.read(8))
        subtype = f.read(44)[:5]
        h['EDF+'] = subtype in ['EDF+C', 'EDF+D']
        h['contiguous'] = subtype != 'EDF+D'
        h['n_records'] = int(f.read(8))
        h['record_length'] = float(f.read(8))  # in seconds
        nchannels = h['n_channels'] = int(f.read(4))

        # read channel info
        channels = range(h['n_channels'])
        h['label'] = [f.read(16).strip() for n in channels]
        h['transducer_type'] = [f.read(80).strip() for n in channels]
        h['units'] = [f.read(8).strip() for n in channels]
        h['physical_min'] = np.asarray([float(f.read(8)) for n in channels])
        h['physical_max'] = np.asarray([float(f.read(8)) for n in channels])
        h['digital_min'] = np.asarray([float(f.read(8)) for n in channels])
        h['digital_max'] = np.asarray([float(f.read(8)) for n in channels])
        h['prefiltering'] = [f.read(80).strip() for n in channels]
        h['n_samples_per_record'] = [int(f.read(8)) for n in channels]
        f.read(32 * nchannels)  # reserved
        return h


class SleepEDFReader:
    def __init__(self, file, header_parser=EDFHeaderParser()):
        self.file = file
        self.parser = header_parser
        self.header = self.parser.parse_edf_header(self.file)
        self.digital_min, self.physical_min, self.gain = None, None, None
        self.get_ranges()

    def get_ranges(self):
        self.digital_min, self.physical_min = self.header['digital_min'], self.header['physical_min']
        ranges = [self.header['digital_max'] - self.digital_min, self.header['physical_max'] - self.physical_min]
        assert np.all(ranges > 0)
        self.gain = ranges[1]/ranges[0]

    def read_one_record(self):
        dig_min, phys_min, gain = self.dig_min, self.phys_min, self.gain
        time = float('nan')
        signals = []
        events = []
        raw_record = []
        for nsamp in self.header['n_samples_per_record']:
            samples = self.file.read(nsamp * 2)
            if len(samples) != nsamp * 2:
                raise EDFEndOfData
            result.append(samples)
        for (i, samples) in enumerate(raw_record):
            if self.header['label'][i] == EVENT_CHANNEL:
                ann = tal(samples)
                time = ann[0][0]
                events.extend(ann[1:])
                # print(i, samples)
                # exit()
            else:
                # 2-byte little-endian integers
                dig = np.fromstring(samples, '<i2').astype(np.float32)
                phys = (dig - dig_min[i]) * gain[i] + phys_min[i]
                signals.append(phys)

        return time, signals, events

        def retrieve_records(self):
            '''
            Record generator.
            '''
            try:
                while True:
                    yield self.read_record()
            except EDFEndOfData:
                pass


class BaseEDFReader:
    def __init__(self, file):
        self.file = file

    def read_header(self):
        self.header = parse_edf_header(self.file)

        # calculate ranges for rescaling
        self.dig_min = self.header['digital_min']
        self.phys_min = self.header['physical_min']
        phys_range = self.header['physical_max'] - self.phys_min
        dig_range = self.header['digital_max'] - self.dig_min
        assert np.all(phys_range > 0)
        assert np.all(dig_range > 0)
        self.gain = phys_range / dig_range

    def read_raw_record(self):
        '''Read a record with data_2013 and return a list containing arrays with raw
        bytes.
        '''
        result = []
        for nsamp in self.header['n_samples_per_record']:
            samples = self.file.read(nsamp * 2)
            if len(samples) != nsamp * 2:
                raise EDFEndOfData
            result.append(samples)
        return result


    def convert_record(self, raw_record):
        '''Convert a raw record to a (time, signals, events) tuple based on
        information in the header.
        '''
        dig_min, phys_min, gain = self.dig_min, self.phys_min, self.gain
        time = float('nan')
        signals = []
        events = []
        for (i, samples) in enumerate(raw_record):
            if self.header['label'][i] == EVENT_CHANNEL:
                ann = tal(samples)
                time = ann[0][0]
                events.extend(ann[1:])
                # print(i, samples)
                # exit()
            else:
                # 2-byte little-endian integers
                dig = np.fromstring(samples, '<i2').astype(np.float32)
                phys = (dig - dig_min[i]) * gain[i] + phys_min[i]
                signals.append(phys)

        return time, signals, events


    def read_record(self):
        return self.convert_record(self.read_raw_record())


    def records(self):
        '''
        Record generator.
        '''
        try:
            while True:
                yield self.read_record()
        except EDFEndOfData:
            pass


def load_edf(edffile):
    '''Load an EDF+ file.
    Very basic reader for EDF and EDF+ files. While BaseEDFReader does support
    exotic features like non-homogeneous sample rates and loading only parts of
    the stream, load_edf expects a single fixed sample rate for all channels and
    tries to load the whole file.
    Parameters
    ----------
    edffile : file-like object or string
    Returns
    -------
    Named tuple with the fields:
        X : NumPy array with shape p by n.
        Raw recording of n samples in p dimensions.
        sample_rate : float
        The sample rate of the recording. Note that mixed sample-rates are not
        supported.
        sens_lab : list of length p with strings
        The labels of the sensors used to record X.
        time : NumPy array with length n
        The time offset in the recording for each sample.
        annotations : a list with tuples
        EDF+ annotations are stored in (start, duration, description) tuples.
        start : float
            Indicates the start of the event in seconds.
        duration : float
            Indicates the duration of the event in seconds.
        description : list with strings
            Contains (multiple?) descriptions of the annotation event.
    '''
    if isinstance(edffile, basestring):
        with open(edffile, 'rb') as f:
            return load_edf(f)  # convert filename to file

    reader = SleepEDFReader(edffile)

    log.debug('EDF header: %s' % reader.header)

    # get sample rate info
    nsamp = np.unique(
        [n for (l, n) in zip(reader.header['label'], reader.header['n_samples_per_record'])
        if l != EVENT_CHANNEL])
    assert nsamp.size == 1, 'Multiple sample rates not supported!'
    sample_rate = float(nsamp[0]) / reader.header['record_length']

    rectime, X, annotations = zip(*reader.records())
    X = np.hstack(X)
    annotations = reduce(operator.add, annotations)
    chan_lab = [lab for lab in reader.header['label'] if lab != EVENT_CHANNEL]

    # create timestamps
    if reader.header['contiguous']:
        time = np.arange(X.shape[1]) / sample_rate
    else:
        reclen = reader.header['record_length']
        within_rec_time = np.linspace(0, reclen, nsamp, endpoint=False)
        time = np.hstack([t + within_rec_time for t in rectime])

    tup = namedtuple('EDF', 'X sample_rate chan_lab time annotations')
    return tup(X, sample_rate, chan_lab, time, annotations)

In [None]:
# Constants for sleep stages
W = 0           # Wakefulness stage
N1 = 1          # Non-REM stage 1
N2 = 2          # Non-REM stage 2
N3 = 3          # Non-REM stage 3
REM = 4         # Rapid Eye Movement (REM) stage
UNKNOWN = 5     # Unknown or unspecified stage

EPOCH_SEC_SIZE = 30  # Duration of each sleep stage epoch in seconds

# Function to map annotation strings to sleep stage labels
def get_sleep_stage_label(ann_str):
    ann_str = ann_str[2:-1]
    if ann_str in ["Sleep stage W", "Sleep stage ?"]:
        return W
    elif ann_str == "Sleep stage 1":
        return N1
    elif ann_str == "Sleep stage 2":
        return N2
    elif ann_str == "Sleep stage 3" or ann_str == "Sleep stage 4":
        return N3
    elif ann_str == "Sleep stage R":
        return REM
    else:
        return UNKNOWN

# Function to load data from PSG and annotation files
def load_data(psg_fname, ann_fname, select_ch):
    # Read PSG (Polysomnography) data from EDF file
    raw = read_raw_edf(psg_fname, preload=True, stim_channel=None)
    sampling_rate = raw.info['sfreq']

    # Extract the selected channel's data
    raw_ch_df = raw.to_data_frame(scaling_time=100.0)[select_ch]
    raw_ch_df = raw_ch_df.to_frame()
    raw_ch_df.set_index(np.arange(len(raw_ch_df)))

    # Get header information from PSG file
    with open(psg_fname, 'r', errors='ignore') as f:
        reader_raw = BaseEDFReader(f)
        reader_raw.read_header()
        h_raw = reader_raw.header

    raw_start_dt = datetime.strptime(h_raw['date_time'], "%Y-%m-%d %H:%M:%S")

    # Read annotation and its header from Hypnogram file
    with open(ann_fname, 'r', errors='ignore') as f:
        reader_ann = BaseEDFReader(f)
        reader_ann.read_header()
        h_ann = reader_ann.header
        _, _, ann = zip(*reader_ann.records())

    ann_start_dt = datetime.strptime(h_ann['date_time'], "%Y-%m-%d %H:%M:%S")

    # Ensure that PSG and annotation files start at the same time
    assert raw_start_dt == ann_start_dt

    # Initialize lists to store data and labels
    remove_idx = []     # Indices of data to be removed
    labels = []         # Sleep stage labels
    label_idx = []      # Indices corresponding to labeled data

    # Process annotations to extract sleep stage labels
    for a in ann[0]:
        onset_sec, duration_sec, ann_char = a
        ann_str = "".join(ann_char)
        label = get_sleep_stage_label(ann_str)

        if label != UNKNOWN:
            if duration_sec % EPOCH_SEC_SIZE != 0:
                raise Exception("Something wrong")
            duration_epoch = int(duration_sec / EPOCH_SEC_SIZE)
            label_epoch = np.ones(duration_epoch, dtype=np.int) * label
            labels.append(label_epoch)
            idx = int(onset_sec * sampling_rate) + np.arange(duration_sec * sampling_rate, dtype=np.int)
            label_idx.append(idx)
        else:
            idx = int(onset_sec * sampling_rate) + np.arange(duration_sec * sampling_rate, dtype=np.int)
            remove_idx.append(idx)

    # Stack labels to create the label array
    labels = np.hstack(labels)

    # Remove unwanted data indices
    if len(remove_idx) > 0:
        remove_idx = np.hstack(remove_idx)
        select_idx = np.setdiff1d(np.arange(len(raw_ch_df)), remove_idx)
    else:
        select_idx = np.arange(len(raw_ch_df))

    # Intersection of selected indices with label indices
    label_idx = np.hstack(label_idx)
    select_idx = np.intersect1d(select_idx, label_idx)

    # Remove extra labels if any
    if len(label_idx) > len(select_idx):
        extra_idx = np.setdiff1d(label_idx, select_idx)
        n_label_trims = int(math.ceil(len(extra_idx) / (EPOCH_SEC_SIZE * sampling_rate)))
        if n_label_trims != 0:
            labels = labels[:-n_label_trims]

    # Select and preprocess the relevant PSG data
    raw_ch = raw_ch_df.values[select_idx]

    # Ensure data is evenly divisible into epochs
    if len(raw_ch) % (EPOCH_SEC_SIZE * sampling_rate) != 0:
        raise Exception("Something wrong")
    n_epochs = len(raw_ch) / (EPOCH_SEC_SIZE * sampling_rate)

    # Split data into epochs and convert labels to int32
    x = np.asarray(np.split(raw_ch, n_epochs)).astype(np.float32)
    y = labels.astype(np.int32)

    # Find indices for non-wakeful stages and select a relevant portion
    nw_idx = np.where(y != W)[0]
    start_idx = nw_idx[0] - (EPOCH_SEC_SIZE * 2)
    end_idx = nw_idx[-1] + (EPOCH_SEC_SIZE * 2)
    if start_idx < 0: start_idx = 0
    if end_idx >= len(y): end_idx = len(y) - 1
    select_idx = np.arange(start_idx, end_idx + 1)

    # Select data and labels based on relevant indices
    x = x[select_idx]
    y = y[select_idx]

    return x, y, sampling_rate, h_raw, h_ann


In [None]:
parser = argparse.ArgumentParser()
parser.add_argument("--data_dir", type=str, default="data_edf_20",
                    help="File path to the PSG and annotation files.")
parser.add_argument("--output_dir", type=str, default="data_edf_20_npz/fpzcz",
                    help="Directory where to save numpy files outputs.")
parser.add_argument("--select_ch", type=str, default="EEG Fpz-Cz",
                    help="The selected channel")
args = parser.parse_args()
if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)
else:
    shutil.rmtree(args.output_dir)
    os.makedirs(args.output_dir)
select_ch = args.select_ch
psg_fnames = glob.glob(os.path.join(args.data_dir, "*PSG.edf"))
ann_fnames = glob.glob(os.path.join(args.data_dir, "*Hypnogram.edf"))
psg_fnames.sort()
ann_fnames.sort()
psg_fnames = np.asarray(psg_fnames)
ann_fnames = np.asarray(ann_fnames)
for i in range(len(psg_fnames)):
    x, y, fs, h_raw, h_ann = load_data(psg_fnames[i], ann_fnames[i], select_ch)
    filename = ntpath.basename(psg_fnames[i]).replace("-PSG.edf", ".npz")
    save_dict = {
        "x": x,
        "y": y,
        "fs": fs,
        "ch_label": select_ch,
        "header_raw": h_raw,
        "header_annotation": h_ann,
    }
    np.savez(os.path.join(args.output_dir, filename), **save_dict)



usage: colab_kernel_launcher.py [-h] [--data_dir DATA_DIR]
                                [--output_dir OUTPUT_DIR]
                                [--select_ch SELECT_CH]
colab_kernel_launcher.py: error: unrecognized arguments: -f /root/.local/share/jupyter/runtime/kernel-58857c0e-8fcc-41b0-a58b-bbd2a9e63af2.json


SystemExit: ignored