In [11]:
import numpy as np
import jax
import jax.numpy as jnp
from jaxlie import SE3, SO3
from functools import partial
import jax_dataclasses as jdc

from sdf_world.sdf_world import *
from sdf_world.robots import *
from sdf_world.util import *

In [2]:
world = SDFWorld()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7001/static/


In [3]:
panda_model = RobotModel(PANDA_URDF, PANDA_PACKAGE)
panda = Robot(world.vis, "panda", panda_model, alpha=0.5)
panda.reduce_dim([7, 8], [0.04, 0.04])

In [1191]:
# ik
robot_dim = 7
horizon = 5
dim = robot_dim * horizon
dt = 0.1

to_mat = lambda x: x.reshape(-1, robot_dim)
to_vec = lambda x: x.flatten()

def to_vel(x, state):
    qmat = jnp.vstack([state.q0, to_mat(x)])
    return (qmat[1:] - qmat[:-1]).flatten() / dt

In [7]:
# Kinematics
def skew(v):
    v1, v2, v3 = v
    return jnp.array([[0, -v3, v2],
                      [v3, 0., -v1],
                      [-v2, v1, 0.]])

def get_rotvec_angvel_map(v):
    vmag = jnp.linalg.norm(v)
    vskew = skew(v)
    return jnp.eye(3) \
        - 1/2*skew(v) \
        + vskew@vskew * 1/vmag**2 * (1-vmag/2 * jnp.sin(vmag)/(1-jnp.cos(vmag)))

@jax.jit
def get_ee_fk_jac(q):
    # outputs ee_posevec and analytical jacobian
    fks = panda_model.fk_fn(q)
    p_ee = fks[-1][-3:]
    rotvec_ee = SO3(fks[-1][:4]).log()
    E = get_rotvec_angvel_map(rotvec_ee)
    jac = []
    for posevec in fks[1:8]:
        p_frame = posevec[-3:]
        rot_axis = SE3(posevec).as_matrix()[:3, 2]
        lin_vel = jnp.cross(rot_axis, p_ee - p_frame)
        jac.append(jnp.hstack([lin_vel, rot_axis]))
    jac = jnp.array(jac).T
    jac = jac.at[3:, :].set(E @ jac[3:, :])
    return jnp.hstack([p_ee, rotvec_ee]), jac

In [17]:
frame = Frame(world.vis, "frame")
frame_ee = Frame(world.vis, "frame_ee")
def to_posevec(pose:SE3):
    return jnp.hstack([
        pose.translation(), pose.rotation().log()
    ])
def make_pose():
    pose = SE3.from_rotation_and_translation(
        SO3(np.random.random(4)).normalize(),
        np.random.uniform([-0.3,-0.5,0.3],[0.6, 0.5, 0.8])
    )
    return pose

In [1539]:
def barrier_fn(x, eps=0.01):
    # def fn_not_active(x):
    #     return 0.
    # def fn_in_transition(x):
    #     return (1/2/eps)*(x+eps)**2
    # def fn_active(x):
    #     return x + 1/2 * eps
    # x = -x
    # is_in_transition = (-eps < x) & (x < 0.)
    # is_active = 0 <= x
    # switch_var = is_active + is_in_transition*2
    # fns = [fn_not_active, fn_active, fn_in_transition]
    # result = jax.lax.switch(switch_var, fns, x)
    return jnp.where(x>=0, 0., -x)

In [1540]:
@jdc.pytree_dataclass
class State:
    q0: Array
    target: Array
line = DottedLine(world.vis, "line", jnp.zeros((5,3)))
def show_path(x, state:State):
    panda.set_joint_angles(to_mat(x)[-1])
    joints = jnp.vstack([state.q0, to_mat(x)])
    points = jax.vmap(panda_model.fk_fn)(joints)[:,-1,-3:]
    line.reload(points=points)

In [1919]:
def value_and_jacrev(x, state, f):
  y, pullback = jax.vjp(f, x, state)
  basis = jnp.eye(y.size, dtype=y.dtype)
  jac = jax.vmap(pullback)(basis)
  return y, jac[0] # jacobian for x

pose_weight = np.array([1, 1, 1, 0.3, 0.3, 0.3])
def get_vg_pose_residual(x, state):
    ee_pose, jac_kin = get_ee_fk_jac(to_mat(x)[-1])
    error = state.target - ee_pose
    return error, jnp.hstack([jnp.zeros((6, 28)), - jac_kin]), pose_weight

def joint_limit_violation(x, state):
    ub_viol = np.tile(panda.ub, horizon) - x # should be positive
    lb_viol = x - np.tile(panda.lb, horizon)
    #viols = jax.vmap(barrier_fn)(jnp.hstack([ub_viol, lb_viol]))
    viols = jnp.hstack([ub_viol, lb_viol])
    return viols

