In [None]:
import jax.numpy as jnp
import jax
#jax.config.update("jax_debug_nans", True)
#jax.config.update('jax_disable_jit', True)
#jax.config.update("jax_enable_x64", True)
jnp.set_printoptions(formatter={'float': '{:.2e}'.format})

import matplotlib.pyplot as plt
import matplotlib as mpl

# Set the style to a dark theme
plt.style.use('dark_background')

# Adjust additional properties for better visibility
mpl.rcParams['axes.facecolor'] = '#2F2F2F'
mpl.rcParams['figure.facecolor'] = '#1E1E1E'
mpl.rcParams['text.color'] = 'white'
mpl.rcParams['axes.labelcolor'] = 'white'
mpl.rcParams['xtick.color'] = 'white'
mpl.rcParams['ytick.color'] = 'white'

import time
from PIL import Image, ImageDraw

In [None]:
image_width, image_height = 200, 100

image = Image.new('L', (image_width, image_height))
draw = ImageDraw.Draw(image)

style = {"fill": "white"}
#style = {"outline": "white"}

draw.circle([50, 30], 20, **style)
draw.circle([150, 30], 20, **style)

image = jnp.array(image)
image = image / image.max()
plt.imshow(image, cmap='gray', origin='lower')
plt.show()

def enumerate_pixels(image):
    y, x = jnp.meshgrid(jnp.arange(image.shape[0]), jnp.arange(image.shape[1]), indexing='ij')
    y = y.flatten()
    x = x.flatten()
    values = image.flatten()
    enumerated_pixels = jnp.column_stack((x, y, values))
    return enumerated_pixels

# x/y/value of all pixels
enumerated_pixels = enumerate_pixels(image)

# x/y coordinates of nonzero pixels
occupied_pixels = enumerated_pixels[enumerated_pixels[:, 2] > 0][:, 0:2]

In [None]:
def plot_loss_vs_radius(loss, radii_range=(0, 100), num_points=200, title='Loss of a circle centered between the target circles'):
    radii = jnp.linspace(radii_range[0], radii_range[1], num_points)
    center = jnp.array([100., 30.])
    losses = jnp.array([loss(jnp.concatenate([center, jnp.array([r])])) for r in radii])
    desired_solution_loss = loss(jnp.array([50., 30., 20.]))

    plt.figure(figsize=(10, 6))
    plt.plot(radii, losses)
    plt.title(title)
    plt.xlabel('Radius')
    plt.ylabel('Loss')
    plt.grid(True)
    plt.axhline(y=desired_solution_loss, color='r', linestyle='--', label='Desired solution')
    plt.legend()
    plt.gca().yaxis.set_major_formatter(plt.ScalarFormatter(useMathText=True))
    plt.gca().ticklabel_format(style='sci', axis='y', scilimits=(0,0))
    plt.show()

In [None]:
def circle(x, y, radius):
    p = jnp.array([x,y])
    return jnp.linalg.norm(p) - jnp.abs(radius)

def distance_to_circle(x, y, circle_x, circle_y, circle_r):
    return circle(x - circle_x, y - circle_y, circle_r)

# Squared distance loss

In [None]:
@jax.jit
def loss_fn_squared_distance(params):
    loss = jax.vmap(lambda point: distance_to_circle(point[0], point[1], *params[0:3]))(occupied_pixels)**2.0
    return loss.sum()
plot_loss_vs_radius(loss_fn_squared_distance)

# Inside losses

In [None]:
@jax.jit
def loss_fn_unoccupied_inside(params):
    def f(pixel):
        x, y, occupied = pixel
        dist = distance_to_circle(x, y, *params[0:3])
        return jnp.where(dist > 0, 0, 1-occupied)
    loss = jax.vmap(f)(enumerated_pixels)
    return loss.sum()
plot_loss_vs_radius(loss_fn_unoccupied_inside, title="loss = unoccupied pixels inside the circle")

