In [None]:
import os
from pprint import pprint
import sys

import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import pandas as pd
from scipy import stats

project_root = '..'
sys.path.append(project_root)

from sleeprnn.common import viz, constants
from sleeprnn.helpers import reader, plotter, misc, performer
from sleeprnn.detection import metrics
from figs_thesis import fig_utils
from baselines_scripts.butils import get_partitions
from sleeprnn.detection.feeder_dataset import FeederDataset
from sklearn.linear_model import LinearRegression, HuberRegressor
from sleeprnn.data import utils

RESULTS_PATH = os.path.join(project_root, 'results')
BASELINES_PATH = os.path.join(project_root, 'resources', 'comparison_data', 'baselines_2021')

%matplotlib inline
viz.notebook_full_width()

# Comparación E1 y E2 en MASS-SS2

In [None]:
def get_marks(dataset_name, expert):
    dataset = reader.load_dataset(dataset_name, verbose=False)
    marks_list = dataset.get_stamps(pages_subset=constants.N2_RECORD, which_expert=expert)
    return marks_list


events_list = get_marks(constants.MASS_SS_NAME, 1)
detections_list = get_marks(constants.MASS_SS_NAME, 2)
performance = fig_utils.compute_fold_performance(events_list, detections_list, constants.MACRO_AVERAGE)
for metric_name in performance.keys():
    print('%s: %1.1f%%' % (metric_name.ljust(10), 100 * performance[metric_name]))

# Tabla de desempeño (by-fold)
Comparación con la literatura en desempeño in-dataset.
Todos los P-value de REDv2-CWT son mayores a 0.05.
Todos los P-value de baselines son menor a 0.001 excepto en INTA-UCH.
Print esperado:
```
Detector & F1-score (\%) & Recall (\%) & Precision (\%) & mIoU (\%) \\

MASS-SS2-E1SS (Fixed)
REDv2-Time & $81.0\pm 0.4$ & $83.7\pm 1.2$ & $79.3\pm 1.0$ & $84.8\pm 0.2$ \\
REDv2-CWT  & $80.7\pm 0.5$ & $83.1\pm 1.7$ & $79.4\pm 2.1$ & $84.5\pm 0.4$ \\
DOSED      & $78.0\pm 0.5$ & $77.7\pm 2.4$ & $79.8\pm 2.0$ & $75.3\pm 1.3$ \\
A7         & $69.7\pm 0.4$ & $82.7\pm 1.9$ & $61.2\pm 1.5$ & $74.9\pm 0.2$ \\

MASS-SS2-E2SS (Fixed)
REDv2-Time & $85.1\pm 0.5$ & $84.5\pm 1.0$ & $86.5\pm 1.8$ & $77.8\pm 0.2$ \\
REDv2-CWT  & $85.0\pm 0.4$ & $85.0\pm 0.8$ & $86.0\pm 1.4$ & $77.9\pm 0.2$ \\
DOSED      & $81.8\pm 0.7$ & $79.7\pm 1.4$ & $85.0\pm 1.4$ & $73.8\pm 0.5$ \\
A7         & $73.4\pm 0.1$ & $82.8\pm 0.0$ & $66.4\pm 0.2$ & $74.8\pm 0.0$ \\

MASS-SS2-KC (Fixed)
REDv2-Time & $83.3\pm 0.4$ & $82.4\pm 1.0$ & $85.0\pm 0.8$ & $90.5\pm 0.2$ \\
REDv2-CWT  & $83.4\pm 0.5$ & $82.1\pm 1.1$ & $85.7\pm 0.8$ & $90.3\pm 0.3$ \\
DOSED      & $77.5\pm 1.0$ & $76.5\pm 1.9$ & $79.5\pm 2.0$ & $72.2\pm 1.3$ \\
Spinky     & $65.7\pm 0.2$ & $65.0\pm 1.8$ & $67.7\pm 2.2$ & $42.3\pm 0.1$ \\

MASS-MODA (5CV)
REDv2-Time & $81.8\pm 1.4$ & $83.5\pm 2.5$ & $80.3\pm 2.3$ & $83.1\pm 0.5$ \\
REDv2-CWT  & $81.5\pm 1.2$ & $82.8\pm 2.5$ & $80.3\pm 2.4$ & $83.1\pm 0.6$ \\
DOSED      & $77.5\pm 1.7$ & $76.4\pm 2.8$ & $78.9\pm 3.0$ & $71.4\pm 1.1$ \\
A7         & $73.3\pm 1.9$ & $74.1\pm 2.1$ & $72.8\pm 3.6$ & $71.0\pm 0.9$ \\

INTA-UCH (5CV)
REDv2-Time & $83.2\pm 4.8$ & $85.0\pm 5.4$ & $82.7\pm 8.5$ & $75.8\pm 2.7$ \\
REDv2-CWT  & $83.2\pm 4.7$ & $85.2\pm 5.9$ & $82.7\pm 8.1$ & $76.0\pm 2.5$ \\
DOSED      & $77.2\pm 7.2$ & $78.0\pm 12.6$ & $79.7\pm 8.0$ & $68.7\pm 3.9$ \\
A7         & $77.6\pm 5.4$ & $78.2\pm 7.0$ & $79.9\pm 10.8$ & $70.3\pm 2.7$ \\

MASS-SS2-E1SS (5CV)
REDv2-Time & $80.8\pm 2.1$ & $84.4\pm 4.0$ & $78.9\pm 5.4$ & $84.4\pm 1.1$ \\
REDv2-CWT  & $80.8\pm 2.0$ & $84.9\pm 4.3$ & $78.5\pm 5.5$ & $84.3\pm 1.2$ \\
DOSED      & $76.8\pm 2.9$ & $79.7\pm 5.9$ & $77.5\pm 7.8$ & $74.7\pm 2.1$ \\
A7         & $73.0\pm 3.4$ & $80.1\pm 4.0$ & $68.1\pm 5.5$ & $73.9\pm 1.0$ \\

MASS-SS2-E2SS (5CV)
REDv2-Time & $86.1\pm 2.0$ & $87.0\pm 3.7$ & $86.0\pm 3.7$ & $78.5\pm 1.1$ \\
REDv2-CWT  & $86.0\pm 2.2$ & $87.2\pm 4.1$ & $85.7\pm 4.3$ & $78.5\pm 1.1$ \\
DOSED      & $82.5\pm 2.5$ & $84.0\pm 5.0$ & $82.5\pm 4.9$ & $73.1\pm 1.1$ \\
A7         & $74.9\pm 2.8$ & $81.5\pm 3.1$ & $70.0\pm 4.3$ & $74.7\pm 1.1$ \\

MASS-SS2-KC (5CV)
REDv2-Time & $83.6\pm 1.5$ & $85.2\pm 3.4$ & $82.9\pm 3.2$ & $90.5\pm 0.6$ \\
REDv2-CWT  & $83.8\pm 1.4$ & $85.0\pm 2.4$ & $83.3\pm 2.5$ & $90.4\pm 0.5$ \\
DOSED      & $77.5\pm 2.4$ & $79.2\pm 3.7$ & $76.5\pm 3.6$ & $72.3\pm 1.4$ \\
Spinky     & $63.1\pm 3.8$ & $61.6\pm 3.5$ & $65.6\pm 6.3$ & $41.2\pm 1.6$ \\
```

In [None]:
models = [constants.V2_TIME, constants.V2_CWT1D]
baselines_ss = ['dosed', 'a7']
baselines_kc = ['dosed', 'spinky']
print_model_names = {
    constants.V2_TIME: 'REDv2-Time',
    constants.V2_CWT1D: 'REDv2-CWT',
    'dosed': 'DOSED',
    'a7': 'A7',
    'spinky': 'Spinky'
}

