In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from datetime import datetime
import os
import sys
from pprint import pprint

import matplotlib.pyplot as plt
import mne
import numpy as np
import ipywidgets as widgets
import pandas as pd
import pyedflib

sys.path.append('..')

from sleeprnn.helpers import reader
from sleeprnn.data import utils
from sleeprnn.common import constants, pkeys, viz

viz.notebook_full_width()
%matplotlib inline

In [3]:
CAP_DATA_PATH = "/home/ntapia/projects/sleep-rnn/resources/datasets/unlabeled_cap"
REC_PATH = os.path.join(CAP_DATA_PATH, "register")
STATE_PATH = os.path.join(CAP_DATA_PATH, "label/state")
invalid_subjects = ['5-006', '5-027', '5-031', '5-033']

subject_ids = [s[:5] for s in os.listdir(REC_PATH) if ".edf" in s]
subject_ids = [s for s in subject_ids if s not in invalid_subjects]

subject_ids = [s for s in subject_ids if '1-' not in s]

subject_ids.sort()
print("Number of subjects:", len(subject_ids))
pprint(subject_ids)

Number of subjects: 25
['2-002',
 '4-001',
 '4-003',
 '4-005',
 '5-003',
 '5-004',
 '5-005',
 '5-007',
 '5-009',
 '5-010',
 '5-011',
 '5-015',
 '5-016',
 '5-017',
 '5-018',
 '5-019',
 '5-020',
 '5-021',
 '5-023',
 '5-028',
 '5-030',
 '5-034',
 '5-036',
 '5-038',
 '5-039']


## MNE - PyEDFlib check

In [None]:
eeg_file_check = os.path.join(REC_PATH, "1-003 PSG.edf")
channel_check = 'C4-A1'

# Read with pyedflib
with pyedflib.EdfReader(eeg_file_check) as file:
    channel_names = file.getSignalLabels()
    chn = channel_names.index(channel_check)
    signal_pyedflib = file.readSignal(chn)
    fs_pyedflib = file.samplefrequency(chn)
print("\npyedflib")
print("fs", fs_pyedflib)
print("signal", signal_pyedflib.size, signal_pyedflib.shape, signal_pyedflib.dtype, signal_pyedflib.min(), signal_pyedflib.max(), signal_pyedflib.mean(), signal_pyedflib.std())

# Read with mne
raw_data = mne.io.read_raw_edf(eeg_file_check, verbose=False)
chn = raw_data.ch_names.index(channel_check)
signal_mne, _ = raw_data[chn, :]
signal_mne = np.squeeze(signal_mne)
fs_mne = raw_data.info['sfreq']
if signal_mne.std() < 0.001:
    signal_mne = 1e6 * signal_mne
print("\nmne")
print("fs", fs_mne)
print("signal", signal_mne.size, signal_mne.shape, signal_mne.dtype, signal_mne.min(), signal_mne.max(), signal_mne.mean(), signal_mne.std())

mse_signals = np.mean((signal_mne - signal_pyedflib) ** 2)
print("\nMSE between signals:", mse_signals)

## EEG reading

