In [1]:
import wandb
import pandas as pd
from loguru import logger
from tqdm import tqdm
from itertools import product
import numpy as np
from pathlib import Path
import functools
import json
from typing import Literal
import multiprocessing as mp
from tqdm.contrib.concurrent import process_map

from histaug.utils import RunningStats, cached_df
from histaug.utils.display import (
    RENAME_MODELS,
    RENAME_FEATURE_EXTRACTORS,
    RENAME_TARGETS,
    TARGET_GROUPS,
    RENAME_AUGMENTATIONS,
)

api = wandb.Api()

INDEX_COLS = ["target", "train_dataset", "test_dataset", "model", "feature_extractor", "augmentations", "seed"]

RESULTS_DIR = Path("/app/results")

# Collect results from `wandb`

In [2]:
def filter_runs(runs, filters: dict):
    return [run for run in runs if all(getattr(run, key, None) == value for key, value in filters.items())]


def summarize_run(run):
    # return dict(
    #     target=(column := run.config["dataset"]["targets"][0]["column"]),
    #     train_dataset=run.config["dataset"]["name"],
    #     test_dataset=run.config["test"]["dataset"]["name"],
    #     model=run.config["model"]["_target_"].split(".")[-1],
    #     feature_extractor=run.config["settings"]["feature_extractor"],
    #     augmentations=run.config["dataset"]["augmentations"]["name"],
    #     seed=run.config["seed"],
    #     train_auroc=run.summary[f"train/{column}/auroc"]["best"],
    #     val_auroc=run.summary[f"val/{column}/auroc"]["best"],
    #     test_auroc=run.summary[f"test/{column}/auroc"]["best"],
    # )

    history = run.history().groupby("epoch").first()
    best = history[~history.index.isna()].sort_values("val/loss", ascending=True).iloc[0]
    column = run.config["dataset"]["targets"][0]["column"]
    if f"test/{column}/auroc" in run.summary:
        test_auroc = run.summary[f"test/{column}/auroc"]["best"]
    else:
        test_auroc = history[f"test/{column}/auroc"].max()
    return dict(
        target=column,
        train_dataset=run.config["dataset"]["name"],
        test_dataset=run.config["test"]["dataset"]["name"],
        model=run.config["model"]["_target_"].split(".")[-1],
        feature_extractor=run.config["settings"]["feature_extractor"],
        augmentations=run.config["dataset"]["augmentations"]["name"],
        seed=run.config["seed"],
        train_auroc=best[f"train/{column}/auroc"],
        val_auroc=best[f"val/{column}/auroc"],
        test_auroc=test_auroc,
    )


@cached_df(lambda: "aurocs")
def load_aurocs():
    logger.info("Loading runs")
    runs = list(api.runs("histaug", order="+created_at", per_page=1000))
    runs = filter_runs(runs, {"state": "finished"})
    runs = [summarize_run(run) for run in tqdm(runs, desc="Loading run data")]
    runs = [run for run in runs if run is not None]
    df = pd.DataFrame(runs)
    df = df.set_index(INDEX_COLS).sort_index().drop_duplicates()
    return df


df = df_all = load_aurocs().drop_duplicates().droplevel(["train_dataset", "test_dataset"])

df = (
    df.reset_index()
    .groupby(["target", "model", "feature_extractor", "augmentations"])
    .filter(lambda x: sorted(x.seed.values) == list(range(5)))
    .set_index(df.index.names)
    .sort_index()
)
print("Removed runs:", len(df_all) - len(df))
print(
    df_all.index.difference(df.index)
    .to_frame(index=False)
    .groupby([x for x in df.index.names if x != "seed"])
    .seed.count()
)

Removed runs: 0
Series([], Name: seed, dtype: int64)


In [8]:
def compute_overall_average(df):
    # Computer overall mean and std (across targets)
    targets = df.columns.get_level_values("target").unique()
    assert "average" not in targets
    n_targets = len(targets)
    overall_mean = df.stack().query("stats == 'mean'").droplevel("stats").sum(axis="columns").divide(n_targets)
    overall_std = (
        df.stack().query("stats == 'std'").droplevel("stats").pow(2).sum(axis="columns").pow(0.5).divide(n_targets)
    )
    df["average", "mean"] = overall_mean
    df["average", "std"] = overall_std
    return df


