# JAX Jit Experiments

Some quick investigations into the performance differences between the jit and non-jit version of some basic SDF operations.

Takeaways:

- No issues with Steve's new SDF library (`polymorph-s2df`) and JAX's jit.
- On CPU, we get a reasonable speedup: the JIT version of calculating the area is 20x faster than the interpreter version.

In [1]:
from polymorph_s2df import *
import jax.numpy as jnp
import jax
import matplotlib.pyplot as plt

jax.devices()

[CpuDevice(id=0)]

In [2]:
def render(shape, bounds=(-3, 3), n=500):
    x = jnp.linspace(bounds[0], bounds[1], n)
    X, Y = jnp.meshgrid(x, x)

    grid_points = jnp.column_stack((X.flatten(), Y.flatten()))
    vec_is_inside = jax.jit(jax.vmap(lambda x: 1 - shape.is_inside(x)))
    plt.imshow(vec_is_inside(grid_points).reshape(n, n), cmap="gray", origin="lower", extent=[bounds[0], bounds[1],bounds[0], bounds[1]])

In [3]:
import time

COUNT = 10000

def bench(fn, *args, count=COUNT):
    vg = jax.value_and_grad(fn)
    start = time.time()
    for i in range(0, count):
        vg(*args)
    print(f"Non-jit: {1000 * (time.time() - start) / count}ms / iter")

    start = time.time()
    jvg = jax.jit(vg)
    for i in range(0, count):
        jvg(*args)
    print(f"Jit: {1000 * (time.time() - start) / count}ms / iter")

In [4]:
import time

def drawing(r1, r2, dx):
    return Union(Circle(r1), Translation(jnp.array([dx, 0]), Circle(r2)))

def dist_to_origin(*args):
    return drawing(*args).distance(jnp.array([0, 0]))

# bench(dist_to_origin, 1., 1., 1.)

In [7]:
import math

def generate_grid_points2(n):
    spacing = int(math.sqrt(n))
    x = jnp.linspace(BOUNDS[0], BOUNDS[1], spacing)
    y = jnp.linspace(BOUNDS[0], BOUNDS[1], spacing)
    xx, yy = jnp.meshgrid(x, y)
    return jnp.stack([xx.ravel(), yy.ravel()], axis=-1)

BOUNDS = (-3, 3)
n = 10000
rand_points = jax.random.uniform(jax.random.PRNGKey(0), (n, 2), jnp.float32, BOUNDS[0], BOUNDS[1])
grid_points = generate_grid_points2(n)

def area(shape, points, bounds=[-1, 1]):
    inside_points = jax.vmap(shape.is_inside)(points)
    return jnp.average(inside_points, axis=0) * (bounds[1] - bounds[0])**2

def circle_area_rand(r1, r2, d):
    s = drawing(r1, r2, d)
    return area(s, rand_points, BOUNDS)

def circle_area_grid(r1, r2, d):
    s = drawing(r1, r2, d)
    return area(s, grid_points, BOUNDS)

print(jax.devices())
print(f"Default backend: {jax.default_backend()}")
print('rand --')
print(circle_area_rand(1., 1., 2.))
bench(circle_area_rand, 1., 1., 2., count=1000)
print('grid --')
print(circle_area_grid(1., 1., 2.))
bench(circle_area_grid, 1., 1., 2., count=1000)

[CpuDevice(id=0)]
Default backend: cpu
rand --
6.397424
Non-jit: 3.08685302734375ms / iter
Jit: 0.10831212997436523ms / iter
grid --
6.1465154
Non-jit: 2.8333771228790283ms / iter
Jit: 0.11165332794189453ms / iter