joint_vel_ub = jnp.ones(dim)
joint_vel_lb = -jnp.ones(dim)
def joint_vel_limit_violation(x, state):
    vel = to_vel(x, state)
    ub_viol = joint_vel_ub - vel # should be positive
    lb_viol = vel - joint_vel_lb
    #viols = jax.vmap(barrier_fn)(jnp.hstack([ub_viol, lb_viol]))
    viols = jnp.hstack([ub_viol, lb_viol])
    return viols #.reshape(2, -1).sum(axis=0)
_get_vg_joint_limit_viol = jax.jit(partial(value_and_jacrev, f=joint_limit_violation))
_get_vg_joint_vel_limit_viol = jax.jit(partial(value_and_jacrev, f=joint_vel_limit_violation))

def get_vg_joint_limit_viol(x, state):
    val, jac = _get_vg_joint_limit_viol(x, state)
    return val, jac, jnp.ones(dim)
def get_vg_joint_vel_limit_viol(x, state):
    val, jac = _get_vg_joint_vel_limit_viol(x, state)
    return val, jac, jnp.ones(dim*2)

In [1920]:
import osqp

In [1713]:
vg_fns = [get_vg_pose_residual, get_vg_joint_limit_viol, get_vg_joint_vel_limit_viol]
@jax.jit
def eval(x, state):
    residuals = []
    jacs = []
    weights = []
    for vg_fn in vg_fns:
        r, jac, weight = vg_fn(x, state)
        residuals.append(r)
        jacs.append(jac)
        weights.append(weight)
    return jnp.hstack(residuals), jnp.vstack(jacs), jnp.hstack(weights)

In [1714]:
#set problem
target_pose = make_pose()
target = to_posevec(target_pose)
frame.set_pose(target_pose)

In [1869]:
#initialize
state = State(panda.neutral, target)
x = np.tile(panda.neutral, 5)
show_path(x, state)

In [1870]:
mu = 0.01
factor = 2.

In [1871]:
#given x, 
error, jac, weight = eval(x, state)
# weight = weight.at[-dim:].set(jnp.zeros(dim))
W = jnp.diag(weight)

d = jnp.linalg.solve(jac.T@W@jac, - jac.T@W@error) # + mu*jnp.eye(dim)

error_new, _, _ = eval(x+d, state)
x = x + d*1.

# model_improvement = 1/2*d@(mu*d - jac.T@error)
# actual_improvement = error@error - error_new@error_new
# gain_factor = actual_improvement/model_improvement
# if gain_factor > 0.:
#     x = x + d
#     mu = np.max([mu/factor, 1e-5])
# else: 
#     mu = np.min([mu/factor, 1e5])

show_path(x, state)
#print(gain_factor)

In [1885]:
q, r = np.linalg.qr(jac)

In [1893]:
eps = 1e-6
error > eps

In [1891]:
jnp.linalg.lstsq()

array([[ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
        -1.6653345e-15, -1.7700882e-01, -8.9062968e-10],
       [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
        -1.7700827e-01,  1.1160222e-08, -7.4505806e-09],
       [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
        -4.6566123e-10, -1.4613657e-01, -2.7228692e-10],
       ...,
       [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
       [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
       [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00]], dtype=float32)

In [1879]:
jac_r = jac + jnp.eye(jac.shape[1])
jac_r.T@jac_r

TypeError: add got incompatible shapes for broadcasting: (111, 35), (35, 35).

In [1806]:
joint_limit_violation(x, state)

Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0.], dtype=float32)

In [1816]:
joint_vel_limit_violation(x, state).reshape(2, -1)

