## Optimization of points/lines together with s2df shapes

In [None]:
from polymorph_s2df import *
import jax.numpy as jnp
import jax
import matplotlib.pyplot as plt
import plotly.graph_objects as go

SAMPLE_SIZE = 10000
INTEGRATION_BOUNDS = (-3, 3)
SEED = jax.random.PRNGKey(0)
points = jax.random.uniform(SEED, (SAMPLE_SIZE, 2), jnp.float32, *INTEGRATION_BOUNDS)

In [None]:
#########################
## Analysis

## TODO: combine sample points and bounds into single object since they're always related.

## TODO: should these fns be put in the s2df shared library?

def area(shape, points, bounds=(-1, 1)):
    samples = shape.is_inside(points)
    return samples.mean() * (bounds[1] - bounds[0])**2

def centroid(shape, points):
    samples = shape.is_inside(points)
    return jnp.average(points, weights=samples, axis=0)


In [None]:
##############################
## Some drawing fns

PX_WIDTH=600


def add_shape(fig: go.FigureWidget, shape, n=PX_WIDTH):
    bounds = fig.__bounds__
    x = jnp.linspace(bounds[0], bounds[1], n)
    X, Y = jnp.meshgrid(x, x)

    grid_points = jnp.column_stack((X.flatten(), Y.flatten()))
    sdf = (shape.distance)

    values = sdf(grid_points).reshape(n, n)
    fig.add_contour(x=x, y=x,
        z = values,
        contours_coloring = "lines",
        ncontours = 1,
        showscale = False,
        line_width = 2,
    )
    return fig.data[-1]


def add_polyline(fig: go.FigureWidget, coordinates):
    x, y = coordinates.T
    fig.add_scatter(x=x, y=y, mode="lines", showlegend=False)
    return fig.data[-1]


def add_points(fig: go.FigureWidget, coordinates):
    x, y = coordinates.T
    fig.add_scatter(x=x, y=y, showlegend=False)
    return fig.data[-1]


def create_figure(bounds = INTEGRATION_BOUNDS) -> go.FigureWidget:
    fig = go.FigureWidget()
    fig.__bounds__ = bounds # save these for later so other fns can access

    fig.update_xaxes(
        range = INTEGRATION_BOUNDS,
        constrain = "domain",
        dtick=0.5,
        ticklabelstep=2
    )

    fig.update_yaxes(
        range=INTEGRATION_BOUNDS,
        scaleanchor = "x",
        scaleratio = 1,
        dtick=0.5,
        ticklabelstep=2        
    )

    fig.update_layout(
        template = "plotly_dark",
        width = PX_WIDTH,
        height = PX_WIDTH,
    )
    return fig

# f = create_figure()
# add_shape(f, Circle(2))
# add_shape(f, Circle(1))
# add_points(f, jnp.array([[1,1]]))
# add_polyline(f, jnp.array([[0,0], [1,1], [2, 1]]))
# f

In [None]:
%%html
<style>
.cell-output-ipywidget-background {
   background-color: transparent !important;
}
.jp-OutputArea-output {
   background-color: transparent;
}  
</style>

In [None]:
#############################
## Optmization

import optimistix
from timeit import default_timer as timer
def optimize_params(cost, params, points):
    solver = optimistix.BFGS(rtol=1e-5, atol=1e-6)
    #solver = optimistix.GradientDescent(learning_rate=1e-3, rtol=1e-5, atol=1e-6)
    start = timer()
    solution = optimistix.minimise(cost, solver, params, points, throw=False)
    elapsed = timer() - start
    print("{0} steps in {1:.2f} seconds".format(
            solution.stats.get('num_steps'),
            elapsed))
    return solution.value

In [None]:
##################################################
## Prob 1: Move point to fixed distance from SDF

initial_params = jnp.array(
    [ 0.1 # point x
    , 0 # point y
    ])

shape = Circle(2)

def cost(params, points):
    target_distance = 0.5
    cost_distance = (shape.distance(params[jnp.newaxis, 0:2]) - target_distance)**2
    return cost_distance[0]

## Note: BFGS fails if gradient vanishes (NaN) e.g., when at center of circle
# cost_grad = jax.grad(cost)
# print(cost_grad(jnp.array([.0,0.]), points))

f = create_figure()
trace_shape = add_shape(f, shape)
trace_point = add_points(f, initial_params[jnp.newaxis, :2])

def solve_and_draw(initial_x, initial_y):
    params = (initial_params
              .at[0].set(initial_x)
              .at[1].set(initial_y))
    params = optimize_params(cost, params, points)
    print("Final params:", params)
    trace_point.update(x=params[jnp.newaxis, 0],
                       y=params[jnp.newaxis, 1])
    return params
    

def handle_click(trace, points, selector):
    solve_and_draw(points.xs[0], points.ys[0])
    
trace_shape.on_click(handle_click)

f