In [48]:
import sys 
sys.path.append('../')
import jax 
import jax.numpy as jnp 
import numpy as np
from jaxlie import SE3, SO3
from jax.random import normal
from jax import grad, hessian, vmap, pmap
from jax.flatten_util import ravel_pytree
from jax.lax import scan
from functools import partial
import jax.random as jax_random
import matplotlib.pyplot as plt 
from IPython.display import clear_output
from plyfile import PlyData
import open3d as o3d 
import trimesh as tm
import trimesh
import time
import jaxopt
import polyscope as ps


import meshcat
import meshcat.geometry as mc_geom

from ergodic_mmd.aug_lagrange_jaxopt import AugmentedLagrangeSolver
# from ergodic_mmd.aug_lagrange_solver import AugmentedLagrangeSolver

import adam
from adam.jax import KinDynComputations

In [49]:
from robomeshcat import Object, Robot, Scene
from example_robot_data.robots_loader import PandaLoader
from pathlib import Path

"Create a scene that stores all objects and robots and has rendering capability"
scene = Scene()

robot = Robot(urdf_path=PandaLoader().df_path, mesh_folder_path=Path(PandaLoader().model_path).parent.parent)
scene.add_robot(robot)
"Render the initial scene"
scene.render()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7004/static/


In [50]:
dataset = o3d.data.BunnyMesh()
plydata = PlyData.read(dataset.path)
# plydata = PlyData.read('../assets/sphere.ply')
verts = np.vstack((
    plydata['vertex']['x'], # offset mesh location by half a meter
    plydata['vertex']['y'],
    plydata['vertex']['z']
)).T
faces = np.array(np.vstack(plydata['face']['vertex_indices']))

mesh = tm.Trimesh(vertices=verts, faces=faces)
mesh.apply_scale(1.5)
mesh.apply_transform(
    SE3.from_rotation(SO3.from_x_radians(np.pi/2)).as_matrix()
    )
mesh.apply_translation(np.array([0.5,0.,0.1]))


exp_obj = trimesh.exchange.obj.export_obj(mesh)
mc_obj = mc_geom.ObjMeshGeometry.from_stream(trimesh.util.wrap_as_stream(exp_obj))
mc_obj = Object(mc_obj, name='red_sphere', opacity=0.5, color=[1., 0., 0.])
scene.add_object(mc_obj)


num_points = 200  # Change this number based on your requirement
points, face_indices = tm.sample.sample_surface(mesh, num_points)

info_distr = lambda x: jnp.exp(-80*(x[0]-0.4)**2 - 20*x[1]**2 - 80*(x[2]-0.4)**2)
P_XI = vmap(info_distr, in_axes=(0,))(points)
P_XI = P_XI/jnp.sum(P_XI)
h = 0.1
args = {'h' : h, 
        'points' : points+0.05*mesh.face_normals[face_indices], 
        'normals' : mesh.face_normals[face_indices], 
        'P_XI' : P_XI}

In [51]:
model_path = "../assets/panda.urdf"
# The joint list
joints_name_list = [
    'panda_joint1', 'panda_joint2', 'panda_joint3', 'panda_joint4', 
    'panda_joint5', 'panda_joint6', 'panda_joint7'
]

kinDyn = KinDynComputations(model_path, joints_name_list)
# kinDyn.set_frame_velocity_representation(adam.Representations.BODY_FIXED_REPRESENTATION)
w_H_ee = kinDyn.forward_kinematics_fun("panda_grasptarget")
w_H_b = np.eye(4) # base frame 

q0 = jnp.array([0,-np.pi/3,0,-3, 0, 2.5, -0.7853])
args.update({'q0':q0})
# w_H_ee(w_H_b, q0)
for i, jt in enumerate(q0):
    robot[i] = jt
scene.render()

