In [1]:
import wandb
import pandas as pd
from loguru import logger
from tqdm import tqdm
from itertools import product
import numpy as np
from pathlib import Path
import functools
import json
from typing import Literal

from histaug.utils import RunningStats, cached_df

api = wandb.Api()

INDEX_COLS = ["target", "train_dataset", "test_dataset", "model", "feature_extractor", "augmentations", "seed"]

RENAME_MODELS = {
    "AttentionMIL": "AttMIL",
    "MeanAveragePooling": "Mean pool",
    "Transformer": "Transformer",
}
RENAME_FEATURE_EXTRACTORS = {
    "bt": "Lunit-BT",
    "swav": "Lunit-SwAV",
    "dino_p16": "Lunit-DINO",
    "ctranspath": "CTransPath",
    "owkin": "Phikon",
    "resnet50": "ResNet-50",
    "retccl": "RetCCL",
    "swin": "Swin",
    "vit": "ViT",
}

RESULTS_DIR = Path("/app/results")

# Collect results from `wandb`

In [2]:
def filter_runs(runs, filters: dict):
    return [run for run in runs if all(getattr(run, key, None) == value for key, value in filters.items())]


def summarize_run(run):
    # return dict(
    #     target=(column := run.config["dataset"]["targets"][0]["column"]),
    #     train_dataset=run.config["dataset"]["name"],
    #     test_dataset=run.config["test"]["dataset"]["name"],
    #     model=run.config["model"]["_target_"].split(".")[-1],
    #     feature_extractor=run.config["settings"]["feature_extractor"],
    #     augmentations=run.config["dataset"]["augmentations"]["name"],
    #     seed=run.config["seed"],
    #     train_auroc=run.summary[f"train/{column}/auroc"]["best"],
    #     val_auroc=run.summary[f"val/{column}/auroc"]["best"],
    #     test_auroc=run.summary[f"test/{column}/auroc"]["best"],
    # )

    history = run.history().groupby("epoch").first()
    best = history[~history.index.isna()].sort_values("val/loss", ascending=True).iloc[0]
    column = run.config["dataset"]["targets"][0]["column"]
    if f"test/{column}/auroc" in run.summary:
        test_auroc = run.summary[f"test/{column}/auroc"]["best"]
    else:
        test_auroc = history[f"test/{column}/auroc"].max()
    return dict(
        target=column,
        train_dataset=run.config["dataset"]["name"],
        test_dataset=run.config["test"]["dataset"]["name"],
        model=run.config["model"]["_target_"].split(".")[-1],
        feature_extractor=run.config["settings"]["feature_extractor"],
        augmentations=run.config["dataset"]["augmentations"]["name"],
        seed=run.config["seed"],
        train_auroc=best[f"train/{column}/auroc"],
        val_auroc=best[f"val/{column}/auroc"],
        test_auroc=test_auroc,
    )


@cached_df(lambda: "aurocs")
def load_aurocs():
    logger.info("Loading runs")
    runs = list(api.runs("histaug", order="+created_at", per_page=1000))
    runs = filter_runs(runs, {"state": "finished"})
    runs = [summarize_run(run) for run in tqdm(runs, desc="Loading run data")]
    runs = [run for run in runs if run is not None]
    df = pd.DataFrame(runs)
    df = df.set_index(INDEX_COLS).sort_index().drop_duplicates()
    return df


df = df_all = load_aurocs().drop_duplicates()
df = (
    df.reset_index()
    .groupby(["target", "train_dataset", "test_dataset", "model", "feature_extractor", "augmentations"])
    .filter(lambda x: sorted(x.seed.values) == list(range(5)))
    .set_index(df.index.names)
    .sort_index()
)
print("Removed runs:", len(df_all) - len(df))
print(
    df_all.index.difference(df.index)
    .to_frame(index=False)
    .groupby([x for x in df.index.names if x != "seed"])
    .seed.count()
)

