In [1]:
import sys 
sys.path.append('../')
import jax 
import jax.numpy as jnp 
import jaxlie
import numpy as np
from jax import vmap
import jax.random as jax_random
from ergodic_mmd.aug_lagrange_jaxopt import AugmentedLagrangeSolver
import adam
from adam.jax import KinDynComputations

from plyfile import PlyData
import trimesh as tm

import matplotlib.pyplot as plt 
from IPython.display import clear_output
import polyscope as ps

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [2]:
plydata = PlyData.read('../assets/sphere.ply')
verts = np.vstack((
    plydata['vertex']['x'],
    plydata['vertex']['y'],
    plydata['vertex']['z']
)).T
faces = np.array(np.vstack(plydata['face']['vertex_indices']))

mesh = tm.Trimesh(vertices=verts, faces=faces)

num_points = 100  # Change this number based on your requirement
points, face_indices = tm.sample.sample_surface(mesh, num_points)

info_distr = lambda x: 1.0
P_XI = vmap(info_distr, in_axes=(0,))(points)
P_XI = P_XI/jnp.sum(P_XI)

h = 0.01
seed = 0 
key = jax_random.PRNGKey(seed)
key, subkey = jax_random.split(key, 2)
_std = jax_random.normal(subkey, shape=(num_points,1))
_points = points+_std*mesh.face_normals[face_indices]
args = {'h' : h, 'points' : _points, 'P_XI' : P_XI}

In [3]:
def RBF_kernel(x, xp, h=0.01):
    return jnp.exp(
        -jnp.sum((x-xp)**2)/h
    )
def create_kernel_matrix(kernel):
    return vmap(vmap(kernel, in_axes=(0, None, None)), in_axes=(None, 0, None))

KernelMatrix = create_kernel_matrix(RBF_kernel)
def emmd_loss(params, args):
    X = params['X']
    T = X.shape[0]
    h = args['h']
    points    = args['points']
    P_XI      = args['P_XI']
    return np.sum(KernelMatrix(X, X, h))/(T**2) \
            - 2 * np.sum(P_XI @ KernelMatrix(X, points, h))/T \
                + jnp.mean(jnp.square(X[1:] - X[:-1]))

def eq_constr(params, args):
    return jnp.array(0.)

def ineq_constr(params, args):
    X = params['X']
    return jnp.array(0.)


In [4]:
T = 100
X = jnp.linspace(mesh.bounds[0], mesh.bounds[1], num=T)

params = {'X' : X}
solver = AugmentedLagrangeSolver(params, emmd_loss, eq_constr, ineq_constr, max_stepsize=1e-1, args=args)


In [5]:
solver.solve(eps=1e-4)

INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 1.3004464870189751e-05 Stepsize:0.1  Decrease Error:0.0  Curvature Error:1.3004464870189751e-05 
0 0.01152769139982022
INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 1.2942871182327355e-05 Stepsize:0.1  Decrease Error:0.0  Curvature Error:1.2942871182327355e-05 
1 0.011497821930373514
INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 1.2882954208821314e-05 Stepsize:0.1  Decrease Error:0.0  Curvature Error:1.2882954208821314e-05 
2 0.011468644891253522
INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 1.2824695585442186e-05 Stepsize:0.1  Decrease Error:0.0  Curvature Error:1.2824695585442186e-05 
3 0.01144015620729404
INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 1.2768077683614097e-05 Stepsize:0.1  Decrease Error:0.0  Curvature Error:1.2768077

KeyboardInterrupt: 

In [6]:
sol = solver.solution

In [1]:
X = sol['X']

ps.init()

ps_bunny_mesh = ps.register_surface_mesh("bunny", mesh.vertices, mesh.faces)
# ps_bunny_mesh.set_transparency(0.8)
# ps_bunny_mesh.add_scalar_quantity('info_distr', mesh_func.func_vals, defined_on='vertices', cmap='blues')

ps_traj         = ps.register_curve_network("trajectory", X , edges="line")

ps.show()

NameError: name 'sol' is not defined

In [None]:
X