# Imports

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import datetime
import json
import os
import sys

import matplotlib.pyplot as plt
from matplotlib import colors
import numpy as np
from matplotlib import gridspec
import matplotlib.image as mpimg

project_root = '..'
sys.path.append(project_root)

from sleeprnn.data.inta_ss import IntaSS, NAMES
from sleeprnn.data import utils, stamp_correction
from sleeprnn.detection import metrics
from sleeprnn.helpers import reader, misc, plotter
from sleeprnn.common import constants, pkeys, viz

%matplotlib inline

viz.notebook_full_width()

# Load data

In [None]:
# Load dataset
dataset = IntaSS(load_checkpoint=True)
dataset_name = dataset.dataset_name
fs = dataset.fs
marked_channel = 'F4-C4'

# Choose subject

In [None]:
# Order: (from worst to best)
# 11 TAGO [x] | 269 conflict pages
# 08 CAPO [x] | 82
# 02 ALUR [...] | 314
# 06 BTOL08 | 9
# 04 BRCA | 479
# 09 CRCA | 308
# 10 ESCI | 69
# 05 BRLO | 156
# 07 BTOL09 | 22
# 03 BECA | 3
# 01 ADGU | 232

# Load stamps of subject
subject_id = 2

print('Loading S%02d' % subject_id)
path_stamps = os.path.join(dataset.dataset_dir, 'label/spindle', 'SS_%s.txt' % NAMES[subject_id - 1])
path_signals = os.path.join(dataset.dataset_dir, 'register', '%s.rec' % NAMES[subject_id - 1]) 
signal_dict = reader.read_signals_from_edf(path_signals)
signal_names = list(signal_dict.keys())
to_show_names = misc.get_inta_eeg_names(signal_names) + misc.get_inta_eog_emg_names(signal_names)
for single_name in misc.get_inta_eeg_names(signal_names):
    this_signal = signal_dict[single_name]
    print('Filtering %s channel' % single_name)
    this_signal = utils.broad_filter(this_signal, fs)
    signal_dict[single_name] = this_signal
raw_stamps_1, raw_stamps_2 = reader.load_raw_inta_stamps(path_stamps, path_signals, min_samples=20, chn_idx=0)
durations_1 = (raw_stamps_1[:, 1] - raw_stamps_1[:, 0]) / fs
durations_2 = (raw_stamps_2[:, 1] - raw_stamps_2[:, 0]) / fs
print('V1', raw_stamps_1.shape, 'Min dur [s]', durations_1.min(), 'Max dur [s]', durations_1.max())
print('V2', raw_stamps_2.shape, 'Min dur [s]', durations_2.min(), 'Max dur [s]', durations_2.max())
overlap_m = utils.get_overlap_matrix(raw_stamps_1, raw_stamps_1)
groups_overlap_1 = utils.overlapping_groups(overlap_m)
overlap_m = utils.get_overlap_matrix(raw_stamps_2, raw_stamps_2)
groups_overlap_2 = utils.overlapping_groups(overlap_m)
n_overlaps_1 = [len(single_group) for single_group in groups_overlap_1]
values_1, counts_1 = np.unique(n_overlaps_1, return_counts=True)
print('\nSize of overlapping groups for Valid 1')
for value, count in zip(values_1, counts_1):
    print('%d marks: %d times' % (value, count))
n_overlaps_2 = [len(single_group) for single_group in groups_overlap_2]
values_2, counts_2 = np.unique(n_overlaps_2, return_counts=True)
print('\nSize of overlapping groups for Valid 2')
for value, count in zip(values_2, counts_2):
    print('%d marks: %d times' % (value, count))
max_overlaps = np.max([values_1.max(), values_2.max()]) - 1

if subject_id != 3:
    this_pages = dataset.get_subject_pages(subject_id=subject_id) 
