In [None]:
%env CUDA_VISIBLE_DEVICES=2
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
from tqdm.auto import tqdm
import jax
import jax.numpy as jnp
import numpy as np
import odds_datasets

from balif import Balif
from sklearn.metrics import average_precision_score

In [None]:
model_configs = {
    "BALIF": {"model_cls":Balif, "hyperplane_components": 1, "path_score": False},
    "EBALIF": {"model_cls":Balif, "hyperplane_components": None, "path_score": False},
    # "BALIF (Path)": {"model_cls":Balif, "hyperplane_components": 1, "path_score": True},
    # "EBALIF (Path)": {"model_cls":Balif, "hyperplane_components": None, "path_score": True},
}

seed = 42
n_sims = 128

In [None]:
def get_multi_run(model_cls, **fit_kwargs):
    def run_fn(rng, train_data, train_labels, test_data):
        def scan_body(carry, key):
            model, queriable = carry
            scores = model.score_samples(test_data)
            interests = model.interest_for(train_data)
            queries_idx = jnp.where(queriable, interests, interests.min()).argmax()
            model = model.register(train_data[queries_idx], train_labels[queries_idx])
            return (model, queriable), scores

        iterations = 1 + train_data.shape[0] // (10)
        rng_init, rng_steps = jax.random.split(rng)
        rng_steps = jax.random.split(rng_steps, iterations)

        model = model_cls.fit(rng_init, train_data, **fit_kwargs)
        queriable = jnp.ones(train_data.shape[0], dtype=bool)
        _, scores = jax.lax.scan(scan_body, (model, queriable), rng_steps)
        return scores

    return jax.jit(jax.vmap(run_fn, in_axes=(0, None, None, None)))


rng_keys = jax.random.split(jax.random.PRNGKey(seed), n_sims)

for label, model_config in tqdm(model_configs.items(), desc="models"):
    run_fn = get_multi_run(**model_config)

    for dataset_name in (pbar := tqdm(odds_datasets.datasets_names)):
        data, labels = odds_datasets.load(dataset_name)
        pbar.set_description(f"{dataset_name}, shape: {data.shape}")
        scores = run_fn(rng_keys, data, labels, data)
        ap = np.array(
            [[average_precision_score(labels, s) for s in run_scores] 
            for run_scores in scores]
        )
        save_path = f"results/margin_query/{dataset_name}_{label}.npy"
        np.save(save_path, ap, allow_pickle=True)

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(20, 30))
for i, dataset_name in enumerate((pbar := tqdm(sorted(odds_datasets.datasets_names)))):
    plt.subplot(6,3,i+1)
    for j, label in enumerate(model_configs.keys()):
        save_path = f"results/margin_query/{dataset_name}_{label}.npy"
        ap = np.load(save_path, allow_pickle=True)
        ap_mean, ap_std = ap.mean(axis=0), ap.std(axis=0)
        plt.semilogx(1 + jnp.arange(len(ap_mean)), ap_mean)
        plt.fill_between(
            1 + jnp.arange(len(ap_mean)),
            jnp.maximum(0, ap_mean - ap_std),
            jnp.minimum(1, ap_mean + ap_std),
            alpha=0.3,
        )
    plt.title(dataset_name)
    data, labels = odds_datasets.load(dataset_name)
    plt.xlabel(f"Queries (10% ={1+len(data)//10})")
    if i % 3 == 0:
        plt.ylabel("Avg. Precision")
    plt.ylim(-0.05, 1.05)
    plt.xlim(1, 1 + data.shape[0] / 10)
    plt.xticks([1, 1 + data.shape[0] / 100, 1 + data.shape[0] / 10], labels=["0%", "1%", "10%"])
    plt.yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
    plt.grid()

plt.subplot(6,3,len(odds_datasets.datasets_names)+1)
for label in model_configs.keys():
    plt.plot(0, 0, label=label)
plt.title("")
plt.legend(loc="center")
plt.axis("off")
plt.tight_layout()
plt.savefig(f"figures/ap_evolution/ap_all_margin_query.pdf", bbox_inches="tight")
plt.show()