eval_configs = [
    dict(dataset_name=constants.MASS_SS_NAME, expert=1, strategy='fixed', seeds=11),
    dict(dataset_name=constants.MASS_SS_NAME, expert=2, strategy='fixed', seeds=11),
    dict(dataset_name=constants.MASS_KC_NAME, expert=1, strategy='fixed', seeds=11),
    dict(dataset_name=constants.MODA_SS_NAME, expert=1, strategy='5cv', seeds=3),
    dict(dataset_name=constants.INTA_SS_NAME, expert=1, strategy='5cv', seeds=3),
    dict(dataset_name=constants.MASS_SS_NAME, expert=1, strategy='5cv', seeds=3),
    dict(dataset_name=constants.MASS_SS_NAME, expert=2, strategy='5cv', seeds=3),
    dict(dataset_name=constants.MASS_KC_NAME, expert=1, strategy='5cv', seeds=3),
]
for config in eval_configs:
    print("\nLoading", config)
    dataset = reader.load_dataset(config["dataset_name"], verbose=False)
    baselines = baselines_ss if dataset.event_name == constants.SPINDLE else baselines_kc
    
    # Collect predictions
    pred_dict = {}
    for model_version in models:
        tmp_dict = fig_utils.get_red_predictions(model_version, config["strategy"], dataset, config["expert"], verbose=False)
        # Retrieve only predictions, same format as baselines
        pred_dict[model_version] = {}
        for k in tmp_dict.keys():
            fold_subjects = tmp_dict[k][constants.TEST_SUBSET].all_ids
            fold_predictions = tmp_dict[k][constants.TEST_SUBSET].get_stamps()
            pred_dict[model_version][k] = {s: pred for s, pred in zip(fold_subjects, fold_predictions)}
    for baseline_name in baselines:
        pred_dict[baseline_name] = fig_utils.get_baseline_predictions(baseline_name, config["strategy"], config["dataset_name"], config["expert"])
    # print("Loaded models:", pred_dict.keys())
    
    # Measure performance byfold
    average_mode = constants.MICRO_AVERAGE if (config["dataset_name"] == constants.MODA_SS_NAME) else constants.MACRO_AVERAGE
    _, _, test_ids_list = get_partitions(dataset, config["strategy"], config["seeds"])
    n_folds = len(test_ids_list)
    table = {'Detector': [], 'F1-score': [], 'Recall': [], 'Precision': [], 'mIoU': [], 'Fold': []}
    for model_name in pred_dict.keys():
        for k in range(n_folds):
            subject_ids = test_ids_list[k]
            feed_d = FeederDataset(dataset, subject_ids, constants.N2_RECORD, which_expert=config["expert"])
            events_list = feed_d.get_stamps()
            detections_list = [pred_dict[model_name][k][subject_id] for subject_id in subject_ids]
            performance = fig_utils.compute_fold_performance(events_list, detections_list, average_mode)
            table['Detector'].append(model_name)
            table['F1-score'].append(performance['F1-score'])
            table['Recall'].append(performance['Recall'])
            table['Precision'].append(performance['Precision'])
            table['mIoU'].append(performance['mIoU'])
            table['Fold'].append(k)
    table = pd.DataFrame.from_dict(table)
    print("By-fold statistics")
    metric_mean = table.groupby(by=["Detector"]).mean().drop(columns=["Fold"])
    metric_std = table.groupby(by=["Detector"]).std(ddof=0).drop(columns=["Fold"])
    print("Detector & F1-score (\%) & Recall (\%) & Precision (\%) & mIoU (\%) \\\\")
    for model_name in pred_dict.keys():
        print("%s & %s & %s & %s & %s \\\\" % (
            print_model_names[model_name].ljust(10),
            fig_utils.format_metric(metric_mean.at[model_name, "F1-score"], metric_std.at[model_name, "F1-score"]),
            fig_utils.format_metric(metric_mean.at[model_name, "Recall"], metric_std.at[model_name, "Recall"]),
            fig_utils.format_metric(metric_mean.at[model_name, "Precision"], metric_std.at[model_name, "Precision"]),
            fig_utils.format_metric(metric_mean.at[model_name, "mIoU"], metric_std.at[model_name, "mIoU"]),
        ))
    # Statistical tests
    reference_model_name = constants.V2_TIME
    print("P-value test against %s" % reference_model_name)
    for model_name in pred_dict.keys():
        model_metrics = table[table["Detector"] == model_name]["F1-score"].values
        reference_metrics = table[table["Detector"] == reference_model_name]["F1-score"].values
        pvalue = stats.ttest_ind(model_metrics, reference_metrics, equal_var=False)[1]
        print("%s: P %1.4f" % (print_model_names[model_name].ljust(10), pvalue))

# Dispersion by-subject
5CV solamente por brevedad, ya que ya se vio que es similar el desempeño y permite tener todos los sujetos en MASS-SS2.

Datos cuantitativos de dispersión entre sujetos (print esperado):
```
Dataset: MASS-SS2-E1SS
          F1-score    Recall  Precision      mIoU
Detector                                         
a7        5.034115  6.698362   8.493331  2.332143
dosed     5.065436  9.265257  13.893792  3.502857
v2_cwt1d  3.443896  7.987251   8.489591  2.374149
v2_time   3.575472  8.186777   8.515364  2.323281

Dataset: MASS-SS2-E2SS
          F1-score    Recall  Precision      mIoU
Detector                                         
a7        4.330717  5.803320   7.882029  1.664745
dosed     4.163007  7.523092   8.901198  2.268538
v2_cwt1d  3.483713  6.769644   6.658597  2.219977
v2_time   3.329545  6.479800   6.366437  2.186028

Dataset: MASS-SS2-KC
          F1-score    Recall  Precision      mIoU
Detector                                         
dosed     3.960219  6.606413   6.062948  1.775579
spinky    7.243040  6.883243  10.711108  2.448843
v2_cwt1d  3.070712  6.027732   5.488672  0.810547
v2_time   3.233140  6.605095   5.811940  0.966797

Dataset: MASS-MODA
           F1-score     Recall  Precision      mIoU
Detector                                           
a7        18.773054  22.023236  14.032634  6.618863
dosed     13.233049  14.972821  13.520275  3.533990
v2_cwt1d  10.929555  14.090439  10.262946  2.774643
v2_time   11.000109  14.058725   9.195598  2.723617

Dataset: INTA-UCH
          F1-score     Recall  Precision      mIoU
Detector                                          
a7        7.567395   9.741852  13.076641  4.286478
dosed     8.210749  13.977135   9.195159  4.280485
v2_cwt1d  5.872097   8.169621   9.838274  3.862264
v2_time   6.009940   7.387042  10.281951  4.194361
```

In [None]:
models = [constants.V2_TIME, constants.V2_CWT1D]
baselines_ss = ['dosed', 'a7']
baselines_kc = ['dosed', 'spinky']
print_model_names = {
    constants.V2_TIME: 'REDv2-Time',
    constants.V2_CWT1D: 'REDv2-CWT',
    'dosed': 'DOSED',
    'a7': 'A7',
    'spinky': 'Spinky'
}
print_dataset_names = {
    (constants.MASS_SS_NAME, 1): "MASS-SS2-E1SS",
    (constants.MASS_SS_NAME, 2): "MASS-SS2-E2SS",
    (constants.MASS_KC_NAME, 1): "MASS-SS2-KC",
    (constants.MODA_SS_NAME, 1): "MASS-MODA",
    (constants.INTA_SS_NAME, 1): "INTA-UCH",
}

eval_configs = [
    dict(dataset_name=constants.MASS_SS_NAME, expert=1, strategy='5cv', seeds=3),
    dict(dataset_name=constants.MASS_SS_NAME, expert=2, strategy='5cv', seeds=3),
    dict(dataset_name=constants.MASS_KC_NAME, expert=1, strategy='5cv', seeds=3),
    dict(dataset_name=constants.MODA_SS_NAME, expert=1, strategy='5cv', seeds=3),
    dict(dataset_name=constants.INTA_SS_NAME, expert=1, strategy='5cv', seeds=3),
]
dispersions_list = []
for config in eval_configs:
    print("\nLoading", config)
    dataset = reader.load_dataset(config["dataset_name"], verbose=False)
    baselines = baselines_ss if dataset.event_name == constants.SPINDLE else baselines_kc
    # Collect predictions
    pred_dict = {}
    for model_version in models:
        tmp_dict = fig_utils.get_red_predictions(model_version, config["strategy"], dataset, config["expert"], verbose=False)
        # Retrieve only predictions, same format as baselines
        pred_dict[model_version] = {}
        for k in tmp_dict.keys():
            fold_subjects = tmp_dict[k][constants.TEST_SUBSET].all_ids
            fold_predictions = tmp_dict[k][constants.TEST_SUBSET].get_stamps()
            pred_dict[model_version][k] = {s: pred for s, pred in zip(fold_subjects, fold_predictions)}
    for baseline_name in baselines:
        pred_dict[baseline_name] = fig_utils.get_baseline_predictions(baseline_name, config["strategy"], config["dataset_name"], config["expert"])
    # Measure performance by subject
    _, _, test_ids_list = get_partitions(dataset, config["strategy"], config["seeds"])
    n_folds = len(test_ids_list)
    table = {'Detector': [], 'F1-score': [], 'Recall': [], 'Precision': [], 'mIoU': [], 'Subject': [], 'Fold': []}
    if config["dataset_name"] == constants.MODA_SS_NAME:
        valid_subjects = [sub_id for sub_id in dataset.all_ids if dataset.data[sub_id]['n_blocks'] == 10]
    else:
        valid_subjects = dataset.all_ids
    for model_name in pred_dict.keys():
        for k in range(n_folds):
            subject_ids = test_ids_list[k]
            feed_d = FeederDataset(dataset, subject_ids, constants.N2_RECORD, which_expert=config["expert"])
            events_list = feed_d.get_stamps()
            detections_list = [pred_dict[model_name][k][subject_id] for subject_id in subject_ids]
            performance = fig_utils.compute_subject_performance(events_list, detections_list)
            for i, subject_id in enumerate(subject_ids):
                if subject_id in valid_subjects:
                    table['Detector'].append(model_name)
                    table['F1-score'].append(performance['F1-score'][i])
                    table['Recall'].append(performance['Recall'][i])
                    table['Precision'].append(performance['Precision'][i])
                    table['mIoU'].append(performance['mIoU'][i])
                    table['Subject'].append(subject_id)
                    table['Fold'].append(k)
    table = pd.DataFrame.from_dict(table)
    bysubject_dispersions = 100 * table.groupby(["Detector", "Subject"]).mean().drop(columns=["Fold"]).groupby("Detector").std(ddof=0)
    dispersions_list.append(bysubject_dispersions)
