In [None]:
import os
#%env JAX_PLATFORMS=cpu
%env CUDA_VISIBLE_DEVICES=0
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
os.chdir('..')
import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
from balif import Balif

In [None]:
def plot_data(
    data_inlier,
    data_anomaly,
    x_lims: tuple[int, int] = (-5, 5),
    y_lims: tuple[int, int] = (-5, 5),
):
    plt.figure(figsize=(5, 5), dpi=80)
    plt.scatter(data_inlier[:, 0], data_inlier[:, 1], c="grey", s=10, label="inlier")
    plt.scatter(
        data_anomaly[:, 0],
        data_anomaly[:, 1],
        c="darksalmon",
        s=10,
        label="anomaly",
    )
    plt.legend()
    plt.xlim(*x_lims)
    plt.ylim(*y_lims)
    plt.grid()


def plot_heatmap(
    model, queried=None, vmin=0, vmax=1, plot_interests=True, key=jr.key(0)
):
    X, Y = jnp.meshgrid(jnp.linspace(-5, 5, 100), jnp.linspace(-5, 5, 100))
    coord = jnp.stack([X.flatten(), Y.flatten()]).T
    scores = model.score(coord, key=key)
    interests = model.interest(coord, key=key)

    if plot_interests:
        plt.figure(figsize=(12, 5), dpi=80)
        plt.subplot(1, 2, 2)
        plt.title("Interests")
        plt.contourf(X, Y, interests.reshape(100, 100), levels=16, cmap="cividis")
        plt.colorbar()
        if queried is not None:
            plt.scatter(
                queried[:, 0],
                queried[:, 1],
                facecolors="none",
                edgecolors="black",
                s=20,
            )
        plt.xticks([])
        plt.yticks([])

        plt.subplot(1, 2, 1)
    else:
        plt.figure(figsize=(6, 5), dpi=80)
    plt.title("Anomaly Score")
    plt.contourf(
        X, Y, scores.reshape(100, 100), levels=16, cmap="cividis", vmin=vmin, vmax=vmax
    )
    plt.colorbar()
    if queried is not None:
        plt.scatter(
            queried[:, 0], queried[:, 1], facecolors="none", edgecolors="black", s=20
        )
    plt.xticks([])
    plt.yticks([])

In [None]:
model_configs: dict = {
    "IF": dict(
        hyperplane_components=1,
        p_normal_idx="uniform",
        p_normal_value="uniform",
        p_intercept="uniform",
    ),
    "BIF": dict(
        hyperplane_components=1,
        p_normal_idx="range",
        p_normal_value="covariant",
        p_intercept="uniform",
    ),
    "EIF": dict(
        hyperplane_components=2,
        p_normal_idx="uniform",
        p_normal_value="uniform",
        p_intercept="uniform",
    ),
    "BEIF": dict(
        hyperplane_components=2,
        p_normal_idx="range",
        p_normal_value="covariant",
        p_intercept="uniform",
    ),
}

Double Blob

In [None]:
N_labels = 3

rng_anom, rng_inlier, rng_forest, rng_labels = jr.split(jr.PRNGKey(0), 4)
data_anomaly = 0.5 * jr.normal(rng_inlier, (64, 2)) - 2
data_inlier = 0.5 * jr.normal(rng_anom, (256, 2)) + 2
data = jnp.concatenate([data_anomaly, data_inlier], axis=0)
is_anomaly = jnp.zeros(len(data)).at[: len(data_anomaly)].set(True)

plot_data(data_inlier, data_anomaly)
os.makedirs(f"figures/blobs/", exist_ok=True)
plt.savefig(f"figures/blobs/blobs_dataset.pdf", bbox_inches="tight")

for model_name, model_config in model_configs.items():
    model = Balif(**model_config)
    model = model.fit(data, key=rng_forest)
    plot_heatmap(model)
    os.makedirs(f"figures/blobs/blobs_{model_name}", exist_ok=True)
    plt.savefig(f"figures/blobs/blobs_{model_name}/0_queries.pdf", bbox_inches="tight")

    queriable = jnp.ones(len(data)).astype(bool)
    for i, key in enumerate(jr.split(rng_labels, N_labels)):
        interests = model.interest(data, key=key)
        queries_idx = jnp.where(queriable, interests, interests.min()).argmax()
        queriable = queriable.at[queries_idx].set(False)
        model = model.register(data[queries_idx], is_anomaly=is_anomaly[queries_idx], key=key)
        plot_heatmap(model, queried=data[~queriable])
        plt.savefig(f"figures/blobs/blobs_{model_name}/{i+1}_queries.pdf", bbox_inches="tight")