In [95]:
import wandb
import pandas as pd
from loguru import logger
from tqdm import tqdm
from itertools import product
import numpy as np
from pathlib import Path
import functools
import json
from typing import Literal

from histaug.utils import RunningStats, cached_df

api = wandb.Api()

INDEX_COLS = ["target", "train_dataset", "test_dataset", "model", "feature_extractor", "augmentations", "seed"]

RENAME_MODELS = {
    "AttentionMIL": "AttMIL",
    "MeanAveragePooling": "Mean pool",
    "Transformer": "Transformer",
}
RENAME_FEATURE_EXTRACTORS = {
    "bt": "Lunit-BT",
    "swav": "Lunit-SwAV",
    "dino_p16": "Lunit-DINO",
    "ctranspath": "CTransPath",
    "owkin": "Phikon",
    "resnet50": "ResNet-50",
    "retccl": "RetCCL",
    "swin": "Swin",
    "vit": "ViT",
}

RESULTS_DIR = Path("/app/results")

# Collect results from `wandb`

In [46]:
def filter_runs(runs, filters: dict):
    return [run for run in runs if all(getattr(run, key, None) == value for key, value in filters.items())]


def summarize_run(run):
    # return dict(
    #     target=(column := run.config["dataset"]["targets"][0]["column"]),
    #     train_dataset=run.config["dataset"]["name"],
    #     test_dataset=run.config["test"]["dataset"]["name"],
    #     model=run.config["model"]["_target_"].split(".")[-1],
    #     feature_extractor=run.config["settings"]["feature_extractor"],
    #     augmentations=run.config["dataset"]["augmentations"]["name"],
    #     seed=run.config["seed"],
    #     train_auroc=run.summary[f"train/{column}/auroc"]["best"],
    #     val_auroc=run.summary[f"val/{column}/auroc"]["best"],
    #     test_auroc=run.summary[f"test/{column}/auroc"]["best"],
    # )

    hist = run.history().groupby("epoch").first()
    hist = hist[~hist.index.isna()]
    best = hist.sort_values("val/loss", ascending=True).iloc[0]
    return dict(
        target=(column := run.config["dataset"]["targets"][0]["column"]),
        train_dataset=run.config["dataset"]["name"],
        test_dataset=run.config["test"]["dataset"]["name"],
        model=run.config["model"]["_target_"].split(".")[-1],
        feature_extractor=run.config["settings"]["feature_extractor"],
        augmentations=run.config["dataset"]["augmentations"]["name"],
        seed=run.config["seed"],
        train_auroc=best[f"train/{column}/auroc"],
        val_auroc=best[f"val/{column}/auroc"],
        test_auroc=run.summary[f"test/{column}/auroc"]["best"],
    )


@cached_df(lambda: "aurocs")
def load_aurocs():
    runs = list(api.runs("histaug"))
    runs = filter_runs(runs, {"state": "finished"})
    runs = [summarize_run(run) for run in tqdm(runs, desc="Loading run data")]
    df = pd.DataFrame(runs)
    df = df.set_index(INDEX_COLS).sort_index()
    return df


df = load_aurocs()
df

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,train_auroc,val_auroc,test_auroc
target,train_dataset,test_dataset,model,feature_extractor,augmentations,seed,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
CDH1,tcga_brca_CDH1,cptac_brca_CDH1,AttentionMIL,bt,Macenko_patchwise,0,0.795646,0.819653,0.776128
CDH1,tcga_brca_CDH1,cptac_brca_CDH1,AttentionMIL,bt,Macenko_patchwise,1,0.828331,0.805249,0.760137
CDH1,tcga_brca_CDH1,cptac_brca_CDH1,AttentionMIL,bt,Macenko_patchwise,2,0.795870,0.829321,0.743575
CDH1,tcga_brca_CDH1,cptac_brca_CDH1,AttentionMIL,bt,Macenko_patchwise,3,0.810821,0.811168,0.762421
CDH1,tcga_brca_CDH1,cptac_brca_CDH1,AttentionMIL,bt,Macenko_patchwise,4,0.823637,0.814917,0.766990
...,...,...,...,...,...,...,...,...,...
subtype,tcga_brca_subtype,cptac_brca_subtype,Transformer,vit,simple_rotate,0,0.837830,0.805028,0.778463
subtype,tcga_brca_subtype,cptac_brca_subtype,Transformer,vit,simple_rotate,1,0.775820,0.762783,0.768838
subtype,tcga_brca_subtype,cptac_brca_subtype,Transformer,vit,simple_rotate,2,0.890516,0.767776,0.738792
subtype,tcga_brca_subtype,cptac_brca_subtype,Transformer,vit,simple_rotate,3,0.843990,0.789167,0.706103


