In [13]:
from pathlib import Path
import pandas as pd

from make_summary import summarize_group
from mil.utils import human_format

selected_model_type = "mnist_collage"
selected_model_type = "mnist_collage_ablations"
selected_model_type = "camelyon16"

if "mnist_collage" in selected_model_type:
    datasets = ["mnist_collage", "mnist_collage_inverse"]
else:
    datasets = ["camelyon16"]
yaml_folder = Path("conf") / "selected_model" / selected_model_type

In [11]:
models = [file.stem for file in yaml_folder.glob("*.yaml") if file.name[0] != "_"]
if selected_model_type == "mnist_collage_ablations":
    models.append("distance_aware_self_attention")
models.sort()
models

['discrete_rel_pos_self_attention',
 'distance_aware_self_attention',
 'distance_aware_self_attention_fc',
 'distance_aware_self_attention_fc_t3',
 'transmil']

In [17]:
all_stats = {
    dataset: {
        model: summarize_group(f"selected-{dataset}-{model}", log_to_wandb=False)
        for model in models
    } for dataset in datasets
}
dfs = []
for dataset, dataset_stats in all_stats.items():
    df = pd.DataFrame(dataset_stats).T
    df["dataset"] = dataset
    dfs.append(df)

# Merge
df = pd.concat(dfs)
df.index.name = "model"
df.reset_index(inplace=True)
df.set_index(["dataset", "model"], inplace=True)
df[["mean(train/acc)", "std(train/acc)", "mean(test/acc)", "std(test/acc)", "mean(train/auc)", "std(train/auc)", "mean(test/auc)", "std(test/auc)"]].round(3)

