In [67]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
from pprint import pprint

import numpy as np
import matplotlib.pyplot as plt
import pyedflib
import ipywidgets as widgets

project_root = os.path.abspath('..')
sys.path.append(project_root)

from sleeprnn.common import viz, constants
from sleeprnn.data.utils import PATH_DATA
from sleeprnn.data.mass_kc import PATH_MASS_RELATIVE, PATH_REC, PATH_MARKS

%matplotlib inline
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [68]:
subject_id = 1
file_rec = os.path.join(
    project_root, PATH_DATA, PATH_MASS_RELATIVE, PATH_REC,
    '01-02-%04d PSG.edf' % subject_id)
file_ss = os.path.join(
    project_root, PATH_DATA, PATH_MASS_RELATIVE, PATH_MARKS,
    '01-02-%04d KComplexesE1.edf' % subject_id)

chosen_unipolar = [
    'F3',
    'F4',
    'C3',
    'Cz',
    'C4',
    'P3',
    'P4'
]
chosen_bipolar = [
    ('F3', 'C3'),
    ('F4', 'C4')
]

marked_ch = 'C3'

unipolar_signals = {}
bipolar_signals = {}

with pyedflib.EdfReader(file_rec) as file:
    channel_names = file.getSignalLabels()
    for name in chosen_unipolar:
        format_name = 'EEG %s-CLE' % name
        channel_to_extract = channel_names.index(format_name)
        this_signal = file.readSignal(channel_to_extract)
        unipolar_signals[name] = this_signal
    fs = file.samplefrequency(channel_to_extract)
    for name in chosen_bipolar:
        bipolar_signals[name] = unipolar_signals[name[0]] - unipolar_signals[name[1]]
    
with pyedflib.EdfReader(file_ss) as file:
    annotations = file.readAnnotations()
onsets = np.array(annotations[0])
durations = np.array(annotations[1])
offsets = onsets + durations
marks_time = np.stack((onsets, offsets), axis=1)  # time-stamps
# Transforms to sample-stamps
ss_marks = np.round(marks_time * fs).astype(np.int32)

In [69]:
def show_segment(ss_to_show):
    context = 6

    ##########
    chosen_mark = ss_marks[ss_to_show, :]
    center_sample = np.mean(chosen_mark)
    start_sample = int(center_sample - fs*context//2)
    end_sample = int(start_sample + fs*context)
    start_mark = chosen_mark[0] - start_sample
    end_mark = chosen_mark[1] - start_sample
    time_axis = np.arange(start_sample, end_sample) / fs


    fig, ax = plt.subplots(1, 1, figsize=(6, 8), dpi=120)
    small_font = 7

    dy = 180
    start_bipolar = 0

    for k, name in enumerate(chosen_bipolar):
        this_signal = bipolar_signals[name]
        this_segment = this_signal[start_sample:end_sample]
        ax.plot(time_axis, this_segment - k*dy + start_bipolar, linewidth=1, color=viz.PALETTE[constants.DARK])
        if marked_ch in name:
            ax.fill_between(
                chosen_mark / fs, 50- k*dy + start_bipolar, -50- k*dy + start_bipolar ,
                facecolor=viz.PALETTE[constants.GREY], alpha=0.4,  label='Mark')

    start_unipolar = len(chosen_bipolar) - k*dy + start_bipolar - dy

    for k, name in enumerate(chosen_unipolar):
        this_signal = unipolar_signals[name]
        this_segment = this_signal[start_sample:end_sample]
        if 'F' in name:
            this_color = constants.BLUE
        elif 'P' in name:
            this_color = constants.RED
        else:
            this_color = constants.DARK
        ax.plot(time_axis, this_segment - k*dy + start_unipolar, linewidth=1, color=viz.PALETTE[this_color])
        if name == marked_ch:
            ax.fill_between(
                chosen_mark / fs, 50- k*dy + start_unipolar, -50- k*dy + start_unipolar, 
                facecolor=viz.PALETTE[constants.GREY], alpha=0.4,  label='Mark')

    x_ticks = np.arange(time_axis[0], time_axis[-1]+1, 1)
    ax.set_xticks(x_ticks)
    ax.set_xticks(np.arange(time_axis[0], time_axis[-1], 0.5), minor=True)
    ax.set_yticks([start_bipolar-k*dy for k in range(len(chosen_bipolar))] + [start_unipolar-k*dy for k in range(len(chosen_unipolar))])
    chosen_bipolar_format = ['%s-%s' % (name[0], name[1]) for name in chosen_bipolar]
    chosen_unipolar_format = ['%s-CLE' % name for name in chosen_unipolar]
    ax.set_yticklabels(chosen_bipolar_format+chosen_unipolar_format)
    ax.grid(b=True, axis='x', which='minor')
    ax.tick_params(labelsize=small_font, labelbottom=True ,labeltop=False, bottom=True, top=True)
    plt.show()

In [70]:
style = {'description_width': 'initial'}
layout= widgets.Layout(width='1000px')
widgets.interact(
    lambda ss_id: show_segment(ss_id),
    ss_id=widgets.IntSlider(
        min=0, max=ss_marks.shape[0]-1, step=1, value=0, 
        continuous_update=False,
        style=style,
        layout=layout
    ));

interactive(children=(IntSlider(value=0, continuous_update=False, description='ss_id', layout=Layout(width='10…