# Run inference and plot ROC curves

## Run inference on test set

These inference results are used later to plot ROC curves.

In [None]:
import csv
from pathlib import Path
from typing import Union

import numpy as np
from PIL import Image
import tensorflow as tf

PathType = Union[str, Path]

In [None]:
def get_model(weights: PathType = None) -> tf.keras.Model:
    tfkl = tf.keras.layers

    # This is from the tf.keras.applications.efficientnet implementation in version
    # 2.5.0 of tensorflow.
    DENSE_KERNEL_INITIALIZER = {
        "class_name": "VarianceScaling",
        "config": {"scale": 1.0 / 3.0, "mode": "fan_out", "distribution": "uniform"},
    }

    base_model = tf.keras.applications.EfficientNetB4(
        include_top=False,
        input_shape=(380, 380, 3),
        weights=None,
    )
    base_model.activity_regularizer = tf.keras.regularizers.l2(l=0.01)

    _x = tfkl.GlobalAveragePooling2D(name="avg_pool")(base_model.output)
    _x = tfkl.Dropout(0.5)(_x)
    _x = tfkl.Dense(
        1,
        activation="sigmoid",
        name="predictions",
        kernel_initializer=DENSE_KERNEL_INITIALIZER,
    )(_x)
    model = tf.keras.Model(inputs=base_model.input, outputs=_x)
    if weights is not None:
        model.load_weights(weights)
    return model

In [None]:
def load_image(path: PathType) -> np.ndarray:
    """Load and process an image in the same way that was done for training."""
    img = Image.open(path)
    img = img.convert('RGB')
    img = img.resize(size=(380, 380), resample=Image.LANCZOS)
    img = np.asarray(img)
    assert img.dtype == np.uint8
    return img.astype(np.float32)

In [None]:
# Load all images in testing set.
test_set_paths = list(Path("test-set").glob("*.png"))
x = np.stack([load_image(p) for p in test_set_paths], axis=0)
x.shape

In [None]:
# Map the checkpoints to the directories where we will save outputs.
mapping = {
    "checkpoints/efficientnetb4_aug_none/ckpt_137_0.0000.hdf5": "outputs/efficientnetb4_aug_none/",
    "checkpoints/efficientnetb4_aug_base/ckpt_292_0.0000.hdf5": "outputs/efficientnetb4_aug_base/",
    "checkpoints/efficientnetb4_aug_base_and_noise/ckpt_238_0.0000.hdf5": "outputs/efficientnetb4_aug_base_and_noise/"
}

# Run inference on all images in the test set, for each model.
for checkpoint, output_path in mapping.items():
    output = Path(output_path) / "inference.csv"
    print(f"++ Running inference using {checkpoint} and saving to {output}")
    output.parent.mkdir(parents=True, exist_ok=True)

    model = get_model(weights=checkpoint)
    y_probs = model.predict(x, batch_size=8, verbose=True)
    y_probs = y_probs.flatten()
    y_preds = (y_probs > 0.5).astype(int)
    y_preds_str = map(lambda p: "gbm" if p == 0 else "pcnsl", y_preds)
    filenames = [p.name for p in test_set_paths]
    # Write results to CSV.
    rows = list(zip(filenames, y_preds_str, 1 - y_probs, y_probs))
    rows.insert(0, ("filename", "prediction", "prob_gbm", "prob_pcnsl"))
    with open(output, "w", newline="") as csvfile:
        writer = csv.writer(csvfile)
        writer.writerows(rows)

## Plot ROC curves

