In [None]:
import os
from pprint import pprint
import sys

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy import stats

project_root = '..'
sys.path.append(project_root)

from sleeprnn.common import viz, constants
from sleeprnn.helpers import reader
from sleeprnn.detection import metrics
from figs_thesis import fig_utils
from baselines_scripts.butils import get_partitions
from sleeprnn.detection.feeder_dataset import FeederDataset

RESULTS_PATH = os.path.join(project_root, 'results')
BASELINES_PATH = os.path.join(project_root, 'resources', 'comparison_data', 'baselines_2021')

%matplotlib inline
viz.notebook_full_width()

# Tabla de desempeño (by-fold)
Comparación con la literatura en desempeño in-dataset.
Todos los P-value de REDv2-CWT son mayores a 0.05.
Todos los P-value de baselines son menor a 0.001 excepto en INTA-UCH.
Print esperado:
```
Detector & F1-score (\%) & Recall (\%) & Precision (\%) & mIoU (\%) \\

MASS-SS2-E1SS (Fixed)
REDv2-Time & $81.0\pm 0.4$ & $83.7\pm 1.2$ & $79.3\pm 1.0$ & $84.8\pm 0.2$ \\
REDv2-CWT  & $80.7\pm 0.5$ & $83.1\pm 1.7$ & $79.4\pm 2.1$ & $84.5\pm 0.4$ \\
DOSED      & $78.0\pm 0.5$ & $77.7\pm 2.4$ & $79.8\pm 2.0$ & $75.3\pm 1.3$ \\
A7         & $69.7\pm 0.4$ & $82.7\pm 1.9$ & $61.2\pm 1.5$ & $74.9\pm 0.2$ \\

MASS-SS2-E2SS (Fixed)
REDv2-Time & $85.1\pm 0.5$ & $84.5\pm 1.0$ & $86.5\pm 1.8$ & $77.8\pm 0.2$ \\
REDv2-CWT  & $85.0\pm 0.4$ & $85.0\pm 0.8$ & $86.0\pm 1.4$ & $77.9\pm 0.2$ \\
DOSED      & $81.8\pm 0.7$ & $79.7\pm 1.4$ & $85.0\pm 1.4$ & $73.8\pm 0.5$ \\
A7         & $73.4\pm 0.1$ & $82.8\pm 0.0$ & $66.4\pm 0.2$ & $74.8\pm 0.0$ \\

MASS-SS2-KC (Fixed)
REDv2-Time & $83.3\pm 0.4$ & $82.4\pm 1.0$ & $85.0\pm 0.8$ & $90.5\pm 0.2$ \\
REDv2-CWT  & $83.4\pm 0.5$ & $82.1\pm 1.1$ & $85.7\pm 0.8$ & $90.3\pm 0.3$ \\
DOSED      & $77.5\pm 1.0$ & $76.5\pm 1.9$ & $79.5\pm 2.0$ & $72.2\pm 1.3$ \\
Spinky     & $65.7\pm 0.2$ & $65.0\pm 1.8$ & $67.7\pm 2.2$ & $42.3\pm 0.1$ \\

MASS-MODA (5CV)
REDv2-Time & $81.8\pm 1.4$ & $83.5\pm 2.5$ & $80.3\pm 2.3$ & $83.1\pm 0.5$ \\
REDv2-CWT  & $81.5\pm 1.2$ & $82.8\pm 2.5$ & $80.3\pm 2.4$ & $83.1\pm 0.6$ \\
DOSED      & $77.5\pm 1.7$ & $76.4\pm 2.8$ & $78.9\pm 3.0$ & $71.4\pm 1.1$ \\
A7         & $73.3\pm 1.9$ & $74.1\pm 2.1$ & $72.8\pm 3.6$ & $71.0\pm 0.9$ \\

INTA-UCH (5CV)
REDv2-Time & $83.2\pm 4.8$ & $85.0\pm 5.4$ & $82.7\pm 8.5$ & $75.8\pm 2.7$ \\
REDv2-CWT  & $83.2\pm 4.7$ & $85.2\pm 5.9$ & $82.7\pm 8.1$ & $76.0\pm 2.5$ \\
DOSED      & $77.2\pm 7.2$ & $78.0\pm 12.6$ & $79.7\pm 8.0$ & $68.7\pm 3.9$ \\
A7         & $77.6\pm 5.4$ & $78.2\pm 7.0$ & $79.9\pm 10.8$ & $70.3\pm 2.7$ \\

MASS-SS2-E1SS (5CV)
REDv2-Time & $80.8\pm 2.1$ & $84.4\pm 4.0$ & $78.9\pm 5.4$ & $84.4\pm 1.1$ \\
REDv2-CWT  & $80.8\pm 2.0$ & $84.9\pm 4.3$ & $78.5\pm 5.5$ & $84.3\pm 1.2$ \\
DOSED      & $76.8\pm 2.9$ & $79.7\pm 5.9$ & $77.5\pm 7.8$ & $74.7\pm 2.1$ \\
A7         & $73.0\pm 3.4$ & $80.1\pm 4.0$ & $68.1\pm 5.5$ & $73.9\pm 1.0$ \\

MASS-SS2-E2SS (5CV)
REDv2-Time & $86.1\pm 2.0$ & $87.0\pm 3.7$ & $86.0\pm 3.7$ & $78.5\pm 1.1$ \\
REDv2-CWT  & $86.0\pm 2.2$ & $87.2\pm 4.1$ & $85.7\pm 4.3$ & $78.5\pm 1.1$ \\
DOSED      & $82.5\pm 2.5$ & $84.0\pm 5.0$ & $82.5\pm 4.9$ & $73.1\pm 1.1$ \\
A7         & $74.9\pm 2.8$ & $81.5\pm 3.1$ & $70.0\pm 4.3$ & $74.7\pm 1.1$ \\

MASS-SS2-KC (5CV)
REDv2-Time & $83.6\pm 1.5$ & $85.2\pm 3.4$ & $82.9\pm 3.2$ & $90.5\pm 0.6$ \\
REDv2-CWT  & $83.8\pm 1.4$ & $85.0\pm 2.4$ & $83.3\pm 2.5$ & $90.4\pm 0.5$ \\
DOSED      & $77.5\pm 2.4$ & $79.2\pm 3.7$ & $76.5\pm 3.6$ & $72.3\pm 1.4$ \\
Spinky     & $63.1\pm 3.8$ & $61.6\pm 3.5$ & $65.6\pm 6.3$ & $41.2\pm 1.6$ \\
```