def results_to_latex(r, goal="min"):
    # Format for appearance
    r = r.unstack(["model", "augmentations"])
    means = r.stack("stats").query("stats == 'mean'").droplevel("stats")
    stds = r.stack("stats").query("stats == 'std'").droplevel("stats")

    has_average = "average" in means.columns

    formatted = means.map(lambda x: f"{x:.2f}") + " \\pm " + stds.map(lambda x: f"{x:.2f}")

    # Make best model bold
    best_mask = means == getattr(means, goal)(axis="index")
    formatted[best_mask] = "\\mathbf{" + formatted[best_mask] + "}"
    formatted = "$" + formatted + "$"
    formatted = formatted.stack("model")
    formatted = formatted.swaplevel("feature_extractor", "model").sort_index()

    formatted = formatted.stack(["augmentations"])

    # Set index order to augmentations, model, feature_extractor
    formatted = formatted.reorder_levels(["augmentations", "model", "feature_extractor"]).reindex(
        [*RENAME_TARGETS.keys(), *(["average"] if has_average else [])], axis=1
    )

    def sort_df_index(df, keys):
        """Sort a dataframe by index level values.

        Args:
            df: Dataframe to sort.
            keys: dict of {level: order} where:
                level: Name of index level to sort by.
                order: List of values in the order you want them to appear in that level
        """

        return df.sort_index(key=lambda x: x.map({v: i for i, v in enumerate(keys[x.name])}) if x.name in keys else x)

    def rename_df(df):
        rename_levels = {
            "feature_extractor": "Feature extractor",
            "model": "Model",
            "augmentations": "Augmentations",
            "target": "Target",
        }
        for axis in [0, 1]:
            d = df.index if axis == 0 else df.columns
            for level in d.names:
                if level == "feature_extractor":
                    df = df.rename(RENAME_FEATURE_EXTRACTORS, level=level, axis=axis)
                elif level == "model":
                    df = df.rename(RENAME_MODELS, level=level, axis=axis)
                elif level == "augmentations":
                    df = df.rename(RENAME_AUGMENTATIONS, level=level, axis=axis)
                elif level == "target":
                    df = df.rename(RENAME_TARGETS, level=level, axis=axis)

        df.index.set_names([rename_levels.get(x, x) for x in df.index.names], inplace=True)
        df.columns.set_names([rename_levels.get(x, x) for x in df.columns.names], inplace=True)
        return df

    formatted = sort_df_index(
        formatted,
        {
            "augmentations": RENAME_AUGMENTATIONS.keys(),
            "model": RENAME_MODELS.keys(),
            "feature_extractor": RENAME_FEATURE_EXTRACTORS.keys(),
        },
    )
    formatted = rename_df(formatted).rename(columns={"average": "Average"})

    col_format = (
        "lll|" + "|".join(("c" * len(group) for group in TARGET_GROUPS.values())) + ("|c" if has_average else "")
    )
    return formatted.to_latex(escape=False, column_format=col_format, multicolumn_format="c")

## Show results

In [9]:
d = df["train_auroc"].groupby(["augmentations", "model", "feature_extractor", "target"]).agg(["mean", "std"])
d.columns.name = "stats"
d = d.unstack("target").reorder_levels(["target", "stats"], axis=1).sort_index(axis=1)
# d = compute_overall_average(d)
d

