In [1]:
import numpy as np
import jax
import jax.numpy as jnp
from jaxlie import SE3, SO3
import jax_dataclasses as jdc
from functools import partial

from sdf_world.sdf_world import *
from sdf_world.robots import *
from sdf_world.util import *

In [2]:
world = SDFWorld()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7006/static/


In [3]:
panda_model = RobotModel(PANDA_URDF, PANDA_PACKAGE)
panda = Robot(world.vis, "panda", panda_model, alpha=0.5)
panda.reduce_dim([7, 8], [0.04, 0.04])

In [4]:
#load object
from flax import linen as nn

class GraspNet(nn.Module):
    hidden_dim: int
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        logit = nn.Dense(features=5)(x)
        return logit

In [5]:
obj_start = Mesh(world.vis, "obj_start", "./sdf_world/assets/object/mesh.obj",
                 alpha=0.5)
d, w, h = obj_start.mesh.bounding_box.primitive.extents
obj_start.set_translate([0.4, -0.3, h/2])
frame = Frame(world.vis, "grasp_pose")

In [6]:
from flax.training import orbax_utils
import orbax
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
raw_restored = orbax_checkpointer.restore("model/grasp_net")
params = raw_restored["params"]
grasp_net = GraspNet(raw_restored["hidden_dim"])
grasp_fn = lambda x: grasp_net.apply(params, x)

In [7]:
import pickle
with open("./sdf_world/assets/object"+'/info.pkl', 'rb') as f:
    obj_data = pickle.load(f)
scale_to_norm = obj_data["scale_to_norm"]
def grasp_reconst(g:Array):
    rot = SO3(grasp_fn(g)[1:5]).normalize()
    trans = g/scale_to_norm
    return SE3.from_rotation_and_translation(rot, trans)
grasp_logit_fn = lambda g: grasp_fn(g)[0]

In [8]:
# CGN fns
class ResidualFn:
    def get_error_and_jac(self, x, state):
        pass
    def get_weight(self):
        pass

class ConstraintFn:
    def get_value_and_jac(self, x, state):
        pass
    def get_bounds(self):
        pass

@jax.jit
def residual_eval(x, state, res_fns:Tuple[ResidualFn]):
    errors, jacs = [], []
    for res_fn in res_fns:
        error, jac = res_fn.get_error_and_jac(x, state)
        errors.append(error)
        jacs.append(jac)
    return jnp.hstack(errors), jnp.vstack(jacs)

def residual_weights(res_fns:Tuple[ResidualFn]):
    weights = []
    for res_fn in res_fns:
        weights.append(res_fn.get_weight())
    return jnp.hstack(weights)

@jax.jit
def constr_eval(x, state, constr_fns:Tuple[ConstraintFn]):
    vals, jacs = [], []
    for constr_fn in constr_fns:
        val, jac = constr_fn.get_value_and_jac(x, state)
        vals.append(val)
        jacs.append(jac)
    return jnp.hstack(vals), jnp.vstack(jacs)

def constr_bounds(constr_fns:Tuple[ConstraintFn]):
    lbs, ubs = [], []
    for constr_fn in constr_fns:
        lb, ub = constr_fn.get_bounds()
        lbs.append(lb)
        ubs.append(ub)
    return jnp.hstack(lbs), jnp.hstack(ubs)

In [9]:
# Kinematics
def get_rotvec_angvel_map(v):
    def skew(v):
        v1, v2, v3 = v
        return jnp.array([[0, -v3, v2],
                        [v3, 0., -v1],
                        [-v2, v1, 0.]])
    vmag = jnp.linalg.norm(v)
    vskew = skew(v)
    return jnp.eye(3) \
        - 1/2*skew(v) \
        + vskew@vskew * 1/vmag**2 * (1-vmag/2 * jnp.sin(vmag)/(1-jnp.cos(vmag)))

@jax.jit
def get_ee_fk_jac(q):
    # outputs ee_posevec and analytical jacobian
    fks = panda_model.fk_fn(q)
    p_ee = fks[-1][-3:]
    rotvec_ee = SO3(fks[-1][:4]).log()
    E = get_rotvec_angvel_map(rotvec_ee)
    jac = []
    for posevec in fks[1:8]:
        p_frame = posevec[-3:]
        rot_axis = SE3(posevec).as_matrix()[:3, 2]
        lin_vel = jnp.cross(rot_axis, p_ee - p_frame)
        jac.append(jnp.hstack([lin_vel, rot_axis]))
    jac = jnp.array(jac).T
    jac = jac.at[3:, :].set(E @ jac[3:, :])
    return jnp.hstack([p_ee, rotvec_ee]), jac

In [10]:
robot_dim = 7
grasp_dim = 3
dim = grasp_dim + robot_dim
def to_posevec(pose:SE3):
    return jnp.hstack([pose.translation(), pose.rotation().log()])

In [11]:
def value_and_jacrev2(f, *x):
    y, pullback = jax.vjp(f, *x)
    basis = jnp.eye(y.size, dtype=y.dtype)
    jac = jax.vmap(pullback)(basis)
    return y, jac

In [12]:
grasp_to_posevec = lambda g: to_posevec(grasp_reconst(g))

