In [17]:
import sys 
sys.path.append('../')
import jax 
import jax.numpy as jnp 
import numpy as np
from jaxlie import SE3, SO3
from jax.random import normal
from jax import grad, hessian, vmap, pmap
from jax.flatten_util import ravel_pytree
from jax.lax import scan
from functools import partial
import jax.random as jax_random
import matplotlib.pyplot as plt 
from IPython.display import clear_output
from plyfile import PlyData
import open3d as o3d 
import trimesh as tm
import trimesh
import time
import jaxopt
import polyscope as ps


import meshcat
import meshcat.geometry as mc_geom

from ergodic_mmd.aug_lagrange_jaxopt import AugmentedLagrangeSolver
# from ergodic_mmd.aug_lagrange_solver import AugmentedLagrangeSolver

import adam
from adam.jax import KinDynComputations

In [2]:
from robomeshcat import Object, Robot, Scene
from example_robot_data.robots_loader import PandaLoader
from pathlib import Path

"Create a scene that stores all objects and robots and has rendering capability"
scene = Scene()

robot = Robot(urdf_path=PandaLoader().df_path, mesh_folder_path=Path(PandaLoader().model_path).parent.parent)
scene.add_robot(robot)
"Render the initial scene"
scene.render()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7003/static/


In [24]:
dataset = o3d.data.BunnyMesh()
plydata = PlyData.read(dataset.path)
# plydata = PlyData.read('../assets/sphere.ply')
verts = np.vstack((
    plydata['vertex']['x'], # offset mesh location by half a meter
    plydata['vertex']['y'],
    plydata['vertex']['z']
)).T
faces = np.array(np.vstack(plydata['face']['vertex_indices']))

mesh = tm.Trimesh(vertices=verts, faces=faces)
mesh.apply_scale(1.2)
mesh.apply_transform(
    SE3.from_rotation(SO3.from_x_radians(np.pi/2)).as_matrix()
    )
mesh.apply_transform(
    SE3.from_rotation(SO3.from_z_radians(np.pi/2)).as_matrix()
    )
mesh.apply_translation(np.array([0.5,0.,0.1]))


num_points = 1000  # Change this number based on your requirement
points, face_indices = tm.sample.sample_surface(mesh.convex_hull, num_points)

info_distr = lambda x: jnp.exp(-20*(x[0]-0.5)**2 - 20*(x[1]+0.1)**2 - 20*(x[2]-0.4)**2)
P_XI = vmap(info_distr, in_axes=(0,))(points)
P_XI = P_XI/jnp.sum(P_XI)
h = 0.01
args = {'h' : h, 
        'points' : points+0.05*mesh.convex_hull.face_normals[face_indices], 
        'normals' : mesh.convex_hull.face_normals[face_indices], 
        'P_XI' : P_XI}


# exp_obj = trimesh.exchange.obj.export_obj(mesh)
# mc_obj = mc_geom.ObjMeshGeometry.from_stream(trimesh.util.wrap_as_stream(exp_obj))
# mc_obj = Object(mc_obj, name='red_sphere', opacity=0.5, color=[1., 0., 0.])
# scene.add_object(mc_obj)

mesh_vertex_color_info = np.zeros_like(mesh.vertices)
_color_map = vmap(info_distr)(mesh.vertices)
_color_map = _color_map - _color_map.min()
_color_map = _color_map/_color_map. max()
mesh_vertex_color_info[:,0] = _color_map
mc_obj = mc_geom.TriangularMeshGeometry(mesh.vertices, mesh.faces, color=mesh_vertex_color_info)

scene.vis['obj'].set_object(
    mc_obj, 
    mc_geom.MeshLambertMaterial(reflectivity=0.9, vertexColors=True)
)

# mc_obj = Object(mc_obj, name='red_sphere', opacity=1)
# scene.add_object(mc_obj)



In [19]:
model_path = "../assets/panda.urdf"
# The joint list
joints_name_list = [
    'panda_joint1', 'panda_joint2', 'panda_joint3', 'panda_joint4', 
    'panda_joint5', 'panda_joint6', 'panda_joint7'
]

kinDyn = KinDynComputations(model_path, joints_name_list)
# kinDyn.set_frame_velocity_representation(adam.Representations.BODY_FIXED_REPRESENTATION)
w_H_ee = kinDyn.forward_kinematics_fun("panda_grasptarget")
w_H_b = np.eye(4) # base frame 

