# Imports

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import json
import os
from pprint import pprint
import sys
from joblib import delayed, Parallel

import ipywidgets as widgets
import matplotlib.pyplot as plt
from matplotlib import colors, gridspec
import numpy as np
from scipy.stats import gaussian_kde
from scipy.interpolate import interp1d
from sklearn.linear_model import LinearRegression, HuberRegressor
project_root = '..'
sys.path.append(project_root)

from sleeprnn.common import constants, pkeys, viz
from sleeprnn.common.optimal_thresholds import OPTIMAL_THR_FOR_CKPT_DICT
from sleeprnn.data import utils, stamp_correction
from sleeprnn.detection.feeder_dataset import FeederDataset
from sleeprnn.detection.postprocessor import PostProcessor
from sleeprnn.detection import metrics
from sleeprnn.helpers import reader, plotter, printer, misc, performer

RESULTS_PATH = os.path.join(project_root, 'results')
COMPARISON_PATH = os.path.join(project_root, 'resources', 'comparison_data')

%matplotlib inline
viz.notebook_full_width()

# Load data

In [None]:
filter_dates = [20200724, None]
printer.print_available_ckpt(OPTIMAL_THR_FOR_CKPT_DICT, filter_dates)

In [None]:
dataset_name = constants.MASS_SS_NAME
which_expert = 1
seed_id_list = [i for i in range(4)]
task_mode = constants.N2_RECORD
fs = 200
set_list = [constants.VAL_SUBSET, constants.TRAIN_SUBSET]

# Specify what to load
comparison_runs_list = [
    (
        '20201024_combi_completa_n2_train_mass_ss/v19_noisy_waves1_ab0.0_focal0.25-0.25', 
        'RED-CWT-Focal-Noisy-Waves', 'v19-focal-noisy-waves'),
    # ('20200724_reproduce_red_n2_train_mass_ss/v19_rep1', 'RED-CWT', 'v19'),
    #('20191227_bsf_10runs_e1_n2_train_mass_ss/v11', 'RED-Time', 'v11'),
    # ('20191227_bsf_10runs_e1_n2_train_mass_ss/v19', 'RED-CWT', 'v19'),
    #('20191227_bsf_10runs_e2_n2_train_mass_ss/v11', 'RED-Time', 'v11'),
    #('20191227_bsf_10runs_e2_n2_train_mass_ss/v19', 'RED-CWT', 'v19'),
    #('20191227_bsf_10runs_e1_n2_train_mass_kc/v11', 'RED-Time', 'v11'),
    #('20191227_bsf_10runs_e1_n2_train_mass_kc/v19', 'RED-CWT', 'v19'),
]
comparison_runs_list = [
    (t_folder, t_label, t_code) 
    for (t_folder, t_label, t_code) in comparison_runs_list 
    if (dataset_name in t_folder)
    # if (dataset_name in t_folder) and ('e%d' % which_expert) in t_folder
]
ckpt_folder_list = [t_folder for (t_folder, t_label, t_code) in comparison_runs_list]
ckpt_folder_dict = {t_label: t_folder for (t_folder, t_label, t_code) in comparison_runs_list}
ckpt_label_dict = {t_folder: t_label for (t_folder, t_label, t_code) in comparison_runs_list}
ckpt_label_list = [t_label for (t_folder, t_label, t_code) in comparison_runs_list]
ckpt_code_list = [t_code for (t_folder, t_label, t_code) in comparison_runs_list]

# Load data
n_cases = len(comparison_runs_list)
dataset = reader.load_dataset(dataset_name, params={pkeys.FS: fs})
ids_dict = {
    constants.ALL_TRAIN_SUBSET: dataset.train_ids,
    constants.TEST_SUBSET: dataset.test_ids}
ids_dict.update(misc.get_splits_dict(dataset, seed_id_list))
predictions_dict = {}
for ckpt_folder in ckpt_folder_list:
    predictions_dict[ckpt_folder] = reader.read_prediction_with_seeds(
        ckpt_folder, dataset_name, task_mode, seed_id_list, set_list=set_list, parent_dataset=dataset)
