# Imports

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import json
import os
from pprint import pprint
import sys

import ipywidgets as widgets
import matplotlib.pyplot as plt
from matplotlib import colors, gridspec
import numpy as np
from scipy.stats import gaussian_kde
from scipy.interpolate import interp1d
from matplotlib.lines import Line2D

project_root = '..'
sys.path.append(project_root)

from sleeprnn.common import constants, pkeys, viz
from sleeprnn.common.optimal_thresholds import OPTIMAL_THR_FOR_CKPT_DICT
from sleeprnn.data import utils, stamp_correction
from sleeprnn.detection.feeder_dataset import FeederDataset
from sleeprnn.detection.postprocessor import PostProcessor
from sleeprnn.detection import metrics
from sleeprnn.helpers import reader, plotter, printer, misc, performer

RESULTS_PATH = os.path.join(project_root, 'results')
COMPARISON_PATH = os.path.join(project_root, 'resources', 'comparison_data')

%matplotlib inline
viz.notebook_full_width()

# Load data

In [None]:
filter_dates = [20200606, None]
printer.print_available_ckpt(OPTIMAL_THR_FOR_CKPT_DICT, filter_dates)

In [None]:
dataset_name = constants.MASS_SS_NAME
fs = 200
which_expert = 1
task_mode = constants.N2_RECORD
seed_id_list = [i for i in range(2)]
set_list = [constants.VAL_SUBSET, constants.TRAIN_SUBSET]

# Specify what to load
comparison_runs_list = [
    ('20191227_bsf_10runs_e1_n2_train_mass_ss/v11', 'Regular\nCrossEntropy'),
    ('20200606_var_reg_first_grid_n2_train_mass_ss/v11_reg_-6.0', 'Weights\n+ $\lambda$ = 1e-6'),
    ('20200606_var_reg_first_grid_n2_train_mass_ss/v11_reg_-1.0', 'Weights\n+ $\lambda$ = 0.1'),
    ('20200606_var_reg_first_grid_n2_train_mass_ss/v11_reg_0.0', 'Weights\n+ $\lambda$ = 1'),
    ('20200606_var_reg_first_grid_n2_train_mass_ss/v11_reg_1.0', 'Weights\n+ $\lambda$ = 10'),
]
comparison_runs_list = [
    (t_folder, t_label) for (t_folder, t_label) in comparison_runs_list if dataset_name in t_folder
]
ckpt_folder_list = [t_folder for (t_folder, t_label) in comparison_runs_list]
ckpt_folder_dict = {t_label: t_folder for (t_folder, t_label) in comparison_runs_list}
ckpt_label_dict = {t_folder: t_label for (t_folder, t_label) in comparison_runs_list}

# Load data
n_cases = len(comparison_runs_list)
dataset = reader.load_dataset(dataset_name, params={pkeys.FS: fs})
ids_dict = {
    constants.ALL_TRAIN_SUBSET: dataset.train_ids,
    constants.TEST_SUBSET: dataset.test_ids}
ids_dict.update(misc.get_splits_dict(dataset, seed_id_list))
predictions_dict = {}
for ckpt_folder in ckpt_folder_list:
    predictions_dict[ckpt_folder] = reader.read_prediction_with_seeds(
        ckpt_folder, dataset_name, task_mode, seed_id_list, set_list=set_list, parent_dataset=dataset)
# useful for viz
iou_hist_bins = np.linspace(0, 1, 21)
iou_curve_axis = misc.custom_linspace(0.05, 0.95, 0.05)
result_id = '%s-%s-E%d-%s' % (
    dataset_name.split('_')[0].upper(), 
    dataset_name.split('_')[1].upper(), 
    which_expert,
    task_mode.upper())
expert_data_dict = reader.load_ss_expert_performance()
exp_keys = list(expert_data_dict.keys())
print('\nAvailable data:')
pprint(exp_keys)

# Prediction Variability

In [None]:
seeds_to_show = [0, 1]
set_name = constants.VAL_SUBSET
global_thr = 0.35
model_colors = [
    viz.PALETTE['red'], 
    viz.PALETTE['blue'],  
    viz.PALETTE['green'], 
    viz.PALETTE['purple'], 
    viz.PALETTE['cyan']
]