# Compare original vs. Macenko (vs Macenko_slidewise)

In [47]:
macenko = df.query("augmentations == 'Macenko_patchwise'")["test_auroc"].droplevel("augmentations")
# macenko = df.query("augmentations == 'Macenko_slidewise'")["test_auroc"].droplevel("augmentations")
orig = df.query("augmentations == 'none'")["test_auroc"].droplevel("augmentations")
# Mean diff across seeds
d = (
    (macenko - orig)
    .rename("test_auroc_diff")
    .reset_index()
    .drop(columns="seed")
    .groupby(["target", "train_dataset", "test_dataset", "model", "feature_extractor"])
    .agg(["mean", "std"])
)
o = (
    orig.rename("test_auroc_orig")
    .reset_index()
    .drop(columns="seed")
    .groupby(["target", "train_dataset", "test_dataset", "model", "feature_extractor"])
    .agg(["mean", "std"])
)
d = pd.concat([d, o], axis=1)
d.query("model == 'Transformer'")

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,test_auroc_diff,test_auroc_diff,test_auroc_orig,test_auroc_orig
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,mean,std,mean,std
target,train_dataset,test_dataset,model,feature_extractor,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2
CDH1,tcga_brca_CDH1,cptac_brca_CDH1,Transformer,bt,0.011136,0.036495,0.762364,0.025135
CDH1,tcga_brca_CDH1,cptac_brca_CDH1,Transformer,ctranspath,-0.035294,0.051737,0.797487,0.014686
CDH1,tcga_brca_CDH1,cptac_brca_CDH1,Transformer,dino_p16,0.017247,0.036132,0.748544,0.025807
CDH1,tcga_brca_CDH1,cptac_brca_CDH1,Transformer,owkin,-0.005882,0.034366,0.734951,0.051818
CDH1,tcga_brca_CDH1,cptac_brca_CDH1,Transformer,resnet50,-0.071673,0.089917,0.708738,0.079017
CDH1,tcga_brca_CDH1,cptac_brca_CDH1,Transformer,retccl,-0.016562,0.074875,0.773158,0.054389
CDH1,tcga_brca_CDH1,cptac_brca_CDH1,Transformer,swav,0.026271,0.059711,0.774872,0.059484
CDH1,tcga_brca_CDH1,cptac_brca_CDH1,Transformer,swin,-0.070246,0.070006,0.697087,0.024496
CDH1,tcga_brca_CDH1,cptac_brca_CDH1,Transformer,vit,-0.022559,0.026442,0.709252,0.019774
PIK3CA,tcga_brca_PIK3CA,cptac_brca_PIK3CA,Transformer,bt,0.0458,0.044454,0.533486,0.045062


# What is the best feature extractor?

In [104]:
augmentation = "none"
augmentation = "Macenko_patchwise"
# augmentation = "Macenko_slidewise"

test_aurocs = df.query(f"augmentations == '{augmentation}'").droplevel("augmentations")["test_auroc"]
test_aurocs

target   train_dataset      test_dataset        model         feature_extractor  seed
CDH1     tcga_brca_CDH1     cptac_brca_CDH1     AttentionMIL  bt                 0       0.776128
                                                                                 1       0.760137
                                                                                 2       0.743575
                                                                                 3       0.762421
                                                                                 4       0.766990
                                                                                           ...   
subtype  tcga_brca_subtype  cptac_brca_subtype  Transformer   vit                0       0.731312
                                                                                 1       0.722702
                                                                                 2       0.657148
                                

In [105]:
def compute_average_offset_from_best(sub_df):
    """Function to compute average offset from best for a given subset of data."""
    pivot_data = sub_df.pivot(index="seed", columns="feature_extractor", values="test_auroc")
    feature_extractors = pivot_data.columns.values
    seeds = pivot_data.index.values
    combinations = product(*pivot_data.values.T)
    n_combinations = int(len(seeds) ** len(feature_extractors))
    stats_by_feature_extractor = {fe: RunningStats() for fe in feature_extractors}

    for i, auroc_values in enumerate(tqdm(combinations, total=n_combinations)):
        # sorted_indices = np.argsort(auroc_values)[::-1]
        # ranks_array[sorted_indices] += np.arange(1, len(feature_extractors) + 1)
        diffs = np.array(auroc_values).max() - np.array(auroc_values)
        for fe, diff in zip(feature_extractors, diffs):
            stats_by_feature_extractor[fe].update(diff)

    return {fe: stats.compute() for fe, stats in stats_by_feature_extractor.items()}


