In [1]:
import numpy as np
import jax
import jax.numpy as jnp
from jaxlie import SE3, SO3

from sdf_world.sdf_world import *
from sdf_world.robots import *
from sdf_world.util import *
#from sdf_world.network import *

In [2]:
world = SDFWorld()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7001/static/


In [3]:
panda_model = RobotModel(PANDA_URDF, PANDA_PACKAGE)
panda = Robot(world.vis, "panda", panda_model, alpha=0.5)
panda.reduce_dim([7,8], [0.04, 0.04])

In [4]:
frame = Frame(world.vis, "frame")

In [240]:
robot_dim = 7
n = 20
dt = 1/10
dim = n * robot_dim

x0 = jnp.zeros(robot_dim*n)
to_mat = lambda vec:vec.reshape(-1, robot_dim)
to_vec = lambda vec:vec.flatten()

def vector_integrate(x_init, xdots):
    dt_vec = jnp.full(n, dt)    
    lower_tri = jnp.tril(jnp.ones((n,n)))
    lower_tri = jnp.kron(lower_tri @ jnp.diag(dt_vec), jnp.eye(7))
    xs = to_vec(x_init + to_mat(lower_tri @ xdots))
    return xs
def rollout(state, u):
    q, qdot = state
    qdots = vector_integrate(qdot, u)
    qs = vector_integrate(q, qdots)
    return qs
def goal_pose_reaching_cost(q, target_pose):
    pose_curr = SE3(panda_model.fk_fn(q)[-1])
    pos_error = pose_curr.translation() - target_pose.translation()
    pos_error = jnp.sum(pos_error **2)
    orn_error = (target_pose.inverse()@pose_curr).rotation().log() /3
    orn_error = jnp.sum(orn_error**2)
    return pos_error + orn_error

In [648]:
from scipy.optimize import minimize

In [653]:
def objective(x):
    qs = rollout(state, x)
    return jax.vmap(goal_pose_reaching_cost, in_axes=(0, None))(
        to_mat(qs), goal_pose).sum()
def cineq_fn(x):
    qs = rollout(state, x)
    ub_viol = to_mat(qs) - panda.ub
    lb_viol = panda.lb - to_mat(qs)
    return jnp.hstack([ub_viol.flatten(), lb_viol.flatten()])
lb = -np.ones(dim)
ub = np.ones(dim)

In [654]:
from functools import partial
def value_and_jacrev(x, f):
    y, pullback = jax.vjp(f, x)
    basis = jnp.eye(y.size, dtype=y.dtype)
    jac = jax.vmap(pullback)(basis)[0]
    return y, jac

In [655]:
#precompile
obj_val_grad = jax.jit(jax.value_and_grad(objective)).lower(x0).compile()
cineq_fn_jit = jax.jit(cineq_fn).lower(x0).compile()
jac_cineq_fn_jit = jax.jit(jax.jacrev(cineq_fn)).lower(x0).compile()

In [678]:
constr = ({"type":'ineq', 'fun':cineq_fn_jit})
bounds = [(l, u) for l, u in zip(lb, ub)]

In [None]:
, 'jac':jac_cineq_fn_jit

In [679]:
res = minimize(
    obj_val_grad, x0, 
    method="SLSQP", jac=True, bounds=bounds,
    constraints=constr) #options={"ftol":0.01}

In [684]:
bounds