[32m2023-11-07 12:09:18.281[0m | [1mINFO    [0m | [36m__main__[0m:[36mload_aurocs[0m:[36m42[0m - [1mLoading runs[0m
Loading run data:  58%|█████▊    | 2734/4725 [22:57<15:55,  2.08it/s]

In [3]:
df = df.query("target in ['BRAF', 'CDH1', 'KRAS', 'MSI', 'PIK3CA', 'SMAD4', 'TP53', 'subtype']")

# Compare original vs Macenko (vs Macenko_slidewise)

In [4]:
macenko = df.query("augmentations == 'Macenko_patchwise'")["test_auroc"].droplevel("augmentations")
# macenko = df.query("augmentations == 'Macenko_slidewise'")["test_auroc"].droplevel("augmentations")
orig = df.query("augmentations == 'none'")["test_auroc"].droplevel("augmentations")
# Mean diff across seeds
d = (
    (macenko - orig)
    .rename("test_auroc_diff")
    .reset_index()
    .drop(columns="seed")
    .groupby(["target", "train_dataset", "test_dataset", "model", "feature_extractor"])
    .agg(["mean", "std"])
)
o = (
    orig.rename("test_auroc_orig")
    .reset_index()
    .drop(columns="seed")
    .groupby(["target", "train_dataset", "test_dataset", "model", "feature_extractor"])
    .agg(["mean", "std"])
)
d = pd.concat([d, o], axis=1)
d.query("model == 'Transformer'")

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,test_auroc_diff,test_auroc_diff,test_auroc_orig,test_auroc_orig
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,mean,std,mean,std
target,train_dataset,test_dataset,model,feature_extractor,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2
BRAF,tcga_crc_BRAF,cptac_crc_BRAF,Transformer,bt,-0.012434,0.032997,0.629920,0.041557
BRAF,tcga_crc_BRAF,cptac_crc_BRAF,Transformer,ctranspath,0.013032,0.034159,0.678457,0.072021
BRAF,tcga_crc_BRAF,cptac_crc_BRAF,Transformer,dino_p16,-0.060306,0.112090,0.737766,0.052131
BRAF,tcga_crc_BRAF,cptac_crc_BRAF,Transformer,owkin,0.032447,0.099511,0.653457,0.073348
BRAF,tcga_crc_BRAF,cptac_crc_BRAF,Transformer,resnet50,0.036968,0.062584,0.569282,0.058041
...,...,...,...,...,...,...,...,...
subtype,tcga_brca_subtype,cptac_brca_subtype,Transformer,resnet50,-0.034491,0.038394,0.690357,0.035691
subtype,tcga_brca_subtype,cptac_brca_subtype,Transformer,retccl,0.031092,0.039379,0.727910,0.034958
subtype,tcga_brca_subtype,cptac_brca_subtype,Transformer,swav,0.013397,0.072871,0.737178,0.045536
subtype,tcga_brca_subtype,cptac_brca_subtype,Transformer,swin,-0.029599,0.030260,0.736527,0.036581


# What is the best feature extractor?

In [5]:
augmentation = "none"
augmentation = "Macenko_patchwise"
# augmentation = "Macenko_slidewise"

test_aurocs = df.query(f"augmentations == '{augmentation}'").droplevel("augmentations")["test_auroc"]
# test_aurocs = df.query(f"augmentations == '{augmentation}' and target == 'MSI'").droplevel("augmentations")["test_auroc"]
test_aurocs

target   train_dataset      test_dataset        model         feature_extractor  seed
BRAF     tcga_crc_BRAF      cptac_crc_BRAF      AttentionMIL  bt                 0       0.423537
                                                                                 1       0.655585
                                                                                 2       0.455452
                                                                                 3       0.450798
                                                                                 4       0.405585
                                                                                           ...   
subtype  tcga_brca_subtype  cptac_brca_subtype  Transformer   vit                0       0.731312
                                                                                 1       0.722702
                                                                                 2       0.657148
                                

In [6]:
def compute_norm_diff_auroc(sub_df):
    """Function to compute average offset from best for a given subset of data."""
    pivot_data = sub_df.pivot(index="seed", columns="feature_extractor", values="test_auroc")
    feature_extractors = pivot_data.columns.values
    seeds = pivot_data.index.values
    combinations = product(*pivot_data.values.T)
    n_combinations = int(len(seeds) ** len(feature_extractors))
    stats_by_feature_extractor = {fe: RunningStats() for fe in feature_extractors}

    for i, auroc_values in enumerate(tqdm(combinations, total=n_combinations)):
        # sorted_indices = np.argsort(auroc_values)[::-1]
        # ranks_array[sorted_indices] += np.arange(1, len(feature_extractors) + 1)
        diffs = np.array(auroc_values).max() - np.array(auroc_values)
        for fe, diff in zip(feature_extractors, diffs):
            stats_by_feature_extractor[fe].update(diff)

    return {fe: stats.compute() for fe, stats in stats_by_feature_extractor.items()}


@cached_df(lambda *args, **kwargs: f"norm_diff_auroc_{augmentation}")
def compute_results_table(test_aurocs: pd.Series):
    """Compute average offsets from best for each (target, model) pair."""
    d = test_aurocs.reset_index()

    results = {}
    unique_pairs = d[["target", "model"]].drop_duplicates().values

    for target, model in unique_pairs:
        sub_data = d[(d["target"] == target) & (d["model"] == model)]
        results[(target, model)] = compute_norm_diff_auroc(sub_data)
        print(
            f"{target:10s} {model:20s}:",
            ", ".join(
                f"{k}={mean:.2f}+-{std:.2f}"
                for (k, (mean, std)) in sorted(results[(target, model)].items(), key=lambda x: x[1], reverse=False)
            ),
        )

    r = pd.DataFrame(results).map(lambda x: x._asdict())
    r.index.name = "feature_extractor"
    r.columns.names = ["target", "model"]
    r = r.stack().stack().apply(pd.Series)
    r.columns.names = ["stats"]
    r = (
        r.pivot_table(index=["model", "feature_extractor"], columns="target")
        .reorder_levels([1, 0], axis=1)
        .sort_index(axis=1)
    )
    return r


r = compute_results_table(test_aurocs)
r

  0%|          | 0/1953125 [00:00<?, ?it/s]

100%|██████████| 1953125/1953125 [01:26<00:00, 22661.97it/s]


BRAF       AttentionMIL        : dino_p16=0.02+-0.04, ctranspath=0.06+-0.04, swav=0.07+-0.05, owkin=0.09+-0.05, swin=0.11+-0.07, vit=0.13+-0.06, retccl=0.14+-0.04, resnet50=0.17+-0.07, bt=0.28+-0.10


100%|██████████| 1953125/1953125 [01:23<00:00, 23320.14it/s]


BRAF       MeanAveragePooling  : dino_p16=0.00+-0.00, swav=0.04+-0.03, ctranspath=0.07+-0.04, owkin=0.09+-0.03, retccl=0.14+-0.05, swin=0.15+-0.04, vit=0.16+-0.06, bt=0.18+-0.03, resnet50=0.21+-0.04


100%|██████████| 1953125/1953125 [01:24<00:00, 23237.77it/s]


BRAF       Transformer         : swav=0.07+-0.10, ctranspath=0.07+-0.07, owkin=0.08+-0.04, dino_p16=0.09+-0.07, retccl=0.15+-0.08, bt=0.15+-0.04, resnet50=0.16+-0.05, vit=0.21+-0.04, swin=0.22+-0.07


100%|██████████| 1953125/1953125 [01:21<00:00, 23858.51it/s]


CDH1       AttentionMIL        : dino_p16=0.01+-0.01, ctranspath=0.02+-0.03, swav=0.03+-0.03, retccl=0.03+-0.02, bt=0.04+-0.01, owkin=0.04+-0.01, vit=0.12+-0.03, resnet50=0.18+-0.04, swin=0.20+-0.04


100%|██████████| 1953125/1953125 [01:26<00:00, 22458.95it/s]


CDH1       MeanAveragePooling  : ctranspath=0.01+-0.01, owkin=0.01+-0.01, dino_p16=0.02+-0.01, retccl=0.02+-0.01, swav=0.02+-0.01, bt=0.02+-0.01, vit=0.09+-0.01, resnet50=0.11+-0.04, swin=0.12+-0.02


100%|██████████| 1953125/1953125 [01:24<00:00, 23100.51it/s]


CDH1       Transformer         : swav=0.01+-0.01, bt=0.04+-0.03, dino_p16=0.05+-0.04, ctranspath=0.05+-0.04, retccl=0.06+-0.04, owkin=0.09+-0.03, vit=0.13+-0.03, resnet50=0.18+-0.10, swin=0.19+-0.05


100%|██████████| 1953125/1953125 [01:20<00:00, 24237.03it/s]


KRAS       AttentionMIL        : retccl=0.01+-0.02, owkin=0.05+-0.05, ctranspath=0.06+-0.03, dino_p16=0.07+-0.04, bt=0.07+-0.06, vit=0.09+-0.04, swin=0.11+-0.04, swav=0.12+-0.08, resnet50=0.12+-0.03


100%|██████████| 1953125/1953125 [01:20<00:00, 24138.38it/s]


KRAS       MeanAveragePooling  : bt=0.02+-0.03, resnet50=0.02+-0.02, swin=0.04+-0.03, retccl=0.05+-0.05, owkin=0.05+-0.03, dino_p16=0.07+-0.04, ctranspath=0.09+-0.05, vit=0.10+-0.05, swav=0.15+-0.02


100%|██████████| 1953125/1953125 [01:25<00:00, 22812.43it/s]


KRAS       Transformer         : owkin=0.03+-0.05, ctranspath=0.05+-0.04, bt=0.07+-0.05, retccl=0.07+-0.05, dino_p16=0.07+-0.03, vit=0.11+-0.06, resnet50=0.11+-0.04, swin=0.13+-0.03, swav=0.13+-0.04


100%|██████████| 1953125/1953125 [01:22<00:00, 23795.20it/s]


MSI        AttentionMIL        : dino_p16=0.01+-0.02, owkin=0.03+-0.03, retccl=0.07+-0.03, ctranspath=0.07+-0.06, swav=0.12+-0.04, swin=0.14+-0.03, vit=0.14+-0.04, resnet50=0.18+-0.03, bt=0.33+-0.07


100%|██████████| 1953125/1953125 [01:20<00:00, 24198.06it/s]


MSI        MeanAveragePooling  : owkin=0.02+-0.03, ctranspath=0.03+-0.03, dino_p16=0.04+-0.04, retccl=0.04+-0.03, bt=0.11+-0.04, swin=0.11+-0.03, swav=0.12+-0.03, resnet50=0.15+-0.05, vit=0.17+-0.04


100%|██████████| 1953125/1953125 [01:29<00:00, 21861.22it/s]


MSI        Transformer         : dino_p16=0.02+-0.03, ctranspath=0.04+-0.04, owkin=0.04+-0.02, bt=0.05+-0.03, swav=0.08+-0.04, retccl=0.08+-0.04, swin=0.19+-0.04, vit=0.21+-0.07, resnet50=0.21+-0.06


100%|██████████| 1953125/1953125 [01:28<00:00, 22110.03it/s]


PIK3CA     AttentionMIL        : resnet50=0.01+-0.01, dino_p16=0.02+-0.02, retccl=0.02+-0.02, ctranspath=0.04+-0.01, vit=0.04+-0.04, swin=0.05+-0.03, swav=0.05+-0.03, owkin=0.05+-0.02, bt=0.13+-0.02


100%|██████████| 1953125/1953125 [01:17<00:00, 25154.76it/s]


PIK3CA     MeanAveragePooling  : ctranspath=0.01+-0.01, swin=0.01+-0.01, resnet50=0.03+-0.01, retccl=0.03+-0.01, vit=0.03+-0.01, dino_p16=0.03+-0.02, bt=0.05+-0.01, owkin=0.08+-0.04, swav=0.11+-0.01


100%|██████████| 1953125/1953125 [01:25<00:00, 22826.97it/s]


PIK3CA     Transformer         : dino_p16=0.03+-0.03, resnet50=0.04+-0.05, ctranspath=0.06+-0.04, retccl=0.06+-0.03, owkin=0.07+-0.04, bt=0.07+-0.04, vit=0.08+-0.03, swav=0.08+-0.05, swin=0.09+-0.04


100%|██████████| 1953125/1953125 [01:23<00:00, 23340.52it/s]


SMAD4      AttentionMIL        : vit=0.02+-0.02, dino_p16=0.05+-0.03, owkin=0.06+-0.06, ctranspath=0.06+-0.03, retccl=0.07+-0.03, swav=0.11+-0.05, resnet50=0.15+-0.07, bt=0.15+-0.08, swin=0.20+-0.03


100%|██████████| 1953125/1953125 [01:24<00:00, 22984.74it/s]


SMAD4      MeanAveragePooling  : bt=0.00+-0.01, vit=0.02+-0.02, retccl=0.03+-0.01, dino_p16=0.03+-0.02, swin=0.04+-0.01, ctranspath=0.05+-0.02, owkin=0.07+-0.06, swav=0.09+-0.02, resnet50=0.11+-0.10


100%|██████████| 1953125/1953125 [01:25<00:00, 22718.91it/s]


SMAD4      Transformer         : owkin=0.03+-0.03, dino_p16=0.06+-0.06, retccl=0.06+-0.04, bt=0.06+-0.06, ctranspath=0.08+-0.06, vit=0.09+-0.04, swin=0.09+-0.06, swav=0.11+-0.02, resnet50=0.30+-0.11


100%|██████████| 1953125/1953125 [01:27<00:00, 22292.99it/s]


TP53       AttentionMIL        : bt=0.02+-0.02, ctranspath=0.03+-0.02, retccl=0.03+-0.04, swav=0.04+-0.02, dino_p16=0.04+-0.03, owkin=0.09+-0.04, vit=0.12+-0.02, resnet50=0.17+-0.05, swin=0.24+-0.03


100%|██████████| 1953125/1953125 [01:29<00:00, 21736.08it/s]


TP53       MeanAveragePooling  : ctranspath=0.02+-0.02, bt=0.03+-0.03, swav=0.04+-0.02, dino_p16=0.04+-0.04, retccl=0.05+-0.03, vit=0.07+-0.03, resnet50=0.07+-0.03, swin=0.11+-0.04, owkin=0.11+-0.03


100%|██████████| 1953125/1953125 [01:22<00:00, 23806.77it/s]


TP53       Transformer         : ctranspath=0.02+-0.02, dino_p16=0.02+-0.02, swav=0.03+-0.02, retccl=0.04+-0.04, bt=0.05+-0.03, owkin=0.07+-0.03, vit=0.18+-0.03, swin=0.20+-0.05, resnet50=0.23+-0.04


100%|██████████| 1953125/1953125 [01:23<00:00, 23474.11it/s]


subtype    AttentionMIL        : ctranspath=0.00+-0.00, dino_p16=0.04+-0.03, swav=0.05+-0.03, swin=0.08+-0.03, retccl=0.09+-0.04, vit=0.11+-0.04, owkin=0.11+-0.03, bt=0.11+-0.04, resnet50=0.17+-0.04


100%|██████████| 1953125/1953125 [01:23<00:00, 23420.04it/s]


subtype    MeanAveragePooling  : ctranspath=0.00+-0.00, retccl=0.01+-0.00, vit=0.03+-0.00, dino_p16=0.05+-0.01, bt=0.06+-0.03, swav=0.06+-0.00, swin=0.06+-0.01, resnet50=0.08+-0.00, owkin=0.11+-0.01


100%|██████████| 1953125/1953125 [01:24<00:00, 23209.08it/s]
  r = r.stack().stack().apply(pd.Series)


subtype    Transformer         : ctranspath=0.01+-0.02, bt=0.03+-0.03, dino_p16=0.04+-0.03, retccl=0.06+-0.03, swav=0.06+-0.03, owkin=0.08+-0.03, swin=0.11+-0.04, vit=0.11+-0.03, resnet50=0.16+-0.05


Unnamed: 0_level_0,target,BRAF,BRAF,CDH1,CDH1,KRAS,KRAS,MSI,MSI,PIK3CA,PIK3CA,SMAD4,SMAD4,TP53,TP53,subtype,subtype
Unnamed: 0_level_1,stats,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std
model,feature_extractor,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2
AttentionMIL,bt,0.278917,0.096484,0.040381,0.014883,0.07306,0.056746,0.332843,0.073579,0.131119,0.023538,0.149196,0.083864,0.023166,0.023604,0.113091,0.040531
AttentionMIL,ctranspath,0.058704,0.036685,0.021078,0.029036,0.061292,0.030044,0.071732,0.056848,0.036662,0.013538,0.063807,0.025953,0.028048,0.023851,0.00135,0.003717
AttentionMIL,dino_p16,0.022135,0.037561,0.009999,0.011961,0.065046,0.03918,0.010466,0.016273,0.016805,0.015374,0.048753,0.034943,0.042863,0.026436,0.044301,0.032044
AttentionMIL,owkin,0.09381,0.045302,0.044607,0.013955,0.050181,0.047491,0.028625,0.030363,0.052319,0.018709,0.060518,0.0556,0.088795,0.035863,0.109806,0.030453
AttentionMIL,resnet50,0.168677,0.070385,0.183557,0.036119,0.121765,0.031194,0.180785,0.028316,0.013376,0.011095,0.146413,0.065345,0.167975,0.054579,0.17176,0.035164
AttentionMIL,retccl,0.141018,0.036274,0.034899,0.020714,0.012231,0.022048,0.070188,0.034005,0.023462,0.018311,0.072662,0.03285,0.032341,0.038449,0.085243,0.039041
AttentionMIL,swav,0.072068,0.05284,0.026332,0.027243,0.118124,0.076239,0.123481,0.039511,0.051405,0.026158,0.109221,0.047082,0.036241,0.017693,0.054317,0.028419
AttentionMIL,swin,0.107773,0.074119,0.196406,0.042599,0.105474,0.036597,0.137472,0.026095,0.04509,0.029057,0.20005,0.028271,0.237223,0.028644,0.07968,0.030228
AttentionMIL,vit,0.12619,0.063989,0.11531,0.031582,0.091134,0.039121,0.142719,0.038258,0.041605,0.040126,0.017634,0.024247,0.122184,0.017218,0.105321,0.039103
MeanAveragePooling,bt,0.181955,0.025563,0.021193,0.008257,0.019133,0.027382,0.108877,0.037839,0.050165,0.013826,0.004534,0.007526,0.031026,0.027218,0.056,0.034604


In [18]:
# Computer overall mean and std (across targets)
n_targets = r.stack().query("stats == 'mean'").droplevel("stats").shape[1]
overall_mean = r.stack().query("stats == 'mean'").droplevel("stats").sum(axis="columns").divide(n_targets)
overall_std = r.stack().query("stats == 'std'").droplevel("stats").pow(2).sum(axis="columns").pow(0.5).divide(n_targets)
r["average", "mean"] = overall_mean
r["average", "std"] = overall_std
r

Unnamed: 0_level_0,target,BRAF,BRAF,CDH1,CDH1,KRAS,KRAS,MSI,MSI,PIK3CA,PIK3CA,SMAD4,SMAD4,TP53,TP53,subtype,subtype,average,average
Unnamed: 0_level_1,stats,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std
model,feature_extractor,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2
AttentionMIL,bt,0.345998,0.133092,0.064523,0.035145,0.032743,0.036754,0.27714,0.133136,0.129801,0.03683,0.251079,0.031551,0.016759,0.014204,0.129456,0.032143,0.155937,0.025499
AttentionMIL,ctranspath,0.06036,0.027964,0.006499,0.012455,0.059225,0.032613,0.083982,0.030418,0.044115,0.031709,0.059175,0.025216,0.012438,0.012595,0.0,0.0,0.040724,0.008593
AttentionMIL,dino_p16,0.01528,0.024514,0.034255,0.034877,0.061027,0.041159,0.001214,0.004766,0.02163,0.029526,0.020339,0.022028,0.027814,0.024693,0.079472,0.027152,0.032629,0.009873
AttentionMIL,owkin,0.069402,0.060869,0.089423,0.022736,0.065269,0.042018,0.064743,0.043849,0.092258,0.033278,0.170054,0.075789,0.094846,0.030961,0.092403,0.021614,0.0923,0.015904
AttentionMIL,resnet50,0.226583,0.068262,0.087824,0.038983,0.112228,0.033884,0.220247,0.042516,0.013601,0.0193,0.212875,0.094667,0.10938,0.03076,0.148987,0.025604,0.141466,0.017709
AttentionMIL,retccl,0.141078,0.030904,0.037567,0.01697,0.028707,0.024362,0.081512,0.034558,0.049258,0.030994,0.056645,0.025801,0.035474,0.028321,0.065062,0.031917,0.061913,0.010061
AttentionMIL,swav,0.070399,0.065243,0.064295,0.030173,0.128707,0.059317,0.102654,0.033832,0.125201,0.060857,0.136342,0.081458,0.058061,0.024875,0.062774,0.022717,0.093554,0.018245
AttentionMIL,swin,0.142407,0.065605,0.168407,0.025006,0.140156,0.042411,0.176626,0.043431,0.065573,0.036333,0.158606,0.048224,0.277842,0.022148,0.070815,0.023356,0.150054,0.014413
AttentionMIL,vit,0.179442,0.074738,0.110668,0.019652,0.029571,0.041542,0.146893,0.034201,0.074144,0.031368,0.007309,0.012596,0.152618,0.031691,0.081363,0.044916,0.097751,0.01427
MeanAveragePooling,bt,0.207247,0.090669,0.035422,0.006928,0.020973,0.023927,0.080158,0.019227,0.072272,0.017518,0.02544,0.016856,0.057807,0.037245,0.061474,0.027821,0.070099,0.013672


In [7]:
def results_to_latex(r: pd.DataFrame, goal: Literal["min", "max"] = "min") -> str:
    # Format for appearance
    r = r.unstack("model")
    means = r.stack("stats").query("stats == 'mean'").droplevel("stats")
    stds = r.stack("stats").query("stats == 'std'").droplevel("stats")
    formatted = means.map(lambda x: f"{x:.2f}") + " \\pm " + stds.map(lambda x: f"{x:.2f}")

    # Make best model bold
    best_mask = means == getattr(means, goal)(axis="index")
    formatted[best_mask] = "\\mathbf{" + formatted[best_mask] + "}"
    formatted = "$" + formatted + "$"
    formatted = formatted.stack("model")
    formatted = formatted.swaplevel("feature_extractor", "model").sort_index()

    # Check if we have average column
    targets = formatted.columns.get_level_values("target").unique()
    has_average = "average" in targets
    if has_average:
        targets = targets.drop("average")

    # Rename
    formatted = formatted.rename(RENAME_MODELS, level="model")
    formatted = formatted.rename(RENAME_FEATURE_EXTRACTORS, level="feature_extractor")
    formatted.index.names = ["Model", "Feature extractor"]
    formatted.columns.names = ["Target"]
    return formatted.to_latex(
        column_format="l" * len(formatted.index.levels) + "|" + "c" * len(targets) + ("|c" if has_average else ""),
        escape=False,
    )


print(results_to_latex(r))

\begin{tabular}{ll|cccc|c}
\toprule
 & Target & CDH1 & PIK3CA & TP53 & subtype & average \\
Model & Feature extractor &  &  &  &  &  \\
\midrule
\multirow[t]{9}{*}{AttMIL} & Lunit-BT & $0.04 \pm 0.01$ & $0.13 \pm 0.02$ & $\mathbf{0.02 \pm 0.02}$ & $0.11 \pm 0.04$ & $0.08 \pm 0.01$ \\
 & CTransPath & $0.02 \pm 0.03$ & $0.04 \pm 0.01$ & $0.03 \pm 0.02$ & $\mathbf{0.00 \pm 0.00}$ & $\mathbf{0.02 \pm 0.01}$ \\
 & Lunit-DINO & $\mathbf{0.01 \pm 0.01}$ & $0.02 \pm 0.02$ & $0.04 \pm 0.03$ & $0.04 \pm 0.03$ & $0.03 \pm 0.01$ \\
 & Phikon & $0.04 \pm 0.01$ & $0.05 \pm 0.02$ & $0.09 \pm 0.04$ & $0.11 \pm 0.03$ & $0.07 \pm 0.01$ \\
 & ResNet-50 & $0.18 \pm 0.04$ & $\mathbf{0.01 \pm 0.01}$ & $0.17 \pm 0.05$ & $0.17 \pm 0.04$ & $0.13 \pm 0.02$ \\
 & RetCCL & $0.03 \pm 0.02$ & $0.02 \pm 0.02$ & $0.03 \pm 0.04$ & $0.09 \pm 0.04$ & $0.04 \pm 0.02$ \\
 & Lunit-SwAV & $0.03 \pm 0.03$ & $0.05 \pm 0.03$ & $0.04 \pm 0.02$ & $0.05 \pm 0.03$ & $0.04 \pm 0.01$ \\
 & Swin & $0.20 \pm 0.04$ & $0.05 \pm 0.03$ & 

In [8]:
# Aggregate test AUROCS over seeds
t = test_aurocs.agg(["mean", "std"])
# means = t["mean"].unstack("target").unstack("model")
# stds = t["std"].unstack("target").unstack("model")

# formatted = means.map(lambda x: f"{x:.2f}") + " \\pm " + stds.map(lambda x: f"{x:.2f}")
# # Make best model bold
# best_mask = means == means.max(axis="index")
# formatted[best_mask] = "\\mathbf{" + formatted[best_mask] + "}"
# formatted = "$" + formatted + "$"
# formatted = formatted.stack("model")
# formatted = formatted.swaplevel("feature_extractor", "model").sort_index()
# formatted = formatted.rename(RENAME_MODELS, level="model")
# formatted = formatted.rename(RENAME_FEATURE_EXTRACTORS, level="feature_extractor")
# formatted.index.names = ["Model", "Feature extractor"]
# formatted.columns.names = ["Target"]
# print(formatted.to_latex(column_format="l" * len(formatted.index.levels) + "|" + "c" * n_targets + "|c", escape=False))

# print(results_to_latex(r))
t = (
    test_aurocs.droplevel(["train_dataset", "test_dataset", "seed"])
    .reset_index()
    .groupby(["target", "model", "feature_extractor"])
    .agg(["mean", "std"])
    .droplevel(0, axis="columns")
)
t.columns.names = ["stats"]
t = t.unstack("target").swaplevel("stats", "target", axis="columns").sort_index(axis="columns")
print(results_to_latex(t, goal="max"))

\begin{tabular}{ll|cccccc}
\toprule
 & Target & CDH1 & KRAS & MSI & PIK3CA & TP53 & subtype \\
Model & Feature extractor &  &  &  &  &  &  \\
\midrule
\multirow[t]{9}{*}{AttMIL} & Lunit-BT & $0.76 \pm 0.01$ & $nan \pm nan$ & $0.57 \pm 0.08$ & $0.51 \pm 0.02$ & $\mathbf{0.79 \pm 0.03}$ & $0.70 \pm 0.03$ \\
 & CTransPath & $0.78 \pm 0.04$ & $0.61 \pm 0.03$ & $0.83 \pm 0.06$ & $0.60 \pm 0.01$ & $0.78 \pm 0.02$ & $\mathbf{0.81 \pm 0.03}$ \\
 & Lunit-DINO & $\mathbf{0.79 \pm 0.01}$ & $nan \pm nan$ & $\mathbf{0.89 \pm 0.03}$ & $0.62 \pm 0.02$ & $0.77 \pm 0.03$ & $0.77 \pm 0.02$ \\
 & Phikon & $0.76 \pm 0.01$ & $0.64 \pm 0.03$ & $0.87 \pm 0.03$ & $0.59 \pm 0.02$ & $0.72 \pm 0.04$ & $0.70 \pm 0.02$ \\
 & ResNet-50 & $0.62 \pm 0.04$ & $0.54 \pm nan$ & $0.72 \pm 0.02$ & $\mathbf{0.63 \pm 0.01}$ & $0.64 \pm 0.06$ & $0.64 \pm 0.03$ \\
 & RetCCL & $0.77 \pm 0.02$ & $\mathbf{0.65 \pm nan}$ & $0.83 \pm 0.03$ & $0.62 \pm 0.02$ & $0.78 \pm 0.05$ & $0.73 \pm 0.03$ \\
 & Lunit-SwAV & $0.78 \pm 0.03$ & $n

In [9]:
augmentation

'Macenko_patchwise'