In [None]:
import pandas as pd
import wandb
from tqdm.notebook import tqdm
import pickle
from os.path import exists
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import math
import ast
import scipy as sp
import scipy.stats as sps

from matplotlib.ticker import MaxNLocator
#...

font = {'family' : 'times',
        'size'   : 14}

matplotlib.rc('font', **font)

In [None]:
class Experiment:
    def __init__(self, run):
        self.name = run.name
        self.config = run.config
        self.summary = run.summary
        self.history = run.history()
        self.tags = run.tags
        self.run = run
        
    def get_id(self):
        return (self.config['formula'],self.config['mol_idx'])
        
    def get_history(self):
        return np.array(list(self.history['additional_steps'])).cumsum()

In [None]:
def fetch(project):
    api = wandb.Api()
    entity = "bogp"
    hdata = []
    runs = api.runs(entity + "/" + project)
    for run in tqdm(runs):
        try:
            hdata.append(Experiment(run))
        except:
            pass
    return hdata

In [None]:
raw = fetch("scale_master")

# Width

In [None]:
exps = {}
for exp in raw:
    if exp.run.group == "bayes_wide":
        print(exp.name)
        exps[exp.config["num_particles"]] = exp
exps

In [None]:
def calc(exps, p):
    d = ast.literal_eval(exps[p].history.orig_dist0[11])
    d_swag = ast.literal_eval(exps[p].history.max_dist0[len(exps[p].history.max_dist0) - 1])
    misclass = 0
    misclass_swag = 0
    total = 0
    for c in range(10):
        misclass += sum(d[c]) - d[c][c]
        total += sum(d[c])
        misclass_swag += sum(d_swag[c]) - d_swag[c][c]
        # print("Orig", c, sps.entropy([x / sum(d[c]) for x in d[c]]))
        # print("Swag", c, sps.entropy([x / sum(d_swag[c]) for x in d_swag[c]]))
    print("original misclass", misclass, "mswag misclass", misclass_swag)
    return 1 - (misclass/total), 1- (misclass_swag/total)

In [None]:
ps = [1, 2, 4, 8, 16, 32]
orig = []
mswag = []
for p in ps:
    m1, m2 = calc(exps, p)
    orig += [m1]
    mswag += [m2]
plt.plot(ps, orig, label='Standard', marker='s', linestyle="--" )
plt.plot(ps, mswag, label="Multi-Swag", marker='o', linestyle=":")
plt.xlabel("Particles")
plt.ylabel("Accuracy")
plt.title("Standard Training vs. Multi-Swag on MNIST")
plt.legend()

In [None]:
params = []
for p in ps:
    params += [exps[p].config["num_params"]]
params

In [None]:
orig, mswag

In [None]:
df = pd.DataFrame({
    "parameters": params,
    "original accuracy": orig,
    "particles": ps,
    "mswag accuracy": mswag,
})
df

In [None]:
df.to_latex(buf="table_width.tex", index=False)

# Depth

In [None]:
bayes5 = {}
for exp in raw:
    if exp.run.group == "bayes_deep":
        print(exp.name)
        bayes5[exp.config["num_particles"]] = exp
bayes5

In [None]:
def calc2(exps, p):
    d = ast.literal_eval(exps[p].history.orig_dist0[11])
    d_swag = ast.literal_eval(exps[p].history.max_dist0[len(exps[p].history.max_dist0) - 1])
    misclass = 0
    misclass_swag = 0
    total = 0
    for c in range(10):
        misclass += sum(d[c]) - d[c][c]
        total += sum(d[c])
        misclass_swag += sum(d_swag[c]) - d_swag[c][c]
        # print("Orig", c, sps.entropy([x / sum(d[c]) for x in d[c]]))
        # print("Swag", c, sps.entropy([x / sum(d_swag[c]) for x in d_swag[c]]))
    print("original misclass", misclass, "mswag misclass", misclass_swag)
    return 1 - (misclass/total), 1 - (misclass_swag/total)

In [None]:
params = []
for p in ps:
    params += [bayes5[p].config["num_params"]]
orig = []
mswag = []
for p in ps:
    m1, m2 = calc2(bayes5, p)
    orig += [m1]
    mswag += [m2]
    
df = pd.DataFrame({
    "parameters": params,
    "original accuracy": orig,
    "particles": ps,
    "mswag accuracy": mswag,
})
df

In [None]:
df.to_latex(buf="table_depth.tex", index=False)

# Plot

In [None]:
ps = [1, 2, 4, 8, 16, 32]
orig = []
mswag = []
for p in ps:
    m1, m2 = calc2(exps, p)
    orig += [m1]
    mswag += [m2]
fig, ax = plt.subplots()
ax.plot(ps, mswag, label="Multi-Swag", marker='o', linestyle=":")
ax.plot(ps, orig, label='Standard', marker='s', linestyle="--" )

def foo(x):
    print(x)
    return exps[x].config["num_params"]

params_to_p = {}
for p in ps:
    params_to_p[exps[p].config["num_params"]] = p

def foo_inv(x):
    return params_to_p[x]
    
secax = ax.secondary_xaxis('top', functions=(foo, foo_inv))


ax.set_xlabel("Particles")
ax.set_ylabel("Acurracy")
ax.set_title("Standard Training vs. Multi-Swag on MNIST")
ax.legend()