In [1]:
import numpy as np
import jax
import jax.numpy as jnp
from jaxlie import SE3, SO3

from sdf_world.sdf_world import *
from sdf_world.robots import *
from sdf_world.util import *

In [2]:
world = SDFWorld()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7001/static/


In [3]:
panda_model = RobotModel(PANDA_URDF, PANDA_PACKAGE)
panda = Robot(world.vis, "panda", panda_model, alpha=0.5)
panda.reduce_dim([7, 8], [0.04, 0.04])

In [669]:
robot_dim = 7
horizon = 10
dim = robot_dim * horizon
u = jnp.zeros(dim)

In [884]:
to_mat = lambda x: x.reshape(-1, robot_dim)
to_vec = lambda x: x.flatten()
dt = 0.1
dt_vec = jnp.ones(dim) * dt

umax = 1.
lower_tri = np.tril(np.full((horizon,horizon), 1))
upper_tri = np.triu(np.full((horizon,horizon), 1))
eye = np.eye(robot_dim)
integration_mat = np.kron(lower_tri, eye) @ np.diag(dt_vec)
double_integration_mat = integration_mat@integration_mat
qdot_max = umax*np.kron(upper_tri, eye)@dt_vec

In [919]:
# @jax.jit
# def integration(q0, qdots):
#     result = []
#     q = q0
#     for qdot in to_mat(qdots):
#         q += qdot*dt
#         result.append(q)
#     return jnp.hstack(result)

@jax.jit
def integration(q0, qdots):
    return integration_mat@qdots + jnp.tile(q0, horizon)
    
@jax.jit
def rollout(us, state):
    q0, qdot0 = state
    qdots = integration(qdot0, us)
    qs = integration(q0, qdots)
    return qs, qdots

In [673]:
# # helper fn
# get_ee = lambda q: panda_model.fk_fn(q)[-1]
# get_pos_jac = lambda q: panda_model.jac_fn(q)[:3, :]
# get_ee_pos = lambda q: get_ee(q)[-3:]

In [920]:
# Kinematics
@jax.jit
def get_ee_fk_jac(q):
    fks = panda_model.fk_fn(q)
    p_ee = fks[-1][-3:]
    jac = []
    for posevec in fks[1:8]:
        p_frame = posevec[-3:]
        rot_axis = SE3(posevec).as_matrix()[:3, 2]
        lin_vel = jnp.cross(rot_axis, p_ee - p_frame)
        jac.append(jnp.hstack([lin_vel, rot_axis]))
    jac = jnp.array(jac).T
    return fks[-1], jac

In [921]:
from scipy.optimize import minimize, Bounds
# optimization functions
# Objective / gradient
@jax.jit
def vg_pos_err(q, target):
    ee_pose, ee_jac = get_ee_fk_jac(q)
    err = target - ee_pose[-3:]
    pos_jac = ee_jac[:3, :]
    val = jnp.sum(err**2)
    grad_poserr = - 2*pos_jac.T@err
    return val, grad_poserr

def vg_objective(u, state, target):
    qs, _ = rollout(u, state)
    vals, grads = jax.vmap(vg_pos_err, in_axes=(0,None))(to_mat(qs), target)
    grads = grads.flatten() @ double_integration_mat
    return vals.sum(), grads

# Constraints
lb, ub = -np.ones(dim), np.ones(dim)
bounds = Bounds(lb, ub)
def constr_stop(u, state):
    _, qdot0 = state
    qdots = integration(qdot0, u)
    qdot_max_viol = qdot_max - jnp.abs(qdots)
    return jnp.min(qdot_max_viol)

In [936]:
# dynamics
def update_dynamics(u, state):
    q, qdot = state
    q_new = q + qdot*dt
    qdot_new = qdot + u*dt
    return q_new, qdot_new

In [2084]:
# problem initialize
target_point = jnp.array([0.4, 0.2, 0.6])
state = (panda.neutral, jnp.zeros(robot_dim))
u = jnp.zeros(dim)
point.set_translate(target_point)

In [1703]:
#compile
vg_objective_j = jax.jit(vg_objective).lower(u, state, target_point).compile()
constr_stop_j = jax.jit(constr_stop).lower(u, state).compile()
jac_constr_stop_j = jax.jit(jax.grad(constr_stop)).lower(u, state).compile()