Unnamed: 0_level_0,Unnamed: 1_level_0,target,BRAF,BRAF,CDH1,CDH1,KRAS,KRAS,MSI,MSI,PIK3CA,PIK3CA,SMAD4,SMAD4,TP53,TP53,lymph,lymph,subtype,subtype
Unnamed: 0_level_1,Unnamed: 1_level_1,stats,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std
augmentations,model,feature_extractor,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2
Macenko_patchwise,AttentionMIL,bt,0.564886,0.105031,0.810861,0.015204,0.533168,0.011238,0.516811,0.035431,0.500907,0.027700,0.513397,0.075430,0.742970,0.058297,0.591109,0.134943,0.712033,0.022810
Macenko_patchwise,AttentionMIL,ctranspath,0.765525,0.038265,0.825021,0.015995,0.626643,0.038938,0.861595,0.063567,0.642871,0.039540,0.570395,0.075514,0.832830,0.035437,0.913725,0.036503,0.852554,0.018805
Macenko_patchwise,AttentionMIL,dino_p16,0.748784,0.060397,0.807309,0.021994,0.612474,0.043347,0.859966,0.052028,0.631749,0.069804,0.572466,0.036883,0.823179,0.038200,0.905756,0.027026,0.881699,0.018108
Macenko_patchwise,AttentionMIL,owkin,0.704834,0.061536,0.810115,0.027982,0.632745,0.075821,0.772467,0.085178,0.627692,0.015109,0.531816,0.034404,0.811122,0.044054,0.878750,0.024066,0.869585,0.016115
Macenko_patchwise,AttentionMIL,resnet50,0.661318,0.051140,0.781056,0.031460,0.634187,0.078625,0.732887,0.097856,0.535411,0.032449,0.677793,0.074466,0.709676,0.040066,0.852046,0.062901,0.799108,0.036404
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
simple_rotate,Transformer,resnet50,,,0.755094,0.099878,,,,,0.590978,0.079378,,,0.758024,0.064721,,,0.823907,0.040391
simple_rotate,Transformer,retccl,,,0.722042,0.096937,,,,,0.565743,0.026086,,,0.810190,0.055196,,,0.846774,0.051194
simple_rotate,Transformer,swav,,,0.744211,0.065471,,,,,0.659194,0.057559,,,0.835403,0.023378,,,0.850344,0.020295
simple_rotate,Transformer,swin,,,0.732951,0.082671,,,,,0.590591,0.112111,,,0.727465,0.064490,,,0.766315,0.060721


In [10]:
d = d.query("augmentations in ['none', 'Macenko_slidewise', 'simple_rotate']")
print(results_to_latex(d, goal="max"))

\begin{tabular}{lll|cccc|c|cccc}
\toprule
 &  & Target & Subtype & CDH1 & TP53 & PIK3CA & LN status & MSI & KRAS & BRAF & SMAD4 \\
Augmentations & Model & Feature extractor &  &  &  &  &  &  &  &  &  \\
\midrule
\multirow[t]{27}{*}{Original} & \multirow[t]{9}{*}{AttMIL} & Swin & $0.84 \pm 0.04$ & $0.75 \pm 0.06$ & $0.76 \pm 0.06$ & $0.58 \pm 0.03$ & $0.87 \pm 0.08$ & $0.69 \pm 0.12$ & $0.58 \pm 0.05$ & $0.63 \pm 0.10$ & $0.60 \pm 0.01$ \\
 &  & CTransPath & $0.87 \pm 0.01$ & $0.84 \pm 0.02$ & $0.81 \pm 0.02$ & $0.63 \pm 0.03$ & $0.90 \pm 0.04$ & $0.83 \pm 0.13$ & $0.62 \pm 0.04$ & $0.73 \pm 0.09$ & $0.59 \pm 0.05$ \\
 &  & ViT-B & $0.85 \pm 0.02$ & $0.75 \pm 0.02$ & $0.76 \pm 0.05$ & $0.61 \pm 0.05$ & $0.80 \pm 0.07$ & $0.76 \pm 0.08$ & $0.58 \pm 0.04$ & $\mathbf{0.77 \pm 0.08}$ & $0.57 \pm 0.06$ \\
 &  & Phikon & $0.85 \pm 0.04$ & $0.79 \pm 0.11$ & $0.83 \pm 0.06$ & $0.58 \pm 0.08$ & $0.89 \pm 0.05$ & $0.85 \pm 0.07$ & $0.64 \pm 0.06$ & $0.74 \pm 0.05$ & $0.59 \pm 0.15$ \\
 &  & Lunit

# What is the best feature extractor?

In [6]:
def compute_norm_diff_auroc(sub_df, show_progress: bool = False):
    """Function to compute average offset from best for a given subset of data."""
    pivot_data = sub_df.pivot(index="seed", columns="feature_extractor", values="test_auroc")
    feature_extractors = pivot_data.columns.values
    seeds = pivot_data.index.values
    combinations = product(*pivot_data.values.T)
    n_combinations = int(len(seeds) ** len(feature_extractors))
    stats_by_feature_extractor = {fe: RunningStats() for fe in feature_extractors}

    for auroc_values in tqdm(combinations, total=n_combinations) if show_progress else combinations:
        diffs = np.array(auroc_values).max() - np.array(auroc_values)
        for fe, diff in zip(feature_extractors, diffs):
            stats_by_feature_extractor[fe].update(diff)

    return {fe: stats.compute() for fe, stats in stats_by_feature_extractor.items()}