n2_diff = {}
det_diff = {}
det_mean = {}
det_std = {}
for ckpt_folder in ckpt_folder_list:
    print(ckpt_label_dict[ckpt_folder], OPTIMAL_THR_FOR_CKPT_DICT[ckpt_folder])
    n2_diff[ckpt_folder] = []
    det_diff[ckpt_folder] = []
    det_mean[ckpt_folder] = []
    det_std[ckpt_folder] = []
    for k in seeds_to_show:
        val_ids = ids_dict[k][set_name]
        t_preds = predictions_dict[ckpt_folder][k][set_name]
        t_preds.set_probability_threshold(global_thr)
        t_dets = t_preds.get_stamps()
        t_probas = t_preds.get_probabilities()
        t_pages = t_preds.get_pages(pages_subset=constants.N2_RECORD)
        for i, single_id in enumerate(val_ids):
            s_pages = t_pages[i]
            s_dets = t_dets[i]
            s_probas = t_probas[i]
            s_probas_n2 = utils.extract_pages(s_probas, s_pages, 4000 // 8)

            tmp_array = np.abs(np.diff(s_probas_n2)).flatten()
            n2_diff[ckpt_folder].append(tmp_array)

            s_dets_down = np.round(s_dets / 8).astype(np.int32)            
            s_probas_dets = [s_probas[t0:tf] for (t0, tf) in s_dets_down]
            s_probas_dets = [segment[1:-1] for segment in s_probas_dets]
            
            tmp_array = np.concatenate([np.abs(np.diff(segment)) for segment in s_probas_dets])
            det_diff[ckpt_folder].append(tmp_array)
            tmp_array = np.array([np.mean(segment) for segment in s_probas_dets])
            det_mean[ckpt_folder].append(tmp_array)
            tmp_array = np.array([np.std(segment) for segment in s_probas_dets])
            det_std[ckpt_folder].append(tmp_array)
            
    n2_diff[ckpt_folder] = np.concatenate(n2_diff[ckpt_folder])
    det_diff[ckpt_folder] = np.concatenate(det_diff[ckpt_folder])
    det_mean[ckpt_folder] = np.concatenate(det_mean[ckpt_folder])
    det_std[ckpt_folder] = np.concatenate(det_std[ckpt_folder])

In [None]:
n_cases = len(ckpt_folder_list)
bins = np.linspace(0, 1, 41)

fig, axes = plt.subplots(n_cases, 4, figsize=(8, 1*n_cases), dpi=200)

tmp_list = []
axes[0, 0].set_title('Overall variability', fontsize=9)
for i, ckpt_folder in enumerate(ckpt_folder_list):
    ax = axes[i, 0]
    n, _, _ = ax.hist(
        n2_diff[ckpt_folder], bins=bins, label=ckpt_label_dict[ckpt_folder], facecolor=model_colors[i])
    tmp_list.append(np.max(n))
    ax.set_yscale('log')
    ax.set_xlim([0, 0.6])
    ax.tick_params(labelsize=7)
    # ax.legend(loc='upper right', fontsize=7)
    if i < n_cases-1:
        ax.set_xticks([])
n = np.max(tmp_list)
[ax.set_ylim([0.5, 10*n]) for ax in axes[:, 0]]
axes[-1, 0].set_xlabel(r'$\Delta p$', fontsize=9)

tmp_list = []
axes[0, 1].set_title('Detection variability', fontsize=9)
for i, ckpt_folder in enumerate(ckpt_folder_list):
    ax = axes[i, 1]
    n, _, _ = ax.hist(
        det_diff[ckpt_folder], bins=bins, label=ckpt_label_dict[ckpt_folder], facecolor=model_colors[i])
    tmp_list.append(np.max(n))
    ax.set_yscale('log')
    ax.set_xlim([0, 0.6])
    ax.tick_params(labelsize=7)
    # ax.legend(loc='upper right', fontsize=7)
    if i < n_cases-1:
        ax.set_xticks([])
n = np.max(tmp_list)
[ax.set_ylim([0.5, 10*n]) for ax in axes[:, 1]]
axes[-1, 1].set_xlabel(r'$\Delta p$', fontsize=9)

tmp_list = []
axes[0, 2].set_title('Detection mean', fontsize=9)
for i, ckpt_folder in enumerate(ckpt_folder_list):
    ax = axes[i, 2]
    mean_value = np.mean(det_mean[ckpt_folder])
    n, _, _ = ax.hist(
        det_mean[ckpt_folder], bins=bins, label=ckpt_label_dict[ckpt_folder], facecolor=model_colors[i])
    tmp_list.append(np.max(n))
    ax.set_xlim([global_thr, 1])
    ax.tick_params(labelsize=7)
    ax.plot([mean_value, mean_value], [0, 1000], '--k')
    # ax.legend(loc='upper left', fontsize=7)
    if i < n_cases-1:
        ax.set_xticks([])
n = np.max(tmp_list)
[ax.set_ylim([0, n+10]) for ax in axes[:, 2]]
axes[-1, 2].set_xlabel(r'$\mu(p)$', fontsize=9)

tmp_list = []
axes[0, 3].set_title('Detection std', fontsize=9)
for i, ckpt_folder in enumerate(ckpt_folder_list):
    ax = axes[i, 3]
    n, _, _ = ax.hist(
        det_std[ckpt_folder], bins=bins, label=ckpt_label_dict[ckpt_folder], facecolor=model_colors[i])
    tmp_list.append(np.max(n))
    ax.set_xlim([0, 0.2])
    ax.tick_params(labelsize=7)
    ax.legend(loc='upper left', bbox_to_anchor=(1.01, 1), frameon=False, handlelength=1, fontsize=8)
    if i < n_cases-1:
        ax.set_xticks([])
n = np.max(tmp_list)
[ax.set_ylim([0, n+10]) for ax in axes[:, 3]]
axes[-1, 3].set_xlabel(r'$\sigma(p)$', fontsize=9)
plt.tight_layout()
plt.show()

# Penalization functions

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(6, 2), dpi=200, sharex=True, sharey=True)

