In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
import pyedflib

import matplotlib.pyplot as plt
from matplotlib import colors
import numpy as np
import ipywidgets as widgets
from matplotlib import gridspec

project_root = '..'
sys.path.append(project_root)

from sleeprnn.data.inta_ss import IntaSS, NAMES
from sleeprnn.data import utils
from sleeprnn.common import constants, pkeys

DPI = 200
CUSTOM_COLOR = {'red': '#c62828', 'grey': '#455a64', 'blue': '#0277bd', 'green': '#43a047'} 

%matplotlib inline
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [2]:
dataset = IntaSS(load_checkpoint=True)

Dataset inta_ss with 10 patients.
Loading from checkpoint... Loaded


In [67]:
# Load stamps of subject
subject_id = 5

fs = dataset.fs
path_dataset = dataset.dataset_dir
path_stamps = os.path.join(path_dataset, 'label/spindle', 'SS_%s.txt' % NAMES[subject_id - 1])
path_signals = os.path.join(path_dataset, 'register', '%s.rec' % NAMES[subject_id - 1]) 

with pyedflib.EdfReader(path_signals) as file:
    signal = file.readSignal(0)
    channel_name = file.getLabel(0)
    print('Reading', channel_name)
    signal_len = signal.shape[0]
data = np.loadtxt(path_stamps)
for_this_channel = (data[:, -1] == 1)
data = data[for_this_channel]
data = np.round(data).astype(np.int32)
print(signal_len)
print(data.shape)
print(data)

Reading F4-C4
5808200
(2027, 6)
[[ 682005  682544     -50     -50       1       1]
 [ 682017  682138     -50     -50       2       1]
 [ 682290  682582     -50     -50       2       1]
 ...
 [5206149 5206526     -50     -50       1       1]
 [5207455 5207659     -50     -50       0       1]
 [5208413 5208672     -50     -50       0       1]]


In [12]:
# Remove zero duration marks, and ensure that start time < end time
new_data = []
for i in range(data.shape[0]):
    if data[i, 0] > data[i, 1]:
        print('End time < Start time fixed')
        aux = data[i, 0]
        data[i, 0] = data[i, 1]
        data[i, 1] = aux
        new_data.append(data[i, :])
    elif data[i, 0] < data[i, 1]:
        new_data.append(data[i, :])
    else:  # Zero duration (equality)
        print('Zero duration stamp found and removed')
data = np.stack(new_data, axis=0)
print(data.shape)
print(data)

End time < Start time fixed
(2027, 6)
[[ 682005  682544     -50     -50       1       1]
 [ 682017  682138     -50     -50       2       1]
 [ 682290  682582     -50     -50       2       1]
 ...
 [5206149 5206526     -50     -50       1       1]
 [5207455 5207659     -50     -50       0       1]
 [5208413 5208672     -50     -50       0       1]]


In [13]:
# Remove stamps outside signal boundaries
new_data = []
for i in range(data.shape[0]):
    if data[i, 1] < signal_len:
        new_data.append(data[i, :])
    else:
        print('Stamp outside boundaries found and removed')
data = np.stack(new_data, axis=0)
print(data.shape)
print(data)

(2027, 6)
[[ 682005  682544     -50     -50       1       1]
 [ 682017  682138     -50     -50       2       1]
 [ 682290  682582     -50     -50       2       1]
 ...
 [5206149 5206526     -50     -50       1       1]
 [5207455 5207659     -50     -50       0       1]
 [5208413 5208672     -50     -50       0       1]]


In [14]:
raw_stamps = data[:, [0, 1]]
print(raw_stamps)

print(raw_stamps.shape)
valid = data[:, 4]
raw_stamps_0 = raw_stamps[valid == 0]
raw_stamps_1 = raw_stamps[valid == 1]
raw_stamps_2 = raw_stamps[valid == 2]
print('Valid 0', raw_stamps_0.shape, 'Min dur [s]', (raw_stamps_0[:, 1] - raw_stamps_0[:, 0]).min()/fs, 'Max dur [s]', (raw_stamps_0[:, 1] - raw_stamps_0[:, 0]).max()/fs)
print('Valid 1', raw_stamps_1.shape, 'Min dur [s]', (raw_stamps_1[:, 1] - raw_stamps_1[:, 0]).min()/fs, 'Max dur [s]', (raw_stamps_1[:, 1] - raw_stamps_1[:, 0]).max()/fs)
print('Valid 2', raw_stamps_2.shape, 'Min dur [s]', (raw_stamps_2[:, 1] - raw_stamps_2[:, 0]).min()/fs, 'Max dur [s]', (raw_stamps_2[:, 1] - raw_stamps_2[:, 0]).max()/fs)