# useful for viz
iou_hist_bins = np.linspace(0, 1, 21)
iou_curve_axis = misc.custom_linspace(0.05, 0.95, 0.05)
result_id = '%s-%s-E%d-%s' % (
    dataset_name.split('_')[0].upper(), 
    dataset_name.split('_')[1].upper(), 
    which_expert,
    task_mode.upper())
expert_data_dict = reader.load_ss_expert_performance()
exp_keys = list(expert_data_dict.keys())
print('\nAvailable data:')
pprint(exp_keys)
model_names = ckpt_label_list
code_names = ckpt_code_list
models = []
for name, code_name in zip(model_names, code_names):
    models.append({'name': name, 'ckpt': ckpt_folder_dict[name], 'code_name': code_name})

In [None]:
save_figs = False
save_txt = False
folder_name = 'cheating_%s_e%d' % (dataset_name, which_expert)
os.makedirs(folder_name, exist_ok=True)

## Precision-Recall Curve subjects

In [None]:
seeds_to_show = seed_id_list
n_seeds = len(seeds_to_show)
set_list = ['val']
iou_to_show = 0.2
dpi = 200
res_thr = 0.01
start_thr = 0.1
end_thr = 0.95
iou_idx = misc.closest_index(iou_to_show, iou_curve_axis) 
color_dict = {
    'train': {i: viz.GREY_COLORS[4] for i in range(4)},
    'val': {0: viz.PALETTE['red'], 1: viz.PALETTE['blue'], 2: viz.PALETTE['green'], 3: viz.PALETTE['dark']}
}
markersize_model = 6
axis_markers = [0, 0.2, 0.4, 0.6, 0.8, 1.0]
n_thr = int(np.floor((end_thr - start_thr) / res_thr + 1))
thr_list = np.array([start_thr + res_thr * i for i in range(n_thr)])
thr_list = np.round(thr_list, 2)
if save_txt:
    f = open(os.path.join(folder_name, 'cheating_metrics_xval.txt'), 'w')
print('Thr grid search: %1.2f:%1.2f:%1.2f' % (start_thr, res_thr, end_thr))
if save_txt:
    print('Thr grid search: %1.2f:%1.2f:%1.2f' % (start_thr, res_thr, end_thr), file=f)
