In [None]:
import jax.numpy as jnp
import jax
#jax.config.update("jax_debug_nans", True)
#jax.config.update('jax_disable_jit', True)
jax.config.update("jax_enable_x64", True)
jnp.set_printoptions(formatter={'float': '{:.2e}'.format})

import matplotlib.pyplot as plt
import time
from PIL import Image, ImageDraw


In [None]:
image_width, image_height = 200, 100

def gaussian_blur(image, kernel_size=3, sigma=1.0):
    """Applies a Gaussian blur to the input image."""
    def gaussian_kernel(size, sigma):
        """Generates a Gaussian kernel."""
        ax = jnp.arange(-size // 2 + 1., size // 2 + 1.)
        xx, yy = jnp.meshgrid(ax, ax)
        kernel = jnp.exp(-0.5 * (jnp.square(xx) + jnp.square(yy)) / jnp.square(sigma))
        return kernel / jnp.sum(kernel)
    
    kernel = gaussian_kernel(kernel_size, sigma)
    return jax.scipy.signal.convolve2d(image, kernel, mode='same')

def create_image(filename, draw_function):
    image = Image.new('L', (image_width, image_height))
    draw = ImageDraw.Draw(image)

    draw_function(draw)

    im_blurred = gaussian_blur(jnp.array(image))
    im_blurred = im_blurred / im_blurred.max()
    plt.figure(figsize=(image_width/100, image_height/100), dpi=100)
    plt.imshow(im_blurred, cmap='gray', origin='lower')
    plt.axis('off')
    plt.tight_layout(pad=0)
    plt.savefig(filename, bbox_inches='tight', pad_inches=0)
    plt.close()
    return im_blurred

In [None]:
true_solution = jnp.array([20., 20, 10])
params = jnp.array([10.1, 10.1, 5])

def draw_one_circle(draw):
    draw.circle([true_solution[0], true_solution[1]], true_solution[2], outline="white")

im_blurred = create_image("target_one_circle.png", draw_one_circle)

In [None]:
def smooth_abs(x):
    return x * jnp.tanh(10.0 * x)

def softplus(x):
    return jax.nn.softplus(50 * x) / 50

def softminus(x):
    return x - softplus(x)

def box_unsigned(x, y, width, height):
    q_x = softplus(smooth_abs(x) - width / 2)
    q_y = softplus(smooth_abs(y) - height / 2)
    return jnp.sqrt(q_x*q_x + q_y*q_y) + softminus(jnp.maximum(q_x, q_y))


def box(x, y, width, height):
    # ref https://iquilezles.org/articles/distfunctions/
    p = jnp.array([x,y])
    b = jnp.array([width/2, height/2])
    q = smooth_abs(p) - b
    return jnp.linalg.norm(jnp.maximum(q, 0.0)) + softminus(jnp.max(q))

def circle(x, y, radius):
    p = jnp.array([x,y])
    return jnp.linalg.norm(p) - jnp.abs(radius)

def distance_to_rectangle(x, y, rect_x, rect_y, rect_w, rect_h):
    return box(x - rect_x, y - rect_y, rect_w, rect_h)

def distance_to_circle(x, y, circle_x, circle_y, circle_r):
    return circle(x - circle_x, y - circle_y, circle_r)


def signed_distance_field(params, image_width, image_height):

    x_coords = jnp.arange(image_width)
    y_coords = jnp.arange(image_height)
    xx, yy = jnp.meshgrid(x_coords, y_coords)

    # rect_x, rect_y, rect_w, rect_h = params
    # sdf = jax.vmap(jax.vmap(lambda x, y: distance_to_rectangle(x, y, rect_x, rect_y, rect_w, rect_h)))(yy, xx)    

    circle_x, circle_y, circle_r = params
    sdf = jax.vmap(jax.vmap(lambda x, y: distance_to_circle(x, y, circle_x, circle_y, circle_r)))(yy, xx)    
    return sdf


circle_params  = jnp.array([50.1, 20.1, 20.1])
sdf = signed_distance_field(circle_params, image_width, image_height)
plt.figure(figsize=(6, 6))
plt.imshow(sdf, cmap='BrBG', origin='lower', vmin=-jnp.max(jnp.abs(sdf)), vmax=jnp.max(jnp.abs(sdf)))
plt.colorbar(label='Signed Distance')
plt.contour(sdf, levels=[0], colors='black')
plt.title('Signed Distance Field')
plt.xlabel('X')
plt.ylabel('Y')
plt.show()

In [None]:
def plot_circle_on_image(image, circles):
    fig, ax = plt.subplots()
    ax.imshow(image, cmap='gray', origin='lower')
    colors = ['r', 'g', 'b', 'y', 'c']  # Define up to 5 colors
    for i, (cx, cy, r) in enumerate(circles):
        color = colors[i % len(colors)]  # Cycle through colors
        circle = plt.Circle((cx, cy), r, color=color, fill=False)
        ax.add_patch(circle)
    plt.show()

In [None]:
import optax
import optax.tree_utils as otu

from functools import partial
_lbfgs = optax.lbfgs()


@partial(jax.jit, static_argnames=["fun"])
def _run_lbfgs(init_params, lbfgs_state, fun, max_steps, tolerance, **kwargs):
    value_and_grad_fun = optax.value_and_grad_from_state(fun)
    
    def step(carry):
        params, state = carry
        value, grad = value_and_grad_fun(params, **kwargs, state=state)
        #jax.debug.print("value: {value}, grad: {grad}, params: {params}", value=value, grad=grad, params=params)

        updates, state = _lbfgs.update(
            grad, state, params, **kwargs, value=value, grad=grad, value_fn=fun
        )
        params = optax.apply_updates(params, updates)

        return params, state

    def continuing_criterion(carry):
        _, state = carry
        iter_num = otu.tree_get(state, "count")
        grad = otu.tree_get(state, "grad")
        err = otu.tree_l2_norm(grad)
        return (iter_num == 0) | ((iter_num < max_steps) & (err >= tolerance))

    if lbfgs_state is None:
        state = _lbfgs.init(init_params)
    else:
        # reset iteration count leftover from previous solve
        state = otu.tree_set(lbfgs_state, count=0)


    init_carry = (init_params, state)
    final_params, final_state = jax.lax.while_loop(
        continuing_criterion, step, init_carry
    )
    return final_params, final_state


def solve(params, loss_fn, target_image):
    nonzero_points = jnp.argwhere(target_image)

    start_time = time.time()
    soln, state = _run_lbfgs(
        params,
        None, # minimizer state
        loss_fn,
        max_steps=1000,
        tolerance=1e-3,
        nonzero_points=nonzero_points,
    )
    end_time = time.time()
    iter_num = otu.tree_get(state, "count")
    print(f"Elapsed time: {end_time - start_time:.5f} seconds; {iter_num} iterations")

    return soln

In [None]:
@jax.jit
def loss_fn_circle(params, nonzero_points):
    loss = jax.vmap(lambda point: distance_to_circle(point[1], point[0], *params[0:3]))(nonzero_points)**2.0
    return loss.sum()

In [None]:
solution = solve(jnp.array([10., 10, 5]), loss_fn_circle, im_blurred)
plot_circle_on_image(im_blurred, solution.reshape(-1, 3))
loss_fn_circle(solution, jnp.argwhere(im_blurred))

In [None]:
@jax.jit
def loss_fn_circle(params, nonzero_points):
    def distance_to_circle(x, y, circle_x, circle_y, circle_r):
        return circle(x - circle_x, y - circle_y, circle_r)
    
    loss = jax.vmap(lambda point: distance_to_circle(point[1], point[0], *params[0:3]))(nonzero_points)**2.0
    return loss.sum()

In [None]:
solution = solve(jnp.array([10., 10, 5]), loss_fn_circle, im_blurred)
plot_circle_on_image(im_blurred, solution.reshape(-1, 3))

# Two Circles

In [None]:
true_solution = jnp.array([20., 50, 10, 100, 51, 10])
params = jnp.array([20.1, 50.1, 9, 100.1, 50.1, 11])

def draw_two_circles(draw):
    draw.circle(true_solution[0:2], true_solution[2], outline="white")
    draw.circle(true_solution[3:5], true_solution[5], outline="white")
    
im_blurred = create_image("target_two_circles.png", draw_two_circles)

In [None]:
## See: "loss_matches_jax" test in /Users/dev/work/constraint/autodiff-tests/fidget-test/src/main.rs
a = jnp.array([30., 60, 5])
str(loss_fn_circle(a, jnp.argwhere(im_blurred)))

In [None]:
@jax.jit
def loss_fn_two_circles_minimum(params, nonzero_points):
    def sdf(point):
        return jnp.array([
                distance_to_circle(point[1], point[0], *params[0:3])**2,
                distance_to_circle(point[1], point[0], *params[3:6])**2,
                ]).min()
    loss = jax.vmap(sdf)(nonzero_points)
    return loss.sum()

solution = solve(params, loss_fn_two_circles_minimum, im_blurred)
plot_circle_on_image(im_blurred, solution.reshape(-1, 3))
solution

In [None]:
@jax.jit
def loss_fn_one_circle_log(params, nonzero_points):
    def sdf(point):
        return jnp.log(distance_to_circle(point[1], point[0], *params[0:3])**2)
    loss = jax.vmap(sdf)(nonzero_points)
    return loss.sum()

solution = solve(params, loss_fn_one_circle_log, im_blurred)
plot_circle_on_image(im_blurred, solution.reshape(-1, 3))

In [None]:
def soft_min(sdfs, beta=0.001):
    exp_sdfs = jnp.exp(-beta * sdfs)
    return -jnp.log(jnp.sum(exp_sdfs)) / beta

@jax.jit
def loss_fn_two_circles_soft_minimum(params, nonzero_points):
    def sdf(point):
        return soft_min(jnp.array([
                distance_to_circle(point[1], point[0], *params[0:3])**2,
                distance_to_circle(point[1], point[0], *params[3:6])**2,
                ]))
    loss = jax.vmap(sdf)(nonzero_points)
    return loss.sum()

solution = solve(params, loss_fn_two_circles_soft_minimum, im_blurred)
plot_circle_on_image(im_blurred, params.reshape(-1, 3))
plot_circle_on_image(im_blurred, solution.reshape(-1, 3))
solution

In [None]:
known_solution = jnp.array([50., 50, 10, 70, 30, 5])
nonzero_points = jnp.argwhere(im_blurred)
(loss_fn_two_circles_soft_minimum(known_solution, nonzero_points) # 300
, loss_fn_two_circles_soft_minimum(solution, nonzero_points)) # 11,370

In [None]:
def weighted_sum(sdfs, epsilon=0.01):
    weights = 1 / (sdfs**2 + epsilon)
    return (weights * sdfs).sum() / weights.sum()

@jax.jit
def loss_fn_two_circles_weighted_sum(params, nonzero_points):
    def sdf(point):
        return weighted_sum(jnp.array([
                distance_to_circle(point[1], point[0], *params[0:3]),
                distance_to_circle(point[1], point[0], *params[3:6]),
                ]))**2
    loss = jax.vmap(sdf)(nonzero_points)
    return loss.sum()

solution = solve(params, loss_fn_two_circles_weighted_sum, im_blurred)
plot_circle_on_image(im_blurred, solution.reshape(-1, 3))
solution

## New loss

Let's try a new approach where we evaluate the SDFs across the entire target image, not just on the nonzero pixels.
We have to wrap the gradient function so that it doesn't turn into NaN at bad points (e.g., center of circle)

In [None]:
import optax
import optax.tree_utils as otu

from functools import partial
_lbfgs = optax.lbfgs()


@partial(jax.jit, static_argnames=["fun"])
def _run_lbfgs(init_params, lbfgs_state, fun, max_steps, tolerance, **kwargs):
    value_and_grad_fun = optax.value_and_grad_from_state(fun)
    
    def step(carry):
        params, state = carry
        value, grad = value_and_grad_fun(params, **kwargs, state=state)

        jax.debug.print("loss: {value}, grad: {grad}, params: {params}", value=jnp.array([value]), grad=grad, params=params)
        grad = jnp.nan_to_num(grad, nan=0.0)

        updates, state = _lbfgs.update(
            grad, state, params, **kwargs, value=value, grad=grad, value_fn=fun
        )
        params = optax.apply_updates(params, updates)

        return params, state

    def continuing_criterion(carry):
        _, state = carry
        iter_num = otu.tree_get(state, "count")
        grad = otu.tree_get(state, "grad")
        err = otu.tree_l2_norm(grad)
        return (iter_num == 0) | ((iter_num < max_steps) & (err >= tolerance))

    if lbfgs_state is None:
        state = _lbfgs.init(init_params)
    else:
        # reset iteration count leftover from previous solve
        state = otu.tree_set(lbfgs_state, count=0)


    init_carry = (init_params, state)
    final_params, final_state = jax.lax.while_loop(
        continuing_criterion, step, init_carry
    )
    return final_params, final_state


def enumerate_pixels(image):
    y, x = jnp.meshgrid(jnp.arange(image.shape[0]), jnp.arange(image.shape[1]), indexing='ij')
    y = y.flatten()
    x = x.flatten()
    values = image.flatten()
    enumerated_pixels = jnp.column_stack((x, y, values))
    return enumerated_pixels

def solve(params, loss_fn, target_image):
    enumerated_pixels = enumerate_pixels(target_image)
    start_time = time.time()
    soln, state = _run_lbfgs(
        params,
        None, # minimizer state
        loss_fn,
        max_steps=1000,
        tolerance=1e-3,
        enumerated_pixels=enumerated_pixels,

    )
    end_time = time.time()
    iter_num = otu.tree_get(state, "count")
    print(f"Elapsed time: {end_time - start_time:.5f} seconds; {iter_num} iterations")

    return soln

In [None]:
def prior_penalty(value, min_value, max_value, penalty_coefficient=1000):
    return jnp.where(
        (value >= min_value) & (value <= max_value),
        0.0,
        penalty_coefficient * jnp.minimum((value - min_value)**2, (value - max_value)**2)
    )

@jax.jit
def pixel_loss(pixel, params):
    x, y, v = pixel
    distances = jnp.array([
            # distance_to_circle(x, y, *params[0:3])**2,
            # distance_to_circle(x, y, *params[3:6])**2,
            distance_to_circle(x, y, params[0], params[1], 5)**2,
            distance_to_circle(x, y, params[3], params[4], 5)**2,
    ])

    #distance = soft_min(distances, beta=0.01)
    distance = jnp.min(distances)


    # Shape distance from nonzero points is the most reliable way to get x/y to move.
    return (v > 0) * distance
    # return (v > 0) * jnp.exp(-distance)
        
    # should be 1 when close to target, 0 when far away
    #expected_intensity = jnp.exp(-distance)
    #return (expected_intensity - v)**2



@jax.jit
def loss(params, enumerated_pixels):
    loss = jax.vmap(lambda pixel: pixel_loss(pixel, params))(enumerated_pixels)
    loss = loss.sum()

    # prior on radii
    loss += prior_penalty(params[2], 4, 10)
    loss += prior_penalty(params[5], 4, 10)

    # stay on page
    loss += prior_penalty(params[0], 0, image_width)
    loss += prior_penalty(params[1], 0, image_height)
    loss += prior_penalty(params[3], 0, image_width)
    loss += prior_penalty(params[4], 0, image_height)
    
    return loss


# Create a heatmap of the pixel loss
def create_pixel_loss_heatmap(params, image):
    enumerated_pixels = enumerate_pixels(image)
    
    loss_values = jax.vmap(lambda pixel: pixel_loss(pixel, params))(enumerated_pixels)
    loss_heatmap = loss_values.reshape(image.shape)
    
    plt.figure(figsize=(10, 8))
    plt.imshow(loss_heatmap, cmap='hot', interpolation='nearest', origin='lower')
    plt.colorbar(label='Pixel Loss')
    plt.title('Heatmap of Pixel Loss')
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.show()



params = jnp.array([50.1, 20.1, 5, 105.1, 80.1, 5])

# Generate heatmap for the initial parameters
solution = solve(params, loss, im_blurred)
#create_pixel_loss_heatmap(solution, im_blurred)
plot_circle_on_image(im_blurred, params.reshape(-1, 3))
plot_circle_on_image(im_blurred, solution.reshape(-1, 3))

(solution,
loss(solution, enumerate_pixels(im_blurred)),
 loss(true_solution, enumerate_pixels(im_blurred)))