## Initialization

In [1]:
from polymorph_s2df import *
from polymorph_s2df.devutils import *
import jax.numpy as jnp
import jax
import matplotlib.pyplot as plt
import numpy as np
from jax.tree_util import register_pytree_node_class


In [2]:
import optimistix
from timeit import default_timer as timer


def optimize_params(cost, params):
    solver = optimistix.BFGS(rtol=1e-5, atol=1e-6)
    start = timer()
    solution = optimistix.minimise(cost, solver, params, throw=False)
    elapsed = timer() - start
    print(
        "{0} steps in {1:.3f} seconds".format(solution.stats.get("num_steps"), elapsed)
    )
    return solution.value


## Reimplement Circle and Translation as variable aware pytree

In [3]:
from polymorph_s2df.utils import length, indent_shape

@register_pytree_node_class
class Translation(Shape):
    def __init__(self, offset, shape):
        self.offset = offset
        self.shape = shape

    def __repr__(self):
        return f"Translation(\n  {self.offset},\n{indent_shape(self.shape)}\n)"

    def tree_flatten(self):
        return (
                (self.shape,), { "offset": self.offset, },
            )

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(
            shape=children[0],
            **aux_data
        )

    def distance(self, register, p):
        return self.shape.distance(register, p - register["array"].deref(self.offset))

@register_pytree_node_class
class Circle(Shape):
    def __init__(self, radius):
        self.radius = radius

    def __repr__(self):
        return f"Circle({self.radius})"

    def tree_flatten(self):
        return (), {"radius": self.radius}

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*children, **aux_data)

    def distance(self, register, p):
        return length(p) - register["scalar"].deref(self.radius)


## A most basic cost function 

In [4]:
def compute_cost(structure, _):
    geometry = structure["geometry"]
    constraint = structure["constraint"]
    register = structure["register"]

    # This could be a tree a constraints
    cost = constraint.cost(register, geometry)  
    print("tracing")
    return cost

## A very simple constraint node

In [5]:
@register_pytree_node_class
class DistanceToPointConstraint:
    def __init__(self, point):
        self.point = point

    def __repr__(self):
        return f"DistanceToPointConstraint({self.point})"
        

    def tree_flatten(self):
        return (), {"point": self.point}

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(**aux_data)

    def cost(self, register, geometry):
        return geometry.distance(register, np.array([self.point]))[0]**2 

In [6]:
@register_pytree_node_class
class ScalarRegister:
    def __init__(self, values=None, is_var=None):
        self.values = values if values is not None else []
        self.is_var = is_var if is_var is not None else []


    def __repr__(self):
        return f"ScalarRegister{self.values}"
        

    def tree_flatten(self):
        return tuple(var for (var, is_var) in zip(self.values, self.is_var) if is_var), {
            "is_var": tuple(self.is_var),
            "consts": tuple(var for (var, is_var) in zip(self.values, self.is_var) if not is_var)
        }

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        variables = list(children)
        constants = list(aux_data["consts"])
        is_var = aux_data["is_var"]
        
        return cls(is_var=is_var, values=[
            variables.pop(0) if is_var else constants.pop(0)
            for is_var in is_var
        ])

    def const(self, value):
        index = len(self.values)
        self.values.append(value)
        self.is_var.append(False)
        return index

    def var(self, value):
        index = len(self.values)
        self.values.append(value)
        self.is_var.append(True)
        return index

    def deref(self, ref):
        return self.values[ref]

@register_pytree_node_class
class ArrayRegister:
    def __init__(self, values=None, is_var=None):
        self.values = values if values is not None else []
        self.is_var = is_var if is_var is not None else []


    def __repr__(self):
        return f"ArrayRegister{self.values}"
        

    def tree_flatten(self):
        return tuple(var for (var, is_var) in zip(self.values, self.is_var) if is_var), {
            "is_var": tuple(self.is_var),
            "consts": tuple((var[0], var[1]) for (var, is_var) in zip(self.values, self.is_var) if not is_var)
        }

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        variables = list(children)
        constants = list(aux_data["consts"])
        is_var = aux_data["is_var"]
        
        return cls(is_var=is_var, values=[
            variables.pop(0) if is_var else np.array(constants.pop(0))
            for is_var in is_var
        ])

    def const(self, x, y):
        index = len(self.values)
        self.values.append(np.array([x,y]))
        self.is_var.append(False)
        return index

    def var(self, x, y):
        index = len(self.values)
        self.values.append(jnp.array([x,y]))
        self.is_var.append(True)
        return index

    def deref(self, ref):
        return self.values[ref]

## Play with the framework

In [7]:
scalar_register = ScalarRegister()
array_register = ArrayRegister()

structure = {
    "register": {
        "scalar": scalar_register,
        "array": array_register,
    },
    "constraint": DistanceToPointConstraint((2., 2.)),
    "geometry": Translation(array_register.const(0, 0), Circle(scalar_register.var(2))),
}

We set a basic problem with a single constraint and a geometry where we want to very the position of the circle to touch the point `(2, 2)`

In [8]:
optimize_params(compute_cost, structure)

tracing
4 steps in 0.249 seconds


{'constraint': DistanceToPointConstraint((2.0, 2.0)),
 'geometry': Translation(
   0,
   Circle(0)
 ),
 'register': {'array': ArrayRegister[array([0, 0])],
  'scalar': ScalarRegister[Array(2.828427, dtype=float32)]}}

If you rerun this function you can see that tracing is not called again.

In [9]:
optimize_params(compute_cost, structure)

4 steps in 0.001 seconds


{'constraint': DistanceToPointConstraint((2.0, 2.0)),
 'geometry': Translation(
   0,
   Circle(0)
 ),
 'register': {'array': ArrayRegister[array([0, 0])],
  'scalar': ScalarRegister[Array(2.828427, dtype=float32)]}}

In [10]:
jax.tree.flatten(structure["register"])

([2],
 PyTreeDef({'array': CustomNode(ArrayRegister[{'is_var': (False,), 'consts': ((0, 0),)}], []), 'scalar': CustomNode(ScalarRegister[{'is_var': (True,), 'consts': ()}], [*])}))

If we change the problem to change the radius of the circle, the function is traced again

In [13]:
scalar_register = ScalarRegister()
array_register = ArrayRegister()

structure2 = {
    "register": {
        "scalar": scalar_register,
        "array": array_register,
    },
    "constraint": DistanceToPointConstraint((2., 2.)),
    "geometry": Translation(array_register.var(0, 0), Circle(scalar_register.const(2))),
}

In [14]:
optimize_params(compute_cost, structure2)

tracing
4 steps in 0.389 seconds


{'constraint': DistanceToPointConstraint((2.0, 2.0)),
 'geometry': Translation(
   0,
   Circle(0)
 ),
 'register': {'array': ArrayRegister[Array([0.5857864, 0.5857864], dtype=float32)],
  'scalar': ScalarRegister[2]}}