In [1]:
from pathlib import Path
import pandas as pd

from make_summary import summarize_group
from mil.utils import human_format

selected_model_type = "mnist_collage"
# selected_model_type = "mnist_collage_ablations"
# selected_model_type = "camelyon16"

if "mnist_collage" in selected_model_type:
    datasets = ["mnist_collage", "mnist_collage_inverse"]
else:
    datasets = ["camelyon16"]
yaml_folder = Path("conf") / "selected_model" / selected_model_type

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
models = [file.stem for file in yaml_folder.glob("*.yaml") if file.name[0] != "_"]
if selected_model_type == "mnist_collage_ablations":
    models.append("distance_aware_self_attention")
models.sort()
models

['abmil',
 'discrete_rel_pos_self_attention',
 'distance_aware_self_attention',
 'gnn_gat',
 'gnn_gcn',
 'induced_set_transformer',
 'just_pool',
 'mil_gnn',
 'mil_gnn_ds',
 'self_attention',
 'self_attention_axial_pe',
 'self_attention_fourier_pe',
 'set_transformer',
 'transmil']

In [3]:
all_stats = {
    dataset: {
        model: summarize_group(f"selected-{dataset}-{model}", log_to_wandb=False)
        for model in models
    } for dataset in datasets
}
dfs = []
for dataset, dataset_stats in all_stats.items():
    df = pd.DataFrame(dataset_stats).T
    df["dataset"] = dataset
    dfs.append(df)

# Merge
df = pd.concat(dfs)
df.index.name = "model"
df.reset_index(inplace=True)
df.set_index(["dataset", "model"], inplace=True)
df[["mean(train/acc)", "std(train/acc)", "mean(test/acc)", "std(test/acc)", "mean(train/auc)", "std(train/auc)", "mean(test/auc)", "std(test/auc)"]].round(3)

