In [None]:
import jax.numpy as jnp
import jax
#jax.config.update("jax_debug_nans", True)
#jax.config.update('jax_disable_jit', True)
#jax.config.update("jax_enable_x64", True)
jnp.set_printoptions(formatter={'float': '{:.2e}'.format})

import matplotlib.pyplot as plt
import matplotlib as mpl

# Set the style to a dark theme
plt.style.use('dark_background')

# Adjust additional properties for better visibility
mpl.rcParams['axes.facecolor'] = '#2F2F2F'
mpl.rcParams['figure.facecolor'] = '#1E1E1E'
mpl.rcParams['text.color'] = 'white'
mpl.rcParams['axes.labelcolor'] = 'white'
mpl.rcParams['xtick.color'] = 'white'
mpl.rcParams['ytick.color'] = 'white'

import time
from PIL import Image, ImageDraw

def enumerate_pixels(image):
    y, x = jnp.meshgrid(jnp.arange(image.shape[0]), jnp.arange(image.shape[1]), indexing='ij')
    y = y.flatten()
    x = x.flatten()
    values = image.flatten()
    enumerated_pixels = jnp.column_stack((x, y, values))
    return enumerated_pixels

In [None]:
%matplotlib inline
image_width, image_height = 200, 100

image = Image.new('L', (image_width, image_height))
draw = ImageDraw.Draw(image)

style = {"fill": "white"}
#style = {"outline": "white"}
target_radius = 20

draw.circle([50, 30], target_radius, **style)
draw.circle([150, 30], target_radius, **style)

image = jnp.array(image)
image = image / image.max()

plt.imshow(image, cmap='gray', origin='lower')
plt.show()

# x/y/value of all pixels
enumerated_pixels = enumerate_pixels(image)

# x/y coordinates of nonzero pixels
occupied_pixels = enumerated_pixels[enumerated_pixels[:, 2] > 0][:, 0:2]

In [None]:
def plot_loss_vs_radius(loss, radii_range=(0, 100), num_points=200, title='Loss of a circle centered between the target circles'):
    radii = jnp.linspace(radii_range[0], radii_range[1], num_points)
    center = jnp.array([100., 30.])
    losses = jnp.array([loss(jnp.concatenate([center, jnp.array([r])])) for r in radii])
    desired_solution_loss = loss(jnp.array([50., 30., 20.]))

    plt.figure(figsize=(10, 6))
    plt.plot(radii, losses)
    plt.title(title)
    plt.xlabel('Radius')
    plt.ylabel('Loss')
    plt.grid(True)
    plt.axhline(y=desired_solution_loss, color='r', linestyle='--', label='Desired solution')
    plt.legend()
    plt.gca().yaxis.set_major_formatter(plt.ScalarFormatter(useMathText=True))
    plt.gca().ticklabel_format(style='sci', axis='y', scilimits=(0,0))
    plt.show()

In [None]:
def circle(x, y, radius):
    p = jnp.array([x,y])
    return jnp.linalg.norm(p) - jnp.abs(radius)

def distance_to_circle(x, y, circle_x, circle_y, circle_r):
    return circle(x - circle_x, y - circle_y, circle_r)

# Squared distance loss

In [None]:
@jax.jit
def loss_fn_squared_distance(params):
    loss = jax.vmap(lambda point: distance_to_circle(point[0], point[1], *params[0:3]))(occupied_pixels)**2.0
    return loss.sum()
plot_loss_vs_radius(loss_fn_squared_distance)

# Inside losses

In [None]:
@jax.jit
def loss_fn_unoccupied_inside(params):
    def f(pixel):
        x, y, occupied = pixel
        dist = distance_to_circle(x, y, *params[0:3])
        return jnp.where(dist > 0, 0, 1-occupied)
    loss = jax.vmap(f)(enumerated_pixels)
    return loss.sum()
plot_loss_vs_radius(loss_fn_unoccupied_inside, title="loss = unoccupied pixels inside the circle")

In [None]:
@jax.jit
def loss_fn_unoccupied_ratio(params):
    def f(pixel):
        x, y, occupied = pixel
        dist = distance_to_circle(x, y, *params[0:3])
        return jnp.where(dist > 0, 0, jnp.array([1, occupied]))
    inside_pixels = jax.vmap(f)(enumerated_pixels)
    total_inside, occupied_inside = inside_pixels.sum(axis=0)
    return 1 - (occupied_inside / total_inside)
