In [1]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

#DATA_PATH = "/mnt/d/Uni/Bachelorarbeit/linux/data"
DATA_PATH = "./data/"
#MODELS_PATH = "/mnt/d/Uni/Bachelorarbeit/linux/huggingface"
MODELS_PATH = "./data/huggingface"

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

from src.log_mock import PrintLog
log = PrintLog()

import wandb

wandb.init(mode="disabled")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.




In [2]:
wapi = wandb.Api()
runs = wapi.runs("foobar/civil")

In [3]:
def create_reliability_plot(results):
    print(f"sECE: {results['sece']:.4f}, ECE: {results['ece']:.4f}")
    bins = list(filter(lambda x: x[0] > 0, zip(results["bin_confidences"], results["bin_accuracies"], results["bin_counts"])))

    print("\\begin{tikzpicture}")
    print("    \\begin{axis}[calstyle, xmin=0, xmax=1, ymin=0, ymax=1]")
    print("        \\addplot[dashed, color=black] coordinates {(0,0) (1,1)};")
    print("        \\addplot[calline] coordinates {" + " ".join(map(lambda x: f"({x[0]}, {x[1]})", bins)) + "};")
    for conf, acc, count in list(bins):
        print(f"        \\node[above, anchor=south west, rotate=60, font=\\tiny] at (axis cs:{conf}, 1.0) {{{count}}};")
        print(f"        \\draw[dotted, color=black] (axis cs:{conf}, {acc}) -- (axis cs:{conf}, 1.0);")
    print("    \\end{axis}")
    print("\\end{tikzpicture}")

In [4]:
import plotly.express as px
import pandas as pd
import dateutil
import datetime
import json

def create_plot_data_for_run(run):
    lowest_sece = 1
    lowest_sece_group = "None"
    highest_sece = -1
    highest_sece_group = "None"
    worst_acc = 1
    worst_acc_group = "None"
    print(run.name)
    for name, results in run.summary["test_results"].items():
        if "toxic" in name:
            if results["sece"] < lowest_sece:
                lowest_sece = results["sece"]
                lowest_sece_group = name
            if results["sece"] > highest_sece:
                highest_sece = results["sece"]
                highest_sece_group = name
            if results["accuracy"] < worst_acc:
                worst_acc = results["accuracy"]
                worst_acc_group = name

    create_reliability_plot(run.summary["test_results"][lowest_sece_group])

    model_name = run.name.split("-")[0]
    return {
        "model": model_name + "-" + run.name.split("-")[2] if "drop-rates" in run.tags else model_name,
        "worst_acc": worst_acc_group,
        "worst_acc accuracy": worst_acc, #run.summary["test_results"]["worst group accuracy"],
        "worst_acc sece": run.summary["test_results"][worst_acc_group]["sece"],
        "worst_acc ece": run.summary["test_results"][worst_acc_group]["ece"],

        "lowest_sece": lowest_sece_group,
        "lowest_sece accuracy": run.summary["test_results"][lowest_sece_group]["accuracy"],
        "lowest_sece sece": lowest_sece,

        "highest_sece": highest_sece_group,
        "highest_sece accuracy": run.summary["test_results"][highest_sece_group]["accuracy"],
        "highest_sece sece": highest_sece,

        "all accuracy": run.summary["test_results"]["all"]["accuracy"],
        "all sece": run.summary["test_results"]["all"]["sece"],
        "all ece": run.summary["test_results"]["all"]["ece"],
    }

def plot(data, value):
    plot = px.box(data, x="model", y=value, color="model")
    return plot

def pareto_plot(data, group, ece=False):
    if ece:
        plot = px.scatter(data, x=f"{group} accuracy", error_x=f"{group} accuracy_std", y=f"{group} ece", error_y=f"{group} ece_std", color="model")
    else:
        plot = px.scatter(data, x=f"{group} accuracy", error_x=f"{group} accuracy_std", y=f"{group} sece", error_y=f"{group} sece_std", color="model")
    return plot

def build_data(runs):
    rows = []
    for run in runs:
        if dateutil.parser.parse(run.created_at) < datetime.datetime(2023, 3, 10, 10, 0):
            continue
        if run.state != "finished":
            continue
        if "old" in run.tags:
            print(run.name)
            continue
        print(run.summary.keys())
        rows.append(create_plot_data_for_run(run))
    return pd.DataFrame.from_dict(rows)