In [295]:
@jdc.pytree_dataclass
class KinError(ResidualFn):
    def get_error_and_jac(self, x, state):
        grasp = x[:grasp_dim]
        grasp_pose = grasp_reconst(grasp)
        grasp_posevec = to_posevec(obj_start.pose@grasp_pose)
        q = x[-robot_dim:]
        ee, jac = get_ee_fk_jac(q)
        error = grasp_posevec - ee
        return error, jnp.hstack([jnp.zeros((6,3)), -jac])
    
    def get_weight(self):
        return np.array([1,1,1,0.3,0.3,0.3])

initial_pose = get_ee_fk_jac(panda.neutral)[0]
euc_dist_to_grasp = lambda g: to_posevec(obj_start.pose@grasp_reconst(g))[:3]
@jdc.pytree_dataclass
class GraspDistance(ResidualFn):
    def get_error_and_jac(self, x, state):
        grasp = x[:grasp_dim]
        grasp_pose = grasp_reconst(grasp)
        grasp_pose_wrt_world = obj_start.pose@grasp_pose
        grasp_posevec = to_posevec(grasp_pose_wrt_world)
        error = grasp_posevec[:3]
        jac = 1 / scale_to_norm * jnp.eye(3)
        return error, jnp.hstack([jac, jnp.zeros((3,7))])
    
    def get_weight(self):
        return np.array([1,1,1.])

# @jdc.pytree_dataclass
# class Velocity(ResidualFn):
#     def get_error_and_jac(self, x, state:State):
#         vel_fn = lambda x, state: to_vel(x, state)
#         return value_and_jacrev(vel_fn, x, state)
    
#     def get_weight(self):
#         return np.ones(dim)*0.00001/horizon

@jdc.pytree_dataclass
class JointLimit(ConstraintFn):
    """ lb < val < ub. dval/dx = jac"""
    robot_lb: Array
    robot_ub: Array
    def get_value_and_jac(self, x, state):
        val = x[-robot_dim:]
        jac = jnp.eye(robot_dim)
        return val, jnp.hstack([jnp.zeros((7,3)), jac])
    
    def get_bounds(self):
        lb = jnp.tile(self.robot_lb, 1)
        ub = jnp.tile(self.robot_ub, 1)
        return lb, ub

@jdc.pytree_dataclass
class GraspLogit(ConstraintFn):
    """ lb < val < ub. dval/dx = jac"""
    def get_value_and_jac(self, x, state):
        grasp_logit_fn = lambda x: grasp_fn(x)[0]
        val, jac = jax.value_and_grad(grasp_logit_fn)(x[:grasp_dim])
        return val, jnp.hstack([jac, jnp.zeros(7)])
    
    def get_bounds(self):
        lb = 1.
        ub = jnp.inf
        return lb, ub

In [296]:
@jdc.pytree_dataclass
class State:
    q0: Array
    target: Array
state = State(jnp.array(panda.neutral), jnp.zeros(3))

In [297]:
res_fns = [KinError()] #, , 
constr_fns = [JointLimit(panda.lb, panda.ub), GraspLogit()] #, Penetration()

In [298]:
x = jnp.hstack([0,0,0.5,panda.neutral])
panda.set_joint_angles(x[-7:])

In [299]:
import osqp
from scipy import sparse

prob = osqp.OSQP()
is_qp_init = False
weights = residual_weights(res_fns)
lb, ub = constr_bounds(constr_fns)
W = jnp.diag(weights)

tr_length = 0.1 # trust-region length
max_tr_length = 1.

err, jac = residual_eval(x, state, res_fns)
val = err@W@err

In [302]:
cval, cjac = constr_eval(x, state, constr_fns)
P = sparse.csc_matrix(jac.T@W@jac)  # hess = jac.T@W@jac
q = np.asarray(jac.T@W@err)         # grad = jac.T@W@err
A = sparse.csc_matrix(np.vstack([cjac, np.eye(dim)]))
l = np.hstack([lb-cval, np.full(dim, -tr_length)])
u = np.hstack([ub-cval, np.full(dim, tr_length)])

if not is_qp_init:
    qp_settings = dict(check_termination=10, verbose=False)
    prob.setup(P, q, A, l, u, **qp_settings)
    is_qp_init = True
else:
    prob.update(
        Px=sparse.triu(P).data, Ax=A.data,
        q=q, l=l, u=u)
res = prob.solve()
print(res.info.status)
p = res.x

err_new, jac_new = residual_eval(x+p, state, res_fns)
val_new = err_new@W@err_new
pred_reduction = 0.5*(val - (err+jac@p)@W@(err+jac@p))
true_reduction = val - val_new
ratio = true_reduction/pred_reduction

if ratio < 0.25:
    tr_length /= 4
    print("ratio is small")
elif ratio > 0.75 and np.abs(np.linalg.norm(p, np.inf) - tr_length) < 1e-3:
    tr_length = np.minimum(2*tr_length, max_tr_length)
    print("ratio getting bigger")
else: pass

if ratio > 0.2:
    print("updated!")
    x = x + p
    err, jac, val = err_new, jac_new, val_new

#draw_traj(x, state)
frame.set_pose(obj_start.pose@grasp_reconst(x[:3]))
panda.set_joint_angles(x[-robot_dim:])
print(p)

primal infeasible


TypeError: Dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.

: 

In [42]:
np.eye(dim)

array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])