plot_loss_vs_radius(loss_fn_unoccupied_ratio, title="loss = 1 - occupied / total pixels")

In [None]:
@jax.jit
def loss_fn_unoccupied_ratio_sigmoid(params):
    def f(pixel):
        x, y, occupied = pixel
        dist = distance_to_circle(x, y, *params[0:3])
        inside_prob = jax.nn.sigmoid(-dist * 1.0)  # Negative dist because we want 1 inside, 0 outside
        return jnp.array([inside_prob, inside_prob * occupied])
        
    inside_pixels = jax.vmap(f)(enumerated_pixels)
    total_inside, occupied_inside = inside_pixels.sum(axis=0)
    return 1 - (occupied_inside / total_inside)

plot_loss_vs_radius(loss_fn_unoccupied_ratio_sigmoid, title="loss = 1 - occupied / total pixels (sigmoid)")

In [None]:
import jax.numpy as jnp
import matplotlib.pyplot as plt

@jax.jit
def loss_fn_unoccupied_ratio_sigmoid_fixed_radius(params):
    fixed_params = jnp.concatenate([params[:2], jnp.array([20.0])])
    return loss_fn_unoccupied_ratio_sigmoid(fixed_params)

grad_fn = jax.grad(loss_fn_unoccupied_ratio_sigmoid_fixed_radius)

def pixel_gradient(pixel):
    x, y, _ = pixel
    # Add small offset to avoid division by zero when calculating gradient
    pixel_params = jnp.array([x+0.01, y+0.01])
    return grad_fn(pixel_params)

gradient_map = jax.vmap(pixel_gradient)(enumerated_pixels)

In [None]:
# Calculate the magnitude of the gradient
gradient_magnitude = jnp.linalg.norm(gradient_map, axis=1)

# Reshape the gradient magnitude to match the image dimensions
gradient_magnitude_2d = gradient_magnitude.reshape(image.shape)

# Create a figure and axis
fig, ax = plt.subplots(figsize=(10, 6))

# Plot the gradient magnitude as a heatmap
im = ax.imshow(gradient_magnitude_2d, cmap='viridis', origin='lower')

# Add a colorbar
cbar = fig.colorbar(im, ax=ax)
cbar.set_label('Gradient Magnitude')

# Set the title and labels
ax.set_title('Magnitude of x/y gradient of loss function')
ax.set_xlabel('X')
ax.set_ylabel('Y')

# Show the plot
plt.show()

In [None]:
# Calculate the loss only for points 20 px away from the edges
margin = 20
inner_pixels = enumerated_pixels[(enumerated_pixels[:, 0] >= margin) & 
                                 (enumerated_pixels[:, 0] < image.shape[1] - margin) & 
                                 (enumerated_pixels[:, 1] >= margin) & 
                                 (enumerated_pixels[:, 1] < image.shape[0] - margin)]

loss_map = jax.vmap(loss_fn_unoccupied_ratio_sigmoid_fixed_radius)(inner_pixels)

# Create a full-sized loss map filled with NaN
full_loss_map = jnp.full(enumerated_pixels.shape[0], jnp.nan)

# Fill in the calculated loss values
inner_indices = jnp.where((enumerated_pixels[:, 0] >= margin) & 
                          (enumerated_pixels[:, 0] < image.shape[1] - margin) & 
                          (enumerated_pixels[:, 1] >= margin) & 
                          (enumerated_pixels[:, 1] < image.shape[0] - margin))[0]
full_loss_map = full_loss_map.at[inner_indices].set(loss_map)

# Reshape the loss map to match the image dimensions
loss_map_2d = full_loss_map.reshape(image.shape)

# Create a figure and axis
fig, ax = plt.subplots(figsize=(10, 6))

# Plot the loss map as a heatmap
im = ax.imshow(loss_map_2d, cmap='viridis', origin='lower')

# Add a colorbar
cbar = fig.colorbar(im, ax=ax)
cbar.set_label('Loss Value')

# Set the title and labels
ax.set_title('Heatmap of loss_fn_unoccupied_ratio_sigmoid_fixed_radius (20px margin)')
ax.set_xlabel('X')
ax.set_ylabel('Y')

