In [1]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import torch.nn.functional as F

import wandb

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

from src.log_mock import PrintLog
log = PrintLog()

#import experiments.base.multiclass_classification as exp

In [2]:
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import pandas as pd
import math

def wandb_data(name, data, corruption):
    return {
            "model": name,
            "corruption": corruption,
            "accuracy": data["accuracy"],
            "ece": data["ece"],
            "sece": data["sece"],
            "avg_ll": data["log_likelihood"],
            "avg_l": 0.0, #TODO
            "agreement": data["agreement"], #TODO
            "tv": data["tv"] #TODO
        }

def extend_from_wandb(data):
    wapi = wandb.Api()
    runs = wapi.runs("foobar/cifar_10")
    rows = []
    for run in runs:
        if run.state != "finished":
            print("Skipping unfinished run " + run.name)
            continue
        if "old" in run.tags:
            print("Skipping old run " + run.name)
            continue
        rows.append(wandb_data(run.name.split("-")[0], run.summary["test_results"], 0))
        rows.append(wandb_data(run.name.split("-")[0], run.summary["c1_results"], 1))
        rows.append(wandb_data(run.name.split("-")[0], run.summary["c3_results"], 2))
        rows.append(wandb_data(run.name.split("-")[0], run.summary["c5_results"], 3))
    
    extra_data = pd.DataFrame.from_dict(rows)
    extra_data = extra_data.groupby(["model", "corruption"]).agg({
        "model": "first",
        "corruption": "first",
        "accuracy": ["mean", "sem"],
        "ece": ["mean", "sem"],
        "sece": ["mean", "sem"],
        "avg_ll": ["mean", "sem"],
        "agreement": ["mean", "sem"],
        "tv": ["mean", "sem"],
    })
    extra_data.columns = [a[0] + "_std" if a[1] == "sem" else a[0] for a in extra_data.columns.to_flat_index()]
    extra_data["accuracy_std"] *= 2.0
    extra_data["ece_std"] *= 2.0
    extra_data["sece_std"] *= 2.0
    extra_data["avg_ll_std"] *= 2.0
    extra_data["agreement_std"] *= 2.0
    extra_data["tv_std"] *= 2.0
    data.extend(extra_data.to_dict("records"))

def plot(data, value):
    plot = px.line(data, x="corruption", y=value, color="model", error_y=value + "_std")
    plot.update_layout(xaxis={"tickmode": "array", "tickvals": [0, 1, 2, 3], "ticktext": ["Standard", "Corrupted 1", "Corrupted 3", "Corrupted 5"]})
    return plot

def plot_all(data):
    acc = plot(data, "accuracy")
    ll = plot(data, "avg_ll")
    ece = plot(data, "ece")
    sece = plot(data, "sece")
    tv = plot(data, "tv")
    agreement = plot(data, "agreement")

    fig = make_subplots(rows=3, cols=2, column_widths=[500, 500])

    for trace in acc["data"]:
        fig.add_trace(trace, row=1, col=1)
    fig.update_xaxes(title_text="Accuracy", row=1, col=1)

    for trace in ll["data"]:
        trace["showlegend"] = False
        fig.add_trace(trace, row=1, col=2)
    fig.update_xaxes(title_text="Log Likelihood", row=1, col=2)

    for trace in ece["data"]:
        trace["showlegend"] = False
        fig.add_trace(trace, row=2, col=1)
    fig.update_xaxes(title_text="ECE", row=2, col=1)

    for trace in sece["data"]:
        trace["showlegend"] = False
        fig.add_trace(trace, row=2, col=2)
    fig.update_xaxes(title_text="sECE", row=2, col=2)

    for trace in agreement["data"]:
        trace["showlegend"] = False
        fig.add_trace(trace, row=3, col=1)
    fig.update_xaxes(title_text="Agreement with HMC", row=3, col=1)

    for trace in tv["data"]:
        trace["showlegend"] = False
        fig.add_trace(trace, row=3, col=2)
    fig.update_xaxes(title_text="Total Variation vs. HMC", row=3, col=2)

    fig.update_layout(margin={"l": 20, "r": 20, "t": 20, "b": 20})
    
    fig.update_layout(width=1200, height=800)
    return fig

In [None]:
data = []

extend_from_wandb(data)

plot_all(data)

In [4]:
print("model,corruption,accuracy,accuracy_std,avg_ll,avg_ll_std,ece,ece_std,sece,sece_std,agreement,agreement_std,tv,tv_std")
for row in data:
    name = row['model']
    if name == "MC Dropout":
        name = "MCD"
    elif name == "Multi MC Dropout":
        name = "MultiMCD"
    elif name == "iVORN":
        continue
    elif name == "ivon_1":
        name = "iVON"
    
    if "agreement" in row:
        agreement, agreement_std, tv, tv_std = row["agreement"], row["agreement_std"], row["tv"], row["tv_std"]
    else:
        agreement, tv = 0, 0
    line = f"{name},{row['corruption']},{row['accuracy']},{row['accuracy_std']},{row['avg_ll']},{row['avg_ll_std']},{row['ece']},{row['ece_std']},{row['sece']},{row['sece_std']},{agreement},{agreement_std},{tv},{tv_std}"
    line = line.replace("nan", "0.0")
    print(line)

model,corruption,accuracy,accuracy_std,avg_ll,avg_ll_std,ece,ece_std,sece,sece_std,agreement,agreement_std,tv,tv_std
bbb_5,0,0.9289399862289429,0.0013690870404477128,-0.22758567631244658,0.00171084822988486,0.018102198257297277,0.0018242617315003301,0.01782454021707177,0.0018975781189125255,0.9299999952316285,0.0011610377610865267,0.12154988497495652,0.0005377555209729788
bbb_5,1,0.8776562452316284,0.0077434539009636014,-0.3939965724945068,0.012232234947588828,0.01848919410724193,0.0051909746776861625,0.00635900350753218,0.00933838297422383,0.8800000071525573,0.004650908302370291,0.16986879110336303,0.0038786628268606977
bbb_5,2,0.7857812523841858,0.013758872296249516,-0.7411298751831055,0.0449542216083606,0.032945302804000674,0.009798379767512268,-0.025964745883829892,0.014145996482403252,0.8303125023841857,0.0059988724299937886,0.2039065182209015,0.002549201545521224
bbb_5,3,0.6206250190734863,0.013519012985723815,-1.4106752395629882,0.10516938354553373,0.114474222920835,0.0158915758