@cached_df(lambda *args, **kwargs: f"offsets_from_best_{augmentation}")
def compute_results_table(test_aurocs: pd.Series):
    """Compute average offsets from best for each (target, model) pair."""
    d = test_aurocs.reset_index()

    results = {}
    unique_pairs = d[["target", "model"]].drop_duplicates().values

    for target, model in unique_pairs:
        sub_data = d[(d["target"] == target) & (d["model"] == model)]
        results[(target, model)] = compute_average_offset_from_best(sub_data)
        print(
            f"{target:10s} {model:20s}:",
            ", ".join(
                f"{k}={mean:.2f}+-{std:.2f}"
                for (k, (mean, std)) in sorted(results[(target, model)].items(), key=lambda x: x[1], reverse=False)
            ),
        )

    r = pd.DataFrame(results).map(lambda x: x._asdict())
    r.index.name = "feature_extractor"
    r.columns.names = ["target", "model"]
    r = r.stack().stack().apply(pd.Series)
    r.columns.names = ["stats"]
    r = (
        r.pivot_table(index=["model", "feature_extractor"], columns="target")
        .reorder_levels([1, 0], axis=1)
        .sort_index(axis=1)
    )
    return r


r = compute_results_table(test_aurocs)
r

  0%|          | 0/1953125 [00:00<?, ?it/s]

100%|██████████| 1953125/1953125 [00:58<00:00, 33421.73it/s]


CDH1       AttentionMIL        : dino_p16=0.01+-0.01, ctranspath=0.02+-0.03, swav=0.03+-0.03, retccl=0.03+-0.02, bt=0.04+-0.01, owkin=0.04+-0.01, vit=0.12+-0.03, resnet50=0.18+-0.04, swin=0.20+-0.04


100%|██████████| 1953125/1953125 [00:57<00:00, 33914.96it/s]


CDH1       MeanAveragePooling  : ctranspath=0.01+-0.01, owkin=0.01+-0.01, dino_p16=0.02+-0.01, retccl=0.02+-0.01, swav=0.02+-0.01, bt=0.02+-0.01, vit=0.09+-0.01, resnet50=0.11+-0.04, swin=0.12+-0.02


100%|██████████| 1953125/1953125 [00:56<00:00, 34436.59it/s]


CDH1       Transformer         : swav=0.01+-0.01, bt=0.04+-0.03, dino_p16=0.05+-0.04, ctranspath=0.05+-0.04, retccl=0.06+-0.04, owkin=0.09+-0.03, vit=0.13+-0.03, resnet50=0.18+-0.10, swin=0.19+-0.05


100%|██████████| 1953125/1953125 [00:56<00:00, 34283.89it/s]


PIK3CA     AttentionMIL        : resnet50=0.01+-0.01, dino_p16=0.02+-0.02, retccl=0.02+-0.02, ctranspath=0.04+-0.01, vit=0.04+-0.04, swin=0.05+-0.03, swav=0.05+-0.03, owkin=0.05+-0.02, bt=0.13+-0.02


100%|██████████| 1953125/1953125 [00:55<00:00, 34945.58it/s]


PIK3CA     MeanAveragePooling  : ctranspath=0.01+-0.01, swin=0.01+-0.01, resnet50=0.03+-0.01, retccl=0.03+-0.01, vit=0.03+-0.01, dino_p16=0.03+-0.02, bt=0.05+-0.01, owkin=0.08+-0.04, swav=0.11+-0.01


100%|██████████| 1953125/1953125 [00:56<00:00, 34715.57it/s]


PIK3CA     Transformer         : dino_p16=0.03+-0.03, resnet50=0.04+-0.05, ctranspath=0.06+-0.04, retccl=0.06+-0.03, owkin=0.07+-0.04, bt=0.07+-0.04, vit=0.08+-0.03, swav=0.08+-0.05, swin=0.09+-0.04


100%|██████████| 1953125/1953125 [00:56<00:00, 34523.42it/s]