# Find local minima (only in the inner region)
from scipy.ndimage import minimum_filter
inner_loss_map_2d = loss_map_2d[margin:-margin, margin:-margin]
local_min = minimum_filter(inner_loss_map_2d, size = 50) == inner_loss_map_2d
minima_coords = jnp.array(jnp.where(local_min))
minima_coords += margin  # Adjust coordinates to match full image

# Mark the minima with red dots
ax.scatter(minima_coords[1], minima_coords[0], color='red', s=20, label='Local Minima')

# Add legend
ax.legend()

# Show the plot
plt.show()


In [None]:
@jax.jit
def gradient_wrt_radius(params):
    return jax.grad(loss_fn_unoccupied_ratio_sigmoid)(params)[2]

@jax.jit
def loss_fn(params):
    return loss_fn_unoccupied_ratio_sigmoid(params)

# Generate radius values
radii = jnp.linspace(0, 50, 500)

# Calculate gradients and losses
gradients = jax.vmap(lambda r: gradient_wrt_radius(jnp.array([50, 30, r])))(radii)
losses = jax.vmap(lambda r: loss_fn(jnp.array([50, 30, r])))(radii)

# Plot the results
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12), sharex=True)

# Plot gradient
ax1.plot(radii, gradients)
ax1.set_title('Gradient with respect to radius (circle center: 50, 30)')
ax1.set_ylabel('Gradient')
ax1.grid(True)

# Plot loss
ax2.plot(radii, losses)
ax2.set_title('Loss with respect to radius (circle center: 50, 30)')
ax2.set_xlabel('Radius')
ax2.set_ylabel('Loss')
ax2.grid(True)

plt.tight_layout()
plt.show()

## Chamfer loss

From Perplexity: https://www.perplexity.ai/search/given-a-signed-distance-field-5XWR4dzITHCpXcVcbm4CdA

The Chamfer distance is an excellent choice for comparing SDFs to pixel images while allowing for partial matching:

- For each point on the SDF surface, find the closest pixel in the image.
- For each occupied pixel in the image, find the closest point on the SDF surface.
- Sum the squared distances for both directions.

This loss encourages the SDF to fit at least some parts of the image well, rather than trying to fit everything poorly1.


In [None]:
@jax.jit
def loss_fn_chamfer(params):
    # if closer than threshold, consider the pixel to be on the surface
    threshold = 5.0

    def pixel_losses(pixel):
        x, y, v = pixel
        loc = pixel[0:2]
        dist = jnp.abs(distance_to_circle(x, y, *params[0:3]))
        
        sdf_dist_loss = jnp.where(v > 0, dist, 0.0)
        closest_occupied_pixel_distance = jnp.where(
            enumerated_pixels[:, 2] > 0,
            jnp.linalg.norm(enumerated_pixels[:, 0:2] - loc, axis=1),
            jnp.inf
        ).min()
        occupied_pixel_dist_loss = jnp.where(v < threshold, closest_occupied_pixel_distance, 0.0)
        
        return jnp.array([
            sdf_dist_loss,
            occupied_pixel_dist_loss,
        ])

    sdf_dist_loss, occupied_pixel_dist_loss = jax.vmap(pixel_losses)(enumerated_pixels).T
    return jnp.abs(sdf_dist_loss).mean() + jnp.abs(occupied_pixel_dist_loss).mean()

# this takes about 2 minutes to compute on my Mac M1
# plot_loss_vs_radius(loss_fn_chamfer)

# Fitting with LBFGS

In [None]:
import optax
import optax.tree_utils as otu

from functools import partial
_lbfgs = optax.lbfgs()


