In [None]:
import os
#%env JAX_PLATFORMS=cpu
%env CUDA_VISIBLE_DEVICES=0
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
os.chdir('..')
from tqdm.auto import tqdm
import jax
import jax.numpy as jnp
import jax.random as jr
import equinox as eqx
import numpy as np
import comet_ml
import matplotlib.pyplot as plt
from sklearn.metrics import average_precision_score as avg_precision

import odds_datasets
from balif import Balif

In [None]:
model_configs: list[dict] = []
for query_strategy in ["margin", "random"]:
    for p_normal_idx in ["uniform", "range"]:
        for hyperplane_components in [1, -1]:
            model_configs.append(
                dict(
                    n_estimators=128,
                    max_samples=256,
                    bootstrap=True,
                    standardize=False,
                    hyperplane_components=hyperplane_components,
                    p_normal_idx=p_normal_idx,
                    p_normal_value="uniform",
                    p_intercept="uniform",
                    prior_sample_size=0.1,
                    score_reduction="mean",
                    query_strategy=query_strategy,
                )
            )

In [None]:
@eqx.filter_jit
def run_fn(model_config, data, labels, key):
    def scan_body(carry, key):
        key_score, key_query, key_update = jr.split(key, 3)
        model, queriable = carry

        scores = model.score(data, key=key_score)

        interests = model.interest(data, key=key_query)
        query_idx = jnp.where(queriable, interests, interests.min()).argmax()
        queriable = queriable.at[query_idx].set(False)
        point, is_anomaly = data[query_idx], labels[query_idx]

        model = model.register(point, is_anomaly=is_anomaly, key=key_update)
        return (model, queriable), scores

    samples, dim = data.shape
    iterations = 1 + samples//10
    rng_fit, rng_steps = jr.split(key)
    model = Balif(**model_config)
    model = model.fit(data, key=rng_fit)
    
    queriable = jnp.ones(samples, dtype=bool)
    _, scores = jax.lax.scan(
        scan_body, (model, queriable), jr.split(rng_steps, iterations)
    )
    return scores

In [None]:
comet_ml.init()
for model_config in tqdm(model_configs, desc="models"):
    for seed in range(32):
        experiment = comet_ml.OfflineExperiment(project_name="balif", offline_directory="comet")
        experiment.log_parameters({"seed":seed,**model_config})

        for dataset_name in  tqdm(odds_datasets.datasets_names, desc="datasets"):
            data, labels = odds_datasets.load(dataset_name)

            with experiment.context_manager(dataset_name):
                sim_results = run_fn(model_config, data, labels, jr.key(seed))

                for step, scores in enumerate(sim_results):
                    ap = avg_precision(labels, scores)
                    experiment.log_metric("average_precision", ap, step)
        experiment.end()
!comet upload comet/*.zip

In [None]:
!comet upload comet/*.zip