TP53       AttentionMIL        : bt=0.02+-0.02, ctranspath=0.03+-0.02, retccl=0.03+-0.04, swav=0.04+-0.02, dino_p16=0.04+-0.03, owkin=0.09+-0.04, vit=0.12+-0.02, resnet50=0.17+-0.05, swin=0.24+-0.03


100%|██████████| 1953125/1953125 [00:57<00:00, 33702.29it/s]


TP53       MeanAveragePooling  : ctranspath=0.02+-0.02, bt=0.03+-0.03, swav=0.04+-0.02, dino_p16=0.04+-0.04, retccl=0.05+-0.03, vit=0.07+-0.03, resnet50=0.07+-0.03, swin=0.11+-0.04, owkin=0.11+-0.03


100%|██████████| 1953125/1953125 [00:57<00:00, 33916.58it/s]


TP53       Transformer         : ctranspath=0.02+-0.02, dino_p16=0.02+-0.02, swav=0.03+-0.02, retccl=0.04+-0.04, bt=0.05+-0.03, owkin=0.07+-0.03, vit=0.18+-0.03, swin=0.20+-0.05, resnet50=0.23+-0.04


100%|██████████| 1953125/1953125 [01:01<00:00, 31932.59it/s]


subtype    AttentionMIL        : ctranspath=0.00+-0.00, dino_p16=0.04+-0.03, swav=0.05+-0.03, swin=0.08+-0.03, retccl=0.09+-0.04, vit=0.11+-0.04, owkin=0.11+-0.03, bt=0.11+-0.04, resnet50=0.17+-0.04


100%|██████████| 1953125/1953125 [00:59<00:00, 32910.00it/s]


subtype    MeanAveragePooling  : ctranspath=0.00+-0.00, retccl=0.01+-0.00, vit=0.03+-0.00, dino_p16=0.05+-0.01, bt=0.06+-0.03, swav=0.06+-0.00, swin=0.06+-0.01, resnet50=0.08+-0.00, owkin=0.11+-0.01


100%|██████████| 1953125/1953125 [00:56<00:00, 34306.67it/s]


subtype    Transformer         : ctranspath=0.01+-0.02, bt=0.03+-0.03, dino_p16=0.04+-0.03, retccl=0.06+-0.03, swav=0.06+-0.03, owkin=0.08+-0.03, swin=0.11+-0.04, vit=0.11+-0.03, resnet50=0.16+-0.05


  r = r.stack().stack().apply(pd.Series)


Unnamed: 0_level_0,target,CDH1,CDH1,PIK3CA,PIK3CA,TP53,TP53,subtype,subtype
Unnamed: 0_level_1,stats,mean,std,mean,std,mean,std,mean,std
model,feature_extractor,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
AttentionMIL,bt,0.040381,0.014883,0.131119,0.023538,0.023166,0.023604,0.113091,0.040531
AttentionMIL,ctranspath,0.021078,0.029036,0.036662,0.013538,0.028048,0.023851,0.00135,0.003717
AttentionMIL,dino_p16,0.009999,0.011961,0.016805,0.015374,0.042863,0.026436,0.044301,0.032044
AttentionMIL,owkin,0.044607,0.013955,0.052319,0.018709,0.088795,0.035863,0.109806,0.030453
AttentionMIL,resnet50,0.183557,0.036119,0.013376,0.011095,0.167975,0.054579,0.17176,0.035164
AttentionMIL,retccl,0.034899,0.020714,0.023462,0.018311,0.032341,0.038449,0.085243,0.039041
AttentionMIL,swav,0.026332,0.027243,0.051405,0.026158,0.036241,0.017693,0.054317,0.028419
AttentionMIL,swin,0.196406,0.042599,0.04509,0.029057,0.237223,0.028644,0.07968,0.030228
AttentionMIL,vit,0.11531,0.031582,0.041605,0.040126,0.122184,0.017218,0.105321,0.039103
MeanAveragePooling,bt,0.021193,0.008257,0.050165,0.013826,0.031026,0.027218,0.056,0.034604


In [106]:
# Computer overall mean and std (across targets)
n_targets = r.stack().query("stats == 'mean'").droplevel("stats").shape[1]
overall_mean = r.stack().query("stats == 'mean'").droplevel("stats").sum(axis="columns").divide(n_targets)
overall_std = r.stack().query("stats == 'std'").droplevel("stats").pow(2).sum(axis="columns").pow(0.5).divide(n_targets)
r["average", "mean"] = overall_mean
r["average", "std"] = overall_std
r