In [None]:
@jax.jit
def loss_fn_unoccupied_ratio(params):
    def f(pixel):
        x, y, occupied = pixel
        dist = distance_to_circle(x, y, *params[0:3])
        return jnp.where(dist > 0, 0, jnp.array([1, occupied]))
    inside_pixels = jax.vmap(f)(enumerated_pixels)
    total_inside, occupied_inside = inside_pixels.sum(axis=0)
    return 1 - (occupied_inside / total_inside)
plot_loss_vs_radius(loss_fn_unoccupied_ratio, title="loss = 1 - occupied / total pixels")

In [None]:
@jax.jit
def loss_fn_unoccupied_ratio_sigmoid(params):
    def f(pixel):
        x, y, occupied = pixel
        dist = distance_to_circle(x, y, *params[0:3])
        inside_prob = jax.nn.sigmoid(-dist * 1.0)  # Negative dist because we want 1 inside, 0 outside
        return jnp.array([inside_prob, inside_prob * occupied])
        
    inside_pixels = jax.vmap(f)(enumerated_pixels)
    total_inside, occupied_inside = inside_pixels.sum(axis=0)
    return 1 - (occupied_inside / total_inside)

plot_loss_vs_radius(loss_fn_unoccupied_ratio_sigmoid, title="loss = 1 - occupied / total pixels (sigmoid)")

In [None]:
import jax.numpy as jnp
import matplotlib.pyplot as plt

@jax.jit
def loss_fn_unoccupied_ratio_sigmoid_fixed_radius(params):
    fixed_params = jnp.concatenate([params[:2], jnp.array([20.0])])
    return loss_fn_unoccupied_ratio_sigmoid(fixed_params)

grad_fn = jax.grad(loss_fn_unoccupied_ratio_sigmoid_fixed_radius)

def pixel_gradient(pixel):
    x, y, _ = pixel
    # Add small offset to avoid division by zero when calculating gradient
    pixel_params = jnp.array([x+0.01, y+0.01])
    return grad_fn(pixel_params)

gradient_map = jax.vmap(pixel_gradient)(enumerated_pixels)

In [None]:
# Calculate the magnitude of the gradient
gradient_magnitude = jnp.linalg.norm(gradient_map, axis=1)

# Reshape the gradient magnitude to match the image dimensions
gradient_magnitude_2d = gradient_magnitude.reshape(image.shape)

# Create a figure and axis
fig, ax = plt.subplots(figsize=(10, 6))

# Plot the gradient magnitude as a heatmap
im = ax.imshow(gradient_magnitude_2d, cmap='viridis', origin='lower')

# Add a colorbar
cbar = fig.colorbar(im, ax=ax)
cbar.set_label('Gradient Magnitude')

# Set the title and labels
ax.set_title('Magnitude of x/y gradient of loss function')
ax.set_xlabel('X')
ax.set_ylabel('Y')

# Show the plot
plt.show()

In [None]:
# Calculate the loss only for points 20 px away from the edges
margin = 20
inner_pixels = enumerated_pixels[(enumerated_pixels[:, 0] >= margin) & 
                                 (enumerated_pixels[:, 0] < image.shape[1] - margin) & 
                                 (enumerated_pixels[:, 1] >= margin) & 
                                 (enumerated_pixels[:, 1] < image.shape[0] - margin)]

loss_map = jax.vmap(loss_fn_unoccupied_ratio_sigmoid_fixed_radius)(inner_pixels)

# Create a full-sized loss map filled with NaN
full_loss_map = jnp.full(enumerated_pixels.shape[0], jnp.nan)

# Fill in the calculated loss values
inner_indices = jnp.where((enumerated_pixels[:, 0] >= margin) & 
                          (enumerated_pixels[:, 0] < image.shape[1] - margin) & 
                          (enumerated_pixels[:, 1] >= margin) & 
                          (enumerated_pixels[:, 1] < image.shape[0] - margin))[0]
full_loss_map = full_loss_map.at[inner_indices].set(loss_map)

# Reshape the loss map to match the image dimensions
loss_map_2d = full_loss_map.reshape(image.shape)

# Create a figure and axis
fig, ax = plt.subplots(figsize=(10, 6))

# Plot the loss map as a heatmap
im = ax.imshow(loss_map_2d, cmap='viridis', origin='lower')

