In [1]:
import os
from pprint import pprint
import sys
import json

import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np

sys.path.append('..')

from sleeprnn.common import viz, constants, pkeys
from sleeprnn.data import utils
from sleeprnn.helpers import reader

viz.notebook_full_width()
%matplotlib inline

In [2]:
dataset = reader.load_dataset(constants.MODA_SS_NAME)

Dataset moda_ss with 180 patients.
Loading from checkpoint... Loaded
Global STD: None


# Visualize single subject

In [None]:
subject_id = dataset.all_ids[55]
print("Subject id %s" % subject_id)
signal = dataset.get_subject_signal(subject_id, normalize_clip=False)
marks = dataset.get_subject_stamps(subject_id, pages_subset=constants.N2_RECORD)
n2_pages = dataset.get_subject_pages(subject_id, pages_subset=constants.N2_RECORD)
print("Signal", signal.shape, signal.dtype)
print("Marks", marks.shape, marks.dtype)
print("N2 pages", n2_pages.shape, n2_pages.dtype)

In [None]:
def draw_signal(n2_page_loc):
    fig, ax = plt.subplots(1, 1, figsize=(12, 2), dpi=140)

    n2_page_id = n2_pages[n2_page_loc]
    start_sample = int(n2_page_id * dataset.page_size)
    end_sample = int(start_sample + dataset.page_size)
    
    page_signal = signal[start_sample:end_sample]
    page_label = utils.filter_stamps(marks, start_sample=start_sample, end_sample=end_sample)
    
    time_axis = np.arange(start_sample, end_sample) / dataset.fs
    
    ax.plot(time_axis, page_signal, linewidth=0.7, color=viz.PALETTE['blue'])
    for m in page_label:
        m = np.clip(m, a_min=start_sample, a_max=end_sample-1)
        m = m / dataset.fs
        ax.plot(m, [-50, -50], linewidth=4, color=viz.PALETTE['red'], alpha=0.5)
    
    ax.set_title(" N2 page %d (page in record: %d) (Subject %s, Phase %d, NBlocks %d)" % (
        n2_page_loc, n2_page_id, subject_id, dataset.data[subject_id]['phase'], dataset.data[subject_id]['n_blocks']))
    
    ax.set_ylim([-150, 150])
    ax.set_xlim([time_axis[0], time_axis[-1]])
    
    plt.show()

In [None]:
init_page_loc = 0

style = {'description_width': 'initial'}
layout= widgets.Layout(width='1000px')
widgets.interact(
    lambda n2_page_loc: draw_signal(n2_page_loc),
    n2_page_loc=widgets.IntSlider(
        min=0, max=n2_pages.size-1, step=1, value=init_page_loc, 
        continuous_update=False,
        style=style,
        layout=layout
    ));

# Cross-validation scheme

In [3]:
n_folds = 5
fold_id = 0
seed = 0

train_ids, val_ids, test_ids = dataset.cv_split(n_folds, fold_id, seed=seed)
# Check overlap
print("train %d, val %d, test %d, total unique %d" % (
    train_ids.size, val_ids.size, test_ids.size, np.unique(np.concatenate([train_ids, val_ids, test_ids])).size
))

train 144, val 0, test 36, total unique 180
