In [2]:
import sys 
sys.path.append('../')
import jax 
import jax.numpy as jnp 
import numpy as np
from jaxlie import SE3, SO3
from jax.random import normal
from jax import grad, hessian, vmap, pmap
from jax.flatten_util import ravel_pytree
from jax.lax import scan
from functools import partial
import jax.random as jax_random
import matplotlib.pyplot as plt 
from IPython.display import clear_output
from plyfile import PlyData
import open3d as o3d 
import trimesh as tm
import trimesh
import time
import jaxopt
import polyscope as ps


import meshcat
import meshcat.geometry as mc_geom
from robomeshcat import Object, Robot, Scene



from ergodic_mmd.aug_lagrange_jaxopt import AugmentedLagrangeSolver
# from ergodic_mmd.aug_lagrange_solver import AugmentedLagrangeSolver

import adam
from adam.jax import KinDynComputations

In [3]:
# dataset = o3d.data.BunnyMesh()
# plydata = PlyData.read(dataset.path)
# plydata = PlyData.read('../assets/sphere.ply')
plydata = PlyData.read('../assets/submarine.ply')
vertices = np.vstack((
    plydata['vertex']['x'],
    plydata['vertex']['y'],
    plydata['vertex']['z']
)).T

# _min_pnt = vertices.min(axis=0)
# _max_pnt = vertices.max()
# _mid_pnt = vertices.mean(axis=0)
# vertices = vertices/_max_pnt

# vertices = (vertices - _mid_pnt)/(_max_pnt-_min_pnt)

faces = np.array(np.vstack(plydata['face']['vertex_indices']))

mesh = tm.Trimesh(vertices=vertices, faces=faces)

mesh.apply_scale(0.001)
mesh.apply_transform(
    SE3.from_rotation(SO3.from_x_radians(np.pi/2)).as_matrix()
    )

# ds_mesh = mesh
# ds_mesh = mesh.simplify_quadric_decimation(int(mesh.vertices.shape[0]/10))


# ds_mesh = tm.primitives.Box()
ds_mesh = mesh.convex_hull

<robomeshcat.object.Object at 0x2b286bb50>

In [86]:
scene = Scene()
exp_obj = trimesh.exchange.obj.export_obj(mesh)
mc_obj = mc_geom.ObjMeshGeometry.from_stream(trimesh.util.wrap_as_stream(exp_obj))
mc_obj = Object(mc_obj, name='red_sphere', opacity=0.5, color=[.1, .1, .1])
# exp_obj = trimesh.exchange.obj.export_obj(ds_mesh)
# mc_obj = mc_geom.ObjMeshGeometry.from_stream(trimesh.util.wrap_as_stream(exp_obj))
# mc_obj = Object(mc_obj, name='red_sphere', opacity=0.5, color=[1., 0., 0.])
scene.add_object(mc_obj)

# bluerov = Object.create_mesh("../assets/blue_rov2.stl", name="bluerov", color=[0.,0.,1.0])
# scene.add_object(bluerov)


You can open the visualizer by visiting the following URL:
http://127.0.0.1:7007/static/


In [78]:
num_points = 200  # Change this number based on your requirement
points, face_indices = tm.sample.sample_surface(ds_mesh, num_points, seed=0)

# Get the normals for each sampled point
normals = ds_mesh.face_normals[face_indices]

info_distr = lambda x: 1.0
P_XI = vmap(info_distr, in_axes=(0,))(points)
P_XI = P_XI/jnp.sum(P_XI)
h = 0.1
args = {'h' : h, 
        'points' : points+0.1*ds_mesh.face_normals[face_indices], 
        'normals' : ds_mesh.face_normals[face_indices], 
        'P_XI' : P_XI}

In [79]:
def RBF_kernel(x, xp, h=0.01):
    return jnp.exp(
        -jnp.sum((x-xp)**2)/h
    )
    # * jnp.exp(
    #     -1e-3*jnp.sum((norm-w)**2)/h
    # ) 
    # return jnp.exp(
    #     -jnp.sum((x-xp)**2)/h
    # )