Unnamed: 0_level_0,target,CDH1,CDH1,PIK3CA,PIK3CA,TP53,TP53,subtype,subtype,average,average
Unnamed: 0_level_1,stats,mean,std,mean,std,mean,std,mean,std,mean,std
model,feature_extractor,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2
AttentionMIL,bt,0.040381,0.014883,0.131119,0.023538,0.023166,0.023604,0.113091,0.040531,0.076939,0.013637
AttentionMIL,ctranspath,0.021078,0.029036,0.036662,0.013538,0.028048,0.023851,0.00135,0.003717,0.021784,0.010028
AttentionMIL,dino_p16,0.009999,0.011961,0.016805,0.015374,0.042863,0.026436,0.044301,0.032044,0.028492,0.01147
AttentionMIL,owkin,0.044607,0.013955,0.052319,0.018709,0.088795,0.035863,0.109806,0.030453,0.073882,0.01313
AttentionMIL,resnet50,0.183557,0.036119,0.013376,0.011095,0.167975,0.054579,0.17176,0.035164,0.134167,0.01878
AttentionMIL,retccl,0.034899,0.020714,0.023462,0.018311,0.032341,0.038449,0.085243,0.039041,0.043986,0.015344
AttentionMIL,swav,0.026332,0.027243,0.051405,0.026158,0.036241,0.017693,0.054317,0.028419,0.042074,0.012617
AttentionMIL,swin,0.196406,0.042599,0.04509,0.029057,0.237223,0.028644,0.07968,0.030228,0.1396,0.01657
AttentionMIL,vit,0.11531,0.031582,0.041605,0.040126,0.122184,0.017218,0.105321,0.039103,0.096105,0.016645
MeanAveragePooling,bt,0.021193,0.008257,0.050165,0.013826,0.031026,0.027218,0.056,0.034604,0.039596,0.01172


In [107]:
def results_to_latex(r: pd.DataFrame, goal: Literal["min", "max"] = "min") -> str:
    # Format for appearance
    r = r.unstack("model")
    means = r.stack("stats").query("stats == 'mean'").droplevel("stats")
    stds = r.stack("stats").query("stats == 'std'").droplevel("stats")
    formatted = means.map(lambda x: f"{x:.2f}") + " \\pm " + stds.map(lambda x: f"{x:.2f}")

    # Make best model bold
    best_mask = means == getattr(means, goal)(axis="index")
    formatted[best_mask] = "\\mathbf{" + formatted[best_mask] + "}"
    formatted = "$" + formatted + "$"
    formatted = formatted.stack("model")
    formatted = formatted.swaplevel("feature_extractor", "model").sort_index()

    # Check if we have average column
    targets = formatted.columns.get_level_values("target").unique()
    has_average = "average" in targets
    if has_average:
        targets = targets.drop("average")

    # Rename
    formatted = formatted.rename(RENAME_MODELS, level="model")
    formatted = formatted.rename(RENAME_FEATURE_EXTRACTORS, level="feature_extractor")
    formatted.index.names = ["Model", "Feature extractor"]
    formatted.columns.names = ["Target"]
    return formatted.to_latex(
        column_format="l" * len(formatted.index.levels) + "|" + "c" * len(targets) + ("|c" if has_average else ""),
        escape=False,
    )


print(results_to_latex(r))

\begin{tabular}{ll|cccc|c}
\toprule
 & Target & CDH1 & PIK3CA & TP53 & subtype & average \\
Model & Feature extractor &  &  &  &  &  \\
\midrule
\multirow[t]{9}{*}{AttMIL} & BT & $0.04 \pm 0.01$ & $0.13 \pm 0.02$ & $\mathbf{0.02 \pm 0.02}$ & $0.11 \pm 0.04$ & $0.08 \pm 0.01$ \\
 & CTransPath & $0.02 \pm 0.03$ & $0.04 \pm 0.01$ & $0.03 \pm 0.02$ & $\mathbf{0.00 \pm 0.00}$ & $\mathbf{0.02 \pm 0.01}$ \\
 & Dino & $\mathbf{0.01 \pm 0.01}$ & $0.02 \pm 0.02$ & $0.04 \pm 0.03$ & $0.04 \pm 0.03$ & $0.03 \pm 0.01$ \\
 & Owkin & $0.04 \pm 0.01$ & $0.05 \pm 0.02$ & $0.09 \pm 0.04$ & $0.11 \pm 0.03$ & $0.07 \pm 0.01$ \\
 & ResNet-50 & $0.18 \pm 0.04$ & $\mathbf{0.01 \pm 0.01}$ & $0.17 \pm 0.05$ & $0.17 \pm 0.04$ & $0.13 \pm 0.02$ \\
 & RetCCL & $0.03 \pm 0.02$ & $0.02 \pm 0.02$ & $0.03 \pm 0.04$ & $0.09 \pm 0.04$ & $0.04 \pm 0.02$ \\
 & SwAV & $0.03 \pm 0.03$ & $0.05 \pm 0.03$ & $0.04 \pm 0.02$ & $0.05 \pm 0.03$ & $0.04 \pm 0.01$ \\
 & Swin & $0.20 \pm 0.04$ & $0.05 \pm 0.03$ & $0.24 \pm 0.03$ & $

