# Imports

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import json
import os
from pprint import pprint
import sys

import ipywidgets as widgets
import matplotlib.pyplot as plt
from matplotlib import colors, gridspec
import numpy as np
from scipy.stats import gaussian_kde
from scipy.interpolate import interp1d
from matplotlib.lines import Line2D

project_root = '..'
sys.path.append(project_root)

from sleeprnn.common import constants, pkeys, viz
from sleeprnn.common.optimal_thresholds import OPTIMAL_THR_FOR_CKPT_DICT
from sleeprnn.data import utils, stamp_correction
from sleeprnn.detection.feeder_dataset import FeederDataset
from sleeprnn.detection.postprocessor import PostProcessor
from sleeprnn.detection import metrics
from sleeprnn.helpers import reader, plotter, printer, misc, performer

RESULTS_PATH = os.path.join(project_root, 'results')
COMPARISON_PATH = os.path.join(project_root, 'resources', 'comparison_data')

%matplotlib inline
viz.notebook_full_width()

# Load data

In [None]:
filter_dates = [20200606, None]
printer.print_available_ckpt(OPTIMAL_THR_FOR_CKPT_DICT, filter_dates)

In [None]:
dataset_name = constants.MASS_SS_NAME
fs = 200
which_expert = 1
task_mode = constants.N2_RECORD
seed_id_list = [i for i in range(2)]
set_list = [constants.VAL_SUBSET, constants.TRAIN_SUBSET]

# Specify what to load
comparison_runs_list = [
    ('20191227_bsf_10runs_e1_n2_train_mass_ss/v11', 'Regular\nCrossEntropy'),
    ('20200606_var_reg_first_grid_n2_train_mass_ss/v11_reg_-6.0', 'Weights\n+ $\lambda$ = 1e-6'),
    ('20200606_var_reg_first_grid_n2_train_mass_ss/v11_reg_-1.0', 'Weights\n+ $\lambda$ = 0.1'),
    ('20200606_var_reg_first_grid_n2_train_mass_ss/v11_reg_0.0', 'Weights\n+ $\lambda$ = 1'),
    ('20200606_var_reg_first_grid_n2_train_mass_ss/v11_reg_1.0', 'Weights\n+ $\lambda$ = 10'),
]
comparison_runs_list = [
    (t_folder, t_label) for (t_folder, t_label) in comparison_runs_list if dataset_name in t_folder
]
ckpt_folder_list = [t_folder for (t_folder, t_label) in comparison_runs_list]
ckpt_folder_dict = {t_label: t_folder for (t_folder, t_label) in comparison_runs_list}
ckpt_label_dict = {t_folder: t_label for (t_folder, t_label) in comparison_runs_list}

# Load data
n_cases = len(comparison_runs_list)
dataset = reader.load_dataset(dataset_name, params={pkeys.FS: fs})
ids_dict = {
    constants.ALL_TRAIN_SUBSET: dataset.train_ids,
    constants.TEST_SUBSET: dataset.test_ids}
ids_dict.update(misc.get_splits_dict(dataset, seed_id_list))
predictions_dict = {}
for ckpt_folder in ckpt_folder_list:
    predictions_dict[ckpt_folder] = reader.read_prediction_with_seeds(
        ckpt_folder, dataset_name, task_mode, seed_id_list, set_list=set_list, parent_dataset=dataset)
# useful for viz
iou_hist_bins = np.linspace(0, 1, 21)
iou_curve_axis = misc.custom_linspace(0.05, 0.95, 0.05)
result_id = '%s-%s-E%d-%s' % (
    dataset_name.split('_')[0].upper(), 
    dataset_name.split('_')[1].upper(), 
    which_expert,
    task_mode.upper())
expert_data_dict = reader.load_ss_expert_performance()
exp_keys = list(expert_data_dict.keys())
print('\nAvailable data:')
pprint(exp_keys)

# Prediction Variability

In [None]:
seeds_to_show = [0, 1]
set_name = constants.VAL_SUBSET
global_thr = 0.35
model_colors = [viz.PALETTE['red'], viz.PALETTE['blue'],  viz.PALETTE['green'], viz.PALETTE['purple'], viz.PALETTE['cyan']]

