In [None]:
import os
#%env JAX_PLATFORMS=cpu
%env CUDA_VISIBLE_DEVICES=0
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
os.chdir('..')
import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt

In [None]:
from forests import IsolationForest
from balif import Balif

model_configs: dict = {
    "IF": dict(
        hyperplane_components=1,
        p_normal_idx="uniform",
        p_normal_value="uniform",
        p_intercept="uniform",
    ),
    "BIF": dict(
        hyperplane_components=1,
        p_normal_idx="range",
        p_normal_value="covariant",
        p_intercept="uniform",
    ),
    "EIF": dict(
        hyperplane_components=2,
        p_normal_idx="uniform",
        p_normal_value="uniform",
        p_intercept="uniform",
    ),
    "BEIF": dict(
        hyperplane_components=2,
        p_normal_idx="range",
        p_normal_value="covariant",
        p_intercept="uniform",
    ),
}

In [None]:
rng_anom, rng_inlier, rng_forest, rng_labels = jr.split(jr.PRNGKey(0), 4)
data_anomaly = jr.normal(rng_inlier, (64, 2))
data_inlier = 0.5 * jr.normal(rng_anom, (1024, 2))
data_inlier = data_inlier.at[:, 1].mul(0.2)
data_inlier = data_inlier.at[:, 0].mul(5)
data = jnp.concatenate([data_anomaly, data_inlier], axis=0)
is_anomaly = jnp.concatenate(
    [jnp.ones(len(data_anomaly)), jnp.zeros(len(data_inlier))]
).astype(bool)


plt.figure(figsize=(5, 5), dpi=80)
plt.scatter(data_inlier[:, 0], data_inlier[:, 1], c="grey", s=10, label="inlier")
plt.scatter(
    data_anomaly[:, 0], data_anomaly[:, 1], c="darksalmon", s=10, label="anomaly"
)
plt.legend()
plt.xlim(-5, 5)
plt.ylim(-5, 5)
plt.grid()
plt.show()

In [None]:
def plot_heatmap(model, title="Anomaly Score", vmin=None, vmax=None, key=jr.PRNGKey(0)):
    X, Y = jnp.meshgrid(jnp.linspace(-5, 5, 100), jnp.linspace(-5, 5, 100))
    coord = jnp.stack([X.flatten(), Y.flatten()]).T
    scores = model.score(coord, key=key)
    plt.figure(figsize=(6, 5), dpi=80)
    plt.title(title)
    plt.contourf(
        X, Y, scores.reshape(100, 100), levels=16, cmap="cividis", vmin=vmin, vmax=vmax
    )
    plt.colorbar()
    plt.xticks([])
    plt.yticks([])


for model_name, model_config in model_configs.items():
    model = Balif(**model_config)
    model = model.fit(data, key=rng_forest)
    plot_heatmap(model, title=model_name)
    plt.show()
    for i in range(1):
        model = model.register(data[i], is_anomaly=is_anomaly[i], key=jr.key(0))
        model.register(data[-i], is_anomaly=is_anomaly[-i], key=jr.key(0))
    plot_heatmap(model, title=model_name)
    plt.show()