Array([[0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 2.3841858e-07,
        0.0000000e+00, 1.2257099e-03, 3.8635731e-04, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 2.3841858e-07, 0.0000000e+00,
        1.2245178e-03, 3.8635731e-04, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.2269020e-03,
        3.8635731e-04, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 1.2245178e-03, 3.8647652e-04,
        0.0000000e+00, 0.0000000e+00, 5.2452087e-06, 0.0000000e+00,
        0.0000000e+00, 1.2245178e-03, 3.8623810e-04],
       [0.0000000e+00, 4.0414333e-03, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        4.0414333e-03, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 4.0414333e-03,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 1.1920929e-07, 4.0414333e-03, 0.0000000

In [1190]:
joint_vel_limit_violation(x, state)

Array([ 0.       ,  0.       ,  0.       ,  0.       ,  0.       ,
        0.       ,  0.       ,  0.       ,  0.       ,  0.       ,
        0.       ,  0.       ,  0.       ,  0.       ,  0.       ,
        0.       ,  0.       ,  0.       ,  0.       ,  0.       ,
        0.       ,  0.       ,  0.       ,  0.       ,  0.       ,
        0.       ,  0.       ,  0.       ,  6.1204395,  3.5883708,
        3.6641312,  5.5315666,  2.392091 , 17.528088 ,  4.63465  ],      dtype=float32)

Array([0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
       0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
       0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
       0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
       0.       , 0.       , 0.       , 0.       , 0.       , 0.6483364,
       0.       , 0.       , 0.       , 0.       , 0.       ],      dtype=float32)

In [1086]:
jnp.hstack([ub_viol, lb_viol])

Array([[ 2.9671   ,  1.8326   ,  2.9671   ,  1.5708   ,  2.9671   ,
         1.9548   ,  2.9671   ,  2.9671   ,  1.8326   ,  2.9671   ,
         1.5708   ,  2.9671   ,  1.9547999,  2.9671   ],
       [ 2.9671   ,  1.8326   ,  2.9671   ,  1.5708   ,  2.9671   ,
         1.9548   ,  2.9671   ,  2.9671   ,  1.8326   ,  2.9671   ,
         1.5708   ,  2.9671   ,  1.9547999,  2.9671   ],
       [ 2.9671   ,  1.8326   ,  2.9671   ,  1.5708   ,  2.9671   ,
         1.9548   ,  2.9671   ,  2.9671   ,  1.8326   ,  2.9671   ,
         1.5708   ,  2.9671   ,  1.9547999,  2.9671   ],
       [ 2.9671   ,  1.8326   ,  2.9671   ,  1.5708   ,  2.9671   ,
         1.9548   ,  2.9671   ,  2.9671   ,  1.8326   ,  2.9671   ,
         1.5708   ,  2.9671   ,  1.9547999,  2.9671   ],
       [ 2.8619504,  4.3085365,  3.4635532,  1.9386004,  3.6764345,
         2.6547718,  4.8591447,  3.0722494, -0.6433364,  2.4706466,
         1.2029995,  2.2577653,  1.2548282,  1.0750551]], dtype=float32)

In [1083]:
viols

NameError: name 'viols' is not defined

In [1074]:
panda.ub

array([2.9671, 1.8326, 2.9671, 0.    , 2.9671, 3.8223, 2.9671])

In [1079]:
to_mat(x)[-1] - panda.lb

Array([ 3.0722494, -0.6433364,  2.4706466,  1.2029995,  2.2577653,
        1.2548282,  1.0750551], dtype=float32)

In [647]:
gain_factor

Array(nan, dtype=float32)

In [1917]:
%timeit run(params_obj=(Q, c), params_eq=(A, b), params_ineq=(G, h))

2.97 ms ± 52.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [572]:
gain_factor

Array(0.67475086, dtype=float32)

In [1902]:
jac.T@error

Array([ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  1.2239296 ,  2.429865  ,
        1.2239296 , -2.3352067 , -0.47887915, -2.2273946 , -1.3044058 ],      dtype=float32)

In [557]:
error@error

Array(2.4639332e-05, dtype=float32)

In [1899]:
jnp.linalg.matrix_rank(jac.T@jac)

Array(6, dtype=int32)

In [1900]:
jac.T@jac

Array([[0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       ...,
       [0.        , 0.        , 0.        , ..., 1.242499  , 0.45108017,
        0.5097498 ],
       [0.        , 0.        , 0.        , ..., 0.45108017, 2.1660151 ,
        0.06741326],
       [0.        , 0.        , 0.        , ..., 0.5097498 , 0.06741326,
        2.275847  ]], dtype=float32)

In [1896]:
jnp.linalg.eigh(jac.T@jac)

(Array([-4.4572613e-07,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
         0.0000000e+00,  2.2168718e-02,  1.5427074e-01,  1.8027227e-01,
         1.2539967e+00,  7.0149894e+00,  7.3020172e+00], dtype=float32),
 Array([[ 0.0000000e+00,  1.0000000e+00,  0.0000000e+00, ...,
          0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
        [ 0.0000000e+00,  0.0000000e+00,  1.0000000e+00, ...,
          0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
        [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
          0.0000000e+00,  0.0000000e+00,  0.00000

In [290]:
sdir = jnp.eye(35)

Array([ 0.00000000e+00, -7.02036743e-07, -8.96019571e-07,  3.66521391e-08,
       -2.18523169e-07,  1.08757284e-07,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
       -5.76895475e-01,  1.64767504e+00, -5.76895475e-01,  2.45629406e+00,
        1.89921856e+00, -2.24523520e+00, -1.13092446e+00], dtype=float32)