In [2112]:
import time
tic = time.perf_counter()
constraints = tuple([
    {"type":"ineq", 
     "fun":constr_stop_j, "jac":jac_constr_stop_j, "args":(state,)}
])
res = minimize(
    fun=vg_objective_j,
    x0=u,
    args=(state,target_point),
    method="SLSQP",
    jac=True,
    bounds=bounds,
    constraints=constraints,
    options={'ftol':0.0001}
)
toc = time.perf_counter()
print(res.message, f"elapsed:{toc-tic}")

draw_ee_traj(res.x, state)
state_new = update_dynamics(res.x[:robot_dim], state)
qs, qdots = rollout(res.x, state)
last_qdot = to_mat(qdots)[-1]
unew = jnp.hstack([res.x[robot_dim:], - last_qdot/ dt])
u, state = unew, state_new

Optimization terminated successfully elapsed:0.002440276090055704


In [1800]:
vg_objective_j(u, state, target_point)

(Array(0.11125635, dtype=float32),
 Array([-8.7218604e-04,  1.4898673e-02, -5.2387547e-04,  2.1761764e-02,
        -1.3568599e-03,  2.4745822e-02, -1.1241400e-09, -7.1360671e-04,
         1.2189822e-02, -4.2862538e-04,  1.7805079e-02, -1.1101582e-03,
         2.0246582e-02, -9.1975089e-10, -5.7088537e-04,  9.7518582e-03,
        -3.4290028e-04,  1.4244063e-02, -8.8812655e-04,  1.6197266e-02,
        -7.3580075e-10, -4.4402198e-04,  7.5847786e-03, -2.6670026e-04,
         1.1078715e-02, -6.9076510e-04,  1.2597873e-02, -5.7228944e-10,
        -3.3301648e-04,  5.6885839e-03, -2.0002518e-04,  8.3090374e-03,
        -5.1807379e-04,  9.4484044e-03, -4.2921711e-10, -2.3786891e-04,
         4.0632742e-03, -1.4287513e-04,  5.9350263e-03, -3.7005273e-04,
         6.7488607e-03, -3.0658365e-10, -1.5857928e-04,  2.7088495e-03,
        -9.5250085e-05,  3.9566844e-03, -2.4670182e-04,  4.4992408e-03,
        -2.0438909e-10, -9.5147567e-05,  1.6253097e-03, -5.7150053e-05,
         2.3740106e-03, -1.48

In [1957]:
constr_stop(unew, state_new)

Array(0.1, dtype=float32)

In [933]:
import meshcat.geometry as g
get_ee_pos = lambda q: panda_model.fk_fn(q)[-1][-3:]
def draw_ee_traj(u, state):
    world.vis["line"].delete()
    world.vis["line_pc"].delete()

    q = state[0]
    panda.set_joint_angles(q)
    qs, _ = rollout(u, state)    
    vertices = jax.vmap(get_ee_pos)(to_mat(qs))
    color = Colors.read("red")
    colors = np.tile(Colors.read("red", return_rgb=True), 
                     vertices.shape[0]).reshape(-1, 3)
    point_obj = g.PointsGeometry(vertices.T, colors.T)
    line_material = g.MeshBasicMaterial(color=color)
    point_material = g.PointsMaterial(size=0.02)
    world.vis["line"].set_object(
        g.Line(point_obj, line_material)
    )
    world.vis["line_pc"].set_object(
        point_obj, point_material)

In [896]:
import time
constraints = tuple([
    {"type":"ineq", "fun":constr_stop_j, "jac":jac_constr_stop_j, "args":(state,)}
])
tic = time.perf_counter()
res = minimize(
    fun=vg_objective_j,
    x0=u,
    args=(state,),
    method="SLSQP",
    jac=True,
    bounds=bounds,
    constraints=constraints,
    options={'ftol':0.01}
)
toc = time.perf_counter()
print(res.message)
toc-tic

Optimization terminated successfully


0.004479089053347707

In [794]:
world.vis["line"].delete()

In [839]:
draw_ee_traj(res.x, state)

In [628]:
value, grads = vg_objective(unext, state)
unext = unext - grads

qs = rollout(state, unext)
panda.set_joint_angles(to_mat(qs)[0])
print(value)

0.20689468


Array(0.26821747, dtype=float32)

In [404]:
to_mat(qs)

Array([[ 0.01328348,  0.01220716,  0.013152  , -1.5656508 ,  0.01068715,
         1.8745259 ,  0.01      ],
       [ 0.0385356 ,  0.03588177,  0.03817638, -1.553735  ,  0.03175081,
         1.8895615 ,  0.03      ],
       [ 0.07481594,  0.07043503,  0.0741664 , -1.5337601 ,  0.06298567,
         1.9133978 ,  0.06      ],
       [ 0.12152295,  0.1154272 ,  0.12054996, -1.5047567 ,  0.10427634,
         1.9466316 ,  0.10000001],
       [ 0.1783392 ,  0.17056578,  0.17703146, -1.4660778 ,  0.15557502,
         1.9896635 ,  0.15      ]], dtype=float32)

In [395]:
unext

Array([1.3283477 , 1.2207156 , 1.3151999 , 0.5149158 , 1.0687147 ,
       0.7025927 , 1.        , 1.1968642 , 1.1467453 , 1.1872379 ,
       0.67666173, 1.0376508 , 0.8009746 , 1.        , 1.1028215 ,
       1.0878648 , 1.0965645 , 0.80591327, 1.0171207 , 0.88006043,
       1.        , 1.0426666 , 1.0438911 , 1.0393531 , 0.9028455 ,
       1.0055797 , 0.93974584, 1.        , 1.0109245 , 1.0146418 ,
       1.0097942 , 0.9675567 , 1.0008031 , 0.97982305, 1.        ],      dtype=float32)

In [131]:
val, grad = jax.vmap(cost_pos_vg, out_axes=(0,0))(qs)

In [139]:
grad.flatten() @ jac_rollout(state, u)

Array([-2.2193089e-02, -1.5408222e-02, -2.1307826e-02,  3.1913638e-02,
       -4.6295011e-03,  1.9339101e-02,  8.2798035e-10, -1.3543896e-02,
       -1.0370999e-02, -1.2890394e-02,  2.1388866e-02, -2.5997330e-03,
        1.2993046e-02,  5.8132033e-10, -7.2516296e-03, -6.2921420e-03,
       -6.8211695e-03,  1.2912054e-02, -1.2322738e-03,  7.8615453e-03,
        4.3733650e-10, -3.1223432e-03, -3.1866729e-03, -2.8897203e-03,
        6.5005664e-03, -4.3687542e-04,  3.9649447e-03,  2.6147917e-10,
       -8.4964902e-04, -1.0780505e-03, -7.6785660e-04,  2.1830455e-03,
       -8.1807942e-05,  1.3326881e-03,  7.6598526e-11], dtype=float32)

In [111]:
panda.set_joint_angles(q - 0.5 * grad)

In [92]:
get_pos_jac(q)

Array([[-2.7173555e-03,  4.8802570e-01, -2.6683933e-03, -1.7285609e-01,
        -2.7030234e-03,  2.1101448e-01, -3.6379788e-11],
       [ 9.3230933e-02,  4.8804199e-03,  8.8346086e-02, -3.4582587e-03,
         9.0075001e-02,  6.3232258e-03,  4.6566129e-10],
       [ 0.0000000e+00, -9.3253441e-02,  1.7847964e-05,  7.6109590e-03,
         9.0121484e-06,  9.0115853e-02, -1.7053026e-13]], dtype=float32)

In [84]:
jax.vmap(cost_pos)(qs)

Array([0.10589164, 0.10027977, 0.09383675, 0.08904492, 0.08927719],      dtype=float32)

In [86]:
panda.set_joint_angles(qs[0])

Array(0.3103109, dtype=float32)

In [57]:
point.set_translate(target_point)

In [56]:
point = Sphere(world.vis, "point", 0.02, "red")

In [27]:
integration2(q0, u)

Array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2,
       0.2, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.4, 0.4, 0.4, 0.4, 0.4,
       0.4, 0.4, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], dtype=float32)

In [36]:
%timeit integration(q0, u).block_until_ready()

27.2 µs ± 753 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [37]:
%timeit integration2(q0, u).block_until_ready()

24.4 µs ± 327 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [12]:
integration_mat

Array([[0.1, 0. , 0. , ..., 0. , 0. , 0. ],
       [0. , 0.1, 0. , ..., 0. , 0. , 0. ],
       [0. , 0. , 0.1, ..., 0. , 0. , 0. ],
       ...,
       [0. , 0. , 0. , ..., 0.1, 0. , 0. ],
       [0. , 0. , 0. , ..., 0. , 0.1, 0. ],
       [0. , 0. , 0. , ..., 0. , 0. , 0.1]], dtype=float32)