ax = axes[0]
delta_p = np.linspace(0, 1)
cost_1 = delta_p ** 2
ax.set_title(r'$C_1(\Delta p)$', fontsize=9)
ax.plot(delta_p, cost_1)
ax.set_xlim([0, 1])
ax.set_ylim([0, None])
ax.tick_params(labelsize=7)
ax.set_xlabel(r'$\Delta p$', fontsize=9)

delta_p = np.linspace(0, 1)
cost_2 = 1-4*((delta_p-0.5) ** 2)
ax = axes[1]
ax.set_title(r'$C_2(\Delta p)$', fontsize=9)
ax.plot(delta_p, cost_2)
ax.set_xlim([0, 1])
ax.set_ylim([0, None])
ax.tick_params(labelsize=7)
ax.set_xlabel(r'$\Delta p$', fontsize=9)

delta_p = np.linspace(0, 1)
cost_3_a = 2 * delta_p 
cost_3_b = 2 * (1 - delta_p)
cost_3 = np.stack([cost_3_a, cost_3_b], axis=1).min(axis=1)
ax = axes[2]
ax.set_title(r'$C_3(\Delta p)$', fontsize=9)
ax.plot(delta_p, cost_3)
ax.set_xlim([0, 1])
ax.set_ylim([0, None])
ax.tick_params(labelsize=7)
ax.set_xlabel(r'$\Delta p$', fontsize=9)

plt.tight_layout()
plt.show()

# Visualization on samples