In [None]:
fs = 200
signal_dict = {}
starting_times_dict = {}
for subject_id in subject_ids:
    path_eeg_file = os.path.join(REC_PATH, "%s PSG.edf" % subject_id)
    channels_to_try = [
        ("C4-A1",),
        ("C4A1",),
        ("C4", "A1"),  # For 5-033
        ("C3-A2",),
        ("C3A2",),
        ("C3", "A2"),
    ]
    raw_data = mne.io.read_raw_edf(path_eeg_file, verbose=False)
    
    recording_start_time_hh_mm_ss = raw_data.info['meas_date'].strftime("%H:%M:%S")
    starting_times_dict[subject_id] = recording_start_time_hh_mm_ss
    fs_old = raw_data.info['sfreq']
    channel_names = raw_data.ch_names
    while len(channels_to_try) > 0:
        channel = channels_to_try.pop(0)
        if np.all([chn in channel_names for chn in channel]):
            break
    channel_index = channel_names.index(channel[0])
    signal, _ = raw_data[channel_index, :]
    if len(channel) == 2:
        channel_index_2 = channel_names.index(channel[1])
        signal2, _ = raw_data[channel_index_2, :]
        signal = signal - signal2
        # Check
        print('Subject %s | Channel extracted: %s minus %s at %s Hz' % (subject_id, channel_names[channel_index], channel_names[channel_index_2], fs_old))
    else:
        # Check
        print('Subject %s | Channel extracted: %s at %s Hz' % (subject_id, channel_names[channel_index], fs_old))  
    signal = signal[0, :]
    if signal.std() < 0.001:
        # Signal is in volts, transform to microvolts for simplicity
        signal = 1e6 * signal
    # The sampling frequency is already an integer in CAP
    fs_old_round = int(np.round(fs_old))
    # Broand bandpass filter to signal
    signal = utils.broad_filter(signal, fs_old_round)
    # Now resample to the required frequency
    if fs != fs_old_round:
        print('Resampling from %d Hz to required %d Hz' % (fs_old_round, fs))
        signal = utils.resample_signal(signal, fs_old=fs_old_round, fs_new=fs)
    else:
        print('Signal already at required %d Hz' % fs)
    signal = signal.astype(np.float32)
    # No robust normalization nor clipping for now
    signal_dict[subject_id] = signal

## Hypnogram reading

In [None]:
def absolute_to_relative_time(abs_time_str, start_time_str):
    start_t = datetime.strptime(start_time_str, '%H:%M:%S')
    end_t = datetime.strptime(abs_time_str, '%H:%M:%S')
    delta = end_t - start_t
    return delta.seconds


def get_skiprows_cap_states(states_file_path):
    with open(states_file_path, 'r') as file:
        lines = file.readlines()
    skiprows = 0
    for line in lines:
        if 'Sleep Stage' in line:
            break
        skiprows += 1
    return skiprows

In [None]:
n2_id = 'SLEEP-S2'
original_page_duration = 30
page_duration = 20
page_size = int(fs * page_duration)

total_30s_pages = 0
total_20s_pages = 0

n2_pages_dict = {}
for subject_id in subject_ids:
    print("")
    signal_length = signal_dict[subject_id].size
    print("Signal length", signal_length, "Total 30s pages", signal_length / (fs * original_page_duration))
    path_states_file = os.path.join(STATE_PATH, "%s Base.txt" % subject_id)
    starting_time = starting_times_dict[subject_id]
    skiprows = get_skiprows_cap_states(path_states_file)
    states_df = pd.read_csv(path_states_file, skiprows=skiprows, sep='\t')
    states_df = states_df.dropna()
    column_names = states_df.columns.values
    duration_col_name = column_names[[('Duration' in s) for s in column_names]][0]
    time_hhmmss_col_name = column_names[[('hh:mm:ss' in s) for s in column_names]][0]
    states_df['Time [s]'] = states_df[time_hhmmss_col_name].apply(lambda x: absolute_to_relative_time(x, starting_time))
    n2_stages = states_df.loc[states_df['Event'] == n2_id]
    n2_stages = n2_stages.loc[n2_stages[duration_col_name] == original_page_duration]
    # These are pages with 30s durations. To work with 20s pages
    # We consider the intersections with 20s divisions
    n2_pages_original = n2_stages["Time [s]"].values / original_page_duration
    print("First page before rounding", n2_pages_original[0])
    n2_pages_original = n2_pages_original.astype(np.int32)
    print("Max N2 page", n2_pages_original.max())
    print('Original N2 pages: %d' % n2_pages_original.size)
    onsets_original = n2_pages_original * original_page_duration
    offsets_original = (n2_pages_original + 1) * original_page_duration
    total_pages = int(np.ceil(signal_length / page_size))
    n2_pages_onehot = np.zeros(total_pages, dtype=np.int16)
    for i in range(total_pages):
        onset_new_page = i * page_duration
        offset_new_page = (i + 1) * page_duration
        for j in range(n2_pages_original.size):
            intersection = (onset_new_page < offsets_original[j]) and (onsets_original[j] < offset_new_page)
            if intersection:
                n2_pages_onehot[i] = 1
                break
    n2_pages = np.where(n2_pages_onehot == 1)[0]
    # Drop first, last and second to last page of the whole registers
    # if they where selected.
    last_page = total_pages - 1
    n2_pages = n2_pages[
        (n2_pages != 0)
        & (n2_pages != last_page)
        & (n2_pages != last_page - 1)]
    n2_pages = n2_pages.astype(np.int16)
    print("Subject %s - Total N2 pages %d" % (subject_id, n2_pages.size))
    n2_pages_dict[subject_id] = n2_pages
    total_30s_pages += n2_pages_original.size
    total_20s_pages += n2_pages.size
