In [1]:
import sys 
sys.path.append('../')
import jax 
import jax.numpy as jnp 
import jaxlie
import numpy as np
from jax import vmap
import jax.random as jax_random
from ergodic_mmd.aug_lagrange_jaxopt import AugmentedLagrangeSolver
import adam
from adam.jax import KinDynComputations

from plyfile import PlyData
import trimesh as tm

import matplotlib.pyplot as plt 
from IPython.display import clear_output
import polyscope as ps
import time

I0000 00:00:1726107909.580960       1 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.


In [2]:
plydata = PlyData.read('../assets/sphere.ply')
verts = np.vstack((
    plydata['vertex']['x'],
    plydata['vertex']['y'],
    plydata['vertex']['z']
)).T
faces = np.array(np.vstack(plydata['face']['vertex_indices']))

mesh = tm.Trimesh(vertices=verts, faces=faces)

num_points = 1000  # Change this number based on your requirement
points, face_indices = tm.sample.sample_surface(mesh, num_points)

h = 0.1
seed = 0 
key = jax_random.PRNGKey(seed)
key, subkey = jax_random.split(key, 2)
_std = 0.1*jax_random.normal(subkey, shape=(num_points,1))
_std = _std - _std.min()+0.1
_points = points + _std*mesh.face_normals[face_indices]
_points = mesh.vertices + 0.2 * mesh.vertex_normals

info_distr = lambda x: (jnp.sin(x[0]*x[1])+1)*jnp.exp(-60*(x[0]-0.5)**2 - 10*(x[1]-0.5)**2) + jnp.exp(-20*(x[2]-0.4)**2 - 30*(x[1]-0.5)**2)
P_XI = vmap(info_distr, in_axes=(0,))(mesh.vertices)
P_XI = P_XI-P_XI.min()+0.01
P_XI = P_XI/jnp.sum(P_XI)

args = {'h' : h, 'points' : _points, 'P_XI' : P_XI}

In [3]:
# ps.init()

# ps_mesh = ps.register_surface_mesh("bunny", mesh.vertices, mesh.faces)
# ps_mesh.add_scalar_quantity("face vals",vmap(info_distr, in_axes=(0,))(mesh.vertices))
# ps_points = ps.register_point_cloud("sampled points", args['points'])
# ps_points.add_scalar_quantity("results", args['P_XI'])


# ps.show()

In [7]:
def RBF_kernel(x, xp, h=0.01):
    return jnp.exp(
        -jnp.sum((x-xp)**2)/h
    )
def create_kernel_matrix(kernel):
    return vmap(vmap(kernel, in_axes=(0, None, None)), in_axes=(None, 0, None))

KernelMatrix = create_kernel_matrix(RBF_kernel)
def emmd_loss(params, args):
    X = params['X']
    T = X.shape[0]
    p, w = jnp.split(X, 2, axis=1)
    h = args['h']
    points    = args['points']
    P_XI      = args['P_XI']
    return np.sum(KernelMatrix(p, p, h))/(T**2) \
            - 2 * np.sum(P_XI @ KernelMatrix(p, points, h))/T \
                + jnp.mean(jnp.square(p[1:] - p[:-1]))

def eq_constr(params, args):
    X = params['X']
    p, w = jnp.split(X, 2, axis=1)
    return jnp.array([
        p[0]-args['points'][0],
        p[-1]-args['points'][-1]
    ])

def ineq_constr(params, args):
    return jnp.array(0.)


In [8]:
T = 60
# X = jnp.linspace(mesh.bounds[0], mesh.bounds[1], num=T)
X = jnp.concatenate(
    [ 
        jnp.linspace(args['points'][0], args['points'][-1], num=T),
        jnp.linspace(mesh.bounds[0], mesh.bounds[1], num=T)
    ],
    axis=1
)


params = {'X' : X}
solver = AugmentedLagrangeSolver(params, emmd_loss, eq_constr, ineq_constr, max_stepsize=1e-1, args=args)


In [11]:
ps.init()

ps_mesh = ps.register_surface_mesh("bunny", mesh.vertices, mesh.faces)
ps_points = ps.register_point_cloud("sampled points", args['points'])
ps_mesh.add_scalar_quantity("face vals",vmap(info_distr, in_axes=(0,))(mesh.vertices))

ps_points.add_scalar_quantity("results", args['P_XI'])

# ps_bunny_mesh.set_transparency(0.8)
# ps_bunny_mesh.add_scalar_quantity('info_distr', mesh_func.func_vals, defined_on='vertices', cmap='blues')

ps_traj         = ps.register_curve_network("trajectory", X[:,:3], edges="line")


for _ in range(1000):
    solver.solve(eps=1e-5, max_iter=10)
    sol = solver.solution
    X = sol['X']

    ps_traj.update_node_positions(X[:,:3])
    # ps_traj.add_vector_quantity("vec img", X[:,:3], enabled=True)
    ps.frame_tick()

    time.sleep(0.001)

0 0.001278553484157236
1 0.001278530996127902
2 0.0012785074471404123
3 0.0012784826683328208
4 0.0012784565523964973
5 0.0012784290587533433
6 0.0012784002154824696
7 0.0012783701179070177
8 0.0012783389239143122
9 0.001278306846241242
unsuccessful, tol:  0.001278306846241242
0 0.0012782741421041323
1 0.00127824110068085
2 0.0012782080290567519
3 0.0012781752373200543
4 0.0012781430235325557
5 0.0012781116593069224
6 0.0012780813766911346
7 0.0012780523569961202
8 0.0012780247221071458
9 0.0012779985286976935
unsuccessful, tol:  0.0012779985286976935
0 0.0012779737656226228
1 0.0012779503546120276
2 0.0012779281542262722
3 0.001277906966873817
4 0.0012778865485446062
5 0.0012778666207801216
6 0.0012778468842931537
7 0.0012778270335711088
8 0.0012778067717500005
9 0.0012777858250345897
unsuccessful, tol:  0.0012777858250345897
0 0.0012777639559634407
1 0.0012777409748754886
2 0.0012777167490232054
3 0.0012776912088932233
4 0.0012776643514318836
5 0.0012776362400243506
6 0.0012776070012

In [None]:
sol = solver.solution

In [None]:
X = sol['X']

ps.init()

ps_mesh = ps.register_surface_mesh("bunny", mesh.vertices, mesh.faces)
ps_points = ps.register_point_cloud("sampled points", args['points'])
ps_mesh.add_scalar_quantity("face vals",vmap(info_distr, in_axes=(0,))(mesh.vertices))

ps_points.add_scalar_quantity("results", args['P_XI'])

# ps_bunny_mesh.set_transparency(0.8)
# ps_bunny_mesh.add_scalar_quantity('info_distr', mesh_func.func_vals, defined_on='vertices', cmap='blues')

ps_traj         = ps.register_curve_network("trajectory", X , edges="line")

ps.show()