In [238]:
import numpy as np
import jax
import jax.numpy as jnp
from jaxlie import SE3, SO3
from typing import *
from sdf_world.sdf_world import *
from sdf_world.sparse_ipopt import *

In [23]:
from jax._src import linear_util as lu
from jax._src.api import _std_basis, _jvp, _jacfwd_unravel
from jax._src.api_util import argnums_partial
from functools import partial

def value_and_jacfwd(fun: Callable, argnums=0):
    def value_and_jacfwd_f(*args, **kwargs):
        f = lu.wrap_init(fun, kwargs)
        f_partial, dyn_args = argnums_partial(f, argnums, args,
                                            require_static_args_hashable=False)
        pushfwd: Callable = partial(_jvp, f_partial, dyn_args)
        y, jac = jax.vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args))
        example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
        jac_tree = jax.tree_map(partial(_jacfwd_unravel, example_args), y, jac)
        return y, jac_tree
    return value_and_jacfwd_f

In [7]:
world = SDFWorld()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7001/static/


In [8]:
world.show_in_jupyter()

In [9]:
frame = Frame(world.vis, "frame")

In [10]:
p1 = Sphere(world.vis, "p1", 0.02)
p2 = Sphere(world.vis, "p2", 0.02)
p3 = Sphere(world.vis, "p3", 0.02, "blue")
p4 = Sphere(world.vis, "p4", 0.02, "blue")
ps = [p1, p2, p3, p4]

In [232]:
def residual(ec, rot_points, ref_points):
    rot = SO3.exp(ec)
    d = jax.vmap(rot.apply)(rot_points) - ref_points
    return jnp.hstack(d)
vg_residual_fn = value_and_jacfwd(residual, argnums=0)

In [227]:
l, L = 0.1, 0.4
locations = []
rot_points = np.random.uniform(-l, l, size=(2, 3))
ref_points = np.random.uniform(-L, L, size=(2, 3))

def draw(ec):
    rot = SO3.exp(ec)
    frame.set_pose(SE3.from_rotation(rot))
    p1.set_translate(rot.apply(rot_points[0]))
    p2.set_translate(rot.apply(rot_points[1]))
    p3.set_translate(ref_points[0])
    p4.set_translate(ref_points[1])    

In [228]:
ec = jnp.zeros(3)
draw(ec)

In [192]:
from typing import NamedTuple

class Carry(NamedTuple):
    rot_points: jnp.ndarray
    ref_points: jnp.ndarray
    ec: jnp.ndarray = jnp.zeros(3)
    d: jnp.ndarray = jnp.zeros(3)
    i: int = 0
    damping: float = 0.04
    threshold: float = 1e-4
    max_iter: int = 10

    def update(self, d):
        return Carry(
            self.rot_points, self.ref_points, self.ec+d, d, self.i+1, self.damping)

In [209]:
def get_rot_body(carry:Carry):
    res, jac = vg_residual_fn(carry.ec, carry.rot_points, carry.ref_points)
    hess = jac.T@jac
    d = jnp.linalg.solve((hess+carry.damping*jnp.eye(3)), -jac.T@res)
    return carry.update(d)
    
def get_rot_cond(carry:Carry):
    return (jnp.linalg.norm(carry.d) > carry.threshold) | (carry.i < carry.max_iter)

In [229]:
carry = Carry(rot_points, ref_points)
result = jax.lax.while_loop(get_rot_cond, get_rot_body, carry)
draw(result.ec)

In [231]:
def residual_sumsqr(ec, rot_points, ref_points):
    res = residual(ec, rot_points, ref_points)
    return 1/2 * res.T @ res

In [237]:
grad_residual_sumsqr = jax.grad(residual_sumsqr)  #(result.ec, rot_points, ref_points)

In [234]:
ec

Array([0., 0., 0.], dtype=float32)

In [40]:
class cyipopt