In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pyedflib
import os
import sys
sys.path.append("..")
from sleeprnn.data import utils, stamp_correction
from sleeprnn.common import viz
%matplotlib inline
viz.notebook_full_width()

In [None]:
# fs = 256
subject_id = 19
mass_folder = "/home/ntapia/projects/repos/sleep-rnn/resources/datasets/mass2/"
eeg_file = "register/01-02-00%02d PSG.edf" % subject_id
ss_file = "label/spindle/01-02-00%02d Spindles_E1.edf" % subject_id
kc_file = "label/kcomplex/01-02-00%02d KComplexes_E1.edf" % subject_id
hypno_file = "label/state/01-02-00%02d Base.edf"  % subject_id

In [None]:
# Load signal
with pyedflib.EdfReader(os.path.join(mass_folder, eeg_file)) as file:
    channel_names = file.getSignalLabels()
    channel_to_extract = channel_names.index("EEG C3-CLE")
    signal = file.readSignal(channel_to_extract)
    fs = file.samplefrequency(channel_to_extract)
    print('Channel extracted: %s (%s Hz)' % (file.getLabel(channel_to_extract), fs))
    fs = int(fs)
# signal = utils.broad_filter(signal, fs_old)
# signal = utils.resample_signal(signal, fs_old=fs_old, fs_new=fs)
signal = signal.astype(np.float32)

print("Signal", signal.shape, signal.dtype)

In [None]:
# Load annotations -- SS
min_ss_duration = 0.3
max_ss_duration = 3.0
with pyedflib.EdfReader(os.path.join(mass_folder, ss_file)) as file:
    annotations = file.readAnnotations()
onsets = np.array(annotations[0])
durations = np.array(annotations[1])
offsets = onsets + durations
marks_time = np.stack((onsets, offsets), axis=1)  # time-stamps
marks_ss = np.round(marks_time * fs).astype(np.int32)
# marks = stamp_correction.combine_close_stamps(marks, fs, min_ss_duration)
# marks_ss = stamp_correction.filter_duration_stamps(marks, fs, min_ss_duration, max_ss_duration)

# Load annotations -- KC
min_kc_duration = 0.2
max_kc_duration = None
with pyedflib.EdfReader(os.path.join(mass_folder, kc_file)) as file:
    annotations = file.readAnnotations()
onsets = np.array(annotations[0])
durations = np.array(annotations[1])
offsets = onsets + durations
marks_time = np.stack((onsets, offsets), axis=1)  # time-stamps
marks_kc = np.round(marks_time * fs).astype(np.int32)
# marks_kc = stamp_correction.filter_duration_stamps(marks, fs, min_kc_duration, max_kc_duration)

print("Marks SS", marks_ss.shape, marks_ss.dtype)
print("Marks KC", marks_kc.shape, marks_kc.dtype)

# Check Signal and Marks

