In [1]:
import argparse
import glob
import math
import ntpath
import os
import shutil
import pyedflib
import numpy as np
import pandas as pd

from sleepstage import stage_dict
from logger import get_logger



In [2]:

# Have to manually define based on the dataset
ann2label = {
    "Sleep stage W": 0,
    "Sleep stage 1": 1,
    "Sleep stage 2": 2,
    "Sleep stage 3": 3,
    "Sleep stage 4": 3,  # Follow AASM Manual
    "Sleep stage R": 4,
    "Sleep stage ?": 6,
    "Movement time": 5,
}


In [21]:


# GD
os.chdir("/home/gs/code/high-density-eeg-analysis/")



In [6]:
parser = argparse.ArgumentParser()


In [7]:

parser.add_argument(
    "--data_dir",
    type=str,
    default="./data/sleepedf/sleep-cassette",
    help="File path to the Sleep-EDF dataset.",
)


_StoreAction(option_strings=['--data_dir'], dest='data_dir', nargs=None, const=None, default='./data/sleepedf/sleep-cassette', type=<class 'str'>, choices=None, help='File path to the Sleep-EDF dataset.', metavar=None)

In [8]:

parser.add_argument(
    "--output_dir",
    type=str,
    default="./data/sleepedf/sleep-cassette/eeg_fpz_cz",
    help="Directory where to save outputs.",
)


parser.add_argument(
    "--select_ch",
    type=str,
    default="EEG Fpz-Cz",
    help="Name of the channel in the dataset.",
)
parser.add_argument(
    "--log_file", type=str, default="info_ch_extract.log", help="Log file."
)



_StoreAction(option_strings=['--log_file'], dest='log_file', nargs=None, const=None, default='info_ch_extract.log', type=<class 'str'>, choices=None, help='Log file.', metavar=None)

In [11]:
args = parser.parse_args([])
args

Namespace(data_dir='./data/sleepedf/sleep-cassette', log_file='info_ch_extract.log', output_dir='./data/sleepedf/sleep-cassette/eeg_fpz_cz', select_ch='EEG Fpz-Cz')

In [12]:

# Output dir
if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)
else:
    shutil.rmtree(args.output_dir)
    os.makedirs(args.output_dir)

args.log_file = os.path.join(args.output_dir, args.log_file)


In [17]:
args.log_file

'./data/sleepedf/sleep-cassette/eeg_fpz_cz/info_ch_extract.log'

In [None]:

# Create logger
logger = get_logger(args.log_file, level="info")

# Select channel
select_ch = args.select_ch


In [18]:
select_ch

'EEG Fpz-Cz'

In [22]:

# Read raw and annotation from EDF files
psg_fnames = glob.glob(os.path.join(args.data_dir, "*PSG.edf"))
ann_fnames = glob.glob(os.path.join(args.data_dir, "*Hypnogram.edf"))
psg_fnames.sort()
ann_fnames.sort()
psg_fnames = np.asarray(psg_fnames)
ann_fnames = np.asarray(ann_fnames)

In [23]:
psg_fnames