[[ 682005  682544]
 [ 682017  682138]
 [ 682290  682582]
 ...
 [5206149 5206526]
 [5207455 5207659]
 [5208413 5208672]]
(2027, 2)
Valid 0 (204, 2) Min dur [s] 0.155 Max dur [s] 4.275
Valid 1 (995, 2) Min dur [s] 0.055 Max dur [s] 6.775
Valid 2 (828, 2) Min dur [s] 0.395 Max dur [s] 4.605


In [15]:
# look for intersections between stamps of the same validity
def overlap_matrix(events, detections):
    # Matrix of overlap, rows are events, columns are detections
    n_det = detections.shape[0]
    n_gs = events.shape[0]
    overlaps = np.zeros((n_gs, n_det))
    for i in range(n_gs):
        candidates = np.where(
            (detections[:, 0] <= events[i, 1])
            & (detections[:, 1] >= events[i, 0]))[0]
        for j in candidates:
            intersection = min(
                events[i, 1], detections[j, 1]
            ) - max(
                events[i, 0], detections[j, 0]
            ) + 1
            if intersection > 0:
                overlaps[i, j] = 1
    return overlaps


overlap_m = overlap_matrix(raw_stamps_0, raw_stamps_0)
n_overlaps_0 = overlap_m.sum(axis=1) - 1  # We expect all stamps intersect with themselves, so we discount one
values_0, counts_0 = np.unique(n_overlaps_0, return_counts=True)
print('\nOverlaps for Valid 0')
for value, count in zip(values_0, counts_0):
    print('%d overlaps: %d times' % (value, count))

overlap_m = overlap_matrix(raw_stamps_1, raw_stamps_1)
n_overlaps_1 = overlap_m.sum(axis=1) - 1  # We expect all stamps intersect with themselves, so we discount one
values_1, counts_1 = np.unique(n_overlaps_1, return_counts=True)
print('\nOverlaps for Valid 1')
for value, count in zip(values_1, counts_1):
    print('%d overlaps: %d times' % (value, count))

overlap_m = overlap_matrix(raw_stamps_2, raw_stamps_2)
n_overlaps_2 = overlap_m.sum(axis=1) - 1  # We expect all stamps intersect with themselves, so we discount one
values_2, counts_2 = np.unique(n_overlaps_2, return_counts=True)
print('\nOverlaps for Valid 2')
for value, count in zip(values_2, counts_2):
    print('%d overlaps: %d times' % (value, count))

max_overlaps_0 = values_0.max()
max_overlaps_1 = values_1.max()
max_overlaps_2 = values_2.max()
max_overlaps = np.max([max_overlaps_0, max_overlaps_1, max_overlaps_2])


Overlaps for Valid 0
0 overlaps: 184 times
1 overlaps: 20 times

Overlaps for Valid 1
0 overlaps: 991 times
1 overlaps: 4 times

Overlaps for Valid 2
0 overlaps: 462 times
1 overlaps: 342 times
2 overlaps: 22 times
3 overlaps: 2 times


In [71]:
# Signal
signal_dict = {}
fs_dict = {}
with pyedflib.EdfReader(path_signals) as file:
    signal_names = file.getSignalLabels()
    for k, name in enumerate(signal_names):
        signal_dict[name] = file.readSignal(k)
        fs_dict[name] = file.getSampleFrequency(k)

In [145]:
print(signal_names)
eeg_names = ['F4-C4', 'C4-O2', 'F3-C3', 'C3-O1', 'C4-C3']
other_names = ['MOR', 'EMG', 'MOV SUP', 'MOV INF']
to_show_names = eeg_names + other_names

['F4-C4', 'C4-O2', 'F3-C3', 'C3-O1', 'C4-C3', 'MOR', 'EMG', 'MOV SUP', 'MOV INF', 'ECG', 'DIAFRAGMA', 'RESP NASAL', 'RESP ABD', 'T* Axilar', 'PULSO', 'OXIGENO', 'POSICION']


In [94]:
# -*- coding: utf-8 -*-
# -*- mode: python -*-
# Adapted from mpl_toolkits.axes_grid1
# LICENSE: Python Software Foundation (http://docs.python.org/license.html)