In [None]:
models = [constants.V2_TIME, constants.V2_CWT1D]
baselines_ss = ['dosed', 'a7']
baselines_kc = ['dosed', 'spinky']
print_model_names = {
    constants.V2_TIME: 'REDv2-Time',
    constants.V2_CWT1D: 'REDv2-CWT',
    'dosed': 'DOSED',
    'a7': 'A7',
    'spinky': 'Spinky'
}

eval_configs = [
    dict(dataset_name=constants.MASS_SS_NAME, expert=1, strategy='fixed', seeds=11),
    dict(dataset_name=constants.MASS_SS_NAME, expert=2, strategy='fixed', seeds=11),
    dict(dataset_name=constants.MASS_KC_NAME, expert=1, strategy='fixed', seeds=11),
    dict(dataset_name=constants.MODA_SS_NAME, expert=1, strategy='5cv', seeds=3),
    dict(dataset_name=constants.INTA_SS_NAME, expert=1, strategy='5cv', seeds=3),
    dict(dataset_name=constants.MASS_SS_NAME, expert=1, strategy='5cv', seeds=3),
    dict(dataset_name=constants.MASS_SS_NAME, expert=2, strategy='5cv', seeds=3),
    dict(dataset_name=constants.MASS_KC_NAME, expert=1, strategy='5cv', seeds=3),
]
for config in eval_configs:
    print("\nLoading", config)
    dataset = reader.load_dataset(config["dataset_name"], verbose=False)
    baselines = baselines_ss if dataset.event_name == constants.SPINDLE else baselines_kc
    
    # Collect predictions
    pred_dict = {}
    for model_version in models:
        tmp_dict = fig_utils.get_red_predictions(model_version, config["strategy"], dataset, config["expert"], verbose=False)
        # Retrieve only predictions, same format as baselines
        pred_dict[model_version] = {}
        for k in tmp_dict.keys():
            fold_subjects = tmp_dict[k][constants.TEST_SUBSET].all_ids
            fold_predictions = tmp_dict[k][constants.TEST_SUBSET].get_stamps()
            pred_dict[model_version][k] = {s: pred for s, pred in zip(fold_subjects, fold_predictions)}
    for baseline_name in baselines:
        pred_dict[baseline_name] = fig_utils.get_baseline_predictions(baseline_name, config["strategy"], config["dataset_name"], config["expert"])
    # print("Loaded models:", pred_dict.keys())
    
    # Measure performance byfold
    average_mode = constants.MICRO_AVERAGE if (config["dataset_name"] == constants.MODA_SS_NAME) else constants.MACRO_AVERAGE
    _, _, test_ids_list = get_partitions(dataset, config["strategy"], config["seeds"])
    n_folds = len(test_ids_list)
    table = {'Detector': [], 'F1-score': [], 'Recall': [], 'Precision': [], 'mIoU': [], 'Fold': []}
    for model_name in pred_dict.keys():
        for k in range(n_folds):
            subject_ids = test_ids_list[k]
            feed_d = FeederDataset(dataset, subject_ids, constants.N2_RECORD, which_expert=config["expert"])
            events_list = feed_d.get_stamps()
            detections_list = [pred_dict[model_name][k][subject_id] for subject_id in subject_ids]
            performance = fig_utils.compute_fold_performance(events_list, detections_list, average_mode)
            table['Detector'].append(model_name)
            table['F1-score'].append(performance['F1-score'])
            table['Recall'].append(performance['Recall'])
            table['Precision'].append(performance['Precision'])
            table['mIoU'].append(performance['mIoU'])
            table['Fold'].append(k)
    table = pd.DataFrame.from_dict(table)
    print("By-fold statistics")
    metric_mean = table.groupby(by=["Detector"]).mean().drop(columns=["Fold"])
    metric_std = table.groupby(by=["Detector"]).std(ddof=0).drop(columns=["Fold"])
    print("Detector & F1-score (\%) & Recall (\%) & Precision (\%) & mIoU (\%) \\\\")
    for model_name in pred_dict.keys():
        print("%s & %s & %s & %s & %s \\\\" % (
            print_model_names[model_name].ljust(10),
            fig_utils.format_metric(metric_mean.at[model_name, "F1-score"], metric_std.at[model_name, "F1-score"]),
            fig_utils.format_metric(metric_mean.at[model_name, "Recall"], metric_std.at[model_name, "Recall"]),
            fig_utils.format_metric(metric_mean.at[model_name, "Precision"], metric_std.at[model_name, "Precision"]),
            fig_utils.format_metric(metric_mean.at[model_name, "mIoU"], metric_std.at[model_name, "mIoU"]),
        ))
    # Statistical tests
    reference_model_name = constants.V2_TIME
    print("P-value test against %s" % reference_model_name)
    for model_name in pred_dict.keys():
        model_metrics = table[table["Detector"] == model_name]["F1-score"].values
        reference_metrics = table[table["Detector"] == reference_model_name]["F1-score"].values
        pvalue = stats.ttest_ind(model_metrics, reference_metrics, equal_var=False)[1]
        print("%s: P %1.4f" % (print_model_names[model_name].ljust(10), pvalue))

