In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../../")

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

DATA_PATH = "../../data"

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

from src.log_mock import PrintLog
log = PrintLog()

import wandb

wandb.init(mode="disabled")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.




In [2]:
wapi = wandb.Api()
runs = wapi.runs("foobar/rxrx1")

In [3]:
for run in runs:
    print(run.name, run.summary.keys())

ll_ivon_1-(5) dict_keys(['_timestamp', 'train_loss', 'test_results', 'id_test_results', 'eval', '_step', '_wandb', '_runtime'])
ll_ivon_1-(4) dict_keys(['_step', '_wandb', '_runtime', '_timestamp', 'train_loss', 'test_results', 'id_test_results', 'eval'])
ll_ivon_1-(3) dict_keys(['_timestamp', 'train_loss', 'test_results', 'id_test_results', 'eval', '_step', '_wandb', '_runtime'])
ll_ivon_1-(2) dict_keys(['eval', '_step', '_wandb', '_runtime', '_timestamp', 'train_loss', 'test_results', 'id_test_results'])
ll_ivon_1-(0) dict_keys(['test_results', 'id_test_results', 'eval', '_step', '_wandb', '_runtime', '_timestamp', 'train_loss'])
ll_ivon_1-(1) dict_keys(['test_results', 'id_test_results', 'eval', '_step', '_wandb', '_runtime', '_timestamp', 'train_loss'])
laplace-5-1 dict_keys([])
laplace-1-(1) dict_keys(['_step', '_wandb', '_runtime', '_timestamp', 'test_results', 'id_test_results'])
laplace-5-0 dict_keys(['_step', '_wandb', '_runtime', '_timestamp', 'test_results', 'id_test_results

In [4]:
import plotly.express as px
import pandas as pd
import dateutil
import datetime

def create_plot_data_for_run(run):
    parts = run.name.split("-")
    if len(parts) > 2:
        model_name = parts[0] + "-" + parts[1]
    else:
        model_name = parts[0]

    return {
        "model": model_name,
        "ood accuracy": run.summary["test_results"]["accuracy"],
        "ood log likelihood": run.summary["test_results"]["log_likelihood"],
        "ood ece": run.summary["test_results"]["ece"],
        "ood sece": run.summary["test_results"]["sece"],
        "id accuracy": run.summary["id_test_results"]["accuracy"],
        "id log likelihood": run.summary["id_test_results"]["log_likelihood"],
        "id ece": run.summary["id_test_results"]["ece"],
        "id sece": run.summary["id_test_results"]["sece"],
    }

def plot(data, value):
    plot = px.box(data, x="model", y=value, color="model")
    return plot

def pareto_plot(data, x, y):
    plot = px.scatter(data, x=x, error_x=f"{x}_std", y=y, error_y=f"{y}_std", color="model")
    return plot

def build_data(runs):
    rows = []
    for run in runs:
        if dateutil.parser.parse(run.created_at) < datetime.datetime(2023, 3, 10, 10, 0):
            continue
        if run.state != "finished":
            continue
        if "old" in run.tags:
            print("Skipping old run " + run.name)
            continue
        if "test_results" not in run.summary:
            print("Skipping crashed run " + run.name)
            continue
        rows.append(create_plot_data_for_run(run))
    return pd.DataFrame.from_dict(rows)

def aggregate_data(data):
    aggregated_data = data.groupby(["model"]).agg({
        "model": "first",
        "ood accuracy": ["mean", "sem"],
        "ood log likelihood": ["mean", "sem"], 
        "ood sece": ["mean", "sem"],
        "ood ece": ["mean", "sem"],
        "id accuracy": ["mean", "sem"],
        "id log likelihood": ["mean", "sem"], 
        "id sece": ["mean", "sem"],
        "id ece": ["mean", "sem"],
    })
    aggregated_data.columns = [a[0] + "_std" if a[1] == "sem" else a[0] for a in aggregated_data.columns.to_flat_index()]
    aggregated_data["ood accuracy_std"] *= 2.0
    aggregated_data["ood log likelihood_std"] *= 2.0
    aggregated_data["ood sece_std"] *= 2.0
    aggregated_data["ood ece_std"] *= 2.0
    aggregated_data["id accuracy_std"] *= 2.0
    aggregated_data["id log likelihood_std"] *= 2.0
    aggregated_data["id sece_std"] *= 2.0
    aggregated_data["id ece_std"] *= 2.0
    return aggregated_data

In [5]:
data = aggregate_data(build_data(runs))

Skipping old run laplace_1-(3)
Skipping old run laplace_1-(4)
Skipping old run laplace_1-(5)
Skipping old run laplace_1-(1)
Skipping old run laplace_1-(0)
Skipping old run laplace_1-(2)
Skipping crashed run rank_1-(5)
Skipping crashed run rank_1-(4)
Skipping crashed run rank_1-(3)
Skipping crashed run rank_1-(1)
Skipping crashed run rank_1-(2)
Skipping crashed run rank_1-(0)
Skipping old run map_1-(0)


In [6]:
data

Unnamed: 0_level_0,model,ood accuracy,ood accuracy_std,ood log likelihood,ood log likelihood_std,ood sece,ood sece_std,ood ece,ood ece_std,id accuracy,id accuracy_std,id log likelihood,id log likelihood_std,id sece,id sece_std,id ece,id ece_std
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
bbb_1,bbb_1,0.038419,0.000462,-6.837261,0.011627,-0.045138,0.00237,0.045138,0.00237,0.045824,0.000914,-6.657397,0.018955,-0.031804,0.002484,0.031804,0.002484
bbb_kl0.2_1,bbb_kl0.2_1,0.046115,0.00059,-7.95693,0.050442,-0.117012,0.002304,0.117012,0.002304,0.0542,0.001119,-7.597903,0.044255,-0.10245,0.00188,0.10245,0.00188
bbb_prior0.1_1,bbb_prior0.1_1,0.031245,0.001481,-6.680904,0.020499,-0.005616,0.003754,0.008236,0.001644,0.03565,0.001681,-6.618399,0.017035,0.002831,0.003979,0.009908,0.001207
bbb_prior0.5_1,bbb_prior0.5_1,0.034609,0.001202,-6.737007,0.015326,-0.015129,0.003086,0.015129,0.003087,0.039582,0.001369,-6.631842,0.016928,-0.006364,0.002733,0.007343,0.001914
laplace-1,laplace-1,0.061425,0.002323,-6.909185,0.27071,-0.027876,0.007075,0.036969,0.00679,0.077391,0.000246,-6.624318,0.205962,-0.011021,0.009276,0.033845,0.004239
laplace-5,laplace-5,0.091368,,-6.013787,,0.036116,,0.036116,,0.115385,,-5.736875,,0.060952,,0.060952,
ll_ivon_1,ll_ivon_1,0.002812,0.000191,-7.213479,0.013037,-0.008676,0.000579,0.00868,0.000577,0.002959,0.000271,-7.176245,0.011567,-0.008217,0.000405,0.008221,0.000404
map_1,map_1,0.082815,0.001258,-7.197267,0.148738,-0.262187,0.015402,0.262187,0.015402,0.104649,0.00156,-6.668784,0.120867,-0.232259,0.014672,0.232259,0.014672
map_5,map_5,0.122392,0.000452,-5.676773,0.018626,-0.06104,0.001684,0.071099,0.001494,0.155934,0.000936,-5.211211,0.011806,-0.026425,0.001606,0.065685,0.001497
mcd_1,mcd_1,0.082612,0.001065,-7.502937,0.043494,-0.288195,0.003563,0.288195,0.003563,0.105884,0.001173,-6.924071,0.035038,-0.25692,0.003104,0.25692,0.003104


In [7]:
pareto_plot(data, "ood accuracy", "ood sece")

In [8]:
pareto_plot(data, "id accuracy", "id sece")

In [9]:
data.to_csv(sep=",", header=True)

'model,model,ood accuracy,ood accuracy_std,ood log likelihood,ood log likelihood_std,ood sece,ood sece_std,ood ece,ood ece_std,id accuracy,id accuracy_std,id log likelihood,id log likelihood_std,id sece,id sece_std,id ece,id ece_std\nbbb_1,bbb_1,0.03841871892412504,0.00046235824352025464,-6.837260961532593,0.011626799455830054,-0.0451375101337201,0.0023704270226309826,0.0451375101337201,0.0023704270226309826,0.04582389506200949,0.000913740403228948,-6.657396713892619,0.01895532059172637,-0.03180355972813914,0.002484118713351236,0.03180355972813914,0.002484118713351236\nbbb_kl0.2_1,bbb_kl0.2_1,0.04611504760881265,0.000589582319979473,-7.95693023999532,0.050442203277153326,-0.11701159709280855,0.0023043171070491787,0.11701159709280855,0.0023043171070491787,0.05419990854958693,0.00111934630261978,-7.59790317217509,0.04425474582900856,-0.10244980697090511,0.0018796998496521603,0.10244980697090511,0.0018796998496521603\nbbb_prior0.1_1,bbb_prior0.1_1,0.03124515898525715,0.0014810028060359508

In [10]:
algo_names = [
    ("map_1", "MAP"),
    ("map_5", "Deep Ensemble"),
    ("mcd_1", "MCD"),
    ("mcd_5", "MultiMCD"),
    ("swag_1", "SWAG"),
    ("swag_5", "MultiSWAG"),
    # ("swag_ll-1", "LL SWAG"),
    ("laplace-1", "LL Laplace"),
    # ("laplace-5", "LL MultiLaplace"),
    ("bbb_1", "LL BBB ($\\sigma=1.0, \\lambda=1.0$)"),
    ("bbb_prior0.5_1", "LL BBB ($\\sigma=0.5, \\lambda=1.0$)"),
    ("bbb_prior0.1_1", "LL BBB ($\\sigma=0.1, \\lambda=1.0$)"),
    ("bbb_kl0.2_1", "LL BBB ($\\sigma=1.0, \\lambda=0.2$)"),
    # ("bbb_5", "LL MultiBBB"),
    ("rank1_1", "Rank-1 VI"),
    ("ll_ivon_1", "LL iVON"),
    # ("ll_ivon_5", "LL MultiiVON"),
    ("svgd_1", "SVGD"),
    # ("ll_svgd-1", "LL SVGD"),
]

def num(value, std, best=None, ty=None):
    value = float(value)
    std = float(std)
    num_string = f"{value:.3f} \\pm {std:.3f}"

    if best is None or ty is None:
        return f"${num_string}$"

    if ty == "max":
        if value >= best:
            num_string = f"\\bm{{{num_string}}}"
    elif ty == "min":
        if value <= best:
            num_string = f"\\bm{{{num_string}}}"
    elif ty == "zero":
        if abs(value) <= best:
            num_string = f"\\bm{{{num_string}}}"
    return f"${num_string}$"

def col_name(name, align):
    return f"\\multicolumn{{1}}{{{align}}}{{{name}}}"

def create_table(data, prefix):
    print("\\begin{tabular}{l|rrrrrr}")
    print(f"    {col_name('Model', 'l')} & {col_name('i.d. Accuracy', 'c')} & {col_name('i.d. ECE', 'c')} & {col_name('i.d. sECE', 'c')}& {col_name('o.o.d. Accuracy', 'c')} & {col_name('o.o.d. ECE', 'c')} & {col_name('o.o.d. sECE', 'c')} \\\\")
    print("    \\hline")

    best_acc, best_acc_std = 0, 0
    best_ece, best_ece_std = 1000, 0
    best_sece, best_sece_std = 1000, 0

    best_id_acc, best_id_acc_std = 0, 0
    best_id_ece, best_id_ece_std = 1000, 0
    best_id_sece, best_id_sece_std = 1000, 0

    for algo, name in algo_names:
        row = data[data["model"] == algo]
        if (row.empty):
            continue
        
        if float(row[prefix + "ood accuracy"]) > best_acc:
            best_acc = float(row[prefix + "ood accuracy"])
            best_acc_std = float(row[prefix + "ood accuracy_std"])
        
        if float(row[prefix + "ood ece"]) < best_ece:
            best_ece = float(row[prefix + "ood ece"])
            best_ece_std = float(row[prefix + "ood ece_std"])
        
        if abs(float(row[prefix + "ood ece"])) < best_sece:
            best_sece = abs(float(row[prefix + "ood sece"]))
            best_sece_std = float(row[prefix + "ood sece_std"])
        
        if float(row[prefix + "id accuracy"]) > best_acc:
            best_id_acc = float(row[prefix + "id accuracy"])
            best_id_acc_std = float(row[prefix + "id accuracy_std"])
        
        if float(row[prefix + "id ece"]) < best_ece:
            best_id_ece = float(row[prefix + "id ece"])
            best_id_ece_std = float(row[prefix + "id ece_std"])
        
        if abs(float(row[prefix + "id sece"])) < best_sece:
            best_id_sece = abs(float(row[prefix + "id sece"]))
            best_id_sece_std = float(row[prefix + "id sece_std"])

    best_acc -= best_acc_std
    best_ece += best_ece_std
    best_sece = abs(best_sece) + best_sece_std

    best_id_acc -= best_id_acc_std
    best_id_ece += best_id_ece_std
    best_id_sece = abs(best_id_sece) + best_id_sece_std

    for algo, name in algo_names:
        row = data[data["model"] == algo]
        if row.empty:
            continue
        print(f"    {name} & {num(row[prefix + 'id accuracy'], row[prefix + 'id accuracy_std'], best_id_acc, 'max')} & {num(row[prefix + 'id ece'], row[prefix + 'id ece_std'], best_id_ece, 'min')} & {num(row[prefix + 'id sece'], row[prefix + 'id sece_std'], best_id_sece, 'zero')} & {num(row[prefix + 'ood accuracy'], row[prefix + 'ood accuracy_std'], best_acc, 'max')} & {num(row[prefix + 'ood ece'], row[prefix + 'ood ece_std'], best_ece, 'min')} & {num(row[prefix + 'ood sece'], row[prefix + 'ood sece_std'], best_sece, 'zero')} \\\\")
    print("\\end{tabular}")
create_table(data, "")

\begin{tabular}{l|rrrrrr}
    \multicolumn{1}{l}{Model} & \multicolumn{1}{c}{i.d. Accuracy} & \multicolumn{1}{c}{i.d. ECE} & \multicolumn{1}{c}{i.d. sECE}& \multicolumn{1}{c}{o.o.d. Accuracy} & \multicolumn{1}{c}{o.o.d. ECE} & \multicolumn{1}{c}{o.o.d. sECE} \\
    \hline
    MAP & $0.105 \pm 0.002$ & $0.232 \pm 0.015$ & $-0.232 \pm 0.015$ & $0.083 \pm 0.001$ & $0.262 \pm 0.015$ & $-0.262 \pm 0.015$ \\
    Deep Ensemble & $0.156 \pm 0.001$ & $0.066 \pm 0.001$ & $-0.026 \pm 0.002$ & $0.122 \pm 0.000$ & $0.071 \pm 0.001$ & $-0.061 \pm 0.002$ \\
    MCD & $0.106 \pm 0.001$ & $0.257 \pm 0.003$ & $-0.257 \pm 0.003$ & $0.083 \pm 0.001$ & $0.288 \pm 0.004$ & $-0.288 \pm 0.004$ \\
    MultiMCD & $0.158 \pm 0.001$ & $0.069 \pm 0.001$ & $-0.035 \pm 0.001$ & $0.121 \pm 0.000$ & $0.081 \pm 0.001$ & $-0.073 \pm 0.000$ \\
    SWAG & $0.110 \pm 0.001$ & $0.269 \pm 0.009$ & $-0.269 \pm 0.009$ & $0.086 \pm 0.001$ & $0.301 \pm 0.010$ & $-0.301 \pm 0.010$ \\
    MultiSWAG & $\bm{0.161 \pm 0.001}$ & $0.07