print("Dispersions computed.")

In [None]:
for i, config in enumerate(eval_configs):
    print("\nDataset: %s" % print_dataset_names[(config["dataset_name"], config["expert"])])
    print(dispersions_list[i])

In [None]:
save_figure = False

letters = ['A', 'B', 'C', 'D', 'E', 'F']
fig, axes = plt.subplots(1, 5, figsize=(8, 3), dpi=200, sharex=True)
for i, config in enumerate(eval_configs):
    ax = axes[i]
    disp_table = dispersions_list[i]
    if config["dataset_name"] == constants.MASS_KC_NAME:
        extra_row = pd.DataFrame([["a7", 0, 0, 0, 0]], columns=["Detector", "F1-score", "Recall", "Precision", "mIoU"])
        extra_row = extra_row.set_index("Detector")
        disp_table_mod = disp_table.append(extra_row).reindex(["spinky", "a7", "dosed", constants.V2_CWT1D, constants.V2_TIME])
    else:
        extra_row = pd.DataFrame([["spinky", 0, 0, 0, 0]], columns=["Detector", "F1-score", "Recall", "Precision", "mIoU"])
        extra_row = extra_row.set_index("Detector")
        disp_table_mod = disp_table.append(extra_row).reindex(["spinky", "a7", "dosed", constants.V2_CWT1D, constants.V2_TIME])
    
    ax = disp_table_mod.plot.barh(y=["F1-score", "Recall", "Precision"], ax=ax, fontsize=8, legend=False)
    ax.set_title(print_dataset_names[(config["dataset_name"], config["expert"])], loc="left", fontsize=8)
    
    yticklabels = ax.get_yticklabels()
    ax.set_yticklabels([print_model_names[yt.get_text()] for yt in yticklabels])
    ax.set_ylabel("")
    ax.set_xlabel("$\sigma_\mathrm{subjects}$ (%)", fontsize=8)
    ax.tick_params(labelsize=8)
    if i > 0:
        ax.set_yticklabels([])
    ax.set_xlim([0, 25])
    ax.set_xticks([0, 10, 20])
    ax.set_xticks([0, 2.5, 5, 7.5, 10, 12.5, 15, 17.5, 20, 22.5, 25], minor=True)
    ax.grid(axis="x", which="minor")
    ax.text(
        x=-0.01, y=1.15, fontsize=16, s=r"$\bf{%s}$" % letters[i],
        ha="left", transform=ax.transAxes)
plt.tight_layout()

# Get legend
lines_labels = [axes[0].get_legend_handles_labels()]
lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
# plt.subplots_adjust(bottom=0.9)
lg = fig.legend(
    lines, labels, fontsize=8, loc="lower center",
    bbox_to_anchor=(0.5, 0.02), ncol=3, frameon=False, handletextpad=0.5)

if save_figure:
    # Save figure
    fname_prefix = "result_comparison_bysubject_std"
    plt.savefig("%s.pdf" % fname_prefix, bbox_extra_artists=(lg,), bbox_inches="tight", pad_inches=0.3)
    plt.savefig("%s.png" % fname_prefix, bbox_extra_artists=(lg,), bbox_inches="tight", pad_inches=0.3)
    plt.savefig("%s.svg" % fname_prefix, bbox_extra_artists=(lg,), bbox_inches="tight", pad_inches=0.3)

plt.show()

# Efecto umbral probabilidad: curva PR y métricas vs umbral

In [None]:
models = [constants.V2_TIME, constants.V2_CWT1D]
baselines_ss = ['dosed', 'a7']
baselines_kc = ['dosed', 'spinky']
print_model_names = {
    constants.V2_TIME: 'REDv2-Time',
    constants.V2_CWT1D: 'REDv2-CWT',
    'dosed': 'DOSED',
    'a7': 'A7',
    'spinky': 'Spinky'
}
print_dataset_names = {
    (constants.MASS_SS_NAME, 1): "MASS-SS2-E1SS",
    (constants.MASS_SS_NAME, 2): "MASS-SS2-E2SS",
    (constants.MASS_KC_NAME, 1): "MASS-SS2-KC",
    (constants.MODA_SS_NAME, 1): "MASS-MODA",
    (constants.INTA_SS_NAME, 1): "INTA-UCH",
}

eval_configs = [
    dict(dataset_name=constants.MASS_SS_NAME, expert=1, strategy='5cv', seeds=3),
    dict(dataset_name=constants.MASS_SS_NAME, expert=2, strategy='5cv', seeds=3),
    dict(dataset_name=constants.MASS_KC_NAME, expert=1, strategy='5cv', seeds=3),
    dict(dataset_name=constants.MODA_SS_NAME, expert=1, strategy='5cv', seeds=3),
    dict(dataset_name=constants.INTA_SS_NAME, expert=1, strategy='5cv', seeds=3),
]
metrics_list = []
for config in eval_configs:
    print("\nLoading", config)
    dataset = reader.load_dataset(config["dataset_name"], verbose=False)
    baselines = baselines_ss if dataset.event_name == constants.SPINDLE else baselines_kc
    
    # Collect predictions
    pred_dict = {}
    for model_version in models:
        tmp_dict = fig_utils.get_red_predictions(model_version, config["strategy"], dataset, config["expert"], verbose=False)
        # Retrieve only predictions, same format as baselines
        pred_dict[model_version] = {}
        for k in tmp_dict.keys():
            fold_subjects = tmp_dict[k][constants.TEST_SUBSET].all_ids
            fold_predictions = tmp_dict[k][constants.TEST_SUBSET].get_stamps()
            pred_dict[model_version][k] = {s: pred for s, pred in zip(fold_subjects, fold_predictions)}
    for baseline_name in baselines:
        pred_dict[baseline_name] = fig_utils.get_baseline_predictions(baseline_name, config["strategy"], config["dataset_name"], config["expert"])
    # print("Loaded models:", pred_dict.keys())
    
    # Measure performance byfold
    average_mode = constants.MICRO_AVERAGE if (config["dataset_name"] == constants.MODA_SS_NAME) else constants.MACRO_AVERAGE
    _, _, test_ids_list = get_partitions(dataset, config["strategy"], config["seeds"])
    n_folds = len(test_ids_list)
    table = {'Detector': [], 'F1-score': [], 'Recall': [], 'Precision': [], 'mIoU': [], 'Fold': []}
    for model_name in pred_dict.keys():
        for k in range(n_folds):
            subject_ids = test_ids_list[k]
            feed_d = FeederDataset(dataset, subject_ids, constants.N2_RECORD, which_expert=config["expert"])
            events_list = feed_d.get_stamps()
            detections_list = [pred_dict[model_name][k][subject_id] for subject_id in subject_ids]
            performance = fig_utils.compute_fold_performance(events_list, detections_list, average_mode)
            table['Detector'].append(model_name)
            table['F1-score'].append(performance['F1-score'])
            table['Recall'].append(performance['Recall'])
            table['Precision'].append(performance['Precision'])
            table['mIoU'].append(performance['mIoU'])
            table['Fold'].append(k)
    table = pd.DataFrame.from_dict(table)
    print("By-fold statistics")
    metric_mean = table.groupby(by=["Detector"]).mean().drop(columns=["Fold"])
    metrics_list.append(metric_mean)
print("Metrics computed.")