print("Total 30s pages:", total_30s_pages)
print("Total 20s pages:", total_20s_pages)

In [None]:
subject_id = subject_ids[28] #'5-027' # "1-001"


def draw_single_page(page_index_to_show, dpi, ylim=150):
    page = n2_pages_dict[subject_id][page_index_to_show]
    start_sample = int(page * page_size)
    end_sample = start_sample + page_size
    segment_signal = signal_dict[subject_id][start_sample:end_sample]
    fig, axes = plt.subplots(2, 1, figsize=(12, 4), dpi=dpi)
    axes[0].plot(segment_signal, linewidth=0.6)
    axes[0].set_ylim([-ylim, ylim])
    sigma_signal = utils.broad_filter(segment_signal, fs, lowcut=11, highcut=16)
    axes[1].plot(sigma_signal, linewidth=0.6)
    axes[1].set_ylim([-ylim, ylim])
    axes[0].set_title("Subject %s, page in record %d" % (subject_id, page))
    plt.tight_layout()
    plt.show()

    power, freq = utils.power_spectrum(segment_signal, fs)
    fig, ax = plt.subplots(1, 1, figsize=(4, 3), dpi=dpi)
    ax.plot(freq, power, linewidth=0.7)
    ax.set_xlim([0, 30])
    ax.set_xlabel("Frequency (Hz)")
    plt.show()

In [None]:
page_index_to_show = 90 # 115


dpi = 120
print('Total N2 pages: %d' % n2_pages_dict[subject_id].size)
style = {'description_width': 'initial'}
layout= widgets.Layout(width='1000px')
widgets.interact(
    lambda page_id: draw_single_page(page_id, dpi=dpi),
    page_id=widgets.IntSlider(
        min=0, max=n2_pages_dict[subject_id].size-1, step=1, value=page_index_to_show, 
        continuous_update=False,
        style=style,
        layout=layout
    ));

# Mean spectrum

In [None]:
mean_power_dict = {}
for subject_id in signal_dict.keys():
    all_pages_power = []
    for page in n2_pages_dict[subject_id]:
        start_sample = int(page * page_size)
        end_sample = start_sample + page_size
        segment_signal = signal_dict[subject_id][start_sample:end_sample]
        power, freq = utils.power_spectrum(segment_signal, fs)
        all_pages_power.append(power)
    mean_power_dict[subject_id] = np.stack(all_pages_power, axis=0).mean(axis=0)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 4), dpi=120)
min_freq = 0.5
max_freq = 30
plot_indices = np.where((freq >= min_freq) & (freq <= max_freq))[0]
for subject_id in subject_ids:
    if subject_id == subject_ids[24]:
        ax.plot(freq[plot_indices], mean_power_dict[subject_id][plot_indices], linewidth=0.7, color="r", alpha=1, zorder=20, label=subject_id)
    else:
        ax.plot(freq[plot_indices], mean_power_dict[subject_id][plot_indices], linewidth=0.7, color="k", alpha=0.2)