# Add a colorbar
cbar = fig.colorbar(im, ax=ax)
cbar.set_label('Loss Value')

# Set the title and labels
ax.set_title('Heatmap of loss_fn_unoccupied_ratio_sigmoid_fixed_radius (20px margin)')
ax.set_xlabel('X')
ax.set_ylabel('Y')

# Find local minima (only in the inner region)
from scipy.ndimage import minimum_filter
inner_loss_map_2d = loss_map_2d[margin:-margin, margin:-margin]
local_min = minimum_filter(inner_loss_map_2d, size = 50) == inner_loss_map_2d
minima_coords = jnp.array(jnp.where(local_min))
minima_coords += margin  # Adjust coordinates to match full image

# Mark the minima with red dots
ax.scatter(minima_coords[1], minima_coords[0], color='red', s=20, label='Local Minima')

# Add legend
ax.legend()

# Show the plot
plt.show()


In [None]:
@jax.jit
def gradient_wrt_radius(params):
    return jax.grad(loss_fn_unoccupied_ratio_sigmoid)(params)[2]

@jax.jit
def loss_fn(params):
    return loss_fn_unoccupied_ratio_sigmoid(params)

# Generate radius values
radii = jnp.linspace(0, 50, 500)

# Calculate gradients and losses
gradients = jax.vmap(lambda r: gradient_wrt_radius(jnp.array([50, 30, r])))(radii)
losses = jax.vmap(lambda r: loss_fn(jnp.array([50, 30, r])))(radii)

# Plot the results
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12), sharex=True)

# Plot gradient
ax1.plot(radii, gradients)
ax1.set_title('Gradient with respect to radius (circle center: 50, 30)')
ax1.set_ylabel('Gradient')
ax1.grid(True)

# Plot loss
ax2.plot(radii, losses)
ax2.set_title('Loss with respect to radius (circle center: 50, 30)')
ax2.set_xlabel('Radius')
ax2.set_ylabel('Loss')
ax2.grid(True)

plt.tight_layout()
plt.show()

## Chamfer loss

From Perplexity: https://www.perplexity.ai/search/given-a-signed-distance-field-5XWR4dzITHCpXcVcbm4CdA

The Chamfer distance is an excellent choice for comparing SDFs to pixel images while allowing for partial matching:

- For each point on the SDF surface, find the closest pixel in the image.
- For each occupied pixel in the image, find the closest point on the SDF surface.
- Sum the squared distances for both directions.

This loss encourages the SDF to fit at least some parts of the image well, rather than trying to fit everything poorly1.


In [None]:
@jax.jit
def loss_fn_chamfer(params):
    # if closer than threshold, consider the pixel to be on the surface
    threshold = 5.0

    def pixel_losses(pixel):
        x, y, v = pixel
        loc = pixel[0:2]
        dist = jnp.abs(distance_to_circle(x, y, *params[0:3]))
        
        sdf_dist_loss = jnp.where(v > 0, dist, 0.0)
        closest_occupied_pixel_distance = jnp.where(
            enumerated_pixels[:, 2] > 0,
            jnp.linalg.norm(enumerated_pixels[:, 0:2] - loc, axis=1),
            jnp.inf
        ).min()
        occupied_pixel_dist_loss = jnp.where(v < threshold, closest_occupied_pixel_distance, 0.0)
        
        return jnp.array([
            sdf_dist_loss,
            occupied_pixel_dist_loss,
        ])

    sdf_dist_loss, occupied_pixel_dist_loss = jax.vmap(pixel_losses)(enumerated_pixels).T
    return jnp.abs(sdf_dist_loss).mean() + jnp.abs(occupied_pixel_dist_loss).mean()

# this takes about 2 minutes to compute on my Mac M1
# plot_loss_vs_radius(loss_fn_chamfer)

# Fitting with LBFGS

In [None]:
import optax
import optax.tree_utils as otu

from functools import partial
_lbfgs = optax.lbfgs()