from matplotlib.offsetbox import AnchoredOffsetbox
class AnchoredScaleBar(AnchoredOffsetbox):
    def __init__(self, transform, sizex=0, sizey=0, labelx=None, labely=None, loc=4,
                 pad=0.1, borderpad=0.1, sep=2, prop=None, barcolor="black", barwidth=None, 
                 **kwargs):
        """
        Draw a horizontal and/or vertical  bar with the size in data coordinate
        of the give axes. A label will be drawn underneath (center-aligned).
        - transform : the coordinate frame (typically axes.transData)
        - sizex,sizey : width of x,y bar, in data units. 0 to omit
        - labelx,labely : labels for x,y bars; None to omit
        - loc : position in containing axes
        - pad, borderpad : padding, in fraction of the legend font size (or prop)
        - sep : separation between labels and bars in points.
        - **kwargs : additional arguments passed to base class constructor
        """
        from matplotlib.patches import Rectangle
        from matplotlib.offsetbox import AuxTransformBox, VPacker, HPacker, TextArea, DrawingArea
        bars = AuxTransformBox(transform)
        if sizex:
            bars.add_artist(Rectangle((0,0), sizex, 0, ec=barcolor, lw=barwidth, fc="none"))
        if sizey:
            bars.add_artist(Rectangle((0,0), 0, sizey, ec=barcolor, lw=barwidth, fc="none"))

        if sizex and labelx:
            self.xlabel = TextArea(labelx, minimumdescent=False)
            bars = VPacker(children=[bars, self.xlabel], align="center", pad=0, sep=sep)
        if sizey and labely:
            self.ylabel = TextArea(labely)
            bars = HPacker(children=[self.ylabel, bars], align="center", pad=0, sep=sep)

        AnchoredOffsetbox.__init__(self, loc, pad=pad, borderpad=borderpad,
                                   child=bars, prop=prop, frameon=False, **kwargs)

        
def add_scalebar(ax, matchx=True, matchy=True, hidex=True, hidey=True, **kwargs):
    """ Add scalebars to axes
    Adds a set of scale bars to *ax*, matching the size to the ticks of the plot
    and optionally hiding the x and y axes
    - ax : the axis to attach ticks to
    - matchx,matchy : if True, set size of scale bars to spacing between ticks
                    if False, size should be set using sizex and sizey params
    - hidex,hidey : if True, hide x-axis and y-axis of parent
    - **kwargs : additional arguments passed to AnchoredScaleBars
    Returns created scalebar object
    """
    def f(axis):
        l = axis.get_majorticklocs()
        return len(l)>1 and (l[1] - l[0])
    
    if matchx:
        kwargs['sizex'] = f(ax.xaxis)
        kwargs['labelx'] = str(kwargs['sizex'])
    if matchy:
        kwargs['sizey'] = f(ax.yaxis)
        kwargs['labely'] = str(kwargs['sizey'])
        
    sb = AnchoredScaleBar(ax.transData, **kwargs)
    ax.add_artist(sb)

    if hidex : ax.xaxis.set_visible(False)
    if hidey : ax.yaxis.set_visible(False)
    if hidex and hidey: ax.set_frame_on(False)

    return sb

In [149]:
# Show a certain page

this_pages = dataset.get_subject_pages(subject_id=subject_id)
this_stamps = dataset.get_subject_stamps(subject_id=subject_id)
stamps_color = '#B71C1C'


def filter_stamps(stamps, single_page, page_size):
    pages_list = []
    for i in range(stamps.shape[0]):
        stamp_start_page = stamps[i, 0] // page_size
        stamp_end_page = stamps[i, 1] // page_size

        start_inside = (stamp_start_page == single_page)
        end_inside = (stamp_end_page == single_page)

        if start_inside or end_inside:
            pages_list.append(stamps[i, :])
    return pages_list


