In [2]:
import sys 
sys.path.append('../')
import jax 
jax.config.update('jax_enable_x64', True)

import jax.numpy as jnp 
from jaxlie import SE3, SO3
import numpy as np
from jax import vmap
import jax.random as jax_random
# from ergodic_mmd.aug_lagrange_jaxopt import AugmentedLagrangeSolver
from ergodic_mmd.aug_lagrange_solver import AugmentedLagrangeSolver

import adam
from adam.jax import KinDynComputations

from plyfile import PlyData
import trimesh as tm

import matplotlib.pyplot as plt 
from IPython.display import clear_output
import polyscope as ps
import time

In [14]:
plydata = PlyData.read('../assets/sphere.ply')
verts = np.vstack((
    plydata['vertex']['x'],
    plydata['vertex']['y'],
    plydata['vertex']['z']
)).T
faces = np.array(np.vstack(plydata['face']['vertex_indices']))

mesh = tm.Trimesh(vertices=verts, faces=faces)

num_points = 1000  # Change this number based on your requirement
points, face_indices = tm.sample.sample_surface(mesh, num_points)

h = 0.1
seed = 0 
key = jax_random.PRNGKey(seed)
key, subkey = jax_random.split(key, 2)
_std = 0.1*jax_random.normal(subkey, shape=(num_points,1))
_std = _std - _std.min()+0.1
_points = points + _std*mesh.face_normals[face_indices]
_points = mesh.vertices + 0.2 * mesh.vertex_normals

info_distr = lambda x: (jnp.sin(x[0]*x[1])+1)*jnp.exp(-60*(x[0]-0.5)**2 - 10*(x[1]-0.5)**2) + 2*jnp.exp(-30*(x[0]-1.)**2 - 30*(x[1]-0.5)**2 - 20*(x[2])**2)
P_XI = vmap(info_distr, in_axes=(0,))(mesh.vertices)
P_XI = P_XI-P_XI.min()+0.01
P_XI = P_XI/jnp.sum(P_XI)

# args = {'h' : h, 'points' : jnp.hstack([_points,mesh.vertex_normals]), 'P_XI' : P_XI, 
#         'x0' : jnp.hstack([_points[0], jnp.zeros(3)]), 'xf' : jnp.hstack([_points[-1], jnp.zeros(3)])}
args = {'h' : h, 'points' : jnp.hstack([_points,mesh.vertex_normals]), 'P_XI' : P_XI, 
        'x0' : jnp.hstack([_points[0], jnp.zeros(3)]), 'xf' : jnp.hstack([_points[-1], jnp.zeros(3)])}

In [15]:
# ps.init()

# ps_mesh = ps.register_surface_mesh("bunny", mesh.vertices, mesh.faces)
# ps_mesh.add_scalar_quantity("face vals",vmap(info_distr, in_axes=(0,))(mesh.vertices))
# ps_points = ps.register_point_cloud("sampled points", args['points'])
# ps_points.add_scalar_quantity("results", args['P_XI'])


# ps.show()

In [16]:
class ThreeDAirCraftModel(object):
    def __init__(self) -> None:
        self.dt = 0.1
        self.n = 5
        self.m = 3 
        def dfdt(x, u):
            # v  = np.clip(u[0], 0.1, 5)
            # w1 = np.clip(u[1], -15,15)
            # w2 = np.clip(u[2], -15,15) 
            v  = u[0]
            w1 = u[1]
            w2 = u[2]
            w3 = u[3]
            return jnp.array([
                v * jnp.cos(x[4]) * jnp.cos(x[3]),
                v * jnp.sin(x[4]) * jnp.cos(x[3]),
                v * jnp.sin(x[3]),
                w1, 
                w2,
                w3
            ]) 
        def f(x, u):
            return x + self.dt * dfdt(x, u)
        self.f      = f
        self.dfdt   = dfdt

robot_model = ThreeDAirCraftModel()

def f_lie_constr(twist2, twist1, U):
    dtwist = jnp.array([U[0],U[1],U[2],0.,0.,0.])
    # dtwist = jnp.array([U[0],0.,0.,U[1],U[2], U[3]])
    # return SE3.log(SE3.exp(twist2).inverse()@SE3.exp(twist1+dtwist))
    return twist2 - (twist1 + dtwist)