[(-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0, 1.0),
 (-1.0

In [520]:
import nlopt

opt = nlopt.opt(nlopt.LD_AUGLAG, dim)

def f(x, grad):
    val, grads = obj_val_grad(x)
    if grad.size > 0:
        grad[:] = grads
    return val.item()
def cineq(result, x, grad):
    val, grads = cineq_fn_val_grad(x)
    if grad.size > 0:
       grad[:] = grads
    result[:] = val
opt.set_min_objective(f)
opt.add_inequality_mconstraint(cineq, np.full(dim*2, 0.01))
opt.set_lower_bounds(lb)
opt.set_upper_bounds(ub)
xtol = 1e-2


In [521]:
opt2 = nlopt.opt(nlopt.LD_SLSQP, dim)
opt2.set_xtol_rel(xtol)

opt.set_xtol_rel(xtol*10)
opt.set_local_optimizer(opt2)

In [522]:
xopt = opt.optimize(x0)

In [627]:
import time
tic = time.perf_counter()
for i in range(10):
    xopt = opt.optimize(xopt)
    print(time.perf_counter()-tic)
toc = time.perf_counter()
elapsed = (toc-tic)/100
print(elapsed)

0.0719231809489429
0.12245099293068051
0.17294700792990625
0.2233892648946494
0.27373430295847356
0.3257537828758359
0.37680774996988475
0.42739091999828815
0.4775787799153477
0.529890367994085
0.005299723159987479


In [647]:
%time 
opt.optimize(xopt)

CPU times: user 9 µs, sys: 0 ns, total: 9 µs
Wall time: 19.1 µs


array([ 1.        , -1.        ,  1.        , -1.        , -1.        ,
        1.        , -1.        ,  1.        , -1.        ,  0.59301524,
       -1.        , -1.        ,  1.        , -1.        ,  1.        ,
       -1.        , -1.        , -1.        , -1.        ,  1.        ,
       -1.        ,  1.        , -1.        , -1.        , -1.        ,
       -0.90145735,  1.        , -1.        ,  1.        , -1.        ,
       -1.        , -1.        ,  0.2810863 ,  1.        , -1.        ,
        1.        , -0.42073285, -1.        , -1.        ,  1.        ,
        1.        , -1.        ,  1.        , -0.01279864, -1.        ,
       -1.        ,  1.        ,  1.        , -1.        ,  0.32739543,
        0.02293431, -1.        , -1.        ,  1.        ,  1.        ,
       -1.        , -0.55936623, -0.05741769, -1.        , -1.        ,
        1.        ,  1.        , -1.        , -0.62991175, -0.10965685,
       -1.        , -1.        ,  1.        ,  1.        , -1.  

In [625]:
%time
xopt = opt.optimize(xopt)

CPU times: user 8 µs, sys: 0 ns, total: 8 µs
Wall time: 16.5 µs


In [667]:
qs = rollout(state, res.x)

In [668]:
i = 0

In [673]:
panda.set_joint_angles(to_mat(qs)[i])
i += 1

In [382]:
goal_pose_reaching_cost(to_mat(qs)[-1], goal_pose)

Array(0.00224686, dtype=float32)

In [351]:
jax.vmap(goal_pose_reaching_cost, in_axes=(0, None))(
        to_mat(qs), goal_pose)

Array([0.2801563 , 0.26153243, 0.23993082, 0.21622306, 0.19183062,
       0.1678482 , 0.14501941, 0.12335581, 0.103462  , 0.08421466,
       0.06530553, 0.04805544, 0.0331141 , 0.02121489, 0.01230958,
       0.00614255, 0.00258136, 0.00103878, 0.00105415, 0.00224686],      dtype=float32)

In [129]:
cineq_fn(u).shape

(140,)

In [107]:

cineq_fn_val_grad = jax.jit(jax.value_and_grad(cineq_fn)).lower(x0).compile()

(70,)

In [None]:
opt.set_lower_bounds(lb)
opt.set_upper_bounds(ub)
def f(x, grad):
    val, grads = jax.value_and_grad(objective)(x)
    if grad.size > 0:
        grad[:] = grads
    return val.item()
opt.set_min_objective(f)

In [93]:
state = panda.neutral, jnp.zeros(robot_dim)
u = jnp.zeros(robot_dim*n)

In [100]:
qs = rollout(state, u)
jax.vmap(goal_pose_reaching_cost, in_axes=(0, None))(
    to_mat(qs), goal_pose)

Array([0.29052478, 0.29052478, 0.29052478, 0.29052478, 0.29052478,
       0.29052478, 0.29052478, 0.29052478, 0.29052478, 0.29052478],      dtype=float32)

In [97]:
# problem
qtn = np.random.random(4)
xyz = np.random.uniform([0.3, -1, 0.3], [0.6, 1, 0.6])
goal_pose = SE3.from_rotation_and_translation(SO3(qtn).normalize(), xyz)
frame.set_pose(goal_pose)