In [1]:
import wandb
import pandas as pd
from loguru import logger
from tqdm import tqdm
from itertools import product
import numpy as np
from pathlib import Path
import functools
import json
from typing import Literal

from histaug.utils import RunningStats, cached_df

api = wandb.Api()

INDEX_COLS = ["target", "train_dataset", "test_dataset", "model", "feature_extractor", "augmentations", "seed"]

RENAME_MODELS = {
    "AttentionMIL": "AttMIL",
    "MeanAveragePooling": "Mean pool",
    "Transformer": "Transformer",
}
RENAME_FEATURE_EXTRACTORS = {
    "bt": "Lunit-BT",
    "swav": "Lunit-SwAV",
    "dino_p16": "Lunit-DINO",
    "ctranspath": "CTransPath",
    "owkin": "Phikon",
    "resnet50": "ResNet-50",
    "retccl": "RetCCL",
    "swin": "Swin",
    "vit": "ViT",
}

RESULTS_DIR = Path("/app/results")

# Collect results from `wandb`

In [8]:
def filter_runs(runs, filters: dict):
    return [run for run in runs if all(getattr(run, key, None) == value for key, value in filters.items())]


def summarize_run(run):
    # return dict(
    #     target=(column := run.config["dataset"]["targets"][0]["column"]),
    #     train_dataset=run.config["dataset"]["name"],
    #     test_dataset=run.config["test"]["dataset"]["name"],
    #     model=run.config["model"]["_target_"].split(".")[-1],
    #     feature_extractor=run.config["settings"]["feature_extractor"],
    #     augmentations=run.config["dataset"]["augmentations"]["name"],
    #     seed=run.config["seed"],
    #     train_auroc=run.summary[f"train/{column}/auroc"]["best"],
    #     val_auroc=run.summary[f"val/{column}/auroc"]["best"],
    #     test_auroc=run.summary[f"test/{column}/auroc"]["best"],
    # )

    history = run.history().groupby("epoch").first()
    best = history[~history.index.isna()].sort_values("val/loss", ascending=True).iloc[0]
    column = run.config["dataset"]["targets"][0]["column"]
    if f"test/{column}/auroc" in run.summary:
        test_auroc = run.summary[f"test/{column}/auroc"]["best"]
    else:
        test_auroc = history[f"test/{column}/auroc"].max()
    return dict(
        target=column,
        train_dataset=run.config["dataset"]["name"],
        test_dataset=run.config["test"]["dataset"]["name"],
        model=run.config["model"]["_target_"].split(".")[-1],
        feature_extractor=run.config["settings"]["feature_extractor"],
        augmentations=run.config["dataset"]["augmentations"]["name"],
        seed=run.config["seed"],
        train_auroc=best[f"train/{column}/auroc"],
        val_auroc=best[f"val/{column}/auroc"],
        test_auroc=test_auroc,
    )


@cached_df(lambda: "aurocs")
def load_aurocs():
    logger.info("Loading runs")
    runs = list(api.runs("histaug", order="+created_at", per_page=1000))
    runs = filter_runs(runs, {"state": "finished"})
    runs = [summarize_run(run) for run in tqdm(runs, desc="Loading run data")]
    runs = [run for run in runs if run is not None]
    df = pd.DataFrame(runs)
    df = df.set_index(INDEX_COLS).sort_index().drop_duplicates()
    return df


df = df_all = load_aurocs().drop_duplicates()
df = (
    df.reset_index()
    .groupby(["target", "train_dataset", "test_dataset", "model", "feature_extractor", "augmentations"])
    .filter(lambda x: sorted(x.seed.values) == list(range(5)))
    .set_index(df.index.names)
    .sort_index()
)
print("Removed runs:", len(df_all) - len(df))
print(
    df_all.index.difference(df.index)
    .to_frame(index=False)
    .groupby([x for x in df.index.names if x != "seed"])
    .seed.count()
)

