In [1]:
import jax
import jax.numpy as jnp
from jaxlie import SE3, SO3

from sdf_world.sdf_world import *
from sdf_world.robots import *
from sdf_world.util import *

from network import *
from loss import *

In [2]:
panda_model = RobotModel(PANDA_URDF, PANDA_PACKAGE)

concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts


In [3]:
world = SDFWorld()
world.show_in_jupyter()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7004/static/


In [4]:
#panda_model = RobotModel(PANDA_URDF, PANDA_PACKAGE)
panda = Robot(world.vis, "panda", panda_model, alpha=0.5)
panda.reduce_dim([7,8], [0.04, 0.04])

In [5]:
q = np.random.uniform(panda_model.lb, panda_model.ub)[:7]
jac = panda_model.jac_fn(q)
lb, ub = panda_model.lb[:7], panda_model.ub[:7]

In [6]:
from functools import partial
def _joint_limit_potential(q, scale):
    value = jnp.sum( 0.1*(ub - lb)**2 / (4 * (q - lb) * (ub - q) ))
    return jnp.exp(-value) * scale
k = _joint_limit_potential(panda_model.neutral[:7], 1.)
joint_limit_potential = partial(_joint_limit_potential, scale=1/k)
def get_manipulability_limit_penalized(q):
    jac = panda_model.jac_fn(q)
    joint_limit_scale = joint_limit_potential(q)
    return jnp.sqrt(jnp.linalg.det(jac@jac.T)) * joint_limit_scale

In [8]:
ws_r = 1.
ws_center = jnp.array([0,0,0.5])
xyz_res = 41
xx = jnp.linspace(-1, 1, xyz_res, endpoint=True)
yy = jnp.linspace(-1, 1, xyz_res, endpoint=True)
zz = jnp.linspace(-0.5, 1.5, xyz_res, endpoint=True)
X, Y, Z = jnp.meshgrid(xx,yy,zz)
xyz_grids = jnp.vstack([X.flatten(), Y.flatten(), Z.flatten()]).T
in_ws_sphere = jnp.linalg.norm(xyz_grids - ws_center, axis=-1) < ws_r
xyz_grids = xyz_grids[in_ws_sphere]
qtn_grids = super_fibonacci_spiral(1000)

In [12]:
def rot_distance(qtn1, qtn2):
    rot_diff = SO3(qtn1).inverse() @ SO3(qtn2)
    angle = jnp.linalg.norm(rot_diff.log())
    return angle
def get_rot_grid_idx(qtn):
    errors = jax.vmap(rot_distance, in_axes=(None,0))(qtn, qtn_grids)
    return errors.argmin()
def get_xyz_grid_idx(xyz):
    errors = jnp.linalg.norm(xyz_grids - xyz, axis=-1)
    return errors.argmin()

In [34]:
map_shape = len(qtn_grids), len(xyz_grids)
manip_map = np.zeros(map_shape)

In [46]:
def get_sample(q):
    ee_pose = panda_model.fk_fn(q)[-1]
    xyz_idx = get_xyz_grid_idx(ee_pose[-3:])
    qtn_idx = get_rot_grid_idx(ee_pose[:4])
    manip = get_manipulability_limit_penalized(q)
    return manip, qtn_idx, xyz_idx
get_samples = jax.jit(jax.vmap(get_sample))

In [66]:
num_batch = 1000
for epoch in range(100):
    qs = np.random.uniform(panda_model.lb, panda_model.ub, size=(num_batch,9))[:,:7]
    manips, qtn_idxs, xyz_idxs = get_samples(qs)
    updated = 0
    for i in range(num_batch):
        if manip_map[qtn_idxs[i], xyz_idxs[i]] < manips[i]:
            manip_map[qtn_idxs[i], xyz_idxs[i]] = manips[i]
            updated += 1
    print(f"{epoch}: updated {updated}")

999: updated 995
999: updated 994
999: updated 997
999: updated 994
999: updated 994
999: updated 995
999: updated 996
999: updated 995
999: updated 994
999: updated 994
999: updated 996
999: updated 991
999: updated 995
999: updated 995
999: updated 992
999: updated 997
999: updated 995
999: updated 995
999: updated 990
999: updated 995
999: updated 996
999: updated 994
999: updated 996
999: updated 992
999: updated 997
999: updated 996
999: updated 993
999: updated 992
999: updated 997
999: updated 993
999: updated 991
999: updated 993
999: updated 991
999: updated 990
999: updated 996
999: updated 995
999: updated 992
999: updated 993
999: updated 996
999: updated 993
999: updated 995
999: updated 996
999: updated 993


KeyboardInterrupt: 

In [45]:
panda.set_joint_angles(q)
frame.set_pose(SE3(jnp.hstack([qtn_grids[qtn_idx], xyz_grids[xyz_idx]])))

In [30]:
q = np.random.uniform(panda_model.lb, panda_model.ub)[:7]


frame.set_pose(SE3(jnp.hstack([qtn_est, p_est])))

In [19]:
frame = Frame(world.vis, "frame")

Array(7711, dtype=int32)

In [172]:
errors.argmin()

Array(37, dtype=int32)

In [116]:
SE3(ee_pose).rotation()

SO3(wxyz=[ 0.00886    -0.26738998  0.22815    -0.93614995])

In [None]:
ee_pose