In [6]:
from pathlib import Path
import pandas as pd

from make_summary import summarize_group
from mil.utils import human_format

selected_model_type = "mnist_collage"
selected_model_type = "mnist_collage_ablations"
selected_model_type = "camelyon16"

if "mnist_collage" in selected_model_type:
    datasets = ["mnist_collage", "mnist_collage_inverse"]
else:
    datasets = ["camelyon16"]
yaml_folder = Path("conf") / "selected_model" / selected_model_type

In [7]:
models = [file.stem for file in yaml_folder.glob("*.yaml") if file.name[0] != "_"]
if selected_model_type == "mnist_collage_ablations":
    models.append("distance_aware_self_attention")
models.sort()
models

['abmil',
 'discrete_rel_pos_self_attention',
 'distance_aware_self_attention',
 'distance_aware_self_attention_fc',
 'distance_aware_self_attention_fc_t3',
 'just_pool',
 'self_attention',
 'self_attention_axial_pe',
 'self_attention_fourier_pe',
 'transmil',
 'transmil_ourparams']

In [8]:
all_stats = {
    dataset: {
        model: summarize_group(f"selected-{dataset}-{model}", log_to_wandb=False)
        for model in models
    } for dataset in datasets
}
dfs = []
for dataset, dataset_stats in all_stats.items():
    df = pd.DataFrame(dataset_stats).T
    df["dataset"] = dataset
    dfs.append(df)

# Merge
df = pd.concat(dfs)
df.index.name = "model"
df.reset_index(inplace=True)
df.set_index(["dataset", "model"], inplace=True)
df[["mean(train/acc)", "std(train/acc)", "mean(test/acc)", "std(test/acc)", "mean(train/auc)", "std(train/auc)", "mean(test/auc)", "std(test/auc)"]].round(3)