@partial(jax.jit, static_argnames=["fun"])
def _run_lbfgs(init_params, lbfgs_state, fun, max_steps, tolerance, **kwargs):
    value_and_grad_fun = optax.value_and_grad_from_state(fun)
    
    def step(carry):
        params, state = carry
        value, grad = value_and_grad_fun(params, **kwargs, state=state)

        jax.debug.print("loss: {value}, grad: {grad}, params: {params}", value=jnp.array([value]), grad=grad, params=params)
        grad = jnp.nan_to_num(grad, nan=0.0)

        updates, state = _lbfgs.update(
            grad, state, params, **kwargs, value=value, grad=grad, value_fn=fun
        )
        params = optax.apply_updates(params, updates)

        return params, state

    def continuing_criterion(carry):
        _, state = carry
        iter_num = otu.tree_get(state, "count")
        grad = otu.tree_get(state, "grad")
        err = otu.tree_l2_norm(grad)
        return (iter_num == 0) | ((iter_num < max_steps) & (err >= tolerance))

    if lbfgs_state is None:
        state = _lbfgs.init(init_params)
    else:
        # reset iteration count leftover from previous solve
        state = otu.tree_set(lbfgs_state, count=0)


    init_carry = (init_params, state)
    final_params, final_state = jax.lax.while_loop(
        continuing_criterion, step, init_carry
    )
    return final_params, final_state


def solve(params, loss_fn, max_steps=100, tolerance=1e-7, **kwargs):
    start_time = time.time()
    soln, state = _run_lbfgs(
        params,
        None, # minimizer state
        loss_fn,
        max_steps=max_steps,
        tolerance=tolerance,
        **kwargs,
    )
    end_time = time.time()
    iter_num = otu.tree_get(state, "count")
    print(f"Elapsed time: {end_time - start_time:.5f} seconds; {iter_num} iterations")

    return soln

def plot_circles_on_image(circles):
    fig, ax = plt.subplots()
    ax.imshow(image, cmap='gray', origin='lower')
    colors = ['r', 'g', 'b', 'y', 'c']  # Define up to 5 colors
    for i, (cx, cy, r) in enumerate(circles):
        color = colors[i % len(colors)]  # Cycle through colors
        circle = plt.Circle((cx, cy), r, color=color, fill=False)
        ax.add_patch(circle)
    plt.show()

In [None]:
params = jnp.array([50.1, 30.1, 40])

solution = solve(params, loss_fn_unoccupied_ratio_sigmoid, max_steps=10)
print(solution)
plot_circles_on_image(solution.reshape(-1, 3))


In [None]:
def perform_random_search(solve_function, loss_fn, initial_radius=20, image=image):
    # Define the search parameters
    num_samples = 200
    x_min, x_max = 0.1, 200.1
    y_min, y_max = 0.1, 100.1


    # Define epsilon for comparison
    epsilon = 1

    # Function to check if solution is close to actual positions
    def is_close_to_actual(solution, actual_position, epsilon):
        return jnp.all(jnp.abs(solution[:2] - actual_position) < epsilon)

    # Actual positions of the two circles
    actual_positions = jnp.array([[50, 30], [150, 30]])

    # Generate random samples
    key = jax.random.PRNGKey(0)  # Set a random seed for reproducibility
    x_samples = jax.random.uniform(key, (num_samples,), minval=x_min, maxval=x_max)
    y_samples = jax.random.uniform(jax.random.split(key)[1], (num_samples,), minval=y_min, maxval=y_max)

    # Perform random search
    results = []
    for x, y in zip(x_samples, y_samples):
        params = jnp.array([x, y, initial_radius])
        solution = solve_function(params, loss_fn)
        
        # Check if solution is close to either of the actual positions
        is_solution = any(is_close_to_actual(solution, pos, epsilon) for pos in actual_positions)
        results.append((params, is_solution, solution))

    # Plot the initial points
    fig, ax = plt.subplots(figsize=(10, 6))
    ax.imshow(image, cmap='gray', origin='lower')

    for params, is_solution, solution in results:
        if jnp.any(jnp.isnan(params)):
            ax.plot(params[0], params[1], 'x', color='red', markersize=5)
        elif is_solution:
            near_target_radius = jnp.abs(solution[2] - target_radius) < 3
            shape = "o" if near_target_radius else "+"
            ax.plot(params[0], params[1], shape, color='green', markersize=5)
        else:
            ax.plot(params[0], params[1], 'o', color='red', markersize=5)

    ax.set_title(f"Points that lead to fit. Initial radius {initial_radius}. (Green o: solution, Green +: solution within 3 radius of target, Red o: No Solution, Red x: NaN)")
    plt.show()

    # Print summary
    num_solutions = sum(is_solution for _, is_solution, _ in results)
    print(f"Found solutions for {num_solutions} out of {num_samples} initial points.")

perform_random_search(solve, loss_fn_unoccupied_ratio_sigmoid_fixed_radius)