n2_diff = {}
det_diff = {}
det_mean = {}
det_std = {}
for ckpt_folder in ckpt_folder_list:
    print(ckpt_label_dict[ckpt_folder], OPTIMAL_THR_FOR_CKPT_DICT[ckpt_folder])
    n2_diff[ckpt_folder] = []
    det_diff[ckpt_folder] = []
    det_mean[ckpt_folder] = []
    det_std[ckpt_folder] = []
    for k in seeds_to_show:
        val_ids = ids_dict[k][set_name]
        t_preds = predictions_dict[ckpt_folder][k][set_name]
        t_preds.set_probability_threshold(global_thr)
        t_dets = t_preds.get_stamps()
        t_probas = t_preds.get_probabilities()
        t_pages = t_preds.get_pages(pages_subset=constants.N2_RECORD)
        for i, single_id in enumerate(val_ids):
            s_pages = t_pages[i]
            s_dets = t_dets[i]
            s_probas = t_probas[i]
            s_probas_n2 = utils.extract_pages(s_probas, s_pages, 4000 // 8)

            tmp_array = np.abs(np.diff(s_probas_n2)).flatten()
            n2_diff[ckpt_folder].append(tmp_array)

            s_dets_down = np.round(s_dets / 8).astype(np.int32)            
            s_probas_dets = [s_probas[t0:tf] for (t0, tf) in s_dets_down]
            s_probas_dets = [segment[1:-1] for segment in s_probas_dets]
            
            tmp_array = np.concatenate([np.abs(np.diff(segment)) for segment in s_probas_dets])
            det_diff[ckpt_folder].append(tmp_array)
            tmp_array = np.array([np.mean(segment) for segment in s_probas_dets])
            det_mean[ckpt_folder].append(tmp_array)
            tmp_array = np.array([np.std(segment) for segment in s_probas_dets])
            det_std[ckpt_folder].append(tmp_array)
            
    n2_diff[ckpt_folder] = np.concatenate(n2_diff[ckpt_folder])
    det_diff[ckpt_folder] = np.concatenate(det_diff[ckpt_folder])
    det_mean[ckpt_folder] = np.concatenate(det_mean[ckpt_folder])
    det_std[ckpt_folder] = np.concatenate(det_std[ckpt_folder])

In [None]:
n_cases = len(ckpt_folder_list)
bins = np.linspace(0, 1, 41)

fig, axes = plt.subplots(n_cases, 4, figsize=(8, 1*n_cases), dpi=120)

tmp_list = []
axes[0, 0].set_title('Overall variability', fontsize=9)
for i, ckpt_folder in enumerate(ckpt_folder_list):
    ax = axes[i, 0]
    n, _, _ = ax.hist(
        n2_diff[ckpt_folder], bins=bins, label=ckpt_label_dict[ckpt_folder], facecolor=model_colors[i])
    tmp_list.append(np.max(n))
    ax.set_yscale('log')
    ax.set_xlim([0, 0.6])
    ax.tick_params(labelsize=7)
    # ax.legend(loc='upper right', fontsize=7)
    if i < n_cases-1:
        ax.set_xticks([])
n = np.max(tmp_list)
[ax.set_ylim([0.5, 10*n]) for ax in axes[:, 0]]
axes[-1, 0].set_xlabel(r'$\Delta p$', fontsize=9)

tmp_list = []
axes[0, 1].set_title('Detection variability', fontsize=9)
for i, ckpt_folder in enumerate(ckpt_folder_list):
    ax = axes[i, 1]
    n, _, _ = ax.hist(
        det_diff[ckpt_folder], bins=bins, label=ckpt_label_dict[ckpt_folder], facecolor=model_colors[i])
    tmp_list.append(np.max(n))
    ax.set_yscale('log')
    ax.set_xlim([0, 0.6])
    ax.tick_params(labelsize=7)
    # ax.legend(loc='upper right', fontsize=7)
    if i < n_cases-1:
        ax.set_xticks([])
n = np.max(tmp_list)
[ax.set_ylim([0.5, 10*n]) for ax in axes[:, 1]]
axes[-1, 1].set_xlabel(r'$\Delta p$', fontsize=9)

tmp_list = []
axes[0, 2].set_title('Detection mean', fontsize=9)
for i, ckpt_folder in enumerate(ckpt_folder_list):
    ax = axes[i, 2]
    n, _, _ = ax.hist(
        det_mean[ckpt_folder], bins=bins, label=ckpt_label_dict[ckpt_folder], facecolor=model_colors[i])
    tmp_list.append(np.max(n))
    ax.set_xlim([global_thr, 1])
    ax.tick_params(labelsize=7)
    # ax.legend(loc='upper left', fontsize=7)
    if i < n_cases-1:
        ax.set_xticks([])
n = np.max(tmp_list)
[ax.set_ylim([0, n+10]) for ax in axes[:, 2]]
axes[-1, 2].set_xlabel(r'$\mu(p)$', fontsize=9)

tmp_list = []
axes[0, 3].set_title('Detection std', fontsize=9)
for i, ckpt_folder in enumerate(ckpt_folder_list):
    ax = axes[i, 3]
    n, _, _ = ax.hist(
        det_std[ckpt_folder], bins=bins, label=ckpt_label_dict[ckpt_folder], facecolor=model_colors[i])
    tmp_list.append(np.max(n))
    ax.set_xlim([0, 0.2])
    ax.tick_params(labelsize=7)
    ax.legend(loc='upper left', bbox_to_anchor=(1.01, 1), frameon=False, handlelength=1, fontsize=8)
    if i < n_cases-1:
        ax.set_xticks([])
n = np.max(tmp_list)
[ax.set_ylim([0, n+10]) for ax in axes[:, 3]]
axes[-1, 3].set_xlabel(r'$\sigma(p)$', fontsize=9)
plt.tight_layout()
plt.show()

# Penalization functions

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(6, 2), dpi=120, sharex=True, sharey=True)

ax = axes[0]
delta_p = np.linspace(0, 1)
cost_1 = delta_p ** 2
ax.set_title(r'$C_1(\Delta p)$', fontsize=9)
ax.plot(delta_p, cost_1)
ax.set_xlim([0, 1])
ax.set_ylim([0, None])
ax.tick_params(labelsize=7)
ax.set_xlabel(r'$\Delta p$', fontsize=9)

delta_p = np.linspace(0, 1)
cost_2 = 1-4*((delta_p-0.5) ** 2)
ax = axes[1]
ax.set_title(r'$C_2(\Delta p)$', fontsize=9)
ax.plot(delta_p, cost_2)
ax.set_xlim([0, 1])
ax.set_ylim([0, None])
ax.tick_params(labelsize=7)
ax.set_xlabel(r'$\Delta p$', fontsize=9)

delta_p = np.linspace(0, 1)
cost_3_a = 2 * delta_p 
cost_3_b = 2 * (1 - delta_p)
cost_3 = np.stack([cost_3_a, cost_3_b], axis=1).min(axis=1)
ax = axes[2]
ax.set_title(r'$C_3(\Delta p)$', fontsize=9)
ax.plot(delta_p, cost_3)
ax.set_xlim([0, 1])
ax.set_ylim([0, None])
ax.tick_params(labelsize=7)
ax.set_xlabel(r'$\Delta p$', fontsize=9)

plt.tight_layout()
plt.show()