In [None]:
# Compute change due to threshold
adjusted_thr_list = np.arange(0.05, 0.95 + 0.001, 0.05)
metrics_curve_list = []  # [loc in config][model_name][fold_id][metric_name][loc in thr]
for config in eval_configs:
    average_mode = constants.MICRO_AVERAGE if (config["dataset_name"] == constants.MODA_SS_NAME) else constants.MACRO_AVERAGE
    print("\nLoading", config)
    dataset = reader.load_dataset(config["dataset_name"], verbose=False)
    baselines = baselines_ss if dataset.event_name == constants.SPINDLE else baselines_kc
    metrics_curve = {}
    for model_version in models:
        metrics_curve[model_version] = {}
        tmp_dict = fig_utils.get_red_predictions(model_version, config["strategy"], dataset, config["expert"], verbose=False)
        for k in tmp_dict.keys():
            optimal_thr = tmp_dict[k][constants.TEST_SUBSET].probability_threshold
            # print("Fold %d, optimal thr %1.3f" % (k, optimal_thr))
            # Get events
            fold_subjects = tmp_dict[k][constants.TEST_SUBSET].all_ids
            feed_d = FeederDataset(dataset, fold_subjects, constants.N2_RECORD, which_expert=config["expert"])
            events_list = feed_d.get_stamps()
            # Get predictions
            tmp_metric_dict_list = []
            for adjusted_thr in adjusted_thr_list:
                tmp_dict[k][constants.TEST_SUBSET].set_probability_threshold(adjusted_thr, adjusted_by_threshold=optimal_thr, verbose=False)
                detections_list = tmp_dict[k][constants.TEST_SUBSET].get_stamps()
                performance = fig_utils.compute_fold_performance(events_list, detections_list, average_mode)
                tmp_metric_dict_list.append(performance)
            # list of dict -> dict of list
            dict_of_list = {}
            for metric_key in tmp_metric_dict_list[0].keys():
                dict_of_list[metric_key] = np.array([tmp_metric_dict_list[thr_idx][metric_key] for thr_idx in range(len(adjusted_thr_list))])
            metrics_curve[model_version][k] = dict_of_list
    metrics_curve_list.append(metrics_curve)
print("Done.")

In [None]:
import pickle

pr_ckpt_action = "load"
fname = 'pr_curve_ckpt.pkl'
adjusted_thr_list = np.arange(0.05, 0.95 + 0.001, 0.05)
if pr_ckpt_action == "save":
    # save checkpoint
    with open(fname, 'wb') as handle:
        pickle.dump(metrics_curve_list, handle, protocol=pickle.HIGHEST_PROTOCOL)
elif pr_ckpt_action == "load":
    # load checkpoint
    with open(fname, 'rb') as handle:
        metrics_curve_list = pickle.load(handle)

In [None]:
save_figure = False
markersize = 5
red_alpha = 0.6
pr_alpha_curve = 0.5
thr_alpha_curve = 1.0
baseline_color = viz.GREY_COLORS[9]
letters = ['A', 'B', 'C', 'D', 'E']
letters2 = ['F', 'G', 'H', 'I', 'J']
model_specs = {
    constants.V2_TIME: dict(marker='o', color=viz.PALETTE['blue']),
    constants.V2_CWT1D: dict(marker='o', color=viz.PALETTE['red']),
    'dosed': dict(marker='s', color=baseline_color),
    'a7': dict(marker='^', color=baseline_color),
    'spinky': dict(marker='v', color=baseline_color),
}
spindle_net = dict(metrics={"F1-score": .83, "Recall": .852, "Precision": .81}, marker='<', color=baseline_color)
dkl_kc = dict(metrics={"F1-score": .78, "Recall": .80, "Precision": .77}, marker='>', color=baseline_color)

fig, axes = plt.subplots(2, 5, figsize=(8, 4), dpi=200)
for i, config in enumerate(eval_configs):
    # PR PLOT
    ax = axes[0, i]
    metric_dict = metrics_list[i].to_dict('index')
    for model_name in metric_dict.keys():
        ax.plot(
            metric_dict[model_name]["Recall"], 
            metric_dict[model_name]["Precision"],
            linestyle="None",
            alpha=red_alpha,
            marker=model_specs[model_name]["marker"],
            markersize=markersize,
            markeredgewidth=0.0, zorder=20,
            color=model_specs[model_name]["color"],
            label=print_model_names[model_name])
    ax.set_title(print_dataset_names[(config["dataset_name"], config["expert"])], loc="left", fontsize=8)
    plotter.format_precision_recall_plot_simple(
        ax, axis_range=(0.5, 1), show_quadrants=False, show_grid=True,
        axis_markers=np.arange(0.5, 1 + 0.001, 0.5), minor_axis_markers=np.arange(0.5, 1 + 0.001, 0.1))
    ax.tick_params(labelsize=8)
    ax.set_xlabel("Recall", fontsize=8)
    if i == 0:
        ax.set_ylabel("Precision", fontsize=8) 
    else:
        ax.set_yticks([])
    # Get labels closer to axis
    ax.xaxis.labelpad = -8
    ax.yaxis.labelpad = -8
    # Add PR curve
    # [loc in config][model_name][fold_id][metric_name][loc in thr]
    pr_curve_data = metrics_curve_list[i]
    for model_name in pr_curve_data.keys():
        n_folds = len(pr_curve_data[model_name].keys())
        seeds_recall = [pr_curve_data[model_name][k]["Recall"] for k in range(n_folds)]
        seeds_precision = [pr_curve_data[model_name][k]["Precision"] for k in range(n_folds)]
        mean_recall_curve, mean_precision_curve = plotter.average_curves(seeds_recall, seeds_precision)
        ax.plot(
            mean_recall_curve, mean_precision_curve,
            linewidth=1.0, color=model_specs[model_name]["color"], zorder=10, alpha=pr_alpha_curve)

    if print_dataset_names[(config["dataset_name"], config["expert"])] == print_dataset_names[(constants.MASS_SS_NAME, 2)]:
        ax.plot(
            spindle_net["metrics"]["Recall"], 
            spindle_net["metrics"]["Precision"],
            linestyle="None",
            alpha=red_alpha,
            marker=spindle_net["marker"],
            markersize=markersize,
            markeredgewidth=0.0, zorder=20,
            color=spindle_net["color"],
            label="SpindleNet")
    if print_dataset_names[(config["dataset_name"], config["expert"])] == print_dataset_names[(constants.MASS_KC_NAME, 1)]:
        ax.plot(
            dkl_kc["metrics"]["Recall"], 
            dkl_kc["metrics"]["Precision"],
            linestyle="None",
            alpha=red_alpha,
            marker=dkl_kc["marker"],
            markersize=markersize,
            markeredgewidth=0.0, zorder=20,
            color=dkl_kc["color"],
            label="DKL-KC")
    
    ax.text(
        x=-0.01, y=1.2, fontsize=16, s=r"$\bf{%s}$" % letters[i],
        ha="left", transform=ax.transAxes)
    
    # CHANGE DUE TO THR
    ax = axes[1, i]
    model_name = "v2_time"
    # for model_name in pr_curve_data.keys():
    
    n_folds = len(pr_curve_data[model_name].keys())
    seeds_metric = {
        metric_name: np.stack([pr_curve_data[model_name][k][metric_name] for k in range(n_folds)], axis=0)
        for metric_name in ["F1-score", "Recall", "Precision", "mIoU"]
    }
    for metric_name in ["F1-score", "Recall", "Precision", "mIoU"]:
        ax.plot(
            adjusted_thr_list, seeds_metric[metric_name].mean(axis=0), 
            linewidth=1.0, zorder=10, alpha=thr_alpha_curve, label=metric_name)
        ax.fill_between(
            adjusted_thr_list, 
            seeds_metric[metric_name].mean(axis=0) + seeds_metric[metric_name].std(axis=0), 
            seeds_metric[metric_name].mean(axis=0) - seeds_metric[metric_name].std(axis=0), 
            linewidth=1.0, zorder=10, alpha=0.2)
    ax.axvline(0.5, color="k", linewidth=1.5, zorder=30)
    ax.tick_params(labelsize=8)
    ax.set_xlabel("Umbral prob.", fontsize=8)
    ax.set_xlim([0, 1])
    ax.set_ylim([0.5, 1.0])
    ax.set_xticks([0, 1.0])
    ax.set_xticks(np.arange(0, 1 + 0.001, 0.1), minor=True)
    ax.set_yticks([0.5, 1.0])
    ax.set_yticks(np.arange(0.5, 1 + 0.001, 0.1), minor=True)
    ax.grid(which="minor")
    if i == 0:
        ax.set_ylabel("Métrica", fontsize=8) 
    else:
        ax.set_yticks([])
    ax.xaxis.labelpad = -8
    ax.yaxis.labelpad = -8
    ax.set_aspect(2)
    
    ax.text(
        x=-0.01, y=1.05, fontsize=16, s=r"$\bf{%s}$" % letters2[i],
        ha="left", transform=ax.transAxes)
    