Removed runs: 0
Series([], Name: seed, dtype: int64)


In [9]:
df = df.query("target in ['BRAF', 'CDH1', 'KRAS', 'MSI', 'PIK3CA', 'SMAD4', 'TP53', 'subtype']")

# Compare original vs Macenko (vs Macenko_slidewise)

In [10]:
macenko = df.query("augmentations == 'Macenko_patchwise'")["test_auroc"].droplevel("augmentations")
# macenko = df.query("augmentations == 'Macenko_slidewise'")["test_auroc"].droplevel("augmentations")
orig = df.query("augmentations == 'none'")["test_auroc"].droplevel("augmentations")
# Mean diff across seeds
d = (
    (macenko - orig)
    .rename("test_auroc_diff")
    .reset_index()
    .drop(columns="seed")
    .groupby(["target", "train_dataset", "test_dataset", "model", "feature_extractor"])
    .agg(["mean", "std"])
)
o = (
    orig.rename("test_auroc_orig")
    .reset_index()
    .drop(columns="seed")
    .groupby(["target", "train_dataset", "test_dataset", "model", "feature_extractor"])
    .agg(["mean", "std"])
)
d = pd.concat([d, o], axis=1)
d.query("model == 'Transformer'")

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,test_auroc_diff,test_auroc_diff,test_auroc_orig,test_auroc_orig
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,mean,std,mean,std
target,train_dataset,test_dataset,model,feature_extractor,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2
BRAF,tcga_crc_BRAF,cptac_crc_BRAF,Transformer,bt,-0.012434,0.032997,0.629920,0.041557
BRAF,tcga_crc_BRAF,cptac_crc_BRAF,Transformer,ctranspath,0.013032,0.034159,0.678457,0.072021
BRAF,tcga_crc_BRAF,cptac_crc_BRAF,Transformer,dino_p16,-0.060306,0.112090,0.737766,0.052131
BRAF,tcga_crc_BRAF,cptac_crc_BRAF,Transformer,owkin,0.032447,0.099511,0.653457,0.073348
BRAF,tcga_crc_BRAF,cptac_crc_BRAF,Transformer,resnet50,0.036968,0.062584,0.569282,0.058041
...,...,...,...,...,...,...,...,...
subtype,tcga_brca_subtype,cptac_brca_subtype,Transformer,resnet50,-0.034491,0.038394,0.690357,0.035691
subtype,tcga_brca_subtype,cptac_brca_subtype,Transformer,retccl,0.031092,0.039379,0.727910,0.034958
subtype,tcga_brca_subtype,cptac_brca_subtype,Transformer,swav,0.013397,0.072871,0.737178,0.045536
subtype,tcga_brca_subtype,cptac_brca_subtype,Transformer,swin,-0.029599,0.030260,0.736527,0.036581


# What is the best feature extractor?

In [14]:
augmentation = "none"
# augmentation = "Macenko_patchwise"
# augmentation = "Macenko_slidewise"

test_aurocs = df.query(f"augmentations == '{augmentation}'").droplevel("augmentations")["test_auroc"]
# test_aurocs = df.query(f"augmentations == '{augmentation}' and target == 'MSI'").droplevel("augmentations")["test_auroc"]
test_aurocs

target   train_dataset      test_dataset        model         feature_extractor  seed
BRAF     tcga_crc_BRAF      cptac_crc_BRAF      AttentionMIL  bt                 0       0.359707
                                                                                 1       0.684840
                                                                                 2       0.331782
                                                                                 3       0.362367
                                                                                 4       0.397606
                                                                                           ...   
subtype  tcga_brca_subtype  cptac_brca_subtype  Transformer   vit                0       0.738184
                                                                                 1       0.738881
                                                                                 2       0.699713
                                

