In [1]:
import sys 
sys.path.append('../')
import jax 
import jax.numpy as jnp 
import numpy as np
from jaxlie import SE3, SO3
from jax.random import normal
from jax import grad, hessian, vmap, pmap
from jax.flatten_util import ravel_pytree
from jax.lax import scan
from functools import partial
import jax.random as jax_random
import matplotlib.pyplot as plt 
from IPython.display import clear_output
from plyfile import PlyData
import open3d as o3d 
import trimesh as tm
import trimesh
import time
import jaxopt
import polyscope as ps


import meshcat
import meshcat.geometry as mc_geom

# from ergodic_mmd.aug_lagrange_jaxopt import AugmentedLagrangeSolver
from ergodic_mmd.aug_lagrange_solver import AugmentedLagrangeSolver

import adam
from adam.jax import KinDynComputations

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
from robomeshcat import Object, Robot, Scene
from example_robot_data.robots_loader import PandaLoader
from pathlib import Path

"Create a scene that stores all objects and robots and has rendering capability"
scene = Scene()

robot = Robot(urdf_path=PandaLoader().df_path, mesh_folder_path=Path(PandaLoader().model_path).parent.parent)
scene.add_robot(robot)
"Render the initial scene"
scene.render()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7000/static/


In [96]:
dataset = o3d.data.BunnyMesh()
plydata = PlyData.read(dataset.path)
# plydata = PlyData.read('../assets/sphere.ply')
verts = np.vstack((
    plydata['vertex']['x'], # offset mesh location by half a meter
    plydata['vertex']['y'],
    plydata['vertex']['z']
)).T
faces = np.array(np.vstack(plydata['face']['vertex_indices']))

mesh = tm.Trimesh(vertices=verts, faces=faces)
mesh.apply_scale(1.2)
mesh.apply_transform(
    SE3.from_rotation(SO3.from_x_radians(np.pi/2)).as_matrix()
    )
mesh.apply_translation(np.array([0.5,0.,0.1]))


num_points = 1000  # Change this number based on your requirement
points, face_indices = tm.sample.sample_surface(mesh.convex_hull, num_points)

info_distr = lambda x: jnp.exp(-20*(x[0]-0.5)**2 - 20*(x[1]+0.1)**2 - 20*(x[2]-0.4)**2)
P_XI = vmap(info_distr, in_axes=(0,))(points)
P_XI = P_XI/jnp.sum(P_XI)
h = 0.01
args = {'h' : h, 
        'points' : points+0.05*mesh.convex_hull.face_normals[face_indices], 
        'normals' : mesh.convex_hull.face_normals[face_indices], 
        'P_XI' : P_XI}


# exp_obj = trimesh.exchange.obj.export_obj(mesh)
# mc_obj = mc_geom.ObjMeshGeometry.from_stream(trimesh.util.wrap_as_stream(exp_obj))
# mc_obj = Object(mc_obj, name='red_sphere', opacity=0.5, color=[1., 0., 0.])
# scene.add_object(mc_obj)

mesh_vertex_color_info = np.zeros_like(mesh.vertices)
_color_map = vmap(info_distr)(mesh.vertices)
_color_map = _color_map - _color_map.min()
_color_map = _color_map/_color_map. max()
mesh_vertex_color_info[:,0] = _color_map
mc_obj = mc_geom.TriangularMeshGeometry(mesh.vertices, mesh.faces, color=mesh_vertex_color_info)

scene.vis['obj'].set_object(
    mc_obj, 
    mc_geom.MeshLambertMaterial(reflectivity=0.9, vertexColors=True)
)

# mc_obj = Object(mc_obj, name='red_sphere', opacity=1)
# scene.add_object(mc_obj)



In [91]:
model_path = "../assets/panda.urdf"
# The joint list
joints_name_list = [
    'panda_joint1', 'panda_joint2', 'panda_joint3', 'panda_joint4', 
    'panda_joint5', 'panda_joint6', 'panda_joint7'
]

kinDyn = KinDynComputations(model_path, joints_name_list)
# kinDyn.set_frame_velocity_representation(adam.Representations.BODY_FIXED_REPRESENTATION)
w_H_ee = kinDyn.forward_kinematics_fun("panda_grasptarget")
w_H_b = np.eye(4) # base frame 

q0 = jnp.array([0,-np.pi/3,0,-3, 0, 2.5, -0.7853])
args.update({'q0':q0})
# w_H_ee(w_H_b, q0)
for i, jt in enumerate(q0):
    robot[i] = jt
