In [None]:
import jax.numpy as jnp
import jax
#jax.config.update("jax_debug_nans", True)
#jax.config.update('jax_disable_jit', True)
#jax.config.update("jax_enable_x64", True)
jnp.set_printoptions(formatter={'float': '{:.2e}'.format})

import matplotlib.pyplot as plt
import time
from PIL import Image, ImageDraw


In [None]:
image_width, image_height = 200, 100

image = Image.new('L', (image_width, image_height))
draw = ImageDraw.Draw(image)

style = {"fill": "white"}
#style = {"outline": "white"}

draw.circle([50, 30], 20, **style)
draw.circle([150, 30], 20, **style)

image = jnp.array(image)
image = image / image.max()
plt.imshow(image, cmap='gray', origin='lower')
plt.show()

def enumerate_pixels(image):
    y, x = jnp.meshgrid(jnp.arange(image.shape[0]), jnp.arange(image.shape[1]), indexing='ij')
    y = y.flatten()
    x = x.flatten()
    values = image.flatten()
    enumerated_pixels = jnp.column_stack((x, y, values))
    return enumerated_pixels

# x/y/value of all pixels
enumerated_pixels = enumerate_pixels(image)

# x/y coordinates of nonzero pixels
occupied_pixels = enumerated_pixels[enumerated_pixels[:, 2] > 0][:, 0:2]

In [None]:
def plot_loss_vs_radius(loss, radii_range=(0, 100), num_points=200, title='Loss of a circle centered between the target circles'):
    radii = jnp.linspace(radii_range[0], radii_range[1], num_points)
    center = jnp.array([100., 30.])
    losses = jnp.array([loss(jnp.concatenate([center, jnp.array([r])])) for r in radii])
    desired_solution_loss = loss(jnp.array([50., 30., 20.]))

    plt.figure(figsize=(10, 6))
    plt.plot(radii, losses)
    plt.title(title)
    plt.xlabel('Radius')
    plt.ylabel('Loss')
    plt.grid(True)
    plt.axhline(y=desired_solution_loss, color='r', linestyle='--', label='Desired solution')
    plt.legend()
    plt.gca().yaxis.set_major_formatter(plt.ScalarFormatter(useMathText=True))
    plt.gca().ticklabel_format(style='sci', axis='y', scilimits=(0,0))
    plt.show()

In [None]:
def circle(x, y, radius):
    p = jnp.array([x,y])
    return jnp.linalg.norm(p) - jnp.abs(radius)

def distance_to_circle(x, y, circle_x, circle_y, circle_r):
    return circle(x - circle_x, y - circle_y, circle_r)

# Squared distance loss

In [None]:
@jax.jit
def loss_fn_squared_distance(params):
    loss = jax.vmap(lambda point: distance_to_circle(point[0], point[1], *params[0:3]))(occupied_pixels)**2.0
    return loss.sum()
plot_loss_vs_radius(loss_fn_squared_distance)

# Inside losses

In [None]:
@jax.jit
def loss_fn_unoccupied_inside(params):
    def f(pixel):
        x, y, occupied = pixel
        dist = distance_to_circle(x, y, *params[0:3])
        return jnp.where(dist > 0, 0, 1-occupied)
    loss = jax.vmap(f)(enumerated_pixels)
    return loss.sum()
plot_loss_vs_radius(loss_fn_unoccupied_inside, title="loss = unoccupied pixels inside the circle")

In [None]:
@jax.jit
def loss_fn_unoccupied_ratio(params):
    def f(pixel):
        x, y, occupied = pixel
        dist = distance_to_circle(x, y, *params[0:3])
        return jnp.where(dist > 0, 0, jnp.array([1, occupied]))
    inside_pixels = jax.vmap(f)(enumerated_pixels)
    total_inside, occupied_inside = inside_pixels.sum(axis=0)
    return 1 - (occupied_inside / total_inside)
plot_loss_vs_radius(loss_fn_unoccupied_ratio, title="loss = unoccupied / total pixels")

## Chamfer loss

From Perplexity: https://www.perplexity.ai/search/given-a-signed-distance-field-5XWR4dzITHCpXcVcbm4CdA

The Chamfer distance is an excellent choice for comparing SDFs to pixel images while allowing for partial matching:

- For each point on the SDF surface, find the closest pixel in the image.
- For each occupied pixel in the image, find the closest point on the SDF surface.
- Sum the squared distances for both directions.

This loss encourages the SDF to fit at least some parts of the image well, rather than trying to fit everything poorly1.


In [None]:
@jax.jit
def loss_fn_chamfer(params):
    # if closer than threshold, consider the pixel to be on the surface
    threshold = 5.0

    def pixel_losses(pixel):
        x, y, v = pixel
        loc = pixel[0:2]
        dist = jnp.abs(distance_to_circle(x, y, *params[0:3]))
        
        sdf_dist_loss = jnp.where(v > 0, dist, 0.0)
        closest_occupied_pixel_distance = jnp.where(
            enumerated_pixels[:, 2] > 0,
            jnp.linalg.norm(enumerated_pixels[:, 0:2] - loc, axis=1),
            jnp.inf
        ).min()
        occupied_pixel_dist_loss = jnp.where(v < threshold, closest_occupied_pixel_distance, 0.0)
        
        return jnp.array([
            sdf_dist_loss,
            occupied_pixel_dist_loss,
        ])

    sdf_dist_loss, occupied_pixel_dist_loss = jax.vmap(pixel_losses)(enumerated_pixels).T
    return jnp.abs(sdf_dist_loss).mean() + jnp.abs(occupied_pixel_dist_loss).mean()

# this takes about 2 minutes to compute on my Mac M1
plot_loss_vs_radius(loss_fn_chamfer)