ax.set_xlim([0, 30])
ax.set_yticks([])
ax.legend()
ax.set_title("Mean N2 spectrum of CAP subjects (n=%d)" % len(subject_ids))
ax.set_xlabel("Frequency (Hz)")
plt.show()

# Comparison of spectrum with MASS

In [None]:
mass = reader.load_dataset(constants.MASS_SS_NAME)

In [None]:
mass_mean_power_dict = {}
for subject_id in mass.train_ids:
    all_pages_power = []
    n2_pages_of_subject = mass.get_subject_pages(subject_id, pages_subset=constants.N2_RECORD)
    print(subject_id, n2_pages_of_subject.size)
    signal_of_subject = mass.get_subject_signal(subject_id, normalize_clip=True)
    for page in n2_pages_of_subject:
        start_sample = int(page * page_size)
        end_sample = start_sample + page_size
        segment_signal = signal_of_subject[start_sample:end_sample]
        power, freq = utils.power_spectrum(segment_signal, fs)
        all_pages_power.append(power)
    mass_mean_power_dict[subject_id] = np.stack(all_pages_power, axis=0).mean(axis=0)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 4), dpi=120)
min_freq = 2
max_freq = 30
plot_indices = np.where((freq >= min_freq) & (freq <= max_freq))[0]
for subject_id in mass_mean_power_dict.keys():
    ax.plot(freq[plot_indices], mass_mean_power_dict[subject_id][plot_indices], linewidth=0.7, color="k", alpha=0.2)
    ax.set_xlim([0, 30])
    ax.set_yticks([])
ax.set_title("Mean N2 spectrum of MASS subjects (n=%d)" % len(mass_mean_power_dict.keys()))
ax.set_xlabel("Frequency (Hz)")
plt.show()

In [None]:
# side by side

fig, axes = plt.subplots(1, 2, figsize=(12, 4), dpi=200, sharey=True, sharex=True)
min_freq = 0.5
max_freq = 30
plot_indices = np.where((freq >= min_freq) & (freq <= max_freq))[0]

# cap
ax = axes[0]
for subject_id in subject_ids:
    ax.plot(freq[plot_indices], mean_power_dict[subject_id][plot_indices] / mass.global_std, linewidth=0.7, color="k", alpha=0.2)
    ax.set_xlim([0, 30])
    # ax.set_yticks([])
    ax.set_yticklabels([])
ax.set_title("Mean N2 spectrum of CAP subjects (n=%d)" % len(subject_ids))
ax.set_xlabel("Frequency (Hz)")
ax.grid()

# mass
ax = axes[1]
for subject_id in mass_mean_power_dict.keys():
    ax.plot(freq[plot_indices], mass_mean_power_dict[subject_id][plot_indices], linewidth=0.7, color="k", alpha=0.2)
    ax.set_xlim([0, 30])
    # ax.set_yticks([])
ax.set_title("Mean N2 spectrum of MASS subjects (n=%d)" % len(mass_mean_power_dict.keys()))
ax.set_xlabel("Frequency (Hz)")
ax.grid()

plt.tight_layout()
plt.show()

# Save CAP *.mat files

In [None]:
import scipy.io

In [None]:
os.makedirs("cap_subset_mat", exist_ok=True)
for subject_id in subject_ids:
    this_signal = signal_dict[subject_id]
    this_signal = np.clip(this_signal, a_min=-10*mass.global_std, a_max=10*mass.global_std) 
    this_n2_pages_from_zero = n2_pages_dict[subject_id]
    fname = os.path.join("cap_subset_mat", "cap_s%s_fs_%s.mat" % (subject_id, fs))
    print(this_signal.min(), this_signal.max())
    scipy.io.savemat(fname, {"signal": this_signal, "n2_pages_from_zero": this_n2_pages_from_zero})