# Imports

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import json
import os
from pprint import pprint
import sys

import ipywidgets as widgets
import matplotlib.pyplot as plt
from matplotlib import colors, gridspec
import numpy as np
from scipy.stats import gaussian_kde
from scipy.interpolate import interp1d

project_root = '..'
sys.path.append(project_root)

from sleeprnn.common import constants, pkeys, viz
from sleeprnn.common.optimal_thresholds import OPTIMAL_THR_FOR_CKPT_DICT
from sleeprnn.data import utils, stamp_correction
from sleeprnn.detection.feeder_dataset import FeederDataset
from sleeprnn.detection.postprocessor import PostProcessor
from sleeprnn.detection import metrics
from sleeprnn.helpers import reader, plotter, printer, misc, performer

RESULTS_PATH = os.path.join(project_root, 'results')
COMPARISON_PATH = os.path.join(project_root, 'resources', 'comparison_data')

%matplotlib inline
viz.notebook_full_width()

# Load data

In [None]:
filter_dates = [20191220, None]
printer.print_available_ckpt(OPTIMAL_THR_FOR_CKPT_DICT, filter_dates)

In [None]:
dataset_name = constants.MASS_SS_NAME
fs = 200
which_expert = 1
task_mode = constants.N2_RECORD
seed_id_list = [i for i in range(4)]
set_list = [constants.VAL_SUBSET]

# Specify what to load
comparison_runs_list = [
    ('20191227_bsf_10runs_e1_n2_train_mass_ss/v11', 'RED-Time'),
    ('20191227_bsf_10runs_e1_n2_train_mass_kc/v11', 'RED-Time'),
    ('20200407_attention_grid_n2_train_mass_ss/att01_d_512_h_8_fc_0', 'ATT01'),
    ('20200409_attention_grid_n2_train_mass_kc/att01_d_512_h_8_fc_0', 'ATT01'),
    ('20200408_att03_grid_n2_train_mass_ss/att03_d_512_h_8_fc_0', 'ATT03'),
    ('20200409_att03_grid_n2_train_mass_kc/att03_d_512_h_8_fc_0', 'ATT03'),
    # ('20200410_att04_task_pe_grid_n2_train_mass_ss/att04_pe_10000', 'ATT04'),
    ('20200409_att04_head_grid_n2_train_mass_ss/att04_h_08', 'ATT04'),
    ('20200410_att04_task_pe_grid_n2_train_mass_kc/att04_pe_10000', 'ATT04'),
]
comparison_runs_list = [
    (t_folder, t_label) for (t_folder, t_label) in comparison_runs_list if dataset_name in t_folder
]
ckpt_folder_list = [t_folder for (t_folder, t_label) in comparison_runs_list]
ckpt_label_dict = {t_folder: t_label for (t_folder, t_label) in comparison_runs_list}

# Load data
n_cases = len(comparison_runs_list)
dataset = reader.load_dataset(dataset_name, params={pkeys.FS: fs})
ids_dict = {
    constants.ALL_TRAIN_SUBSET: dataset.train_ids,
    constants.TEST_SUBSET: dataset.test_ids}
ids_dict.update(misc.get_splits_dict(dataset, seed_id_list))
predictions_dict = {}
for ckpt_folder in ckpt_folder_list:
    predictions_dict[ckpt_folder] = reader.read_prediction_with_seeds(
        ckpt_folder, dataset_name, task_mode, seed_id_list, set_list=set_list, parent_dataset=dataset)

In [None]:
iou_hist_bins = np.linspace(0, 1, 21)
iou_curve_axis = misc.custom_linspace(0.05, 0.95, 0.05)
result_id = '%s-%s-E%d-%s' % (
    dataset_name.split('_')[0].upper(), 
    dataset_name.split('_')[1].upper(), 
    which_expert,
    task_mode.upper())

# Performance