@partial(jax.jit, static_argnames=["fun"])
def _run_lbfgs(init_params, lbfgs_state, fun, max_steps, tolerance, **kwargs):
    value_and_grad_fun = optax.value_and_grad_from_state(fun)
    
    def step(carry):
        params, state = carry
        value, grad = value_and_grad_fun(params, **kwargs, state=state)

        jax.debug.print("loss: {value}, grad: {grad}, params: {params}", value=jnp.array([value]), grad=grad, params=params)
        grad = jnp.nan_to_num(grad, nan=0.0)

        updates, state = _lbfgs.update(
            grad, state, params, **kwargs, value=value, grad=grad, value_fn=fun
        )
        params = optax.apply_updates(params, updates)

        return params, state

    def continuing_criterion(carry):
        _, state = carry
        iter_num = otu.tree_get(state, "count")
        grad = otu.tree_get(state, "grad")
        err = otu.tree_l2_norm(grad)
        return (iter_num == 0) | ((iter_num < max_steps) & (err >= tolerance))

    if lbfgs_state is None:
        state = _lbfgs.init(init_params)
    else:
        # reset iteration count leftover from previous solve
        state = otu.tree_set(lbfgs_state, count=0)


    init_carry = (init_params, state)
    final_params, final_state = jax.lax.while_loop(
        continuing_criterion, step, init_carry
    )
    return final_params, final_state


def solve(params, loss_fn):
    start_time = time.time()
    soln, state = _run_lbfgs(
        params,
        None, # minimizer state
        loss_fn,
        max_steps=5,
        tolerance=1e-7,
    )
    end_time = time.time()
    iter_num = otu.tree_get(state, "count")
    print(f"Elapsed time: {end_time - start_time:.5f} seconds; {iter_num} iterations")

    return soln

def plot_circles_on_image(circles):
    fig, ax = plt.subplots()
    ax.imshow(image, cmap='gray', origin='lower')
    colors = ['r', 'g', 'b', 'y', 'c']  # Define up to 5 colors
    for i, (cx, cy, r) in enumerate(circles):
        color = colors[i % len(colors)]  # Cycle through colors
        circle = plt.Circle((cx, cy), r, color=color, fill=False)
        ax.add_patch(circle)
    plt.show()

In [None]:
params = jnp.array([50.0, 30.0, 40])

solution = solve(params, loss_fn_unoccupied_ratio_sigmoid)
print(solution)
plot_circles_on_image(solution.reshape(-1, 3))


In [None]:
# Define the search parameters
num_samples = 200
x_min, x_max = 0.1, 200.1
y_min, y_max = 0.1, 100.1
fixed_radius = 20

# Define epsilon for comparison
epsilon = 1e-2

# Function to check if solution is close to actual positions
def is_close_to_actual(solution, actual_position, epsilon):
    return jnp.all(jnp.abs(solution[:2] - actual_position) < epsilon)

# Actual positions of the two circles
actual_positions = jnp.array([[50, 30], [150, 30]])

# Generate random samples
key = jax.random.PRNGKey(0)  # Set a random seed for reproducibility
x_samples = jax.random.uniform(key, (num_samples,), minval=x_min, maxval=x_max)
y_samples = jax.random.uniform(jax.random.split(key)[1], (num_samples,), minval=y_min, maxval=y_max)

# Perform random search
results = []
for x, y in zip(x_samples, y_samples):
    initial_params = jnp.array([x, y, fixed_radius])
    solution = solve(initial_params, loss_fn_unoccupied_ratio_sigmoid_fixed_radius)
    
    # Check if solution is close to either of the actual positions
    is_solution = any(is_close_to_actual(solution, pos, epsilon) for pos in actual_positions)
    results.append((initial_params, is_solution))

# Plot the initial points
fig, ax = plt.subplots(figsize=(10, 6))
ax.imshow(image, cmap='gray', origin='lower')

for initial_params, is_solution in results:
    color = 'green' if is_solution else 'red'
    ax.plot(initial_params[0], initial_params[1], 'o', color=color, markersize=5)

ax.set_title("Initial Points (Green: Found Solution, Red: No Solution)")
plt.show()

# Print summary
num_solutions = sum(is_solution for _, is_solution in results)
print(f"Found solutions for {num_solutions} out of {num_samples} initial points.")