plt.tight_layout()


# Get legend methods
labels_to_lines_dict = {}
for ax in axes[0, :]:
    t_lines, t_labels = ax.get_legend_handles_labels()
    for lbl, lin in zip(t_labels, t_lines):
        labels_to_lines_dict[lbl] = lin
labels = ["REDv2-Time", "REDv2-CWT", "DOSED", "A7", "Spinky", "SpindleNet", "DKL-KC"]
lines = [labels_to_lines_dict[lbl] for lbl in labels]
lg1 = fig.legend(
    lines, labels, fontsize=7, loc="lower center",
    bbox_to_anchor=(0.5, 0.49), ncol=len(labels), frameon=False, handletextpad=0.5)

# Get legend metrics
labels_to_lines_dict = {}
for ax in axes[1, :]:
    t_lines, t_labels = ax.get_legend_handles_labels()
    for lbl, lin in zip(t_labels, t_lines):
        labels_to_lines_dict[lbl] = lin
labels = ["F1-score", "Recall", "Precision", "mIoU"]
lines = [labels_to_lines_dict[lbl] for lbl in labels]
lg2 = fig.legend(
    lines, labels, fontsize=7, loc="lower center",
    bbox_to_anchor=(0.5, 0.02), ncol=len(labels), frameon=False, handletextpad=0.5)

if save_figure:
    # Save figure
    fname_prefix = "result_comparison_pr_thr"
    plt.savefig("%s.pdf" % fname_prefix, bbox_extra_artists=(lg1, lg2), bbox_inches="tight", pad_inches=0.3)
    plt.savefig("%s.png" % fname_prefix, bbox_extra_artists=(lg1, lg2), bbox_inches="tight", pad_inches=0.3)
    plt.savefig("%s.svg" % fname_prefix, bbox_extra_artists=(lg1, lg2), bbox_inches="tight", pad_inches=0.3)


plt.show()

# Efecto umbral IoU: F1 vs IoU, Histograma IoU

In [None]:
models = [constants.V2_TIME, constants.V2_CWT1D]
baselines_ss = ['dosed', 'a7']
baselines_kc = ['dosed', 'spinky']
print_model_names = {
    constants.V2_TIME: 'REDv2-Time',
    constants.V2_CWT1D: 'REDv2-CWT',
    'dosed': 'DOSED',
    'a7': 'A7',
    'spinky': 'Spinky'
}
print_dataset_names = {
    (constants.MASS_SS_NAME, 1): "MASS-SS2-E1SS",
    (constants.MASS_SS_NAME, 2): "MASS-SS2-E2SS",
    (constants.MASS_KC_NAME, 1): "MASS-SS2-KC",
    (constants.MODA_SS_NAME, 1): "MASS-MODA",
    (constants.INTA_SS_NAME, 1): "INTA-UCH",
}

eval_configs = [
    dict(dataset_name=constants.MASS_SS_NAME, expert=1, strategy='5cv', seeds=3),
    dict(dataset_name=constants.MASS_SS_NAME, expert=2, strategy='5cv', seeds=3),
    dict(dataset_name=constants.MASS_KC_NAME, expert=1, strategy='5cv', seeds=3),
    dict(dataset_name=constants.MODA_SS_NAME, expert=1, strategy='5cv', seeds=3),
    dict(dataset_name=constants.INTA_SS_NAME, expert=1, strategy='5cv', seeds=3),
]

iou_curve_axis = np.arange(0.05, 0.95 + 0.001, 0.05)
iou_hist_bins = np.linspace(0, 1, 21)

metrics_list = []
for config in eval_configs:
    print("\nLoading", config)
    dataset = reader.load_dataset(config["dataset_name"], verbose=False)
    baselines = baselines_ss if dataset.event_name == constants.SPINDLE else baselines_kc
    
    # Collect predictions
    pred_dict = {}
    for model_version in models:
        tmp_dict = fig_utils.get_red_predictions(model_version, config["strategy"], dataset, config["expert"], verbose=False)
        # Retrieve only predictions, same format as baselines
        pred_dict[model_version] = {}
        for k in tmp_dict.keys():
            fold_subjects = tmp_dict[k][constants.TEST_SUBSET].all_ids
            fold_predictions = tmp_dict[k][constants.TEST_SUBSET].get_stamps()
            pred_dict[model_version][k] = {s: pred for s, pred in zip(fold_subjects, fold_predictions)}
    for baseline_name in baselines:
        pred_dict[baseline_name] = fig_utils.get_baseline_predictions(baseline_name, config["strategy"], config["dataset_name"], config["expert"])
    # print("Loaded models:", pred_dict.keys())
    
    # Measure performance byfold
    average_mode = constants.MICRO_AVERAGE if (config["dataset_name"] == constants.MODA_SS_NAME) else constants.MACRO_AVERAGE
    _, _, test_ids_list = get_partitions(dataset, config["strategy"], config["seeds"])
    n_folds = len(test_ids_list)
    table = {'Detector': [], 'F1-score_vs_iou': [], 'IoU_hist': [], 'mIoU': [], 'Fold': []}
    for model_name in pred_dict.keys():
        for k in range(n_folds):
            subject_ids = test_ids_list[k]
            feed_d = FeederDataset(dataset, subject_ids, constants.N2_RECORD, which_expert=config["expert"])
            events_list = feed_d.get_stamps()
            detections_list = [pred_dict[model_name][k][subject_id] for subject_id in subject_ids]
            performance = fig_utils.compute_fold_performance_vs_iou(
                events_list, detections_list, average_mode, iou_curve_axis)
            iou_mean, iou_hist_values = fig_utils.compute_iou_histogram(
                performance['nonzero_IoU'], average_mode, iou_hist_bins)
            table['Detector'].append(model_name)
            table['F1-score_vs_iou'].append(performance['F1-score_vs_iou'])
            table['IoU_hist'].append(iou_hist_values)
            table['mIoU'].append(iou_mean)
            table['Fold'].append(k)
    table = pd.DataFrame.from_dict(table)
    print("By-fold statistics")
    metric_mean = table.groupby(by=["Detector"]).apply(np.mean).drop(columns=["Fold"])
    metrics_list.append(metric_mean)
print("Metrics computed.")

In [None]:
save_figure = False
markersize = 4
f1_alpha_curve = 1.0
iou_thr_reported = 0.2
f1_markers_iou = [0.2, 0.4, 0.6, 0.8]

idx_markers_iou = [
    misc.closest_index(single_marker, iou_curve_axis) 
    for single_marker in f1_markers_iou]

baseline_color = viz.GREY_COLORS[8]
letters = ['A', 'B', 'C', 'D', 'E']
letters2 = ['F', 'G', 'H', 'I', 'J']
model_specs = {
    constants.V2_TIME: dict(marker='o', color=viz.PALETTE['blue']),
    constants.V2_CWT1D: dict(marker='o', color=viz.PALETTE['red']),
    'dosed': dict(marker='s', color=baseline_color),
    'a7': dict(marker='^', color=baseline_color),
    'spinky': dict(marker='v', color=baseline_color),
}
spindle_net = dict(metrics={"F1-score": .83, "Recall": .852, "Precision": .81}, marker='<', color=baseline_color)
dkl_kc = dict(metrics={"F1-score": .78, "Recall": .80, "Precision": .77}, marker='>', color=baseline_color)

