# Imports

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import json
import os
from pprint import pprint
import sys

import ipywidgets as widgets
import matplotlib.pyplot as plt
from matplotlib import colors, gridspec
import numpy as np
from scipy.stats import gaussian_kde
from scipy.interpolate import interp1d

project_root = '..'
sys.path.append(project_root)

from sleeprnn.common import constants, pkeys, viz
from sleeprnn.common.optimal_thresholds import OPTIMAL_THR_FOR_CKPT_DICT
from sleeprnn.data import utils, stamp_correction
from sleeprnn.detection.feeder_dataset import FeederDataset
from sleeprnn.detection.postprocessor import PostProcessor
from sleeprnn.detection import metrics
from sleeprnn.helpers import reader, plotter, printer, misc, performer

RESULTS_PATH = os.path.join(project_root, 'results')
COMPARISON_PATH = os.path.join(project_root, 'resources', 'comparison_data')

%matplotlib inline
viz.notebook_full_width()

In [None]:
def models_signal_snapshot(
    t_central_sample,
    t_signal, 
    t_events_1,
    t_events_2,
    t_dets,
    t_probas,
    t_opt_thr,
    t_models,
    ax,
    title='',
    show_stamp_edge=False,
    window_seconds=10,
    fs=200,
    max_voltage=100,
    ratio_signal_models=0.15
):
    
    width_det = ratio_signal_models * 2 * max_voltage
    expert_1_color = viz.PALETTE['blue']
    expert_2_color = viz.PALETTE['purple']
    start_sample = int(t_central_sample - fs * window_seconds / 2)
    end_sample = int(start_sample + window_seconds * fs)
    x_axis = np.arange(t_signal.size) / fs
    if t_events_1 is None:
        filt_exp_1 = []
    else:
        filt_exp_1 = utils.filter_stamps(t_events_1, start_sample, end_sample)
    if t_events_2 is None:
        filt_exp_2 = []
    else:
        filt_exp_2 = utils.filter_stamps(t_events_2, start_sample, end_sample)
    filt_models = [utils.filter_stamps(this_det, start_sample, end_sample) for this_det in t_dets]
    # Signal + expert
    ax.plot(
        x_axis[start_sample:end_sample], 
        np.clip(t_signal[start_sample:end_sample], -max_voltage, max_voltage), 
        linewidth=0.9, color=viz.PALETTE['dark'], zorder=20)
    # Dummy stamp for expert labels
    ax.fill_between([start_sample, start_sample], 0, 0, facecolor=expert_1_color, alpha=0.5, zorder=10, label='E1')
    ax.fill_between([start_sample, start_sample], 0, 0, facecolor=expert_2_color, alpha=0.5, zorder=10, label='E2')
    for s_stamp in filt_exp_1:
        if show_stamp_edge:
            ax.fill_between(
                s_stamp / fs, 0, 0.5 * max_voltage, facecolor=expert_1_color, edgecolor='k', alpha=0.5, zorder=10)
        else:
            ax.fill_between(
                s_stamp / fs, 0, 0.5 * max_voltage, facecolor=expert_1_color, alpha=0.5, zorder=10)
    for s_stamp in filt_exp_2:
        if show_stamp_edge:
            ax.fill_between(
                s_stamp / fs, -0.5 * max_voltage, 0, facecolor=expert_2_color, edgecolor='k', alpha=0.5, zorder=10)
        else:
            ax.fill_between(
                s_stamp / fs, -0.5 * max_voltage, 0, facecolor=expert_2_color, alpha=0.5, zorder=10)
    # Models
    bottom = -max_voltage
    gap = 0.1 * width_det
    for j_m, model in enumerate(t_models):
        bottom = bottom - width_det - 2 * gap
        top = bottom + 2 * gap + width_det
        thr_level = bottom + gap + t_opt_thr[j_m] * width_det
        proba_level = bottom + gap + t_probas[j_m] * width_det
        ax.plot(
            x_axis[start_sample:end_sample][::8],
            proba_level[start_sample//8:end_sample//8],
            linewidth=1.5, color=model['color'], zorder=40, label=model['name'])
        ax.plot(
            [start_sample / fs, end_sample / fs], [thr_level, thr_level], 
            linewidth=1.1, color=viz.GREY_COLORS[8], zorder=30)
        ax.plot(
            [start_sample/fs, end_sample/fs], [bottom, bottom], 
            linewidth=1.1, color=viz.GREY_COLORS[8], zorder=5)
        ax.plot(
            [start_sample/fs, end_sample/fs], [top, top], 
            linewidth=1.1, color=viz.GREY_COLORS[8], zorder=5)
        for s_stamp in filt_models[j_m]:
            if show_stamp_edge:
                ax.fill_between(s_stamp / fs, bottom + gap, top - gap, facecolor=model['color'], edgecolor='k', alpha=0.5, zorder=10)
            else:
                ax.fill_between(s_stamp / fs, bottom + gap, top - gap, facecolor=model['color'], alpha=0.5, zorder=10)
    ax.set_title(title, fontsize=10)
    ax.set_ylim([bottom, max_voltage])
    ax.set_xlim([start_sample / fs, end_sample / fs])
    ax.set_yticks([-50, 0, 50])
    ax.set_xticks(start_sample / fs + np.arange(window_seconds), minor=True)
    ax.set_xticks([t_central_sample / fs])
    ax.set_xticklabels(['%1.1f [s]' % (t_central_sample / fs)])
    lg = ax.legend(
        loc='lower left', labelspacing=1.2,
        fontsize=9, frameon=False, bbox_to_anchor=(1, 0), ncol=1)
    return ax, lg

# Load data

In [None]:
filter_dates = [20191220, None]
printer.print_available_ckpt(OPTIMAL_THR_FOR_CKPT_DICT, filter_dates)

In [None]:
dataset_name = constants.MASS_SS_NAME
fs = 200
which_expert = 1
task_mode = constants.N2_RECORD
seed_id_list = [i for i in range(2)]
set_list = [constants.VAL_SUBSET, constants.TRAIN_SUBSET]

# Specify what to load
comparison_runs_list = [
    ('20191227_bsf_10runs_e1_n2_train_mass_ss/v11', 'xEnt-0.5'),  # <- si
    ('20200502_grid_losses_2020_n2_train_mass_ss/v11_cross_entropy_loss_bNone_eNone_gNone_pi0.01','xEnt-0.01'), # <- si
    ('20200502_grid_losses_2020_n2_train_mass_ss/v11_cross_entropy_smoothing_loss_bNone_e0.1_gNone_pi0.1','xEntSmooth-0.1'), # <- si
    ('20200502_grid_losses_2020_n2_train_mass_ss/v11_cross_entropy_smoothing_clip_loss_bNone_e0.1_gNone_pi0.1','xEntSmoothClip-0.1'),
]
comparison_runs_list = [
    (t_folder, t_label) for (t_folder, t_label) in comparison_runs_list if dataset_name in t_folder
]
ckpt_folder_list = [t_folder for (t_folder, t_label) in comparison_runs_list]
ckpt_folder_dict = {t_label: t_folder for (t_folder, t_label) in comparison_runs_list}
ckpt_label_dict = {t_folder: t_label for (t_folder, t_label) in comparison_runs_list}

# Load data
n_cases = len(comparison_runs_list)
dataset = reader.load_dataset(dataset_name, params={pkeys.FS: fs})
ids_dict = {
    constants.ALL_TRAIN_SUBSET: dataset.train_ids,
    constants.TEST_SUBSET: dataset.test_ids}
ids_dict.update(misc.get_splits_dict(dataset, seed_id_list))
predictions_dict = {}
for ckpt_folder in ckpt_folder_list:
    predictions_dict[ckpt_folder] = reader.read_prediction_with_seeds(
        ckpt_folder, dataset_name, task_mode, seed_id_list, set_list=set_list, parent_dataset=dataset, verbose=False)
# useful for viz
iou_hist_bins = np.linspace(0, 1, 21)
iou_curve_axis = misc.custom_linspace(0.05, 0.95, 0.05)
result_id = '%s-%s-E%d-%s' % (
    dataset_name.split('_')[0].upper(), 
    dataset_name.split('_')[1].upper(), 
    which_expert,
    task_mode.upper())

## Utilities

In [None]:
model_names = ['xEnt-0.5', 'xEnt-0.01', 'xEntSmooth-0.1', 'xEntSmoothClip-0.1']
code_names = ['xent', 'xentB', 'xentBS', 'xentBSC']
# model_names = ['RED-Time', 'ATT4', 'ATT1']
# code_names = ['v11', 'att4', 'att1']
# model_names = ['RED-CWT', 'RED-Time', 'Time+CWT']
model_colors = [viz.PALETTE['red'], viz.PALETTE['green'], viz.PALETTE['grey'], viz.PALETTE['cyan']]
# code_names = ['v19', 'v11', 'v35']

models = []
for name, code_name, color in zip(model_names, code_names, model_colors):
    models.append({'name': name, 'ckpt': ckpt_folder_dict[name], 'code_name': code_name, 'color': color})

# Visualization

In [None]:
chosen_seed = 1
set_name = 'val'

val_ids = ids_dict[chosen_seed][set_name]
events_1 = FeederDataset(dataset, val_ids, task_mode, which_expert=1).get_stamps()
if dataset_name == constants.MASS_SS_NAME:
    events_2 = FeederDataset(dataset, val_ids, task_mode, which_expert=2).get_stamps()
else:
    events_2 = len(val_ids) * [None]
subset_data = FeederDataset(dataset, val_ids, task_mode, which_expert=which_expert)
events = subset_data.get_stamps()
signals = subset_data.get_signals(normalize_clip=False)

cmp_opt_thr = []
cmp_preds = []
cmp_dets = []
cmp_probas = []
for model in models:
    t_opt_thr = OPTIMAL_THR_FOR_CKPT_DICT[model['ckpt']][chosen_seed]
    t_preds = predictions_dict[model['ckpt']][chosen_seed][set_name]
    t_preds.set_probability_threshold(t_opt_thr)
    t_dets = t_preds.get_stamps()
    t_probas = t_preds.get_probabilities()
    cmp_opt_thr.append(t_opt_thr)
    cmp_preds.append(t_preds)
    cmp_dets.append(t_dets)
    cmp_probas.append(t_probas)

matching_data_all = {}
for i, single_id in enumerate(val_ids):  
    # Matching with expert
    for j, s_dets in enumerate(cmp_dets):
        s_key = '%s_vs_exp' % models[j]['code_name']
        if i == 0:
            matching_data_all[s_key] = []
        s_iou_matching, s_idx_matching = metrics.matching(events[i], s_dets[i])
        matching_data_all[s_key].append({'iou': s_iou_matching, 'idx': s_idx_matching})
    # Matching between models
    for j_1, s_dets_1 in enumerate(cmp_dets):
        for j_2, s_dets_2 in enumerate(cmp_dets):
            if j_2<=j_1:
                continue
            s_key = '%s_vs_%s' % (models[j_2]['code_name'], models[j_1]['code_name'])
            if i == 0:
                matching_data_all[s_key] = []
            s_iou_matching, s_idx_matching = metrics.matching(s_dets_1[i], s_dets_2[i])
            matching_data_all[s_key].append({'iou': s_iou_matching, 'idx': s_idx_matching})
print("Done")

# Discrepancies between models

In [None]:
mod1_idx = 0
mod2_idx = 1
dpi = 200
figsize = (7, 4)
save_figs = False
save_txt = False
# -------------------------
prefix_str = 'diffs_%s_%s_seed%d' % (models[mod1_idx]['code_name'], models[mod2_idx]['code_name'], chosen_seed)
if save_figs or save_txt:
    os.makedirs(prefix_str, exist_ok=True)
if save_txt:
    f = open(os.path.join(prefix_str, '%s.txt' % prefix_str), 'w')
print('Database: %s, Expert: %d' % (dataset_name, which_expert))
print('Differences for Seed %d' % chosen_seed)
if save_txt:
    print('Database: %s, Expert: %d' % (dataset_name, which_expert), file=f)
    print('Differences for Seed %d' % chosen_seed, file=f)
fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)
for i, single_id in enumerate(val_ids):
    s_dets_1 = cmp_dets[mod1_idx][i]
    s_dets_2 = cmp_dets[mod2_idx][i]
    s_events = events[i]
    s_signal = signals[i]
    s_proba_1 = cmp_probas[mod1_idx][i]
    s_proba_2 = cmp_probas[mod2_idx][i]
    print("")
    print('Subject %02d (%d annotations)' % (single_id, s_events.shape[0]))
    print("Discrepancy; TP; FP")  
    if save_txt:
        print("", file=f)
        print('Subject %02d (%d annotations)' % (single_id, s_events.shape[0]), file=f)
        print("Discrepancy; TP; FP", file=f) 
    # model 1 marco y model 2 no 
    # metrics.matching(events, detections)
    iou_matching, idx_matching = metrics.matching(s_dets_1, s_dets_2)
    m1_idx_different = np.where(idx_matching == -1)[0]
    iou_matching, idx_matching = metrics.matching(s_events, s_dets_1[m1_idx_different])
    subset_success_idx = idx_matching[idx_matching != -1]
    subset_failed_idx = [i for i in range(len(m1_idx_different)) if i not in subset_success_idx]
    m1_tp_m2_fn_idx = m1_idx_different[subset_success_idx]
    m1_fp_m2_tn_idx = m1_idx_different[subset_failed_idx]
    print('%s yes | %s no' % (models[mod1_idx]['name'], models[mod2_idx]['name']), end='')
    print('; %d; %d' % (m1_tp_m2_fn_idx.size, m1_fp_m2_tn_idx.size))
    if save_txt:
        print('%s yes | %s no' % (models[mod1_idx]['name'], models[mod2_idx]['name']), end='', file=f)
        print('; %d; %d' % (m1_tp_m2_fn_idx.size, m1_fp_m2_tn_idx.size), file=f)
    # visualization
    if save_figs:
        for this_idx in np.concatenate([m1_tp_m2_fn_idx, m1_fp_m2_tn_idx]):
            if this_idx in m1_tp_m2_fn_idx:
                title = '%s TP - %s FN (S%02d-%s-Seed%d)' % (
                    models[mod1_idx]['name'], models[mod2_idx]['name'], single_id, set_name.upper(), chosen_seed)
                fname = '%s_tp_fn_s%02d_idx%d.png' % (prefix_str, single_id, this_idx)
            elif this_idx in m1_fp_m2_tn_idx:
                title = '%s FP - %s TN (S%02d-%s-Seed%d)' % (
                    models[mod1_idx]['name'], models[mod2_idx]['name'], single_id, set_name.upper(), chosen_seed)
                fname = '%s_fp_tn_s%02d_idx%d.png' % (prefix_str, single_id, this_idx)
            else:
                raise ValueError()
            ax.clear()
            ax, lg = models_signal_snapshot(
                s_dets_1[this_idx].mean(),
                s_signal, 
                events_1[i],
                events_2[i],
                [md[i] for md in cmp_dets],
                [mp[i] for mp in cmp_probas],
                cmp_opt_thr,
                models,
                ax,
                title=title)
            plt.savefig(os.path.join(prefix_str, fname), dpi=200, bbox_extra_artists=(lg,), bbox_inches="tight", pad_inches=0.01)
    # model_1 no marco y model_2 si 
    # metrics.matching(events, detections)
    iou_matching, idx_matching = metrics.matching(s_dets_2, s_dets_1)
    m2_idx_different = np.where(idx_matching == -1)[0]
    iou_matching, idx_matching = metrics.matching(s_events, s_dets_2[m2_idx_different])
    subset_success_idx = idx_matching[idx_matching != -1]
    subset_failed_idx = [i for i in range(len(m2_idx_different)) if i not in subset_success_idx]
    m1_fn_m2_tp_idx = m2_idx_different[subset_success_idx]
    m1_tn_m2_fp_idx = m2_idx_different[subset_failed_idx]
    print('%s no | %s yes'% (models[mod1_idx]['name'], models[mod2_idx]['name']), end='')
    print('; %d; %d' % (m1_fn_m2_tp_idx.size, m1_tn_m2_fp_idx.size))
    if save_txt:
        print('%s no | %s yes'% (models[mod1_idx]['name'], models[mod2_idx]['name']), end='', file=f)
        print('; %d; %d' % (m1_fn_m2_tp_idx.size, m1_tn_m2_fp_idx.size), file=f)
    # visualization
    if save_figs:
        for this_idx in np.concatenate([m1_fn_m2_tp_idx, m1_tn_m2_fp_idx]):
            if this_idx in m1_fn_m2_tp_idx:
                title = '%s FN - %s TP (S%02d-%s-Seed%d)' % (
                    models[mod1_idx]['name'], models[mod2_idx]['name'], single_id, set_name.upper(), chosen_seed)
                fname = '%s_fn_tp_s%02d_idx%d.png' % (prefix_str, single_id, this_idx)
            elif this_idx in m1_tn_m2_fp_idx:
                title = '%s TN - %s FP (S%02d-%s-Seed%d)' % (
                    models[mod1_idx]['name'], models[mod2_idx]['name'], single_id, set_name.upper(), chosen_seed)
                fname = '%s_tn_fp_s%02d_idx%d.png' % (prefix_str, single_id, this_idx)
            else:
                raise ValueError()
            ax.clear()
            ax, lg = models_signal_snapshot(
                s_dets_2[this_idx].mean(),
                s_signal, 
                events_1[i],
                events_2[i],
                [md[i] for md in cmp_dets],
                [mp[i] for mp in cmp_probas],
                cmp_opt_thr,
                models,
                ax,
                title=title)
            plt.savefig(os.path.join(prefix_str, fname), dpi=200, bbox_extra_artists=(lg,), bbox_inches="tight", pad_inches=0.01)
    total_diffs = m1_tp_m2_fn_idx.size + m1_fp_m2_tn_idx.size + m1_fn_m2_tp_idx.size + m1_tn_m2_fp_idx.size
    print("Total discrepancies: %d" % total_diffs)
    if save_txt:
        print("Total discrepancies: %d" % total_diffs, file=f)
