In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import itertools
import seaborn as sns
from tueplots import figsizes, fontsizes, fonts, bundles, axes

from histaug.analysis.bootstrap import compare_bootstraps
from histaug.analysis.collect_results import load_results
from histaug.utils.display import RENAME_MODELS, RENAME_FEATURE_EXTRACTORS, FEATURE_EXTRACTOR_GROUPS
from histaug.utils import rc_context, savefig

In [7]:
results = load_results()
df = compare_bootstraps(results, "magnification", "low", "high", n_bootstraps_per_config=25)
df

[32m2024-05-30 12:48:03.751[0m | [34m[1mDEBUG   [0m | [36mhistaug.utils.caching[0m:[36mwrapper[0m:[36m20[0m - [34m[1mLoading results from cache[0m
  results = process_map(fn, configs, max_workers=n_workers, tqdm_class=tqdm, desc="Computing results")
Computing results: 100%|██████████| 1890/1890 [00:25<00:00, 73.20it/s] 


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,train_dataset,test_dataset,auroc_diff
augmentations,feature_extractor,model,target,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
Macenko_patchwise,bt,AttentionMIL,BRAF,tcga_crc_BRAF,cptac_crc_BRAF,-0.015097
Macenko_patchwise,bt,AttentionMIL,BRAF,tcga_crc_BRAF,cptac_crc_BRAF,-0.121981
Macenko_patchwise,bt,AttentionMIL,BRAF,tcga_crc_BRAF,cptac_crc_BRAF,0.056111
Macenko_patchwise,bt,AttentionMIL,BRAF,tcga_crc_BRAF,cptac_crc_BRAF,-0.009058
Macenko_patchwise,bt,AttentionMIL,BRAF,tcga_crc_BRAF,cptac_crc_BRAF,0.033058
...,...,...,...,...,...,...
none,vits,AttentionMIL,subtype,tcga_brca_subtype,cptac_brca_subtype,0.031617
none,vits,AttentionMIL,subtype,tcga_brca_subtype,cptac_brca_subtype,0.039148
none,vits,AttentionMIL,subtype,tcga_brca_subtype,cptac_brca_subtype,0.027897
none,vits,AttentionMIL,subtype,tcga_brca_subtype,cptac_brca_subtype,0.035023


In [14]:
def format_feature_extractor(new_name: str):
    if new_name in FEATURE_EXTRACTOR_GROUPS:
        return f"\\textbf{{{new_name}}}"
    return new_name


df["Feature extractor"] = (
    df.index.get_level_values("feature_extractor").map(RENAME_FEATURE_EXTRACTORS).map(format_feature_extractor)
)
df["Augmentation"] = df.index.get_level_values("augmentations").map(
    lambda x: {
        "none": "without augmentation",
        "Macenko_patchwise": "stain normalisation",
    }.get(x, x)
)

model = "AttentionMIL"

with rc_context("half", journal=True):
    plt.figure()

    # thick line at y=0
    plt.axhline(0, color="black", linewidth=0.5)

    # Boxplot with 95% confidence interval
    sns.boxplot(
        data=df.query("model == @model"),
        x="Feature extractor",
        y="auroc_diff",
        hue="Augmentation",
        hue_order=["without augmentation", "with stain normalisation"],
        order=[
            x
            for y in [
                [
                    *([" " * i] if i != 0 else []),
                    *[format_feature_extractor(RENAME_FEATURE_EXTRACTORS[extractor]) for extractor in extractors],
                ]
                for i, extractors in enumerate(FEATURE_EXTRACTOR_GROUPS.values())
            ]
            for x in y
        ],
        # ensure 95% confidence interval is shown
        showfliers=False,
        whis=[2.5, 97.5],
        # show all data points
        # showmeans=True,
    )

    # Rotate x-axis labels
    plt.xticks(rotation=45, ha="right")
    # plt.legend(loc="lower right", ncol=2, bbox_to_anchor=(1.0, 1.05), borderaxespad=0.0)
    plt.legend(loc="upper left", ncol=2)
    plt.ylabel("Change in test AUROC")
    plt.xlabel("")

    i = 0
    xticks = []
    for group in FEATURE_EXTRACTOR_GROUPS.values():
        xticks.extend(range(i, i + len(group)))
        if i != 0:
            plt.axvline(x=i - 1, linestyle="--", color="grey", linewidth=0.5)
        i += len(group) + 1

    plt.xticks(xticks, rotation=45, ha="right")

    # light grey grid
    plt.grid(axis="y", color="lightgrey")
    plt.xlabel("")
    plt.ylim(-0.42, 0.42)
    savefig(f"bootstrap_magnifications_low_vs_high_{model}", journal=True)
    plt.show()

In [9]:
from functools import partial

d = (
    df.reset_index()
    .groupby(["augmentations", "model", "feature_extractor"])["auroc_diff"]
    .agg(["mean", partial(pd.Series.quantile, q=0.025), partial(pd.Series.quantile, q=0.975)])
)
d.columns = ["mean", "ci_lo", "ci_hi"]
d = (
    "$"
    + d["mean"].map(lambda x: f"{x:+.3f}")
    + "\\ ["
    + d["ci_lo"].map(lambda x: f"{x:.3f}")
    + ", "
    + d["ci_hi"].map(lambda x: f"{x:.3f}")
    + "]$"
)
d = d.unstack("augmentations")
d = d.reset_index()
d = d.sort_values(
    by=["model", "feature_extractor"],
    key=lambda series: series.map(
        lambda x: list(
            RENAME_FEATURE_EXTRACTORS.keys() if series.name == "feature_extractor" else RENAME_MODELS.keys()
        ).index(x)
    ),
)
d["model"] = d["model"].map(RENAME_MODELS)
d["feature_extractor"] = d["feature_extractor"].map(RENAME_FEATURE_EXTRACTORS)
d = d.set_index(["model", "feature_extractor"])
d.index.names = ["Model", "Feature extractor"]
cols = {"none": "Original", "Macenko_patchwise": "Macenko"}
d = d[cols.keys()]
d.columns = [{"none": "Original", "Macenko_patchwise": "Macenko"}[col] for col in d.columns]
print(d.to_latex(column_format="ll|cc", escape=False))

\begin{tabular}{ll|cc}
\toprule
 &  & Original & Macenko \\
Model & Feature extractor &  &  \\
\midrule
\multirow[t]{14}{*}{AttMIL} & Swin & $+0.028\ [-0.149, 0.251]$ & $+0.022\ [-0.182, 0.258]$ \\
 & CTransPath & $-0.022\ [-0.191, 0.107]$ & $-0.014\ [-0.153, 0.098]$ \\
 & ViT-S & $+0.010\ [-0.171, 0.190]$ & $-0.025\ [-0.208, 0.177]$ \\
 & Lunit-DINO & $-0.005\ [-0.188, 0.137]$ & $-0.009\ [-0.146, 0.138]$ \\
 & ViT-B & $+0.012\ [-0.203, 0.198]$ & $+0.003\ [-0.226, 0.199]$ \\
 & Phikon-S & $+0.046\ [-0.128, 0.251]$ & $+0.012\ [-0.132, 0.228]$ \\
 & Phikon-T & $+0.027\ [-0.150, 0.189]$ & $+0.011\ [-0.141, 0.172]$ \\
 & ViT-L & $+0.020\ [-0.190, 0.287]$ & $+0.031\ [-0.146, 0.302]$ \\
 & UNI & $-0.000\ [-0.128, 0.147]$ & $-0.009\ [-0.160, 0.113]$ \\
 & ResNet-50 & $+0.014\ [-0.201, 0.256]$ & $-0.004\ [-0.212, 0.247]$ \\
 & RetCCL & $-0.064\ [-0.236, 0.076]$ & $-0.047\ [-0.194, 0.103]$ \\
 & Lunit-BT & $-0.050\ [-0.356, 0.287]$ & $-0.064\ [-0.335, 0.193]$ \\
 & Lunit-SwAV & $+0.011\ [-0.139