In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
from pprint import pprint

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

project_root = os.path.abspath('..')
sys.path.append(project_root)

from sleeprnn.helpers.reader import load_dataset
from sleeprnn.data import utils, stamp_correction
from sleeprnn.detection import metrics
from sleeprnn.common import constants, pkeys, viz

GRID_PATH = os.path.join(
    '/home/ntapia/projects/repos/sleep-rnn/resources/datasets/',
    'output_mass_ss_second_grid_e1/output_mass_ss_second')


# A7 original method

## Load data

In [None]:
dataset_name = constants.MASS_SS_NAME
which_expert = 1
fs = 200
fs_predictions = 128

dataset_params = {pkeys.FS: fs}
task_mode = constants.N2_RECORD
id_try_list = np.arange(10)

# Load expert annotations
dataset = load_dataset(dataset_name, params=dataset_params)
all_train_ids = dataset.train_ids
page_size = dataset.page_size
print('Page size:', page_size)
print('All train ids', all_train_ids)
n2_dict = {}
for subject_id in all_train_ids:
    n2_dict[subject_id] = dataset.get_subject_pages(subject_id, pages_subset=task_mode)

## Load grid search predictions (E1)

In [None]:
def filter_settings(full_settings, valid_absSigPow, valid_relSigPow, valid_sigCov, valid_sigCorr):
    filtered_settings = []
    for setting in full_settings:
        setting_parse = setting.split("_")
        tmp_dict = {}
        for param in setting_parse:
            param = param.split("(")
            param_name = param[0]
            param_value = param[1][:-1]
            tmp_dict[param_name] = float(param_value)
        cond1 = True if valid_absSigPow is None else tmp_dict['absSigPow'] in valid_absSigPow
        cond2 = True if valid_relSigPow is None else tmp_dict['relSigPow'] in valid_relSigPow
        cond3 = True if valid_sigCov is None else tmp_dict['sigCov'] in valid_sigCov
        cond4 = True if valid_sigCorr is None else tmp_dict['sigCorr'] in valid_sigCorr
        if cond1 and cond2 and cond3 and cond4:
            filtered_settings.append(setting)
    return filtered_settings


def print_available_settings(full_settings):
    absSigPow_list = []
    relSigPow_list = []
    sigCov_list = []
    sigCorr_list = []
    for setting in full_settings:
        setting_parse = setting.split("_")
        tmp_dict = {}
        for param in setting_parse:
            param = param.split("(")
            param_name = param[0]
            param_value = param[1][:-1]
            tmp_dict[param_name] = float(param_value)
        absSigPow_list.append(tmp_dict['absSigPow'])
        relSigPow_list.append(tmp_dict['relSigPow'])
        sigCov_list.append(tmp_dict['sigCov'])
        sigCorr_list.append(tmp_dict['sigCorr'])
    absSigPow_list = np.unique(absSigPow_list)
    relSigPow_list = np.unique(relSigPow_list)
    sigCov_list = np.unique(sigCov_list)
    sigCorr_list = np.unique(sigCorr_list)
    print("absSigPow: %s" % absSigPow_list)
    print("relSigPow: %s" % relSigPow_list)
    print("sigCov: %s" % sigCov_list)
    print("sigCorr: %s" % sigCorr_list)

In [None]:
# Load predictions
pred_folder = os.path.join(GRID_PATH, 'e%d' % which_expert)
print('Loading predictions from %s' % pred_folder, flush=True)
pred_files = os.listdir(pred_folder)

pred_dict = {}
visited_settings = []
for file in pred_files:
    subject_id = int(file.split('_')[3][1:])
    setting = '_'.join(file.split('_')[4:])[:-4]
    if setting not in visited_settings:
        pred_dict[setting] = {}
        visited_settings.append(setting)
    # sample marks
    filepath = os.path.join(pred_folder, file)
    pred_data = pd.read_csv(filepath, sep='\t')
    # We substract 1 to translate from matlab to numpy indexing system
    start_samples = pred_data.start_sample.values - 1
    end_samples = pred_data.end_sample.values - 1
    pred_marks = np.stack([start_samples, end_samples], axis=1)
    # Transform to correct sampling frequency
    pred_marks = (pred_marks * fs / fs_predictions).astype(np.int32)
    # Valid subset of marks
    pred_marks_n2 = utils.extract_pages_for_stamps(pred_marks, n2_dict[subject_id], page_size)
    # Postprocessing
    pred_marks_n2 = stamp_correction.combine_close_stamps(
        pred_marks_n2, fs, min_separation=0.3)
    pred_marks_n2 = stamp_correction.filter_duration_stamps(
        pred_marks_n2, fs, min_duration=0.3, max_duration=3.0)
    # Save marks for evaluation
    pred_dict[setting][subject_id] = pred_marks_n2
print("Done.")

In [None]:
print("Total settings: %d" % len(visited_settings))
print_available_settings(visited_settings)

## Evaluate performance 

In [None]:
iou_thr = 0.2
perf_dict = {setting: {} for setting in visited_settings}
for subject_id in all_train_ids:
    expert_marks = dataset.get_subject_stamps(subject_id, which_expert=which_expert, pages_subset=task_mode)
    for setting in visited_settings:
        pred_marks_n2 = pred_dict[setting][subject_id]
        # Compare
        this_precision = metrics.metric_vs_iou(
            expert_marks, pred_marks_n2, [iou_thr], metric_name=constants.PRECISION)[0]
        this_recall = metrics.metric_vs_iou(
            expert_marks, pred_marks_n2, [iou_thr], metric_name=constants.RECALL)[0]
        tmp_results_dict = {
            'precision': this_precision,
            'recall': this_recall,
        }
        perf_dict[setting][subject_id] = tmp_results_dict
