In [None]:
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".1"
os.environ["CUDA_VISIBLE_DEVICES"]="2"
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from balif import Balif
from isolation_forest import ExtendedIsolationForest

In [None]:
def plot_data(
    data_inlier,
    data_anomaly,
    x_lims: tuple[int, int] = (-5, 5),
    y_lims: tuple[int, int] = (-5, 5),
    
):
    plt.figure(figsize=(5, 5), dpi=80)
    plt.scatter(data_inlier[:, 0], data_inlier[:, 1], marker="o", c="grey", s=10, label="inlier")
    plt.scatter(
        data_anomaly[:, 0], data_anomaly[:, 1], marker="o", c="darksalmon", s=10, label="anomaly"
    )
    plt.legend()
    plt.xlim(*x_lims)
    plt.ylim(*y_lims)
    plt.grid()

def plot_heatmap(
    model,
    queried = None,
    vmin: float = 0.35,
    vmax: float = 0.65,
    x_lims: tuple[int, int] = (-5, 5),
    y_lims: tuple[int, int] = (-5, 5),
):
    X, Y = jnp.meshgrid(jnp.linspace(*x_lims, 100), jnp.linspace(*y_lims, 100))
    coord = jnp.stack([X.flatten(), Y.flatten()]).T
    scores = model.score_samples(coord)

    plt.figure(figsize=(6, 5), dpi=80)
    plt.contourf(X, Y, scores.reshape(100, 100), levels=8, cmap="YlOrRd", vmin=vmin, vmax=vmax)
    plt.colorbar()
    if queried is not None:
        plt.scatter(queried[:, 0], queried[:, 1], facecolors="none", edgecolors="black", s=20)
    plt.xticks([])
    plt.yticks([])
    

Double Blob

In [None]:
N_labels = 10

rng_anom, rng_inlier, rng_forest = jax.random.split(jax.random.PRNGKey(0), 3)
data_anomaly = 0.5*jax.random.normal(rng_inlier, (64, 2)) - 2
data_inlier = 0.5*jax.random.normal(rng_anom, (256, 2)) + 2
data = jnp.concatenate([data_anomaly, data_inlier], axis=0)
is_anomaly = jnp.concatenate([jnp.ones(len(data_anomaly)), jnp.zeros(len(data_inlier))]).astype(bool)

plot_data(data_inlier, data_anomaly)
plt.savefig(f"figures/blobs_dataset.pdf", bbox_inches="tight")

In [None]:
for hyp_comp, model_name in enumerate(["IF", "EIF"], start=1):
    forest = Balif.fit(rng_forest, data, hyperplane_components = hyp_comp, max_samples = 64)
    plot_heatmap(forest)
    plt.savefig(f"figures/blobs_{model_name}_0_queries.pdf", bbox_inches="tight")

    queriable = jnp.ones(len(data)).astype(bool)
    for i in range(N_labels):
        interests = forest.interest_for(data)
        query_idx = jnp.where(queriable, interests, 0.0).argmax()
        forest = forest.register(data[query_idx], is_anomaly[query_idx])
        queriable = queriable.at[query_idx].set(False)
        plot_heatmap(forest, queried=data[~queriable])
        plt.savefig(f"figures/blobs_{model_name}_{i+1}_queries.pdf", bbox_inches="tight")

Wave

In [None]:
N_labels = 10

rng_datax, rng_datay, rng_forest = jax.random.split(jax.random.PRNGKey(42), 3)
datax = jax.random.normal(rng_datax, (512,))*jnp.pi/2
datay = 2*jnp.sin(3*datax) + 0.3*jax.random.normal(rng_datay, (512,))
data = jnp.stack([datax, datay], axis=1)
is_anomaly = jnp.pi/2 < data[:, 0]#(jnp.pi/8 < data[:, 0]) & (data[:, 0] < jnp.pi*3/8)
data_anomaly = data[is_anomaly]
data_inlier = data[~is_anomaly]
data = jnp.concatenate([data_anomaly, data_inlier], axis=0)
is_anomaly = jnp.concatenate([jnp.ones(len(data_anomaly)), jnp.zeros(len(data_inlier))]).astype(bool)
plot_data(data_inlier, data_anomaly)
plt.savefig(f"figures/wave_dataset.pdf", bbox_inches="tight")


for hyp_comp, model_name in enumerate(["IF", "EIF"], start=1):
    forest = Balif.fit(rng_forest, data, hyperplane_components = hyp_comp, max_samples = 64)
    plot_heatmap(forest)
    plt.savefig(f"figures/wave_{model_name}_0_queries.pdf", bbox_inches="tight")

    queriable = jnp.ones(len(data)).astype(bool)
    for i in range(N_labels):
        interests = forest.interest_for(data)
        query_idx = jnp.where(queriable, interests, 0.0).argmax()
        forest = forest.register(data[query_idx], is_anomaly[query_idx])
        queriable = queriable.at[query_idx].set(False)
        plot_heatmap(forest, queried=data[~queriable])
        plt.savefig(f"figures/wave_{model_name}_{i+1}_queries.pdf", bbox_inches="tight")