In [1]:
from plyfile import PlyData
import jax 
import jax.numpy as jnp 
from jax import grad, jacfwd, vmap, jit
import polyscope as ps

from jaxlie import SO3, SE3
from jax.flatten_util import ravel_pytree

import numpy as np
import open3d as o3d 
import trimesh as tm

import meshcat
from meshcat import Visualizer
import meshcat.geometry as mc_geom
import meshcat.transformations as mc_trans

from kernel import create_kernel_matrix, RBF
import jaxopt

import time

objc[17123]: Class GLFWWindowDelegate is implemented in both /Users/ia285/miniconda3/lib/python3.10/site-packages/polyscope_bindings.cpython-310-darwin.so (0x1366bee90) and /Users/ia285/miniconda3/lib/python3.10/site-packages/open3d/cpu/pybind.cpython-310-darwin.so (0x16a286db0). One of the two will be used. Which one is undefined.
objc[17123]: Class GLFWApplicationDelegate is implemented in both /Users/ia285/miniconda3/lib/python3.10/site-packages/polyscope_bindings.cpython-310-darwin.so (0x1366bee68) and /Users/ia285/miniconda3/lib/python3.10/site-packages/open3d/cpu/pybind.cpython-310-darwin.so (0x16a286e28). One of the two will be used. Which one is undefined.
objc[17123]: Class GLFWContentView is implemented in both /Users/ia285/miniconda3/lib/python3.10/site-packages/polyscope_bindings.cpython-310-darwin.so (0x1366beee0) and /Users/ia285/miniconda3/lib/python3.10/site-packages/open3d/cpu/pybind.cpython-310-darwin.so (0x16a286e50). One of the two will be used. Which one is undefin

In [2]:
dataset = o3d.data.BunnyMesh()
plydata = PlyData.read(dataset.path)
vertices = np.vstack((
    plydata['vertex']['x'],
    plydata['vertex']['y'],
    plydata['vertex']['z']
)).T

# _min_pnt = vertices.min(axis=0)
# _max_pnt = vertices.max(axis=0)
# _mid_pnt = vertices.mean(axis=0)
# vertices = (vertices - _mid_pnt)/(_max_pnt-_min_pnt)

faces = np.array(np.vstack(plydata['face']['vertex_indices']))

mesh = tm.Trimesh(vertices=vertices, faces=faces)

ds_mesh = mesh.simplify_quadric_decimation(int(mesh.vertices.shape[0]/10))

In [None]:
# ps.init()

# ps_mesh = ps.register_surface_mesh("my mesh", ds_mesh.vertices, ds_mesh.faces)
# ps_mesh.add_vector_quantity("rand vecs", ds_mesh.vertex_normals, enabled=True)


# ps.show()

In [119]:
XP = jnp.hstack([-ds_mesh.vertex_normals, ds_mesh.vertices + ds_mesh.vertex_normals*0.1])
X  = jnp.zeros((100,6))

In [129]:
X_samples = ds_mesh.vertices+ds_mesh.vertex_normals*0.05
P_XI      = jnp.ones(X_samples.shape[0])
P_XI = P_XI/np.sum(P_XI)
h = 0.1
args = {'h' : h, 'vertices' : ds_mesh.vertices, 'normals' : ds_mesh.vertex_normals}

In [140]:
v1 = np.array([0., 0., 1.])
def phi(x, xp, h=0.01):
    w, v        = jnp.split(x,2)
    norm, root  = jnp.split(xp, 2)
    _R = SO3.exp(w)
    return jnp.exp(
        -(jnp.sum((norm+_R@v1)**2)+ np.sum((norm+root-v)**2))/h
    )

def RBF(x, xp, h=0.01):
    w, v        = jnp.split(x,2)
    norm, root  = jnp.split(xp, 2)
    return jnp.exp(
        -jnp.sum((root-v)**2)/h
    ) * jnp.exp(
        -0.01*jnp.sum((norm-w)**2)/h
    )
    # return jnp.exp(
    #     -jnp.sum((x-xp)**2)/h
    # )

In [141]:
phi_vmapped = vmap(vmap(phi, in_axes=(0,None,None)),in_axes=(None,0,None))

In [142]:
KernelMatrix = create_kernel_matrix(RBF)

x0 =  X_samples.min(axis=0)
xf =  X_samples.max(axis=0)
T = 100
X_init = np.linspace(x0, xf, num=T, endpoint=True)
V_init = np.linspace(x0, xf, num=T, endpoint=True)