# Dispersion by-subject
5CV solamente por brevedad, ya que ya se vio que es similar el desempeño y permite tener todos los sujetos en MASS-SS2.

Datos cuantitativos de dispersión entre sujetos (print esperado):
```
Dataset: MASS-SS2-E1SS
          F1-score    Recall  Precision      mIoU
Detector                                         
a7        5.034115  6.698362   8.493331  2.332143
dosed     5.065436  9.265257  13.893792  3.502857
v2_cwt1d  3.443896  7.987251   8.489591  2.374149
v2_time   3.575472  8.186777   8.515364  2.323281

Dataset: MASS-SS2-E2SS
          F1-score    Recall  Precision      mIoU
Detector                                         
a7        4.330717  5.803320   7.882029  1.664745
dosed     4.163007  7.523092   8.901198  2.268538
v2_cwt1d  3.483713  6.769644   6.658597  2.219977
v2_time   3.329545  6.479800   6.366437  2.186028

Dataset: MASS-SS2-KC
          F1-score    Recall  Precision      mIoU
Detector                                         
dosed     3.960219  6.606413   6.062948  1.775579
spinky    7.243040  6.883243  10.711108  2.448843
v2_cwt1d  3.070712  6.027732   5.488672  0.810547
v2_time   3.233140  6.605095   5.811940  0.966797

Dataset: MASS-MODA
           F1-score     Recall  Precision      mIoU
Detector                                           
a7        18.773054  22.023236  14.032634  6.618863
dosed     13.233049  14.972821  13.520275  3.533990
v2_cwt1d  10.929555  14.090439  10.262946  2.774643
v2_time   11.000109  14.058725   9.195598  2.723617

Dataset: INTA-UCH
          F1-score     Recall  Precision      mIoU
Detector                                          
a7        7.567395   9.741852  13.076641  4.286478
dosed     8.210749  13.977135   9.195159  4.280485
v2_cwt1d  5.872097   8.169621   9.838274  3.862264
v2_time   6.009940   7.387042  10.281951  4.194361
```