In [None]:
# Performance
data_dict = {}
for ckpt_folder in ckpt_folder_list:
    print(ckpt_label_dict[ckpt_folder])
    t_data_dict = performer.performance_vs_iou_with_seeds(
        dataset,
        predictions_dict[ckpt_folder],
        OPTIMAL_THR_FOR_CKPT_DICT[ckpt_folder],
        iou_curve_axis,
        iou_hist_bins,
        task_mode,
        which_expert,
        set_name=constants.VAL_SUBSET
    )
    
    # Mean performance
    print('Val AF1: %1.2f +- %1.2f' % (
        100 * t_data_dict[constants.MEAN_AF1].mean(), 100 * t_data_dict[constants.MEAN_AF1].std()
    ))
    print('Val Mean IoU at TP: %1.2f +- %1.2f' % (
        100 * t_data_dict[constants.MEAN_IOU].mean(), 100 * t_data_dict[constants.MEAN_IOU].std()
    ))
    
    data_dict[ckpt_folder] = t_data_dict
    print("")

In [None]:
color_list = [viz.PALETTE['red'], viz.PALETTE['green'], viz.PALETTE['grey'], viz.PALETTE['blue']]
marker_list = 4 * ['o']
alpha_line_list = 4 * [1]
zorder_list = [30] + 3 * [20]
idx_to_remove = []
# idx_to_remove = [0]

# Plot f1 vs iou specs
smaller_plot = False
external_legend = True
show_seed_std = False
print_formatted_table = True
zoom_plot = False
zoom_f1 = [0.5, 0.85]
alpha_seed_std = 0.4
alpha_expert = 0.5
iou_thr_to_show = 0.2
figsize = (4, 4)
title = '4-fold cross-validation (%s)' % result_id

# -------------------- P L O T ----------------------  
print('Database: %s, Expert: %d' % (dataset_name, which_expert))
print('IoU to show: %1.1f' % iou_thr_to_show)
this_dpi = 100 if smaller_plot else viz.DPI
fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=this_dpi)

for i, ckpt_folder in enumerate(ckpt_folder_list):
    if i in idx_to_remove:
        continue
    this_label = ckpt_label_dict[ckpt_folder]
    this_marker = marker_list[i]
    this_alpha = alpha_line_list[i]
    this_color = color_list[i]
    this_zorder = zorder_list[i]
    model_data_dict = data_dict[ckpt_folder]
    mean_f1_vs_iou = model_data_dict[constants.F1_VS_IOU].mean(axis=0)
    std_f1_vs_iou = model_data_dict[constants.F1_VS_IOU].std(axis=0)
    ax.plot(
        iou_curve_axis, mean_f1_vs_iou, linewidth=viz.LINEWIDTH, zorder=this_zorder, label=this_label,
        markersize=viz.MARKERSIZE, markevery=(1, 2),
        marker=this_marker, color=this_color, alpha=this_alpha
    )
    if show_seed_std:
        ax.fill_between(
            iou_curve_axis, mean_f1_vs_iou - std_f1_vs_iou, mean_f1_vs_iou + std_f1_vs_iou, 
            alpha=alpha_seed_std, facecolor=this_color, zorder=this_zorder)
    printer.print_performance_at_iou(model_data_dict, iou_thr_to_show, this_label)
    
if print_formatted_table:
    print("")
    for i, ckpt_folder in enumerate(ckpt_folder_list):
        printer.print_formatted_performance_at_iou(
            data_dict[ckpt_folder], 
            iou_thr_to_show, 
            ckpt_label_dict[ckpt_folder], 
            print_header=(i==0))

ax.set_title(title, fontsize=viz.FONTSIZE_TITLE, loc='center')
ax = plotter.format_metric_vs_iou_plot(ax, 'F1-score', iou_thr_to_show)
lg = plotter.format_legend(ax, external_legend)
if zoom_plot:
    ax.set_ylim(zoom_f1)
lg = plotter.set_legend_color(lg)
ax = plotter.set_axis_color(ax)
plt.show()