In [None]:
def solve_with_gradient_descent(initial_params, loss_fn, learning_rate=0.1, num_steps=1000, tolerance=1e-3):
    optimizer = optax.adam(learning_rate)
    opt_state = optimizer.init(initial_params)
    
    def step(params, opt_state):
        loss, grad = jax.value_and_grad(loss_fn)(params)
        print(f"loss: {jnp.array([loss])}, grad: {grad}, params: {params}")
        updates, opt_state = optimizer.update(grad, opt_state)
        new_params = optax.apply_updates(params, updates)
        return new_params, opt_state, loss
    
    params = initial_params
    steps_taken = 0
    for _ in range(num_steps):
        new_params, opt_state, loss = step(params, opt_state)
        steps_taken += 1
        if jnp.all(jnp.abs(new_params - params) < tolerance):
            break
        params = new_params

    print(f"Total steps taken: {steps_taken}")
    return params

In [None]:
params = jnp.array([55.1, 20.1, 40])
solution = solve_with_gradient_descent(params, loss_fn_unoccupied_ratio_sigmoid)
print(solution)
plot_circles_on_image(solution.reshape(-1, 3))

In [None]:
# Loss function that includes total number of occupied pixels, to try and encourage smaller fit circles to expand to capture more of the occupied area
total_occupied = occupied_pixels.shape[0]

@jax.jit
def loss_fn_sigmoid2(params):
    def f(pixel):
        x, y, occupied = pixel
        dist = distance_to_circle(x, y, *params[0:3])
        inside_prob = jax.nn.sigmoid(-dist * 1.0)  # Negative dist because we want 1 inside, 0 outside
        return jnp.array([inside_prob, inside_prob * occupied])
        
    inside_pixels = jax.vmap(f)(enumerated_pixels)
    total_inside, occupied_inside = inside_pixels.sum(axis=0)
    unoccupied_inside = total_inside - occupied_inside
    # suggestion from Avi: https://discord.com/channels/721409715602587698/1295518715793768458/1295777188158640182
    return 1 - (occupied_inside / (unoccupied_inside + total_occupied))

plot_loss_vs_radius(loss_fn_sigmoid2, title="Sigmoid2 loss")

In [None]:
params = jnp.array([60.1, 20.1, 10])
#solution = solve_with_gradient_descent(params, loss_fn_sigmoid2)
solution = solve(params, loss_fn_sigmoid2)
print(solution)
plot_circles_on_image(solution.reshape(-1, 3))

In [None]:
perform_random_search(solve, loss_fn_sigmoid2, initial_radius=10)

In [None]:
#perform_random_search(solve_with_gradient_descent, loss_fn_sigmoid2, initial_radius=10)

# Does a similar loss function work with unfilled circles?

In [None]:
image_width, image_height = 200, 100

image = Image.new('L', (image_width, image_height))
draw = ImageDraw.Draw(image)


style = {"outline": "white"}
target_radius = 20

draw.circle([50, 30], target_radius, **style)
draw.circle([150, 30], target_radius, **style)

image = jnp.array(image)
image = image / image.max()
plt.imshow(image, cmap='gray', origin='lower')
plt.show()

# x/y/value of all pixels
enumerated_pixels = enumerate_pixels(image)

# x/y coordinates of nonzero pixels
occupied_pixels = enumerated_pixels[enumerated_pixels[:, 2] > 0][:, 0:2]

In [None]:
# Loss function that includes total number of occupied pixels, to try and encourage smaller fit circles to expand to capture more of the occupied area
total_occupied = occupied_pixels.shape[0]
threshold = 5.0

@jax.jit
def loss_fn_sigmoid_unfilled(params):
    def f(pixel):
        x, y, occupied = pixel
        dist = distance_to_circle(x, y, *params[0:3])
        near_surface_prob = jax.nn.sigmoid(threshold - dist**2)
        return jnp.array([near_surface_prob, near_surface_prob * occupied])
        
    pixels = jax.vmap(f)(enumerated_pixels)
    near_surface, occupied_near_surface = pixels.sum(axis=0)
    unoccupied_near_surface = near_surface - occupied_near_surface

    return 1 - (occupied_near_surface / (unoccupied_near_surface + total_occupied))

