In [1]:
import sys 
sys.path.append('../')
import jax 
import jax.numpy as jnp 
import numpy as np
from jax.random import normal
from jax import grad, hessian, vmap, pmap
from jax.flatten_util import ravel_pytree
from jax.lax import scan
from functools import partial
import jax.random as jax_random
import matplotlib.pyplot as plt 
from IPython.display import clear_output
from plyfile import PlyData
import open3d as o3d 
import trimesh as tm
from trimesh.points import tsp

import jaxopt
import polyscope as ps

from ergodic_mmd.aug_lagrange_jaxopt import AugmentedLagrangeSolver

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
def RBF_kernel(x, xp, h=0.01):
    return jnp.exp(
        -jnp.sum((x-xp)**2)/h
    )
def create_kernel_matrix(kernel):
    return vmap(vmap(kernel, in_axes=(0, None, None)), in_axes=(None, 0, None))

In [19]:
dataset = o3d.data.BunnyMesh()
plydata = PlyData.read(dataset.path)
verts = np.vstack((
    plydata['vertex']['x'],
    plydata['vertex']['y'],
    plydata['vertex']['z']
)).T
faces = np.array(np.vstack(plydata['face']['vertex_indices']))

mesh = tm.Trimesh(vertices=verts, faces=faces)
mesh.apply_scale(2)

num_points = 200  # Change this number based on your requirement
points, face_indices = tm.sample.sample_surface(mesh, num_points)

# traversal, distances = tsp(mesh.vertices)

info_distr = lambda x: 1.0
P_XI = vmap(info_distr, in_axes=(0,))(points)
P_XI = P_XI/jnp.sum(P_XI)
h = 0.005
args = {'h' : h, 'points' : points+0.025*mesh.face_normals[face_indices], 'P_XI' : P_XI}

In [20]:
KernelMatrix = create_kernel_matrix(RBF_kernel)
def emmd_loss(params, args):
    X = params['X']
    T = X.shape[0]
    h = args['h']
    points    = args['points']
    P_XI      = args['P_XI']
    return np.sum(KernelMatrix(X, X, h))/(T**2) \
            - 2 * np.sum(P_XI @ KernelMatrix(X, points, h))/T + 2*jnp.mean(jnp.square(X[1:] - X[:-1]))

def eq_constr(params, args):
    # X = params['X']
    return jnp.array(0.)

def ineq_constr(params, args):
    X = params['X']
    return jnp.array(0.)


In [21]:
T = 500
X = jnp.linspace(mesh.bounds[0], mesh.bounds[1], num=T)

params = {'X' : X}

solver = AugmentedLagrangeSolver(params, emmd_loss, eq_constr, ineq_constr, args=args)


In [22]:
solver.solve(eps=1e-4)

0 0.07709807230370763
1 0.07485473313062958
2 0.08305244951693054
3 0.06721189341924351
4 0.04262858741614743
5 0.03306376631091034
6 0.024941022484519612
7 0.023004378521831667
8 0.014867215498917985
9 0.012538333558766388
10 0.008473451951043301
11 0.006785202559427398
12 0.006147109766439859
13 0.003863132432229546
14 0.0037515849876878873
15 0.00305579477684015
16 0.0038511308773333847
17 0.0021402466171832287
18 0.0020938313288338373
19 0.0021689064857600324
20 0.0025422015430923016
21 0.0018183639172000602
22 0.0015960255942373308
23 0.001506179546990802
24 0.0017297882598857558
25 0.0014440408406909513
26 0.0019552968383680677
27 0.0012674391295229965
28 0.001272206426638075
29 0.0012775424284821363
30 0.0016100701618236752
31 0.0011600829292179678
32 0.001157904668960654
33 0.0010586051633030794
34 0.0008643603554284957
35 0.001036033786814931
36 0.0009837621760655014
37 0.0012365374935777655
38 0.000916091410607164
39 0.0008814763073950943
40 0.0008583186067441989
41 0.0007744

In [23]:
sol = solver.solution

In [27]:
traversal, distances = tsp(args['points'])

In [31]:
X = sol['X']

ps.init()

ps_bunny_mesh = ps.register_surface_mesh("bunny", mesh.vertices, mesh.faces)
# ps_bunny_mesh.set_transparency(0.8)
# ps_bunny_mesh.add_scalar_quantity('info_distr', mesh_func.func_vals, defined_on='vertices', cmap='blues')
ps_bunny_pcl = ps.register_point_cloud("bunny points", args['points'])

ps_traj         = ps.register_curve_network("trajectory", X , edges="line")

ps_tsp_traj         = ps.register_curve_network("tsp_traj", args['points'][traversal], edges="line")

ps.show()