for j_m, model in enumerate(models):
    for cheating in [False, True]:
        if cheating:
            title = '%s (Cheating)\n(%s) - ValSet IoU>%1.1f' % (model['name'], result_id, iou_to_show)
        else:
            title = '%s\n(%s) - ValSet IoU>%1.1f' % (model['name'], result_id, iou_to_show)
        print('\n%s' % title)
        if save_txt:
            print('\n%s' % title, file=f)
        fig, ax = plt.subplots(1, 1, figsize=(3, 3), dpi=viz.DPI if dpi is None else dpi)
        store_f1 = []
        store_precision = []
        store_recall = []
        for k_ax, seed_id_for_f1vsiou in enumerate(seeds_to_show):
            # ---------------- Compute performance
            pre_vs_iou_subject_dict = {}
            rec_vs_iou_subject_dict = {}
            for set_name in set_list:
                # print('Processing %s' % set_name, flush=True)
                # Prepare expert labels
                data_inference = FeederDataset(
                    dataset, ids_dict[seed_id_for_f1vsiou][set_name], task_mode, which_expert)
                this_ids = data_inference.get_ids()
                this_events_list = data_inference.get_stamps()
                # Prepare model predictions
                prediction_obj = predictions_dict[model['ckpt']][seed_id_for_f1vsiou][set_name]
                if cheating:
                    for i, single_id in enumerate(this_ids):
                        predictions_at_thr_list = []
                        t_proba = prediction_obj.get_subject_probabilities(single_id)
                        max_valid = t_proba.max() - 0.05
                        thr_list_subject = thr_list[thr_list < max_valid]
                        for thr in thr_list_subject:
                            prediction_obj.set_probability_threshold(thr)
                            this_detections = prediction_obj.get_subject_stamps(single_id)
                            predictions_at_thr_list.append(this_detections)
                        single_events = this_events_list[i]
                        af1_list = Parallel(n_jobs=-1)(
                            delayed(metrics.average_metric)(single_events, single_prediction, verbose=False)
                            for single_prediction in predictions_at_thr_list)
                        max_idx = np.argmax(af1_list).item()
                        best_thr = thr_list_subject[max_idx]
                        prediction_obj.set_probability_threshold(best_thr)
                        single_detections = prediction_obj.get_subject_stamps(single_id)
                        this_precision = metrics.metric_vs_iou(
                            single_events, single_detections, iou_curve_axis, metric_name=constants.PRECISION)
                        this_recall = metrics.metric_vs_iou(
                            single_events, single_detections, iou_curve_axis, metric_name=constants.RECALL)
                        this_f1 = 2 * this_precision * this_recall / (this_precision + this_recall)
                        # Print performance
                        print("S%02d (seed%d) F1: %1.1f -- Precision %1.1f -- Recall %1.1f  (thr %1.2f)" % (
                            single_id, 
                            seed_id_for_f1vsiou,
                            100 * this_f1[iou_idx],
                            100 * this_precision[iou_idx],
                            100 * this_recall[iou_idx],
                            best_thr
                        ))
                        if save_txt:
                            print("S%02d (seed%d) F1: %1.1f -- Precision %1.1f -- Recall %1.1f  (thr %1.2f)" % (
                                single_id, 
                                seed_id_for_f1vsiou,
                                100 * this_f1[iou_idx],
                                100 * this_precision[iou_idx],
                                100 * this_recall[iou_idx],
                                best_thr
                            ), file=f)
                        store_f1.append(100 * this_f1[iou_idx])
                        store_precision.append(100 * this_precision[iou_idx])
                        store_recall.append(100 * this_recall[iou_idx])
                        pre_vs_iou_subject_dict[single_id] = this_precision
                        rec_vs_iou_subject_dict[single_id] = this_recall
                else:
                    prediction_obj.set_probability_threshold(OPTIMAL_THR_FOR_CKPT_DICT[model['ckpt']][seed_id_for_f1vsiou])
                    this_detections_list = prediction_obj.get_stamps()
                    for i, single_id in enumerate(this_ids):
                        single_events = this_events_list[i]
                        single_detections = this_detections_list[i]
                        this_precision = metrics.metric_vs_iou(
                            single_events, single_detections, iou_curve_axis, metric_name=constants.PRECISION)
                        this_recall = metrics.metric_vs_iou(
                            single_events, single_detections, iou_curve_axis, metric_name=constants.RECALL)
                        this_f1 = 2 * this_precision * this_recall / (this_precision + this_recall)
                        # Print performance
                        print("S%02d (seed%d) F1: %1.1f -- Precision %1.1f -- Recall %1.1f  (thr %1.2f)" % (
                            single_id, 
                            seed_id_for_f1vsiou,
                            100 * this_f1[iou_idx],
                            100 * this_precision[iou_idx],
                            100 * this_recall[iou_idx],
                            OPTIMAL_THR_FOR_CKPT_DICT[model['ckpt']][seed_id_for_f1vsiou]
                        ))
                        if save_txt:
                            print("S%02d (seed%d) F1: %1.1f -- Precision %1.1f -- Recall %1.1f  (thr %1.2f)" % (
                                single_id, 
                                seed_id_for_f1vsiou,
                                100 * this_f1[iou_idx],
                                100 * this_precision[iou_idx],
                                100 * this_recall[iou_idx],
                                OPTIMAL_THR_FOR_CKPT_DICT[model['ckpt']][seed_id_for_f1vsiou]
                            ), file=f)
                        store_f1.append(100 * this_f1[iou_idx])
                        store_precision.append(100 * this_precision[iou_idx])
                        store_recall.append(100 * this_recall[iou_idx])
                        pre_vs_iou_subject_dict[single_id] = this_precision
                        rec_vs_iou_subject_dict[single_id] = this_recall
            # print('Done', flush=True)

            # -------------------- P L O T ----------------------    

            for set_name in set_list:
                for i, single_id in enumerate(ids_dict[seed_id_for_f1vsiou][set_name]):
                    if i == 0:
                        label = 'Seed %d' % seed_id_for_f1vsiou
                    else:
                        label = None
                    ax.plot(
                        rec_vs_iou_subject_dict[single_id][iou_idx], 
                        pre_vs_iou_subject_dict[single_id][iou_idx],
                        color=color_dict[set_name][k_ax], marker='o', 
                        markersize=markersize_model, label=label, linestyle = 'None'
                    )
        print("Avg F1: %1.1f \u00B1 %1.1f -- Precision %1.1f \u00B1 %1.1f -- Recall %1.1f \u00B1 %1.1f" % (
            np.mean(store_f1), np.std(store_f1),
            np.mean(store_precision), np.std(store_precision),
            np.mean(store_recall), np.std(store_recall)
        ))
        if save_txt:
            print("Avg F1: %1.1f \u00B1 %1.1f -- Precision %1.1f \u00B1 %1.1f -- Recall %1.1f \u00B1 %1.1f" % (
                np.mean(store_f1), np.std(store_f1),
                np.mean(store_precision), np.std(store_precision),
                np.mean(store_recall), np.std(store_recall)
            ), file=f)
        ax.plot([0, 1], [0, 1], zorder=1, linewidth=1, color=viz.GREY_COLORS[4])
        ax.set_title(title, fontsize=8)
        ax.set_xlim([0, 1])
        ax.set_ylim([0, 1])
        ax.set_yticks(axis_markers)
        ax.set_xticks(axis_markers)
        ax.tick_params(labelsize=8) 
        ax.set_ylabel('Precision', fontsize=8)
        ax.set_xlabel('Recall', fontsize=8)
        ax.set_aspect('equal')
        ax.legend(loc='lower left', fontsize=8)
        plt.tight_layout()
        if save_figs:
            if cheating:
                fname = os.path.join(folder_name, "pr_%s_seeds_cheating.png" % model['code_name'])
            else:
                fname = os.path.join(folder_name, "pr_%s_seeds.png" % model['code_name'])
            plt.savefig(fname, dpi=200, bbox_inches="tight", pad_inches=0.01)
        plt.show()