[32m2023-05-09 00:13:13.288[0m | [1mINFO    [0m | [36mmake_summary[0m:[36msummarize_group[0m:[36m43[0m - [1mSummarizing group selected-camelyon16-abmil[0m
[32m2023-05-09 00:13:15.478[0m | [1mINFO    [0m | [36mmake_summary[0m:[36msummarize_group[0m:[36m43[0m - [1mSummarizing group selected-camelyon16-discrete_rel_pos_self_attention[0m
[32m2023-05-09 00:13:17.874[0m | [1mINFO    [0m | [36mmake_summary[0m:[36msummarize_group[0m:[36m43[0m - [1mSummarizing group selected-camelyon16-distance_aware_self_attention[0m
[32m2023-05-09 00:13:20.416[0m | [1mINFO    [0m | [36mmake_summary[0m:[36msummarize_group[0m:[36m43[0m - [1mSummarizing group selected-camelyon16-distance_aware_self_attention_fc[0m
[32m2023-05-09 00:13:22.900[0m | [1mINFO    [0m | [36mmake_summary[0m:[36msummarize_group[0m:[36m43[0m - [1mSummarizing group selected-camelyon16-distance_aware_self_attention_fc_t3[0m
[32m2023-05-09 00:13:24.968[0m | [1mINFO    [0m | [36

Unnamed: 0_level_0,Unnamed: 1_level_0,mean(train/acc),std(train/acc),mean(test/acc),std(test/acc),mean(train/auc),std(train/auc),mean(test/auc),std(test/auc)
dataset,model,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
camelyon16,abmil,0.928,0.004,0.809,0.018,0.964,0.003,0.795,0.009
camelyon16,discrete_rel_pos_self_attention,0.931,0.016,0.825,0.01,0.976,0.006,0.806,0.015
camelyon16,distance_aware_self_attention,0.923,0.015,0.831,0.011,0.974,0.004,0.807,0.019
camelyon16,distance_aware_self_attention_fc,0.978,0.021,0.887,0.022,0.997,0.005,0.914,0.007
camelyon16,distance_aware_self_attention_fc_t3,0.993,0.008,0.878,0.02,0.999,0.002,0.919,0.008
camelyon16,just_pool,0.63,0.039,0.648,0.069,0.686,0.018,0.739,0.017
camelyon16,self_attention,0.93,0.014,0.829,0.023,0.972,0.014,0.823,0.05
camelyon16,self_attention_axial_pe,0.883,0.019,0.521,0.031,0.956,0.016,0.485,0.018
camelyon16,self_attention_fourier_pe,0.935,0.011,0.791,0.067,0.975,0.007,0.819,0.028
camelyon16,transmil,0.997,0.007,0.884,0.026,0.999,0.002,0.911,0.027


In [12]:
if selected_model_type == "mnist_collage":
    NAMES = {
        "just_pool": "MIL with max pool",
        "abmil": "AB-MIL~\\cite{ilse2018attention}",
        "gnn_gat": "MIL with GNN (GAT~\\cite{velickovic2018graph})",
        "gnn_gcn": "MIL with GNN (GCN~\\cite{kipf2017semisupervised})",
        # "mil_gnn": "MIL-GNN~\\cite{tu2019multiple}",
        # "mil_gnn_ds": "MIL-GNN-DS~\\cite{tu2019multiple}",
        "mil_gnn_ds": "MIL-GNN~\\cite{tu2019multiple}",
        "induced_set_transformer": "MIL with iSet Transformer~\\cite{lee2019set}",
        "set_transformer": "MIL with Set Transformer~\\cite{lee2019set}",
        "self_attention": "MIL with SA~\\cite{vaswani2017attention}",
        "self_attention_axial_pe": "MIL with SA + axial PE~\\cite{ramachandran2019stand}",
        "self_attention_fourier_pe": "MIL with SA + Fourier PE~\\cite{yang2021learnable}",
        "discrete_rel_pos_self_attention": "MIL with disc.\\ rel.\\ SA~\\cite{wu2021rethinking}",
        "transmil": "TransMIL~\\cite{shao2021transmil}",
        "distance_aware_self_attention": "DAS-MIL (ours)",
    }
elif selected_model_type == "mnist_collage_ablations":
    NAMES = {
        "distance_aware_self_attention_embedk": "DAS-MIL ($\\vb^K$)",
        "distance_aware_self_attention_embedq": "DAS-MIL ($\\vb^Q$)",
        "distance_aware_self_attention_embedv": "DAS-MIL ($\\vb^V$)",
        "distance_aware_self_attention_embedkq": "DAS-MIL ($\\vb^K, \\vb^Q$) ($+ \\vb^Q{\\vb^K}^\\top$ in \\cref{eq:das:compatibility:impl})",
        "distance_aware_self_attention_embedkq_noterm3": "DAS-MIL ($\\vb^K, \\vb^Q$)",
        "distance_aware_self_attention_embedkv": "DAS-MIL ($\\vb^K, \\vb^V$)",
        "distance_aware_self_attention_embedqv": "DAS-MIL ($\\vb^Q, \\vb^V$)",
        "distance_aware_self_attention_embedkqv": "DAS-MIL ($\\vb^K, \\vb^Q, \\vb^V$) ($+ \\vb^Q{\\vb^K}^\\top$ in \\cref{eq:das:compatibility:impl})",
        "line1": None,
        "distance_aware_self_attention_fixedembed": "DAS-MIL (non-trainable $\\vb^K, \\vb^Q, \\vb^V$)",
        "line2": None,
        # "distance_aware_self_attention_embedkqv_noterm3": "DAS-MIL ($\\vb^K, \\vb^Q, \\vb^V$)",
        "distance_aware_self_attention": "DAS-MIL ($\\vb^K, \\vb^Q, \\vb^V$)",
    }
elif selected_model_type == "camelyon16":
    NAMES = {
        "just_pool": "MIL with max pool",
        "abmil": "AB-MIL~\\cite{ilse2018attention}",
        "self_attention": "MIL with SA~\\cite{vaswani2017attention}",
        "self_attention_axial_pe": "MIL with SA + axial PE~\\cite{ramachandran2019stand}",
        "self_attention_fourier_pe": "MIL with SA + Fourier PE~\\cite{yang2021learnable}",
        'discrete_rel_pos_self_attention': "MIL with disc.\\ rel.\\ SA~\\cite{wu2021rethinking}",
        'transmil': "TransMIL~\\cite{shao2021transmil}",
        # 'transmil_ourparams': "TransMIL~\\cite{shao2021transmil}",
        # 'distance_aware_self_attention': "DAS-MIL (no FC)",
        'distance_aware_self_attention_fc': "DAS-MIL (ours)",
        # 'distance_aware_self_attention_fc_t3': "DAS-MIL (FC, T3)",
    }

pos_col = selected_model_type == "mnist_collage" or selected_model_type == "camelyon16"
def if_pos(s):
    return s if pos_col else ""

abs_pos = ["self_attention_axial_pe", "self_attention_fourier_pe", "transmil"]
rel_pos = ["discrete_rel_pos_self_attention", "distance_aware_self_attention_fc", "distance_aware_self_attention", "mil_gnn", "mil_gnn_ds", "gnn_gat", "gnn_gcn"]


if "mnist_collage" in selected_model_type:
    metrics = ["balanced_acc"]
    # metrics = ["auc"]
    print("\\begin{tabular}{l" + if_pos("|c") + "|r|rr|rr}")
    print("\\toprule")
    print(" & " + if_pos("&") + " & \\multicolumn{2}{c|}{\\smaller{MNIST-COLLAGE}} & \\multicolumn{2}{c}{\\smaller{MNIST-COLLAGE-INV}} \\\\")
    print("Model " + if_pos("& \\multicolumn{1}{c|}{Pos}") + " & \\multicolumn{1}{c|}{Params} & \\multicolumn{1}{c}{Train} & \\multicolumn{1}{c|}{Test} & \\multicolumn{1}{c}{Train} & \\multicolumn{1}{c}{Test} \\\\")
    print("\\midrule")

elif selected_model_type == "camelyon16":
    metrics = ["auc", "balanced_acc"]
    print("\\begin{tabular}{l|c|c|rr|rr}")
    print("\\toprule")
    print(" & & & \\multicolumn{2}{c|}{AUROC} & \\multicolumn{2}{c}{Balanced accuracy} \\\\")
    print("Model & Pos & \\multicolumn{1}{c|}{Params} & \\multicolumn{1}{c}{Train} & \\multicolumn{1}{c|}{Test} & \\multicolumn{1}{c}{Train} & \\multicolumn{1}{c}{Test} \\\\")
    print("\\midrule")

rows = {name: row for name, row in df.iterrows()}
for name, desc in NAMES.items():
    if desc is None:
        print("\\hline")
        continue
    pos = "abs" if name in abs_pos else "rel" if name in rel_pos else None
    row = f"{NAMES[name]}"
    pos = f" \\smaller{{{pos.upper()}}}" if pos in ("abs", "rel") else "\\xmark"
    row += if_pos(f" & {pos}")
    row += f" & {human_format(df.loc[datasets[0], name]['num_parameters'])}"


    for dataset in datasets:
        for metric in metrics:
            df_ds = df.loc[dataset]
            for split in ("train", "test"):
                if df_ds[f"mean({split}/{metric})"].loc[list(NAMES)].max().round(3) == df_ds[f"mean({split}/{metric})"].loc[name].round(3):
                    row += f" & \\textbf{{{df_ds[f'mean({split}/{metric})'].loc[name]:.03f} $\pm$ {df_ds[f'std({split}/{metric})'].loc[name]:.03f}}}"
                else:
                    row += f" & {df_ds[f'mean({split}/{metric})'].loc[name]:.03f} $\pm$ {df_ds[f'std({split}/{metric})'].loc[name]:.03f}"
    print(row + " \\\\")
print("\\bottomrule")
print("\\end{tabular}")

\begin{tabular}{l|c|c|rr|rr}
\toprule
 & & & \multicolumn{2}{c|}{AUROC} & \multicolumn{2}{c}{Balanced accuracy} \\
Model & Pos & \multicolumn{1}{c|}{Params} & \multicolumn{1}{c}{Train} & \multicolumn{1}{c|}{Test} & \multicolumn{1}{c}{Train} & \multicolumn{1}{c}{Test} \\
\midrule
MIL with max pool & \xmark & 769 & 0.686 $\pm$ 0.018 & 0.739 $\pm$ 0.017 & 0.639 $\pm$ 0.033 & 0.573 $\pm$ 0.077 \\
AB-MIL~\cite{ilse2018attention} & \xmark & 8.47K & 0.964 $\pm$ 0.003 & 0.795 $\pm$ 0.009 & 0.923 $\pm$ 0.004 & 0.773 $\pm$ 0.010 \\
MIL with SA~\cite{vaswani2017attention} & \xmark & 27.7K & 0.972 $\pm$ 0.014 & 0.823 $\pm$ 0.050 & 0.925 $\pm$ 0.016 & 0.803 $\pm$ 0.033 \\
MIL with SA + axial PE~\cite{ramachandran2019stand} &  \smaller{ABS} & 27.7K & 0.956 $\pm$ 0.016 & 0.485 $\pm$ 0.018 & 0.887 $\pm$ 0.017 & 0.490 $\pm$ 0.034 \\
MIL with SA + Fourier PE~\cite{yang2021learnable} &  \smaller{ABS} & 41K & 0.975 $\pm$ 0.007 & 0.819 $\pm$ 0.028 & 0.932 $\pm$ 0.011 & 0.770 $\pm$ 0.041 \\
MIL with disc.\ 

In [10]:
df

Unnamed: 0_level_0,Unnamed: 1_level_0,mean(train/acc),std(train/acc),mean(train/balanced_acc),std(train/balanced_acc),mean(train/auc),std(train/auc),mean(train/f1),std(train/f1),mean(train/precision),std(train/precision),...,std(min(test/auc)),mean(min(test/f1)),std(min(test/f1)),mean(min(test/precision)),std(min(test/precision)),mean(min(test/recall)),std(min(test/recall)),mean(min(test/loss)),std(min(test/loss)),num_parameters
dataset,model,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
camelyon16,abmil,0.927881,0.004239,0.923064,0.003746,0.964272,0.003434,0.911107,0.004777,0.927427,0.012386,...,0.061901,0.268628,0.050201,0.365524,0.020979,0.179592,0.044244,0.596805,0.026047,8470.0
camelyon16,discrete_rel_pos_self_attention,0.930855,0.016332,0.925864,0.01728,0.976428,0.006195,0.914552,0.02016,0.932974,0.024142,...,0.037146,0.274951,0.182859,0.346472,0.06501,0.216327,0.164915,0.617486,0.024532,28025.0
camelyon16,distance_aware_self_attention,0.922677,0.015417,0.916758,0.014946,0.97449,0.00358,0.904156,0.018442,0.926653,0.026374,...,0.07016,0.244284,0.169057,0.27971,0.158209,0.195918,0.147589,0.625754,0.006302,27739.0
camelyon16,distance_aware_self_attention_fc,0.978439,0.02093,0.977894,0.021974,0.996693,0.004961,0.973845,0.025454,0.972939,0.02313,...,0.058293,0.264301,0.186722,0.30495,0.17104,0.204082,0.152037,0.362803,0.036618,412251.0
camelyon16,distance_aware_self_attention_fc_t3,0.993309,0.007619,0.993232,0.008117,0.999259,0.001626,0.991888,0.009269,0.991022,0.008972,...,0.058176,0.271275,0.193523,0.306236,0.171701,0.236735,0.191772,0.362471,0.038941,412251.0
camelyon16,just_pool,0.62974,0.039042,0.639252,0.032975,0.68574,0.018163,0.607408,0.03117,0.54327,0.045114,...,0.060577,0.0,0.0,0.0,0.0,0.0,0.0,0.774685,0.003218,769.0
camelyon16,self_attention,0.930112,0.014253,0.925499,0.016017,0.971593,0.014413,0.913789,0.018102,0.929142,0.011092,...,0.034367,0.215871,0.126854,0.295628,0.165837,0.142857,0.087779,0.570828,0.072815,27665.0
camelyon16,self_attention_axial_pe,0.883271,0.019065,0.886965,0.017209,0.955787,0.01569,0.865483,0.019918,0.827147,0.032431,...,0.020396,0.169513,0.113222,0.252315,0.14355,0.130612,0.096373,0.79204,0.004769,27665.0
camelyon16,self_attention_fourier_pe,0.935316,0.011337,0.932341,0.010749,0.974615,0.006743,0.921191,0.013121,0.927514,0.022981,...,0.073365,0.071197,0.08798,0.227548,0.207723,0.040816,0.052031,0.608706,0.017315,41009.0
camelyon16,transmil,0.997026,0.00665,0.996664,0.007459,0.999259,0.001657,0.996364,0.008131,0.998165,0.004103,...,0.079077,0.192099,0.229438,0.33936,0.204542,0.155102,0.226613,0.39304,0.034516,2540561.0