In [None]:
models = [constants.V2_TIME, constants.V2_CWT1D]
baselines_ss = ['dosed', 'a7']
baselines_kc = ['dosed', 'spinky']
print_model_names = {
    constants.V2_TIME: 'REDv2-Time',
    constants.V2_CWT1D: 'REDv2-CWT',
    'dosed': 'DOSED',
    'a7': 'A7',
    'spinky': 'Spinky'
}
print_dataset_names = {
    (constants.MASS_SS_NAME, 1): "MASS-SS2-E1SS",
    (constants.MASS_SS_NAME, 2): "MASS-SS2-E2SS",
    (constants.MASS_KC_NAME, 1): "MASS-SS2-KC",
    (constants.MODA_SS_NAME, 1): "MASS-MODA",
    (constants.INTA_SS_NAME, 1): "INTA-UCH",
}

eval_configs = [
    dict(dataset_name=constants.MASS_SS_NAME, expert=1, strategy='5cv', seeds=3),
    dict(dataset_name=constants.MASS_SS_NAME, expert=2, strategy='5cv', seeds=3),
    dict(dataset_name=constants.MASS_KC_NAME, expert=1, strategy='5cv', seeds=3),
    dict(dataset_name=constants.MODA_SS_NAME, expert=1, strategy='5cv', seeds=3),
    dict(dataset_name=constants.INTA_SS_NAME, expert=1, strategy='5cv', seeds=3),
]
dispersions_list = []
for config in eval_configs:
    print("\nLoading", config)
    dataset = reader.load_dataset(config["dataset_name"], verbose=False)
    baselines = baselines_ss if dataset.event_name == constants.SPINDLE else baselines_kc
    # Collect predictions
    pred_dict = {}
    for model_version in models:
        tmp_dict = fig_utils.get_red_predictions(model_version, config["strategy"], dataset, config["expert"], verbose=False)
        # Retrieve only predictions, same format as baselines
        pred_dict[model_version] = {}
        for k in tmp_dict.keys():
            fold_subjects = tmp_dict[k][constants.TEST_SUBSET].all_ids
            fold_predictions = tmp_dict[k][constants.TEST_SUBSET].get_stamps()
            pred_dict[model_version][k] = {s: pred for s, pred in zip(fold_subjects, fold_predictions)}
    for baseline_name in baselines:
        pred_dict[baseline_name] = fig_utils.get_baseline_predictions(baseline_name, config["strategy"], config["dataset_name"], config["expert"])
    # Measure performance by subject
    _, _, test_ids_list = get_partitions(dataset, config["strategy"], config["seeds"])
    n_folds = len(test_ids_list)
    table = {'Detector': [], 'F1-score': [], 'Recall': [], 'Precision': [], 'mIoU': [], 'Subject': [], 'Fold': []}
    if config["dataset_name"] == constants.MODA_SS_NAME:
        valid_subjects = [sub_id for sub_id in dataset.all_ids if dataset.data[sub_id]['n_blocks'] == 10]
    else:
        valid_subjects = dataset.all_ids
    for model_name in pred_dict.keys():
        for k in range(n_folds):
            subject_ids = test_ids_list[k]
            feed_d = FeederDataset(dataset, subject_ids, constants.N2_RECORD, which_expert=config["expert"])
            events_list = feed_d.get_stamps()
            detections_list = [pred_dict[model_name][k][subject_id] for subject_id in subject_ids]
            performance = fig_utils.compute_subject_performance(events_list, detections_list)
            for i, subject_id in enumerate(subject_ids):
                if subject_id in valid_subjects:
                    table['Detector'].append(model_name)
                    table['F1-score'].append(performance['F1-score'][i])
                    table['Recall'].append(performance['Recall'][i])
                    table['Precision'].append(performance['Precision'][i])
                    table['mIoU'].append(performance['mIoU'][i])
                    table['Subject'].append(subject_id)
                    table['Fold'].append(k)
    table = pd.DataFrame.from_dict(table)
    bysubject_dispersions = 100 * table.groupby(["Detector", "Subject"]).mean().drop(columns=["Fold"]).groupby("Detector").std(ddof=0)
    dispersions_list.append(bysubject_dispersions)