if save_txt:
    f.close()

# Stuff as a function of thr

In [None]:
seeds_to_show = seed_id_list
res_thr = 0.01
start_thr = 0.1
end_thr = 0.95
set_list = ['val']
iou_to_show = 0.2

n_thr = int(np.floor((end_thr - start_thr) / res_thr + 1))
thr_list = np.array([start_thr + res_thr * i for i in range(n_thr)])
thr_list = np.round(thr_list, 2)
grid_result = {}
for model in models:
    ckpt_folder = model['ckpt']
    grid_result[ckpt_folder] = {}
    for seed_id in seeds_to_show:
        print("Processing Seed %d" % seed_id, flush=True)
        grid_result[ckpt_folder][seed_id] = {}
        for set_name in set_list:
            print("Processing set %s" % set_name, flush=True)
            grid_result[ckpt_folder][seed_id][set_name] = {}
            
            data_inference = FeederDataset(dataset, ids_dict[seed_id][set_name], task_mode, which_expert)
            this_events_list = data_inference.get_stamps()
            this_ids = data_inference.get_ids()
            prediction_obj = predictions_dict[ckpt_folder][seed_id][set_name]
            
            set_probas = prediction_obj.get_probabilities()
            max_valid = np.max([t_proba.max() for t_proba in set_probas]) - 0.05
            thr_list_seed = thr_list[thr_list < max_valid]
            print("Evaluating %d thresholds" % len(thr_list_seed), flush=True)
            
            detections_at_thr = []
            for thr in thr_list_seed:
                prediction_obj.set_probability_threshold(thr)
                this_detections = prediction_obj.get_stamps()
                detections_at_thr.append(this_detections)
            
            for i, single_events in enumerate(this_events_list):
                grid_result[ckpt_folder][seed_id][set_name][this_ids[i]] = {}
                single_detections_at_thr = [dets_at_single_thr[i] for dets_at_single_thr in detections_at_thr]
                n_detections = [p.shape[0] for p in single_detections_at_thr]
                
                af1_thr_list = Parallel(n_jobs=-1)(
                    delayed(metrics.average_metric)(single_events, single_prediction, verbose=False)
                    for single_prediction in single_detections_at_thr)
                prec_thr_list = Parallel(n_jobs=-1)(
                    delayed(metrics.metric_vs_iou)(single_events, single_prediction, [iou_to_show], metric_name=constants.PRECISION)
                    for single_prediction in single_detections_at_thr)
                prec_thr_list = np.array([p[0] for p in prec_thr_list])
                rec_thr_list = Parallel(n_jobs=-1)(
                    delayed(metrics.metric_vs_iou)(single_events, single_prediction, [iou_to_show], metric_name=constants.RECALL)
                    for single_prediction in single_detections_at_thr)
                rec_thr_list = np.array([p[0] for p in rec_thr_list])
                f1_thr_list = 2 * prec_thr_list * rec_thr_list / (prec_thr_list + rec_thr_list)
                
                grid_result[ckpt_folder][seed_id][set_name][this_ids[i]]['thr'] = np.array(thr_list_seed)
                grid_result[ckpt_folder][seed_id][set_name][this_ids[i]]['af1'] = np.array(af1_thr_list)
                grid_result[ckpt_folder][seed_id][set_name][this_ids[i]]['n_dets'] = np.array(n_detections)
                grid_result[ckpt_folder][seed_id][set_name][this_ids[i]]['f1'] = f1_thr_list
                grid_result[ckpt_folder][seed_id][set_name][this_ids[i]]['precision'] = prec_thr_list
                grid_result[ckpt_folder][seed_id][set_name][this_ids[i]]['recall'] = rec_thr_list
