In [None]:
import jax.numpy as jnp
import jax
# jax.config.update("jax_debug_nans", True)
# jax.config.update('jax_disable_jit', True)

import matplotlib.pyplot as plt
import time
from PIL import Image, ImageDraw


In [None]:
image_width, image_height = 200, 100
image = Image.new('L', (image_width, image_height))
draw = ImageDraw.Draw(image)

draw.circle([20, 50], 10, outline="white")
draw.circle([100, 20], 5, outline="white")

im = jnp.array(image)

def gaussian_blur(image, kernel_size=3, sigma=1.0):
    """Applies a Gaussian blur to the input image."""
    def gaussian_kernel(size, sigma):
        """Generates a Gaussian kernel."""
        ax = jnp.arange(-size // 2 + 1., size // 2 + 1.)
        xx, yy = jnp.meshgrid(ax, ax)
        kernel = jnp.exp(-0.5 * (jnp.square(xx) + jnp.square(yy)) / jnp.square(sigma))
        return kernel / jnp.sum(kernel)
    
    kernel = gaussian_kernel(kernel_size, sigma)
    return jax.scipy.signal.convolve2d(image, kernel, mode='same')

im_blurred = gaussian_blur(im)

plt.imshow(im_blurred, origin='lower')


In [None]:
def smooth_abs(x):
    return x * jnp.tanh(10.0 * x)

def softplus(x):
    return jax.nn.softplus(50 * x) / 50

def softminus(x):
    return x - softplus(x)

def box_unsigned(x, y, width, height):
    q_x = softplus(smooth_abs(x) - width / 2)
    q_y = softplus(smooth_abs(y) - height / 2)
    return jnp.sqrt(q_x*q_x + q_y*q_y) + softminus(jnp.maximum(q_x, q_y))


def box(x, y, width, height):
    # ref https://iquilezles.org/articles/distfunctions/
    p = jnp.array([x,y])
    b = jnp.array([width/2, height/2])
    q = smooth_abs(p) - b
    return jnp.linalg.norm(jnp.maximum(q, 0.0)) + softminus(jnp.max(q))

def circle(x, y, radius):
    p = jnp.array([x,y])
    return jnp.linalg.norm(p) - radius


def signed_distance_field(params, image_width, image_height):
    """Creates a signed distance field of a rectangle."""
    def distance_to_rectangle(x, y, rect_x, rect_y, rect_w, rect_h):
        return box(x - rect_x, y - rect_y, rect_w, rect_h)

    def distance_to_circle(x, y, circle_x, circle_y, circle_r):
        return circle(x - circle_x, y - circle_y, circle_r)
    
    x_coords = jnp.arange(image_width)
    y_coords = jnp.arange(image_height)
    xx, yy = jnp.meshgrid(x_coords, y_coords)
    

    # rect_x, rect_y, rect_w, rect_h = params
    # sdf = jax.vmap(jax.vmap(lambda x, y: distance_to_rectangle(x, y, rect_x, rect_y, rect_w, rect_h)))(yy, xx)    

    circle_x, circle_y, circle_r = params
    sdf = jax.vmap(jax.vmap(lambda x, y: distance_to_circle(x, y, circle_x, circle_y, circle_r)))(yy, xx)    
    return sdf


params  = jnp.array([50.1, 20.1, 20.1])
sdf = signed_distance_field(params, image_width, image_height)
plt.figure(figsize=(6, 6))
plt.imshow(sdf, cmap='BrBG', origin='lower', vmin=-jnp.max(jnp.abs(sdf)), vmax=jnp.max(jnp.abs(sdf)))
plt.colorbar(label='Signed Distance')
plt.contour(sdf, levels=[0], colors='black')
plt.title('Signed Distance Field')
plt.xlabel('X')
plt.ylabel('Y')
plt.show()

In [None]:
import optax
import optax.tree_utils as otu

from functools import partial
_lbfgs = optax.lbfgs()


@partial(jax.jit, static_argnames=["fun"])
def _run_lbfgs(init_params, lbfgs_state, fun, max_steps, tolerance, **kwargs):
    value_and_grad_fun = optax.value_and_grad_from_state(fun)
    
    def step(carry):
        params, state = carry
        value, grad = value_and_grad_fun(params, **kwargs, state=state)
        #jax.debug.print("value: {value}, grad: {grad}, params: {params}", value=value, grad=grad, params=params)

        updates, state = _lbfgs.update(
            grad, state, params, **kwargs, value=value, grad=grad, value_fn=fun
        )
        params = optax.apply_updates(params, updates)

        return params, state

    def continuing_criterion(carry):
        _, state = carry
        iter_num = otu.tree_get(state, "count")
        grad = otu.tree_get(state, "grad")
        err = otu.tree_l2_norm(grad)
        return (iter_num == 0) | ((iter_num < max_steps) & (err >= tolerance))

    if lbfgs_state is None:
        state = _lbfgs.init(init_params)
    else:
        # reset iteration count leftover from previous solve
        state = otu.tree_set(lbfgs_state, count=0)


    init_carry = (init_params, state)
    final_params, final_state = jax.lax.while_loop(
        continuing_criterion, step, init_carry
    )
    return final_params, final_state


In [None]:
def solve(params, loss_fn, target_image):
    nonzero_points = jnp.argwhere(target_image)

    start_time = time.time()
    soln, state = _run_lbfgs(
        params,
        None, # minimizer state
        loss_fn,
        max_steps=1000,
        tolerance=1e-3,
        nonzero_points=nonzero_points,
    )
    end_time = time.time()
    iter_num = otu.tree_get(state, "count")
    print(f"Elapsed time: {end_time - start_time:.5f} seconds; {iter_num} iterations")

    return soln

In [None]:
@jax.jit
def loss_fn_circle(params, nonzero_points):
    def distance_to_circle(x, y, circle_x, circle_y, circle_r):
        return circle(x - circle_x, y - circle_y, circle_r)
    
    loss = jax.vmap(lambda idx: distance_to_circle(idx[1], idx[0], *params[0:3]))(nonzero_points)**2.0
    return loss.sum()

In [None]:
solution = solve(jnp.array([10., 10, 5]), loss_fn_circle, im_blurred)

def plot_circle_on_image(image, circles):
    fig, ax = plt.subplots()
    ax.imshow(image, cmap='gray', origin='lower')
    for (cx, cy, r) in circles:
        circle = plt.Circle((cx, cy), r, color='r', fill=False)
        ax.add_patch(circle)
    plt.show()

plot_circle_on_image(im_blurred, solution.reshape(-1, 3))

In [None]:
@jax.jit
def loss_fn_two_circles(params, nonzero_points):
    def distance_to_circle(x, y, circle_x, circle_y, circle_r):
        return circle(x - circle_x, y - circle_y, circle_r)
    loss = jax.vmap(lambda idx: jnp.minimum(distance_to_circle(idx[1], idx[0], *params[0:3]), distance_to_circle(idx[1], idx[0], *params[3:6]))**2)(nonzero_points)    
    return loss.sum()

In [None]:
params = jnp.array([10., 10, 5, 150, 20, 5])
solution = solve(params, loss_fn_two_circles, im_blurred)
plot_circle_on_image(im_blurred, solution.reshape(-1, 3))