In [20]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
import pyedflib

import matplotlib.pyplot as plt
from matplotlib import colors
import numpy as np
import ipywidgets as widgets
from matplotlib import gridspec

project_root = '..'
sys.path.append(project_root)

from sleeprnn.data.inta_ss import IntaSS, NAMES
from sleeprnn.data import utils
from sleeprnn.common import constants, pkeys

DPI = 200
CUSTOM_COLOR = {'red': '#c62828', 'grey': '#455a64', 'blue': '#0277bd', 'green': '#43a047'} 

%matplotlib inline
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [2]:
dataset = IntaSS(load_checkpoint=True)

Dataset inta_ss with 10 patients.
Loading from checkpoint... Loaded


In [31]:
# Load stamps of subject
subject_id = 1

fs = dataset.fs
path_dataset = dataset.dataset_dir
path_stamps = os.path.join(path_dataset, 'label/spindle', 'SS_%s.txt' % NAMES[subject_id - 1])
path_signals = os.path.join(path_dataset, 'register', '%s.rec' % NAMES[subject_id - 1]) 

with pyedflib.EdfReader(path_signals) as file:
    signal = file.readSignal(0)
    signal_len = signal.shape[0]
data = np.loadtxt(path_stamps)
for_this_channel = (data[:, -1] == 1)
data = data[for_this_channel]
data = np.round(data).astype(np.int32)
print(signal_len)
print(data)

5808200
[[  86721   86908     -50     -50       0       1]
 [ 133096  133204     -50     -50       0       1]
 [ 166983  167084     -50     -50       0       1]
 ...
 [5738701 5738801     -50     -50       0       1]
 [5765348 5765608     -50     -50       0       1]
 [5798352 5798524     -50     -50       0       1]]


In [32]:
# Remove zero duration marks, and ensure that start time < end time
new_data = []
for i in range(data.shape[0]):
    if data[i, 0] > data[i, 1]:
        print('End time < Start time fixed')
        aux = data[i, 0]
        data[i, 0] = data[i, 1]
        data[i, 1] = aux
        new_data.append(data[i, :])
    elif data[i, 0] < data[i, 1]:
        new_data.append(data[i, :])
    else:  # Zero duration (equality)
        print('Zero duration stamp found and removed')
data = np.stack(new_data, axis=0)
print(data)

End time < Start time fixed
End time < Start time fixed
End time < Start time fixed
End time < Start time fixed
End time < Start time fixed
End time < Start time fixed
Zero duration stamp found and removed
End time < Start time fixed
End time < Start time fixed
End time < Start time fixed
Zero duration stamp found and removed
End time < Start time fixed
Zero duration stamp found and removed
End time < Start time fixed
End time < Start time fixed
End time < Start time fixed
[[  86721   86908     -50     -50       0       1]
 [ 133096  133204     -50     -50       0       1]
 [ 166983  167084     -50     -50       0       1]
 ...
 [5738701 5738801     -50     -50       0       1]
 [5765348 5765608     -50     -50       0       1]
 [5798352 5798524     -50     -50       0       1]]


In [58]:
# Remove stamps outside signal boundaries
new_data = []
for i in range(data.shape[0]):
    if data[i, 1] < signal_len:
        new_data.append(data[i, :])
    else:
        print('Stamp outside boundaries found and removed')
data = np.stack(new_data, axis=0)
print(data)

[[  86721   86908     -50     -50       0       1]
 [ 133096  133204     -50     -50       0       1]
 [ 166983  167084     -50     -50       0       1]
 ...
 [5738701 5738801     -50     -50       0       1]
 [5765348 5765608     -50     -50       0       1]
 [5798352 5798524     -50     -50       0       1]]


In [59]:
raw_stamps = data[:, [0, 1]]
print(raw_stamps)

print(raw_stamps.shape)
valid = data[:, 4]
raw_stamps_0 = raw_stamps[valid == 0]
raw_stamps_1 = raw_stamps[valid == 1]
raw_stamps_2 = raw_stamps[valid == 2]
print('Valid 0', raw_stamps_0.shape, 'Min dur [s]', (raw_stamps_0[:, 1] - raw_stamps_0[:, 0]).min()/fs, 'Max dur [s]', (raw_stamps_0[:, 1] - raw_stamps_0[:, 0]).max()/fs)
print('Valid 1', raw_stamps_1.shape, 'Min dur [s]', (raw_stamps_1[:, 1] - raw_stamps_1[:, 0]).min()/fs, 'Max dur [s]', (raw_stamps_1[:, 1] - raw_stamps_1[:, 0]).max()/fs)
print('Valid 2', raw_stamps_2.shape, 'Min dur [s]', (raw_stamps_2[:, 1] - raw_stamps_2[:, 0]).min()/fs, 'Max dur [s]', (raw_stamps_2[:, 1] - raw_stamps_2[:, 0]).max()/fs)

