In [1]:
import jax.numpy as np
import jax

In [2]:
def _length(a):
    return np.linalg.norm(a, axis=1)

ORIGIN = np.array((0, 0))

SEED = jax.random.PRNGKey(0)


def circle(radius=1, center=ORIGIN):
    def f(p):
        return _length(p - center) - radius

    return f

In [3]:
a = np.array([(0, 1), (2, 3)])

In [4]:
a

Array([[0, 1],
       [2, 3]], dtype=int32)

In [5]:
np.linalg.norm(a, axis=0) ** 2

Array([ 4., 10.], dtype=float32)

In [6]:
sdf = circle()

In [7]:
sdf(np.array([(0, 0)]))

Array([-1.], dtype=float32)

In [8]:
N = 1000
shape = (N, 2)
dtype = np.float32

random_numbers = jax.random.uniform(SEED, shape, dtype, -1, 1)
computed_sdf = sdf(random_numbers) < 0
np.sum(computed_sdf) / N * 4

Array(3.16, dtype=float32)

In [28]:
def inside(sdf):
   def f(p):
       distance = sdf(p)
       return 1-jax.nn.sigmoid(100* distance)

   return f



In [10]:
N = 1000
shape = (N, 2)
dtype = np.float32

random_numbers = jax.random.uniform(SEED, shape, dtype, -1, 1)
computed_sdf = inside(sdf)(random_numbers)
np.sum(computed_sdf) / N * 4

Array(3.150337, dtype=float32)

In [29]:
def area(sdf, bounds=[-1, 1], n=1000):
    points = jax.random.uniform(SEED, (n, 2), np.float32, bounds[0], bounds[1])
    computed_sdf = inside(sdf)(points)
    return np.sum(computed_sdf) / n * (bounds[1] - bounds[0])**2
    

In [49]:
area(circle(1), n=10000, bounds=[-1.2, 1.2])

Array(3.1196575, dtype=float32)

In [56]:
jax.value_and_grad(lambda x: area(circle(x), n=1000000, bounds=[-1.2, 1.2]))(np.array([1.0]))

(Array(3.144044, dtype=float32), Array([6.2955976], dtype=float32))