In [None]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}
// Stops auto-scrolling so entire output is visible: see https://stackoverflow.com/a/41646403

In [None]:
# Default parameter values. They will be overwritten by papermill notebook parameters.
# This cell must carry the tag "parameters" in its metadata.
from pathlib import Path
import pickle
import codecs

innereye_path = Path.cwd().parent.parent.parent
train_metrics_csv = ""
val_metrics_csv = innereye_path / 'Tests' / 'ML' / 'reports' / 'val_metrics_classification.csv'
test_metrics_csv = innereye_path / 'Tests' / 'ML' / 'reports' / 'test_metrics_classification.csv'
number_best_and_worst_performing = 20
dataset_csv_path=innereye_path / 'Tests' / 'ML' / 'reports' / 'dataset.csv'
config= ""

In [None]:
import sys
print(f"Adding to path: {innereye_path}")
if str(innereye_path) not in sys.path:
    sys.path.append(str(innereye_path))

%matplotlib inline
import matplotlib.pyplot as plt

config = pickle.loads(codecs.decode(config.encode(), "base64"))

from InnerEye.ML.reports.notebook_report import print_header
from InnerEye.ML.reports.classification_report import plot_pr_and_roc_curves_from_csv, \
print_k_best_and_worst_performing, print_metrics, plot_k_best_and_worst_performing, \
ReportedMetrics, get_unique_label_combinations, get_metric

import warnings
warnings.filterwarnings("ignore")
plt.rcParams['figure.figsize'] = (20, 10)

#convert params to Path
train_metrics_csv = Path(train_metrics_csv)
val_metrics_csv = Path(val_metrics_csv)
test_metrics_csv = Path(test_metrics_csv)
dataset_csv_path = Path(dataset_csv_path)

In [None]:
label_names = config.class_names
unique_label_combinations = get_unique_label_combinations(dataset_csv_path, config)
all_label_combinations = list(set([(l,) for l in label_names]) | set(map(tuple, unique_label_combinations)))


# Train Metrics

In [None]:
if train_metrics_csv.is_file():
    thresholds_per_label = [get_metric(val_metrics_csv=train_metrics_csv,
                                       test_metrics_csv=train_metrics_csv,
                                       metric=ReportedMetrics.OptimalThreshold,
                                       hues=[label],
                                       exclusive=False)
                            for label in label_names]

    for labels in all_label_combinations:
        print_header(f"Class {'|'.join(labels)}", level=3)
        print_metrics(val_metrics_csv=train_metrics_csv, test_metrics_csv=train_metrics_csv, 
                      hues=labels, all_hues=config.class_names, thresholds=thresholds_per_label, exclusive=True)

    for label in config.class_names:
        print_header(f"Class {label} (Inclusive)", level=3)
        print_metrics(val_metrics_csv=train_metrics_csv, test_metrics_csv=train_metrics_csv, 
                      hues=[label], all_hues=config.class_names, thresholds=thresholds_per_label, exclusive=False)


# Validation Metrics

In [None]:
if val_metrics_csv.is_file():
    thresholds_per_label = [get_metric(val_metrics_csv=val_metrics_csv,
                                       test_metrics_csv=val_metrics_csv,
                                       metric=ReportedMetrics.OptimalThreshold,
                                       hues=[label],
                                       exclusive=False)
                            for label in label_names]

    for labels in all_label_combinations:
        print_header(f"Class {'|'.join(labels)}", level=3)
        print_metrics(val_metrics_csv=val_metrics_csv, test_metrics_csv=val_metrics_csv, 
                      hues=labels, all_hues=config.class_names, thresholds=thresholds_per_label, exclusive=True)

    for label in config.class_names:
        print_header(f"Class {label} (Inclusive)", level=3)
        print_metrics(val_metrics_csv=val_metrics_csv, test_metrics_csv=val_metrics_csv, 
                      hues=[label], all_hues=config.class_names, thresholds=thresholds_per_label, exclusive=False)


# Test Metrics

In [None]:
if val_metrics_csv.is_file() and test_metrics_csv.is_file():
    thresholds_per_label = [get_metric(val_metrics_csv=val_metrics_csv,
                                       test_metrics_csv=test_metrics_csv,
                                       metric=ReportedMetrics.OptimalThreshold,
                                       hues=[label],
                                       exclusive=False)
                            for label in label_names]

    for labels in all_label_combinations:
        print_header(f"Class {'|'.join(labels)}", level=3)
        print_metrics(val_metrics_csv=val_metrics_csv, test_metrics_csv=test_metrics_csv, 
                      hues=labels, all_hues=config.class_names, thresholds=thresholds_per_label, exclusive=True)

    for label in config.class_names:
        print_header(f"Class {label} (Inclusive)", level=3)
        print_metrics(val_metrics_csv=val_metrics_csv, test_metrics_csv=test_metrics_csv, 
                      hues=[label], all_hues=config.class_names, thresholds=thresholds_per_label, exclusive=False)


# AUC and PR curves
## Test Set

In [None]:
if test_metrics_csv.is_file():
    for hue in config.class_names:
        print_header(f"Class {hue}", level=3)
        plot_pr_and_roc_curves_from_csv(test_metrics_csv, hue=hue)

## Validation set

In [None]:
if val_metrics_csv.is_file():
    for hue in config.class_names:
        print_header(f"Class {hue}", level=3)
        plot_pr_and_roc_curves_from_csv(val_metrics_csv, hue=hue)

## Training set

In [None]:
if train_metrics_csv.is_file():
    for hue in config.class_names:
        print_header(f"Class {hue}", level=3)
        plot_pr_and_roc_curves_from_csv(train_metrics_csv, hue=hue)

# Best and worst samples by ID

In [None]:
if val_metrics_csv.is_file() and test_metrics_csv.is_file():
    for hue in config.class_names:
        print_header(f"Class {hue}", level=3)
        print_k_best_and_worst_performing(val_metrics_csv=val_metrics_csv, test_metrics_csv=test_metrics_csv,
                                      k=number_best_and_worst_performing,
                                      hue=hue)

# Plot best and worst sample images

In [None]:
if val_metrics_csv.is_file() and test_metrics_csv.is_file() and dataset_csv_path.is_file():
    for hue in config.class_names:
        print_header(f"Class {hue}", level=3)
        plot_k_best_and_worst_performing(val_metrics_csv=val_metrics_csv, test_metrics_csv=test_metrics_csv,
                                     k=number_best_and_worst_performing, dataset_csv_path=dataset_csv_path,
                                     hue=hue, config=config)