fig, axes = plt.subplots(2, 5, figsize=(8, 4.5), dpi=200)
for i, config in enumerate(eval_configs):
    # F1-score vs IoU
    ax = axes[0, i]
    metric_dict = metrics_list[i].to_dict('index')
    for model_name in metric_dict.keys():
        ax.plot(
            iou_curve_axis, 
            metric_dict[model_name]["F1-score_vs_iou"],
            linewidth=1,
            marker=model_specs[model_name]["marker"],
            markersize=markersize,
            alpha=f1_alpha_curve,
            markeredgewidth=0.0, zorder=20,
            color=model_specs[model_name]["color"],
            label=print_model_names[model_name],
            markevery=idx_markers_iou)
    ax.set_title(print_dataset_names[(config["dataset_name"], config["expert"])], loc="left", fontsize=8)
    # ax.axvline(iou_thr_reported, color="k", linewidth=1.5, zorder=5, alpha=0.5)
    ax.tick_params(labelsize=8)
    ax.set_xlabel("Umbral IoU", fontsize=8)
    ax.set_xlim([0, 1])
    ax.set_ylim([0, 1])
    ax.set_xticks([0, 1.0])
    ax.set_xticks(np.arange(0, 1 + 0.001, 0.1), minor=True)
    ax.set_yticks([0, 1.0])
    ax.set_yticks(np.arange(0, 1 + 0.001, 0.1), minor=True)
    ax.grid(which="minor")
    if i == 0:
        ax.set_ylabel("F1-score", fontsize=8) 
    else:
        ax.set_yticks([])
    ax.xaxis.labelpad = -8
    ax.yaxis.labelpad = -8
    ax.set_aspect("equal")
    
    ax.text(
        x=-0.01, y=1.2, fontsize=16, s=r"$\bf{%s}$" % letters[i],
        ha="left", transform=ax.transAxes)
    
    # IoU Hist
    ax = axes[1, i]
    max_value = 0
    for model_name in metric_dict.keys():
        max_value = max(metric_dict[model_name]["IoU_hist"].max(), max_value)
    max_value = max_value * 1.3
    n_cases = len(metric_dict.keys())
    y_sep = 1 / n_cases
    this_center = 1 - y_sep
    
    model_names = list(metric_dict.keys())
    reference_order = [constants.V2_TIME, constants.V2_CWT1D, "dosed", "a7", "spinky"]
    model_names_sorted = [n for n in reference_order if n in model_names]
    
    for i_offset, model_name in enumerate(model_names_sorted):
        x, y = plotter.piecewise_constant_histogram(
            iou_hist_bins, metric_dict[model_name]["IoU_hist"])
        y = y_sep * y / max_value
        ax.plot(
            [metric_dict[model_name]["mIoU"], metric_dict[model_name]["mIoU"]], 
            [this_center, this_center + 0.8*y_sep],
            linewidth=1.5, color="k", zorder=25, label='mIoU')
        ax.fill_between(
            x, this_center + y, this_center,
            edgecolor=model_specs[model_name]["color"], linewidth=1,
            facecolor=viz.GREY_COLORS[3], zorder=20)
        ax.plot(
            0.05, this_center + 0.2*y_sep, 
            markersize=markersize, c=model_specs[model_name]["color"], zorder=15, 
            marker=model_specs[model_name]["marker"], linestyle="None")
        this_center = this_center - y_sep
        if i_offset == 0:
            lg = ax.legend(loc='upper left', fontsize=8, frameon=False, bbox_to_anchor=(0, 1.05))
            
    ax.tick_params(labelsize=8)
    ax.set_xlabel("IoU de par", fontsize=8)
    ax.set_xlim([0, 1])
    ax.set_xticks([0, 1.0])
    ax.set_xticks(np.arange(0, 1 + 0.001, 0.1), minor=True)
    ax.xaxis.labelpad = -8
    if i == 0:
        ax.set_ylabel("Densidad", fontsize=8)
    ax.set_ylim([0, 1])
    ax.set_yticks([])
    ax.grid(axis="x", which="minor")
    ax.set_aspect("equal")
    
    ax.text(
        x=-0.01, y=1.05, fontsize=16, s=r"$\bf{%s}$" % letters2[i],
        ha="left", transform=ax.transAxes)
    
plt.tight_layout()

# Get legend methods
labels_to_lines_dict = {}
for ax in axes[0, :]:
    t_lines, t_labels = ax.get_legend_handles_labels()
    for lbl, lin in zip(t_labels, t_lines):
        labels_to_lines_dict[lbl] = lin
labels = ["REDv2-Time", "REDv2-CWT", "DOSED", "A7", "Spinky"]
lines = [labels_to_lines_dict[lbl] for lbl in labels]
lg1 = fig.legend(
    lines, labels, fontsize=7, loc="lower center",
    bbox_to_anchor=(0.5, 0.49), ncol=len(labels), frameon=False, handletextpad=0.5)

if save_figure:
    # Save figure
    fname_prefix = "result_comparison_f1_iou"
    plt.savefig("%s.pdf" % fname_prefix, bbox_extra_artists=(lg1,), bbox_inches="tight", pad_inches=0.4)
    plt.savefig("%s.png" % fname_prefix, bbox_extra_artists=(lg1,), bbox_inches="tight", pad_inches=0.4)
    plt.savefig("%s.svg" % fname_prefix, bbox_extra_artists=(lg1,), bbox_inches="tight", pad_inches=0.4)

plt.show()

# Parameters: By-event overlap
[5CV only, MODA y MASS-KC, by-event all-in] matches individuales: duracion real vs predicha (ajuste lineal y R2), duracion real vs IoU  ¿scatter o hist2D?

In [None]:
def get_durations(events_list, detections_list):
    # iou_matching = []  # Array for IoU for every true event (gs)
    # idx_matching = []  # Array for the index associated with the true event.
    _, idx_matching_list = metrics.matching_with_list(events_list, detections_list)
    durations_real_list = []
    durations_pred_list = []
    for i in range(len(events_list)):
        events = events_list[i]
        detections = detections_list[i]
        if events.size == 0 or detections.size == 0:
            continue
        idx_matching = idx_matching_list[i]
        valid_event_locs = np.where(idx_matching != -1)[0]
        if valid_event_locs.size == 0:
            continue
        events_m = events[valid_event_locs]
        detections_m = detections[idx_matching[valid_event_locs]]
        durations_real = events_m[:, 1] - events_m[:, 0] + 1
        durations_pred = detections_m[:, 1] - detections_m[:, 0] + 1
        durations_real_list.append(durations_real)
        durations_pred_list.append(durations_pred)
    durations_real_list = np.concatenate(durations_real_list)
    durations_pred_list = np.concatenate(durations_pred_list)
    return durations_real_list, durations_pred_list

In [None]:
models = [constants.V2_TIME, constants.V2_CWT1D]
baselines_ss = ['dosed', 'a7']
baselines_kc = ['dosed', 'spinky']
print_model_names = {
    constants.V2_TIME: 'REDv2-Time',
    constants.V2_CWT1D: 'REDv2-CWT',
    'dosed': 'DOSED',
    'a7': 'A7',
    'spinky': 'Spinky'
}
print_dataset_names = {
    (constants.MASS_SS_NAME, 1): "MASS-SS2-E1SS",
    (constants.MASS_SS_NAME, 2): "MASS-SS2-E2SS",
    (constants.MASS_KC_NAME, 1): "MASS-SS2-KC",
    (constants.MODA_SS_NAME, 1): "MASS-MODA",
    (constants.INTA_SS_NAME, 1): "INTA-UCH",
}

eval_configs = [
    dict(dataset_name=constants.MODA_SS_NAME, expert=1, strategy='5cv', seeds=3),
    dict(dataset_name=constants.MASS_KC_NAME, expert=1, strategy='5cv', seeds=3),
]

metrics_list = []
for config in eval_configs:
    print("\nLoading", config)
    dataset = reader.load_dataset(config["dataset_name"], verbose=False)
    baselines = baselines_ss if dataset.event_name == constants.SPINDLE else baselines_kc
    
    # Collect predictions
    pred_dict = {}
    for model_version in models:
        tmp_dict = fig_utils.get_red_predictions(model_version, config["strategy"], dataset, config["expert"], verbose=False)
        # Retrieve only predictions, same format as baselines
        pred_dict[model_version] = {}
        for k in tmp_dict.keys():
            fold_subjects = tmp_dict[k][constants.TEST_SUBSET].all_ids
            fold_predictions = tmp_dict[k][constants.TEST_SUBSET].get_stamps()
            pred_dict[model_version][k] = {s: pred for s, pred in zip(fold_subjects, fold_predictions)}
    for baseline_name in baselines:
        pred_dict[baseline_name] = fig_utils.get_baseline_predictions(baseline_name, config["strategy"], config["dataset_name"], config["expert"])
    # print("Loaded models:", pred_dict.keys())
    
    # Retrieve matchings (it does not matter macro or micro because all events are grouped together)
    _, _, test_ids_list = get_partitions(dataset, config["strategy"], config["seeds"])
    n_folds = len(test_ids_list)
    table = {'Detector': [], 'Duration_real': [], 'Duration_pred': []}
    for model_name in pred_dict.keys():
        durations_real_list = []
        durations_pred_list = []
        for k in range(n_folds):
            subject_ids = test_ids_list[k]
            feed_d = FeederDataset(dataset, subject_ids, constants.N2_RECORD, which_expert=config["expert"])
            events_list = feed_d.get_stamps()
            detections_list = [pred_dict[model_name][k][subject_id] for subject_id in subject_ids]
            durations_real, durations_pred = get_durations(events_list, detections_list)
            durations_real_list.append(durations_real)
            durations_pred_list.append(durations_pred)
        durations_real_list = np.concatenate(durations_real_list).astype(np.float32) / dataset.fs
        durations_pred_list = np.concatenate(durations_pred_list).astype(np.float32) / dataset.fs
        table['Detector'].append(model_name)
        table['Duration_real'].append(durations_real_list)
        table['Duration_pred'].append(durations_pred_list)
    metrics_list.append(table)