q0 = jnp.array([0,-np.pi/3,0,-3, 0, 2.5, -0.7853])
args.update({'q0':q0})
# w_H_ee(w_H_b, q0)
for i, jt in enumerate(q0):
    robot[i] = jt
scene.render()

Unknown tag "material" in /robot[@name='panda']/link[@name='panda_link0']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_link1']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_link2']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_link4']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_link5']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_link6']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_link7']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_hand']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_leftfinger']/collision[1]
Unknown tag "contact" in /robot[@name='panda']/link[@name='panda_leftfinger']
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_rightfinger']/collision[1]
Unknown tag "contact" in /robot[@name='pan

In [20]:
def get_fk(q):
    return SE3.from_matrix(w_H_ee(w_H_b, q))

def RBF_kernel(x, xp, h=0.01):
    return jnp.exp(
        -jnp.sum((x-xp)**2)/h
    )

# def camera_view_penalty(g, v, h=0.01):
#     p = g.translation()
#     rot = g.rotation()
#     w = rot.apply(jnp.array([0.,0.,1.]))
#     root, norm  = jnp.split(v, 2)
#     return jnp.exp(-jnp.sum((root-p)**2)/h - jnp.sum((-norm-w)**2)/h)
#     # return jnp.exp(-jnp.sum((root-p)**2)/h)*jnp.sum((-norm-w)**2)

def camera_view_penalty(g, v, h=0.01):
    p = g.translation()
    rot = g.rotation()
    w = rot.apply(jnp.array([0.,0.,1.]))
    root, norm  = jnp.split(v, 2)
    # return jnp.exp(-jnp.sum((root-p)**2)/h - 0.01*jnp.sum((-norm-w)**2)/h)
    return jnp.exp(
        -jnp.sum((root-p)**2)/h #- 1e-1*jnp.sum((norm-w)**2)
    )*jnp.dot(-norm,w)



# def RBF_kernel(g, v, h=0.01):
#     p = g.translation()
#     rot = g.rotation()
#     w = rot.apply(jnp.array([0.,0.,1.]))
#     root, norm  = jnp.split(v, 2)
#     return jnp.exp(
#         -jnp.sum((root-p)**2)/h #- 1e-1*jnp.sum((norm-w)**2)
#     ) * jnp.exp(
#         -10*jnp.sum((-norm-w)**2)
#     )

def create_kernel_matrix(kernel):
    return vmap(vmap(kernel, in_axes=(0, None, None)), in_axes=(None, 0, None))

camera_view_matrix = create_kernel_matrix(camera_view_penalty)
KernelMatrix = create_kernel_matrix(RBF_kernel)
def emmd_loss(params, args):
    X = params['X']
    T = X.shape[0]
    h = args['h']
    q0 = args['q0']
    points    = args['points']
    norms     = args['normals']
    P_XI      = args['P_XI']

    g = vmap(get_fk)(X)
    p = g.translation()
    
    view_matrix = camera_view_matrix(g, jnp.hstack([points, norms]), 0.01)

    return np.sum(KernelMatrix(p, p, h))/(T**2) \
            - 2 * np.sum(P_XI @ KernelMatrix(p, points, h))/T \
            - np.sum(P_XI@view_matrix)/T \
                    + 1e-5*jnp.sum(jnp.square(X[0]-q0)) + 1e-5*jnp.sum(jnp.square(X[-1]-q0)) \
                        + jnp.mean(jnp.square(X[1:] - X[:-1]))

def eq_constr(params, args):
    X = params['X']
    q0 = args['q0']
    return jnp.zeros(1)#jnp.hstack([X[0]-q0, X[-1]-q0])

def ineq_constr(params, args):
    X = params['X']
    return jnp.zeros(1)#jnp.square(X[1:] - X[:-1]) - 0.1**2


In [21]:
T = 100
X = jnp.linspace(q0, q0+0.1, num=T)

params = {'X' : X}

solver = AugmentedLagrangeSolver(params, emmd_loss, eq_constr, ineq_constr, args=args)
# solver = AugmentedLagrangeSolver(
#     params, emmd_loss, eq_constr, ineq_constr, 
#     step_size=1e-3,
#     args=args)
# 

In [22]:
solver.solve(max_iter=15_000, eps=1e-3, alpha=1+1e-4)

0 0.47328748999613557
1 0.17358217662566086
2 0.11804033040151654
3 0.09449714828094317
4 0.08799482100322785
5 0.06329191526315699
6 0.05324587854644871
7 0.050800520954583435
8 0.06725505125372441
9 0.046572872289560065
10 0.047268825940630206
11 0.0329454013932032
12 0.03327330212012984
13 0.02553714235757124
14 0.02149103257807741
15 0.024674913341543377
16 0.016485948872712684
17 0.014984893394986876
18 0.020307604859846486
19 0.013849436983269557
20 0.018964331698563198
21 0.014315101986166003
22 0.0195554027668191
23 0.01687723855449231
24 0.01941670722963453
25 0.012415795684411054
26 0.01150971921175097
27 0.011082436599056023
28 0.010159759678567842
29 0.008346029451778855
30 0.008194667028220961
31 0.011595811976211866
32 0.010374366988341172
33 0.01635657803872897
34 0.012089558576117774
35 0.009453344258223126
36 0.009393168233341756
37 0.009330841638580264
38 0.010368763030808413
39 0.012385524378820172
40 0.009033466760047466
41 0.008840634513568141
42 0.0098297562593006

In [14]:
# sol = solver.solution
# X = sol['X']
# vertices = vmap(get_fk)(X).translation().astype(np.float32).T
# scene.vis['lines_segments'].set_object(mc_geom.Line(mc_geom.PointsGeometry(vertices), mc_geom.MeshLambertMaterial(color=0xff0000)))
# ee_frame = scene.vis['ee']
# ee_frame.set_object(mc_geom.triad(scale=0.2))
# for q in X:
#     _tf = np.array(get_fk(q).as_matrix()).astype(np.float64)
#     ee_frame.set_transform(_tf)
#     for i, jt in enumerate(q):
#         robot[i] = jt
#     time.sleep(0.01)
#     scene.render()

In [29]:
# scene = Scene()
sol = solver.solution
X = sol['X']

vertices = vmap(get_fk)(X).translation().astype(np.float32).T
scene.vis['lines_segments'].set_object(mc_geom.Line(mc_geom.PointsGeometry(vertices), mc_geom.MeshLambertMaterial(color=0x0000FF, linewidth=5)))


# robot_img = []
# for i, q in enumerate(X[::int(T/5)]):
#     # _tf = np.array(get_fk(q).as_matrix()).astype(np.float64)
#     # ee_frame.set_transform(_tf)
#     _robot = Robot(name='robot{}'.format(i),urdf_path=PandaLoader().df_path, mesh_folder_path=Path(PandaLoader().model_path).parent.parent, opacity=0.75)
#     robot_img.append(_robot)
#     scene.add_robot(_robot)
#     for j, jt in enumerate(q):
#         _robot[j] = jt
# "Render the initial scene"
scene.render()

In [None]:
# # ps.set_up_dir("z_up")
# # ps.set_front_dir("neg_y_front")
# ps.init()


# ps_mesh = ps.register_surface_mesh("bunny", mesh.vertices, mesh.faces)
# # ps_points = ps.register_point_cloud("sampled points", args['points'][:,:3])
# # ps_mesh.add_scalar_quantity("face vals",vmap(info_distr, in_axes=(0,))(mesh.vertices))

# # ps_points.add_scalar_quantity("results", args['P_XI'])

# # ps_bunny_mesh.set_transparency(0.8)
# # ps_bunny_mesh.add_scalar_quantity('info_distr', mesh_func.func_vals, defined_on='vertices', cmap='blues')

# ps_traj         = ps.register_curve_network("trajectory", vmap(get_fk)(X).translation() , edges="line")


# for _ in range(1000):
#     solver.solve(eps=1e-5, max_iter=10)
#     sol = solver.solution
#     X = sol['X']

#     ps_traj.update_node_positions(vmap(get_fk)(X).translation() )
#     # ps_traj.add_vector_quantity("vec img", X[:,:3], enabled=True)
#     ps.frame_tick()

#     time.sleep(0.001)

In [None]:
solver.solve(eps=1e-3)

In [22]:
sol = solver.solution

In [64]:
X = sol['X']

ps.init()

ps_bunny_mesh = ps.register_surface_mesh("bunny", mesh.vertices, mesh.faces)
# ps_bunny_mesh.set_transparency(0.8)
ps_bunny_mesh.add_scalar_quantity('info_distr', vmap(info_distr, in_axes=(0,))(mesh.vertices), defined_on='vertices', cmap='blues')

ps_traj         = ps.register_curve_network("trajectory", vmap(get_fk)(X).translation() , edges="line")

ps.show()

[polyscope] Backend: openGL3_glfw -- Loaded openGL version: 4.1 Metal - 88