plot_loss_vs_radius(loss_fn_sigmoid_unfilled, title="Sigmoid_unfilled loss")

In [None]:
params = jnp.array([35.1, 19.1, 10])
#solution = solve_with_gradient_descent(params, loss_fn_sigmoid_unfilled)
solution = solve(params, loss_fn_sigmoid_unfilled)
print(solution)
plot_circles_on_image(solution.reshape(-1, 3))

In [None]:
perform_random_search(solve, loss_fn_sigmoid_unfilled, initial_radius=20, image=image)

In [None]:
# Here's an example of a bad solution. A potential issue is that we're not penalizing the parts of the circle that end up offscreen, as there are no pixels there to evaluate.
# A prior on size would probably also help.

params = jnp.array([38.1, 18.1, 10])
solution = solve(params, loss_fn_sigmoid_unfilled)
plot_circles_on_image(solution.reshape(-1, 3))

In [None]:
image_width, image_height = 400, 400

image = Image.new('L', (image_width, image_height))
draw = ImageDraw.Draw(image)


style = {"outline": "white"}
target_radius = 20

draw.circle([150, 130], target_radius, **style)
draw.circle([250, 130], target_radius, **style)

image = jnp.array(image)
image = image / image.max()
plt.imshow(image, cmap='gray', origin='lower')
plt.show()

# x/y/value of all pixels
enumerated_pixels = enumerate_pixels(image)

# x/y coordinates of nonzero pixels
occupied_pixels = enumerated_pixels[enumerated_pixels[:, 2] > 0][:, 0:2]

In [None]:
# Loss function that includes total number of occupied pixels, to try and encourage smaller fit circles to expand to capture more of the occupied area
total_occupied = occupied_pixels.shape[0]
threshold = 5.0

@jax.jit
def loss_fn_sigmoid_unfilled(params):
    def f(pixel):
        x, y, occupied = pixel
        dist = distance_to_circle(x, y, *params[0:3])
        near_surface_prob = jax.nn.sigmoid(threshold - dist**2)
        return jnp.array([near_surface_prob, near_surface_prob * occupied])
        
    pixels = jax.vmap(f)(enumerated_pixels)
    near_surface, occupied_near_surface = pixels.sum(axis=0)
    unoccupied_near_surface = near_surface - occupied_near_surface

    return 1 - (occupied_near_surface / (unoccupied_near_surface + total_occupied))

In [None]:
params = jnp.array([138.1, 118.1, 10])
solution = solve(params, loss_fn_sigmoid_unfilled)
plot_circles_on_image(solution.reshape(-1, 3))

# Fitting lines too

In [None]:
image_width, image_height = 400, 400

image = Image.new('L', (image_width, image_height))
draw = ImageDraw.Draw(image)

style = {"outline": "white"}

l1 = [100, 100, 100, 150]
l2 = [100, 100, 200, 100]

draw.line(l1, fill="white", width=1)
draw.line(l2, fill="white", width=1)

image = jnp.array(image)
image = image / image.max()
plt.imshow(image, cmap='gray', origin='lower')
plt.show()

# x/y/value of all pixels
enumerated_pixels = enumerate_pixels(image)

# x/y coordinates of nonzero pixels
occupied_pixels = enumerated_pixels[enumerated_pixels[:, 2] > 0][:, 0:2]

def plot_fitted_line(params, image=image):
    # Draw the line with the optimized parameters
    fig, ax = plt.subplots(figsize=(10, 6))

    # Display the existing image
    ax.imshow(image, cmap='gray', origin='lower')

    # Extract the optimized parameters
    x1, y1, x2, y2 = params

    # Draw the red line
    ax.plot([x1, x2], [y1, y2], color='red', linewidth=2, label='Fitted Line')

    ax.set_title('Image with Fitted Line')
    ax.legend()
    plt.show()


In [None]:
def distance_to_line(x, y, x1, y1, x2, y2):
    # Calculate the vector components of the line
    dx = x2 - x1
    dy = y2 - y1
    
    # Calculate the length of the line squared
    line_length_sq = dx**2 + dy**2
    
    # Calculate the parameter t for the closest point on the line
    t = ((x - x1) * dx + (y - y1) * dy) / line_length_sq
    
    # Clamp t to [0, 1] to ensure the closest point is on the line segment
    t = jnp.clip(t, 0, 1)
    
    # Calculate the closest point on the line
    closest_x = x1 + t * dx
    closest_y = y1 + t * dy
    
    # Calculate and return the distance to the closest point
    return jnp.sqrt((x - closest_x)**2 + (y - closest_y)**2)