plt.close('all')
if save_txt:
    f.close()
print("Done")

# Errors of one model

In [None]:
mod1_idx = 0

max_voltage = 100
iou_low_thr = 0.5
delta_iou = 0.01
n_samples = 4
dpi = 200
figsize = (7, 4)
save_figs = True
save_txt = True
# -------------------------
prefix_str = 'errors_%s_seed%d' % (models[mod1_idx]['code_name'], chosen_seed)
if save_figs or save_txt:
    os.makedirs(prefix_str, exist_ok=True)
if save_txt:
    f = open(os.path.join(prefix_str, '%s.txt' % prefix_str), 'w')
print('Database: %s, Expert: %d' % (dataset_name, which_expert))
print('Errors for %s and Seed %d' % (models[mod1_idx]['name'], chosen_seed))
if save_txt:
    print('Database: %s, Expert: %d' % (dataset_name, which_expert), file=f)
    print('Errors for %s and Seed %d' % (models[mod1_idx]['name'], chosen_seed), file=f)
fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)
for i, single_id in enumerate(val_ids):
    s_dets_1 = cmp_dets[mod1_idx][i]
    s_events = events[i]
    s_signal = signals[i]
    s_proba_1 = cmp_probas[mod1_idx][i]
    print("")
    print('Subject %02d (%d annotations)' % (single_id, s_events.shape[0]))
    if save_txt:
        print("", file=f)
        print('Subject %02d (%d annotations)' % (single_id, s_events.shape[0]), file=f)
    iou_matching, idx_matching = metrics.matching(s_events, s_dets_1)
    # FN indices
    fn_idx = np.where(idx_matching == -1)[0]
    fn_stamps = s_events[fn_idx]
    fn_central_samples = fn_stamps.mean(axis=1).astype(np.int32)
    print('FN: %d' % fn_central_samples.size)
    if save_txt:
        print('FN: %d' % fn_central_samples.size, file=f)
    fn_chosen_centers = np.random.choice(fn_central_samples, size=n_samples, replace=False)   
    # low IoU indices
    this_iou_low_thr = iou_low_thr
    while True:
        low_iou_idx = np.where(((iou_matching < this_iou_low_thr) & (idx_matching !=-1)))[0]
        if low_iou_idx.size >= n_samples:
            break
        this_iou_low_thr = this_iou_low_thr + delta_iou   
    low_iou_stamps = s_events[low_iou_idx]
    low_iou_central_samples = low_iou_stamps.mean(axis=1).astype(np.int32)
    print('IoU < %1.2f: %d' % (this_iou_low_thr, low_iou_central_samples.size))
    if save_txt:
        print('IoU < %1.2f: %d' % (this_iou_low_thr, low_iou_central_samples.size), file=f)
    low_iou_chosen_centers = np.random.choice(low_iou_central_samples, size=n_samples, replace=False)  
    # FP indices
    fp_idx = [i for i in range(s_dets_1.shape[0]) if i not in idx_matching]
    fp_stamps = s_dets_1[fp_idx]
    fp_central_samples = fp_stamps.mean(axis=1).astype(np.int32)
    print('FP %d' % fp_central_samples.size)
    if save_txt:
        print('FP %d' % fp_central_samples.size, file=f)
    fp_chosen_centers = np.random.choice(fp_central_samples, size=n_samples, replace=False)  
    all_centers = np.concatenate([fn_chosen_centers, low_iou_chosen_centers, fp_chosen_centers])
    if all_centers.size != np.unique(all_centers).size:
        raise ValueError()
    if save_figs:
        for this_center in all_centers:
            if this_center in fn_chosen_centers:
                title = '%s FN (S%02d-%s-Seed%d)' % (models[mod1_idx]['name'], single_id, set_name.upper(), chosen_seed)
                fname = '%s_fn_s%02d_idx%d.png' % (prefix_str, single_id, this_center)
            elif this_center in low_iou_chosen_centers:
                title = '%s IoU < %1.2f (S%02d-%s-Seed%d)' % (models[mod1_idx]['name'], this_iou_low_thr, single_id, set_name.upper(), chosen_seed)
                fname = '%s_low_iou_s%02d_idx%d.png' % (prefix_str, single_id, this_center)
            elif this_center in fp_chosen_centers:
                title = '%s FP (S%02d-%s-Seed%d)' % (models[mod1_idx]['name'], single_id, set_name.upper(), chosen_seed)
                fname = '%s_fp_s%02d_idx%d.png' % (prefix_str, single_id, this_center)
            else:
                raise ValueError()
            ax.clear()
            ax, lg = models_signal_snapshot(
                this_center,
                s_signal, 
                events_1[i],
                events_2[i],
                [md[i] for md in cmp_dets],
                [mp[i] for mp in cmp_probas],
                cmp_opt_thr,
                models,
                ax,
                title=title, max_voltage=max_voltage)
            plt.savefig(os.path.join(prefix_str, fname), dpi=200, bbox_extra_artists=(lg,), bbox_inches="tight", pad_inches=0.01)