print("Dispersions computed.")

In [None]:
for i, config in enumerate(eval_configs):
    print("\nDataset: %s" % print_dataset_names[(config["dataset_name"], config["expert"])])
    print(dispersions_list[i])

In [None]:
save_figure = False

letters = ['A', 'B', 'C', 'D', 'E', 'F']
fig, axes = plt.subplots(1, 5, figsize=(8, 3), dpi=200, sharex=True)
for i, config in enumerate(eval_configs):
    ax = axes[i]
    disp_table = dispersions_list[i]
    if config["dataset_name"] == constants.MASS_KC_NAME:
        extra_row = pd.DataFrame([["a7", 0, 0, 0, 0]], columns=["Detector", "F1-score", "Recall", "Precision", "mIoU"])
        extra_row = extra_row.set_index("Detector")
        disp_table_mod = disp_table.append(extra_row).reindex(["spinky", "a7", "dosed", constants.V2_CWT1D, constants.V2_TIME])
    else:
        extra_row = pd.DataFrame([["spinky", 0, 0, 0, 0]], columns=["Detector", "F1-score", "Recall", "Precision", "mIoU"])
        extra_row = extra_row.set_index("Detector")
        disp_table_mod = disp_table.append(extra_row).reindex(["spinky", "a7", "dosed", constants.V2_CWT1D, constants.V2_TIME])
    
    ax = disp_table_mod.plot.barh(y=["F1-score", "Recall", "Precision"], ax=ax, fontsize=8, legend=False)
    ax.set_title(print_dataset_names[(config["dataset_name"], config["expert"])], loc="left", fontsize=8)
    
    yticklabels = ax.get_yticklabels()
    ax.set_yticklabels([print_model_names[yt.get_text()] for yt in yticklabels])
    ax.set_ylabel("")
    ax.set_xlabel("$\sigma_\mathrm{subjects}$", fontsize=8)
    ax.tick_params(labelsize=8)
    if i > 0:
        ax.set_yticklabels([])
    ax.set_xlim([0, 25])
    ax.set_xticks([0, 10, 20])
    ax.set_xticks([0, 2.5, 5, 7.5, 10, 12.5, 15, 17.5, 20, 22.5, 25], minor=True)
    ax.grid(axis="x", which="minor")
    ax.text(
        x=-0.01, y=1.15, fontsize=16, s=r"$\bf{%s}$" % letters[i],
        ha="left", transform=ax.transAxes)
plt.tight_layout()

# Get legend
lines_labels = [axes[0].get_legend_handles_labels()]
lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
# plt.subplots_adjust(bottom=0.9)
lg = fig.legend(
    lines, labels, fontsize=8, loc="lower center",
    bbox_to_anchor=(0.5, 0.02), ncol=3, frameon=False, handletextpad=0.5)

if save_figure:
    # Save figure
    fname_prefix = "result_comparison_bysubject_std"
    plt.savefig("%s.pdf" % fname_prefix, bbox_extra_artists=(lg,), bbox_inches="tight", pad_inches=0.3)
    plt.savefig("%s.png" % fname_prefix, bbox_extra_artists=(lg,), bbox_inches="tight", pad_inches=0.3)
    plt.savefig("%s.svg" % fname_prefix, bbox_extra_artists=(lg,), bbox_inches="tight", pad_inches=0.3)

plt.show()