The following resources were used to write the code below:
- [Scikit-learn example of plotting ROC](https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html#sphx-glr-auto-examples-model-selection-plot-roc-py)
- [StackOverflow post about AUC confidence intervals](https://stackoverflow.com/a/19132400/5666087)

We use bootstrapping to estimate a confidence interval around our AUROC (aka AUC).

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
from sklearn import metrics

In [None]:
def auc_bootstrap(y_true, y_pred, n_bootstraps, seed=None):
    # With help from https://stackoverflow.com/a/19132400/5666087
    bootstrapped_aucs = np.empty(n_bootstraps)
    prng = np.random.RandomState(seed)
    for i in range(n_bootstraps):
        indices = prng.randint(0, y_pred.shape[0], y_pred.shape[0])
        if len(np.unique(y_true[indices])) < 2:
            continue
        bootstrapped_aucs[i] = metrics.roc_auc_score(
            y_true[indices], y_pred[indices])
        print(f"{round((i + 1) / n_bootstraps * 100, 2)} % completed bootstrapping", end="\r")
    print()
    bootstrapped_aucs.sort()
    return bootstrapped_aucs

def plot_roc(y_true, y_pred, positive_class, n_bootstraps=10000, seed=None):
    fpr, tpr, _ = metrics.roc_curve(y_true=y_true, y_score=y_score)
    
    aucs = auc_bootstrap(y_true, y_score, n_bootstraps=n_bootstraps, seed=seed)
    roc_auc = aucs.mean()
    confidence_95 = aucs[int(0.025 * aucs.shape[0])], aucs[int(0.975 * aucs.shape[0])]

    fig = plt.figure()
    lw = 2
    plt.plot(fpr, tpr, color='black', lw=lw)
    plt.plot([0, 1], [0, 1], color='black', lw=lw, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f'ROC for GBM vs PCNSL ({positive_class} is positive class)')
    
    print(f"ROC curve (area = {roc_auc:0.02f}")
    print(f"95% CI = {confidence_95[0]:0.2f} - {confidence_95[1]:0.2f}")
    print(n_bootstraps, "bootstraps")
    
    return fig, roc_auc, confidence_95

In [None]:
N_BOOTSTRAPS = 10000
SEED = 42

for inference_output_dir in mapping.values():
    inference_output_dir = Path(inference_output_dir)
    prediction_file = inference_output_dir / "inference.csv"
    print("\n++ Calculating metrics for", prediction_file)
    
    prediction_file = Path(prediction_file)

    df = pd.read_excel("ground-truth.xlsx", sheet_name=2, index_col='filename')
    df_probs = pd.read_csv(prediction_file, index_col="filename")
    df_probs.loc[:, "class"] = df.loc[:, "class"]
    del df  # To be sure we don't reference this by accident.
    df_probs.head()

    # PCNSL == 1
    print("++ PCNSL == 1")
    y_true = (df_probs.loc[:, 'class'] == 'pcnsl').astype(int)
    y_score = df_probs.loc[:, 'prob_pcnsl']
    fig, roc_auc, confidence_95 = plot_roc(
        y_true, y_score, "PCNSL", n_bootstraps=N_BOOTSTRAPS, seed=SEED)
    fig.savefig(inference_output_dir / "pcnsl_roc_curve.pdf")
    with (inference_output_dir / "pcnsl_metrics.txt").open("w") as f:
        print("PCNSL results")
        print(f"ROC AUC = {roc_auc}", file=f)
        print(f"95% CI = {confidence_95[0]:0.2f} - {confidence_95[1]:0.2f}", file=f)
        print(f"Using {N_BOOTSTRAPS:,d} bootstraps", file=f)
        print(file=f)
        print(metrics.classification_report(y_true, y_score > 0.5, target_names=["GBM", "PCNSL"]), file=f)
        

    # GBM == 1
    print("++ GBM == 1")
    y_true = (df_probs.loc[:, 'class'] == 'gbm').astype(int)
    y_score = df_probs.loc[:, 'prob_gbm']
    fig, roc_auc, confidence_95 = plot_roc(
        y_true, y_score, "GBM", n_bootstraps=N_BOOTSTRAPS, seed=SEED)
    fig.savefig(inference_output_dir / "gbm_roc_curve.pdf")
    with (inference_output_dir / "gbm_metrics.txt").open("w") as f:
        print("GBM results", file=f)
        print(f"ROC AUC = {roc_auc}", file=f)
        print(f"95% CI = {confidence_95[0]:0.2f} - {confidence_95[1]:0.2f}", file=f)
        print(f"Using {N_BOOTSTRAPS:,d} bootstraps", file=f)
        print(file=f)
        print(metrics.classification_report(y_true, y_score > 0.5, target_names=["PCNSL", "GBM"]), file=f)