print("Metrics computed.")

In [None]:
save_figure = False

use_hist = True
hist_temp_res = 0.05
scatter_alpha = 0.01
baseline_color = viz.GREY_COLORS[8]
letters = ['A', 'B', 'C', 'D']
letters2 = ['E', 'F', 'G', 'H']
model_specs = {
    constants.V2_TIME: dict(marker='o', color=viz.PALETTE['blue']),
    constants.V2_CWT1D: dict(marker='o', color=viz.PALETTE['red']),
    'dosed': dict(marker='s', color=baseline_color),
    'a7': dict(marker='^', color=baseline_color),
    'spinky': dict(marker='v', color=baseline_color),
}

fig, axes = plt.subplots(2, 4, figsize=(8, 4.7), dpi=200, sharex=True, sharey=True)
for i, config in enumerate(eval_configs):
    
    max_dur = 2 #3 if 'moda' in config["dataset_name"] else 3
    
    x_bins = np.arange(0, max_dur + 0.001, hist_temp_res)
    y_bins = np.arange(0, max_dur + 0.001, hist_temp_res)
    x_centers = x_bins[:-1] + x_bins[1]/2 - x_bins[0]/2
    y_centers = y_bins[:-1] + y_bins[1]/2 - y_bins[0]/2
    xv, yv = np.meshgrid(x_centers, y_centers)
    
    # Duration scatter
    axx = axes[i, :]
    metric_dict = metrics_list[i]
    n_models = len(metric_dict['Detector'])
    for j in range(n_models):
        model_name = metric_dict['Detector'][j]
        ax = axx[j]
        if use_hist:
            hist, _, _ = np.histogram2d(
                metric_dict['Duration_real'][j], metric_dict['Duration_pred'][j], 
                bins=[x_bins, y_bins], density=True)
            ax.hist2d(
                xv.flatten(), yv.flatten(), bins=[x_bins, y_bins], weights=np.transpose(hist).flatten(), cmap='Blues')
        else:
            ax.plot(
                metric_dict['Duration_real'][j], metric_dict['Duration_pred'][j], 
                linestyle='None', marker='o', markersize=3, alpha=scatter_alpha, markeredgewidth=0.0)
        dataset_str = print_dataset_names[(config["dataset_name"], config["expert"])]
        model_str = print_model_names[model_name]
        ax.set_title('%s, %s' % (model_str, dataset_str), loc="left", fontsize=8)
        ax.tick_params(labelsize=8)
        ax.set_xlim([0, max_dur])
        ax.set_ylim([0, max_dur])
        ax.set_xticks([0, max_dur])
        ax.set_yticks([0, max_dur])
        ax.set_xticks(np.arange(0, max_dur + 0.001, 0.5), minor=True)
        ax.set_yticks(np.arange(0, max_dur + 0.001, 0.5), minor=True)
        ax.set_aspect('equal')
        ax.xaxis.labelpad = -7
        ax.yaxis.labelpad = -7
        ax.grid(which="minor")
        ax.plot([0, max_dur], [0, max_dur], color=viz.GREY_COLORS[4], linewidth=0.7, zorder=5)
        fig_utils.linear_regression(metric_dict['Duration_real'][j], metric_dict['Duration_pred'][j], 0.3, 1.7, ax)
        print(dataset_str, "max",metric_dict['Duration_real'][j].max(), "prct", np.percentile(metric_dict['Duration_real'][j], 98))
        if i == 1:
            axes[i, j].set_xlabel("Duración real (s)", fontsize=8)
        letters_selected = letters if i == 0 else letters2
        ax.text(
            x=-0.01, y=1.15, fontsize=16, s=r"$\bf{%s}$" % letters_selected[j],
            ha="left", transform=ax.transAxes)
            
    axes[i, 0].set_ylabel("Duración predicha (s)", fontsize=8)
        
plt.tight_layout()

if save_figure:
    # Save figure
    fname_prefix = "result_comparison_durations"
    plt.savefig("%s.pdf" % fname_prefix, bbox_inches="tight", pad_inches=0.01)
    plt.savefig("%s.png" % fname_prefix, bbox_inches="tight", pad_inches=0.01)
    plt.savefig("%s.svg" % fname_prefix, bbox_inches="tight", pad_inches=0.01)

plt.show()

# Parameters: By-subject parameters
[5CV only, MODA y MASS-KC, by-subject, pintar por fase] parámetro experto vs modelo, mostrando ajuste lineal y R2: duracion promedio, densidad promedio, amplitud PP promedio, y PR (maxSigma/broadNoDelta) promedio (only SS).


```
Skipped subject, moda_ss a7 Events shape (2, 2) Detections shape (0, 2)
Skipped subject, moda_ss a7 Events shape (2, 2) Detections shape (0, 2)
Skipped subject, moda_ss a7 Events shape (2, 2) Detections shape (0, 2)
```

In [None]:
models = [constants.V2_TIME, constants.V2_CWT1D]
baselines_ss = ['dosed', 'a7']
baselines_kc = ['dosed', 'spinky']
print_model_names = {
    constants.V2_TIME: 'REDv2-Time',
    constants.V2_CWT1D: 'REDv2-CWT',
    'dosed': 'DOSED',
    'a7': 'A7',
    'spinky': 'Spinky'
}
print_dataset_names = {
    (constants.MASS_SS_NAME, 1): "MASS-SS2-E1SS",
    (constants.MASS_SS_NAME, 2): "MASS-SS2-E2SS",
    (constants.MASS_KC_NAME, 1): "MASS-SS2-KC",
    (constants.MODA_SS_NAME, 1): "MASS-MODA",
    (constants.INTA_SS_NAME, 1): "INTA-UCH",
}

eval_configs = [
    dict(dataset_name=constants.MODA_SS_NAME, expert=1, strategy='5cv', seeds=3),
    dict(dataset_name=constants.MASS_KC_NAME, expert=1, strategy='5cv', seeds=3),
]

metrics_list = []
for config in eval_configs:
    print("\nLoading", config)
    dataset = reader.load_dataset(config["dataset_name"], verbose=False)
    baselines = baselines_ss if dataset.event_name == constants.SPINDLE else baselines_kc
    
    # Collect predictions
    pred_dict = {}
    for model_version in models:
        tmp_dict = fig_utils.get_red_predictions(model_version, config["strategy"], dataset, config["expert"], verbose=False)
        # Retrieve only predictions, same format as baselines
        pred_dict[model_version] = {}
        for k in tmp_dict.keys():
            fold_subjects = tmp_dict[k][constants.TEST_SUBSET].all_ids
            fold_predictions = tmp_dict[k][constants.TEST_SUBSET].get_stamps()
            pred_dict[model_version][k] = {s: pred for s, pred in zip(fold_subjects, fold_predictions)}
    for baseline_name in baselines:
        pred_dict[baseline_name] = fig_utils.get_baseline_predictions(baseline_name, config["strategy"], config["dataset_name"], config["expert"])
    # print("Loaded models:", pred_dict.keys())
    
    # Retrieve by subject parameters (MODA only if 10 blocks)
    if config["dataset_name"] == constants.MODA_SS_NAME:
        valid_subjects = [sub_id for sub_id in dataset.all_ids if dataset.data[sub_id]['n_blocks'] == 10]
        phase_subjects = {sub_id: dataset.data[sub_id]['phase'] for sub_id in valid_subjects}
        stat_spindle = True
    else:
        valid_subjects = dataset.all_ids
        phase_subjects = {sub_id: 1 for sub_id in valid_subjects}
        stat_spindle = False
    _, _, test_ids_list = get_partitions(dataset, config["strategy"], config["seeds"])
    n_folds = len(test_ids_list)
    table = {
        'Detector': [], 
        'Phase': [],
        'Duration_mean_real': [], 
        'Duration_mean_pred': [], 
        'Density_real': [],
        'Density_pred': [],
        'AmplitudePP_mean_real': [],
        'AmplitudePP_mean_pred': [],
    }
    for model_name in pred_dict.keys():
        tmp_table = {
            'Phase': [],
            'Duration_mean_real': [], 
            'Duration_mean_pred': [], 
            'Density_real': [],
            'Density_pred': [],
            'AmplitudePP_mean_real': [],
            'AmplitudePP_mean_pred': [],
        }
        for k in range(n_folds):
            subject_ids = test_ids_list[k]
            feed_d = FeederDataset(dataset, subject_ids, constants.N2_RECORD, which_expert=config["expert"])
            events_list = feed_d.get_stamps()
            detections_list = [pred_dict[model_name][k][subject_id] for subject_id in subject_ids]
            for i_sub, subject_id in enumerate(subject_ids):
                if subject_id not in valid_subjects:
                    continue
                tmp_table['Phase'].append(phase_subjects[subject_id])
                events = events_list[i_sub]
                detections = detections_list[i_sub]
                if events.size * detections.size == 0:
                    print("Skipped subject,", config["dataset_name"], model_name, "Events shape", events.shape, "Detections shape", detections.shape)
                    continue
                
                # Duration
                duration_real = np.mean((events[:, 1] - events[:, 0] + 1) / dataset.fs)
                duration_pred = np.mean((detections[:, 1] - detections[:, 0] + 1) / dataset.fs)
                tmp_table['Duration_mean_real'].append(duration_real)
                tmp_table['Duration_mean_pred'].append(duration_pred)
                
                # Density
                n2_pages = dataset.get_subject_pages(subject_id, pages_subset=constants.N2_RECORD)
                n2_minutes = n2_pages.size * dataset.page_duration / 60
                density_real = events.shape[0] / n2_minutes
                density_pred = detections.shape[0] / n2_minutes
                tmp_table['Density_real'].append(density_real)
                tmp_table['Density_pred'].append(density_pred)
                
                # Amplitude
                signal = dataset.get_subject_signal(subject_id, normalize_clip=False)
                if stat_spindle:
                    filt_signal = utils.broad_filter(signal, dataset.fs, lowcut=9, highcut=17)
                else:
                    filt_signal = utils.filter_iir_lowpass(signal, dataset.fs, highcut=7)
                signal_events = [filt_signal[e[0]:e[1]+1] for e in events]
                signal_detections = [filt_signal[e[0]:e[1]+1] for e in detections]
                amplitude_real = np.mean([(s.max()-s.min()) for s in signal_events])
                amplitude_pred = np.mean([(s.max()-s.min()) for s in signal_detections])
                tmp_table['AmplitudePP_mean_real'].append(amplitude_real)
                tmp_table['AmplitudePP_mean_pred'].append(amplitude_pred)
                
        table['Detector'].append(model_name)
        for key in tmp_table.keys():
            table[key].append(tmp_table[key])
    metrics_list.append(table)