# def camera_view_penalty(x, xp, h=0.01):
#     w, v        = jnp.split(x,2)
#     norm, root  = jnp.split(xp, 2)
#     return jnp.exp(
#         -jnp.sum((root-v)**2)/h #- 1e-1*jnp.sum((norm-w)**2)
#     ) * jnp.exp(
#         -1e-1*jnp.sum((norm-w)**2)
#     )

def camera_view_penalty(g, v, h=0.01):
    p = g.translation()
    rot = g.rotation()
    w = rot.apply(jnp.array([0.,0.,1.]))
    root, norm  = jnp.split(v, 2)
    # return jnp.exp(-jnp.sum((root-p)**2)/h - 0.01*jnp.sum((-norm-w)**2)/h)
    return jnp.exp(
        -jnp.sum((root-p)**2)/h #- 1e-1*jnp.sum((norm-w)**2)
    )*jnp.dot(-norm,w)
    # * jnp.exp(
    #     # -jnp.sum((-norm-w)**2)/h
    #     jnp.dot(-norm,w)
    # )

def create_kernel_matrix(kernel):
    return vmap(vmap(kernel, in_axes=(0, None, None)), in_axes=(None, 0, None))

In [80]:
KernelMatrix = create_kernel_matrix(RBF_kernel)
camera_view_matrix = create_kernel_matrix(camera_view_penalty)


def emmd_loss(params, args):
    X = params['X']
    T = X.shape[0]

    g = vmap(SE3.exp)(X)
    p = g.translation()

    h = args['h']
    points    = args['points']
    norms     = args['normals']
    P_XI      = args['P_XI']

    view_matrix = camera_view_matrix(g, jnp.hstack([points, norms]), h)

    return np.sum(KernelMatrix(p, p, h))/(T**2) \
            - 2 * np.sum(P_XI @ KernelMatrix(p, points, h))/T \
            + np.mean((X[1:]-X[:-1])**2) - np.sum(P_XI@view_matrix)/T
            # + np.sum((X_init[0]-x0)**2)

    # view_matrix = camera_view_matrix(X_init, XP, h)
    # x0 = args['x0']
    # vertices = args['vertices']
    # norms    = args['normals']
    # XP = jnp.hstack([norms, vertices])
    # P_XI      = jnp.sum(phi_vmapped(X_init, XP, h), axis=1)
    # return np.sum(KernelMatrix(X_init, X_init, h))/(T**2) \
    #         - 2 * np.sum(P_XI @ KernelMatrix(X_init, XP, h))/T \
    #         + np.mean((v[1:]-v[:-1])**2) - np.sum(P_XI@view_matrix)/T
    #         # + np.sum((X_init[0]-x0)**2)

def eq_constr(params, args):
    X = params['X']
    # q0 = args['q0']
    return jnp.zeros(1)#jnp.hstack([X[0]-q0, X[-1]-q0])

def ineq_constr(params, args):
    X = params['X']
    return jnp.zeros(1)#jnp.square(X[1:] - X[:-1]) - 0.01**2

In [81]:
T = 100
X = jnp.linspace(-0.1*jnp.ones(6), 0.1*jnp.ones(6), num=T)

params = {'X' : X}

# solver = AugmentedLagrangeSolver(params, emmd_loss, eq_constr, ineq_constr, max_stepsize=3e-1, args=args)
solver = AugmentedLagrangeSolver(params, emmd_loss, eq_constr, ineq_constr, args=args)

In [82]:
solver.solve(max_iter=50_000, eps=1e-4)