In [18]:
# def RBF_kernel(x, xp, h=0.01):
#     return jnp.exp(
#         -jnp.sum((x-xp)**2)/h
#     )
def RBF_kernel(x, xp, h=0.01):
    return jnp.exp(
        -jnp.sum(SE3.log(SE3.exp(x).inverse()@SE3.exp(xp))**2)/h
    )
def create_kernel_matrix(kernel):
    return vmap(vmap(kernel, in_axes=(0, None, None)), in_axes=(None, 0, None))

KernelMatrix = create_kernel_matrix(RBF_kernel)
def emmd_loss(params, args):
    X = params['X']
    T = X.shape[0]
    p, w = jnp.split(X, 2, axis=1)
    h = args['h']
    points    = args['points']
    P_XI      = args['P_XI']
    return np.sum(KernelMatrix(X, X, h))/(T**2) \
            - 2 * np.sum(P_XI @ KernelMatrix(X, points, h))/T

def eq_constr(params, args):
    X = params['X']
    U = params['U']
    # p, w = jnp.split(X, 2, axis=1)
    return jnp.vstack([
        (X[0]-args['x0']),
        vmap(f_lie_constr)(X[1:], X[:-1], U[:-1]),
        (X[-1]-args['xf'])
    ])

def ineq_constr(params, args):
    return jnp.array(0.)


In [19]:
T = 60
# X = jnp.linspace(mesh.bounds[0], mesh.bounds[1], num=T)
X = jnp.concatenate(
    [ 
        jnp.linspace(args['points'][0], args['points'][-1], num=T),
        jnp.linspace(mesh.bounds[0], mesh.bounds[1], num=T)*0
    ],
    axis=1
)

U = jnp.zeros((T, 4)) # forward v, theta dot, phi dot


params = {'X' : X, 'U' : U}
# solver = AugmentedLagrangeSolver(params, emmd_loss, eq_constr, ineq_constr, max_stepsize=1e-1, args=args)
solver = AugmentedLagrangeSolver(params, emmd_loss, eq_constr, ineq_constr, args=args)


TypeError: sub got incompatible shapes for broadcasting: (9,), (6,).

In [9]:
ps.init()

ps_mesh = ps.register_surface_mesh("bunny", mesh.vertices, mesh.faces)
ps_points = ps.register_point_cloud("sampled points", args['points'])
ps_mesh.add_scalar_quantity("face vals",vmap(info_distr, in_axes=(0,))(mesh.vertices))

ps_points.add_scalar_quantity("results", args['P_XI'])

# ps_bunny_mesh.set_transparency(0.8)
# ps_bunny_mesh.add_scalar_quantity('info_distr', mesh_func.func_vals, defined_on='vertices', cmap='blues')

ps_traj         = ps.register_curve_network("trajectory", X[:,:3], edges="line")


for _ in range(1000):
    solver.solve(eps=1e-5, max_iter=10)
    sol = solver.solution
    X = sol['X']

    ps_traj.update_node_positions(X[:,:3])
    # ps_traj.add_vector_quantity("vec img", X[:,:3], enabled=True)
    ps.frame_tick()

    time.sleep(0.001)

[polyscope] Backend: openGL3_glfw -- Loaded openGL version: 4.6 (Core Profile) Mesa 21.2.6


AssertionError: 

In [None]:
sol = solver.solution

In [None]:
X = sol['X']

ps.init()

ps_mesh = ps.register_surface_mesh("bunny", mesh.vertices, mesh.faces)
ps_points = ps.register_point_cloud("sampled points", args['points'])
ps_mesh.add_scalar_quantity("face vals",vmap(info_distr, in_axes=(0,))(mesh.vertices))

ps_points.add_scalar_quantity("results", args['P_XI'])

# ps_bunny_mesh.set_transparency(0.8)
# ps_bunny_mesh.add_scalar_quantity('info_distr', mesh_func.func_vals, defined_on='vertices', cmap='blues')

ps_traj         = ps.register_curve_network("trajectory", X , edges="line")

ps.show()