Unknown tag "material" in /robot[@name='panda']/link[@name='panda_link0']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_link1']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_link2']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_link4']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_link5']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_link6']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_link7']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_hand']/collision[1]
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_leftfinger']/collision[1]
Unknown tag "contact" in /robot[@name='panda']/link[@name='panda_leftfinger']
Unknown tag "material" in /robot[@name='panda']/link[@name='panda_rightfinger']/collision[1]
Unknown tag "contact" in /robot[@name='pan

In [52]:
def get_fk(q):
    return SE3.from_matrix(w_H_ee(w_H_b, q))

def RBF_kernel(x, xp, h=0.01):
    return jnp.exp(
        -jnp.sum((x-xp)**2)/h
    )

def camera_view_penalty(g, v, h=0.01):
    p = g.translation()
    rot = g.rotation()
    w = rot.apply(jnp.array([0.,0.,1.]))
    root, norm  = jnp.split(v, 2)
    return jnp.exp(-jnp.sum((root-p)**2)/h - jnp.sum((-norm-w)**2)/h)
    # return jnp.exp(-jnp.sum((root-p)**2)/h)*jnp.sum((-norm-w)**2)

# def RBF_kernel(g, v, h=0.01):
#     p = g.translation()
#     rot = g.rotation()
#     w = rot.apply(jnp.array([0.,0.,1.]))
#     root, norm  = jnp.split(v, 2)
#     return jnp.exp(
#         -jnp.sum((root-p)**2)/h #- 1e-1*jnp.sum((norm-w)**2)
#     ) * jnp.exp(
#         -10*jnp.sum((-norm-w)**2)
#     )

def create_kernel_matrix(kernel):
    return vmap(vmap(kernel, in_axes=(0, None, None)), in_axes=(None, 0, None))

camera_view_matrix = create_kernel_matrix(camera_view_penalty)
KernelMatrix = create_kernel_matrix(RBF_kernel)
def emmd_loss(params, args):
    X = params['X']
    T = X.shape[0]
    h = args['h']
    q0 = args['q0']
    points    = args['points']
    norms     = args['normals']
    P_XI      = args['P_XI']

    g = vmap(get_fk)(X)
    p = g.translation()
    
    view_matrix = camera_view_matrix(g, jnp.hstack([points, norms]), 0.1)

    return np.sum(KernelMatrix(p, p, h))/(T**2) \
            - 2 * np.sum(P_XI @ KernelMatrix(p, points, h))/T \
            - 10*np.sum(P_XI@view_matrix)/T \
                    + 1e-4*jnp.mean(jnp.square(X-q0))
                # + jnp.mean(jnp.square(X[1:] - X[:-1])) \

def eq_constr(params, args):
    X = params['X']
    q0 = args['q0']
    return jnp.hstack([X[0]-q0, X[-1]-q0])

def ineq_constr(params, args):
    X = params['X']
    return jnp.square(X[1:] - X[:-1]) - 0.01**2


In [53]:
T = 100
X = jnp.linspace(q0, q0+0.1, num=T)

params = {'X' : X}

# solver = AugmentedLagrangeSolver(params, emmd_loss, eq_constr, ineq_constr, max_stepsize=3e-1, args=args)
solver = AugmentedLagrangeSolver(params, emmd_loss, eq_constr, ineq_constr, args=args)
# 

In [54]:
solver.solve(max_iter=5_000, eps=1e-3)

0 0.16036986648848145
INFO: jaxopt.ZoomLineSearch: Iter: 1000, Stepsize: 7.06943815353509e-303, Decrease error: 0.006843806071840589, Curvature error: 0.003330872720651008
1 0.16709587162487632
INFO: jaxopt.ZoomLineSearch: Iter: 1000, Stepsize: 8.166033424396642e-303, Decrease error: 0.006843881364568234, Curvature error: 0.00442684433751629
2 0.17746151225614235
INFO: jaxopt.ZoomLineSearch: Iter: 1000, Stepsize: 9.500687150175778e-303, Decrease error: 0.006843956658049002, Curvature error: 0.005626035192125506
3 0.19087485219997588
INFO: jaxopt.ZoomLineSearch: Iter: 1000, Stepsize: 1.1254816604068813e-302, Decrease error: 0.00684403195228267, Curvature error: 0.006978392069386691
4 0.2067436431810234
INFO: jaxopt.ZoomLineSearch: Iter: 1000, Stepsize: 3.376689030762164e-303, Decrease error: 0.00684410724726924, Curvature error: 0.008470215239167673
5 0.22454799884368268
INFO: jaxopt.ZoomLineSearch: Iter: 1000, Stepsize: 5.752423944466565e-303, Decrease error: 0.0068441825430087125, Cur

XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: KeyboardInterrupt: <EMPTY MESSAGE>

At:
  /Users/ia285/miniconda3/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py(2309): _wrapped_callback
  /Users/ia285/miniconda3/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py(1144): __call__
  /Users/ia285/miniconda3/lib/python3.10/site-packages/jax/_src/profiler.py(340): wrapper
  /Users/ia285/miniconda3/lib/python3.10/site-packages/jax/_src/pjit.py(1152): _pjit_call_impl_python
  /Users/ia285/miniconda3/lib/python3.10/site-packages/jax/_src/pjit.py(1196): call_impl_cache_miss
  /Users/ia285/miniconda3/lib/python3.10/site-packages/jax/_src/pjit.py(1212): _pjit_call_impl
  /Users/ia285/miniconda3/lib/python3.10/site-packages/jax/_src/core.py(868): process_primitive
  /Users/ia285/miniconda3/lib/python3.10/site-packages/jax/_src/core.py(388): bind_with_trace
  /Users/ia285/miniconda3/lib/python3.10/site-packages/jax/_src/core.py(2656): bind
  /Users/ia285/miniconda3/lib/python3.10/site-packages/jax/_src/pjit.py(167): _python_pjit_helper
  /Users/ia285/miniconda3/lib/python3.10/site-packages/jax/_src/pjit.py(256): cache_miss
  /Users/ia285/miniconda3/lib/python3.10/site-packages/jax/_src/traceback_util.py(177): reraise_with_filtered_traceback
  /Users/ia285/PlayGround/ergodic_MMD/notebooks/../ergodic_mmd/aug_lagrange_jaxopt.py(50): step
  /Users/ia285/PlayGround/ergodic_MMD/notebooks/../ergodic_mmd/aug_lagrange_jaxopt.py(80): solve
  /var/folders/87/hlqv70yd0n79wj330vzqffx10y2fzc/T/ipykernel_84827/3059930979.py(1): <module>
  /Users/ia285/miniconda3/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3442): run_code
  /Users/ia285/miniconda3/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3382): run_ast_nodes
  /Users/ia285/miniconda3/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3203): run_cell_async
  /Users/ia285/miniconda3/lib/python3.10/site-packages/IPython/core/async_helpers.py(129): _pseudo_sync_runner
  /Users/ia285/miniconda3/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3000): _run_cell
  /Users/ia285/miniconda3/lib/python3.10/site-packages/IPython/core/interactiveshell.py(2945): run_cell
  /Users/ia285/miniconda3/lib/python3.10/site-packages/ipykernel/zmqshell.py(540): run_cell
  /Users/ia285/miniconda3/lib/python3.10/site-packages/ipykernel/ipkernel.py(417): do_execute
  /Users/ia285/miniconda3/lib/python3.10/site-packages/ipykernel/kernelbase.py(731): execute_request
  /Users/ia285/miniconda3/lib/python3.10/site-packages/ipykernel/kernelbase.py(408): dispatch_shell
  /Users/ia285/miniconda3/lib/python3.10/site-packages/ipykernel/kernelbase.py(501): process_one
  /Users/ia285/miniconda3/lib/python3.10/site-packages/ipykernel/kernelbase.py(512): dispatch_queue
  /Users/ia285/miniconda3/lib/python3.10/asyncio/events.py(80): _run
  /Users/ia285/miniconda3/lib/python3.10/asyncio/base_events.py(1881): _run_once
  /Users/ia285/miniconda3/lib/python3.10/asyncio/base_events.py(595): run_forever
  /Users/ia285/miniconda3/lib/python3.10/site-packages/tornado/platform/asyncio.py(195): start
  /Users/ia285/miniconda3/lib/python3.10/site-packages/ipykernel/kernelapp.py(724): start
  /Users/ia285/miniconda3/lib/python3.10/site-packages/traitlets/config/application.py(846): launch_instance
  /Users/ia285/miniconda3/lib/python3.10/site-packages/ipykernel_launcher.py(17): <module>
  /Users/ia285/miniconda3/lib/python3.10/runpy.py(86): _run_code
  /Users/ia285/miniconda3/lib/python3.10/runpy.py(196): _run_module_as_main


