In [2]:
import jax
import jax.numpy as jnp
from jaxlie import SE3, SO3

from sdf_world.sdf_world import *
from sdf_world.robots import *
from sdf_world.util import *
from sdf_world.network import *
from sdf_world.loss import *

In [3]:
panda_model = RobotModel(PANDA_URDF, PANDA_PACKAGE)

In [4]:
world = SDFWorld()
world.show_in_jupyter()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7000/static/


In [5]:
#panda_model = RobotModel(PANDA_URDF, PANDA_PACKAGE)
panda = Robot(world.vis, "panda", panda_model, alpha=0.5)
panda.reduce_dim([7,8], [0.04, 0.04])

In [6]:
q = np.random.uniform(panda_model.lb, panda_model.ub)[:7]
jac = panda_model.jac_fn(q)
lb, ub = panda_model.lb[:7], panda_model.ub[:7]

In [7]:
from functools import partial
def _joint_limit_potential(q, scale):
    value = jnp.sum( 0.1*(ub - lb)**2 / (4 * (q - lb) * (ub - q) ))
    return jnp.exp(-value) * scale
k = _joint_limit_potential(panda_model.neutral[:7], 1.)
joint_limit_potential = partial(_joint_limit_potential, scale=1/k)
def get_manipulability_limit_penalized(q):
    jac = panda_model.jac_fn(q)
    joint_limit_scale = joint_limit_potential(q)
    return jnp.sqrt(jnp.linalg.det(jac@jac.T)) * joint_limit_scale

In [8]:
ws_r = 1.
ws_center = jnp.array([0,0,0.5])
xyz_res = 41
xx = jnp.linspace(-1, 1, xyz_res, endpoint=True)
yy = jnp.linspace(-1, 1, xyz_res, endpoint=True)
zz = jnp.linspace(-0.5, 1.5, xyz_res, endpoint=True)
X, Y, Z = jnp.meshgrid(xx,yy,zz)
xyz_grids = jnp.vstack([X.flatten(), Y.flatten(), Z.flatten()]).T
in_ws_sphere = jnp.linalg.norm(xyz_grids - ws_center, axis=-1) < ws_r
xyz_grids = xyz_grids[in_ws_sphere]
qtn_grids = super_fibonacci_spiral(1000)

In [9]:
def rot_distance(qtn1, qtn2):
    rot_diff = SO3(qtn1).inverse() @ SO3(qtn2)
    angle = jnp.linalg.norm(rot_diff.log())
    return angle
def get_rot_grid_idx(qtn):
    errors = jax.vmap(rot_distance, in_axes=(None,0))(qtn, qtn_grids)
    return errors.argmin()
def get_xyz_grid_idx(xyz):
    errors = jnp.linalg.norm(xyz_grids - xyz, axis=-1)
    return errors.argmin()

In [11]:
def get_sample(q):
    ee_pose = panda_model.fk_fn(q)[-1]
    xyz_idx = get_xyz_grid_idx(ee_pose[-3:])
    qtn_idx = get_rot_grid_idx(ee_pose[:4])
    manip = get_manipulability_limit_penalized(q)
    return manip, qtn_idx, xyz_idx
get_samples = jax.jit(jax.vmap(get_sample))

In [16]:
%load_ext cython
import Cython

In [25]:
%%cython -a
import numpy as np
cimport numpy as np
cimport cython
from cython.parallel import prange

#@cython.boundscheck(False)
#@cython.wraparound(False)
def insert_values(
    int[:, ::1] idxs, 
    np.float32_t[::1] manips, 
    np.float32_t[:,::1] manip_map, 
):
    cdef int updated
    cdef Py_ssize_t n = idxs.shape[0]
    cdef Py_ssize_t x, y, i
    cdef np.float32_t[:,:] map_view = manip_map
    cdef int[:,:] idxs_view = idxs
    cdef np.float32_t[:] manips_view = manips
    
    updated = 0
    for i in prange(n, nogil=True):    
        x = idxs_view[i, 0]
        y = idxs_view[i, 1]
        
        if map_view[x, y] < manips_view[i]:
            if map_view[x, y] == 0.: 
               updated += 1
            map_view[x, y] = manips_view[i]
    print(updated)

In file included from /home/polde/miniconda3/envs/cu11/lib/python3.8/site-packages/numpy/core/include/numpy/ndarraytypes.h:1940,
                 from /home/polde/miniconda3/envs/cu11/lib/python3.8/site-packages/numpy/core/include/numpy/ndarrayobject.h:12,
                 from /home/polde/miniconda3/envs/cu11/lib/python3.8/site-packages/numpy/core/include/numpy/arrayobject.h:5,
                 from /home/polde/.cache/ipython/cython/_cython_magic_045e89b641c7d73343082418ffa3794b.c:769:
      |  ^~~~~~~


In [None]:
map_shape = len(qtn_grids), len(xyz_grids)
manip_map = np.zeros(map_shape, dtype=np.float32)
num_batch = 500000

In [43]:
from IPython.display import clear_output
#loop
for epoch in range(1000):
    qs = np.random.uniform(panda_model.lb, panda_model.ub, size=(num_batch,9))[:,:7]
    manips, qtn_idxs, xyz_idxs = get_samples(qs)
    idxs = jnp.vstack([qtn_idxs, xyz_idxs]).T
    clear_output(wait=True)
    print(f"i:{epoch}")
    insert_values(np.array(idxs), np.array(manips), manip_map)
    
    # idxs = j
    # updated = 0
    # for i in range(num_batch):
    #     if manip_map[qtn_idxs[i], xyz_idxs[i]] < manips[i]:
    #         manip_map[qtn_idxs[i], xyz_idxs[i]] = manips[i]
    #         updated += 1

i:999
67


In [54]:
np.savez_compressed(
    "manip_data.npz", 
    manip_map=manip_map,
    qtn_grids=qtn_grids,
    xyz_grids=xyz_grids
)

: 