## Initialization

In [1]:
from polymorph_s2df import *
from polymorph_s2df.devutils import *
import jax.numpy as jnp
import jax
import matplotlib.pyplot as plt
import numpy as np
from jax.tree_util import register_pytree_node_class


In [2]:
import optimistix
from timeit import default_timer as timer


def optimize_params(cost, params):
    solver = optimistix.BFGS(rtol=1e-5, atol=1e-6)
    start = timer()
    solution = optimistix.minimise(cost, solver, params, throw=False)
    elapsed = timer() - start
    print(
        "{0} steps in {1:.3f} seconds".format(solution.stats.get("num_steps"), elapsed)
    )
    return solution.value


In [26]:
from polymorph_s2df.utils import length, indent_shape

@register_pytree_node_class
class Translation(Shape):
    def __init__(self, offset, shape, variables=()):
        self.offset = offset
        self.shape = shape
        self.variables = variables

    def __repr__(self):
        return f"Translation(\n  {self.offset},\n{indent_shape(self.shape)}\n)"

    def tree_flatten(self):
        if "offset" not in self.variables:
            return (
                (self.shape,),
                {
                    "offset": (self.offset[0], self.offset[1]),
                    "variables": self.variables,
                },
            )
        return (self.offset, self.shape), {"variables": self.variables}

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        if len(aux_data) == 1:
            return cls(*children, **aux_data)
        return cls(
            shape=children[0],
            variables=aux_data["variables"],
            offset=np.array(aux_data["offset"]),
        )

    def distance(self, p):
        return self.shape.distance(p - self.offset)

@register_pytree_node_class
class Circle(Shape):
    def __init__(self, radius, variables=()):
        self.radius = radius
        self.variables = variables

    def __repr__(self):
        return f"Circle({self.radius})"

    def tree_flatten(self):
        if "radius" in self.variables:
            return (self.radius,), {"variables": self.variables}
        return (), {"radius": self.radius, "variables": self.variables}

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*children, **aux_data)

    def distance(self, p):
        return length(p) - self.radius


## A most basic cost function 

In [27]:
def compute_cost(structure, _):
    geometry = structure["geometry"]
    constraint = structure["constraint"]

    # This could be a tree a constraints
    cost = constraint.cost(geometry)  
    print("tracing")
    return cost

## A very simple constraint node

In [28]:
@register_pytree_node_class
class DistanceToPointConstraint:
    def __init__(self, point):
        self.point = point

    def __repr__(self):
        return f"DistanceToPointConstraint({self.point})"
        

    def tree_flatten(self):
        return (), {"point": self.point}

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(**aux_data)

    def cost(self, geometry):
        return geometry.distance(np.array([self.point]))[0]**2 

In [29]:

structure = {
    "constraint": DistanceToPointConstraint((2., 2.)),
    "geometry": Translation(p(0, 0), Circle(2), variables=["offset"]),
}

We set a basic problem with a single constraint and a geometry where we want to very the position of the circle to touch the point `(2, 2)`

In [30]:
optimize_params(compute_cost, structure)

tracing
4 steps in 0.196 seconds


{'constraint': DistanceToPointConstraint((2.0, 2.0)),
 'geometry': Translation(
   [0.5857864 0.5857864],
   Circle(2)
 )}

If you rerun this function you can see that tracing is not called again.

In [31]:
optimize_params(compute_cost, structure)

4 steps in 0.002 seconds


{'constraint': DistanceToPointConstraint((2.0, 2.0)),
 'geometry': Translation(
   [0.5857864 0.5857864],
   Circle(2)
 )}

In [32]:
jax.tree.flatten(structure)

([Array([0, 0], dtype=int32)],
 PyTreeDef({'constraint': CustomNode(DistanceToPointConstraint[{'point': (2.0, 2.0)}], []), 'geometry': CustomNode(Translation[{'variables': ['offset']}], [*, CustomNode(Circle[{'radius': 2, 'variables': ()}], [])])}))

If we change the problem to change the radius of the circle, the function is traced again

In [33]:
structure2 = {
    "constraint": DistanceToPointConstraint((2., 2.)),
    "geometry": Translation(p(0., 0.), Circle(2, variables=["radius"]))
}

optimize_params(compute_cost, structure2)

tracing
4 steps in 0.155 seconds


{'constraint': DistanceToPointConstraint((2.0, 2.0)),
 'geometry': Translation(
   [0. 0.],
   Circle(2.8284270763397217)
 )}