## Shape optimization, using the new s2df functions

(Same as boat_shape_optimization notebook.)

In [None]:
from polymorph_s2df import *
import jax.numpy as jnp
import jax
import matplotlib.pyplot as plt

def render(shape, bounds=(-3, 3), n=500):
    x = jnp.linspace(bounds[0], bounds[1], n)
    X, Y = jnp.meshgrid(x, x)

    grid_points = jnp.column_stack((X.flatten(), Y.flatten()))
    vec_is_inside = jax.jit(lambda x: 1 - shape.is_inside(x))
    plt.imshow(vec_is_inside(grid_points).reshape(n, n), cmap="gray", origin="lower", extent=[bounds[0], bounds[1],bounds[0], bounds[1]])

In [None]:
#########################
## Analysis

## TODO: combine sample points and bounds into single object since they're always related.

## TODO: should these fns be put in the s2df shared library?

def area(sdf, points, bounds=(-1, 1)):
    samples = sdf.is_inside(points)
    return samples.mean() * (bounds[1] - bounds[0])**2

def centroid(sdf, points):
    samples = sdf.is_inside(points)
    return jnp.average(points, weights=samples, axis=0)


In [None]:
RELATIVE_DENSITY = 0.5
SAMPLE_SIZE = 10000
INTEGRATION_BOUNDS = (-3, 3)
SEED = jax.random.PRNGKey(0)

In [None]:
def update_position(shape, original_shape_center, p1, p2):
    return (shape
            .rotate_around(jnp.atan(p1), original_shape_center)
            .translate(p(0, p2)))

def displacement(shape):
    return shape.intersect(BottomHalfPlane)

def alignment_cost(shape, points):
    center_of_gravity = centroid(shape, points)
    
    shape_weight = RELATIVE_DENSITY * area(shape, points, bounds=INTEGRATION_BOUNDS)

    displacement = displacement(shape)
    center_of_buoyancy = centroid(displacement, points)

    displacement_weight = area(displacement, points, bounds=INTEGRATION_BOUNDS)

    gravity_cost = displacement_weight - shape_weight
    torque_cost = (center_of_buoyancy - center_of_gravity)[0]

    costs = jnp.array([gravity_cost, torque_cost])
    return jnp.dot(costs, costs)

def center_of_gravity(shape, points):
    return centroid(shape, points)

def center_of_buoyancy(shape, points):
    return centroid(displacement(shape), points)

In [None]:
points = jax.random.uniform(SEED, (SAMPLE_SIZE, 2), jnp.float32, *INTEGRATION_BOUNDS)
initial_params = jnp.array(
    [ 0. # y position
    , 0. # tan(angle) around center of gravity
    , 1. # r1 width
    , 1. # r2 width
    , 1. # r3 width
    ])


def params_to_shape(params):
    shape = (Box(params[2], 1)
             .union(Box(params[3], 1).translate(p(0, 1)))
             .union(Box(params[4], 1).translate(p(0, 2)))
             )
    cog = center_of_gravity(shape, points)
    return update_position(shape, cog, params[0], params[1])

render(params_to_shape(initial_params), INTEGRATION_BOUNDS)

In [None]:
def optimize_params(cost, params, points):
    solver = optimistix.BFGS(rtol=1e-5, atol=1e-6)
    start = timer()
    solution = optimistix.minimise(cost, solver, params, points)
    elapsed = timer() - start
    print("{0} steps in {1:.2f} seconds".format(
            solution.stats.get('num_steps'),
            elapsed))
    return solution.value

In [None]:
def cost(params, points):
    shape = params_to_shape(params)
    target_area = 3
    area_cost = (target_area - area(shape, points, INTEGRATION_BOUNDS))**2
    distance_from_3 = shape(jnp.array([[0, 3]]))
    return area_cost + alignment_cost(shape, points) + 2.0**distance_from_3[0]

params = optimize_params(cost, initial_params, points)
print(params)

end_shape = params_to_shape(params)
render(end_shape, bounds=INTEGRATION_BOUNDS)
