In [1]:
import numpy as np
import jax
import jax.numpy as jnp
from jaxlie import SE3, SO3
import jax_dataclasses as jdc
import cyipopt
from functools import partial

from sdf_world.sdf_world import *
from sdf_world.robots import *
from sdf_world.util import *

In [2]:
world = SDFWorld()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7003/static/


In [3]:
panda_model = RobotModel(PANDA_URDF, PANDA_PACKAGE)
panda = Robot(world.vis, "panda", panda_model, alpha=0.5)
panda.reduce_dim([7, 8], [0.04, 0.04])

In [4]:
robot_dim = 7
horizon = 10
dim = robot_dim * horizon
dt = 0.1

to_mat = lambda x: x.reshape(-1, robot_dim)
to_vec = lambda x: x.flatten()
def to_posevec(pose:SE3):
    return jnp.hstack([pose.translation(), pose.rotation().log()])
def to_vel(x, state):
    qmat = jnp.vstack([state.q0, to_mat(x)])
    vel_mat = (qmat[1:] - qmat[:-1]) / dt
    return vel_mat.flatten()

In [5]:
# Kinematics
def get_rotvec_angvel_map(v):
    def skew(v):
        v1, v2, v3 = v
        return jnp.array([[0, -v3, v2],
                        [v3, 0., -v1],
                        [-v2, v1, 0.]])
    vmag = jnp.linalg.norm(v)
    vskew = skew(v)
    return jnp.eye(3) \
        - 1/2*skew(v) \
        + vskew@vskew * 1/vmag**2 * (1-vmag/2 * jnp.sin(vmag)/(1-jnp.cos(vmag)))

@jax.jit
def get_ee_fk_jac(q):
    # outputs ee_posevec and analytical jacobian
    fks = panda_model.fk_fn(q)
    p_ee = fks[-1][-3:]
    rotvec_ee = SO3(fks[-1][:4]).log()
    E = get_rotvec_angvel_map(rotvec_ee)
    jac = []
    for posevec in fks[1:8]:
        p_frame = posevec[-3:]
        rot_axis = SE3(posevec).as_matrix()[:3, 2]
        lin_vel = jnp.cross(rot_axis, p_ee - p_frame)
        jac.append(jnp.hstack([lin_vel, rot_axis]))
    jac = jnp.array(jac).T
    jac = jac.at[3:, :].set(E @ jac[3:, :])
    return jnp.hstack([p_ee, rotvec_ee]), jac

In [None]:
@jdc.pytree_dataclass
class State:
    q0: Array
    target: Array

In [6]:
frame = Frame(world.vis, "frame")
frame_ee = Frame(world.vis, "frame_ee")
line = DottedLine(world.vis, "line", np.zeros([horizon+1, 3]))
def make_pose():
    return SE3.from_rotation_and_translation(
        SO3(np.random.random(4)).normalize(),
        np.random.uniform([-0.3,-0.5,0.3],[0.6, 0.5, 0.8])
    )
def draw_traj(x, state):
    get_ee_point = lambda q: panda_model.fk_fn(q)[-1,-3:]
    qmat = jnp.vstack([state.q0, to_mat(x)])
    points = jax.vmap(get_ee_point)(qmat)
    line.reload(points=points)
    panda.set_joint_angles(qmat[-1])

In [8]:
# CGN fns
class ResidualFn:
    def get_error_and_jac(self, x, state):
        pass
    def get_weight(self):
        pass

class ConstraintFn:
    def get_value_and_jac(self, x, state):
        pass
    def get_bounds(self):
        pass

@jax.jit
def residual_eval(x, state, res_fns:Tuple[ResidualFn]):
    errors, jacs = [], []
    for res_fn in res_fns:
        error, jac = res_fn.get_error_and_jac(x, state)
        errors.append(error)
        jacs.append(jac)
    return jnp.hstack(errors), jnp.vstack(jacs)

def residual_weights(res_fns:Tuple[ResidualFn]):
    weights = []
    for res_fn in res_fns:
        weights.append(res_fn.get_weight())
    return jnp.hstack(weights)

