In [None]:
import jax.numpy as jnp
import numpy as np
import jax
import matplotlib.pyplot as plt

In [None]:
def _length(a):
    return jnp.linalg.norm(a, axis=1)

ORIGIN = jnp.array((0, 0))

SEED = jax.random.PRNGKey(0)


def circle(radius=1, center=(0, 0)):
    center = jnp.array(center)
    def f(p):
        return _length(p - center) - radius

    return f

""" float sdBox( in vec2 p, in vec2 b )
{
    vec2 d = abs(p)-b;
    return length(max(d,0.0)) + min(max(d.x,d.y),0.0);
} """

# https://www.youtube.com/watch?v=62-pRVZuS5c
def box(bounds=jnp.array([1.0, 2.0])):
    b = bounds / 2
    def f(p):
        d = jnp.abs(p) - b
        return (_length(jnp.maximum(d, 0.0))
                 + jnp.minimum(jnp.max(d), 0.0))
    return f

def translate(sdf, offset):
    offset = jnp.array(offset)
    def f(p):
        return sdf(p - offset)

    return f

In [None]:
def render(sdf, bounds=(-1, 1), n=500):
    x = jnp.linspace(bounds[0], bounds[1], n)
    X, Y = jnp.meshgrid(x, x)

    grid_points = jnp.column_stack((X.flatten(), Y.flatten()))
    plt.imshow(1-inside(sdf)(grid_points).reshape(n, n), cmap="gray", origin="lower", extent=[bounds[0], bounds[1],bounds[0], bounds[1]])

In [None]:
def naive_area(sdf, bounds=[-1, 1], n=1000):
    points = jax.random.uniform(SEED, (n, 2), jnp.float32, bounds[0], bounds[1])
    computed_sdf = sdf(random_numbers) < 0
    return jnp.sum(computed_sdf) / n * (bounds[1] - bounds[0])**2

In [None]:
def inside(sdf):
   def f(p):
       distance = sdf(p)
       return 1-jax.nn.sigmoid(100* distance)

   return f

In [None]:
def area(sdf, bounds=[-1, 1], n=1000):
    points = jax.random.uniform(SEED, (n, 2), jnp.float32, bounds[0], bounds[1])
    computed_sdf = inside(sdf)(points)
    return jnp.sum(computed_sdf) / n * (bounds[1] - bounds[0])**2
    

In [None]:
area(circle(1), n=10000, bounds=[-1.2, 1.2])

In [None]:
jax.value_and_grad(lambda x: area(circle(x), n=1000000, bounds=[-1.2, 1.2]))(jnp.array([1.0]))

In [None]:
def centroid(sdf, bounds=[-1, 1], n=1000):
    points = jax.random.uniform(SEED, (n, 2), jnp.float32, bounds[0], bounds[1])
    inside_points = inside(sdf)(points)
    return jnp.average(points, weights=inside_points, axis=0)

In [None]:
points = jax.random.uniform(SEED, (100000, 2), jnp.float32, -1, 3)
inside_points = inside(translate(circle(), jnp.array([1, 0])))(points)

jnp.average(points, weights=inside_points, axis=0)

In [None]:
sdf = translate(circle(), (0, 1))
center = centroid(sdf, bounds=(-5, 5), n=10000)
print(center)

render(sdf, (-5, 5))
plt.scatter(0, 0, c="red")
plt.scatter(center[0], center[1], c="blue")


