In [1]:
import sys 
sys.path.append('../')
import jax 
import jax.numpy as jnp 
import numpy as np
from jax.random import normal
from jax import grad, hessian, vmap, pmap
from jax.flatten_util import ravel_pytree
from jax.lax import scan
from functools import partial
import jax.random as jax_random
import matplotlib.pyplot as plt 
from IPython.display import clear_output
from plyfile import PlyData
import open3d as o3d 
import trimesh as tm
from trimesh.points import tsp

import jaxopt
import polyscope as ps

from ergodic_mmd.aug_lagrange_jaxopt import AugmentedLagrangeSolver

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
def RBF_kernel(x, xp, h=0.01):
    return jnp.exp(
        -jnp.sum((x-xp)**2)/h
    )
def create_kernel_matrix(kernel):
    return vmap(vmap(kernel, in_axes=(0, None, None)), in_axes=(None, 0, None))

In [27]:
mesh = tm.load_mesh("../assets/windturbine_structure.msh")
# mesh = tm.Trimesh(vertices=verts, faces=faces)
# mesh.apply_scale(10)

num_points = 200  # Change this number based on your requirement
points, face_indices = tm.sample.sample_surface(mesh, num_points)

# traversal, distances = tsp(mesh.vertices)

info_distr = lambda x: 1.0
P_XI = vmap(info_distr, in_axes=(0,))(points)
P_XI = P_XI/jnp.sum(P_XI)
h = 0.1
args = {'h' : h, 'points' : points+0.025*mesh.face_normals[face_indices], 'P_XI' : P_XI}

KeyError: 'msh'

In [20]:
KernelMatrix = create_kernel_matrix(RBF_kernel)
def emmd_loss(params, args):
    X = params['X']
    T = X.shape[0]
    h = args['h']
    points    = args['points']
    P_XI      = args['P_XI']
    return np.sum(KernelMatrix(X, X, h))/(T**2) \
            - 2 * np.sum(P_XI @ KernelMatrix(X, points, h))/T + 2*jnp.mean(jnp.square(X[1:] - X[:-1]))

def eq_constr(params, args):
    # X = params['X']
    return jnp.array(0.)

def ineq_constr(params, args):
    X = params['X']
    return jnp.array(0.)


In [21]:
T = 500
X = jnp.linspace(mesh.bounds[0], mesh.bounds[1], num=T)

params = {'X' : X}

solver = AugmentedLagrangeSolver(params, emmd_loss, eq_constr, ineq_constr, args=args)


In [22]:
solver.solve(eps=1e-4)

0 0.043504893980498276
1 0.04312788739230472
2 0.042242623769351624
3 0.04021909177118872
4 0.03574719736004658
5 0.026344919522577996
6 0.012936463114031528
7 0.009139310508476375
8 0.009529863605251692
9 0.007210860288795055
10 0.004359299199463897
11 0.004007207019870438
12 0.0036348482347789974
13 0.0032445329976897824
14 0.0026148654200426652
15 0.0022821588358991
16 0.00197307656976419
17 0.0018206682241725164
18 0.0015791751349817412
19 0.0013118955410145676
20 0.0012427933584157086
21 0.001153552858674567
22 0.0011650224383109558
23 0.0010520169222306606
24 0.0009445923948159367
25 0.001303052715963851
26 0.0009254173507695908
27 0.0007247476399851134
28 0.0007478495161108236
29 0.0010897562614875424
30 0.0008424202274733107
31 0.0006417259307417839
32 0.0006955548575823953
33 0.0006819009308838727
34 0.000691249956137817
35 0.0006991426593163815
36 0.0006606958803526615
37 0.0005934270554403217
38 0.0005955249211159914
39 0.0005730393238671126
40 0.0005697452652613789
41 0.000

In [23]:
sol = solver.solution

In [24]:
traversal, distances = tsp(args['points'])

In [25]:
mesh.bounds*0.01

array([[-0.0012454 ,  0.        , -0.00623514],
       [ 0.00131869,  0.02030674,  0.00604418]])

In [26]:
X = sol['X']

ps.init()

ps_bunny_mesh = ps.register_surface_mesh("bunny", mesh.vertices, mesh.faces)
# ps_bunny_mesh.set_transparency(0.8)
# ps_bunny_mesh.add_scalar_quantity('info_distr', mesh_func.func_vals, defined_on='vertices', cmap='blues')
ps_bunny_pcl = ps.register_point_cloud("bunny points", args['points'])

ps_traj         = ps.register_curve_network("trajectory", X , edges="line")

ps_tsp_traj         = ps.register_curve_network("tsp_traj", args['points'][traversal], edges="line")

ps.show()

In [20]:
mesh.bounds

array([[-0.12454 ,  0.      , -0.623514],
       [ 0.131869,  2.030674,  0.604418]])