In [None]:
%env CUDA_VISIBLE_DEVICES=2
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from balif import Balif

In [None]:
def plot_data(
    data_inlier,
    data_anomaly,
    x_lims: tuple[int, int] = (-5, 5),
    y_lims: tuple[int, int] = (-5, 5),
):
    plt.figure(figsize=(5, 5), dpi=80)
    plt.scatter(data_inlier[:, 0], data_inlier[:, 1], marker="o", c="grey", s=10, label="inlier")
    plt.scatter(
        data_anomaly[:, 0], data_anomaly[:, 1], marker="o", c="darksalmon", s=10, label="anomaly"
    )
    plt.legend()
    plt.xlim(*x_lims)
    plt.ylim(*y_lims)
    plt.grid()


def plot_heatmap(
    model,
    queried=None,
    vmin=0,
    vmax=1,
    x_lims: tuple[int, int] = (-5, 5),
    y_lims: tuple[int, int] = (-5, 5),
    plot_interests: bool = False,
):
    X, Y = jnp.meshgrid(jnp.linspace(*x_lims, 100), jnp.linspace(*y_lims, 100))
    coord = jnp.stack([X.flatten(), Y.flatten()]).T
    scores = model.score_samples(coord)
    interests = model.interest_for(coord)

    if plot_interests:
        plt.figure(figsize=(12, 5), dpi=80)
        plt.subplot(1, 2, 2)
        plt.title("Interests")
        plt.contourf(X, Y, interests.reshape(100, 100), levels=16, cmap="cividis")
        plt.colorbar()
        if queried is not None:
            plt.scatter(queried[:, 0], queried[:, 1], facecolors="none", edgecolors="black", s=20)
        plt.xticks([])
        plt.yticks([])

        plt.subplot(1, 2, 1)
    else:
        plt.figure(figsize=(6, 5), dpi=80)
    plt.title("Anomaly Score")
    plt.contourf(X, Y, scores.reshape(100, 100), levels=16, cmap="cividis", vmin=vmin, vmax=vmax)
    plt.colorbar()
    if queried is not None:
        plt.scatter(queried[:, 0], queried[:, 1], facecolors="none", edgecolors="black", s=20)
    plt.xticks([])
    plt.yticks([])

In [None]:
model_configs = {
    "BALIF": {"model_cls":Balif, "hyperplane_components": 1, "path_score": False},
    "EBALIF": {"model_cls":Balif, "hyperplane_components": None, "path_score": False},
    # "BALIF (Path)": {"model_cls":Balif, "hyperplane_components": 1, "path_score": True},
    # "EBALIF (Path)": {"model_cls":Balif, "hyperplane_components": None, "path_score": True},
}

Double Blob

In [None]:
N_labels = 3

rng_anom, rng_inlier, rng_forest, rng_labels = jax.random.split(jax.random.PRNGKey(0), 4)
data_anomaly = 0.5*jax.random.normal(rng_inlier, (64, 2)) - 2
data_inlier = 0.5*jax.random.normal(rng_anom, (256, 2)) + 2
data = jnp.concatenate([data_anomaly, data_inlier], axis=0)
is_anomaly = jnp.concatenate([jnp.ones(len(data_anomaly)), jnp.zeros(len(data_inlier))]).astype(bool)

plot_data(data_inlier, data_anomaly)
plt.savefig(f"figures/blobs/blobs_dataset.pdf", bbox_inches="tight")

for model_name, model_config in model_configs.items():
    (_, model_cls), *model_config = list(model_config.items())
    forest = model_cls.fit(rng_forest, data, **dict(model_config))
    plot_heatmap(forest, plot_interests=False)
    plt.savefig(f"figures/blobs/blobs_{model_name}_0_queries.pdf", bbox_inches="tight")

    queriable = jnp.ones(len(data)).astype(bool)
    for i, key in enumerate(jax.random.split(rng_labels, N_labels)):
        #interests = forest.interest_for(data)
        #queries_idx = jnp.where(queriable, interests, interests.min()).argmax()
        queries_idx = jax.random.choice(key, len(data)) 
        queriable = queriable.at[queries_idx].set(False)
        forest = forest.register(data[queries_idx], is_anomaly[queries_idx])
        plot_heatmap(forest, queried=data[~queriable], plot_interests=True)
        plt.savefig(f"figures/blobs/blobs_{model_name}_{i+1}_queries.pdf", bbox_inches="tight")
        plt.show()

Wave

In [None]:
N_labels = 3

rng_datax, rng_datay, rng_forest, rng_labels = jax.random.split(jax.random.PRNGKey(0), 4)
datax = jax.random.normal(rng_datax, (512,))*jnp.pi/2
datay = 2*jnp.sin(3*datax) + 0.3*jax.random.normal(rng_datay, (512,))
data = jnp.stack([datax, datay], axis=1)
is_anomaly = jnp.pi/2 < data[:, 0]#(jnp.pi/8 < data[:, 0]) & (data[:, 0] < jnp.pi*3/8)
data_anomaly = data[is_anomaly]
data_inlier = data[~is_anomaly]
data = jnp.concatenate([data_anomaly, data_inlier], axis=0)
is_anomaly = jnp.concatenate([jnp.ones(len(data_anomaly)), jnp.zeros(len(data_inlier))]).astype(bool)

plot_data(data_inlier, data_anomaly)
plt.savefig(f"figures/wave/wave_dataset.pdf", bbox_inches="tight")

for model_name, model_config in model_configs.items():
    (_, model_cls), *model_config = list(model_config.items())
    forest = model_cls.fit(rng_forest, data, **dict(model_config))
    plot_heatmap(forest, plot_interests=False)
    plt.savefig(f"figures/wave/wave_{model_name}_0_queries.pdf", bbox_inches="tight")

    queriable = jnp.ones(len(data)).astype(bool)
    for i, key in enumerate(jax.random.split(rng_labels, N_labels)):
        queries_idx = jax.random.choice(key, len(data), ())        
        #interests = forest.interest_for(data)
        #queries_idx = jnp.where(queriable, interests, interests.min()).argmax()
        queriable = queriable.at[queries_idx].set(False)
        forest = forest.register(data[queries_idx], is_anomaly[queries_idx])
        plot_heatmap(forest, queried=data[~queriable], plot_interests=True)
        plt.savefig(f"figures/wave/wave_{model_name}_{i+1}_queries.pdf", bbox_inches="tight")
        plt.show()