0 0.24664223567435373
1 0.24692779527754885
2 0.24702373387999513
3 0.24693551653530696
4 0.24666921511094725
5 0.24623128351674237
6 0.24562875861536476
7 0.24486916339135426
8 0.24396041371996696
9 0.24291072805312258
10 0.24172854125292326
11 0.24042242356858562
12 0.239001005520767
13 0.237472909235925
14 0.2358466865685306
15 0.2341307641649507
16 0.23233339546228632
17 0.23046261947966423
18 0.22852622614847953
19 0.22653172784097358
20 0.22448633669169343
21 0.22239694726169984
22 0.22027012406846844
23 0.21811209349272245
24 0.2159287395744122
25 0.2137256032212837
26 0.2115078843727011
27 0.2092804466865704
28 0.2070478243465639
29 0.20481423061884207
30 0.20258356782081743
31 0.200359438398172
32 0.19814515683948344
33 0.19594376218981255
34 0.1937580309549787
35 0.19159049021668598
36 0.18944343080495044
37 0.18731892039831652
38 0.18521881644410912
39 0.18314477881047694
40 0.1810982820993239
41 0.17908062756450174
42 0.1770929545929882
43 0.1751362517183372
44 0.1732113671

In [83]:
sol = solver.solution
X = sol['X']
vertices = vmap(SE3.exp)(X).translation().astype(np.float32).T
scene.vis['lines_segments'].set_object(
        mc_geom.Line(mc_geom.PointsGeometry(vertices), 
                     mc_geom.MeshLambertMaterial(color=0xff0000),
                     
                )
            )
ee_frame = scene.vis['ee']
ee_frame.set_object(mc_geom.triad(scale=0.2))
for q in X:
    _tf = np.array(SE3.exp(q).as_matrix()).astype(np.float64)
    ee_frame.set_transform(_tf)
    time.sleep(0.1)
    scene.render()

In [8]:
ps.init()

ps_bunny_mesh = ps.register_surface_mesh("bunny", ds_mesh.vertices, ds_mesh.faces)
ps_bunny_pcl = ps.register_point_cloud("bunny points", points+normals*0.2)


ps_bunny_mesh.set_transparency(0.8)
ps_traj         = ps.register_curve_network("trajectory", X_init * 1.25, edges="line")

for _ in range(1000):
    (sol, solver_state) = solver.update(
                        params=sol, 
                        state=solver_state, 
                        hyperparams_proj=bounds, 
                                args=args)
    X = unflatten_X(sol)
    ps_traj.update_node_positions(X[:,3:])
    ps_traj.add_vector_quantity("vec img", X[:,:3], enabled=True)
    ps.frame_tick()

    time.sleep(0.001)

[polyscope] Backend: openGL3_glfw -- Loaded openGL version: 4.1 Metal - 88


In [8]:
ps.init()

ps_bunny_mesh = ps.register_surface_mesh("bunny", ds_mesh.vertices, ds_mesh.faces)
ps_bunny_mesh.set_transparency(0.8)
ps_traj         = ps.register_curve_network("trajectory", X_init * 1.25, edges="line")

# ps_bunny_mesh.add_scalar_quantity('info_distr', mesh_func.func_vals, defined_on='vertices', cmap='blues')

ps_traj.update_node_positions(X[:,3:])
ps_traj.add_vector_quantity("vec img", X[:,:3], enabled=True)


for i,_xi in enumerate(X): 
    _dir, _root =  np.split(_xi, 2)
    intrinsics = ps.CameraIntrinsics(fov_vertical_deg=60, aspect=2)

    extrinsics = ps.CameraExtrinsics(root=np.array(_root), look_dir=np.array(_dir), up_dir=(0.,1.,0.))
    params = ps.CameraParameters(intrinsics, extrinsics)
    cam = ps.register_camera_view('cam{}'.format(i), params)
ps.show()

KeyboardInterrupt: 

In [None]:
def lie_RBF(x, xp, h=0.01):
    w, v        = jnp.split(x,2)
    norm, root  = jnp.split(xp, 2)
    _R = SO3.exp(w)
    return jnp.exp(
        -(jnp.sum((norm+_R@v1)**2) + np.sum((norm+root-v)**2))/h
    )

def create_kernel_matrix(kernel):
    return vmap(vmap(kernel, in_axes=(0, None, None)), in_axes=(None, 0, None))