In [15]:
def compute_norm_diff_auroc(sub_df):
    """Function to compute average offset from best for a given subset of data."""
    pivot_data = sub_df.pivot(index="seed", columns="feature_extractor", values="test_auroc")
    feature_extractors = pivot_data.columns.values
    seeds = pivot_data.index.values
    combinations = product(*pivot_data.values.T)
    n_combinations = int(len(seeds) ** len(feature_extractors))
    stats_by_feature_extractor = {fe: RunningStats() for fe in feature_extractors}

    for i, auroc_values in enumerate(tqdm(combinations, total=n_combinations)):
        # sorted_indices = np.argsort(auroc_values)[::-1]
        # ranks_array[sorted_indices] += np.arange(1, len(feature_extractors) + 1)
        diffs = np.array(auroc_values).max() - np.array(auroc_values)
        for fe, diff in zip(feature_extractors, diffs):
            stats_by_feature_extractor[fe].update(diff)

    return {fe: stats.compute() for fe, stats in stats_by_feature_extractor.items()}


@cached_df(lambda *args, **kwargs: f"norm_diff_auroc_{augmentation}")
def compute_results_table(test_aurocs: pd.Series):
    """Compute average offsets from best for each (target, model) pair."""
    d = test_aurocs.reset_index()

    results = {}
    unique_pairs = d[["target", "model"]].drop_duplicates().values

    for target, model in unique_pairs:
        sub_data = d[(d["target"] == target) & (d["model"] == model)]
        results[(target, model)] = compute_norm_diff_auroc(sub_data)
        print(
            f"{target:10s} {model:20s}:",
            ", ".join(
                f"{k}={mean:.2f}+-{std:.2f}"
                for (k, (mean, std)) in sorted(results[(target, model)].items(), key=lambda x: x[1], reverse=False)
            ),
        )

    r = pd.DataFrame(results).map(lambda x: x._asdict())
    r.index.name = "feature_extractor"
    r.columns.names = ["target", "model"]
    r = r.stack().stack().apply(pd.Series)
    r.columns.names = ["stats"]
    r = (
        r.pivot_table(index=["model", "feature_extractor"], columns="target")
        .reorder_levels([1, 0], axis=1)
        .sort_index(axis=1)
    )
    return r


r = compute_results_table(test_aurocs)
r

  0%|          | 1631/1953125 [00:00<01:59, 16298.94it/s]

100%|██████████| 1953125/1953125 [01:23<00:00, 23422.25it/s]

PIK3CA     AttentionMIL        : resnet50=0.01+-0.02, dino_p16=0.02+-0.03, ctranspath=0.04+-0.03, retccl=0.05+-0.03, swin=0.07+-0.04, vit=0.07+-0.03, owkin=0.09+-0.03, swav=0.13+-0.06, bt=0.13+-0.04



  r = r.stack().stack().apply(pd.Series)


Unnamed: 0_level_0,target,PIK3CA,PIK3CA
Unnamed: 0_level_1,stats,mean,std
model,feature_extractor,Unnamed: 2_level_2,Unnamed: 3_level_2
AttentionMIL,bt,0.129801,0.03683
AttentionMIL,ctranspath,0.044115,0.031709
AttentionMIL,dino_p16,0.02163,0.029526
AttentionMIL,owkin,0.092258,0.033278
AttentionMIL,resnet50,0.013601,0.0193
AttentionMIL,retccl,0.049258,0.030994
AttentionMIL,swav,0.125201,0.060857
AttentionMIL,swin,0.065573,0.036333
AttentionMIL,vit,0.074144,0.031368


In [16]:
# Computer overall mean and std (across targets)
n_targets = r.stack().query("stats == 'mean'").droplevel("stats").shape[1]
overall_mean = r.stack().query("stats == 'mean'").droplevel("stats").sum(axis="columns").divide(n_targets)
overall_std = r.stack().query("stats == 'std'").droplevel("stats").pow(2).sum(axis="columns").pow(0.5).divide(n_targets)
r["average", "mean"] = overall_mean
r["average", "std"] = overall_std
r