print("Done.")

In [None]:
for model in models:
    print("Showing model %s in %s" % (model['name'], model['ckpt']))
    n_seeds = len(seeds_to_show)
    fig, axes = plt.subplots(n_seeds, 3, figsize=(9, 3 * n_seeds), dpi=200)
    for i, seed_id in enumerate(seeds_to_show):
        for set_name in set_list:
            for j, single_id in enumerate(ids_dict[seed_id][set_name]):
                results = grid_result[ckpt_folder][seed_id][set_name][single_id]
                ax = axes[i, j]
                
                ax.plot(results['thr'], results['af1'], label="AF1")
                ax.plot(results['thr'], results['f1'], label="F1-%1.1f" % iou_to_show)
                ax.plot(results['thr'], results['recall'], label="Recall-%1.1f" % iou_to_show)
                ax.plot(results['thr'], results['precision'], label="Precision-%1.1f" % iou_to_show)
                ax.plot(results['thr'], results['n_dets'] / results['n_dets'].max(), label="Detections (% of max)")
                
                # print('af1', results['af1'])
                # print('f1', results['f1'])
                # print('recall', results['recall'])
                # print('precision', results['precision'])
                
                max_af1 = np.argmax(results['af1'])
                max_f1 = np.argmax(results['f1'])
                min_gap = np.argmin(np.abs(results['recall'] - results['precision']))
                ax.plot(results['thr'][max_af1], results['af1'][max_af1], marker='o', markersize=4, color="k", zorder=30)
                ax.plot(results['thr'][max_f1], results['f1'][max_f1], marker='o', markersize=4, color="k", zorder=30)
                ax.plot(results['thr'][min_gap], results['recall'][min_gap], marker='o', markersize=4, color="k", zorder=30)
                ax.axvline(results['thr'][max_af1], linestyle="--", color="k")
                
                ax.tick_params(labelsize=8)
                ax.set_ylim([0, 1])
                ax.set_xlim([0, 1])
                ax.set_title("Seed %d - %s - S%02d" % (seed_id, set_name, single_id), fontsize=10)
                ax.set_xlabel("Threshold", fontsize=8)
            lg = ax.legend(loc="upper left", bbox_to_anchor=(1, 1), frameon=False, fontsize=8)
    plt.tight_layout()
    plt.savefig("%s_thr_effect.pdf" % model['code_name'], bbox_extra_artists=(lg,), bbox_inches="tight")
    plt.show()