print("Metrics computed.")

In [None]:
save_figure = False

scatter_alpha = 0.5
markersize = 3
baseline_color = viz.GREY_COLORS[8]
letters = ['A', 'B', 'C', 'D']
letters2 = ['E', 'F', 'G', 'H']
letters3 = ['I', 'J', 'K', 'L']
model_specs = {
    constants.V2_TIME: dict(marker='o', color=viz.PALETTE['blue']),
    constants.V2_CWT1D: dict(marker='o', color=viz.PALETTE['red']),
    'dosed': dict(marker='s', color=baseline_color),
    'a7': dict(marker='^', color=baseline_color),
    'spinky': dict(marker='v', color=baseline_color),
}
print_name = {
    'Duration_mean': 'Duración', 'Density': 'Densidad', 'AmplitudePP_mean': 'Amplitud PP'
}
units = {
    'Duration_mean': 's', 'Density': 'epm', 'AmplitudePP_mean': '$\mu$V'
}
decimals = {
    'Duration_mean': 1, 'Density': 0, 'AmplitudePP_mean': -1
}
resolutions = {
    'Duration_mean': 0.1, 'Density': 1, 'AmplitudePP_mean': 10
}
for i_config, config in enumerate(eval_configs):
    metric_dict = metrics_list[i_config]
    n_models = len(metric_dict['Detector'])
    dataset_str = print_dataset_names[(config["dataset_name"], config["expert"])]
    print("Processing", dataset_str)
    fig, axes = plt.subplots(3, 4, figsize=(8, 7), dpi=200)
    for i, param_name in enumerate(['Duration_mean', 'Density', 'AmplitudePP_mean']):
        # Find range
        min_val = 1000
        max_val = 0
        for j in range(n_models):
            x_data = np.array(metric_dict['%s_real' % param_name][j])
            y_data = np.array(metric_dict['%s_pred' % param_name][j])
            joint_data = np.concatenate([x_data, y_data])
            min_val = min(min_val, joint_data.min())
            max_val = max(max_val, joint_data.max())
        range_width = max_val - min_val
        min_val = max(0, min_val - 0.1 * range_width)
        max_val = max_val + 0.1 * range_width
        min_val = np.around(min_val, decimals=decimals[param_name])
        max_val = np.around(max_val, decimals=decimals[param_name])
        # print(param_name, min_val, max_val)
        if config["dataset_name"] == constants.MODA_SS_NAME and param_name == 'AmplitudePP_mean':
            this_resolution = resolutions[param_name] / 2
        elif config["dataset_name"] == constants.MASS_KC_NAME and param_name == 'Density':
            this_resolution = resolutions[param_name] / 2
        else:
            this_resolution = resolutions[param_name]
        minor_ticks = np.arange(min_val, max_val + 0.001, this_resolution)
        major_ticks = [min_val, max_val]
        
        for j in range(n_models):
            ax = axes[i, j]
            model_name = metric_dict['Detector'][j]
            x_data = np.array(metric_dict['%s_real' % param_name][j])
            y_data = np.array(metric_dict['%s_pred' % param_name][j])
            ax.plot(
                x_data, y_data, linestyle="None", marker='o', 
                markersize=markersize, alpha=scatter_alpha, markeredgewidth=0.0, color=viz.PALETTE['blue'])

            model_str = print_model_names[model_name]
            ax.set_title('%s' % model_str, loc="left", fontsize=8)
            ax.tick_params(labelsize=8)
            ax.set_aspect('equal')
            ax.set_xlim([min_val, max_val])
            ax.set_ylim([min_val, max_val])
            ax.plot([min_val, max_val], [min_val, max_val], color=viz.GREY_COLORS[4], linewidth=0.7, zorder=5)
            ax.set_xlabel("%s real (%s)" % (print_name[param_name], units[param_name]), fontsize=8)
            ax.set_xticks(major_ticks)
            ax.set_xticks(minor_ticks, minor=True)
            ax.set_yticks(major_ticks)
            ax.set_yticks(minor_ticks, minor=True)
            ax.grid(which="minor")
            if j == 0:
                ax.set_ylabel("%s pred. (%s)" % (print_name[param_name], units[param_name]), fontsize=8)
            else:
                ax.set_yticklabels([])
            
            ax.xaxis.labelpad = -7
            if config["dataset_name"] == constants.MODA_SS_NAME and param_name == 'Duration_mean':
                ax.yaxis.labelpad = -9.5
            elif config["dataset_name"] == constants.MODA_SS_NAME and param_name == 'AmplitudePP_mean':
                ax.yaxis.labelpad = -9
            elif config["dataset_name"] == constants.MASS_KC_NAME and param_name == 'AmplitudePP_mean':
                ax.yaxis.labelpad = -12
            elif config["dataset_name"] == constants.MASS_KC_NAME and param_name == 'Density':
                ax.yaxis.labelpad = -1
            elif config["dataset_name"] == constants.MASS_KC_NAME and param_name == 'Duration_mean':
                ax.yaxis.labelpad = -8.5
            else:
                ax.yaxis.labelpad = -7
            new_range = max_val - min_val
            fig_utils.linear_regression(
                x_data, y_data, min_val + 0.1*new_range, max_val-0.1*new_range, ax,
                frameon=False, fontsize=8, loc="lower right", bbox_to_anchor=(1.05, -0.05)
            )
            letters_selected = [letters, letters2, letters3][i]
            ax.text(
                x=-0.01, y=1.15, fontsize=16, s=r"$\bf{%s}$" % letters_selected[j],
                ha="left", transform=ax.transAxes)
            
    plt.tight_layout()
    
    if save_figure:
        # Save figure
        fname_prefix = "result_comparison_bysubject_%s" % config["dataset_name"]
        plt.savefig("%s.pdf" % fname_prefix, bbox_inches="tight", pad_inches=0.01)
        plt.savefig("%s.png" % fname_prefix, bbox_inches="tight", pad_inches=0.01)
        plt.savefig("%s.svg" % fname_prefix, bbox_inches="tight", pad_inches=0.01)

    plt.show()

# Subgroup performance analysis
[5CV only, MODA y MASS-KC, by-fold ambos micro] Desempeño en subgrupos: intervalo de duracion, intervalo de amplitud; only SS: intervalo de PR (maxSigma/broadNoDelta) , frecuencia bajo o sobre 13 (medida por cruces por cero).