In [None]:
for _ in range(10):
    mark_correction = -0.9
    
    mark_linewidth = 1
    window_size = int(fs * 15)
    random_loc = np.random.choice(range(marks_kc.shape[0]))  # 676
    random_center = marks_kc[random_loc].mean()
    start_sample = int((random_center - (window_size // 2)) / fs) * fs
    end_sample = start_sample + window_size

    segment_time = np.arange(start_sample, end_sample) / fs
    segment_signal = signal[start_sample:end_sample]
    segment_marks_ss = utils.filter_stamps(marks_ss, start_sample, end_sample) / fs
    segment_marks_kc = utils.filter_stamps(marks_kc, start_sample, end_sample) / fs

    fig, ax = plt.subplots(1, 1, dpi=80, figsize=(10, 2))
    ax.plot(segment_time, segment_signal, linewidth=0.8, color=viz.GREY_COLORS[8])
    for mark in segment_marks_ss:
        mark = mark +  mark_correction
        ax.axvline(mark[0], linewidth=0.8, color=viz.PALETTE["blue"])
        ax.axvline(mark[1], linewidth=0.8, color=viz.PALETTE["blue"])
        ax.fill_between(mark, -150, 150, facecolor=viz.PALETTE["blue"], alpha=0.5)
    for mark in segment_marks_kc:
        mark = mark + mark_correction
        ax.axvline(mark[0], linewidth=0.8, color=viz.PALETTE["red"])
        ax.axvline(mark[1], linewidth=0.8, color=viz.PALETTE["red"])
        ax.fill_between(mark, -150, 150, facecolor=viz.PALETTE["red"], alpha=0.5)
    ax.set_ylim([-150, 150])
    ax.set_yticks([-50, 0, 50])
    ax.set_xlabel("Time [s]")
    ax.set_ylabel(r"EEG [$\mu$V]")
    ax.set_xticks(np.arange(start_sample / fs, end_sample / fs, 0.5), minor=True)
    ax.set_xlim([start_sample / fs, end_sample / fs])
    ax.set_title("Subject %02d, KC Mark %d" % (subject_id, random_loc))

    plt.tight_layout()
    plt.show()

# Comparar con archivos viejos

In [None]:
for subject_id in range(1, 20):

    print("\n\nSubject %d" % subject_id)

    # New
    mass_folder = "/home/ntapia/projects/repos/sleep-rnn/resources/datasets/mass2/"
    eeg_file = "register/01-02-00%02d PSG.edf" % subject_id
    ss_file = "label/spindle/01-02-00%02d Spindles_E1.edf" % subject_id
    kc_file = "label/kcomplex/01-02-00%02d KComplexes_E1.edf" % subject_id
    with pyedflib.EdfReader(os.path.join(mass_folder, eeg_file)) as file:
        channel_names = file.getSignalLabels()
        channel_to_extract = channel_names.index("EEG C3-CLE")
        signal = file.readSignal(channel_to_extract)
        fs_new = file.samplefrequency(channel_to_extract)
        print('Channel extracted: %s (%s Hz)' % (file.getLabel(channel_to_extract), fs_new))
    signal_new = signal.astype(np.float32)
    with pyedflib.EdfReader(os.path.join(mass_folder, ss_file)) as file:
        annotations = file.readAnnotations()
    onsets = np.array(annotations[0])
    durations = np.array(annotations[1])
    offsets = onsets + durations
    marks_ss_new = np.stack((onsets, offsets), axis=1)  # time-stamps

    # Old
    mass_folder = "/home/ntapia/projects/repos/sleep-rnn/resources/datasets/mass_old/"
    eeg_file = "register/01-02-00%02d PSG.edf" % subject_id
    ss_file = "label/spindle/01-02-00%02d SpindleE1.edf" % subject_id
    kc_file = "label/kcomplex/01-02-00%02d KComplexesE1.edf" % subject_id
    with pyedflib.EdfReader(os.path.join(mass_folder, eeg_file)) as file:
        channel_names = file.getSignalLabels()
        channel_to_extract = channel_names.index("EEG C3-CLE")
        signal = file.readSignal(channel_to_extract)
        fs_old = file.samplefrequency(channel_to_extract)
        print('Channel extracted: %s (%s Hz)' % (file.getLabel(channel_to_extract), fs_old))
    signal_old = signal.astype(np.float32)
    with pyedflib.EdfReader(os.path.join(mass_folder, ss_file)) as file:
        annotations = file.readAnnotations()
    onsets = np.array(annotations[0])
    durations = np.array(annotations[1])
    offsets = onsets + durations
    marks_ss_old = np.stack((onsets, offsets), axis=1)  # time-stamps

    print("Total duration Old [s]", signal_old.size / fs_old)
    print("Total duration New [s]", signal_new.size / fs_new)
    print("Ratio sizes Old:New", signal_old.size / signal_new.size)
    print("Ratio fs Old:New", fs_old / fs_new)
    half_duration_difference = ((signal_old.size / fs_old) - (signal_new.size / fs_new)) / 2
    print("Half duration difference [s]", half_duration_difference)

    n_samples = 100
    start_sample = 1000000
    fig, ax = plt.subplots(1, 1, dpi=80, figsize=(8, 2))
    ax.plot(
        np.arange(start_sample, start_sample+n_samples)/fs_old, 
        signal_old[start_sample:start_sample+n_samples], linewidth=0.8, label="old")
    ax.plot(
        np.arange(start_sample, start_sample+n_samples)/fs_new, 
        signal_new[start_sample:start_sample+n_samples], linewidth=0.8, label="new")
    ax.legend(loc="upper right")
    plt.tight_layout()
    plt.show()

    # Comparacion de marcas
    marks_ss_old_fix = stamp_correction.filter_duration_stamps(marks_ss_old * fs_old, fs_old, None, 5) / fs_old
    print("Number of marks old_fix:", marks_ss_old_fix.shape[0])
    print("Number of marks new:", marks_ss_new.shape[0])

    difference_in_onset = marks_ss_old_fix[:, 0] - marks_ss_new[:, 0]
    print("Mean Gap: % 1.4f [s]" % np.mean(difference_in_onset))
    plt.hist(difference_in_onset)
    plt.title("Onset difference Old_fix - New [s] - Subject %02d" % subject_id)
    plt.show()

    # Comparacion de señales
    signal_old_resample = utils.resample_signal_linear(signal_old, fs_old, fs_new)

    fig, ax = plt.subplots(1, 1, dpi=80, figsize=(8, 2))
    ax.plot(
        np.arange(start_sample, start_sample+n_samples)/fs_new, 
        signal_old_resample[start_sample:start_sample+n_samples], linewidth=0.8, label="old_resample")
    ax.plot(
        np.arange(start_sample, start_sample+n_samples)/fs_new, 
        signal_new[start_sample:start_sample+n_samples], linewidth=0.8, label="new")
    ax.legend()
    plt.tight_layout()
    plt.show()

    n_samples_to_mse = min(signal_old_resample.size, signal_new.size)
    error = signal_old_resample[:n_samples_to_mse] - signal_new[:n_samples_to_mse]
    rmse = np.sqrt(np.mean(error ** 2))
    print("RMSE between old_resample and new:", rmse)
    print("|Error| range", np.abs(error).min(), np.abs(error).max())

    cases_loc = np.where((np.abs(error) > 1) & (np.abs(error) < 1000000))[0]
    print("Number of error cases greater than 1: %d (%1.4f %%)" % (len(cases_loc), 100 * len(cases_loc) / n_samples_to_mse))
    signal_old_resample = np.concatenate([signal_old_resample, np.zeros(int(fs_new * 10))])
    signal_new = np.concatenate([signal_new, np.zeros(int(fs_new * 10))])
    for _ in range(5):
        single_case = np.random.choice(cases_loc)
        case_error = error[single_case]
        case_start = single_case - int(fs_new * 2)
        case_end = single_case + int(fs_new * 2)
        fig, ax = plt.subplots(1, 1, dpi=80, figsize=(9, 2))
        ax.plot(
            np.arange(case_start, case_end)/fs_new, 
            signal_old_resample[case_start:case_end], linewidth=0.8, label="old_resample")
        ax.plot(
            np.arange(case_start, case_end)/fs_new, 
            signal_new[case_start:case_end], linewidth=0.8, label="new")
        ax.legend()
        ax.set_title("Case visualization of signal mismatch (Error %1.6f)" % case_error)
        plt.tight_layout()
        plt.show()

# Check hypnogram alignment

In [None]:
# Load hypnogram