In [None]:
%env CUDA_VISIBLE_DEVICES=2
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
from tqdm.auto import tqdm
import jax
import jax.numpy as jnp
import numpy as np
from sklearn.metrics import average_precision_score
import odds_datasets
from balif import Balif
from isolation_forest import ExtendedIsolationForest

In [None]:
model_configs = {
    "IF": {"model_cls": ExtendedIsolationForest, "hyperplane_components": 1},
    "pIF": {"model_cls": Balif, "hyperplane_components": 1, "path_score": False},
    "pIF (Path)": {"model_cls":Balif, "hyperplane_components": 1, "path_score": True},
    "EIF": {"model_cls": ExtendedIsolationForest, "hyperplane_components": None},
    "pEIF": {"model_cls": Balif, "hyperplane_components": None, "path_score": False},
    "pEIF (Path)": {"model_cls":Balif, "hyperplane_components": None, "path_score": True},
}

seed = 42
n_sims = 32

In [None]:
def get_multi_fit_predict(model_cls, **fit_kwargs):
    def fit_pred_fn(rng, data):
        return model_cls.fit(rng, data, **fit_kwargs).score_samples(data)
    return jax.jit(jax.vmap(fit_pred_fn, in_axes=(0, None)))
    

rng_keys = jax.random.split(jax.random.PRNGKey(seed), n_sims)

for label, model_config in tqdm(model_configs.items(), desc="models"):
    fit_pred_fn = get_multi_fit_predict(**model_config)

    for dataset_name in (pbar := tqdm(odds_datasets.datasets_names)):
        data, labels = odds_datasets.load(dataset_name)
        pbar.set_description(f"{dataset_name}, shape: {data.shape}")
        scores = fit_pred_fn(rng_keys, data)
        ap = np.array([average_precision_score(labels, s) for s in scores])
        save_path = f"results/fit_only/{dataset_name}_{label}.npy"
        np.save(save_path, ap, allow_pickle=True)

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(15, 30))
for i, dataset_name in enumerate((pbar := tqdm(odds_datasets.datasets_names))):
    plt.subplot(6,3,i+1)
    for j, (label, model_config) in enumerate(model_configs.items()):
        save_path = f"results/fit_only/{dataset_name}_{label}.npy"
        ap = np.load(save_path, allow_pickle=True)
        plt.violinplot([ap], positions=[j], showmeans=True)
    plt.xticks(range(len(model_configs)), model_configs.keys())
    plt.title(dataset_name)
    plt.grid()
    plt.ylim(0, 1)
    plt.tight_layout()