print("Done.")

## Visualize PR of val subjects

In [None]:
# Show all settings
subject_to_show = 19

settings_to_show = filter_settings(
    visited_settings, 
    valid_absSigPow=[1.75], 
    valid_relSigPow=[1.6], 
    valid_sigCov=[1.8], 
    valid_sigCorr=[0.75]
)
print("Plotting %d settings" % len(settings_to_show))
print_available_settings(settings_to_show)

fig, ax = plt.subplots(1, 1, figsize=(4, 4), dpi=100)
subject_recall_list = [perf_dict[setting][subject_to_show]['recall'] for setting in settings_to_show]
subject_precision_list = [perf_dict[setting][subject_to_show]['precision'] for setting in settings_to_show]
rest_recall_list = []
rest_precision_list = []
for setting in visited_settings:
    if setting not in settings_to_show:
        rest_recall_list.append(perf_dict[setting][subject_to_show]['recall'])
        rest_precision_list.append(perf_dict[setting][subject_to_show]['precision'])
if rest_recall_list:
    ax.plot(
    rest_recall_list, 
    rest_precision_list, 
    color=viz.GREY_COLORS[4], alpha=0.4,
    marker='o', markersize=5, linestyle='None', zorder=2, label="Not selected")
ax.plot(
    subject_recall_list, 
    subject_precision_list, 
    color=viz.PALETTE['blue'], alpha=0.6,
    marker='o', markersize=5, linestyle='None', zorder=10, label="Selected")
ax.plot([0, 1], [0, 1], zorder=1, linewidth=1, color=viz.GREY_COLORS[4])
ax.set_title('A7 - Subject %02d' % subject_to_show, fontsize=10)
ax.set_xlim([0, 1])
ax.set_ylim([0, 1])
ax.set_yticks(np.arange(11) / 10)
ax.set_xticks(np.arange(11) / 10)
ax.tick_params(labelsize=8) 
ax.grid()
ax.set_ylabel('Precision (IoU>%1.1f)' % iou_thr, fontsize=10)
ax.set_xlabel('Recall (IoU>%1.1f)' % iou_thr, fontsize=10)
ax.set_aspect('equal')
ax.legend(loc="lower left", fontsize=10)

plt.tight_layout()
plt.show()

## PR of val subjects for a single setting

In [None]:
# setting to show
valid_absSigPow = 1.75 
valid_relSigPow = 1.6 
valid_sigCov = 1.8
valid_sigCorr = 0.75

# Filter settings
setting_to_show = filter_settings(
    visited_settings, 
    valid_absSigPow=[valid_absSigPow], 
    valid_relSigPow=[valid_relSigPow], 
    valid_sigCov=[valid_sigCov], 
    valid_sigCorr=[valid_sigCorr]
)[0]

fig, ax = plt.subplots(1, 1, figsize=(4, 4), dpi=100)

all_recall = []
all_precision = []
for subject_to_show in all_train_ids:
    subject_recall = perf_dict[setting_to_show][subject_to_show]['recall']
    subject_precision = perf_dict[setting_to_show][subject_to_show]['precision']
    all_recall.append(subject_recall)
    all_precision.append(subject_precision)
    ax.plot(
        subject_recall, 
        subject_precision, 
        color=viz.PALETTE['blue'],
        marker='o', markersize=6, linestyle='None', zorder=10)
    ax.annotate(
        subject_to_show, (subject_recall, subject_precision), 
        horizontalalignment="center", verticalalignment="center", fontsize=4, color="w", zorder=20)
mean_recall = np.mean(all_recall)
mean_precision = np.mean(all_precision)
std_recall = np.std(all_recall)
std_precision = np.std(all_precision)
perf_string = "P: %1.1f\u00B1%1.1f, R: %1.1f\u00B1%1.1f" % (
    100 * mean_precision, 100 * std_precision,
    100 * mean_recall, 100 * std_recall, 
)
ax.plot([0, 1], [0, 1], zorder=1, linewidth=1, color=viz.GREY_COLORS[4])
ax.plot(
    mean_recall, mean_precision, 
    marker='o', markersize=3, linestyle="None",
    color=viz.GREY_COLORS[6], zorder=30
)
ax.set_title('A7 Validation\n%s\n%s' % (setting_to_show, perf_string), fontsize=7)
ax.set_xlim([0, 1])
ax.set_ylim([0, 1])
ax.set_yticks(np.arange(11) / 10)
ax.set_xticks(np.arange(11) / 10)
ax.tick_params(labelsize=8) 
ax.grid()
ax.set_ylabel('Precision (IoU>%1.1f)' % iou_thr, fontsize=10)
ax.set_xlabel('Recall (IoU>%1.1f)' % iou_thr, fontsize=10)
ax.set_aspect('equal')

plt.tight_layout()

# plt.savefig("pr_a7_val_%s.png" % setting_to_show, dpi=200, bbox_inches="tight", pad_inches=0.01)

plt.show()