In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import odds_datasets


def load_sim(dataset, strategy, batch_size):
    save_dir = f"results/{dataset}/{strategy}/batch_size_{batch_size}"
    files = [f"{save_dir}/{f}" for f in os.listdir(save_dir) if f.endswith(".npz")]
    sim = np.stack([np.load(f)["avp_test"] for f in files], axis=-1)
    return sim


def plot_sim(sim, label, color, linestyle):
    x = (np.arange(len(sim)) / len(sim) * 100) + 1.0
    value = sim.mean(-1)
    lb = value - sim.std(-1) * 1.96 / np.sqrt(sim.shape[-1])
    ub = value + sim.std(-1) * 1.96 / np.sqrt(sim.shape[-1])
    plt.plot(x, value, label=label, color=color, linestyle=linestyle)
    plt.fill_between(x, lb.clip(0, 1), ub.clip(0, 1), alpha=0.1, color=color)

In [None]:
# compare worstcase and batchbald
for batch_size in [10, 5, 1]:
    plt.figure(figsize=(9, 12))
    for i, dataset in enumerate(
        odds_datasets.small_datasets_names
        + odds_datasets.medium_datasets_names
        + odds_datasets.large_datasets_names
    ):
        plt.subplot(6, 3, i + 1)
        plt.title(f"{dataset}")
        try:  # if simulation data is not available, skip
            sim = load_sim(dataset, "worstcase_bald", batch_size)
            plot_sim(sim, label=f"bald worstcase", color="tab:blue", linestyle="-")

            sim = load_sim(dataset, "independent_bald", batch_size)
            plot_sim(sim, f"bald independent", color="tab:blue", linestyle=":")

            sim = load_sim(dataset, "worstcase_margin", batch_size)
            plot_sim(sim, label=f"margin worstcase", color="tab:orange", linestyle="-")

            sim = load_sim(dataset, "independent_margin", batch_size)
            plot_sim(sim, f"margin independent", color="tab:orange", linestyle=":")

            sim = load_sim(dataset, "batchbald", batch_size)
            plot_sim(sim, f"batchbald", color="tab:green", linestyle="-")

            sim = load_sim(dataset, "random", batch_size)
            plot_sim(sim, f"random", color="black", linestyle=":")
        except:
            pass
        
        plt.grid(True)
        plt.xscale("log")
        plt.xticks([1, 11, 101], ["0%", "10%", "100%"])
        # plt.xlim(0, 100)
        # plt.ylim(0, 1)
        plt.xlabel("Labelling Budget")
        plt.ylabel("AP")
        if i==17:
            # legend external to boxplots
            plt.legend(
                loc="upper left",
                bbox_to_anchor=(1.05, 1),
                borderaxespad=0,
                fontsize=12,
            )
        
    plt.tight_layout()
    plt.savefig(f"figures/ablation_variants_{batch_size}.pdf", bbox_inches="tight")
    plt.show()

In [None]:
import numpy as np
from scipy.special import digamma
from scipy.stats import beta
import matplotlib.pyplot as plt
from bad import BetaDistr
from mpl_toolkits.axes_grid1.inset_locator import inset_axes


def bald(a, b):
    mu = BetaDistr(a, b).mu()
    H_y = -mu * np.log(mu) - (1 - mu) * np.log(1 - mu)
    H_y_given_w = digamma(a + b + 1) - mu * digamma(a + 1) - (1 - mu) * digamma(b + 1)
    return np.log(H_y - H_y_given_w)


def margin(a, b, r=0.4):
    mode = BetaDistr(a, b).mode()
    logmargin = beta.logpdf(mode, a, b) - beta.logpdf(r, a, b)
    return np.exp(-logmargin)


def anom(a, b):
    mu = BetaDistr(a, b).mu()
    return mu


a = 1 + 10 ** np.linspace(-1, 2, 100)
b = 1 + 10 ** np.linspace(-1, 2, 100)
A, B = np.meshgrid(a, b)

plt.figure(figsize=(15, 4))
for i, fn in enumerate([bald, margin, anom]):
    plt.subplot(1, 3, i + 1)
    plt.contourf(A, B, fn(A, B), levels=20, cmap="viridis")
    plt.colorbar()

    a, b = 15, 10
    ss1, ss2, ss3 = (a + b + 30, a + b + 60, a + b + 90)
    plt.plot(a, b, "rx")
    mu = np.linspace(0, 1, 100)
    plt.plot(ss1 * mu, ss1 * (1 - mu), "r")
    plt.plot(ss2 * mu, ss2 * (1 - mu), "r")
    plt.plot(ss3 * mu, ss3 * (1 - mu), "r")

    plt.title(fn.__name__)
    plt.xlabel("$\\alpha$")
    plt.ylabel("$\\beta$")
    # plt.xscale("log")
    # plt.yscale("log")
    plt.xlim(1, A.max())
    plt.ylim(1, B.max())
    plt.gca().set_aspect("equal", adjustable="box")

    axins = inset_axes(plt.gca(), width="40%", height="40%", loc="upper right")
    axins.plot(mu, fn(ss1 * mu, ss1 * (1 - mu)))
    axins.plot(mu, fn(ss2 * mu, ss2 * (1 - mu)))
    axins.plot(mu, fn(ss3 * mu, ss3 * (1 - mu)))
    axins.set_xticks([])
    axins.set_yticks([])

plt.tight_layout()
plt.savefig("figures/bald_margin_anom.pdf")
plt.show()