def plot_page(page_idx):
    
    microvolt_per_second = 80  # Aspect ratio
    
    fig = plt.figure(figsize=(12, len(to_show_names)+5), dpi=DPI)
    gs = gridspec.GridSpec(4, 1, height_ratios=[1, 1, 1, 4*len(to_show_names)])
    
    page_idx = page_idx - 1
    page_chosen = this_pages[page_idx]
    page_start = page_chosen * dataset.page_size
    page_end = page_start + dataset.page_size
    
    segment_stamps = filter_stamps(this_stamps, page_chosen, dataset.page_size)
    segment_stamps_valid_0 = filter_stamps(raw_stamps_0, page_chosen, dataset.page_size)
    segment_stamps_valid_1 = filter_stamps(raw_stamps_1, page_chosen, dataset.page_size)
    segment_stamps_valid_2 = filter_stamps(raw_stamps_2, page_chosen, dataset.page_size)
    
    time_axis = np.arange(page_start, page_end) / fs
    
    gs_idx = 0
    
    # Show valid 0
    delta_y = 0.1
    ax = fig.add_subplot(gs[gs_idx])
    print('V0 stamps:')
    for this_stamp in segment_stamps_valid_0:
        print(this_stamp/200)
        ax.fill_between(
            this_stamp / fs, 1+delta_y, -delta_y, 
            facecolor=stamps_color, alpha=0.2,
            edgecolor='k', linewidth=1.5, 
        )
        
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_ylim([-delta_y, 1+delta_y])
    ax.set_xlim([time_axis[0], time_axis[-1]])
    ax.set_ylabel('V0', fontsize=8)

    # Show valid 1
    delta_y = 0.1
    gs_idx = gs_idx + 1
    ax = fig.add_subplot(gs[gs_idx])
    print('V1 stamps:')
    for this_stamp in segment_stamps_valid_1:
        print(this_stamp/200)
        ax.fill_between(
            this_stamp / fs, 1+delta_y, -delta_y, 
            facecolor=stamps_color, alpha=0.2,
            edgecolor='k', linewidth=1.5, 
        )
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_ylim([-delta_y, 1+delta_y])
    ax.set_xlim([time_axis[0], time_axis[-1]])
    ax.set_ylabel('V1', fontsize=8)
    
    # Show valid 2
    delta_y = 0.1
    gs_idx = gs_idx + 1
    ax = fig.add_subplot(gs[gs_idx])
    print('V2 stamps:')
    for this_stamp in segment_stamps_valid_2:
        print(this_stamp/200)
        ax.fill_between(
            this_stamp / fs, 1+delta_y, -delta_y, 
            facecolor=stamps_color, alpha=0.2,
            edgecolor='k', linewidth=1.5, 
        )
    ax.set_xticks(np.arange(time_axis[0], time_axis[-1], 0.5), minor=True)
    ax.set_xticks([
        time_axis[0], 
        time_axis[0] + 5, 
        time_axis[0] + 10, 
        time_axis[0] + 15, 
        time_axis[0] + 20])
    ax.set_yticks([])
    ax.set_ylim([-delta_y, 1+delta_y])
    ax.set_xlim([time_axis[0], time_axis[-1]])
    ax.set_ylabel('V2', fontsize=8)   
    ax.set_xlabel('Time [s]', fontsize=8)
    ax.tick_params(labelsize=8.5)
    
    # Signal
    gs_idx = gs_idx + 1
    y_max = 150
    y_sep = 120
    ax = fig.add_subplot(gs[gs_idx])
    
    for k, name in enumerate(to_show_names):
        if name == 'F4-C4':
            stamp_center = -y_sep*k
        segment_signal = signal_dict[name][page_start:page_end]
        ax.plot(
            time_axis, -y_sep*k + segment_signal, linewidth=1, color=CUSTOM_COLOR['grey'])
    
    
    add_scalebar(ax, matchx=True, matchy=True, hidex=False, hidey=False)
    for expert_stamp in segment_stamps:
        ax.fill_between(
            expert_stamp / fs, 50+stamp_center, -50+stamp_center, 
            facecolor=CUSTOM_COLOR['red'], alpha=0.2,
            edgecolor='k', linewidth=1.5, 
        )
    ax.set_yticks([-y_sep*k for k in range(len(to_show_names))])
    ax.set_yticklabels(to_show_names)
    # ax.set_yticks([])
    # ax.set_ylabel('F4-C4 [$\mu$V]')
    ax.set_xlim([time_axis[0], time_axis[-1]])
    ax.set_ylim([-y_max -y_sep*(len(to_show_names)-1) - 30, y_max])
    ax.set_title('Subject %d (%s INTA). Page in record: %d. (intervals of 0.5s are shown).' 
                 % (subject_id, NAMES[subject_id-1], page_chosen), fontsize=10)
    ax.set_xticks([
        time_axis[0], 
        time_axis[0] + 5, 
        time_axis[0] + 10, 
        time_axis[0] + 15, 
        time_axis[0] + 20])
    ax.set_xticks(np.arange(time_axis[0], time_axis[-1], 0.5), minor=True)
    ax.grid(b=True, axis='x', which='minor')
    ax.tick_params(labelsize=8.5)
    # ax.legend(loc='upper right', fontsize=8)
    ax.set_aspect(1/microvolt_per_second)
    ax.set_xlabel('Time [s]', fontsize=8)

    plt.tight_layout()
    plt.show()

In [150]:
widgets.interact(
    lambda page_idx: plot_page(page_idx),
    page_idx=widgets.IntSlider(min=1,max=this_pages.shape[0],step=1,value=359, continuous_update=False));

interactive(children=(IntSlider(value=359, continuous_update=False, description='page_idx', max=1450, min=1), …