In [None]:
total_occupied = occupied_pixels.shape[0]

@jax.jit
def loss_fn_sigmoid_line(params, threshold=2.0):
    def f(pixel):
        x, y, occupied = pixel
        dist = distance_to_line(x, y, *params)
        near_surface_prob = jax.nn.sigmoid(threshold - dist**2)
        return jnp.array([near_surface_prob, near_surface_prob * occupied])
        
    pixels = jax.vmap(f)(enumerated_pixels)
    near_surface, occupied_near_surface = pixels.sum(axis=0)
    unoccupied_near_surface = near_surface - occupied_near_surface

    return 1 - (occupied_near_surface / (unoccupied_near_surface + total_occupied))

In [None]:
params = jnp.array([10.1, 10.1, 100.1, 150.1])
solution = solve(params, loss_fn_sigmoid_line, threshold=5.0)
plot_fitted_line(solution)


In [None]:
color_success = "#20B2AA"  # Light Sea Green
color_failure = "#FF4500"  # Orange Red
def perform_random_search(n, max_line_length=50, img=image, loss_fn=loss_fn_sigmoid_line, solve_fn=solve, **kwargs):
    seed = 42  # Fixed seed for reproducibility
    key = jax.random.PRNGKey(seed)

    def generate_random_line(key):
        key, subkey1, subkey2, subkey3, subkey4 = jax.random.split(key, 5)
        x1 = jax.random.uniform(subkey1, minval=0, maxval=img.shape[1])
        y1 = jax.random.uniform(subkey2, minval=0, maxval=img.shape[0])
        angle = jax.random.uniform(subkey3, minval=0, maxval=2*jnp.pi)
        
        # Calculate maximum possible length in the current direction
        dx = jnp.cos(angle)
        dy = jnp.sin(angle)
        t_x = jnp.where(dx != 0, (img.shape[1] - x1) / dx if dx > 0 else -x1 / dx, jnp.inf)
        t_y = jnp.where(dy != 0, (img.shape[0] - y1) / dy if dy > 0 else -y1 / dy, jnp.inf)
        max_t = jnp.minimum(t_x, t_y)
        
        # Generate a random length within the allowed range
        length = jax.random.uniform(subkey4, minval=0, maxval=jnp.minimum(max_t, max_line_length))
        
        x2 = x1 + length * dx
        y2 = y1 + length * dy
        return jnp.array([x1, y1, x2, y2]), key

    def converged(params, target_lines, tolerance=5):
        for target in target_lines:
            if jnp.all(jnp.abs(params - target) < tolerance):
                return True
            # handle case where points have flipped
            if jnp.all(jnp.abs(params[jnp.array([2, 3, 0, 1])] - target) < tolerance):
                return True
        return False
        
    
    def moved(initial_params, solution_params):
        return any(jnp.abs(initial_params - solution_params) > 1)

    target_lines = jnp.array([l1, l2])
    results = []

    for _ in range(n):
        initial_params, key = generate_random_line(key)
        solution = solve_fn(initial_params, loss_fn, **kwargs)
        did_converge = converged(solution, target_lines)
        did_move = moved(initial_params, solution)
        results.append((initial_params, solution, did_converge, did_move))

    n_converged = sum(converged for _, _, converged, _ in results)
    print(f"Found {n_converged}/{n} solutions")

    return results

def plot_random_search_results(results, img=image):
    plt.figure(figsize=(12, 8))
    plt.imshow(img, cmap='gray', origin='lower')
    
    for params, solution, converged, moved in results:
        x1, y1, x2, y2 = params
        color = color_success if converged else color_failure
        if moved:
            plt.plot([x1, x2], [y1, y2], color=color, linewidth=2.0, alpha=0.5)
    
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.show()

In [None]:
search_results = perform_random_search(1000, threshold=1.0)
plot_random_search_results(search_results, image)

In [None]:
# Collect all results that moved but did not converge
moved_results = [result for result in search_results if result[3] and not result[2]]
print(f"len(moved_results): {len(moved_results)}")

