In [1]:
from polymorph_s2df import *
import jax.numpy as jnp
import jax

SAMPLE_SIZE = 200
INTEGRATION_BOUNDS = (-3, 3)
SEED = jax.random.PRNGKey(0)
points = jax.random.uniform(SEED, (SAMPLE_SIZE, 2), jnp.float32, *INTEGRATION_BOUNDS)

def area(shape, points, bounds):
    samples = shape.is_inside(points)
    return samples.mean() * (bounds[1] - bounds[0])**2


import optimistix
from timeit import default_timer as timer
def optimize_params(cost, params, points):
    solver = optimistix.BFGS(rtol=1e-5, atol=1e-6)
    start = timer()
    solution = optimistix.minimise(cost, solver, params, points, throw=False)
    elapsed = timer() - start
    print("{0} steps in {1:.2f} seconds".format(
            solution.stats.get('num_steps'),
            elapsed))
    return solution.value

In [2]:
initial_params = jnp.array(
    [ 
        0.5 # circle radius
    ])

def cost(params, points):
    radius_circle = params[0]
    shape = Circle(radius_circle)
    target_area = 3.141
    cost_area = (target_area - area(shape, points, INTEGRATION_BOUNDS))**2
    return cost_area

params = optimize_params(cost, initial_params, points)
print("Final params:", params)

# Run again so we can see how long it takes post-JIT
params = optimize_params(cost, initial_params, points)


22 steps in 0.12 seconds
Final params: [0.8747944]
22 steps in 0.00 seconds


In [3]:
import sys
from polymorph_num.node import as_node
from polymorph_num import loss, ops, optimizer, point

In [4]:
apoints = [point.Point(p[0].item(), p[1].item()) for p in points]

def circle_sdf(radius, center, point):
    dx = center.x - point.x
    dy = center.y - point.y
    dist = ops.sqrt(dx*dx + dy*dy)
    return dist - radius

#r = ops.param()
r = ops.observation('r')
c = point.Point(0.,0.)

def area(f, points, bounds):
    sum = as_node(0.0)
    for p in points:
        distance = circle_sdf(r, c, p)
        scale = as_node(100)
        is_inside = as_node(1) - ops.sigmoid(scale*distance)
        sum += is_inside
    mean = sum / len(points)
    return mean * (bounds[1] - bounds[0])**2

target_area = as_node(3.141)
error = target_area - area(circle_sdf, apoints, INTEGRATION_BOUNDS)

l = loss.Loss(error*error)
l.register_output(r)

opt = optimizer.Optimizer(l)
print("---")
start = timer()
soln = opt.optimize({'r': 0.5})
elapsed = timer() - start
print("optimized in {0:.2f} seconds".format(elapsed))
print(soln.eval(r))

tracing
---
tracing
optimized in 1.54 seconds
0.5


In [5]:
## Quick test to see if JAX is smart enough to detect a reduction
import jax

a = jnp.arange(3)

@jax.jit
def sum1(a):
    s = 0
    for p in a:
        s += p
    return s

@jax.jit
def sum2(a):
    return a.sum()

print(sum1.lower(a).as_text())

print(sum2.lower(a).as_text())


module @jit_sum1 attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<3xi32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.slice %arg0 [0:1] : (tensor<3xi32>) -> tensor<1xi32>
    %1 = stablehlo.reshape %0 : (tensor<1xi32>) -> tensor<i32>
    %c = stablehlo.constant dense<0> : tensor<i32>
    %2 = stablehlo.add %c, %1 : tensor<i32>
    %3 = stablehlo.slice %arg0 [1:2] : (tensor<3xi32>) -> tensor<1xi32>
    %4 = stablehlo.reshape %3 : (tensor<1xi32>) -> tensor<i32>
    %5 = stablehlo.add %2, %4 : tensor<i32>
    %6 = stablehlo.slice %arg0 [2:3] : (tensor<3xi32>) -> tensor<1xi32>
    %7 = stablehlo.reshape %6 : (tensor<1xi32>) -> tensor<i32>
    %8 = stablehlo.add %5, %7 : tensor<i32>
    return %8 : tensor<i32>
  }
}

module @jit_sum2 attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<3xi32>