array(['./data/sleepedf/sleep-cassette/SC4001E0-PSG.edf',
       './data/sleepedf/sleep-cassette/SC4002E0-PSG.edf',
       './data/sleepedf/sleep-cassette/SC4011E0-PSG.edf',
       './data/sleepedf/sleep-cassette/SC4012E0-PSG.edf',
       './data/sleepedf/sleep-cassette/SC4021E0-PSG.edf',
       './data/sleepedf/sleep-cassette/SC4022E0-PSG.edf',
       './data/sleepedf/sleep-cassette/SC4031E0-PSG.edf',
       './data/sleepedf/sleep-cassette/SC4032E0-PSG.edf',
       './data/sleepedf/sleep-cassette/SC4052E0-PSG.edf',
       './data/sleepedf/sleep-cassette/SC4061E0-PSG.edf',
       './data/sleepedf/sleep-cassette/SC4062E0-PSG.edf',
       './data/sleepedf/sleep-cassette/SC4072E0-PSG.edf',
       './data/sleepedf/sleep-cassette/SC4082E0-PSG.edf',
       './data/sleepedf/sleep-cassette/SC4092E0-PSG.edf',
       './data/sleepedf/sleep-cassette/SC4101E0-PSG.edf',
       './data/sleepedf/sleep-cassette/SC4122E0-PSG.edf',
       './data/sleepedf/sleep-cassette/SC4131E0-PSG.edf',
       './data

In [24]:
args.data_dir

'./data/sleepedf/sleep-cassette'

In [13]:
i = 1

In [25]:
logger.info("Loading ...")
logger.info("Signal file: {}".format(psg_fnames[i]))
logger.info("Annotation file: {}".format(ann_fnames[i]))


[36m[INFO ][0m Loading ...
[36m[INFO ][0m Signal file: ./data/sleepedf/sleep-cassette/SC4002E0-PSG.edf
[36m[INFO ][0m Annotation file: ./data/sleepedf/sleep-cassette/SC4002EC-Hypnogram.edf


In [26]:

psg_f = pyedflib.EdfReader(psg_fnames[i])
ann_f = pyedflib.EdfReader(ann_fnames[i])

assert psg_f.getStartdatetime() == ann_f.getStartdatetime()
start_datetime = psg_f.getStartdatetime()
logger.info("Start datetime: {}".format(str(start_datetime)))


[36m[INFO ][0m Start datetime: 1989-04-25 14:50:00


In [27]:

file_duration = psg_f.getFileDuration()
logger.info("File duration: {} sec".format(file_duration))
epoch_duration = psg_f.datarecord_duration
if psg_f.datarecord_duration == 60: # Fix problems of SC4362F0-PSG.edf, SC4362FC-Hypnogram.edf
    epoch_duration = epoch_duration / 2
    logger.info("Epoch duration: {} sec (changed from 60 sec)".format(epoch_duration))
else:
    logger.info("Epoch duration: {} sec".format(epoch_duration))



[36m[INFO ][0m File duration: 84900.0 sec
[36m[INFO ][0m Epoch duration: 30.0 sec


In [28]:

# Extract signal from the selected channel
ch_names = psg_f.getSignalLabels()
ch_samples = psg_f.getNSamples()
select_ch_idx = -1
for s in range(psg_f.signals_in_file):
    if ch_names[s] == select_ch:
        select_ch_idx = s
        break
if select_ch_idx == -1:
    raise Exception("Channel not found.")
sampling_rate = psg_f.getSampleFrequency(select_ch_idx)

n_epoch_samples = int(epoch_duration * sampling_rate)

In [31]:
signals = psg_f.readSignal(select_ch_idx).reshape(-1, n_epoch_samples)
signals.shape

(2830, 3000)

In [32]:

logger.info("Select channel: {}".format(select_ch))
logger.info("Select channel samples: {}".format(ch_samples[select_ch_idx]))
logger.info("Sample rate: {}".format(sampling_rate))


[36m[INFO ][0m Select channel: EEG Fpz-Cz
[36m[INFO ][0m Select channel samples: 8490000
[36m[INFO ][0m Sample rate: 100.0


In [33]:

# Sanity check
n_epochs = psg_f.datarecords_in_file
if (
    psg_f.datarecord_duration == 60
):  # Fix problems of SC4362F0-PSG.edf, SC4362FC-Hypnogram.edf
    n_epochs = n_epochs * 2
assert len(signals) == n_epochs, f"signal: {signals.shape} != {n_epochs}"


In [46]:
# Generate labels from onset and duration annotation
# ann_onsets: list of onset time in seconds
# ann_durations: list of duration time in seconds
# ann_stages: list of sleep stages

labels = []
total_duration = 0
ann_onsets, ann_durations, ann_stages = ann_f.readAnnotations()

ann_onsets, ann_durations, ann_stages

(array([    0., 26070., 26160., 26670., 26940., 28080., 28110., 28230.,
        28260., 28290., 28320., 28350., 29250., 29310., 29550., 29580.,
        29730., 29820., 30030., 30630., 31110., 31140., 31260., 31290.,
        31320., 31440., 32550., 32790., 32820., 32880., 32970., 33000.,
        33090., 33270., 33300., 33360., 33390., 33420., 33450., 33480.,
        33510., 33570., 33870., 33900., 34020., 34290., 34320., 34350.,
        34380., 34410., 34470., 35790., 35880., 35940., 36150., 36240.,
        36540., 36570., 36630., 36900., 36960., 37110., 37140., 38160.,
        38190., 38220., 38310., 38340., 38430., 38490., 38820., 38850.,
        38910., 39300., 39870., 39900., 39990., 40020., 40050., 40080.,
        40170., 40200., 40230., 40440., 40920., 40980., 41010., 41280.,
        41310., 41370., 41970., 42000., 42210., 42300., 43290., 43320.,
        43830., 44010., 44130., 44220., 44460., 44490., 44610., 44670.,
        44850., 44880., 44910., 44940., 45390., 45450., 45630., 

In [47]:
for a in range(len(ann_stages)):
    onset_sec = int(ann_onsets[a])
    duration_sec = int(ann_durations[a])
    ann_str = "".join(ann_stages[a])

    # Sanity check
    assert onset_sec == total_duration

    # Get label value
    label = ann2label[ann_str]

    # Compute # of epoch for this stage
    if duration_sec % epoch_duration != 0:
        logger.info(f"Something wrong: {duration_sec} {epoch_duration}")
        raise Exception(f"Something wrong: {duration_sec} {epoch_duration}")
    duration_epoch = int(duration_sec / epoch_duration)

    # Generate sleep stage labels
    label_epoch = np.ones(duration_epoch, dtype=int) * label
    labels.append(label_epoch)

    total_duration += duration_sec

    logger.info(
        "Include onset:{}, duration:{}, label:{} ({})".format(
            onset_sec, duration_sec, label, ann_str
        )
    )

[36m[INFO ][0m Include onset:0, duration:26070, label:0 (Sleep stage W)
[36m[INFO ][0m Include onset:26070, duration:90, label:1 (Sleep stage 1)
[36m[INFO ][0m Include onset:26160, duration:510, label:2 (Sleep stage 2)
[36m[INFO ][0m Include onset:26670, duration:270, label:3 (Sleep stage 3)
[36m[INFO ][0m Include onset:26940, duration:1140, label:3 (Sleep stage 4)
[36m[INFO ][0m Include onset:28080, duration:30, label:5 (Movement time)
[36m[INFO ][0m Include onset:28110, duration:120, label:3 (Sleep stage 3)
[36m[INFO ][0m Include onset:28230, duration:30, label:3 (Sleep stage 4)
[36m[INFO ][0m Include onset:28260, duration:30, label:3 (Sleep stage 3)
[36m[INFO ][0m Include onset:28290, duration:30, label:3 (Sleep stage 4)
[36m[INFO ][0m Include onset:28320, duration:30, label:3 (Sleep stage 3)
[36m[INFO ][0m Include onset:28350, duration:900, label:3 (Sleep stage 4)
[36m[INFO ][0m Include onset:29250, duration:60, label:3 (Sleep stage 3)
[36m[INFO ][0m Inc

In [48]:
labels

[array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [49]:
labels = np.hstack(labels)

labels
signals

array([[-48.93919414,  -8.71575092, -32.54432234, ...,  27.94358974,
        -28.26739927, -14.52014652],
       [-54.84542125, -67.37069597, -54.84542125, ..., -15.43663004,
        -16.76043956, -10.03956044],
       [ -2.60586081,  -1.18021978,   4.52234432, ...,   8.29010989,
          3.5040293 ,  18.88058608],
       ...,
       [ 18.47326007,  10.12307692,  51.46666667, ...,   2.99487179,
         -3.92967033,  -8.1047619 ],
       [  3.3003663 ,   5.64249084,   1.26373626, ...,  52.99413919,
         56.86373626,  38.12673993],
       [ 34.46080586,  37.51575092,  33.13699634, ...,  27.23076923,
         51.26300366,  62.76996337]])

In [None]:

# Remove annotations that are longer than the recorded signals
labels = labels[: len(signals)]

# Get epochs and their corresponding labels
x = signals.astype(np.float32)
y = labels.astype(np.int32)

# Select only sleep periods
w_edge_mins = 30
nw_idx = np.where(y != stage_dict["W"])[0]
start_idx = nw_idx[0] - (w_edge_mins * 2)
end_idx = nw_idx[-1] + (w_edge_mins * 2)
if start_idx < 0:
    start_idx = 0
if end_idx >= len(y):
    end_idx = len(y) - 1
select_idx = np.arange(start_idx, end_idx + 1)
logger.info("Data before selection: {}, {}".format(x.shape, y.shape))
x = x[select_idx]
y = y[select_idx]
logger.info("Data after selection: {}, {}".format(x.shape, y.shape))

In [40]:
x, y

(array([[ 9.81758213e+00,  1.52146521e+01,  1.57238092e+01, ...,
          6.16498184e+01,  5.99186821e+01,  5.77802200e+01],
        [ 5.27904778e+01,  4.51531143e+01,  4.06725273e+01, ...,
          8.74131851e+01,  8.45619049e+01,  7.44805832e+01],
        [ 7.09164810e+01,  6.80652008e+01,  6.71487198e+01, ...,
          5.45216103e+01,  4.76989021e+01,  5.13648338e+01],
        ...,
        [ 2.42776566e+01,  2.76380959e+01,  3.91450539e+01, ...,
         -1.18725271e+01,  5.50915778e-01,  1.20578756e+01],
        [ 1.77604389e+01,  4.72600746e+00,  4.17582430e-02, ...,
          2.43794880e+01,  1.36556780e+00,  1.92879124e+01],
        [-7.18827820e+00,  1.53164835e+01,  1.05304031e+01, ...,
          7.80446854e+01,  9.03663025e+01,  8.67003632e+01]], dtype=float32),
 array([0, 0, 0, ..., 0, 0, 0], dtype=int32))

In [None]:


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--data_dir",
        type=str,
        default="./data/sleepedf/sleep-cassette",
        help="File path to the Sleep-EDF dataset.",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="./data/sleepedf/sleep-cassette/eeg_fpz_cz",
        help="Directory where to save outputs.",
    )
    parser.add_argument(
        "--select_ch",
        type=str,
        default="EEG Fpz-Cz",
        help="Name of the channel in the dataset.",
    )
    parser.add_argument(
        "--log_file", type=str, default="info_ch_extract.log", help="Log file."
    )
    args = parser.parse_args()

    # Output dir
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    else:
        shutil.rmtree(args.output_dir)
        os.makedirs(args.output_dir)

    args.log_file = os.path.join(args.output_dir, args.log_file)

    # Create logger
    logger = get_logger(args.log_file, level="info")

    # Select channel
    select_ch = args.select_ch

    # Read raw and annotation from EDF files
    psg_fnames = glob.glob(os.path.join(args.data_dir, "*PSG.edf"))
    ann_fnames = glob.glob(os.path.join(args.data_dir, "*Hypnogram.edf"))
    psg_fnames.sort()
    ann_fnames.sort()
    psg_fnames = np.asarray(psg_fnames)
    ann_fnames = np.asarray(ann_fnames)

    # for i in range(len(psg_fnames)):
    i = 1
    while i == 1:

        logger.info("Loading ...")
        logger.info("Signal file: {}".format(psg_fnames[i]))
        logger.info("Annotation file: {}".format(ann_fnames[i]))

        psg_f = pyedflib.EdfReader(psg_fnames[i])
        ann_f = pyedflib.EdfReader(ann_fnames[i])

        assert psg_f.getStartdatetime() == ann_f.getStartdatetime()
        start_datetime = psg_f.getStartdatetime()
        logger.info("Start datetime: {}".format(str(start_datetime)))

        file_duration = psg_f.getFileDuration()
        logger.info("File duration: {} sec".format(file_duration))
        epoch_duration = psg_f.datarecord_duration
        if (
            psg_f.datarecord_duration == 60
        ):  # Fix problems of SC4362F0-PSG.edf, SC4362FC-Hypnogram.edf
            epoch_duration = epoch_duration / 2
            logger.info(
                "Epoch duration: {} sec (changed from 60 sec)".format(epoch_duration)
            )
        else:
            logger.info("Epoch duration: {} sec".format(epoch_duration))

        # Extract signal from the selected channel
        ch_names = psg_f.getSignalLabels()
        ch_samples = psg_f.getNSamples()
        select_ch_idx = -1
        for s in range(psg_f.signals_in_file):
            if ch_names[s] == select_ch:
                select_ch_idx = s
                break
        if select_ch_idx == -1:
            raise Exception("Channel not found.")
        sampling_rate = psg_f.getSampleFrequency(select_ch_idx)

        n_epoch_samples = int(epoch_duration * sampling_rate)

        # GS
        # print(select_ch_idx, n_epoch_samples, epoch_duration, sampling_rate)
        # signals = psg_f.readSignal(select_ch_idx)[:90000 * (7950000 // 90000)].reshape(-1, n_epoch_samples)

        signals = psg_f.readSignal(select_ch_idx).reshape(-1, n_epoch_samples)

        logger.info("Select channel: {}".format(select_ch))
        logger.info("Select channel samples: {}".format(ch_samples[select_ch_idx]))
        logger.info("Sample rate: {}".format(sampling_rate))

        # Sanity check
        n_epochs = psg_f.datarecords_in_file
        if (
            psg_f.datarecord_duration == 60
        ):  # Fix problems of SC4362F0-PSG.edf, SC4362FC-Hypnogram.edf
            n_epochs = n_epochs * 2
        assert len(signals) == n_epochs, f"signal: {signals.shape} != {n_epochs}"

        # Generate labels from onset and duration annotation
        labels = []
        total_duration = 0
        ann_onsets, ann_durations, ann_stages = ann_f.readAnnotations()
        for a in range(len(ann_stages)):
            onset_sec = int(ann_onsets[a])
            duration_sec = int(ann_durations[a])
            ann_str = "".join(ann_stages[a])

            # Sanity check
            assert onset_sec == total_duration

            # Get label value
            label = ann2label[ann_str]

            # Compute # of epoch for this stage
            if duration_sec % epoch_duration != 0:
                logger.info(f"Something wrong: {duration_sec} {epoch_duration}")
                raise Exception(f"Something wrong: {duration_sec} {epoch_duration}")
            duration_epoch = int(duration_sec / epoch_duration)

            # Generate sleep stage labels
            label_epoch = np.ones(duration_epoch, dtype=np.int) * label
            labels.append(label_epoch)

            total_duration += duration_sec

            logger.info(
                "Include onset:{}, duration:{}, label:{} ({})".format(
                    onset_sec, duration_sec, label, ann_str
                )
            )
        labels = np.hstack(labels)

        # Remove annotations that are longer than the recorded signals
        labels = labels[: len(signals)]

        # Get epochs and their corresponding labels
        x = signals.astype(np.float32)
        y = labels.astype(np.int32)

        # Select only sleep periods
        w_edge_mins = 30
        nw_idx = np.where(y != stage_dict["W"])[0]
        start_idx = nw_idx[0] - (w_edge_mins * 2)
        end_idx = nw_idx[-1] + (w_edge_mins * 2)
        if start_idx < 0:
            start_idx = 0
        if end_idx >= len(y):
            end_idx = len(y) - 1
        select_idx = np.arange(start_idx, end_idx + 1)
        logger.info("Data before selection: {}, {}".format(x.shape, y.shape))
        x = x[select_idx]
        y = y[select_idx]
        logger.info("Data after selection: {}, {}".format(x.shape, y.shape))

        # Remove movement and unknown
        move_idx = np.where(y == stage_dict["MOVE"])[0]
        unk_idx = np.where(y == stage_dict["UNK"])[0]
        if len(move_idx) > 0 or len(unk_idx) > 0:
            remove_idx = np.union1d(move_idx, unk_idx)
            logger.info("Remove irrelavant stages")
            logger.info("  Movement: ({}) {}".format(len(move_idx), move_idx))
            logger.info("  Unknown: ({}) {}".format(len(unk_idx), unk_idx))
            logger.info("  Remove: ({}) {}".format(len(remove_idx), remove_idx))
            logger.info("  Data before removal: {}, {}".format(x.shape, y.shape))
            select_idx = np.setdiff1d(np.arange(len(x)), remove_idx)
            x = x[select_idx]
            y = y[select_idx]
            logger.info("  Data after removal: {}, {}".format(x.shape, y.shape))

        # Save
        filename = ntpath.basename(psg_fnames[i]).replace("-PSG.edf", ".npz")
        save_dict = {
            "x": x,
            "y": y,
            "fs": sampling_rate,
            "ch_label": select_ch,
            "start_datetime": start_datetime,
            "file_duration": file_duration,
            "epoch_duration": epoch_duration,
            "n_all_epochs": n_epochs,
            "n_epochs": len(x),
        }
        np.savez(os.path.join(args.output_dir, filename), **save_dict)

        logger.info("\n=======================================\n")


if __name__ == "__main__":
    main()