[[  86721   86908]
 [ 133096  133204]
 [ 166983  167084]
 ...
 [5738701 5738801]
 [5765348 5765608]
 [5798352 5798524]]
(5854, 2)
Valid 0 (1302, 2) Min dur [s] 0.03 Max dur [s] 4.955
Valid 1 (2659, 2) Min dur [s] 0.3 Max dur [s] 8.45
Valid 2 (1893, 2) Min dur [s] 0.025 Max dur [s] 13.395


In [105]:
# look for intersections between stamps of the same validity
def overlap_matrix(events, detections):
    # Matrix of overlap, rows are events, columns are detections
    n_det = detections.shape[0]
    n_gs = events.shape[0]
    overlaps = np.zeros((n_gs, n_det))
    for i in range(n_gs):
        candidates = np.where(
            (detections[:, 0] <= events[i, 1])
            & (detections[:, 1] >= events[i, 0]))[0]
        for j in candidates:
            intersection = min(
                events[i, 1], detections[j, 1]
            ) - max(
                events[i, 0], detections[j, 0]
            ) + 1
            if intersection > 0:
                overlaps[i, j] = 1
    return overlaps


overlap_m = overlap_matrix(raw_stamps_0, raw_stamps_0)
n_overlaps_0 = overlap_m.sum(axis=1) - 1  # We expect all stamps intersect with themselves, so we discount one
values_0, counts_0 = np.unique(n_overlaps_0, return_counts=True)
print('\nOverlaps for Valid 0')
for value, count in zip(values_0, counts_0):
    print('%d overlaps: %d times' % (value, count))

overlap_m = overlap_matrix(raw_stamps_1, raw_stamps_1)
n_overlaps_1 = overlap_m.sum(axis=1) - 1  # We expect all stamps intersect with themselves, so we discount one
values_1, counts_1 = np.unique(n_overlaps_1, return_counts=True)
print('\nOverlaps for Valid 1')
for value, count in zip(values_1, counts_1):
    print('%d overlaps: %d times' % (value, count))

overlap_m = overlap_matrix(raw_stamps_2, raw_stamps_2)
n_overlaps_2 = overlap_m.sum(axis=1) - 1  # We expect all stamps intersect with themselves, so we discount one
values_2, counts_2 = np.unique(n_overlaps_2, return_counts=True)
print('\nOverlaps for Valid 2')
for value, count in zip(values_2, counts_2):
    print('%d overlaps: %d times' % (value, count))

max_overlaps_0 = values_0.max()
max_overlaps_1 = values_1.max()
max_overlaps_2 = values_2.max()
max_overlaps = np.max([max_overlaps_0, max_overlaps_1, max_overlaps_2])


Overlaps for Valid 0
0 overlaps: 817 times
1 overlaps: 366 times
2 overlaps: 92 times
3 overlaps: 19 times
4 overlaps: 3 times
5 overlaps: 5 times

Overlaps for Valid 1
0 overlaps: 2654 times
1 overlaps: 4 times
2 overlaps: 1 times

Overlaps for Valid 2
0 overlaps: 1050 times
1 overlaps: 490 times
2 overlaps: 225 times
3 overlaps: 82 times
4 overlaps: 37 times
5 overlaps: 6 times
7 overlaps: 2 times
10 overlaps: 1 times


In [135]:
# Show a certain page

this_pages = dataset.get_subject_pages(subject_id=subject_id)
this_signal = dataset.get_subject_signal(subject_id=subject_id, normalize_clip=False)
this_stamps = dataset.get_subject_stamps(subject_id=subject_id)


def filter_stamps(stamps, single_page, page_size):
    pages_list = []
    for i in range(stamps.shape[0]):
        stamp_start_page = stamps[i, 0] // page_size
        stamp_end_page = stamps[i, 1] // page_size

        start_inside = (stamp_start_page == single_page)
        end_inside = (stamp_end_page == single_page)

        if start_inside or end_inside:
            pages_list.append(stamps[i, :])
    return pages_list


