In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import itertools
import seaborn as sns
from tueplots import figsizes, fontsizes, fonts, bundles, axes
import matplotlib as mpl

from histaug.analysis.bootstrap_augmentations import compare_bootstraps
from histaug.utils.display import RENAME_MODELS, RENAME_FEATURE_EXTRACTORS, FEATURE_EXTRACTOR_GROUPS

In [2]:
augmentation_groups_to_compare = ["Macenko_slidewise", "Macenko_patchwise", "simple_rotate", "all"]
augmentation_groups_to_compare = ["Macenko_slidewise", "simple_rotate", "all"]

dfs = []
for aug in augmentation_groups_to_compare:
    df = compare_bootstraps("none", aug)
    df["augmentation"] = aug
    dfs.append(df)

df = pd.concat(dfs).reset_index().set_index(["augmentation", "model", "feature_extractor"]).drop(columns=["target"])
df

[32m2023-11-12 17:37:21.981[0m | [34m[1mDEBUG   [0m | [36mhistaug.utils.caching[0m:[36mwrapper[0m:[36m20[0m - [34m[1mLoading bootstrapped_augmentations_none_vs_Macenko_slidewise from cache[0m
[32m2023-11-12 17:37:22.016[0m | [34m[1mDEBUG   [0m | [36mhistaug.utils.caching[0m:[36mwrapper[0m:[36m20[0m - [34m[1mLoading bootstrapped_augmentations_none_vs_simple_rotate from cache[0m
[32m2023-11-12 17:37:22.046[0m | [34m[1mDEBUG   [0m | [36mhistaug.utils.caching[0m:[36mwrapper[0m:[36m20[0m - [34m[1mLoading bootstrapped_augmentations_none_vs_all from cache[0m


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,auroc_diff
augmentation,model,feature_extractor,Unnamed: 3_level_1
Macenko_slidewise,Transformer,ctranspath,-0.038041
Macenko_slidewise,Transformer,ctranspath,0.048500
Macenko_slidewise,Transformer,ctranspath,-0.011456
Macenko_slidewise,Transformer,ctranspath,-0.048586
Macenko_slidewise,Transformer,ctranspath,-0.039406
...,...,...,...
all,AttentionMIL,vits,0.049296
all,AttentionMIL,vits,0.107455
all,AttentionMIL,vits,0.076722
all,AttentionMIL,vits,0.132783


In [218]:
from histaug.analysis.collect_results import load_aurocs
from histaug.analysis.collect_results import compute_results_table


def compute_overall_average(df):
    # Computer overall mean and std (across targets)
    targets = df.columns.get_level_values("target").unique()
    assert "average" not in targets
    n_targets = len(targets)
    overall_mean = df.stack().query("stats == 'mean'").droplevel("stats").sum(axis="columns").divide(n_targets)
    overall_std = (
        df.stack().query("stats == 'std'").droplevel("stats").pow(2).sum(axis="columns").pow(0.5).divide(n_targets)
    )
    df["average", "mean"] = overall_mean
    df["average", "std"] = overall_std
    return df


def format_feature_extractor(new_name: str):
    if new_name in FEATURE_EXTRACTOR_GROUPS:
        return f"\\textbf{{{new_name}}}"
    return new_name


df["Feature extractor"] = (
    df.index.get_level_values("feature_extractor").map(RENAME_FEATURE_EXTRACTORS).map(format_feature_extractor)
)
df["Augmentation"] = df.index.get_level_values("augmentation").map(
    lambda x: {
        "Macenko_slidewise": "stain normalisation",
        "simple_rotate": "rotate/flip",
        "all": "all augmentations",
    }.get(x, x)
)

with plt.rc_context(
    {
        **axes.lines(),
        **bundles.tmlr2023(family="sans-serif"),
        **figsizes.cvpr2022_full(),
        "figure.dpi": 300,
    }
):
    fig = plt.figure()
    gs1, gs2 = fig.add_gridspec(1, 2, width_ratios=[75, 25])

    # Create the first subplot
    ax1 = fig.add_subplot(gs1)

    # horizontal line at 0
    hline = plt.axhline(0, color="black", linewidth=0.5)

    # Boxplot with 95% confidence interval
    boxplot = sns.boxplot(
        data=df.query("model == 'AttentionMIL'"),
        x="Augmentation",
        y="auroc_diff",
        hue="Feature extractor",
        hue_order=list(map(format_feature_extractor, RENAME_FEATURE_EXTRACTORS.values())),
        # ensure 95% confidence interval is shown
        showfliers=False,
        whis=[2.5, 97.5],
        # show all data points
        # showmeans=True,
        width=0.95,
    )

    # plt.legend(loc="lower right", ncol=2, bbox_to_anchor=(1.0, 1.05), borderaxespad=0.0)
    # plt.legend(loc="upper left", ncol=2)
    plt.gca().get_legend().remove()
    plt.ylabel("Change in test AUROC")
    plt.xlabel("")

    # plt.xticks(rotation=45, ha="right")

    # light grey grid
    plt.grid(axis="y", color="lightgrey")
    plt.xlabel("")
    plt.gca().tick_params(axis="x", which="both", length=0)
    plt.title("Effect of augmentation on downstream performance")
    # plt.title("Performance with vs.\\ without stain normalisation (AttMIL)")

    ######################

    # Create the second subplot
    ax2 = fig.add_subplot(gs2)

    d = compute_results_table(df)
    d = compute_overall_average(d)["average"]
    d["group"] = d.index.get_level_values("feature_extractor").map(
        {extractor: group for group, extractors in FEATURE_EXTRACTOR_GROUPS.items() for extractor in extractors}
    )
    d = d.query("model == 'AttentionMIL' and augmentations=='none'")
    d["aug"] = d.index.get_level_values("augmentations").map({"none": "no augmentation"})

    # sns.barplot(
    #     data=d.reset_index(),
    #     x="aug",
    #     y="mean",
    #     hue="feature_extractor",
    #     hue_order=RENAME_FEATURE_EXTRACTORS.keys(),
    # )

    # order d feature extractors by the way they appear in RENAME_FEATURE_EXTRACTORS
    d = d.reset_index()
    d["feature_extractor"] = pd.Categorical(d["feature_extractor"], RENAME_FEATURE_EXTRACTORS.keys(), ordered=True)
    d = d.sort_values("feature_extractor")

    xs = []
    for group, extractors in FEATURE_EXTRACTOR_GROUPS.items():
        for i in range(len(extractors)):
            xs.append(1.5 if i == 0 else 1)
    xs = np.cumsum(xs) - 1.5

    plt.bar(
        x=xs,
        height=d["mean"],
        yerr=d["std"],
        width=1,
        color=sns.color_palette(),
    )
    plt.xlabel("")
    plt.xticks([np.mean(xs)])
    plt.gca().set_xticklabels(["no augmentation"])
    plt.gca().tick_params(axis="x", which="both", length=0)
    plt.ylabel("AUROC deterioration vs.\ best")
    plt.gca().yaxis.set_label_position("right")
    plt.gca().yaxis.tick_right()
    # plt.gca().yaxis.label.set_rotation(270)
    # plt.gca().yaxis.label.set_verticalalignment("bottom")
    # plt.gca().yaxis.label.set_horizontalalignment("center")
    plt.grid(axis="y", color="lightgrey")

    # Remove legend
    # plt.gca().get_legend().remove()

    # Put legend underneath whole figure
    hl = np.array(ax1.get_legend_handles_labels()).T
    i = 0
    leg = []
    for group in FEATURE_EXTRACTOR_GROUPS:
        for extractor in FEATURE_EXTRACTOR_GROUPS[group]:
            leg.append(hl[i])
            i += 1
        empty_handle = ax1.plot([], [], "none", label="")[0]
        leg.append((empty_handle, ""))
    hl = np.array(leg)[:-1].T
    plt.title("Relative performance comparison")

    # at small text at top of ax with "(lower is better)"
    plt.gca().text(
        0.5,
        0.965,
        "(lower is better)",
        horizontalalignment="center",
        verticalalignment="top",
        transform=plt.gca().transAxes,
        fontsize=7,
        # white background
        bbox=dict(facecolor="white", alpha=1.0, edgecolor="none", pad=0.3, boxstyle="square"),
    )

    plt.figlegend(
        *hl,
        bbox_to_anchor=(1.01, 1.0),
        loc="upper left",
        ncol=1,
        borderaxespad=0.0,
        columnspacing=0.6,
        handletextpad=0.5,
        title="Feature extractor",
    )

    # Reduce width of boxes. This hack is brought to you by
    # https://stackoverflow.com/questions/51105226/seaborn-boxplot-individual-box-spacing
    # and
    # https://stackoverflow.com/questions/36874697/how-to-edit-properties-of-whiskers-fliers-caps-etc-in-seaborn-boxplot/72333641#72333641
    # and a giant headache

    ax = ax1  # Or get the axis another way
    factor = 0.3

    box_patches = [patch for patch in ax.patches if type(patch) == mpl.patches.PathPatch]
    if len(box_patches) == 0:  # in matplotlib older than 3.5, the boxes are stored in ax2.artists
        box_patches = ax.artists
    num_patches = len(box_patches)
    lines_per_boxplot = len(ax.lines) // num_patches
    for i, patch in enumerate(box_patches):
        vertices = patch.get_path().vertices
        artist_width = vertices[1, 0] - vertices[0, 0]
        vertices[0, 0] += artist_width * (factor / 2)
        vertices[1, 0] -= artist_width * (factor / 2)
        vertices[2, 0] -= artist_width * (factor / 2)
        vertices[3, 0] += artist_width * (factor / 2)
        vertices[4, 0] += artist_width * (factor / 2)

        # Each box has associated Line2D objects (to make the whiskers, fliers, etc.)
        # Loop over them here, and use the same color as above
        lines = [
            l
            for l in ax.lines
            # check that line is not the horizontal line at 0
            if l is not hline
        ][i * lines_per_boxplot : (i + 1) * lines_per_boxplot]
        # filter horizontal lines
        lines = [
            line
            for line in lines
            if len(line.get_path().vertices) != 0 and line.get_path().vertices[0, 1] == line.get_path().vertices[1, 1]
        ]

        line = lines[-1]
        vertices = line.get_path().vertices
        # shorten horizontal line
        vertices[0, 0] += artist_width * (factor / 2)
        vertices[1, 0] -= artist_width * (factor / 2)

    ax.redraw_in_frame()

    # Add vertical lines to ax
    ax1.axvline(0.5, color="lightgrey", linewidth=0.5, linestyle="--")

    plt.show()