[32m2023-05-19 16:21:45.050[0m | [1mINFO    [0m | [36mmake_summary[0m:[36msummarize_group[0m:[36m43[0m - [1mSummarizing group selected-mnist_collage-abmil[0m
[32m2023-05-19 16:21:47.893[0m | [1mINFO    [0m | [36mmake_summary[0m:[36msummarize_group[0m:[36m43[0m - [1mSummarizing group selected-mnist_collage-discrete_rel_pos_self_attention[0m
[32m2023-05-19 16:21:50.683[0m | [1mINFO    [0m | [36mmake_summary[0m:[36msummarize_group[0m:[36m43[0m - [1mSummarizing group selected-mnist_collage-distance_aware_self_attention[0m
[32m2023-05-19 16:21:53.730[0m | [1mINFO    [0m | [36mmake_summary[0m:[36msummarize_group[0m:[36m43[0m - [1mSummarizing group selected-mnist_collage-gnn_gat[0m
[32m2023-05-19 16:21:56.903[0m | [1mINFO    [0m | [36mmake_summary[0m:[36msummarize_group[0m:[36m43[0m - [1mSummarizing group selected-mnist_collage-gnn_gcn[0m
[32m2023-05-19 16:21:59.870[0m | [1mINFO    [0m | [36mmake_summary[0m:[36msummarize_group

Unnamed: 0_level_0,Unnamed: 1_level_0,mean(train/acc),std(train/acc),mean(test/acc),std(test/acc),mean(train/auc),std(train/auc),mean(test/auc),std(test/auc)
dataset,model,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
mnist_collage,abmil,0.799,0.014,0.74,0.01,0.893,0.017,0.825,0.008
mnist_collage,discrete_rel_pos_self_attention,0.921,0.014,0.926,0.03,0.97,0.005,0.975,0.005
mnist_collage,distance_aware_self_attention,0.955,0.011,0.958,0.013,0.99,0.003,0.992,0.009
mnist_collage,gnn_gat,0.811,0.089,0.758,0.041,0.873,0.086,0.82,0.068
mnist_collage,gnn_gcn,0.865,0.031,0.79,0.058,0.932,0.029,0.851,0.033
mnist_collage,induced_set_transformer,0.813,0.017,0.734,0.04,0.889,0.008,0.829,0.013
mnist_collage,just_pool,0.846,0.021,0.828,0.033,0.924,0.007,0.905,0.008
mnist_collage,mil_gnn,0.546,0.061,0.578,0.084,0.648,0.022,0.708,0.059
mnist_collage,mil_gnn_ds,0.686,0.143,0.656,0.123,0.771,0.13,0.778,0.093
mnist_collage,self_attention,0.881,0.013,0.848,0.024,0.932,0.028,0.888,0.013


## Table of results

In [4]:
if selected_model_type == "mnist_collage":
    NAMES = {
        "just_pool": "MIL with max pool",
        "abmil": "AB-MIL~\\cite{ilse2018attention}",
        "gnn_gat": "MIL with GNN (GAT~\\cite{velickovic2018graph})",
        "gnn_gcn": "MIL with GNN (GCN~\\cite{kipf2017semisupervised})",
        # "mil_gnn": "MIL-GNN~\\cite{tu2019multiple}",
        # "mil_gnn_ds": "MIL-GNN-DS~\\cite{tu2019multiple}",
        "mil_gnn_ds": "MIL-GNN~\\cite{tu2019multiple}",
        "induced_set_transformer": "MIL with iSet Transformer~\\cite{lee2019set}",
        "set_transformer": "MIL with Set Transformer~\\cite{lee2019set}",
        "self_attention": "MIL with SA~\\cite{vaswani2017attention}",
        "self_attention_axial_pe": "MIL with SA + axial PE~\\cite{ramachandran2019stand}",
        "self_attention_fourier_pe": "MIL with SA + Fourier PE~\\cite{yang2021learnable}",
        "discrete_rel_pos_self_attention": "MIL with disc.\\ rel.\\ SA~\\cite{wu2021rethinking}",
        "transmil": "TransMIL~\\cite{shao2021transmil}",
        "distance_aware_self_attention": "DAS-MIL (ours)",
    }
elif selected_model_type == "mnist_collage_ablations":
    NAMES = {
        "distance_aware_self_attention_embedk": "DAS-MIL ($\\vb^K$)",
        "distance_aware_self_attention_embedq": "DAS-MIL ($\\vb^Q$)",
        "distance_aware_self_attention_embedv": "DAS-MIL ($\\vb^V$)",
        "distance_aware_self_attention_embedkq": "DAS-MIL ($\\vb^K, \\vb^Q$) ($+ \\vb^Q{\\vb^K}^\\top$ in \\cref{eq:das:compatibility:impl})",
        "distance_aware_self_attention_embedkq_noterm3": "DAS-MIL ($\\vb^K, \\vb^Q$)",
        "distance_aware_self_attention_embedkv": "DAS-MIL ($\\vb^K, \\vb^V$)",
        "distance_aware_self_attention_embedqv": "DAS-MIL ($\\vb^Q, \\vb^V$)",
        "distance_aware_self_attention_embedkqv": "DAS-MIL ($\\vb^K, \\vb^Q, \\vb^V$) ($+ \\vb^Q{\\vb^K}^\\top$ in \\cref{eq:das:compatibility:impl})",
        "line1": None,
        "distance_aware_self_attention_fixedembed": "DAS-MIL (non-trainable $\\vb^K, \\vb^Q, \\vb^V$)",
        "line2": None,
        # "distance_aware_self_attention_embedkqv_noterm3": "DAS-MIL ($\\vb^K, \\vb^Q, \\vb^V$)",
        "distance_aware_self_attention": "DAS-MIL ($\\vb^K, \\vb^Q, \\vb^V$)",
    }
elif selected_model_type == "camelyon16":
    NAMES = {
        "just_pool": "MIL with max pool",
        "abmil": "AB-MIL~\\cite{ilse2018attention}",
        "self_attention": "MIL with SA~\\cite{vaswani2017attention}",
        "self_attention_axial_pe": "MIL with SA + axial PE~\\cite{ramachandran2019stand}",
        "self_attention_fourier_pe": "MIL with SA + Fourier PE~\\cite{yang2021learnable}",
        'discrete_rel_pos_self_attention': "MIL with disc.\\ rel.\\ SA~\\cite{wu2021rethinking}",
        'transmil': "TransMIL~\\cite{shao2021transmil}",
        # 'transmil_ourparams': "TransMIL~\\cite{shao2021transmil}",
        # 'distance_aware_self_attention': "DAS-MIL (no FC)",
        'distance_aware_self_attention_fc': "DAS-MIL (ours)",
        # 'distance_aware_self_attention_fc_t3': "DAS-MIL (FC, T3)",
    }

pos_col = selected_model_type == "mnist_collage" or selected_model_type == "camelyon16"
def if_pos(s):
    return s if pos_col else ""

abs_pos = ["self_attention_axial_pe", "self_attention_fourier_pe", "transmil"]
rel_pos = ["discrete_rel_pos_self_attention", "distance_aware_self_attention_fc", "distance_aware_self_attention", "mil_gnn", "mil_gnn_ds", "gnn_gat", "gnn_gcn"]


if "mnist_collage" in selected_model_type:
    metrics = ["balanced_acc"]
    # metrics = ["auc"]
    print("\\begin{tabular}{l" + if_pos("|c") + "|r|rr|rr}")
    print("\\toprule")
    print(" & " + if_pos("&") + " & \\multicolumn{2}{c|}{\\smaller{MNIST-COLLAGE}} & \\multicolumn{2}{c}{\\smaller{MNIST-COLLAGE-INV}} \\\\")
    print("Model " + if_pos("& \\multicolumn{1}{c|}{Pos}") + " & \\multicolumn{1}{c|}{Params} & \\multicolumn{1}{c}{Train} & \\multicolumn{1}{c|}{Test} & \\multicolumn{1}{c}{Train} & \\multicolumn{1}{c}{Test} \\\\")
    print("\\midrule")

elif selected_model_type == "camelyon16":
    metrics = ["auc", "balanced_acc"]
    print("\\begin{tabular}{l|c|c|rr|rr}")
    print("\\toprule")
    print(" & & & \\multicolumn{2}{c|}{AUROC} & \\multicolumn{2}{c}{Balanced accuracy} \\\\")
    print("Model & Pos & \\multicolumn{1}{c|}{Params} & \\multicolumn{1}{c}{Train} & \\multicolumn{1}{c|}{Test} & \\multicolumn{1}{c}{Train} & \\multicolumn{1}{c}{Test} \\\\")
    print("\\midrule")

rows = {name: row for name, row in df.iterrows()}
for name, desc in NAMES.items():
    if desc is None:
        print("\\hline")
        continue
    pos = "abs" if name in abs_pos else "rel" if name in rel_pos else None
    row = f"{NAMES[name]}"
    pos = f" \\smaller{{{pos.upper()}}}" if pos in ("abs", "rel") else "\\xmark"
    row += if_pos(f" & {pos}")
    row += f" & {human_format(df.loc[datasets[0], name]['num_parameters'])}"


    for dataset in datasets:
        for metric in metrics:
            df_ds = df.loc[dataset]
            for split in ("train", "test"):
                if df_ds[f"mean({split}/{metric})"].loc[list(NAMES)].max().round(3) == df_ds[f"mean({split}/{metric})"].loc[name].round(3):
                    row += f" & \\textbf{{{df_ds[f'mean({split}/{metric})'].loc[name]:.03f} $\pm$ {df_ds[f'std({split}/{metric})'].loc[name]:.03f}}}"
                else:
                    row += f" & {df_ds[f'mean({split}/{metric})'].loc[name]:.03f} $\pm$ {df_ds[f'std({split}/{metric})'].loc[name]:.03f}"
    print(row + " \\\\")
print("\\bottomrule")
print("\\end{tabular}")

\begin{tabular}{l|c|r|rr|rr}
\toprule
 & & & \multicolumn{2}{c|}{\smaller{MNIST-COLLAGE}} & \multicolumn{2}{c}{\smaller{MNIST-COLLAGE-INV}} \\
Model & \multicolumn{1}{c|}{Pos} & \multicolumn{1}{c|}{Params} & \multicolumn{1}{c}{Train} & \multicolumn{1}{c|}{Test} & \multicolumn{1}{c}{Train} & \multicolumn{1}{c}{Test} \\
\midrule
MIL with max pool & \xmark & 15.6K & 0.846 $\pm$ 0.021 & 0.828 $\pm$ 0.033 & 0.840 $\pm$ 0.009 & 0.788 $\pm$ 0.029 \\
AB-MIL~\cite{ilse2018attention} & \xmark & 16.1K & 0.799 $\pm$ 0.014 & 0.740 $\pm$ 0.010 & 0.805 $\pm$ 0.006 & 0.692 $\pm$ 0.015 \\
MIL with GNN (GAT~\cite{velickovic2018graph}) &  \smaller{REL} & 16.7K & 0.811 $\pm$ 0.089 & 0.758 $\pm$ 0.041 & 0.745 $\pm$ 0.033 & 0.716 $\pm$ 0.018 \\
MIL with GNN (GCN~\cite{kipf2017semisupervised}) &  \smaller{REL} & 16.3K & 0.865 $\pm$ 0.031 & 0.790 $\pm$ 0.058 & 0.883 $\pm$ 0.034 & 0.794 $\pm$ 0.036 \\
MIL-GNN~\cite{tu2019multiple} &  \smaller{REL} & 19.2K & 0.686 $\pm$ 0.143 & 0.656 $\pm$ 0.123 & 0.784 $\pm$ 0

## Table of hyperparameters

In [28]:
from omegaconf import OmegaConf, DictConfig

print("\\begin{tabular}{l|rrrrr}")
print("\\toprule")
print("Model & optimiser & LR & weight decay & hidden dim & agg \\\\")
print("\\midrule")
NA = "N/A"

for name in NAMES:
    file = yaml_folder / f"{name}.yaml"
    cfg = OmegaConf.load(file)
    # Model & LR & weight decay & hidden dim & agg \\
    optim = "Adam" if "base_optimizer" not in cfg.optimizer else "Lookahead"
    lr = cfg.optimizer.lr if "lr" in cfg.optimizer else cfg.optimizer.base_optimizer.lr
    weight_decay = cfg.optimizer.weight_decay if "weight_decay" in cfg.optimizer else cfg.optimizer.base_optimizer.weight_decay
    if "settings" in cfg:
        hidden_dim = cfg.settings.get("hidden_dim", NA)
        agg = cfg.settings.get("agg", NA)
    else:
        hidden_dim = NA
        agg = NA
    if agg == NA:
        base_model = [x for x in cfg.defaults if isinstance(x, DictConfig) and "/model" in x]
        if len(base_model) > 0:
            base_model = base_model[0].get("/model")
            if "max" in base_model:
                agg = "max"
    print(f"{NAMES[name]} & {optim} & {lr} & {weight_decay} & {hidden_dim} & {agg} \\\\")
print("\\bottomrule")
print("\\end{tabular}")

\begin{tabular}{l|rrrrr}
\toprule
Model & optimiser & LR & weight decay & hidden dim & agg \\
\midrule
MIL with max pool & Adam & 0.0001 & 0.01 & 10 & max \\
AB-MIL~\cite{ilse2018attention} & Adam & 0.0001 & 0.001 & 15 & N/A \\
MIL with GNN (GAT~\cite{velickovic2018graph}) & Adam & 0.001 & 0.1 & 20 & max \\
MIL with GNN (GCN~\cite{kipf2017semisupervised}) & Adam & 0.001 & 0.1 & 15 & max \\
MIL-GNN~\cite{tu2019multiple} & Adam & 0.001 & 0.01 & 20 & N/A \\
MIL with iSet Transformer~\cite{lee2019set} & Adam & 0.0001 & 0.1 & 15 & N/A \\
MIL with Set Transformer~\cite{lee2019set} & Adam & 0.001 & 0.1 & 10 & N/A \\
MIL with SA~\cite{vaswani2017attention} & Adam & 0.001 & 0.1 & 10 & max \\
MIL with SA + axial PE~\cite{ramachandran2019stand} & Adam & 0.0001 & 0.1 & 15 & max \\
MIL with SA + Fourier PE~\cite{yang2021learnable} & Adam & 0.001 & 0.001 & 15 & max \\
MIL with disc.\ rel.\ SA~\cite{wu2021rethinking} & Adam & 0.0001 & 0.1 & 10 & max \\
TransMIL~\cite{shao2021transmil} & Lookahead & 0