def plot_page(page_idx):
    
    microvolt_per_second = 100  # Aspect ratio
    
    fig = plt.figure(figsize=(10, 6), dpi=100)
    gs = gridspec.GridSpec(4, 1, height_ratios=[4, 1, 1, 1])
    
    page_idx = page_idx - 1
    page_chosen = this_pages[page_idx]
    page_start = page_chosen * dataset.page_size
    page_end = page_start + dataset.page_size
    
    segment_signal = this_signal[page_start:page_end]
    
    segment_stamps = filter_stamps(this_stamps, page_chosen, dataset.page_size)
    segment_stamps_valid_0 = filter_stamps(raw_stamps_0, page_chosen, dataset.page_size)
    segment_stamps_valid_1 = filter_stamps(raw_stamps_1, page_chosen, dataset.page_size)
    segment_stamps_valid_2 = filter_stamps(raw_stamps_2, page_chosen, dataset.page_size)
    
    time_axis = np.arange(page_start, page_end) / fs
    
    gs_idx = 0
    
    # Signal
    y_max = 250
    ax = fig.add_subplot(gs[gs_idx])
    ax.plot(
        time_axis, segment_signal, 
        linewidth=1, color=CUSTOM_COLOR['grey'])
    for expert_stamp in segment_stamps:
        ax.fill_between(
            expert_stamp / fs, y_max, -y_max, 
            facecolor=CUSTOM_COLOR['blue'], alpha=0.3,
            edgecolor='k', linewidth=1.5, 
        )
    # ax.set_yticks([-100, 100])
    ax.set_yticks([])
    # ax.set_ylabel('F4-C4 [$\mu$V]')
    ax.set_xlim([time_axis[0], time_axis[-1]])
    ax.set_ylim([-y_max, y_max])
    ax.set_title('Subject %d (%s INTA). Page in record: %d. (intervals of 0.5s are shown).' 
                 % (subject_id, NAMES[subject_id-1], page_chosen), fontsize=10)
    ax.set_xticks([
        time_axis[0], 
        time_axis[0] + 5, 
        time_axis[0] + 10, 
        time_axis[0] + 15, 
        time_axis[0] + 20])
    ax.set_xticks(np.arange(time_axis[0], time_axis[-1], 0.5), minor=True)
    ax.grid(b=True, axis='x', which='minor')
    ax.tick_params(labelsize=8.5)
    
    ax.set_aspect(1/microvolt_per_second)
    
    # Show valid 2
    delta_y = 0.1
    gs_idx = gs_idx + 1
    ax = fig.add_subplot(gs[gs_idx])
    
    for this_stamp in segment_stamps_valid_2:
        ax.fill_between(
            this_stamp / fs, 1+delta_y, -delta_y, 
            facecolor=CUSTOM_COLOR['blue'], alpha=float(1/max_overlaps),
            edgecolor='k', linewidth=1.5, 
        )
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_ylim([-delta_y, 1+delta_y])
    ax.set_xlim([time_axis[0], time_axis[-1]])
    ax.set_ylabel('Valid 2', fontsize=10)
    
    # Show valid 1
    delta_y = 0.1
    gs_idx = gs_idx + 1
    ax = fig.add_subplot(gs[gs_idx])
    
    for this_stamp in segment_stamps_valid_1:
        ax.fill_between(
            this_stamp / fs, 1+delta_y, -delta_y, 
            facecolor=CUSTOM_COLOR['blue'], alpha=float(1/max_overlaps),
            edgecolor='k', linewidth=1.5, 
        )
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_ylim([-delta_y, 1+delta_y])
    ax.set_xlim([time_axis[0], time_axis[-1]])
    ax.set_ylabel('Valid 1', fontsize=10)
    
    # Show valid 0
    delta_y = 0.1
    gs_idx = gs_idx + 1
    ax = fig.add_subplot(gs[gs_idx])
    
    for this_stamp in segment_stamps_valid_0:
        ax.fill_between(
            this_stamp / fs, 1+delta_y, -delta_y, 
            facecolor=CUSTOM_COLOR['blue'], alpha=float(1/max_overlaps),
            edgecolor='k', linewidth=1.5, 
        )
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_ylim([-delta_y, 1+delta_y])
    ax.set_xlim([time_axis[0], time_axis[-1]])
    ax.set_ylabel('Valid 0', fontsize=10)
    
    plt.tight_layout()
    plt.show()

In [136]:
widgets.interact(
    lambda page_idx: plot_page(page_idx),
    page_idx=widgets.IntSlider(min=1,max=this_pages.shape[0],step=1,value=1, continuous_update=False));

interactive(children=(IntSlider(value=1, continuous_update=False, description='page_idx', max=1450, min=1), Ou…