In [7]:
import sys 
sys.path.append('../')
import jax 
import jax.numpy as jnp 
import numpy as np
from jaxlie import SE3, SO3
from jax.random import normal
from jax import grad, hessian, vmap, pmap
from jax.flatten_util import ravel_pytree
from jax.lax import scan
from functools import partial
import jax.random as jax_random
import matplotlib.pyplot as plt 
from IPython.display import clear_output
from plyfile import PlyData
import open3d as o3d 
import trimesh as tm
import time
import jaxopt
import polyscope as ps

# from ergodic_mmd.aug_lagrange_jaxopt import AugmentedLagrangeSolver
from ergodic_mmd.aug_lagrange_solver import AugmentedLagrangeSolver

import adam
from adam.jax import KinDynComputations

In [8]:
dataset = o3d.data.BunnyMesh()
plydata = PlyData.read(dataset.path)
verts = np.vstack((
    plydata['vertex']['x'], # offset mesh location by half a meter
    plydata['vertex']['y']+0.5,
    plydata['vertex']['z']
)).T
faces = np.array(np.vstack(plydata['face']['vertex_indices']))

mesh = tm.Trimesh(vertices=verts, faces=faces)

num_points = 100  # Change this number based on your requirement
points, face_indices = tm.sample.sample_surface(mesh, num_points)

info_distr = lambda x: 1.0
P_XI = vmap(info_distr, in_axes=(0,))(points)
P_XI = P_XI/jnp.sum(P_XI)
h = 0.01
args = {'h' : h, 'points' : points+0.1*mesh.face_normals[face_indices], 'P_XI' : P_XI}

In [3]:
model_path = "../assets/panda.urdf"
# The joint list
joints_name_list = [
    'panda_joint1', 'panda_joint2', 'panda_joint3', 'panda_joint4', 
    'panda_joint5', 'panda_joint6', 'panda_joint7', 'panda_joint8'
]

kinDyn = KinDynComputations(model_path, joints_name_list)
# kinDyn.set_frame_velocity_representation(adam.Representations.BODY_FIXED_REPRESENTATION)
w_H_ee = kinDyn.forward_kinematics_fun("panda_hand")
w_H_b = np.eye(4) # base frame 

q0 = jnp.array([0,0,0,-1.57079, 0, 1.57079, -0.7853, 0.04])