[32m2023-05-07 21:13:10.628[0m | [1mINFO    [0m | [36mmake_summary[0m:[36msummarize_group[0m:[36m43[0m - [1mSummarizing group selected-camelyon16-discrete_rel_pos_self_attention[0m
[32m2023-05-07 21:13:12.127[0m | [1mINFO    [0m | [36mmake_summary[0m:[36msummarize_group[0m:[36m43[0m - [1mSummarizing group selected-camelyon16-distance_aware_self_attention[0m
[32m2023-05-07 21:13:12.472[0m | [1mINFO    [0m | [36mmake_summary[0m:[36msummarize_group[0m:[36m43[0m - [1mSummarizing group selected-camelyon16-distance_aware_self_attention_fc[0m
[32m2023-05-07 21:13:13.881[0m | [1mINFO    [0m | [36mmake_summary[0m:[36msummarize_group[0m:[36m43[0m - [1mSummarizing group selected-camelyon16-distance_aware_self_attention_fc_t3[0m
[32m2023-05-07 21:13:15.283[0m | [1mINFO    [0m | [36mmake_summary[0m:[36msummarize_group[0m:[36m43[0m - [1mSummarizing group selected-camelyon16-transmil[0m


Unnamed: 0_level_0,Unnamed: 1_level_0,mean(train/acc),std(train/acc),mean(test/acc),std(test/acc),mean(train/auc),std(train/auc),mean(test/auc),std(test/auc)
dataset,model,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
camelyon16,discrete_rel_pos_self_attention,0.83,0.212,0.762,0.131,0.877,0.201,0.757,0.108
camelyon16,distance_aware_self_attention,,,,,,,,
camelyon16,distance_aware_self_attention_fc,0.985,0.017,0.891,0.021,0.999,0.002,0.919,0.007
camelyon16,distance_aware_self_attention_fc_t3,0.976,0.037,0.882,0.022,0.995,0.008,0.922,0.006
camelyon16,transmil,0.899,0.218,0.84,0.12,0.901,0.22,0.85,0.15


In [19]:
if selected_model_type == "mnist_collage":
    NAMES = {
        "just_pool": "MIL with max pool",
        "abmil": "AB-MIL~\\cite{ilse2018attention}",
        "gnn_gat": "MIL with GNN (GAT~\\cite{velickovic2018graph})",
        "gnn_gcn": "MIL with GNN (GCN~\\cite{kipf2017semisupervised})",
        "induced_set_transformer": "MIL with iSet Transformer~\\cite{lee2019set}",
        "set_transformer": "MIL with Set Transformer~\\cite{lee2019set}",
        "mil_gnn": "MIL-GNN~\\cite{tu2019multiple}",
        "mil_gnn_ds": "MIL-GNN-DS~\\cite{tu2019multiple}",
        "self_attention": "MIL with SA~\\cite{vaswani2017attention}",
        "self_attention_axial_pe": "MIL with SA + axial PE~\\cite{ramachandran2019stand}",
        "self_attention_fourier_pe": "MIL with SA + Fourier PE~\\cite{yang2021learnable}",
        "discrete_rel_pos_self_attention": "MIL with disc.\\ rel.\\ SA~\\cite{wu2021rethinking}",
        "transmil": "TransMIL~\\cite{shao2021transmil}",
        "distance_aware_self_attention": "DAS-MIL (ours)",
    }
elif selected_model_type == "mnist_collage_ablations":
    NAMES = {
        "distance_aware_self_attention_embedk": "DAS-MIL ($\\vb^K$)",
        "distance_aware_self_attention_embedq": "DAS-MIL ($\\vb^Q$)",
        "distance_aware_self_attention_embedv": "DAS-MIL ($\\vb^V$)",
        "distance_aware_self_attention_embedkq": "DAS-MIL ($\\vb^K, \\vb^Q$) ($+ \\vb^Q{\\vb^K}^\\top$ in \\cref{eq:das:compatibility:impl})",
        "distance_aware_self_attention_embedkq_noterm3": "DAS-MIL ($\\vb^K, \\vb^Q$)",
        "distance_aware_self_attention_embedkv": "DAS-MIL ($\\vb^K, \\vb^V$)",
        "distance_aware_self_attention_embedqv": "DAS-MIL ($\\vb^Q, \\vb^V$)",
        "distance_aware_self_attention_embedkqv": "DAS-MIL ($\\vb^K, \\vb^Q, \\vb^V$) ($+ \\vb^Q{\\vb^K}^\\top$ in \\cref{eq:das:compatibility:impl})",
        "line1": None,
        "distance_aware_self_attention_fixedembed": "DAS-MIL (non-trainable $\\vb^K, \\vb^Q, \\vb^V$)",
        "line2": None,
        # "distance_aware_self_attention_embedkqv_noterm3": "DAS-MIL ($\\vb^K, \\vb^Q, \\vb^V$)",
        "distance_aware_self_attention": "DAS-MIL ($\\vb^K, \\vb^Q, \\vb^V$)",
    }
elif selected_model_type == "camelyon16":
    NAMES = {
        'discrete_rel_pos_self_attention': "MIL with disc.\\ rel.\\ SA~\\cite{wu2021rethinking}",
        'transmil': "TransMIL~\\cite{shao2021transmil}",
        'transmil_ourparams': "TransMIL~\\cite{shao2021transmil} after tuning",
        'distance_aware_self_attention': "DAS-MIL",
        'distance_aware_self_attention_fc': "DAS-MIL (FC)",
        'distance_aware_self_attention_fc_t3': "DAS-MIL (FC, T3)",
    }

pos_col = selected_model_type == "mnist_collage"
def if_pos(s):
    return s if pos_col else ""

abs_pos = ["self_attention_axial_pe", "self_attention_fourier_pe", "transmil"]
rel_pos = ["discrete_rel_pos_self_attention", "distance_aware_self_attention", "mil_gnn", "mil_gnn_ds", "gnn_gat", "gnn_gcn"]


if selected_model_type == "mnist_collage" or selected_model_type == "mnist_collage_ablations":
    metrics = ["balanced_acc"]
    print("\\begin{tabular}{l" + if_pos("|c") + "|r|rr|rr}")
    print("\\toprule")
    print(" & " + if_pos("&") + " & \\multicolumn{2}{c|}{\\smaller{MNIST-COLLAGE}} & \\multicolumn{2}{c}{\\smaller{MNIST-COLLAGE-INV}} \\\\")
    print("Model " + if_pos("& \\multicolumn{1}{c|}{Pos}") + " & \\multicolumn{1}{c|}{Params} & \\multicolumn{1}{c}{Train} & \\multicolumn{1}{c|}{Test} & \\multicolumn{1}{c}{Train} & \\multicolumn{1}{c}{Test} \\\\")
    print("\\midrule")

elif selected_model_type == "camelyon16":
    metrics = ["auc", "balanced_acc"]
    print("\\begin{tabular}{l|c|rr|rr}")
    print("\\toprule")
    print(" & & \\multicolumn{2}{c|}{AUROC} & \\multicolumn{2}{c}{Balanced accuracy} \\\\")
    print("Model & \\multicolumn{1}{c|}{Params} & \\multicolumn{1}{c}{Train} & \\multicolumn{1}{c|}{Test} & \\multicolumn{1}{c}{Train} & \\multicolumn{1}{c}{Test} \\\\")
    print("\\midrule")

rows = {name: row for name, row in df.iterrows()}
for name, desc in NAMES.items():
    if desc is None:
        print("\\hline")
        continue
    pos = "abs" if name in abs_pos else "rel" if name in rel_pos else None
    row = f"{NAMES[name]}"
    pos = f" \\smaller{{{pos.upper()}}}" if pos in ("abs", "rel") else "\\xmark"
    row += if_pos(f" & {pos}")
    row += f" & {human_format(df.loc[datasets[0], name]['num_parameters'])}"


    for dataset in datasets:
        for metric in metrics:
            df_ds = df.loc[dataset]
            for split in ("train", "test"):
                if df_ds[f"mean({split}/{metric})"].max() == df_ds[f"mean({split}/{metric})"].loc[name]:
                    row += f" & \\textbf{{{df_ds[f'mean({split}/{metric})'].loc[name]:.03f} $\pm$ {df_ds[f'std({split}/{metric})'].loc[name]:.03f}}}"
                else:
                    row += f" & {df_ds[f'mean({split}/{metric})'].loc[name]:.03f} $\pm$ {df_ds[f'std({split}/{metric})'].loc[name]:.03f}"
    print(row + " \\\\")
print("\\bottomrule")
print("\\end{tabular}")

\begin{tabular}{l|rr|rr}
\toprule
 & & \multicolumn{2}{c|}{AUROC} & \multicolumn{2}{c}{Balanced accuracy} \\
Model & \multicolumn{1}{c|}{Params} & \multicolumn{1}{c}{Train} & \multicolumn{1}{c|}{Test} & \multicolumn{1}{c}{Train} & \multicolumn{1}{c}{Test} \\
\midrule
MIL with disc.\ rel.\ SA~\cite{wu2021rethinking} & nan & 0.877 $\pm$ 0.201 & 0.757 $\pm$ 0.108 & 0.829 $\pm$ 0.203 & 0.734 $\pm$ 0.109 \\
TransMIL~\cite{shao2021transmil} & nan & 0.901 $\pm$ 0.220 & 0.850 $\pm$ 0.150 & 0.899 $\pm$ 0.217 & 0.802 $\pm$ 0.153 \\
DAS-MIL & nan & nan $\pm$ nan & nan $\pm$ nan & nan $\pm$ nan & nan $\pm$ nan \\
DAS-MIL (FC) & nan & \textbf{0.999 $\pm$ 0.002} & 0.919 $\pm$ 0.007 & \textbf{0.985 $\pm$ 0.018} & \textbf{0.867 $\pm$ 0.028} \\
DAS-MIL (FC, T3) & nan & 0.995 $\pm$ 0.008 & \textbf{0.922 $\pm$ 0.006} & 0.975 $\pm$ 0.038 & 0.855 $\pm$ 0.036 \\
\bottomrule
\end{tabular}
