In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../../")

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

DATA_PATH = "../data"

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

from src.log_mock import PrintLog
log = PrintLog()

import wandb

wandb.init(mode="disabled")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.




In [2]:
wapi = wandb.Api()
runs = wapi.runs("foobar/fmow")

In [3]:
for run in runs:
    print(run.name, run.summary.keys())

swag_ll-1-(3) dict_keys(['train_loss', 'val_results', 'test_results', 'id_val_results', 'eval', '_wandb', '_runtime', '_timestamp', '_step'])
swag_ll-1-(2) dict_keys(['_timestamp', 'val_results', 'id_val_results', 'eval', '_step', '_wandb', '_runtime', 'train_loss', 'test_results'])
swag_ll-1-(1) dict_keys(['_step', '_wandb', '_runtime', '_timestamp', 'train_loss', 'test_results', 'eval', 'val_results', 'id_val_results'])
swag_ll-1-(0) dict_keys(['_step', 'val_results', 'test_results', 'eval', '_wandb', '_runtime', '_timestamp', 'train_loss', 'id_val_results'])
swag_ll-1-(2) dict_keys(['_wandb'])
swag_ll-1-(5) dict_keys(['_wandb'])
swag_ll-1-(4) dict_keys(['_wandb'])
swag_ll-1-(0) dict_keys(['_wandb'])
swag_ll-1-(1) dict_keys(['_wandb'])
swag_ll-1-(3) dict_keys(['_wandb'])
ll_ivon-5-(3) dict_keys(['_step', '_wandb', '_runtime', '_timestamp', 'val_results', 'test_results', 'id_val_results'])
ll_ivon-5-(1) dict_keys(['test_results', 'id_val_results', '_step', '_wandb', '_runtime', '_time

In [4]:
import plotly.express as px
import pandas as pd
import dateutil
import datetime

def create_plot_data_for_run(run):
    parts = run.name.split("-")
    if len(parts) > 2:
        model_name = parts[0] + "-" + parts[1]
    else:
        model_name = parts[0]

    worst_acc = 1
    worst_acc_group = "None"
    for name, results in run.summary["test_results"].items():
        if "region" in name and name != "worst_region_acc":
            if results["accuracy"] < worst_acc:
                worst_acc = results["accuracy"]
                worst_acc_group = name

    return {
        "model": model_name,
        "worst_region_acc": run.summary["test_results"]["worst_region_acc"],
        "all accuracy": run.summary["test_results"]["all"]["accuracy"],
        "all log likelihood": run.summary["test_results"]["all"]["log_likelihood"],
        "all ece": run.summary["test_results"]["all"]["ece"],
        "all sece": run.summary["test_results"]["all"]["sece"],
        "worst_acc accuracy": run.summary["test_results"][worst_acc_group]["accuracy"],
        "worst_acc sece": run.summary["test_results"][worst_acc_group]["sece"],
        "worst_acc ece": run.summary["test_results"][worst_acc_group]["ece"],
        "worst_acc log_likelihood": run.summary["test_results"][worst_acc_group]["log_likelihood"]
    }

def plot(data, value):
    plot = px.box(data, x="model", y=value, color="model")
    return plot

def pareto_plot(data, x, y):
    plot = px.scatter(data, x=x, error_x=f"{x}_std", y=y, error_y=f"{y}_std", color="model")
    return plot

def build_data(runs):
    rows = []
    for run in runs:
        if dateutil.parser.parse(run.created_at) < datetime.datetime(2023, 3, 10, 10, 0):
            continue
        if run.state != "finished":
            continue
        if "old" in run.tags:
            print("Skipping old run " + run.name)
            continue
        if "test_results" not in run.summary:
            print("Skipping crashed run " + run.name)
            continue
        rows.append(create_plot_data_for_run(run))
    return pd.DataFrame.from_dict(rows)

def aggregate_data(data):
    aggregated_data = data.groupby(["model"]).agg({
        "model": "first",
        "worst_region_acc": ["mean", "sem"],
        "all accuracy": ["mean", "sem"],
        "all log likelihood": ["mean", "sem"], 
        "all sece": ["mean", "sem"],
        "all ece": ["mean", "sem"],
        "worst_acc accuracy": ["mean", "sem"],
        "worst_acc sece": ["mean", "sem"],
        "worst_acc ece": ["mean", "sem"],
        "worst_acc log_likelihood": ["mean", "sem"],
    })
    aggregated_data.columns = [a[0] + "_std" if a[1] == "sem" else a[0] for a in aggregated_data.columns.to_flat_index()]
    aggregated_data["worst_region_acc_std"] *= 2.0
    aggregated_data["all accuracy_std"] *= 2.0
    aggregated_data["all log likelihood_std"] *= 2.0
    aggregated_data["all sece_std"] *= 2.0
    aggregated_data["all ece_std"] *= 2.0
    aggregated_data["worst_acc accuracy_std"] *= 2.0
    aggregated_data["worst_acc sece_std"] *= 2.0
    aggregated_data["worst_acc ece_std"] *= 2.0
    aggregated_data["worst_acc log_likelihood_std"] *= 2.0
    return aggregated_data

In [5]:
data = aggregate_data(build_data(runs))

Skipping crashed run swag_ll-1-(2)
Skipping crashed run swag_ll-1-(5)
Skipping crashed run swag_ll-1-(4)
Skipping crashed run swag_ll-1-(0)
Skipping crashed run swag_ll-1-(1)
Skipping crashed run swag_ll-1-(3)
Skipping crashed run laplace-5-0
Skipping old run laplace-1-(0)
Skipping old run laplace-1-(5)
Skipping old run laplace-1-(4)
Skipping old run laplace-1-(3)
Skipping old run laplace-1-(0)
Skipping old run laplace-1-(1)
Skipping old run laplace-1-(2)
Skipping old run swag-1-(5)
Skipping old run swag-1-(4)
Skipping old run swag-1-(3)
Skipping old run swag-1-(2)
Skipping old run swag-1-(1)
Skipping old run swag-1-(0)
Skipping old run swag_p-1-(0)
Skipping old run bbb-1-(5)
Skipping old run bbb-1-(4)
Skipping old run bbb-1-(3)
Skipping old run bbb-1-(2)
Skipping old run bbb-1-(1)
Skipping old run bbb-1-(0)
Skipping old run mcd_p0.1-1-(5)
Skipping old run mcd_p0.1-1-(4)
Skipping old run mcd_p0.1-1-(2)
Skipping old run mcd_p0.1-1-(3)
Skipping old run mcd_p0.1-1-(0)
Skipping old run map

In [6]:
data

Unnamed: 0_level_0,model,worst_region_acc,worst_region_acc_std,all accuracy,all accuracy_std,all log likelihood,all log likelihood_std,all sece,all sece_std,all ece,all ece_std,worst_acc accuracy,worst_acc accuracy_std,worst_acc sece,worst_acc sece_std,worst_acc ece,worst_acc ece_std,worst_acc log_likelihood,worst_acc log_likelihood_std
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1
bbb-1,bbb-1,0.305502,0.007895,0.508654,0.003198,-4.250698,0.052966,-0.292621,0.00266,0.292621,0.00266,0.305502,0.007895,-0.44776,0.00966,0.447876,0.009653,-6.673707,0.342666
bbb-5,bbb-5,0.339452,0.006172,0.561019,0.00131,-2.617455,0.007514,-0.101704,0.001357,0.101704,0.001357,0.339452,0.006172,-0.233121,0.007834,0.233121,0.007834,-4.174393,0.084816
laplace-1,laplace-1,0.217123,0.012019,0.373952,0.013472,-5.872745,0.300547,-0.445054,0.012131,0.44506,0.01213,0.217123,0.012019,-0.583118,0.014738,0.583223,0.014781,-8.141227,0.415136
laplace-5,laplace-5,0.301118,0.004135,0.51689,0.002044,-2.743831,0.017199,0.020211,0.00166,0.059305,0.002124,0.301118,0.004135,-0.123035,0.004166,0.123035,0.004166,-4.085591,0.046997
ll_ivon-1,ll_ivon-1,0.300296,0.008917,0.505262,0.002679,-3.106737,0.023407,-0.348323,0.002462,0.348331,0.002453,0.300296,0.008917,-0.514225,0.00878,0.514225,0.00878,-4.556923,0.112206
ll_ivon-5,ll_ivon-5,0.340609,0.004493,0.559662,0.001205,-2.060377,0.009235,-0.112013,0.001672,0.112013,0.001672,0.340609,0.004493,-0.241298,0.00398,0.24142,0.004038,-3.176987,0.022844
ll_ivon_p100-1,ll_ivon_p100-1,0.292325,,0.506649,,-3.4387,,-0.342936,,0.342936,,0.292325,,-0.513999,,0.513999,,-5.133786,
ll_ivon_p500-1,ll_ivon_p500-1,0.313536,,0.512529,,-3.088313,,-0.341564,,0.341599,,0.313536,,-0.51399,,0.51399,,-4.645092,
map-1,map-1,0.310066,0.00782,0.517882,0.003051,-3.502953,0.025069,-0.352619,0.002421,0.352622,0.00242,0.310066,0.00782,-0.526066,0.008797,0.526229,0.00871,-5.439209,0.116967
map-5,map-5,0.342383,0.003366,0.569341,0.001367,-2.14051,0.005768,-0.128129,0.00103,0.128129,0.00103,0.342383,0.003366,-0.270757,0.004468,0.270757,0.004468,-3.446131,0.007291


In [7]:
pareto_plot(data, "worst_acc accuracy", "worst_acc sece")

In [8]:
pareto_plot(data, "all accuracy", "all sece")

In [9]:
data.to_csv(sep=",", header=True)

'model,model,worst_region_acc,worst_region_acc_std,all accuracy,all accuracy_std,all log likelihood,all log likelihood_std,all sece,all sece_std,all ece,all ece_std,worst_acc accuracy,worst_acc accuracy_std,worst_acc sece,worst_acc sece_std,worst_acc ece,worst_acc ece_std,worst_acc log_likelihood,worst_acc log_likelihood_std\nbbb-1,bbb-1,0.305501992503802,0.007894554664193006,0.5086544851462046,0.003197545213125401,-4.2506983280181885,0.05296627858945722,-0.29262070674284074,0.002660070441326867,0.29262070674284074,0.002660070441326867,0.305501992503802,0.007894554664193006,-0.44776048123221623,0.009660342831999697,0.4478758809064889,0.009653120064726189,-6.673706690470378,0.342666161857036\nbbb-5,bbb-5,0.3394523739814758,0.0061723829949812715,0.5610186457633972,0.0013100246226058216,-2.6174545764923094,0.007514391558887915,-0.10170435984707427,0.0013571882710110338,0.10170435984707427,0.0013571882710110338,0.3394523739814758,0.0061723829949812715,-0.23312052044567802,0.007833556076787

In [59]:
algo_names = [
    ("map-1", "MAP"),
    ("map-5", "Deep Ensemble"),
    ("mcd_p0.1-1", "MCD"),
    ("mcd-5", "MultiMCD"),
    ("swag-1", "SWAG"),
    ("swag-5", "MultiSWAG"),
    ("swag_ll-1", "LL SWAG"),
    ("laplace-1", "LL Laplace"),
    ("laplace-5", "LL MultiLaplace"),
    ("bbb-1", "LL BBB"),
    ("bbb-5", "LL MultiBBB"),
    ("rank1-1", "Rank-1 VI"),
    ("ll_ivon-1", "LL iVON"),
    ("ll_ivon-5", "LL MultiiVON"),
    ("svgd-1", "SVGD"),
]

def num(value, std, best=None, ty=None):
    value = float(value)
    std = float(std)
    num_string = f"{value:.3f} \\pm {std:.3f}"

    if best is None or ty is None:
        return f"${num_string}$"

    if ty == "max":
        if value >= best:
            num_string = f"\\bm{{{num_string}}}"
    elif ty == "min":
        if value <= best:
            num_string = f"\\bm{{{num_string}}}"
    elif ty == "zero":
        if abs(value) <= best:
            num_string = f"\\bm{{{num_string}}}"
    return f"${num_string}$"

def col_name(name, align):
    return f"\\multicolumn{{1}}{{{align}}}{{{name}}}"

def create_table(data, prefix):
    print("\\begin{tabular}{l|rrrrrr}")
    print(f"    {col_name('Model', 'l')} & {col_name('WR Accuracy', 'c')} & {col_name('WR ECE', 'c')} & {col_name('WR sECE', 'c')} & {col_name('Avg Accuracy', 'c')} & {col_name('Avg ECE', 'c')} & {col_name('Avg sECE', 'c')} \\\\")
    print("    \\hline")

    best_acc, best_acc_std = 0, 0
    best_ece, best_ece_std = 1000, 0
    best_sece, best_sece_std = 1000, 0
    best_avg_acc, best_avg_acc_std = 0, 0
    best_avg_ece, best_avg_ece_std = 1000, 0
    best_avg_sece, best_avg_sece_std = 1000, 0

    for algo, name in algo_names:
        row = data[data["model"] == algo]

        if float(row[prefix + "worst_acc accuracy"]) > best_acc:
            best_acc = float(row[prefix + "worst_acc accuracy"])
            best_acc_std = float(row[prefix + "worst_acc accuracy_std"])
        
        if float(row[prefix + "worst_acc ece"]) < best_ece:
            best_ece = float(row[prefix + "worst_acc ece"])
            best_ece_std = float(row[prefix + "worst_acc ece_std"])
        
        if abs(float(row[prefix + "worst_acc sece"])) < best_sece:
            best_sece = abs(float(row[prefix + "worst_acc sece"]))
            best_sece_std = float(row[prefix + "worst_acc sece_std"])
        
        if float(row[prefix + "all accuracy"]) > best_avg_acc:
            best_avg_acc = float(row[prefix + "all accuracy"])
            best_avg_acc_std = float(row[prefix + "all accuracy_std"])
        
        if float(row[prefix + "all ece"]) < best_avg_ece:
            best_avg_ece = float(row[prefix + "all ece"])
            best_avg_ece_std = float(row[prefix + "all ece_std"])
        
        if abs(float(row[prefix + "all sece"])) < best_avg_sece:
            best_avg_sece = abs(float(row[prefix + "all sece"]))
            best_avg_sece_std = float(row[prefix + "all sece_std"])

    best_acc -= best_acc_std
    best_ece += best_ece_std
    best_sece = abs(best_sece) + best_sece_std

    best_avg_acc -= best_avg_acc_std
    best_avg_ece += best_avg_ece_std
    best_avg_sece = abs(best_avg_sece) + best_avg_sece_std

    for algo, name in algo_names:
        row = data[data["model"] == algo]
        print(f"    {name} & {num(row[prefix + 'worst_acc accuracy'], row[prefix + 'worst_acc accuracy_std'], best_acc, 'max')} & {num(row[prefix + 'worst_acc ece'], row[prefix + 'worst_acc ece_std'], best_ece, 'min')} & {num(row[prefix + 'worst_acc sece'], row[prefix + 'worst_acc sece_std'], best_sece, 'zero')} & {num(row[prefix + 'all accuracy'], row[prefix + 'all accuracy_std'], best_avg_acc, 'max')} & {num(row[prefix + 'all ece'], row[prefix + 'all ece_std'], best_avg_ece, 'min')} & {num(row[prefix + 'all sece'], row[prefix + 'all sece_std'], best_avg_sece, 'zero')} \\\\")
    print("\\end{tabular}")
create_table(data, "")

\begin{tabular}{l|rrrrrr}
    \multicolumn{1}{l}{Model} & \multicolumn{1}{c}{WR Accuracy} & \multicolumn{1}{c}{WR ECE} & \multicolumn{1}{c}{WR sECE} & \multicolumn{1}{c}{Avg Accuracy} & \multicolumn{1}{c}{Avg ECE} & \multicolumn{1}{c}{Avg sECE} \\
    \hline
    MAP & $0.310 \pm 0.008$ & $0.526 \pm 0.009$ & $-0.526 \pm 0.009$ & $0.518 \pm 0.003$ & $0.353 \pm 0.002$ & $-0.353 \pm 0.002$ \\
    Deep Ensemble & $0.342 \pm 0.003$ & $0.271 \pm 0.004$ & $-0.271 \pm 0.004$ & $0.569 \pm 0.001$ & $0.128 \pm 0.001$ & $-0.128 \pm 0.001$ \\
    MCD & $0.307 \pm 0.009$ & $0.520 \pm 0.011$ & $-0.520 \pm 0.011$ & $0.515 \pm 0.002$ & $0.349 \pm 0.004$ & $-0.349 \pm 0.004$ \\
    MultiMCD & $\bm{0.353 \pm 0.005}$ & $0.253 \pm 0.005$ & $-0.253 \pm 0.005$ & $\bm{0.571 \pm 0.000}$ & $0.122 \pm 0.001$ & $-0.122 \pm 0.001$ \\
    SWAG & $0.308 \pm 0.009$ & $0.501 \pm 0.007$ & $-0.500 \pm 0.007$ & $0.520 \pm 0.003$ & $0.327 \pm 0.003$ & $-0.327 \pm 0.003$ \\
    MultiSWAG & $0.338 \pm 0.003$ & $0.270 \pm 0.0