In [None]:
import os
#%env JAX_PLATFORMS=cpu
%env CUDA_VISIBLE_DEVICES=0
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
os.chdir('..')
import jax
import jax.numpy as jnp
import jax.random as jr
import equinox as eqx
import matplotlib.pyplot as plt

In [None]:
distr

In [None]:
from balif import BetaDistribution

N = 100
alpha = jr.uniform(jr.PRNGKey(0), (N,), minval=-1, maxval=1)
beta = jr.uniform(jr.PRNGKey(1), (N,), minval=-1, maxval=1)
beliefs = BetaDistribution(alpha, beta)

var = BetaDistribution.from_mean_and_var(
    mean=beliefs.mean.mean(),
    var=beliefs.var.mean() / N,
)

nu = BetaDistribution.from_mean_and_nu(
    mean=beliefs.mean.mean(),
    nu=beliefs.nu.sum(),
)


In [None]:
x = jnp.linspace(0, 1, 100)
plt.plot(x, jax.vmap(var.pdf)(x), label='var')
plt.plot(x, jax.vmap(nu.pdf)(x), label='nu')
plt.legend()

In [None]:
from forests import IsolationForest
from balif import Balif

rng_data, rng_fit, rng_score = jr.split(jr.key(42), 3)
data = jr.normal(rng_data, (1024, 2))
forest = IsolationForest(hyperplane_components=1)
balif = Balif(hyperplane_components=1)

def heatmap(model):
    model = model.fit(data, key=rng_fit)    
    X, Y = jnp.meshgrid(jnp.linspace(-5, 5, 100), jnp.linspace(-5, 5, 100))
    coord = jnp.stack([X.flatten(), Y.flatten()]).T
    scores = model.score(coord, key=rng_score)
    plt.figure(figsize=(12, 6))
    plt.subplot(122)
    plt.imshow(scores.reshape(100, 100), extent=(-5, 5, 5, -5), cmap="YlOrRd")
    plt.colorbar()
    plt.subplot(121)
    plt.scatter(data[:, 0], data[:, 1], marker="o", c="grey", s=10)
    plt.xlim(-5, 5)
    plt.ylim(-5, 5)
    plt.grid()
    plt.show()

heatmap(forest)
heatmap(balif)

In [None]:
data_dim = 2
data = jr.normal(rng_data, (10000, data_dim))
forest = forest.fit(data, key=rng_fit)
balif = balif.fit(data, key=rng_fit)
scores = forest.score(data, key=rng_score)
scores = balif.score(data, key=rng_score)

from sklearn.ensemble import IsolationForest
model = IsolationForest(n_estimators=128)
model.fit(data)
print("\nIForest (sklearn)")
print("fit time:")
%timeit model.fit(data)
print("score time:")
%timeit model.score_samples(data)

print("\nIForest (jax)")
print("fit time:")
%timeit forest.fit(data, key=rng_fit).trees.normals.block_until_ready()
print("score time:")
%timeit forest.score(data, key=rng_score).block_until_ready()

print("\nBalif (jax)")
print("fit time:")
%timeit balif.fit(data, key=rng_fit).beliefs.alpha.block_until_ready()
print("score time:")
%timeit balif.score(data, key=rng_score).block_until_ready()
print("update time:")
%timeit balif.register(data[0], key=rng_score, is_anomaly=False).beliefs.alpha.block_until_ready()
print("interest time:")
%timeit balif.interest(data, key=rng_score).block_until_ready()

In [None]:
vectorized_fit = eqx.filter_vmap(lambda key: forest.fit(data, key=key))
vectorized_score = eqx.filter_vmap(lambda forest: forest.score(data, key=rng_score))
forests = vectorized_fit(jr.split(rng_fit, 32))
scores = vectorized_score(forests)
%timeit vectorized_fit(jr.split(rng_fit, 32)).trees.normals.block_until_ready()
%timeit vectorized_score(forests).block_until_ready()

In [None]:
vectorized_fit = eqx.filter_vmap(lambda key: balif.fit(data, key=key))
vectorized_score = eqx.filter_vmap(lambda balif: balif.score(data, key=rng_score))
vectorized_register = eqx.filter_vmap(lambda balif: balif.register(data[0], key=jr.key(0), is_anomaly=False))
vectorized_interest = eqx.filter_vmap(lambda balif: balif.interest(data, key=rng_score))
balifs = vectorized_fit(jr.split(rng_fit, 32))
scores = vectorized_score(balifs)
balifs = vectorized_register(balifs)
%timeit vectorized_fit(jr.split(rng_fit, 32)).beliefs.alpha.block_until_ready()
%timeit vectorized_score(balifs).block_until_ready()
%timeit vectorized_register(balifs).beliefs.alpha.block_until_ready()
%timeit vectorized_interest(balifs).block_until_ready()