plt.close('all')
if save_txt:
    f.close()
print("Done")

# Random matchings

In [None]:
mod1_idx = 0

max_voltage = 100
n_samples = 12
dpi = 200
figsize = (7, 4)
save_figs = True
save_txt = True
# -------------------------
prefix_str = 'matchings_%s_seed%d' % (models[mod1_idx]['code_name'], chosen_seed)
if save_figs or save_txt:
    os.makedirs(prefix_str, exist_ok=True)
if save_txt:
    f = open(os.path.join(prefix_str, '%s.txt' % prefix_str), 'w')
print('Database: %s, Expert: %d' % (dataset_name, which_expert))
print('Matchings for %s and Seed %d' % (models[mod1_idx]['name'], chosen_seed))
if save_txt:
    print('Database: %s, Expert: %d' % (dataset_name, which_expert), file=f)
    print('Matchings for %s and Seed %d' % (models[mod1_idx]['name'], chosen_seed), file=f)
fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)
for i, single_id in enumerate(val_ids):
    s_dets_1 = cmp_dets[mod1_idx][i]
    s_events = events[i]
    s_signal = signals[i]
    s_proba_1 = cmp_probas[mod1_idx][i]
    print("")
    print('Subject %02d (%d annotations)' % (single_id, s_events.shape[0]))
    if save_txt:
        print("", file=f)
        print('Subject %02d (%d annotations)' % (single_id, s_events.shape[0]), file=f)
    iou_matching, idx_matching = metrics.matching(s_events, s_dets_1)
    valid_idx = np.where(idx_matching !=-1)[0]
    valid_stamps = s_events[valid_idx]
    valid_central_samples = valid_stamps.mean(axis=1).astype(np.int32)
    print('Matchings: %d' % (valid_central_samples.size))
    if save_txt:
        print('Matchings: %d' % (valid_central_samples.size), file=f)
    valid_chosen_centers = np.random.choice(valid_central_samples, size=n_samples, replace=False)  
    if save_figs:
        for this_center in valid_chosen_centers:
            title = '%s Match (S%02d-%s-Seed%d)' % (models[mod1_idx]['name'], single_id, set_name.upper(), chosen_seed)
            fname = '%s_match_s%02d_idx%d.png' % (prefix_str, single_id, this_center)
            ax.clear()
            ax, lg = models_signal_snapshot(
                this_center,
                s_signal, 
                events_1[i],
                events_2[i],
                [md[i] for md in cmp_dets],
                [mp[i] for mp in cmp_probas],
                cmp_opt_thr,
                models,
                ax,
                title=title, max_voltage=max_voltage)
            plt.savefig(os.path.join(prefix_str, fname), dpi=200, bbox_extra_artists=(lg,), bbox_inches="tight", pad_inches=0.01)
plt.close('all')
if save_txt:
    f.close()
print("Done")