Unknown tag "material" in /robot[@name='panda']/link[@name='panda_link0']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_link1']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_link2']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_link4']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_link5']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_link6']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_link7']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_hand']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_leftfinger']/collision[1]
Unknown tag "contact" in /robot[@name='panda']/link[@name='panda_leftfinger']
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_rightfinger']/collision[1]
Unknown tag "contact" in /robot[@name='pan

In [4]:
def get_fk(q):
    return SE3.from_matrix(w_H_ee(w_H_b, q))


def RBF_kernel(x, xp, h=0.01):
    return jnp.exp(
        -jnp.sum((x-xp)**2)/h
    )
def create_kernel_matrix(kernel):
    return vmap(vmap(kernel, in_axes=(0, None, None)), in_axes=(None, 0, None))

KernelMatrix = create_kernel_matrix(RBF_kernel)
def emmd_loss(params, args):
    X = params['X']
    p = vmap(get_fk)(X).translation()
    T = X.shape[0]
    h = args['h']
    points    = args['points']
    P_XI      = args['P_XI']
    return np.sum(KernelMatrix(p, p, h))/(T**2) \
            - 2 * np.sum(P_XI @ KernelMatrix(p, points, h))/T \
                + jnp.mean(jnp.square(p[1:] - p[:-1])) + jnp.mean(jnp.square(X-q0))

def eq_constr(params, args):
    X = params['X']
    return jnp.array(0.)

def ineq_constr(params, args):
    X = params['X']
    return jnp.array(0.)


In [5]:
T = 100
X = jnp.linspace(q0, q0+0.1, num=T)

params = {'X' : X}

# solver = AugmentedLagrangeSolver(params, emmd_loss, eq_constr, ineq_constr, max_stepsize=3e-1, args=args)
solver = AugmentedLagrangeSolver(params, emmd_loss, eq_constr, ineq_constr, args=args)


In [6]:
ps.init()

ps_mesh = ps.register_surface_mesh("bunny", mesh.vertices, mesh.faces)
# ps_points = ps.register_point_cloud("sampled points", args['points'][:,:3])
# ps_mesh.add_scalar_quantity("face vals",vmap(info_distr, in_axes=(0,))(mesh.vertices))

# ps_points.add_scalar_quantity("results", args['P_XI'])

# ps_bunny_mesh.set_transparency(0.8)
# ps_bunny_mesh.add_scalar_quantity('info_distr', mesh_func.func_vals, defined_on='vertices', cmap='blues')

ps_traj         = ps.register_curve_network("trajectory", vmap(get_fk)(X).translation() , edges="line")


for _ in range(1000):
    solver.solve(eps=1e-5, max_iter=10)
    sol = solver.solution
    X = sol['X']

    ps_traj.update_node_positions(vmap(get_fk)(X).translation() )
    # ps_traj.add_vector_quantity("vec img", X[:,:3], enabled=True)
    ps.frame_tick()

    time.sleep(0.001)

[polyscope] Backend: openGL3_glfw -- Loaded openGL version: 4.1 Metal - 88
iter  0  loss  0.7956927  grad l2 norm  0.63298243
iter  1  loss  0.7656877  grad l2 norm  0.6296775
iter  2  loss  0.7426031  grad l2 norm  0.6236332
iter  3  loss  0.721843  grad l2 norm  0.61635333
iter  4  loss  0.70251274  grad l2 norm  0.60715234
iter  5  loss  0.6844517  grad l2 norm  0.5962354
iter  6  loss  0.6675812  grad l2 norm  0.5841345
iter  7  loss  0.651803  grad l2 norm  0.5714155
iter  8  loss  0.63698995  grad l2 norm  0.55858505
iter  9  loss  0.6230155  grad l2 norm  0.54597986
unsuccessful, tol:  0.54597986
iter  0  loss  0.6097695  grad l2 norm  0.5337978
iter  1  loss  0.5971662  grad l2 norm  0.5220881
iter  2  loss  0.58515364  grad l2 norm  0.510787
iter  3  loss  0.5737073  grad l2 norm  0.49980766
iter  4  loss  0.5628101  grad l2 norm  0.4890992
iter  5  loss  0.5524435  grad l2 norm  0.47864652
iter  6  loss  0.54258394  grad l2 norm  0.4684561
iter  7  loss  0.53320533  grad l2 n

KeyboardInterrupt: 

In [21]:
solver.solve(eps=1e-3)

0 0.5785384635452112
1 0.4719889521287219
2 0.38854021631700314
3 0.3337916149304583
4 0.2969637819884055
5 0.26975457353110294
6 0.24786540918348815
7 0.22942432756095105
8 0.21364087098349313
9 0.20009496651057743
10 0.18845733776603718
11 0.17841816501052601
12 0.16968983978746957
13 0.16202301611325395
14 0.15521576036429874
15 0.14911331832509261
16 0.14360171796136723
17 0.13859896018162907
18 0.1340464188846757
19 0.12990178597525162
20 0.1261339547003065
21 0.12271969129930642
22 0.11964170515026401
23 0.11688767400223166
24 0.11444982023446947
25 0.1123247058542261
26 0.11051298563174271
27 0.10901891729545474
28 0.10784947673489931
29 0.10701297467662003
30 0.10651713249435582
31 0.1063666590179928
32 0.1065604777829396
33 0.10708955755647143
34 0.10793551615652915
35 0.10906598959221869
36 0.11043307656765335
37 0.111973165303124
38 0.11360828157413358
39 0.11524898150613269
40 0.1167987187438838
41 0.11815958349524677
42 0.11923921692962128
43 0.11995840420229066
44 0.12025

KeyboardInterrupt: 

In [22]:
sol = solver.solution

In [23]:
X = sol['X']

ps.init()

ps_bunny_mesh = ps.register_surface_mesh("bunny", mesh.vertices, mesh.faces)
# ps_bunny_mesh.set_transparency(0.8)
# ps_bunny_mesh.add_scalar_quantity('info_distr', mesh_func.func_vals, defined_on='vertices', cmap='blues')

ps_traj         = ps.register_curve_network("trajectory", vmap(get_fk)(X).translation() , edges="line")

ps.show()

[polyscope] Backend: openGL3_glfw -- Loaded openGL version: 4.1 Metal - 88


KeyboardInterrupt: 