Unnamed: 0_level_0,target,PIK3CA,PIK3CA,average,average
Unnamed: 0_level_1,stats,mean,std,mean,std
model,feature_extractor,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2
AttentionMIL,bt,0.129801,0.03683,0.129801,0.03683
AttentionMIL,ctranspath,0.044115,0.031709,0.044115,0.031709
AttentionMIL,dino_p16,0.02163,0.029526,0.02163,0.029526
AttentionMIL,owkin,0.092258,0.033278,0.092258,0.033278
AttentionMIL,resnet50,0.013601,0.0193,0.013601,0.0193
AttentionMIL,retccl,0.049258,0.030994,0.049258,0.030994
AttentionMIL,swav,0.125201,0.060857,0.125201,0.060857
AttentionMIL,swin,0.065573,0.036333,0.065573,0.036333
AttentionMIL,vit,0.074144,0.031368,0.074144,0.031368


In [7]:
def results_to_latex(r: pd.DataFrame, goal: Literal["min", "max"] = "min") -> str:
    # Format for appearance
    r = r.unstack("model")
    means = r.stack("stats").query("stats == 'mean'").droplevel("stats")
    stds = r.stack("stats").query("stats == 'std'").droplevel("stats")
    formatted = means.map(lambda x: f"{x:.2f}") + " \\pm " + stds.map(lambda x: f"{x:.2f}")

    # Make best model bold
    best_mask = means == getattr(means, goal)(axis="index")
    formatted[best_mask] = "\\mathbf{" + formatted[best_mask] + "}"
    formatted = "$" + formatted + "$"
    formatted = formatted.stack("model")
    formatted = formatted.swaplevel("feature_extractor", "model").sort_index()

    # Check if we have average column
    targets = formatted.columns.get_level_values("target").unique()
    has_average = "average" in targets
    if has_average:
        targets = targets.drop("average")

    # Rename
    formatted = formatted.rename(RENAME_MODELS, level="model")
    formatted = formatted.rename(RENAME_FEATURE_EXTRACTORS, level="feature_extractor")
    formatted.index.names = ["Model", "Feature extractor"]
    formatted.columns.names = ["Target"]
    return formatted.to_latex(
        column_format="l" * len(formatted.index.levels) + "|" + "c" * len(targets) + ("|c" if has_average else ""),
        escape=False,
    )


print(results_to_latex(r))

\begin{tabular}{ll|cccc|c}
\toprule
 & Target & CDH1 & PIK3CA & TP53 & subtype & average \\
Model & Feature extractor &  &  &  &  &  \\
\midrule
\multirow[t]{9}{*}{AttMIL} & Lunit-BT & $0.04 \pm 0.01$ & $0.13 \pm 0.02$ & $\mathbf{0.02 \pm 0.02}$ & $0.11 \pm 0.04$ & $0.08 \pm 0.01$ \\
 & CTransPath & $0.02 \pm 0.03$ & $0.04 \pm 0.01$ & $0.03 \pm 0.02$ & $\mathbf{0.00 \pm 0.00}$ & $\mathbf{0.02 \pm 0.01}$ \\
 & Lunit-DINO & $\mathbf{0.01 \pm 0.01}$ & $0.02 \pm 0.02$ & $0.04 \pm 0.03$ & $0.04 \pm 0.03$ & $0.03 \pm 0.01$ \\
 & Phikon & $0.04 \pm 0.01$ & $0.05 \pm 0.02$ & $0.09 \pm 0.04$ & $0.11 \pm 0.03$ & $0.07 \pm 0.01$ \\
 & ResNet-50 & $0.18 \pm 0.04$ & $\mathbf{0.01 \pm 0.01}$ & $0.17 \pm 0.05$ & $0.17 \pm 0.04$ & $0.13 \pm 0.02$ \\
 & RetCCL & $0.03 \pm 0.02$ & $0.02 \pm 0.02$ & $0.03 \pm 0.04$ & $0.09 \pm 0.04$ & $0.04 \pm 0.02$ \\
 & Lunit-SwAV & $0.03 \pm 0.03$ & $0.05 \pm 0.03$ & $0.04 \pm 0.02$ & $0.05 \pm 0.03$ & $0.04 \pm 0.01$ \\
 & Swin & $0.20 \pm 0.04$ & $0.05 \pm 0.03$ & 