In [None]:
def plot_results(moved_results, image, max_plots=9):
    # Determine the number of plots (up to max_plots)
    num_plots = min(max_plots, len(moved_results))

    # Create a grid of subplots
    fig, axes = plt.subplots(3, 3, figsize=(15, 15))
    fig.suptitle("Moved Results: Initial (Blue) and Final (Red) Positions", fontsize=16)

    # Flatten the axes array for easier indexing
    axes_flat = axes.flatten()

    for i in range(num_plots):
        ax = axes_flat[i]
        initial_params, solution, _, _ = moved_results[i]
        
        # Plot the close-up of the target image
        ax.imshow(image[75:175, 75:225], cmap='gray', origin='lower', extent=[75, 225, 75, 175])
        
        # Plot initial line
        x1, y1, x2, y2 = initial_params
        ax.plot([x1, x2], [y1, y2], color='blue', linewidth=2.0, label='Initial', alpha=0.7)
        
        # Plot final line
        x1, y1, x2, y2 = solution
        ax.plot([x1, x2], [y1, y2], color='red', linewidth=2.0, label='Final', alpha=0.7)
        
        ax.set_xlim(75, 225)
        ax.set_ylim(75, 175)
        ax.set_title(f'Result {i + 1}')
        
        if i == 0:  # Only add legend to the first subplot
            ax.legend(loc='upper right', fontsize='small')

    # Remove any unused subplots
    for i in range(num_plots, 9):
        fig.delaxes(axes_flat[i])

    plt.tight_layout()
    plt.show()

In [None]:
moved_results = [result for result in search_results if result[3] and not result[2]]
plot_results(moved_results, image)

It seems like as the threshold increases, we end up moving more lines but fewer of them converge. Let's plot to get a better sense.

In [None]:
thresholds = jnp.array([0.5, 1.0, 2.0, 5.0, 10.0, 20.0, 50.0, 100.0, 200.0, 500.0])

results = []
for threshold in thresholds:
    search_results = perform_random_search(1000, max_line_length=50, threshold=threshold)
    total_converged = sum(result[2] for result in search_results)
    total_moved = sum(result[3] for result in search_results)
    results.append((threshold, total_converged, total_moved, search_results))

In [None]:
# Extract data for plotting
thresholds = [result[0] for result in results]
total_converged = [result[1] for result in results]
total_moved = [result[2] for result in results]

# Create the scatter plot
plt.figure(figsize=(10, 6))
plt.scatter(thresholds, total_converged, color=color_success, label='Total Converged', s=50)
plt.scatter(thresholds, total_moved, color=color_failure, label='Total Moved', s=50)

# Connect points with lines
plt.plot(thresholds, total_converged, color=color_success, linestyle='--', alpha=0.5)
plt.plot(thresholds, total_moved, color=color_failure, linestyle='--', alpha=0.5)

# Set labels and title
plt.xlabel('Threshold')
plt.ylabel('Count')
plt.title('Convergence and Movement vs Threshold')

# Set x-axis to log scale and use the threshold values as ticks
plt.xscale('log')
plt.xticks(thresholds, [str(t) for t in thresholds])

# Add legend
plt.legend()

# Add grid for better readability
plt.grid(True, which="both", ls="--", alpha=0.5)

# Show the plot
plt.tight_layout()
plt.show()


In [None]:
# Let's try a multistep solve that starts with high threshold and then reduces to a smaller one.
thresholds = jnp.array([100.0, 10.0, 1.0])

def multistep_solve(params, loss_fn, thresholds):
    for threshold in thresholds:
        params = solve(params, loss_fn, threshold=threshold)
    return params

multistep_results = perform_random_search(1000, solve_fn=multistep_solve, thresholds=thresholds)

In [None]:
plot_random_search_results(multistep_results, image)

In [None]:
# Calculate the total number of multi-step results that moved at all
total_moved = sum(1 for result in multistep_results if result[3])

print(f"Total number of multi-step results that moved: {total_moved}")


In [None]:
results100 = perform_random_search(1000, threshold=100.0)
moved_results = [result for result in results100 if result[3]]
print(f"Total number of single-step results that moved: {len(moved_results)}")

In [None]:
moved_results = [result for result in multistep_results if result[3] and not result[2]]
plot_results(moved_results, image)