else:
    this_pages = np.arange(1, signal_dict[marked_channel].size//dataset.page_size - 1)
print('This pages', this_pages.shape)

# Conflicts

In [None]:
# Select marks without doubt
groups_in_doubt_v1_list = []
groups_in_doubt_v2_list = []

iou_to_accept = 0.8
marks_without_doubt = []
overlap_between_1_and_2 = utils.get_overlap_matrix(raw_stamps_1, raw_stamps_2)

for single_group in groups_overlap_2:
    if len(single_group) == 1:
        marks_without_doubt.append(raw_stamps_2[single_group[0], :])
    elif len(single_group) == 2:
        # check if IOU between marks is close 1, if close, then just choose newer (second one)
        option1_mark = raw_stamps_2[single_group[0], :]
        option2_mark = raw_stamps_2[single_group[1], :]
        iou_between_marks = metrics.get_iou(option1_mark, option2_mark)
        if iou_between_marks >= iou_to_accept:
            marks_without_doubt.append(option2_mark)
        else:
            groups_in_doubt_v2_list.append(single_group)
    else:
        groups_in_doubt_v2_list.append(single_group)
        
for single_group in groups_overlap_1:
    is_in_doubt = False
    # Check if entire group is overlapping
    all_are_overlapping_2 = np.all(overlap_between_1_and_2[single_group, :].sum(axis=1))
    if not all_are_overlapping_2:
        # Consider the mark
        if len(single_group) == 1:
            # Since has size 1 and is no overlapping 2, accept it
            marks_without_doubt.append(raw_stamps_1[single_group[0], :])
        elif len(single_group) == 2:
            # check if IOU between marks is close 1, if close, then just choose newer (second one) since there is no intersection
            option1_mark = raw_stamps_1[single_group[0], :]
            option2_mark = raw_stamps_1[single_group[1], :]
            iou_between_marks = metrics.get_iou(option1_mark, option2_mark)
            if iou_between_marks >= iou_to_accept:
                marks_without_doubt.append(raw_stamps_1[single_group[1], :])
            else:
                is_in_doubt = True
        else:
            is_in_doubt = True
    if is_in_doubt:
        groups_in_doubt_v1_list.append(single_group)

marks_without_doubt = np.stack(marks_without_doubt, axis=0)
marks_without_doubt = np.sort(marks_without_doubt, axis=0)
print('Marks automatically added:', marks_without_doubt.shape)
print('Remaining conflicts:')
print('    V1: %d' % len(groups_in_doubt_v1_list))
print('    V2: %d' % len(groups_in_doubt_v2_list))

In [None]:
show_complete_conflict_detail = False

conflict_pages = []

if show_complete_conflict_detail:
    print('Conflict detail')
for single_group in groups_in_doubt_v1_list:
    group_stamps = raw_stamps_1[single_group, :]
    min_sample = group_stamps.min()
    max_sample = group_stamps.max()
    center_group = (min_sample + max_sample) / 2
    integer_page = int(center_group / dataset.page_size)
    decimal_part = np.round(2 * (center_group % dataset.page_size) / dataset.page_size) / 2 - 0.5
    page_location = integer_page + decimal_part
    conflict_pages.append(page_location)
    if show_complete_conflict_detail:
        print('V1 - Group of size %d at page %1.1f' % (group_stamps.shape[0], page_location ))

for single_group in groups_in_doubt_v2_list:
    group_stamps = raw_stamps_2[single_group, :]
    min_sample = group_stamps.min()
    max_sample = group_stamps.max()
    center_group = (min_sample + max_sample) / 2
    integer_page = int(center_group / dataset.page_size)
    decimal_part = np.round(2 * (center_group % dataset.page_size) / dataset.page_size) / 2 - 0.5
    page_location = integer_page + decimal_part
    conflict_pages.append(page_location)
    if show_complete_conflict_detail:
        print('V2 - Group of size %d at page %1.1f' % (group_stamps.shape[0], page_location ))
conflict_pages = np.unique(conflict_pages)

print('')
print('Number of pages with conflict %d' % conflict_pages.size)

In [None]:
# Add available final versions of marks
string_to_search = 'Revision_SS_%s.txt' % NAMES[subject_id-1]
available_files = os.listdir('mark_files')
res = [f for f in available_files if string_to_search in f]
print('Files found for "%s":' % string_to_search)
print(res)
if res:
    this_final_marks = np.loadtxt(os.path.join('mark_files', res[0]))
    this_final_marks = this_final_marks[:, [0, 1]]
else:
    this_final_marks = np.array([])

# Plotter functions

In [None]:
def plot_page_conflict(conflict_idx, ax, show_final=False):
    signal_uv_to_display = 20
    microvolt_per_second = 200  # Aspect ratio
    page_chosen = conflict_pages[conflict_idx-1]
    page_start = page_chosen * dataset.page_size
    page_end = page_start + dataset.page_size
    segment_stamps = utils.filter_stamps(marks_without_doubt, page_start, page_end)
    segment_stamps_valid_1 = utils.filter_stamps(raw_stamps_1, page_start, page_end)
    segment_stamps_valid_2 = utils.filter_stamps(raw_stamps_2, page_start, page_end)
    segment_stamps_final = utils.filter_stamps(this_final_marks, page_start, page_end) if show_final else []    
    time_axis = np.arange(page_start, page_end) / fs
    x_ticks = np.arange(time_axis[0], time_axis[-1]+1, 1)
    dy_valid = 40
    shown_valid = False
    valid_label = 'Candidate mark'
    # Show valid 1
    valid_start = -100
    shown_groups_1 = []
    for j, this_stamp in enumerate(segment_stamps_valid_1):
        idx_stamp = np.where([np.all(this_stamp == single_stamp) for single_stamp in raw_stamps_1])[0]
        idx_group = np.where([idx_stamp in single_group for single_group in groups_overlap_1])[0][0].item()
        shown_groups_1.append(idx_group)
    shown_groups_1 = np.unique(shown_groups_1)
    max_size_shown = 0
    for single_group in shown_groups_1:
        group_stamps = [raw_stamps_1[single_idx] for single_idx in groups_overlap_1[single_group]]
        group_stamps = np.stack(group_stamps, axis=0)
        group_size = group_stamps.shape[0]
        if group_size > max_size_shown:
            max_size_shown = group_size
        for j, single_stamp in enumerate(group_stamps):
            stamp_idx = int(1 * 1e4 + groups_overlap_1[single_group][j])
            color_for_display = viz.PALETTE['red']
            ax.plot(
                single_stamp/fs, [valid_start-j*dy_valid, valid_start-j*dy_valid], 
                color=color_for_display, linewidth=1.5, label=valid_label)
            ax.annotate(
                stamp_idx, (single_stamp[1]/fs+0.05, valid_start-j*dy_valid-10), fontsize=7)
            shown_valid = True
            valid_label = None
    valid_1_center = valid_start - (max_size_shown//2) * dy_valid
    # Show valid 2
    valid_start = - max_size_shown * dy_valid - 200
    shown_groups_2 = []
    for j, this_stamp in enumerate(segment_stamps_valid_2):
        idx_stamp = np.where([np.all(this_stamp == single_stamp) for single_stamp in raw_stamps_2])[0]
        idx_group = np.where([idx_stamp in single_group for single_group in groups_overlap_2])[0][0].item()
        shown_groups_2.append(idx_group)
    shown_groups_2 = np.unique(shown_groups_2)
    max_size_shown = 0
    for single_group in shown_groups_2:
        group_stamps = [raw_stamps_2[single_idx] for single_idx in groups_overlap_2[single_group]]
        group_stamps = np.stack(group_stamps, axis=0)
        group_size = group_stamps.shape[0]
        if group_size > max_size_shown:
            max_size_shown = group_size
        for j, single_stamp in enumerate(group_stamps):
            stamp_idx = int(2 * 1e4 + groups_overlap_2[single_group][j])
            color_for_display = viz.PALETTE['red']
            ax.plot(
                single_stamp/fs, [valid_start-j*dy_valid, valid_start-j*dy_valid], 
                color=color_for_display, linewidth=1.5, label=valid_label)
            ax.annotate(stamp_idx, (single_stamp[1]/fs+0.05, valid_start-j*dy_valid-10), fontsize=7)
            shown_valid = True
            valid_label = None
    valid_2_center = valid_start - (max_size_shown//2) * dy_valid
    # Signal
    y_max = 150
    y_sep = 300
    start_signal_plot = valid_start - max_size_shown * dy_valid - y_sep
    y_minor_ticks = []
    for k, name in enumerate(to_show_names):
        if name == 'F4-C4':
            stamp_center = start_signal_plot-y_sep*k
        #if name == 'EMG':
        #    continue
        segment_fs = fs
        segment_start = int(page_chosen * dataset.page_duration * segment_fs)
        segment_end = int(segment_start + dataset.page_duration * segment_fs)
        segment_signal = signal_dict[name][segment_start:segment_end]
        segment_time_axis = np.arange(segment_start, segment_end) / segment_fs
        ax.plot(
            segment_time_axis, start_signal_plot-y_sep*k + segment_signal, linewidth=0.7, color=viz.PALETTE['grey'])
        y_minor_ticks.append(start_signal_plot-y_sep*k + signal_uv_to_display)
        y_minor_ticks.append(start_signal_plot-y_sep*k - signal_uv_to_display)
    plotter.add_scalebar(
        ax, matchx=False, matchy=False, hidex=False, hidey=False, sizex=1, sizey=100, 
        labelx='1 s', labely='100 uV', loc=1)
    expert_shown = False
    for expert_stamp in segment_stamps:
        label = None if expert_shown else 'Accepted mark (automatic)'
        ax.plot(
            expert_stamp / fs, [stamp_center-50, stamp_center-50], 
            color=viz.PALETTE['green'], linewidth=2, label=label)
        expert_shown = True
    expert_manual_shown = False
    for final_stamp in segment_stamps_final:
        label = None if expert_manual_shown else 'Expert Final Version'
        ax.fill_between(
            final_stamp / fs, 100+stamp_center, -100+stamp_center, 
            facecolor=viz.PALETTE['grey'], alpha=0.4,  label=label, edgecolor='k')
        expert_manual_shown = True
    ticks_valid = [valid_1_center, valid_2_center]
    ticks_signal = [start_signal_plot-y_sep*k for k in range(len(to_show_names))]
    ticklabels_valid = ['V1', 'V2']
    total_ticks = ticks_valid + ticks_signal
    total_ticklabels = ticklabels_valid + to_show_names[:-2] + ['MOR', 'EMG']
    ax.set_yticks(total_ticks)
    ax.set_yticklabels(total_ticklabels)
    ax.set_xlim([time_axis[0], time_axis[-1]])
    ax.set_ylim([-y_max - 30 + ticks_signal[-1], 100])
    ax.set_title('Subject %d (%s INTA). Page in record: %1.1f. (intervals of 0.5s are shown as a vertical grid).' 
                 % (subject_id, NAMES[subject_id-1], page_chosen), fontsize=10, y=1.05)
    ax.set_xticks(x_ticks)
    ax.set_xticks(np.arange(time_axis[0], time_axis[-1], 0.5), minor=True)
    ax.grid(b=True, axis='x', which='minor')
    ax.tick_params(labelsize=7.5, labelbottom=True ,labeltop=True, bottom=True, top=True)
    ax.set_aspect(1/microvolt_per_second)
    ax.set_xlabel('Time [s]', fontsize=8)
    if expert_shown or shown_valid:
        lg = ax.legend(loc='lower left', fontsize=8)
        for lh in lg.legendHandles:
            lh.set_alpha(1.0)
    plt.tight_layout()
    return ax

# Visual validation

In [None]:
start_from_conflict = 127

folder_name = '%s_conflicts' % NAMES[subject_id - 1]
os.makedirs(folder_name, exist_ok=True)
n_conflicts = conflict_pages.size
print('Total conflicting pages: %d' % n_conflicts)
fig, ax = plt.subplots(1, 1, figsize=(12, 1+len(to_show_names)), dpi=180)
for conflict_id in range(start_from_conflict, n_conflicts + 1):
    fname = os.path.join(folder_name, 'conflict_%03d.pdf' % conflict_id)
    ax.clear()
    ax = plot_page_conflict(conflict_id, ax)
    plt.savefig(fname, dpi=200, bbox_inches="tight", pad_inches=0.02)
plt.close('all')

# Verify Correction Transcription

In [None]:
start_from_conflict = 127
optional_end_conflict = 127 + 7

if optional_end_conflict is None:
    optional_end_conflict = n_conflicts

folder_name = '%s_conflicts_final' % NAMES[subject_id - 1]
os.makedirs(folder_name, exist_ok=True)
n_conflicts = conflict_pages.size
print('Total conflicting pages: %d' % n_conflicts)
fig, ax = plt.subplots(1, 1, figsize=(12, 1+len(to_show_names)), dpi=180)
for conflict_id in range(start_from_conflict, optional_end_conflict + 1):
    fname = os.path.join(folder_name, 'conflict_%03d.pdf' % conflict_id)
    ax.clear()
    ax = plot_page_conflict(conflict_id, ax, show_final=True)
    plt.savefig(fname, dpi=200, bbox_inches="tight", pad_inches=0.02)
plt.close('all')