In [8]:
# Aggregate test AUROCS over seeds
t = test_aurocs.agg(["mean", "std"])
# means = t["mean"].unstack("target").unstack("model")
# stds = t["std"].unstack("target").unstack("model")

# formatted = means.map(lambda x: f"{x:.2f}") + " \\pm " + stds.map(lambda x: f"{x:.2f}")
# # Make best model bold
# best_mask = means == means.max(axis="index")
# formatted[best_mask] = "\\mathbf{" + formatted[best_mask] + "}"
# formatted = "$" + formatted + "$"
# formatted = formatted.stack("model")
# formatted = formatted.swaplevel("feature_extractor", "model").sort_index()
# formatted = formatted.rename(RENAME_MODELS, level="model")
# formatted = formatted.rename(RENAME_FEATURE_EXTRACTORS, level="feature_extractor")
# formatted.index.names = ["Model", "Feature extractor"]
# formatted.columns.names = ["Target"]
# print(formatted.to_latex(column_format="l" * len(formatted.index.levels) + "|" + "c" * n_targets + "|c", escape=False))

# print(results_to_latex(r))
t = (
    test_aurocs.droplevel(["train_dataset", "test_dataset", "seed"])
    .reset_index()
    .groupby(["target", "model", "feature_extractor"])
    .agg(["mean", "std"])
    .droplevel(0, axis="columns")
)
t.columns.names = ["stats"]
t = t.unstack("target").swaplevel("stats", "target", axis="columns").sort_index(axis="columns")
print(results_to_latex(t, goal="max"))

\begin{tabular}{ll|cccccc}
\toprule
 & Target & CDH1 & KRAS & MSI & PIK3CA & TP53 & subtype \\
Model & Feature extractor &  &  &  &  &  &  \\
\midrule
\multirow[t]{9}{*}{AttMIL} & Lunit-BT & $0.76 \pm 0.01$ & $nan \pm nan$ & $0.57 \pm 0.08$ & $0.51 \pm 0.02$ & $\mathbf{0.79 \pm 0.03}$ & $0.70 \pm 0.03$ \\
 & CTransPath & $0.78 \pm 0.04$ & $0.61 \pm 0.03$ & $0.83 \pm 0.06$ & $0.60 \pm 0.01$ & $0.78 \pm 0.02$ & $\mathbf{0.81 \pm 0.03}$ \\
 & Lunit-DINO & $\mathbf{0.79 \pm 0.01}$ & $nan \pm nan$ & $\mathbf{0.89 \pm 0.03}$ & $0.62 \pm 0.02$ & $0.77 \pm 0.03$ & $0.77 \pm 0.02$ \\
 & Phikon & $0.76 \pm 0.01$ & $0.64 \pm 0.03$ & $0.87 \pm 0.03$ & $0.59 \pm 0.02$ & $0.72 \pm 0.04$ & $0.70 \pm 0.02$ \\
 & ResNet-50 & $0.62 \pm 0.04$ & $0.54 \pm nan$ & $0.72 \pm 0.02$ & $\mathbf{0.63 \pm 0.01}$ & $0.64 \pm 0.06$ & $0.64 \pm 0.03$ \\
 & RetCCL & $0.77 \pm 0.02$ & $\mathbf{0.65 \pm nan}$ & $0.83 \pm 0.03$ & $0.62 \pm 0.02$ & $0.78 \pm 0.05$ & $0.73 \pm 0.03$ \\
 & Lunit-SwAV & $0.78 \pm 0.03$ & $n

In [9]:
augmentation

'Macenko_patchwise'