In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import odds_datasets


def load_sims(dataset, batch_size, interest_method, data_copies, independent_queries):
    save_dir = f"results/{data_copies}x{dataset}/{"independent" if independent_queries else "worstcase"}/{interest_method}/batch_size_{batch_size}"
    files = [f"{save_dir}/{f}" for f in os.listdir(save_dir) if f.endswith(".npz")]
    sim = np.stack([np.load(f)["avp_test"] for f in files], axis=-1)
    return sim


def plot_sim(
    dataset,
    batch_size,
    data_copies,
    strategies=["bald", "margin", "anom"],
):
    if dataset == "legend":
        for strategy, color in zip(strategies, plt.cm.tab10.colors):
            plt.plot([], [], label=f"{strategy}", color=color)
        for independent, linestyle in zip([False, True], ["-", ":"]):
            plt.plot([], [], label=f"independent={independent}", linestyle=linestyle, color="black")
        plt.legend()
        return
    for strategy, color in zip(strategies, plt.cm.tab10.colors):
        for independent, linestyle in zip([False, True], ["-", ":"]):
            try:
                sim = load_sims(dataset, batch_size, strategy, data_copies, independent_queries=independent)
                x = np.arange(len(sim)) / len(sim) * 100
                value = sim.mean(-1)
                lb = value - sim.std(-1)*1.96/np.sqrt(sim.shape[-1])
                ub = value + sim.std(-1)*1.96/np.sqrt(sim.shape[-1])
                plt.plot(x, value, color=color, linestyle=linestyle)
                plt.fill_between(x, lb, ub, alpha=0.1, color=color)
            except:
                pass
    plt.grid(True)
    plt.xlim(0, 20)

for data_copies in [3]:
    for batch_size in [1]:
        plt.figure(figsize=(9, 18))
        for i, dataset in enumerate(odds_datasets.datasets_names):
            plt.subplot(6, 3, i + 1)
            plt.title(f"{dataset}")
            plot_sim(dataset, batch_size=batch_size, data_copies=data_copies)
        plt.subplot(6, 3, 18)
        plot_sim("legend", None, None)
        plt.axis("off")
        plt.suptitle(f"Copies={data_copies}, Batch={batch_size}")
        plt.tight_layout()
        plt.savefig(f"results/plot_batch_{batch_size}_copies_{data_copies}.pdf")
        plt.show()