In [1]:
import numpy as np
import jax
import jax.numpy as jnp
from jaxlie import SE3, SO3
import jax_dataclasses as jdc
from functools import partial

from sdf_world.sdf_world import *
from sdf_world.robots import *
from sdf_world.util import *
from sdf_world.network import *
from sdf_world.sparse_ipopt import *

In [2]:
world = SDFWorld()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7002/static/


In [3]:
# robot, hand
panda_model = RobotModel(PANDA_URDF, PANDA_PACKAGE)
panda = Robot(world.vis, "panda", panda_model, alpha=0.5)
panda.reduce_dim([7, 8], [0.04, 0.04])

In [4]:
frame = Frame(world.vis, "frame")

In [699]:
frame2 = Frame(world.vis, "frame2")

In [949]:
def skew(v):
    v1, v2, v3 = v
    return jnp.array([[0, -v3, v2],
                      [v3, 0., -v1],
                      [-v2, v1, 0.]])

@jax.custom_jvp
def fk(q):
    fks = panda_model.fk_fn(q)
    return fks[-1]

@fk.defjvp
def fk_jvp(primals, tangents):
    q, = primals
    q_dot, = tangents
    fks = panda_model.fk_fn(q)
    qtn, p_ee = fks[-1][:4], fks[-1][-3:]
    w, xyz = qtn[0], qtn[1:]
    geom_jac = []
    for posevec in fks[1:8]:
        p_frame = posevec[-3:]
        rot_axis = SE3(posevec).as_matrix()[:3, 2]
        lin_vel = jnp.cross(rot_axis, p_ee - p_frame)
        geom_jac.append(jnp.hstack([lin_vel, rot_axis]))
    geom_jac = jnp.array(geom_jac).T  #geom_jacobian
    H = jnp.hstack([-xyz[:,None], skew(xyz)+jnp.eye(3)*w])
    rot_jac = 0.5*H.T@geom_jac[3:,:]
    jac = jnp.vstack([rot_jac, geom_jac[:3,:]])
    return fks[-1], jac@q_dot

In [2068]:
t = np.random.uniform([0.3,-0.4, 0.5], [0.7, 0.4, 0.7])
qtn = SO3(np.random.normal(size=4)).normalize()
target_pose = SE3.from_rotation_and_translation(qtn, t)
frame.set_pose(target_pose)

q = panda.neutral.copy()
panda.set_joint_angles(q)
ee = fk(q)
frame2.set_pose(SE3(ee))

In [1838]:
q = panda.neutral.copy()
panda.set_joint_angles(q)

In [2020]:
def error_vec(q, target_pose):
    target = target_pose.parameters()
    curr = fk(q)
    #R_curr = SO3(curr[:4]).as_matrix()
    pos_err = (target[-3:] - curr[-3:]) #R_curr.T @ 
    rot_err = (SO3(curr[:4]).inverse()@SO3(target[:4])).log()
    return jnp.hstack([pos_err, 0.5*rot_err])
error_jac = jax.jit(jax.jacfwd(error_vec))

In [2069]:
#dq = fk_grad(q, target_pose)
d = -jnp.linalg.pinv(error_jac(q, target_pose))@error_vec(q, target_pose)


q += d*0.3
panda.set_joint_angles(q)
ee = fk(q)
frame2.set_pose(SE3(ee))
#error(q, target_pose)

In [1624]:
print(fk(q))
print(SE3(fk(q)).normalize())

[ 0.06209511  0.9251849   0.34851304  0.13680421  0.58107424 -0.01059722
  0.5536008 ]
SE3(wxyz=[0.0621     0.92519    0.34851    0.13679999], xyz=[ 0.58107 -0.0106   0.5536 ])


In [934]:
error(q, target_pose)

Array(0.00557508, dtype=float32)

In [38]:
ee = fk(q)

In [7]:
q = jnp.zeros(7)

In [23]:
fks = panda_model.fk_fn(q)
p_ee = fks[-1][-3:]
qtn = fks[-1][:4]
jac = []
for posevec in fks[1:8]:
    p_frame = posevec[-3:]
    rot_axis = SE3(posevec).as_matrix()[:3, 2]
    lin_vel = jnp.cross(rot_axis, p_ee - p_frame)
    jac.append(jnp.hstack([lin_vel, rot_axis]))
