In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
from tqdm.auto import tqdm
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import odds_datasets

from balif import Balif
from sklearn.metrics import average_precision_score

In [None]:
def run_sim(rng, train_data, train_labels, test_data, hyperplane_components):
    def scan_body(carry, x):
        model, queriable = carry

        scores = model.score_samples(test_data)

        interests = model.interest_for(train_data)
        query_idx = jnp.where(queriable, interests, 0.0).argmax()
        queriable = queriable.at[query_idx].set(False)

        model = model.register(train_data[query_idx], train_labels[query_idx])

        return (model, queriable), scores

    iterations = train_data.shape[0]//10
    model = Balif.fit(rng, train_data, hyperplane_components=hyperplane_components)
    queriable = jnp.ones(train_data.shape[0], dtype=bool)

    _, scores = jax.lax.scan(scan_body, (model, queriable), None, length=iterations+1)
    return scores

jitted_vectorized_run_sim = jax.jit(
    jax.vmap(run_sim, in_axes=(0, None, None, None, None)),
    static_argnames=("hyperplane_components",),
)

In [None]:
seed = 42
n_sims = 128
rng = jax.random.PRNGKey(seed)
rng_balif, rng_ebalif = jax.random.split(rng)
rng_balif = jax.random.split(rng_balif, n_sims)
rng_ebalif = jax.random.split(rng_ebalif, n_sims)

In [None]:
for dataset_name in tqdm(sorted(odds_datasets.datasets_names)):
    if dataset_name in ["cover"]:  # cover is so large it causes OOM
        continue
    data, labels = odds_datasets.load(dataset_name)
    print(data.shape)
    # train_data, test_data, train_labels, test_labels = odds_datasets.load_as_train_test(
    #     dataset_name, test_size=0.5, random_state=seed
    # )
    train_data = test_data = data
    train_labels = test_labels = labels

    hyperplane_components = 1
    scores = jitted_vectorized_run_sim(
        rng_balif, train_data, train_labels, test_data, hyperplane_components
    )
    ap = jnp.array(
        [[average_precision_score(test_labels, s) for s in run_scores] for run_scores in scores]
    )
    ap_mean, ap_std = ap.mean(axis=0), ap.std(axis=0)
    jnp.save(f"results/{dataset_name}_balif_perc.npy", ap, allow_pickle=True)

    hyperplane_components = None
    scores = jitted_vectorized_run_sim(
        rng_ebalif, train_data, train_labels, test_data, hyperplane_components
    )
    ap = jnp.array(
        [[average_precision_score(test_labels, s) for s in run_scores] for run_scores in scores]
    )
    jnp.save(f"results/{dataset_name}_ebalif_perc.npy", ap, allow_pickle=True)