scene.render()

Unknown tag "material" in /robot[@name='panda']/link[@name='panda_link0']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_link1']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_link2']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_link4']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_link5']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_link6']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_link7']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_hand']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_leftfinger']/collision[1]
Unknown tag "contact" in /robot[@name='panda']/link[@name='panda_leftfinger']
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_rightfinger']/collision[1]
Unknown tag "contact" in /robot[@name='pan

In [92]:
def get_fk(q):
    return SE3.from_matrix(w_H_ee(w_H_b, q))

def RBF_kernel(x, xp, h=0.01):
    return jnp.exp(
        -jnp.sum((x-xp)**2)/h
    )

# def camera_view_penalty(g, v, h=0.01):
#     p = g.translation()
#     rot = g.rotation()
#     w = rot.apply(jnp.array([0.,0.,1.]))
#     root, norm  = jnp.split(v, 2)
#     return jnp.exp(-jnp.sum((root-p)**2)/h - jnp.sum((-norm-w)**2)/h)
#     # return jnp.exp(-jnp.sum((root-p)**2)/h)*jnp.sum((-norm-w)**2)

def camera_view_penalty(g, v, h=0.01):
    p = g.translation()
    rot = g.rotation()
    w = rot.apply(jnp.array([0.,0.,1.]))
    root, norm  = jnp.split(v, 2)
    # return jnp.exp(-jnp.sum((root-p)**2)/h - 0.01*jnp.sum((-norm-w)**2)/h)
    return jnp.exp(
        -jnp.sum((root-p)**2)/h #- 1e-1*jnp.sum((norm-w)**2)
    )*jnp.dot(-norm,w)



# def RBF_kernel(g, v, h=0.01):
#     p = g.translation()
#     rot = g.rotation()
#     w = rot.apply(jnp.array([0.,0.,1.]))
#     root, norm  = jnp.split(v, 2)
#     return jnp.exp(
#         -jnp.sum((root-p)**2)/h #- 1e-1*jnp.sum((norm-w)**2)
#     ) * jnp.exp(
#         -10*jnp.sum((-norm-w)**2)
#     )

def create_kernel_matrix(kernel):
    return vmap(vmap(kernel, in_axes=(0, None, None)), in_axes=(None, 0, None))

camera_view_matrix = create_kernel_matrix(camera_view_penalty)
KernelMatrix = create_kernel_matrix(RBF_kernel)
def emmd_loss(params, args):
    X = params['X']
    T = X.shape[0]
    h = args['h']
    q0 = args['q0']
    points    = args['points']
    norms     = args['normals']
    P_XI      = args['P_XI']

    g = vmap(get_fk)(X)
    p = g.translation()
    
    view_matrix = camera_view_matrix(g, jnp.hstack([points, norms]), 0.1)

    return np.sum(KernelMatrix(p, p, h))/(T**2) \
            - 2 * np.sum(P_XI @ KernelMatrix(p, points, h))/T \
            - np.sum(P_XI@view_matrix)/T \
                    + 1e-5*jnp.mean(jnp.square(X-q0)) \
                        + jnp.mean(jnp.square(X[1:] - X[:-1]))

def eq_constr(params, args):
    X = params['X']
    q0 = args['q0']
    return jnp.hstack([X[0]-q0, X[-1]-q0])

def ineq_constr(params, args):
    X = params['X']
    return jnp.square(X[1:] - X[:-1]) - 0.1**2


In [93]:
T = 100
X = jnp.linspace(q0, q0+0.1, num=T)

params = {'X' : X}

# solver = AugmentedLagrangeSolver(params, emmd_loss, eq_constr, ineq_constr, args=args)
solver = AugmentedLagrangeSolver(
    params, emmd_loss, eq_constr, ineq_constr, 
    step_size=1e-3,
    args=args)
# 

In [94]:
solver.solve(max_iter=15_000, eps=1e-3, alpha=1+1e-4)

iter  0  loss  0.5871675  grad l2 norm  0.46018854
iter  1  loss  0.62646693  grad l2 norm  0.64370817
iter  2  loss  0.6644217  grad l2 norm  0.8517197
iter  3  loss  0.6996696  grad l2 norm  1.066953
iter  4  loss  0.7321252  grad l2 norm  1.2830049
iter  5  loss  0.7619387  grad l2 norm  1.4972678
iter  6  loss  0.7892666  grad l2 norm  1.7085488
iter  7  loss  0.8142911  grad l2 norm  1.9162997
iter  8  loss  0.8371874  grad l2 norm  2.1202693
iter  9  loss  0.85811675  grad l2 norm  2.3203523
iter  10  loss  0.87722415  grad l2 norm  2.5165217
iter  11  loss  0.8946333  grad l2 norm  2.7087927
iter  12  loss  0.9104439  grad l2 norm  2.8972018
iter  13  loss  0.9247347  grad l2 norm  3.0817952
iter  14  loss  0.9375694  grad l2 norm  3.2626207
iter  15  loss  0.94900393  grad l2 norm  3.439729
iter  16  loss  0.95908844  grad l2 norm  3.6131694
iter  17  loss  0.9678657  grad l2 norm  3.782988
iter  18  loss  0.97537136  grad l2 norm  3.9492266
iter  19  loss  0.98163617  grad l2 

KeyboardInterrupt: 

In [14]:
# sol = solver.solution
# X = sol['X']
# vertices = vmap(get_fk)(X).translation().astype(np.float32).T
# scene.vis['lines_segments'].set_object(mc_geom.Line(mc_geom.PointsGeometry(vertices), mc_geom.MeshLambertMaterial(color=0xff0000)))
# ee_frame = scene.vis['ee']
# ee_frame.set_object(mc_geom.triad(scale=0.2))
# for q in X:
#     _tf = np.array(get_fk(q).as_matrix()).astype(np.float64)
#     ee_frame.set_transform(_tf)
#     for i, jt in enumerate(q):
#         robot[i] = jt
#     time.sleep(0.01)
#     scene.render()

In [95]:
scene = Scene()
sol = solver.solution
X = sol['X']

vertices = vmap(get_fk)(X).translation().astype(np.float32).T
scene.vis['lines_segments'].set_object(mc_geom.Line(mc_geom.PointsGeometry(vertices), mc_geom.MeshLambertMaterial(color=0xff0000, linewidth=2)))


robot_img = []
for i, q in enumerate(X[::int(T/5)]):
    # _tf = np.array(get_fk(q).as_matrix()).astype(np.float64)
    # ee_frame.set_transform(_tf)
    _robot = Robot(name='robot{}'.format(i),urdf_path=PandaLoader().df_path, mesh_folder_path=Path(PandaLoader().model_path).parent.parent, opacity=0.75)
    robot_img.append(_robot)
    scene.add_robot(_robot)
    for j, jt in enumerate(q):
        _robot[j] = jt
"Render the initial scene"
scene.render()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7012/static/


In [None]:
# # ps.set_up_dir("z_up")
# # ps.set_front_dir("neg_y_front")
# ps.init()


# ps_mesh = ps.register_surface_mesh("bunny", mesh.vertices, mesh.faces)
# # ps_points = ps.register_point_cloud("sampled points", args['points'][:,:3])
# # ps_mesh.add_scalar_quantity("face vals",vmap(info_distr, in_axes=(0,))(mesh.vertices))

# # ps_points.add_scalar_quantity("results", args['P_XI'])

# # ps_bunny_mesh.set_transparency(0.8)
# # ps_bunny_mesh.add_scalar_quantity('info_distr', mesh_func.func_vals, defined_on='vertices', cmap='blues')

# ps_traj         = ps.register_curve_network("trajectory", vmap(get_fk)(X).translation() , edges="line")


# for _ in range(1000):
#     solver.solve(eps=1e-5, max_iter=10)
#     sol = solver.solution
#     X = sol['X']

#     ps_traj.update_node_positions(vmap(get_fk)(X).translation() )
#     # ps_traj.add_vector_quantity("vec img", X[:,:3], enabled=True)
#     ps.frame_tick()

#     time.sleep(0.001)

In [None]:
solver.solve(eps=1e-3)

In [22]:
sol = solver.solution

In [64]:
X = sol['X']

ps.init()

ps_bunny_mesh = ps.register_surface_mesh("bunny", mesh.vertices, mesh.faces)
# ps_bunny_mesh.set_transparency(0.8)
ps_bunny_mesh.add_scalar_quantity('info_distr', vmap(info_distr, in_axes=(0,))(mesh.vertices), defined_on='vertices', cmap='blues')

ps_traj         = ps.register_curve_network("trajectory", vmap(get_fk)(X).translation() , edges="line")

ps.show()

[polyscope] Backend: openGL3_glfw -- Loaded openGL version: 4.1 Metal - 88
