In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../../")

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

DATA_PATH = "../data"

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

from src.log_mock import PrintLog
log = PrintLog()

import wandb

wandb.init(mode="disabled")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.




In [2]:
wapi = wandb.Api()
runs = wapi.runs("foobar/iwildcam")

In [3]:
for i, run in enumerate(runs):
    print(i, run.name)

0 swag_ll-1-(0)
1 swag_ll-1-(1)
2 swag_ll-1-(2)
3 swag_ll-1-(1)
4 swag_ll-1-(2)
5 swag_ll-1-(0)
6 swag_ll-1-(0)
7 swag_ll-1-(2)
8 swag_ll-1-(1)
9 ll_svgd-1-(2)
10 ll_svgd-1-(1)
11 ll_svgd-1-(0)
12 ll_ivon_5-(0)
13 ll_ivon_5-(1)
14 ll_ivon_5-(2)
15 ll_ivon_5-(0)
16 ll_ivon-1-(5)
17 ll_ivon-1-(4)
18 ll_ivon-1-(1)
19 ll_ivon-1-(2)
20 ll_ivon-1-(3)
21 ll_ivon-1-(0)
22 ll_ivon-1-(0)
23 ll_ivon-1-(0)
24 ll_ivon-1-(0)
25 ll_ivon-1-(0)
26 ll_svgd-1-(2)
27 ll_svgd-1-(1)
28 ll_svgd-1-(0)
29 laplace-1-(5)
30 laplace_5-4
31 laplace-1-(4)
32 laplace_5-3
33 laplace-1-(3)
34 laplace_5-2
35 laplace-1-(2)
36 laplace_5-1
37 laplace-1-(1)
38 laplace_5-0
39 laplace-1-(0)
40 laplace_5-5
41 bbb_5-(1)
42 bbb_5-(0)
43 bbb_5-(2)
44 bbb-1-(2)
45 bbb-1-(1)
46 bbb-1-(0)
47 bbb_full-1-(0)
48 ivon-1-(0)
49 laplace-1-(0)
50 laplace-1-(2)
51 laplace-1-(1)
52 ivon-1-(0)
53 swag_5-(0)
54 swag_5-(2)
55 swag_5-(1)
56 laplace-1-(0)
57 mcd_5-(1)
58 mcd_5-(2)
59 mcd_5-(0)
60 swag-1-(2)
61 swag-1-(0)
62 swag-1-(1)
63 mcd-1-(

In [4]:
def create_reliability_plot(run, results_name):
    results = run.summary[results_name]
    print(f"sECE: {results['sece']:.4f}, ECE: {results['ece']:.4f}")
    bins = list(filter(lambda x: x[0] > 0, zip(results["bin_confidences"], results["bin_accuracies"], results["bin_counts"])))

    print("\\begin{tikzpicture}")
    print("    \\begin{axis}[calstyle, xmin=0, xmax=1, ymin=0, ymax=1]")
    print("        \\addplot[dashed, color=black] coordinates {(0,0) (1,1)};")
    print("        \\addplot[calline] coordinates {" + " ".join(map(lambda x: f"({x[0]}, {x[1]})", bins)) + "};")
    for conf, acc, count in list(bins):
        print(f"        \\node[above, anchor=south west, rotate=60, font=\\tiny] at (axis cs:{conf}, 1.0) {{{count}}};")
        print(f"        \\draw[dotted, color=black] (axis cs:{conf}, {acc}) -- (axis cs:{conf}, 1.0);")
    print("    \\end{axis}")
    print("\\end{tikzpicture}")

#create_reliability_plot(runs[76], "test_results") # MAP (overconfident)
#create_reliability_plot(runs[10], "test_results") # MultiLaplace (underconfident)


In [5]:
import plotly.express as px
import pandas as pd
import dateutil
import datetime

def create_plot_data_for_run(run):
    parts = run.name.split("-")
    if len(parts) > 2:
        model_name = parts[0] + "-" + parts[1]
    else:
        model_name = parts[0]

    return {
        "model": model_name,
        "accuracy": run.summary["test_results"]["accuracy"],
        "macro f1": run.summary["test_results"]["macro_f1"],
        "log likelihood": run.summary["test_results"]["log_likelihood"],
        "ece": run.summary["test_results"]["ece"],
        "sece": run.summary["test_results"]["sece"],
        "id_val accuracy": run.summary["id_val_results"]["accuracy"],
        "id_val macro f1": run.summary["id_val_results"]["macro_f1"],
        "id_val log likelihood": run.summary["id_val_results"]["log_likelihood"],
        "id_val ece": run.summary["id_val_results"]["ece"],
        "id_val sece": run.summary["id_val_results"]["sece"],
    }

def plot(data, value):
    plot = px.box(data, x="model", y=value, color="model")
    return plot

def pareto_plot(data, x, y):
    plot = px.scatter(data, x=x, error_x=f"{x}_std", y=y, error_y=f"{y}_std", color="model")
    return plot

def build_data(runs):
    rows = []
    for run in runs:
        if dateutil.parser.parse(run.created_at) < datetime.datetime(2023, 3, 10, 10, 0):
            continue
        if run.state != "finished":
            continue
        if "old" in run.tags:
            print("Skipping old run " + run.name)
            continue
        if "test_results" not in run.summary:
            print("Skipping crashed run " + run.name)
            continue
        rows.append(create_plot_data_for_run(run))
    return pd.DataFrame.from_dict(rows)

def aggregate_data(data):
    aggregated_data = data.groupby(["model"]).agg({
        "model": "first",
        "accuracy": ["mean", "sem"], 
        "macro f1": ["mean", "sem"], 
        "log likelihood": ["mean", "sem"], 
        "sece": ["mean", "sem"],
        "ece": ["mean", "sem"],
        "id_val accuracy": ["mean", "sem"], 
        "id_val macro f1": ["mean", "sem"], 
        "id_val log likelihood": ["mean", "sem"], 
        "id_val sece": ["mean", "sem"],
        "id_val ece": ["mean", "sem"],
    })
    aggregated_data.columns = [a[0] + "_std" if a[1] == "sem" else a[0] for a in aggregated_data.columns.to_flat_index()]
    aggregated_data["accuracy_std"] *= 2
    aggregated_data["macro f1_std"] *= 2
    aggregated_data["log likelihood_std"] *= 2
    aggregated_data["sece_std"] *= 2
    aggregated_data["ece_std"] *= 2
    aggregated_data["id_val accuracy_std"] *= 2
    aggregated_data["id_val macro f1_std"] *= 2
    aggregated_data["id_val log likelihood_std"] *= 2
    aggregated_data["id_val sece_std"] *= 2
    aggregated_data["id_val ece_std"] *= 2
    return aggregated_data

In [6]:
data = aggregate_data(build_data(runs))

Skipping crashed run swag_ll-1-(1)
Skipping crashed run swag_ll-1-(2)
Skipping crashed run swag_ll-1-(0)
Skipping crashed run swag_ll-1-(0)
Skipping crashed run swag_ll-1-(2)
Skipping crashed run swag_ll-1-(1)
Skipping crashed run ll_ivon_5-(0)
Skipping old run laplace-1-(0)
Skipping old run map-1-(2)
Skipping old run map-1-(1)
Skipping old run map-1-(0)
Skipping old run map-1-(2)
Skipping old run map-1-(1)
Skipping old run map-1-(0)


In [7]:
data

Unnamed: 0_level_0,model,accuracy,accuracy_std,macro f1,macro f1_std,log likelihood,log likelihood_std,sece,sece_std,ece,...,id_val accuracy,id_val accuracy_std,id_val macro f1,id_val macro f1_std,id_val log likelihood,id_val log likelihood_std,id_val sece,id_val sece_std,id_val ece,id_val ece_std
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
bbb-1,bbb-1,0.718247,0.008699,0.282251,0.01111,-1.543047,0.054406,-0.093181,0.00505,0.097063,...,0.81638,0.004792,0.442135,0.011096,-1.142696,0.029683,-0.074734,0.003189,0.074997,0.003103
bbb_5,bbb_5,0.748101,0.00215,0.312464,0.007549,-1.164416,0.011009,-0.011962,0.002286,0.014952,...,0.748911,0.002241,0.316098,0.006113,-1.164969,0.009095,-0.011022,0.00282,0.014622,0.003256
bbb_full-1,bbb_full-1,0.478652,,0.036373,,-1.84225,,0.010132,,0.019468,...,0.569593,,0.062815,,-1.775371,,0.033727,,0.078839,
ivon-1,ivon-1,0.471852,,0.030404,,-1.943763,,0.004276,,0.088849,...,0.568636,,0.042897,,-1.831677,,0.04519,,0.085354,
laplace-1,laplace-1,0.694311,0.014671,0.270444,0.010353,-1.566674,0.083208,-0.051817,0.016714,0.052988,...,0.809697,0.004706,0.45615,0.016663,-1.044671,0.05847,-0.02605,0.009799,0.027941,0.0092
laplace_5,laplace_5,0.738644,0.004205,0.304144,0.007366,-1.197276,0.012082,0.046295,0.004604,0.046403,...,0.835863,0.001233,0.48911,0.011984,-0.839482,0.012383,0.026704,0.003108,0.027474,0.002715
ll_ivon-1,ll_ivon-1,0.724716,0.009575,0.264779,0.008799,-1.331397,0.048783,-0.083594,0.014255,0.088223,...,0.81246,0.004619,0.447176,0.014917,-1.001551,0.028363,-0.075534,0.008068,0.076179,0.008131
ll_ivon_5,ll_ivon_5,0.76284,0.003274,0.299456,0.006257,-1.036238,0.005838,0.010722,0.003309,0.019399,...,0.762544,0.002766,0.293952,0.00388,-1.035152,0.005248,0.010329,0.003291,0.018889,0.002852
ll_svgd-1,ll_svgd-1,0.736962,0.01416,0.265332,0.017513,-1.447085,0.045127,-0.11747,0.002627,0.117699,...,0.822076,0.012116,0.452774,0.01811,-1.135232,0.234305,-0.094283,0.009489,0.094324,0.009527
map-1,map-1,0.707536,0.016173,0.280066,0.020495,-1.513872,0.093951,-0.140396,0.014726,0.140465,...,0.81328,0.007338,0.460402,0.017044,-1.121351,0.086854,-0.10401,0.007057,0.104014,0.007057


In [13]:
algo_names = [
    ("map-1", "MAP"),
    ("map_5", "Deep Ensemble"),
    ("mcd-1", "MCD"),
    ("mcd_5", "MultiMCD"),
    ("swag-1", "SWAG"),
    ("swag_5", "MultiSWAG"),
    ("swag_ll-1", "LL SWAG"),
    ("laplace-1", "LL Laplace"),
    ("laplace_5", "LL MultiLaplace"),
    ("bbb-1", "LL BBB"),
    ("bbb_5", "LL MultiBBB"),
    ("rank1-1", "Rank-1 VI"),
    ("ll_ivon-1", "LL iVON"),
    ("ll_ivon_5", "LL MultiiVON"),
    ("svgd-1", "SVGD"),
    ("ll_svgd-1", "LL SVGD"),
]

def num(value, std, best=None, ty=None):
    value = float(value)
    std = float(std)
    num_string = f"{value:.3f} \\pm {std:.3f}"

    if best is None or ty is None:
        return f"${num_string}$"

    if ty == "max":
        if value >= best:
            num_string = f"\\bm{{{num_string}}}"
    elif ty == "min":
        if value <= best:
            num_string = f"\\bm{{{num_string}}}"
    elif ty == "zero":
        if abs(value) <= best:
            num_string = f"\\bm{{{num_string}}}"
    return f"${num_string}$"

def col_name(name, align):
    return f"\\multicolumn{{1}}{{{align}}}{{{name}}}"

def create_table(data, prefix):
    print("\\begin{tabular}{l|rrrr}")
    print(f"    {col_name('Model', 'l')} & {col_name('Macro F1 Score', 'c')} & {col_name('Accuracy', 'c')} & {col_name('ECE', 'c')} & {col_name('sECE', 'c')} \\\\")
    print("    \\hline")

    best_f1, best_f1_std = 0, 0
    best_acc, best_acc_std = 0, 0
    best_ece, best_ece_std = 1000, 0
    best_sece, best_sece_std = 1000, 0

    for algo, name in algo_names:
        row = data[data["model"] == algo]
        
        if float(row[prefix + "macro f1"]) > best_f1:
            best_f1 = float(row[prefix + "macro f1"])
            best_f1_std = float(row[prefix + "macro f1_std"])
        
        if float(row[prefix + "accuracy"]) > best_acc:
            best_acc = float(row[prefix + "accuracy"])
            best_acc_std = float(row[prefix + "accuracy_std"])
        
        if float(row[prefix + "ece"]) < best_ece:
            best_ece = float(row[prefix + "ece"])
            best_ece_std = float(row[prefix + "ece_std"])
        
        if abs(float(row[prefix + "sece"])) < best_sece:
            best_sece = abs(float(row[prefix + "sece"]))
            best_sece_std = float(row[prefix + "sece_std"])

    best_f1 -= best_f1_std
    best_acc -= best_acc_std
    best_ece += best_ece_std
    best_sece = abs(best_sece) + best_sece_std

    for algo, name in algo_names:
        row = data[data["model"] == algo]
        print(f"    {name} & {num(row[prefix + 'macro f1'], row[prefix + 'macro f1_std'], best_f1, 'max')} & {num(row[prefix + 'accuracy'], row[prefix + 'accuracy_std'], best_acc, 'max')} & {num(row[prefix + 'ece'], row[prefix + 'ece_std'], best_ece, 'min')} & {num(row[prefix + 'sece'], row[prefix + 'sece_std'], best_sece, 'zero')} \\\\")
    print("\\end{tabular}")
create_table(data, "id_val ")

\begin{tabular}{l|rrrr}
    \multicolumn{1}{l}{Model} & \multicolumn{1}{c}{Macro F1 Score} & \multicolumn{1}{c}{Accuracy} & \multicolumn{1}{c}{ECE} & \multicolumn{1}{c}{sECE} \\
    \hline
    MAP & $0.460 \pm 0.017$ & $0.813 \pm 0.007$ & $0.104 \pm 0.007$ & $-0.104 \pm 0.007$ \\
    Deep Ensemble & $0.308 \pm 0.005$ & $0.752 \pm 0.007$ & $0.020 \pm 0.001$ & $-0.015 \pm 0.005$ \\
    MCD & $0.457 \pm 0.010$ & $0.814 \pm 0.002$ & $0.100 \pm 0.011$ & $-0.100 \pm 0.011$ \\
    MultiMCD & $0.311 \pm 0.001$ & $0.762 \pm 0.006$ & $\bm{0.013 \pm 0.003}$ & $\bm{-0.008 \pm 0.006}$ \\
    SWAG & $\bm{0.491 \pm 0.011}$ & $0.832 \pm 0.003$ & $0.087 \pm 0.002$ & $-0.087 \pm 0.002$ \\
    MultiSWAG & $0.333 \pm 0.011$ & $0.761 \pm 0.002$ & $0.033 \pm 0.002$ & $-0.033 \pm 0.002$ \\
    LL SWAG & $0.465 \pm 0.043$ & $0.819 \pm 0.016$ & $0.088 \pm 0.012$ & $-0.088 \pm 0.012$ \\
    LL Laplace & $0.456 \pm 0.017$ & $0.810 \pm 0.005$ & $0.028 \pm 0.009$ & $-0.026 \pm 0.010$ \\
    LL MultiLaplace & $\bm{

In [9]:
data.to_csv(sep=",", header=True)

'model,model,accuracy,accuracy_std,macro f1,macro f1_std,log likelihood,log likelihood_std,sece,sece_std,ece,ece_std,id_val accuracy,id_val accuracy_std,id_val macro f1,id_val macro f1_std,id_val log likelihood,id_val log likelihood_std,id_val sece,id_val sece_std,id_val ece,id_val ece_std\nbbb-1,bbb-1,0.7182468275229136,0.008699303471846553,0.28225127760948293,0.011109714438048106,-1.543046732743581,0.05440643967498269,-0.0931807081048114,0.005049549605790326,0.09706304044997006,0.005918370840583741,0.8163795471191406,0.004791726065836142,0.4421345258928023,0.011096235578060604,-1.142696221669515,0.029682856085293162,-0.0747335595085982,0.0031889590346456833,0.07499668611312546,0.0031034233086544433\nbbb_5,bbb_5,0.7481012543042501,0.002150148547266455,0.31246402346004376,0.007548790591173385,-1.1644161542256672,0.011009497262232335,-0.011961853079257593,0.002285698673582957,0.014951700955406712,0.003203509821140049,0.7489113807678223,0.002240587514982809,0.31609830592401145,0.00611326

In [10]:
pareto_plot(data, "macro f1", "sece")

In [11]:
pareto_plot(data, "id_val macro f1", "id_val sece")