In [4]:
import sys 
sys.path.append('../')
import jax 
import jax.numpy as jnp 
import numpy as np
from jax.random import normal
from jax import grad, hessian, vmap, pmap
from jax.flatten_util import ravel_pytree
from jax.lax import scan
from functools import partial
import jax.random as jax_random
import matplotlib.pyplot as plt 
from IPython.display import clear_output
from plyfile import PlyData
import open3d as o3d 
import trimesh as tm

import jaxopt
import polyscope as ps

from ergodic_mmd.aug_lagrange_jaxopt import AugmentedLagrangeSolver

import adam
from adam.jax import KinDynComputations

In [6]:
dataset = o3d.data.BunnyMesh()
plydata = PlyData.read(dataset.path)
verts = np.vstack((
    plydata['vertex']['x'] + 0.5, # offset mesh location by half a meter
    plydata['vertex']['y'],
    plydata['vertex']['z']
)).T
faces = np.array(np.vstack(plydata['face']['vertex_indices']))

mesh = tm.Trimesh(vertices=verts, faces=faces)

num_points = 100  # Change this number based on your requirement
points, face_indices = tm.sample.sample_surface(mesh, num_points)

info_distr = lambda x: 1.0
P_XI = vmap(info_distr, in_axes=(0,))(points)
P_XI = P_XI/jnp.sum(P_XI)
h = 0.01
args = {'h' : h, 'points' : points+0.1*mesh.face_normals[face_indices], 'P_XI' : P_XI}

In [11]:
model_path = "../assets/panda.urdf"
# The joint list
joints_name_list = [
    'panda_joint1', 'panda_joint2', 'panda_joint3', 'panda_joint4', 
    'panda_joint5', 'panda_joint6', 'panda_joint7', 'panda_joint8'
]

kinDyn = KinDynComputations(model_path, joints_name_list)
# kinDyn.set_frame_velocity_representation(adam.Representations.BODY_FIXED_REPRESENTATION)
w_H_ee = kinDyn.forward_kinematics_fun("panda_hand")
w_H_b = np.eye(4) # base frame 

q0 = jnp.array([0,0,0,-1.57079, 0, 1.57079, -0.7853, 0.04])

Unknown tag "material" in /robot[@name='panda']/link[@name='panda_link0']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_link1']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_link2']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_link4']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_link5']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_link6']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_link7']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_hand']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_leftfinger']/collision[1]
Unknown tag "contact" in /robot[@name='panda']/link[@name='panda_leftfinger']
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_rightfinger']/collision[1]
Unknown tag "contact" in /robot[@name='pan

In [10]:
w_H_ee(w_H_b, np.ones(len(joints_name_list))) @ jnp.array([0.,0.,0.,1.])

Array([0.30016132, 0.18034758, 0.8164567 , 1.        ], dtype=float64)

In [4]:
def RBF_kernel(x, xp, h=0.01):
    return jnp.exp(
        -jnp.sum((x-xp)**2)/h
    )
def create_kernel_matrix(kernel):
    return vmap(vmap(kernel, in_axes=(0, None, None)), in_axes=(None, 0, None))

KernelMatrix = create_kernel_matrix(RBF_kernel)
def emmd_loss(params, args):
    X = params['X']
    T = X.shape[0]
    h = args['h']
    points    = args['points']
    P_XI      = args['P_XI']
    return np.sum(KernelMatrix(X, X, h))/(T**2) \
            - 2 * np.sum(P_XI @ KernelMatrix(X, points, h))/T \
                + jnp.mean(jnp.square(X[1:] - X[:-1]))

def eq_constr(params, args):
    # X = params['X']
    return jnp.array(0.)

def ineq_constr(params, args):
    X = params['X']
    return jnp.array(0.)


In [5]:
T = 100
X = jnp.linspace(mesh.bounds[0], mesh.bounds[1], num=T)

params = {'X' : X}

solver = AugmentedLagrangeSolver(params, emmd_loss, eq_constr, ineq_constr, step_size=3e-3, args=args)


In [6]:
solver.solve(eps=1e-4)

0 0.1519310733095478
1 0.14834730386690362
2 0.18167940255484638
3 0.20739011172496477
4 0.16595977250016256
5 0.13859488221391295
6 0.12358499626706629
7 0.1130825514907888
8 0.09858905683996971
9 0.08246523865183482
10 0.06912886270444302
11 0.058801942301906655
12 0.05078126599189241
13 0.04457195209853205
14 0.03974251829114232
15 0.03589855672909651
16 0.03273108806440456
17 0.030033909506735484
18 0.02768587780073049
19 0.025620394041113444
20 0.023798775476625185
21 0.02219372067593518
22 0.02078161493323195
23 0.019539891611139875
24 0.0184464900911999
25 0.01748003475149007
26 0.016620306850830616
27 0.01584886206599042
28 0.015149659817085092
29 0.01450955787416282
30 0.013918547976616631
31 0.013369664652147452
32 0.012858575183499015
33 0.012382930748459997
34 0.011941604938706801
35 0.011533955088502423
36 0.011159217096988878
37 0.010816098643772759
38 0.010502585268897193
39 0.010215932724473603
40 0.00995279514719938
41 0.009709432631238865
42 0.009481948964253698
43 0.

In [13]:
sol = solver.solution

In [14]:
X = sol['X']

ps.init()

ps_bunny_mesh = ps.register_surface_mesh("bunny", mesh.vertices, mesh.faces)
# ps_bunny_mesh.set_transparency(0.8)
# ps_bunny_mesh.add_scalar_quantity('info_distr', mesh_func.func_vals, defined_on='vertices', cmap='blues')

ps_traj         = ps.register_curve_network("trajectory", X , edges="line")

ps.show()

KeyboardInterrupt: 