In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="2"
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

In [None]:
from isolation_forest import ExtendedIsolationForest

rng_data, rng_forest = jax.random.split(jax.random.PRNGKey(42))
data = jax.random.normal(rng_data, (1024, 2))
forest = ExtendedIsolationForest.fit(rng_forest, data, hyperplane_components=1)

X, Y = jnp.meshgrid(jnp.linspace(-5, 5, 100), jnp.linspace(-5, 5, 100))
coord = jnp.stack([X.flatten(), Y.flatten()]).T
scores = forest.score_samples(coord)
plt.figure(figsize=(12, 6))
plt.subplot(122)
plt.imshow(scores.reshape(100, 100), extent=(-5, 5, 5, -5), cmap="YlOrRd")
plt.colorbar()
plt.subplot(121)
plt.scatter(data[:, 0], data[:, 1], marker="o", c="grey", s=10)
plt.xlim(-5, 5)
plt.ylim(-5, 5)
plt.grid()
plt.show()

In [None]:
data_dim = 2
rng_data, rng_forest = jax.random.split(jax.random.PRNGKey(42))
data = jax.random.normal(rng_data, (10000, data_dim))

In [None]:
forest = ExtendedIsolationForest.fit(rng_forest, data, hyperplane_components = 1)
scores = forest.score_samples(data)
%timeit ExtendedIsolationForest.fit(rng_forest, data, hyperplane_components = 1).trees.normals.block_until_ready()
%timeit forest.score_samples(data).block_until_ready()

In [None]:
vectorized_fit = jax.vmap(ExtendedIsolationForest.fit, in_axes=(0, None))
vectorized_score = jax.vmap(ExtendedIsolationForest.score_samples, in_axes=(0, None))
rng = jax.random.split(jax.random.PRNGKey(42), 32)
forests = vectorized_fit(rng, data)
scores = vectorized_score(forests, data)
%timeit vectorized_fit(rng, data).trees.normals.block_until_ready()
%timeit vectorized_score(forests, data).block_until_ready()

In [None]:
from sklearn.ensemble import IsolationForest
model = IsolationForest(n_estimators=128)
model.fit(data)
%timeit model.fit(data)
%timeit model.score_samples(data)