def aggregate_data(data):
    aggregated_data = data.groupby(["model"]).agg({
        "model": "first",
        "worst_acc accuracy": ["mean", "sem"], 
        "worst_acc sece": ["mean", "sem"],
        "worst_acc ece": ["mean", "sem"],
        "lowest_sece accuracy": ["mean", "sem"], 
        "lowest_sece sece": ["mean", "sem"],
        "highest_sece accuracy": ["mean", "sem"], 
        "highest_sece sece": ["mean", "sem"],
        "all accuracy": ["mean", "sem"], 
        "all sece": ["mean", "sem"],
        "all ece": ["mean", "sem"],
    })
    aggregated_data.columns = [a[0] + "_std" if a[1] == "sem" else a[0] for a in aggregated_data.columns.to_flat_index()]
    aggregated_data["worst_acc accuracy_std"] *= 2.0
    aggregated_data["worst_acc sece_std"] *= 2.0
    aggregated_data["worst_acc ece_std"] *= 2.0
    aggregated_data["lowest_sece accuracy_std"] *= 2.0
    aggregated_data["lowest_sece sece_std"] *= 2.0
    aggregated_data["highest_sece accuracy_std"] *= 2.0
    aggregated_data["highest_sece sece_std"] *= 2.0
    aggregated_data["all accuracy_std"] *= 2.0
    aggregated_data["all sece_std"] *= 2.0
    aggregated_data["all ece_std"] *= 2.0
    return aggregated_data

In [5]:
data = aggregate_data(build_data(runs))

