In [1]:
import numpy as np
import jax
import jax.numpy as jnp
from jaxlie import SE3, SO3

from sdf_world.sdf_world import *
from sdf_world.robots import *
from sdf_world.util import *

In [2]:
world = SDFWorld()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7001/static/


In [3]:
frame = Frame(world.vis, "frame")
frame_curr = Frame(world.vis, "frame_curr")

In [606]:
def make_pose():
    return SE3.from_rotation_and_translation(
        SO3(np.random.random(4)).normalize(),
        np.random.uniform([-0.,-0.5,0.3],[0.5,0.5,0.8])
    )

In [5]:
pose_rand = make_pose()
frame.set_pose(pose_rand)

In [13]:
def to_posevec(pose:SE3):
    return jnp.hstack([
        pose.translation(), pose.rotation().log()
    ])
def to_SE3(posevec:Array):
    return SE3.from_rotation_and_translation(
        SO3.exp(posevec[3:]), posevec[:3]
    )

In [604]:
panda_model = RobotModel(PANDA_URDF, PANDA_PACKAGE)
panda = Robot(world.vis, "panda", panda_model, alpha=0.5)
panda.reduce_dim([7, 8], [0.04, 0.04])

In [1172]:
# Kinematics
def skew(v):
    v1, v2, v3 = v
    return jnp.array([[0, -v3, v2],
                      [v3, 0., -v1],
                      [-v2, v1, 0.]])

def get_rotvec_angvel_map(v):
    vmag = jnp.linalg.norm(v)
    return jnp.eye(3) \
        - 1/2*skew(v) \
        + skew(v)@skew(v) * 1/vmag**2 * (1-vmag/2 * jnp.sin(vmag)/(1-jnp.cos(vmag)))

@jax.jit
def get_ee_fk_jac(q):
    fks = panda_model.fk_fn(q)
    p_ee = fks[-1][-3:]
    rotvec_ee = SO3(fks[-1][:4]).log()
    E = get_rotvec_angvel_map(rotvec_ee)
    jac = []
    for posevec in fks[1:8]:
        p_frame = posevec[-3:]
        rot_axis = SE3(posevec).as_matrix()[:3, 2]
        lin_vel = jnp.cross(rot_axis, p_ee - p_frame)
        jac.append(jnp.hstack([lin_vel, rot_axis]))
    jac = jnp.array(jac).T
    jac = jac.at[3:, :].set(E @ jac[3:, :])
    return jnp.hstack([p_ee, rotvec_ee]), jac

In [1681]:
posevec_d = to_posevec(make_pose())
#posevec_curr = to_posevec(make_pose())

frame.set_pose(to_SE3(posevec_d))
#frame_curr.set_pose(to_SE3(posevec_curr))

In [1682]:
q = panda.neutral
panda.set_joint_angles(q)

In [1714]:
ee_posevec, jac = get_ee_fk_jac(q)
qupdate = jnp.linalg.pinv(jac) @ (posevec_d - ee_posevec)
q += qupdate*0.1
panda.set_joint_angles(q)

In [1634]:
del frame_curr

In [None]:
def get_error_twist(posevec_curr):
    pos_err = posevec_d[:3] - posevec_curr[:3]
    rot_diff = SO3.exp(posevec_curr[3:]).inverse() @ SO3.exp(posevec_d[3:]) # 
    rotvec_err = rot_diff.log()
    return jnp.hstack([pos_err, rot_angvel])

def get_error_twist(posevec_curr):
    pos_err = posevec_d[:3] - posevec_curr[:3]
    rot_diff = SO3.exp(posevec_curr[3:]).inverse() @ SO3.exp(posevec_d[3:]) # 
    rot_angvel = rot_diff.log()
    rot_angvel = SO3.exp(posevec_curr[3:]).apply(rot_angvel)
    return jnp.hstack([pos_err, rot_angvel])
rotvec_dot = get_rotvec_angvel_map(posevec_curr[3:]) @ twist[3:]

In [597]:
#difference: input:posevec_curr, parameter:posevec_d
def pose_difference(posevec_curr):
    pos_err = posevec_d[:3] - posevec_curr[:3]
    rot_diff = SO3.exp(posevec_curr[3:]).inverse() @ SO3.exp(posevec_d[3:])
    rot_err = rot_diff.log()
    return jnp.sum(pos_err**2) + jnp.sum(rot_err**2)

In [601]:
vg_pose_difference = jax.jit(jax.value_and_grad(pose_difference))

In [1169]:
twist = get_error_twist(posevec_curr)
rotvec_dot = get_rotvec_angvel_map(posevec_curr[3:]) @ twist[3:]
update = jnp.hstack([twist[:3], rotvec_dot])

posevec_curr += update*0.1
frame_curr.set_pose(to_SE3(posevec_curr))
#print(val)

In [1145]:
vg_pose_difference(posevec_curr)

(Array(2.703032, dtype=float32),
 Array([-0.51846766, -0.8305994 , -0.01259148, -0.58684206,  1.4605045 ,
         2.165997  ], dtype=float32))

In [594]:
val, grads = vg_pose_difference(posevec_curr)
posevec_curr -= grads*0.1
frame_curr.set_pose(to_SE3(posevec_curr))
print(val)

0.0072116824