In [55]:
sol = solver.solution
X = sol['X']
vertices = vmap(get_fk)(X).translation().astype(np.float32).T
scene.vis['lines_segments'].set_object(mc_geom.Line(mc_geom.PointsGeometry(vertices), mc_geom.MeshLambertMaterial(color=0xff0000)))
ee_frame = scene.vis['ee']
ee_frame.set_object(mc_geom.triad(scale=0.2))
for q in X:
    _tf = np.array(get_fk(q).as_matrix()).astype(np.float64)
    ee_frame.set_transform(_tf)
    for i, jt in enumerate(q):
        robot[i] = jt
    time.sleep(0.01)
    scene.render()

In [None]:
# # ps.set_up_dir("z_up")
# # ps.set_front_dir("neg_y_front")
# ps.init()


# ps_mesh = ps.register_surface_mesh("bunny", mesh.vertices, mesh.faces)
# # ps_points = ps.register_point_cloud("sampled points", args['points'][:,:3])
# # ps_mesh.add_scalar_quantity("face vals",vmap(info_distr, in_axes=(0,))(mesh.vertices))

# # ps_points.add_scalar_quantity("results", args['P_XI'])

# # ps_bunny_mesh.set_transparency(0.8)
# # ps_bunny_mesh.add_scalar_quantity('info_distr', mesh_func.func_vals, defined_on='vertices', cmap='blues')

# ps_traj         = ps.register_curve_network("trajectory", vmap(get_fk)(X).translation() , edges="line")


# for _ in range(1000):
#     solver.solve(eps=1e-5, max_iter=10)
#     sol = solver.solution
#     X = sol['X']

#     ps_traj.update_node_positions(vmap(get_fk)(X).translation() )
#     # ps_traj.add_vector_quantity("vec img", X[:,:3], enabled=True)
#     ps.frame_tick()

#     time.sleep(0.001)

In [None]:
solver.solve(eps=1e-3)

In [22]:
sol = solver.solution

In [64]:
X = sol['X']

ps.init()

ps_bunny_mesh = ps.register_surface_mesh("bunny", mesh.vertices, mesh.faces)
# ps_bunny_mesh.set_transparency(0.8)
ps_bunny_mesh.add_scalar_quantity('info_distr', vmap(info_distr, in_axes=(0,))(mesh.vertices), defined_on='vertices', cmap='blues')

ps_traj         = ps.register_curve_network("trajectory", vmap(get_fk)(X).translation() , edges="line")

ps.show()

[polyscope] Backend: openGL3_glfw -- Loaded openGL version: 4.1 Metal - 88
