In [2]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../../")

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
DATA_PATH = "../../data/"

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

from src.log_mock import PrintLog
log = PrintLog()

import wandb
wandb.init(mode="disabled")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.




In [2]:
wapi = wandb.Api(timeout=60)
runs = wapi.runs("foobar/amazon")

In [3]:
for run in runs:
    print(run.name, run.summary.keys())

svgd-0 dict_keys(['_step', '_timestamp', '10th_percentile_acc', 'ece', 'ood_test_results', '_wandb', '_runtime', 'accuracy', 'train_loss', 'bin_accuracies', 'log_likelihood', 'bin_confidences', 'sece', 'id_test_results', 'bin_counts'])
svgd-1 dict_keys(['bin_counts', '_timestamp', '_step', 'log_likelihood', 'ood_test_results', '10th_percentile_acc', 'ece', 'bin_accuracies', 'bin_confidences', 'sece', '_runtime', 'accuracy', 'train_loss', 'id_test_results', '_wandb'])
svgd-3 dict_keys(['log_likelihood', 'ood_test_results', '_step', '_wandb', '_runtime', 'accuracy', 'bin_counts', 'sece', 'train_loss', 'bin_confidences', 'id_test_results', '10th_percentile_acc', 'ece', '_timestamp', 'bin_accuracies'])
svgd-4 dict_keys(['sece', 'id_test_results', '10th_percentile_acc', 'ood_test_results', '_step', '_timestamp', 'bin_counts', 'log_likelihood', 'bin_confidences', '_wandb', '_runtime', 'train_loss', 'ece', 'accuracy', 'bin_accuracies'])
svgd-2 dict_keys(['log_likelihood', 'bin_confidences', '

In [4]:
def create_reliability_plot(results):
    print(f"sECE: {results['sece']:.4f}, ECE: {results['ece']:.4f}")
    bins = list(filter(lambda x: x[0] > 0, zip(results["bin_confidences"], results["bin_accuracies"], results["bin_counts"])))

    print("\\begin{tikzpicture}")
    print("    \\begin{axis}[calstyle, xmin=0, xmax=1, ymin=0, ymax=1]")
    print("        \\addplot[dashed, color=black] coordinates {(0,0) (1,1)};")
    print("        \\addplot[calline] coordinates {" + " ".join(map(lambda x: f"({x[0]}, {x[1]})", bins)) + "};")
    for conf, acc, count in list(bins):
        print(f"        \\node[above, anchor=south west, rotate=60, font=\\tiny] at (axis cs:{conf}, 1.0) {{{count}}};")
        print(f"        \\draw[dotted, color=black] (axis cs:{conf}, {acc}) -- (axis cs:{conf}, 1.0);")
    print("    \\end{axis}")
    print("\\end{tikzpicture}")

In [5]:
import plotly.express as px
import pandas as pd
import dateutil
import datetime

def create_plot_data_for_run(run):
    parts = run.name.split("-")
    if len(parts) > 2:
        model_name = parts[0] + "-" + parts[1]
    else:
        model_name = parts[0]

    print(f"Including {model_name}")
    create_reliability_plot(run.summary["ood_test_results"])

    return {
        "model": model_name,
        "ood accuracy": run.summary["ood_test_results"]["accuracy"],
        "ood 10th_percentile_acc": run.summary["ood_test_results"]["10th_percentile_acc"],
        "ood log likelihood": run.summary["ood_test_results"]["log_likelihood"],
        "ood ece": run.summary["ood_test_results"]["ece"],
        "ood sece": run.summary["ood_test_results"]["sece"],
        "id accuracy": run.summary["id_test_results"]["accuracy"],
        "id 10th_percentile_acc": run.summary["id_test_results"]["10th_percentile_acc"],
        "id log likelihood": run.summary["id_test_results"]["log_likelihood"],
        "id ece": run.summary["id_test_results"]["ece"],
        "id sece": run.summary["id_test_results"]["sece"],
    }

def plot(data, value):
    plot = px.box(data, x="model", y=value, color="model")
    return plot

def pareto_plot(data, x, y):
    plot = px.scatter(data, x=x, error_x=f"{x}_std", y=y, error_y=f"{y}_std", color="model")
    return plot

def build_data(runs):
    rows = []
    for run in runs:
        if dateutil.parser.parse(run.created_at) < datetime.datetime(2023, 4, 27, 0, 0):
            print("Skipping run " + run.name + " because it is older than the cutoff time")
            continue
        if run.state != "finished":
            continue
        if "old" in run.tags:
            print("Skipping old run " + run.name)
            continue
        if "ood_test_results" not in run.summary:
            print("Skipping crashed run " + run.name)
            continue
        rows.append(create_plot_data_for_run(run))
    return pd.DataFrame.from_dict(rows)

def aggregate_data(data):
    aggregated_data = data.groupby(["model"]).agg({
        "model": "first",
        "ood 10th_percentile_acc": ["mean", "sem"], 
        "ood accuracy": ["mean", "sem"], 
        "ood log likelihood": ["mean", "sem"], 
        "ood sece": ["mean", "sem"],
        "ood ece": ["mean", "sem"],
        "id 10th_percentile_acc": ["mean", "sem"], 
        "id accuracy": ["mean", "sem"], 
        "id log likelihood": ["mean", "sem"], 
        "id sece": ["mean", "sem"],
        "id ece": ["mean", "sem"],
    })
    aggregated_data.columns = [a[0] + "_std" if a[1] == "sem" else a[0] for a in aggregated_data.columns.to_flat_index()]
    aggregated_data["ood 10th_percentile_acc_std"] *= 2.0
    aggregated_data["ood accuracy_std"] *= 2.0
    aggregated_data["ood log likelihood_std"] *= 2.0
    aggregated_data["ood sece_std"] *= 2.0
    aggregated_data["ood ece_std"] *= 2.0
    aggregated_data["id 10th_percentile_acc_std"] *= 2.0
    aggregated_data["id accuracy_std"] *= 2.0
    aggregated_data["id log likelihood_std"] *= 2.0
    aggregated_data["id sece_std"] *= 2.0
    aggregated_data["id ece_std"] *= 2.0
    return aggregated_data

In [6]:
data = aggregate_data(build_data(runs))

Including svgd
sECE: -0.0549, ECE: 0.0549
\begin{tikzpicture}
    \begin{axis}[calstyle, xmin=0, xmax=1, ymin=0, ymax=1]
        \addplot[dashed, color=black] coordinates {(0,0) (1,1)};
        \addplot[calline] coordinates {(0.27743783593177795, 0.2384105920791626) (0.35889512300491333, 0.32244008779525757) (0.46239572763442993, 0.40612584352493286) (0.5498671531677246, 0.4888852536678314) (0.6498734951019287, 0.5814056992530823) (0.7493167519569397, 0.6803209781646729) (0.8529765605926514, 0.8019340634346008) (0.9401679635047911, 0.9101652503013612)};
        \node[above, anchor=south west, rotate=60, font=\tiny] at (axis cs:0.27743783593177795, 1.0) {453};
        \draw[dotted, color=black] (axis cs:0.27743783593177795, 0.2384105920791626) -- (axis cs:0.27743783593177795, 1.0);
        \node[above, anchor=south west, rotate=60, font=\tiny] at (axis cs:0.35889512300491333, 1.0) {2295};
        \draw[dotted, color=black] (axis cs:0.35889512300491333, 0.32244008779525757) -- (axis cs:0

In [7]:
data

Unnamed: 0_level_0,model,ood 10th_percentile_acc,ood 10th_percentile_acc_std,ood accuracy,ood accuracy_std,ood log likelihood,ood log likelihood_std,ood sece,ood sece_std,ood ece,...,id 10th_percentile_acc,id 10th_percentile_acc_std,id accuracy,id accuracy_std,id log likelihood,id log likelihood_std,id sece,id sece_std,id ece,id ece_std
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
bbb,bbb,0.526667,0.005963,0.695284,0.00743,-0.897627,0.018957,-0.154253,0.005337,0.154268,...,0.56,0.0,0.730401,0.006394,-0.778404,0.01473,-0.127717,0.003822,0.127719,0.003823
bbb_5,bbb_5,0.533333,0.0,0.709237,0.001719,-0.747748,0.001311,-0.105061,0.001206,0.105078,...,0.573333,0.0,0.746454,0.001248,-0.647611,0.001506,-0.078813,0.001608,0.078843,0.001599
laplace_1,laplace_1,0.455,0.009344,0.653856,0.002768,-0.816272,0.006181,-0.067162,0.006139,0.067172,...,0.481667,0.009344,0.677654,0.001722,-0.756493,0.004184,-0.048184,0.006952,0.048225,0.006972
laplace_5,laplace_5,0.453333,0.0,0.658836,0.000594,-0.800253,0.001207,-0.058317,0.001397,0.058317,...,0.48,0.0,0.681944,0.000324,-0.741925,0.000799,-0.040416,0.001691,0.040422,0.001695
ll_ivon,ll_ivon,0.458444,0.010452,0.661316,0.002213,-0.793888,0.008097,-0.053443,0.010477,0.053455,...,0.484444,0.008889,0.683862,0.002238,-0.736763,0.005636,-0.037241,0.010398,0.037436,0.010355
ll_ivon_5,ll_ivon_5,0.458667,0.006532,0.66474,0.000736,-0.779269,0.001075,-0.044756,0.002133,0.044756,...,0.484,0.005333,0.687353,0.000992,-0.72378,0.000549,-0.028628,0.002169,0.029258,0.002528
ll_swag,ll_swag,0.451778,0.011734,0.65569,0.002916,-0.801564,0.006417,-0.047702,0.008718,0.047731,...,0.474444,0.009988,0.678619,0.001914,-0.746732,0.004581,-0.02987,0.009171,0.031185,0.008184
map,map,0.453333,0.009737,0.655169,0.003163,-0.814966,0.007054,-0.067469,0.006492,0.067495,...,0.476667,0.008255,0.678346,0.001974,-0.755257,0.005475,-0.049262,0.007244,0.049316,0.007245
map_5,map_5,0.453333,0.0,0.658803,0.000723,-0.800404,0.001484,-0.058439,0.001605,0.058439,...,0.48,0.0,0.682126,0.000439,-0.741926,0.001058,-0.040308,0.001994,0.040308,0.001994
mcd,mcd,0.446667,0.013333,0.656633,0.002001,-0.788582,0.003969,-0.019232,0.012343,0.020422,...,0.472222,0.012131,0.678289,0.001208,-0.740641,0.002757,-0.002382,0.013184,0.014871,0.00628


In [8]:
pareto_plot(data, "ood 10th_percentile_acc", "ood sece")

In [9]:
pareto_plot(data, "ood accuracy", "ood sece")

In [10]:
pareto_plot(data, "id 10th_percentile_acc", "id sece")

In [11]:
data.to_csv(sep=",", header=True)

'model,model,ood 10th_percentile_acc,ood 10th_percentile_acc_std,ood accuracy,ood accuracy_std,ood log likelihood,ood log likelihood_std,ood sece,ood sece_std,ood ece,ood ece_std,id 10th_percentile_acc,id 10th_percentile_acc_std,id accuracy,id accuracy_std,id log likelihood,id log likelihood_std,id sece,id sece_std,id ece,id ece_std\nbbb,bbb,0.526666671037674,0.005962868909391993,0.6952840288480123,0.0074301630236107774,-0.8976268370946249,0.018957477938627728,-0.1542534022557225,0.0053372576004840945,0.15426780016148625,0.005334186039254019,0.5600000023841858,0.0,0.7304011285305023,0.006394461072875588,-0.778404027223587,0.014730367097626311,-0.12771731354798732,0.0038216606073879565,0.127718534245611,0.0038234245204340936\nbbb_5,bbb_5,0.5333333611488342,0.0,0.7092373847961426,0.0017189139319656723,-0.7477477192878723,0.0013106337875799705,-0.10506073958474,0.0012057160544015242,0.10507805465207228,0.0012006995731298447,0.5733333230018616,0.0,0.7464536786079407,0.0012479995936486932,-

In [13]:
algo_names = [
    ("map", "MAP"),
    ("map_5", "Deep Ensemble"),
    ("mcd", "MCD"),
    ("mcd_5", "MultiMCD"),
    ("mcd_ll", "LL MCD"),
    ("swag", "SWAG"),
    ("swag_5", "MultiSWAG"),
    ("ll_swag", "LL SWAG"),
    ("laplace_1", "LL Laplace"),
    ("laplace_5", "LL MultiLaplace"),
    ("bbb", "LL BBB"),
    ("bbb_5", "LL MultiBBB"),
    ("rank1", "Rank-1 VI"),
    ("ll_ivon", "LL iVON"),
    ("ll_ivon_5", "LL MultiiVON"),
    ("svgd", "SVGD"),
]

def num(value, std, best=None, ty=None):
    value = float(value)
    std = float(std)
    num_string = f"{value:.3f} \\pm {std:.3f}"

    if best is None or ty is None:
        return f"${num_string}$"

    if ty == "max":
        if value >= best:
            num_string = f"\\bm{{{num_string}}}"
    elif ty == "min":
        if value <= best:
            num_string = f"\\bm{{{num_string}}}"
    elif ty == "zero":
        if abs(value) <= best:
            num_string = f"\\bm{{{num_string}}}"
    return f"${num_string}$"

def col_name(name, align):
    return f"\\multicolumn{{1}}{{{align}}}{{{name}}}"

def create_table(data, prefix):
    print("\\begin{tabular}{l|rrrrrrrr}")
    print(f"    {col_name('Model', 'l')} & {col_name('o.o.d. 10 Accuracy', 'c')} & {col_name('o.o.d. Accuracy', 'c')} & {col_name('o.o.d. ECE', 'c')} & {col_name('o.o.d. sECE', 'c')} & {col_name('i.d. 10 Accuracy', 'c')} & {col_name('i.d. Avg Accuracy', 'c')} & {col_name('i.d. Avg ECE', 'c')} & {col_name('i.d. Avg sECE', 'c')} \\\\")
    print("    \\hline")

    best_perc, best_perc_std = 0, 0
    best_acc, best_acc_std = 0, 0
    best_ece, best_ece_std = 1000, 0
    best_sece, best_sece_std = 1000, 0
    best_avg_perc, best_avg_perc_std = 0, 0
    best_avg_acc, best_avg_acc_std = 0, 0
    best_avg_ece, best_avg_ece_std = 1000, 0
    best_avg_sece, best_avg_sece_std = 1000, 0

    for algo, name in algo_names:
        row = data[data["model"] == algo]

        if float(row[prefix + "ood 10th_percentile_acc"]) > best_perc:
            best_perc = float(row[prefix + "ood 10th_percentile_acc"])
            best_perc_std = float(row[prefix + "ood 10th_percentile_acc_std"])

        if float(row[prefix + "ood accuracy"]) > best_acc:
            best_acc = float(row[prefix + "ood accuracy"])
            best_acc_std = float(row[prefix + "ood accuracy_std"])
        
        if float(row[prefix + "ood ece"]) < best_ece:
            best_ece = float(row[prefix + "ood ece"])
            best_ece_std = float(row[prefix + "ood ece_std"])
        
        if abs(float(row[prefix + "ood sece"])) < best_sece:
            best_sece = abs(float(row[prefix + "ood sece"]))
            best_sece_std = float(row[prefix + "ood sece_std"])

        if float(row[prefix + "id 10th_percentile_acc"]) > best_avg_perc:
            best_avg_perc = float(row[prefix + "id 10th_percentile_acc"])
            best_avg_perc_std = float(row[prefix + "id 10th_percentile_acc_std"])
        
        if float(row[prefix + "id accuracy"]) > best_avg_acc:
            best_avg_acc = float(row[prefix + "id accuracy"])
            best_avg_acc_std = float(row[prefix + "id accuracy_std"])
        
        if float(row[prefix + "id ece"]) < best_avg_ece:
            best_avg_ece = float(row[prefix + "id ece"])
            best_avg_ece_std = float(row[prefix + "id ece_std"])
        
        if abs(float(row[prefix + "id sece"])) < best_avg_sece:
            best_avg_sece = abs(float(row[prefix + "id sece"]))
            best_avg_sece_std = float(row[prefix + "id sece_std"])

    best_perc -= best_perc_std
    best_acc -= best_acc_std
    best_ece += best_ece_std
    best_sece = abs(best_sece) + best_sece_std

    best_avg_perc -= best_avg_perc_std
    best_avg_acc -= best_avg_acc_std
    best_avg_ece += best_avg_ece_std
    best_avg_sece = abs(best_avg_sece) + best_avg_sece_std

    for algo, name in algo_names:
        row = data[data["model"] == algo]
        print(f"    {name} & {num(row[prefix + 'ood 10th_percentile_acc'], row[prefix + 'ood 10th_percentile_acc_std'], best_perc, 'max')} & {num(row[prefix + 'ood accuracy'], row[prefix + 'ood accuracy_std'], best_acc, 'max')} & {num(row[prefix + 'ood ece'], row[prefix + 'ood ece_std'], best_ece, 'min')} & {num(row[prefix + 'ood sece'], row[prefix + 'ood sece_std'], best_sece, 'zero')} & {num(row[prefix + 'id 10th_percentile_acc'], row[prefix + 'id 10th_percentile_acc_std'], best_avg_perc, 'max')} & {num(row[prefix + 'id accuracy'], row[prefix + 'id accuracy_std'], best_avg_acc, 'max')} & {num(row[prefix + 'id ece'], row[prefix + 'id ece_std'], best_avg_ece, 'min')} & {num(row[prefix + 'id sece'], row[prefix + 'id sece_std'], best_avg_sece, 'zero')} \\\\")
    print("\\end{tabular}")
create_table(data, "")

\begin{tabular}{l|rrrrrrrr}
    \multicolumn{1}{l}{Model} & \multicolumn{1}{c}{o.o.d. 10 Accuracy} & \multicolumn{1}{c}{o.o.d. Accuracy} & \multicolumn{1}{c}{o.o.d. ECE} & \multicolumn{1}{c}{o.o.d. sECE} & \multicolumn{1}{c}{i.d. 10 Accuracy} & \multicolumn{1}{c}{i.d. Avg Accuracy} & \multicolumn{1}{c}{i.d. Avg ECE} & \multicolumn{1}{c}{i.d. Avg sECE} \\
    \hline
    MAP & $0.453 \pm 0.010$ & $0.655 \pm 0.003$ & $0.067 \pm 0.006$ & $-0.067 \pm 0.006$ & $0.477 \pm 0.008$ & $0.678 \pm 0.002$ & $0.049 \pm 0.007$ & $-0.049 \pm 0.007$ \\
    Deep Ensemble & $0.453 \pm 0.000$ & $0.659 \pm 0.001$ & $0.058 \pm 0.002$ & $-0.058 \pm 0.002$ & $0.480 \pm 0.000$ & $0.682 \pm 0.000$ & $0.040 \pm 0.002$ & $-0.040 \pm 0.002$ \\
    MCD & $0.447 \pm 0.013$ & $0.657 \pm 0.002$ & $0.020 \pm 0.011$ & $-0.019 \pm 0.012$ & $0.472 \pm 0.012$ & $0.678 \pm 0.001$ & $0.015 \pm 0.006$ & $\bm{-0.002 \pm 0.013}$ \\
    MultiMCD & $0.451 \pm 0.005$ & $0.660 \pm 0.001$ & $\bm{0.012 \pm 0.003}$ & $\bm{-0.012 \pm 0.