jac = jnp.array(jac).T  #geom_jacobian
w, xyz = qtn[0], qtn[1:]
H = jnp.hstack([-xyz[:,None], skew(xyz)+jnp.eye(3)*w])
rot_jac = 0.5*H.T@jac[3:,:]
jac = jnp.vstack([jac[:3,:], rot_jac])

In [25]:
jac.shape

(7, 7)

Array([[-0.4619397 , -0.19134168, -0.        ],
       [ 0.        ,  0.        , -0.19134168],
       [ 0.        ,  0.        ,  0.4619397 ],
       [ 0.19134168, -0.4619397 ,  0.        ]], dtype=float32)

Array([[ 0.0000000e+00, -1.9134168e-01,  0.0000000e+00,  1.9134165e-01,
         0.0000000e+00,  1.9134165e-01,  0.0000000e+00],
       [-1.9134168e-01,  0.0000000e+00, -1.9134168e-01, -2.2809706e-08,
        -1.9134168e-01, -2.2809706e-08,  1.9134165e-01],
       [ 4.6193969e-01,  0.0000000e+00,  4.6193969e-01,  5.5067503e-08,
         4.6193969e-01,  5.5067503e-08, -4.6193963e-01],
       [ 0.0000000e+00, -4.6193969e-01,  0.0000000e+00,  4.6193963e-01,
         0.0000000e+00,  4.6193963e-01,  0.0000000e+00]], dtype=float32)

In [14]:
@jax.jit
def get_ee_fk_jac(q):
    # outputs ee_posevec and analytical jacobian
    fks = panda_model.fk_fn(q)
    ee = SE3(fks[-1])
    p_ee = fks[-1][-3:]
    #rotvec_ee = SO3(fks[-1][:4]).log()
    #E = get_rotvec_angvel_map(rotvec_ee)
    jac = []
    for posevec in fks[1:8]:
        p_frame = posevec[-3:]
        rot_axis = SE3(posevec).as_matrix()[:3, 2]
        lin_vel = jnp.cross(rot_axis, p_ee - p_frame)
        jac.append(jnp.hstack([lin_vel, rot_axis]))
    jac = jnp.array(jac).T #geometric jacobian
    #jac = jac.at[3:, :].set(E @ jac[3:, :])
    return ee, jac

In [353]:
def get_rotvec_angvel_map(v):
    def skew(v):
        v1, v2, v3 = v
        return jnp.array([[0, -v3, v2],
                        [v3, 0., -v1],
                        [-v2, v1, 0.]])
    vmag = jnp.linalg.norm(v)
    vskew = skew(v)
    cot = jnp.cos(vmag/2)/jnp.sin(vmag/2)
    alpha = (vmag/2)*cot
    return jnp.eye(3) \
        + 1/2*skew(v) \
        + (1-alpha)/vmag**2 * vskew@vskew
    # return jnp.eye(3) \
    #     - 1/2*skew(v) \
    #     + vskew@vskew * 1/vmag**2 * (1-vmag/2 * jnp.sin(vmag)/(1-jnp.cos(vmag)))

In [637]:
t = np.random.uniform([0.3,-0.5, 0.5], [0.8, 0.5, 0.8])
qtn = SO3(np.random.normal(size=4)).normalize()
target_pose = SE3.from_rotation_and_translation(qtn, t)
frame.set_pose(target_pose)

q = panda.neutral.copy()
panda.set_joint_angles(q)

In [689]:
ee_pose, geom_jac = get_ee_fk_jac(q)
R_ee = ee_pose.rotation().as_matrix()
err_t = R_ee.T@(target_pose.translation() - ee_pose.translation()) #body
err_r = ee_pose.rotation().inverse() @ target_pose.rotation()
err = jnp.hstack([err_t, err_r.log()])

B = get_rotvec_angvel_map(err_r.log())
jac_t = - R_ee.T @ geom_jac[:3]
jac_r = - B @ R_ee.T @ geom_jac[3:]
jac = jnp.vstack([jac_t, jac_r]) #, jac_r
hess = jac.T@jac
d = - jnp.linalg.pinv(hess)@jac.T@err
q = q + d * 0.05
panda.set_joint_angles(q)