dict_keys(['_step', '_wandb', '_runtime', 'train_loss', 'other_religion-toxic', 'black-non-toxic', 'female-non-toxic', 'other_religion-non-toxic', 'all', 'male-toxic', 'white-toxic', 'female-toxic', 'male-non-toxic', 'christian-toxic', 'muslim-non-toxic', 'christian-non-toxic', 'all-toxic', '_timestamp', 'black-toxic', 'test_results', 'all-non-toxic', 'lgbtq-toxic', 'muslim-toxic', 'lgbtq-non-toxic', 'white-non-toxic', 'worst group accuracy'])
ll_svgd-2
sECE: -0.3798, ECE: 0.3798
\begin{tikzpicture}
    \begin{axis}[calstyle, xmin=0, xmax=1, ymin=0, ymax=1]
        \addplot[dashed, color=black] coordinates {(0,0) (1,1)};
        \addplot[calline] coordinates {(0.5515111684799194, 0.44255319237709045) (0.6487820744514465, 0.4377880096435547) (0.7487647533416748, 0.4919354915618897) (0.8525769114494324, 0.481203019618988) (0.9580023288726808, 0.12585033476352692)};
        \node[above, anchor=south west, rotate=60, font=\tiny] at (axis cs:0.5515111684799194, 1.0) {235};
        \draw[dot

In [6]:
algo_names = [
    ("map", "MAP"),
    ("map_4", "Deep Ensemble"),
    ("mcd", "MCD ($p=0.2$)"),
    ("mcd-p0.1", "MCD ($p=0.1$)"),
    ("mcd-p0.05", "MCD ($p=0.05$)"),
    ("mcd-p0.01", "MCD ($p=0.01$)"),
    ("mcd_4", "MultiMCD ($p=0.2$)"),
    ("swag", "SWAG"),
    ("swag_4", "MultiSWAG"),
    ("laplace", "Laplace"),
    ("laplace_4", "MultiLaplace"),
    ("bbb", "BBB"),
    ("bbb_4", "MultiBBB"),
    ("rank1", "Rank-1 VI"),
    ("ll_ivon", "iVON"),
    ("ll_ivon_5", "MultiiVON"),
    ("svgd", "SVGD"),
]

def num(value, std):
    return f"${float(value):.3f} \\pm {float(std):.3f}$"

def col_name(name, align):
    return f"\\multicolumn{{1}}{{{align}}}{{{name}}}"

def create_table(data, prefix):
    print("\\begin{tabular}{l|rrr}")
    print(f"    {col_name('Model', 'l')} & {col_name('Accuracy', 'c')} & {col_name('ECE', 'c')} & {col_name('sECE', 'c')} \\\\")
    print("    \\hline")
    for algo, name in algo_names:
        row = data[data["model"] == algo]
        print(f"    {name} & {num(row[prefix + ' accuracy'], row[prefix + ' accuracy_std'])} & {num(row[prefix + ' ece'], row[prefix + ' ece_std'])} & {num(row[prefix + ' sece'], row[prefix + ' sece_std'])} \\\\")
    print("\\end{tabular}")
create_table(data, "worst_acc")

\begin{tabular}{l|rrr}
    \multicolumn{1}{l}{Model} & \multicolumn{1}{c}{Accuracy} & \multicolumn{1}{c}{ECE} & \multicolumn{1}{c}{sECE} \\
    \hline
    MAP & $0.420 \pm 0.021$ & $0.353 \pm 0.025$ & $-0.353 \pm 0.025$ \\
    Deep Ensemble & $0.419 \pm 0.008$ & $0.349 \pm 0.010$ & $-0.349 \pm 0.010$ \\
    MCD ($p=0.2$) & $0.326 \pm 0.023$ & $0.417 \pm 0.030$ & $-0.417 \pm 0.030$ \\
    MCD ($p=0.1$) & $0.325 \pm 0.021$ & $0.418 \pm 0.027$ & $-0.418 \pm 0.027$ \\
    MCD ($p=0.05$) & $0.364 \pm 0.018$ & $0.393 \pm 0.024$ & $-0.393 \pm 0.024$ \\
    MCD ($p=0.01$) & $0.396 \pm 0.015$ & $0.374 \pm 0.023$ & $-0.374 \pm 0.023$ \\
    MultiMCD ($p=0.2$) & $0.326 \pm 0.005$ & $0.412 \pm 0.007$ & $-0.412 \pm 0.007$ \\
    SWAG & $0.448 \pm 0.021$ & $0.197 \pm 0.041$ & $-0.184 \pm 0.027$ \\
    MultiSWAG & $0.429 \pm 0.016$ & $0.183 \pm 0.018$ & $-0.183 \pm 0.018$ \\
    Laplace & $0.424 \pm 0.016$ & $0.348 \pm 0.018$ & $-0.347 \pm 0.018$ \\
    MultiLaplace & $0.420 \pm 0.008$ & $0.348 \pm 0

In [7]:
data.to_csv(sep=",", header=True)

'model,model,worst_acc accuracy,worst_acc accuracy_std,worst_acc sece,worst_acc sece_std,worst_acc ece,worst_acc ece_std,lowest_sece accuracy,lowest_sece accuracy_std,lowest_sece sece,lowest_sece sece_std,highest_sece accuracy,highest_sece accuracy_std,highest_sece sece,highest_sece sece_std,all accuracy,all accuracy_std,all sece,all sece_std,all ece,all ece_std\nbbb,bbb,0.5368559569120407,0.03200839396511114,-0.36101158601252437,0.032593469208323206,0.3620227043002385,0.031865070970091586,0.5374175786972046,0.031602751415921976,-0.36217909601381704,0.03269167832750197,0.9586798846721649,0.005067274099174752,-0.022064603723480063,0.0037759260563161516,0.917814064025879,0.0016092263536423084,-0.05580003168551693,0.0018544118248516141,0.05580003168551693,0.0018544118248516141\nbbb_0.01,bbb_0.01,0.5417108178138733,0.04470278052827165,-0.3591472928683913,0.04313249069067954,0.3596272585324829,0.04273540173720286,0.5417108178138733,0.04470278052827165,-0.3591472928683913,0.04313249069067954

In [8]:
pareto_plot(data, "worst_acc")

In [9]:
pareto_plot(data, "worst_acc", ece=True)

In [10]:
pareto_plot(data, "lowest_sece")

In [11]:
pareto_plot(data, "highest_sece")

In [18]:
algo_names = [
    ("map", "MAP"),
    ("map_4", "Deep Ensemble"),
    ("mcd", "MCD ($p=0.2$)"),
    ("mcd-p0.1", "MCD ($p=0.1$)"),
    ("mcd-p0.05", "MCD ($p=0.05$)"),
    ("mcd-p0.01", "MCD ($p=0.01$)"),
    ("mcd_4", "MultiMCD ($p=0.2$)"),
    ("swag", "SWAG"),
    ("swag_4", "MultiSWAG"),
    # ("swag_ll-1", "LL SWAG"),
    ("laplace", "LL Laplace"),
    ("laplace_4", "LL MultiLaplace"),
    ("bbb", "LL BBB"),
    ("bbb_4", "LL MultiBBB"),
    ("rank1", "Rank-1 VI"),
    ("ll_ivon", "LL iVON"),
    ("ll_ivon_5", "LL MultiiVON"),
    ("svgd", "SVGD"),
]

def num(value, std, best=None, ty=None):
    value = float(value)
    std = float(std)
    num_string = f"{value:.3f} \\pm {std:.3f}"

    if best is None or ty is None:
        return f"${num_string}$"

    if ty == "max":
        if value >= best:
            num_string = f"\\bm{{{num_string}}}"
    elif ty == "min":
        if value <= best:
            num_string = f"\\bm{{{num_string}}}"
    elif ty == "zero":
        if abs(value) <= best:
            num_string = f"\\bm{{{num_string}}}"
    return f"${num_string}$"

def col_name(name, align):
    return f"\\multicolumn{{1}}{{{align}}}{{{name}}}"

def create_table(data, prefix):
    print("\\begin{tabular}{l|rrrrrr}")
    print(f"    {col_name('Model', 'l')} & {col_name('WG Accuracy', 'c')} & {col_name('WG ECE', 'c')} & {col_name('WG sECE', 'c')} & {col_name('Avg Accuracy', 'c')} & {col_name('Avg ECE', 'c')} & {col_name('Avg sECE', 'c')} \\\\")
    print("    \\hline")

    best_acc, best_acc_std = 0, 0
    best_ece, best_ece_std = 1000, 0
    best_sece, best_sece_std = 1000, 0
    best_avg_acc, best_avg_acc_std = 0, 0
    best_avg_ece, best_avg_ece_std = 1000, 0
    best_avg_sece, best_avg_sece_std = 1000, 0

    for algo, name in algo_names:
        row = data[data["model"] == algo]

        if float(row[prefix + "worst_acc accuracy"]) > best_acc:
            best_acc = float(row[prefix + "worst_acc accuracy"])
            best_acc_std = float(row[prefix + "worst_acc accuracy_std"])
        
        if float(row[prefix + "worst_acc ece"]) < best_ece:
            best_ece = float(row[prefix + "worst_acc ece"])
            best_ece_std = float(row[prefix + "worst_acc ece_std"])
        
        if abs(float(row[prefix + "worst_acc sece"])) < best_sece:
            best_sece = abs(float(row[prefix + "worst_acc sece"]))
            best_sece_std = float(row[prefix + "worst_acc sece_std"])
        
        if float(row[prefix + "all accuracy"]) > best_avg_acc:
            best_avg_acc = float(row[prefix + "all accuracy"])
            best_avg_acc_std = float(row[prefix + "all accuracy_std"])
        
        if float(row[prefix + "all ece"]) < best_avg_ece:
            best_avg_ece = float(row[prefix + "all ece"])
            best_avg_ece_std = float(row[prefix + "all ece_std"])
        
        if abs(float(row[prefix + "all sece"])) < best_avg_sece:
            best_avg_sece = abs(float(row[prefix + "all sece"]))
            best_avg_sece_std = float(row[prefix + "all sece_std"])

    best_acc -= best_acc_std
    best_ece += best_ece_std
    best_sece = abs(best_sece) + best_sece_std

    best_avg_acc -= best_avg_acc_std
    best_avg_ece += best_avg_ece_std
    best_avg_sece = abs(best_avg_sece) + best_avg_sece_std

    for algo, name in algo_names:
        row = data[data["model"] == algo]
        print(f"    {name} & {num(row[prefix + 'worst_acc accuracy'], row[prefix + 'worst_acc accuracy_std'], best_acc, 'max')} & {num(row[prefix + 'worst_acc ece'], row[prefix + 'worst_acc ece_std'], best_ece, 'min')} & {num(row[prefix + 'worst_acc sece'], row[prefix + 'worst_acc sece_std'], best_sece, 'zero')} & {num(row[prefix + 'all accuracy'], row[prefix + 'all accuracy_std'], best_avg_acc, 'max')} & {num(row[prefix + 'all ece'], row[prefix + 'all ece_std'], best_avg_ece, 'min')} & {num(row[prefix + 'all sece'], row[prefix + 'all sece_std'], best_avg_sece, 'zero')} \\\\")
    print("\\end{tabular}")
create_table(data, "")

\begin{tabular}{l|rrrrrr}
    \multicolumn{1}{l}{Model} & \multicolumn{1}{c}{WG Accuracy} & \multicolumn{1}{c}{WG ECE} & \multicolumn{1}{c}{WG sECE} & \multicolumn{1}{c}{Avg Accuracy} & \multicolumn{1}{c}{Avg ECE} & \multicolumn{1}{c}{Avg sECE} \\
    \hline
    MAP & $0.420 \pm 0.021$ & $0.353 \pm 0.025$ & $-0.353 \pm 0.025$ & $0.916 \pm 0.001$ & $0.012 \pm 0.003$ & $-0.012 \pm 0.003$ \\
    Deep Ensemble & $0.419 \pm 0.008$ & $0.349 \pm 0.010$ & $-0.349 \pm 0.010$ & $0.916 \pm 0.000$ & $0.010 \pm 0.001$ & $-0.010 \pm 0.001$ \\
    MCD ($p=0.2$) & $0.326 \pm 0.023$ & $0.417 \pm 0.030$ & $-0.417 \pm 0.030$ & $0.918 \pm 0.000$ & $0.007 \pm 0.005$ & $\bm{0.006 \pm 0.006}$ \\
    MCD ($p=0.1$) & $0.325 \pm 0.021$ & $0.418 \pm 0.027$ & $-0.418 \pm 0.027$ & $0.918 \pm 0.000$ & $\bm{0.007 \pm 0.005}$ & $\bm{0.006 \pm 0.005}$ \\
    MCD ($p=0.05$) & $0.364 \pm 0.018$ & $0.393 \pm 0.024$ & $-0.393 \pm 0.024$ & $0.918 \pm 0.000$ & $\bm{0.005 \pm 0.002}$ & $\bm{-0.003 \pm 0.004}$ \\
    MCD ($p=