@jax.jit
def constr_eval(x, state, constr_fns:Tuple[ConstraintFn]):
    vals, jacs = [], []
    for constr_fn in constr_fns:
        val, jac = constr_fn.get_value_and_jac(x, state)
        vals.append(val)
        jacs.append(jac)
    return jnp.hstack(vals), jnp.vstack(jacs)

def constr_bounds(constr_fns:Tuple[ConstraintFn]):
    lbs, ubs = [], []
    for constr_fn in constr_fns:
        lb, ub = constr_fn.get_bounds()
        lbs.append(lb)
        ubs.append(ub)
    return jnp.hstack(lbs), jnp.hstack(ubs)

In [9]:
@jdc.pytree_dataclass
class FinalPoseError(ResidualFn):
    def get_error_and_jac(self, x, state:State):
        q_final = x[-robot_dim:]
        ee, jac = get_ee_fk_jac(q_final)
        error = state.target - ee
        zeropad = jnp.zeros([6, dim-robot_dim])
        return error, jnp.hstack([zeropad, -jac])
    
    def get_weight(self):
        return np.array([1,1,1,0.3,0.3,0.3])

@jdc.pytree_dataclass
class Velocity(ResidualFn):
    def get_error_and_jac(self, x, state:State):
        vel_fn = lambda x, state: to_vel(x, state)
        return value_and_jacrev(vel_fn, x, state)
    
    def get_weight(self):
        return np.ones(dim)*0.00001/horizon

@jdc.pytree_dataclass
class JointLimit(ConstraintFn):
    """ lb < val < ub. dval/dx = jac"""
    robot_lb: Array
    robot_ub: Array
    def get_value_and_jac(self, x, state):
        val = x
        jac = jnp.eye(dim)
        return val, jac
    
    def get_bounds(self):
        lb = jnp.tile(self.robot_lb, horizon)
        ub = jnp.tile(self.robot_ub, horizon)
        return lb, ub

In [10]:
res_fns = [FinalPoseError(), Velocity()] #, 
constr_fns = [JointLimit(panda.lb, panda.ub)]

In [2048]:
# problem setting
pose_d = make_pose()
frame.set_pose(pose_d)
posevec_d = to_posevec(pose_d)
state = State(panda.neutral, posevec_d)

In [2049]:
x = jnp.tile(panda.neutral, horizon)
panda.set_joint_angles(state.q0)

In [2050]:
import osqp
from scipy import sparse

prob = osqp.OSQP()
is_qp_init = False
weights = residual_weights(res_fns)
lb, ub = constr_bounds(constr_fns)
W = jnp.diag(weights)

tr_length = 0.1 # trust-region length
max_tr_length = 1.

err, jac = residual_eval(x, state, res_fns)
val = err@W@err

In [2088]:
cval, cjac = constr_eval(x, state, constr_fns)
P = sparse.csc_matrix(jac.T@W@jac)  # hess = jac.T@W@jac
q = np.asarray(jac.T@W@err)         # grad = jac.T@W@err
A = sparse.csc_matrix(np.vstack([cjac, np.eye(dim)]))
l = np.hstack([lb-cval, np.full(dim, -tr_length)])
u = np.hstack([ub-cval, np.full(dim, tr_length)])

if not is_qp_init:
    qp_settings = dict(check_termination=10, verbose=False)
    prob.setup(P, q, A, l, u, **qp_settings)
    is_qp_init = True
else:
    prob.update(
        Px=sparse.triu(P).data, Ax=A.data,
        q=q, l=l, u=u)
res = prob.solve()
print(res.info.status)
p = res.x

err_new, jac_new = residual_eval(x+p, state, res_fns)
val_new = err_new@W@err_new
pred_reduction = 0.5*(val - (err+jac@p)@W@(err+jac@p))
true_reduction = val - val_new
ratio = true_reduction/pred_reduction

if ratio < 0.25:
    tr_length /= 4
elif ratio > 0.75 and np.abs(np.linalg.norm(p, np.inf) - tr_length) < 1e-3:
    tr_length = np.minimum(2*tr_length, max_tr_length)
else: pass

if ratio > 0.2:
    x = x + p
    err, jac, val = err_new, jac_new, val_new

draw_traj(x, state)
# panda.set_joint_angles(x)

solved


In [2101]:
panda.set_joint_angles(to_mat(x)[9])

: 