def compute_norm_diff_auroc_worker(args):
    target, model, augmentations, sub_data = args
    return (target, model, augmentations), compute_norm_diff_auroc(sub_data)


@cached_df(lambda *args, **kwargs: f"norm_diff")
def compute_results_table(test_aurocs: pd.Series, n_workers: int = 32):
    """Compute average offsets from best for each (target, model, augmentation) pair using multiprocessing."""
    d = test_aurocs.reset_index()

    unique_pairs = d[["target", "model", "augmentations"]].drop_duplicates().values

    # Create a tuple of arguments for each unique pair
    args_list = [
        (
            target,
            model,
            augmentations,
            d[(d["target"] == target) & (d["model"] == model) & (d["augmentations"] == augmentations)],
        )
        for target, model, augmentations in unique_pairs
    ]

    # Use multiprocessing Pool to compute results in parallel
    results_list = process_map(
        compute_norm_diff_auroc_worker, args_list, max_workers=n_workers, tqdm_class=tqdm, desc="Computing results"
    )

    # Convert list of results into dictionary
    results = {(target, model, augmentations): result for (target, model, augmentations), result in results_list}

    r = pd.DataFrame(results).map(lambda x: x._asdict())
    r.index.name = "feature_extractor"
    r.columns.names = ["target", "model", "augmentations"]
    r = r.stack(["target", "model", "augmentations"]).apply(pd.Series)
    r.columns.names = ["stats"]
    r = (
        r.pivot_table(index=["augmentations", "model", "feature_extractor"], columns="target")
        .reorder_levels([1, 0], axis=1)
        .sort_index(axis=1)
    )
    r
    return r


r = compute_results_table(df["test_auroc"])
r

Unnamed: 0_level_0,Unnamed: 1_level_0,target,BRAF,BRAF,CDH1,CDH1,KRAS,KRAS,MSI,MSI,PIK3CA,PIK3CA,SMAD4,SMAD4,TP53,TP53,lymph,lymph,subtype,subtype
Unnamed: 0_level_1,Unnamed: 1_level_1,stats,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std
augmentations,model,feature_extractor,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2
Macenko_patchwise,AttentionMIL,bt,0.278917,0.096484,0.040381,0.014883,0.073060,0.056746,0.332843,0.073579,0.131119,0.023538,0.149196,0.083864,0.023166,0.023604,0.252021,0.129876,0.113091,0.040531
Macenko_patchwise,AttentionMIL,ctranspath,0.058704,0.036685,0.021078,0.029036,0.061292,0.030044,0.071732,0.056848,0.036662,0.013538,0.063807,0.025953,0.028048,0.023851,0.036961,0.049644,0.001350,0.003717
Macenko_patchwise,AttentionMIL,dino_p16,0.022135,0.037561,0.009999,0.011961,0.065046,0.039180,0.010466,0.016273,0.016805,0.015374,0.048753,0.034943,0.042863,0.026436,0.064955,0.058930,0.044301,0.032044
Macenko_patchwise,AttentionMIL,owkin,0.093810,0.045302,0.044607,0.013955,0.050181,0.047491,0.028625,0.030363,0.052319,0.018709,0.060518,0.055600,0.088795,0.035863,0.091281,0.092033,0.109806,0.030453
Macenko_patchwise,AttentionMIL,resnet50,0.168677,0.070385,0.183557,0.036119,0.121765,0.031194,0.180785,0.028316,0.013376,0.011095,0.146413,0.065345,0.167975,0.054579,0.167246,0.069956,0.171760,0.035164
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
simple_rotate,Transformer,resnet50,,,0.092867,0.047223,,,,,0.043375,0.054267,,,0.176895,0.041851,,,0.088892,0.015435
simple_rotate,Transformer,retccl,,,0.063570,0.043956,,,,,0.094775,0.041955,,,0.020974,0.020281,,,0.069215,0.048930
simple_rotate,Transformer,swav,,,0.041240,0.029316,,,,,0.119060,0.044002,,,0.048107,0.022257,,,0.061270,0.042487
simple_rotate,Transformer,swin,,,0.103490,0.038366,,,,,0.044318,0.027570,,,0.239717,0.037082,,,0.062011,0.034360


In [7]:
r = compute_overall_average(r)
r = r.query("augmentations in ['none', 'Macenko_slidewise']")
print(results_to_latex(r))