In [None]:
def var_viz_signal_snapshot(
    t_central_sample,
    t_signal, 
    t_events_1,
    t_events_2,
    t_proba,
    t_opt_thr,
    ax,
    title='',
    show_stamp_edge=True,
    window_seconds=10,
    fs=200,
    max_voltage=100,
    ratio_signal_models=0.15
):
    
    width_det = ratio_signal_models * 2 * max_voltage
    expert_1_color = viz.PALETTE['blue']
    expert_2_color = viz.PALETTE['purple']
    start_sample = int(t_central_sample - fs * window_seconds / 2)
    end_sample = int(start_sample + window_seconds * fs)
    x_axis = np.arange(t_signal.size) / fs
    if t_events_1 is None:
        filt_exp_1 = []
    else:
        filt_exp_1 = utils.filter_stamps(t_events_1, start_sample, end_sample)
    if t_events_2 is None:
        filt_exp_2 = []
    else:
        filt_exp_2 = utils.filter_stamps(t_events_2, start_sample, end_sample)
    # Signal + expert
    ax.plot(
        x_axis[start_sample:end_sample], 
        np.clip(t_signal[start_sample:end_sample], -max_voltage, max_voltage), 
        linewidth=0.9, color=viz.PALETTE['dark'], zorder=20)
    # Dummy stamp for expert labels
    ax.fill_between([start_sample, start_sample], 0, 0, facecolor=expert_1_color, alpha=0.5, zorder=10, label='E1')
    ax.fill_between([start_sample, start_sample], 0, 0, facecolor=expert_2_color, alpha=0.5, zorder=10, label='E2')
    for s_stamp in filt_exp_1:
        if show_stamp_edge:
            ax.fill_between(
                s_stamp / fs, 0, 0.5 * max_voltage, facecolor=expert_1_color, edgecolor='k', alpha=0.5, zorder=10)
        else:
            ax.fill_between(
                s_stamp / fs, 0, 0.5 * max_voltage, facecolor=expert_1_color, alpha=0.5, zorder=10)
    for s_stamp in filt_exp_2:
        if show_stamp_edge:
            ax.fill_between(
                s_stamp / fs, -0.5 * max_voltage, 0, facecolor=expert_2_color, edgecolor='k', alpha=0.5, zorder=10)
        else:
            ax.fill_between(
                s_stamp / fs, -0.5 * max_voltage, 0, facecolor=expert_2_color, alpha=0.5, zorder=10)
    
    bottom = -max_voltage
    gap = 0.1 * width_det
    
    # Probability
    bottom = bottom - width_det - 2 * gap
    top = bottom + 2 * gap + width_det
    segment_proba = t_proba[start_sample//8:end_sample//8]
    proba_level = bottom + gap + segment_proba * width_det
    thr_level = bottom + gap + t_opt_thr * width_det
    ax.plot(
        x_axis[start_sample:end_sample][4::8],
        proba_level,
        linewidth=1.1, color=viz.PALETTE['red'], zorder=40)
    ax.plot(
        [start_sample / fs, end_sample / fs], [thr_level, thr_level], 
        linewidth=1.1, color=viz.GREY_COLORS[8], zorder=30)
    ax.plot(
        [start_sample/fs, end_sample/fs], [bottom, bottom], 
        linewidth=1.1, color=viz.GREY_COLORS[8], zorder=5)
    ax.plot(
        [start_sample/fs, end_sample/fs], [top, top], 
        linewidth=1.1, color=viz.GREY_COLORS[8], zorder=5)
    ax.annotate(
        'Model',
        (start_sample / fs + 0.1, top - 0.1 * width_det),
        xycoords='data', fontsize=8, 
        horizontalalignment='left', verticalalignment='top')

    # Lags
    lag_range = [1, 2, 4, 6]
    for j_lag in lag_range:
        lag_in_seconds = j_lag / (fs / 8)
        t_delta_p_lagged = np.abs(segment_proba[j_lag:] - segment_proba[:-j_lag])
        x_axis_lagged = x_axis[start_sample:end_sample][4::8]
        skip_first = j_lag // 2
        skip_last = j_lag - skip_first
        x_axis_lagged = x_axis_lagged[skip_first:-skip_last]
        bottom = bottom - width_det - 2 * gap
        top = bottom + 2 * gap + width_det
        thr_level = bottom + gap + 0.5 * width_det
        proba_level = bottom + gap + t_delta_p_lagged * width_det
        ax.plot(
            x_axis_lagged,
            proba_level,
            linewidth=1.1, color=viz.PALETTE['cyan'], zorder=40)
        ax.plot(
            [start_sample / fs, end_sample / fs], [thr_level, thr_level], 
            linewidth=1.1, color=viz.GREY_COLORS[8], zorder=30)
        ax.plot(
            [start_sample/fs, end_sample/fs], [bottom, bottom], 
            linewidth=1.1, color=viz.GREY_COLORS[8], zorder=5)
        ax.plot(
            [start_sample/fs, end_sample/fs], [top, top], 
            linewidth=1.1, color=viz.GREY_COLORS[8], zorder=5)
        ax.annotate(
            'Lag %1.2f[s]' % lag_in_seconds,
            (start_sample / fs + 0.1, top - 0.1 * width_det),
            xycoords='data', fontsize=8, 
            horizontalalignment='left', verticalalignment='top')
    ax.set_title(title, fontsize=10)
    ax.set_ylim([bottom, max_voltage])
    ax.set_xlim([start_sample / fs, end_sample / fs])
    ax.set_yticks([-50, 0, 50])
    ax.set_xticks(start_sample / fs + np.arange(window_seconds), minor=True)
    ax.set_xticks([t_central_sample / fs])
    ax.set_xticklabels(['%1.1f [s]' % (t_central_sample / fs)])
    lg = ax.legend(
        loc='upper left', labelspacing=1.2,
        fontsize=9, frameon=False, ncol=2)
    return ax, lg

In [None]:
chosen_seed = 1
set_name = constants.VAL_SUBSET

val_ids = ids_dict[chosen_seed][set_name]
events_1 = FeederDataset(dataset, val_ids, task_mode, which_expert=1).get_stamps()
if dataset_name == constants.MASS_SS_NAME:
    events_2 = FeederDataset(dataset, val_ids, task_mode, which_expert=2).get_stamps()
else:
    events_2 = len(val_ids) * [None]
subset_data = FeederDataset(dataset, val_ids, task_mode, which_expert=which_expert)
events = subset_data.get_stamps()
signals = subset_data.get_signals(normalize_clip=False)

In [None]:
mod_idx = 1
n_samples = 10
dpi = 200
figsize = (7, 4)
save_figs = True
# -------------------------
ckpt_folder = ckpt_folder_list[mod_idx]
prefix_str = 'var_viz_seed%d' % chosen_seed
os.makedirs(prefix_str, exist_ok=True)
print('Database: %s, Expert: %d' % (dataset_name, which_expert))
print('Variability visualization for Seed %d' % chosen_seed)
fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)
t_preds = predictions_dict[ckpt_folder][chosen_seed][set_name]
t_proba = t_preds.get_probabilities()
for i, single_id in enumerate(val_ids):
    s_proba = t_proba[i]
    s_events = events[i]
    s_signal = signals[i]
    print("")
    print('Subject %02d (%d annotations)' % (single_id, s_events.shape[0]))
    valid_central_samples = s_events.mean(axis=1).astype(np.int32)
    valid_chosen_centers = np.random.choice(valid_central_samples, size=n_samples, replace=False)  
    if save_figs:
        for this_center in valid_chosen_centers:
            title = 'Prediction Variability (S%02d-%s-Seed%d)' % (single_id, set_name.upper(), chosen_seed)
            fname = '%s_s%02d_idx%d.png' % (prefix_str, single_id, this_center)
            ax.clear()
            # Draw variability
            ax, lg = var_viz_signal_snapshot(
                this_center,
                s_signal, 
                events_1[i],
                events_2[i],
                s_proba,
                OPTIMAL_THR_FOR_CKPT_DICT[ckpt_folder][chosen_seed],
                ax,
                title=title,
                show_stamp_edge=True,
                window_seconds=10,
                fs=200,
                max_voltage=100,
                ratio_signal_models=0.15
            )
            plt.savefig(
                os.path.join(prefix_str, fname), 
                dpi=200, bbox_extra_artists=(lg,), bbox_inches="tight", pad_inches=0.01)
plt.close('all')
print("Done")