In [108]:
# Aggregate test AUROCS over seeds
t = test_aurocs.agg(["mean", "std"])
# means = t["mean"].unstack("target").unstack("model")
# stds = t["std"].unstack("target").unstack("model")

# formatted = means.map(lambda x: f"{x:.2f}") + " \\pm " + stds.map(lambda x: f"{x:.2f}")
# # Make best model bold
# best_mask = means == means.max(axis="index")
# formatted[best_mask] = "\\mathbf{" + formatted[best_mask] + "}"
# formatted = "$" + formatted + "$"
# formatted = formatted.stack("model")
# formatted = formatted.swaplevel("feature_extractor", "model").sort_index()
# formatted = formatted.rename(RENAME_MODELS, level="model")
# formatted = formatted.rename(RENAME_FEATURE_EXTRACTORS, level="feature_extractor")
# formatted.index.names = ["Model", "Feature extractor"]
# formatted.columns.names = ["Target"]
# print(formatted.to_latex(column_format="l" * len(formatted.index.levels) + "|" + "c" * n_targets + "|c", escape=False))

# print(results_to_latex(r))
t = (
    test_aurocs.droplevel(["train_dataset", "test_dataset", "seed"])
    .reset_index()
    .groupby(["target", "model", "feature_extractor"])
    .agg(["mean", "std"])
    .droplevel(0, axis="columns")
)
t.columns.names = ["stats"]
t = t.unstack("target").swaplevel("stats", "target", axis="columns").sort_index(axis="columns")
print(results_to_latex(t, goal="max"))

\begin{tabular}{ll|cccc}
\toprule
 & Target & CDH1 & PIK3CA & TP53 & subtype \\
Model & Feature extractor &  &  &  &  \\
\midrule
\multirow[t]{9}{*}{AttMIL} & BT & $0.76 \pm 0.01$ & $0.51 \pm 0.02$ & $\mathbf{0.79 \pm 0.03}$ & $0.70 \pm 0.03$ \\
 & CTransPath & $0.78 \pm 0.04$ & $0.60 \pm 0.01$ & $0.78 \pm 0.02$ & $\mathbf{0.81 \pm 0.03}$ \\
 & Dino & $\mathbf{0.79 \pm 0.01}$ & $0.62 \pm 0.02$ & $0.77 \pm 0.03$ & $0.77 \pm 0.02$ \\
 & Owkin & $0.76 \pm 0.01$ & $0.59 \pm 0.02$ & $0.72 \pm 0.04$ & $0.70 \pm 0.02$ \\
 & ResNet-50 & $0.62 \pm 0.04$ & $\mathbf{0.63 \pm 0.01}$ & $0.64 \pm 0.06$ & $0.64 \pm 0.03$ \\
 & RetCCL & $0.77 \pm 0.02$ & $0.62 \pm 0.02$ & $0.78 \pm 0.05$ & $0.73 \pm 0.03$ \\
 & SwAV & $0.78 \pm 0.03$ & $0.59 \pm 0.03$ & $0.77 \pm 0.01$ & $0.76 \pm 0.01$ \\
 & Swin & $0.61 \pm 0.05$ & $0.60 \pm 0.03$ & $0.57 \pm 0.03$ & $0.73 \pm 0.02$ \\
 & ViT & $0.69 \pm 0.03$ & $0.60 \pm 0.05$ & $0.69 \pm 0.01$ & $0.71 \pm 0.03$ \\
\cline{1-6}
\multirow[t]{9}{*}{Mean pool} & BT & $

In [109]:
augmentation

'Macenko_patchwise'