X = jnp.hstack([V_init, X_init])

sol, unflatten_X = ravel_pytree(X)
bounds = (-np.inf * np.ones_like(sol), np.inf * np.ones_like(sol))

def ergodic_mmd(flat_X, args):
    X_init = unflatten_X(flat_X)
    w, v        = jnp.split(X_init,2, axis=1)

    T = X_init.shape[0]
    h = args['h']
    # x0 = args['x0']
    # vertices = args['vertices']
    # norms    = args['normals']
    # XP = jnp.hstack([norms, vertices])
    # P_XI      = jnp.sum(phi_vmapped(X_init, XP, h), axis=1)
    return np.sum(KernelMatrix(X_init, X_init, h))/(T**2) \
            - 2 * np.sum(P_XI @ KernelMatrix(X_init, XP, h))/T \
            + np.mean((v[1:]-v[:-1])**2)
            # + np.sum((X_init[0]-x0)**2)

solver = jaxopt.ProjectedGradient(fun=ergodic_mmd, projection=jaxopt.projection.projection_box, tol=1e-6)
solver_state = solver.init_state(sol, hyperparams_proj=bounds)
# sol = solver.run(init_params=sol, 
#                         hyperparams_proj=bounds,  
#                         args=args).params

X = unflatten_X(sol)

In [144]:
ps.init()

ps_bunny_mesh = ps.register_surface_mesh("bunny", ds_mesh.vertices, ds_mesh.faces)
ps_bunny_mesh.set_transparency(0.8)
ps_traj         = ps.register_curve_network("trajectory", X_init * 1.25, edges="line")

for _ in range(1000):
    (sol, solver_state) = solver.update(
                        params=sol, 
                        state=solver_state, 
                        hyperparams_proj=bounds, 
                                args=args)
    X = unflatten_X(sol)
    ps_traj.update_node_positions(X[:,3:])
    ps_traj.add_vector_quantity("vec img", X[:,:3])
    ps.frame_tick()

    time.sleep(0.001)

KeyboardInterrupt: 

In [23]:
ps.init()

ps_bunny_mesh = ps.register_surface_mesh("bunny", ds_mesh.vertices, ds_mesh.faces)
ps_bunny_mesh.set_transparency(0.8)
# ps_bunny_mesh.add_scalar_quantity('info_distr', mesh_func.func_vals, defined_on='vertices', cmap='blues')

ps_traj         = ps.register_curve_network("trajectory", X * 1.25, edges="line")

ps.show()

KeyboardInterrupt: 

In [108]:
X[:,:3]

Array([[-1.2171284 , -1.1115862 , -1.1009053 ],
       [-1.2152026 , -0.95313704, -0.7008316 ],
       [-0.95265466, -0.7279403 , -1.0201361 ],
       [-0.42642927, -0.5275799 , -0.7298151 ],
       [-1.0408267 , -0.7709126 , -0.4051031 ],
       [-0.74579436, -0.26336494, -0.6057465 ],
       [-0.44910622, -0.7464509 , -0.4842489 ],
       [-0.57261825, -0.31917834, -0.75050336],
       [-0.8659484 , -0.31514153, -0.37906358],
       [-0.57987833, -0.8105865 , -0.03539886],
       [-0.9161823 , -0.09556733, -0.38340583],
       [-0.70823085, -0.6259475 , -0.31731534],
       [-0.49622777, -0.83794636, -0.21826875],
       [ 0.14867975, -0.31548867, -0.9335648 ],
       [-0.95671546, -0.18534584, -0.21623643],
       [-0.5573607 , -0.16308616, -0.339997  ],
       [-0.41934413, -0.0366478 , -0.9036501 ],
       [-0.30767146, -0.2785412 , -0.10740983],
       [-0.69494236, -0.67872554,  0.22421643],
       [ 0.21281777, -0.1129107 , -0.9673629 ],
       [-0.9792235 ,  0.00802065, -0.187

In [22]:
def lie_RBF(x, xp, h=0.01):
    w, v        = jnp.split(x,2)
    norm, root  = jnp.split(xp, 2)
    _R = SO3.exp(w)
    return jnp.exp(
        -(jnp.sum((norm+_R@v1)**2) + np.sum((norm+root-v)**2))/h
    )

def create_kernel_matrix(kernel):
    return vmap(vmap(kernel, in_axes=(0, None, None)), in_axes=(None, 0, None))