In [2]:
import numpy as np
import jax
import jax.numpy as jnp
from jaxlie import SE3, SO3
import orbax
from flax.training import orbax_utils

from sdf_world.sdf_world import *
from sdf_world.robots import *
from sdf_world.util import *

In [3]:
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
net_dict = orbax_checkpointer.restore("model/grasp_net")

In [2]:
world = SDFWorld()
world.show_in_jupyter()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7007/static/


In [3]:
panda_model = RobotModel(PANDA_URDF, PANDA_PACKAGE)
panda = Robot(world.vis, "panda", panda_model, alpha=0.5)

concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts


In [4]:
box_start = Box(world.vis, "box_start", [0.06, 0.15, 0.25], "green", 0.5)
box_start.set_translate([0.4, -0.3, 0.25/2])
box_goal = Box(world.vis, "box_goal", [0.06, 0.15, 0.25], "blue", 0.5)
goal_pose = SE3.from_rotation_and_translation(
    rotation=SO3.from_x_radians(jnp.pi/2),
    translation=jnp.array([0.4, 0.3, 0.15/2])
)
box_goal.set_pose(goal_pose)
ground = Box(world.vis, "ground", [2, 2, 0.5], "white", alpha=0.1)
ground.set_translate([0,0,-ground.lengths[-1]/2])

In [5]:
hand_model = RobotModel(HAND_URDF, PANDA_PACKAGE, True)
hand1 = Robot(world.vis, "hand1", hand_model, color="yellow", alpha=0.5)
hand2 = Robot(world.vis, "hand2", hand_model, color="yellow", alpha=0.5)
frame = Frame(world.vis, "frame")
def set_hand_pose(pose, hand:Robot):
    ee_base_pose = pose @ SE3.from_translation(jnp.array([0,0,-0.105]))
    hand.set_pose(ee_base_pose)

concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts


In [30]:
scale_to_world = jnp.linalg.norm(box_start.lengths)/2/0.8
scale_to_obj = 1/scale_to_world
hand_pose_wrt_ee = SE3.from_translation(jnp.array([0,0,-0.105]))
box_x = box_start.pose.rotation().as_matrix()[:,0]
def grasp_reconst(g:Array):
    normalize = lambda v: v/safe_2norm(v)    
    z = normalize(-g) #box.pose.translation() 
    _y = box_x
    x = normalize(jnp.cross(_y, z))
    y = jnp.cross(z, x)
    rot_mat = jnp.vstack([x, y, z]).T
    grasp_pose = SE3.from_rotation_and_translation(
        SO3.from_matrix(rot_mat), g*scale_to_world)
    return grasp_pose
def get_grasp_dist(g):
    vec_to_center = - g*scale_to_world
    dist_surface = jnp.abs(box_start.distance(g, SE3.identity(), box_start.lengths/2))
    dist_center = safe_2norm(vec_to_center)
    p_surface = vec_to_center*dist_surface/dist_center + g
    p_surface_proj = p_surface.at[0].set(0.)
    return safe_2norm(p_surface_proj - g*scale_to_world)

In [31]:
import orbax
from flax.training import orbax_utils
import flax.linen as nn
class ManipNet(nn.Module):
    hidden_dim: int
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(features=1)(x)
        return nn.softplus(x)

#load
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
ckpt = orbax_checkpointer.restore('model/manip_net')
manip_net = ManipNet(64)
params = ckpt["params"]

ws_r = 1.
ws_center = jnp.array([0,0,0.5])
def get_manip_value(posevec):
    _, xyz = posevec[:4], posevec[4:]
    is_out_of_bound = jnp.linalg.norm(xyz - ws_center) > ws_r
    return jax.lax.cond(
        is_out_of_bound, lambda x:jnp.array(0.), 
        lambda x: manip_net.apply(params, x)[0], posevec)

In [32]:
#penetration
qhand = jnp.array([0.04, 0.04])
hand_pc = hand1.get_surface_points_fn(qhand)
env = SDFContainer([ground], 0.05)
def get_hand_pcs(g, obj1, obj2):
    result = []
    for obj in [obj1, obj2]:
        grasp_pose = grasp_reconst(g)
        hand_base_pose_wrt_world = obj.pose @ grasp_pose @ hand_pose_wrt_ee
        assigned_hand_pc = jax.vmap(hand_base_pose_wrt_world.apply)(hand_pc)
        result.append(assigned_hand_pc)
    return jnp.vstack(result)

def constr_penetration(g, obj1, obj2, env):
    points = get_hand_pcs(g, obj1, obj2)
    return env.penetration_sum(points)

In [33]:
g_init = jnp.array(np.random.normal(size=3))
grasp_pose = grasp_reconst(g_init)
set_hand_pose(box_start.pose @ grasp_pose, hand1)
set_hand_pose(box_goal.pose @ grasp_pose, hand2)

points = get_hand_pcs(g_init, box_start, box_goal)
#pc.reload(points=points)
g = g_init

In [34]:
obj1_pose = box_start.pose
obj2_pose = box_goal.pose

In [35]:
@jax.jit
def grasp_to_posevecs(g, apply_zflip=False):
    if apply_zflip:
        zflip = SE3.from_rotation(SO3.from_z_radians(jnp.pi))
    else:
        zflip = SE3.identity()
    grasp_pose_wrt_obj = grasp_reconst(g)
    poses = []
    poses += [obj1_pose @ grasp_pose_wrt_obj @ hand_pose_wrt_ee @ zflip]
    poses += [obj2_pose @ grasp_pose_wrt_obj @ hand_pose_wrt_ee @ zflip]
    poses = jnp.vstack([pose.parameters() for pose in poses])
    return poses

def constr_penet(posevecs):
    assigned_points_fn = \
        lambda posevec, points :jax.vmap(SE3(posevec).apply)(points)
    points = jax.vmap(assigned_points_fn, in_axes=(0, None))(posevecs, hand_pc)
    points = jnp.vstack(points)
    return env.penetration_sum(points)
def constr_kin(posevecs):
    manip_min = jax.vmap(get_manip_value)(posevecs).min()
    return manip_min

def constraints(g):
    posevecs = grasp_to_posevecs(g)
    dist = get_grasp_dist(g)
    penet = constr_penet(posevecs)
    manip = constr_kin(posevecs)
    return jnp.hstack([dist, penet, manip])

In [36]:
from sdf_world.nlp import NLP
prob = NLP()

In [37]:
prob.add_var("g", 3, -1*np.ones(3), np.ones(3))

In [38]:
prob.add_con("con", 3, ["g"], constraints, 
             lower=np.array([0., 0., 0.1]),
             upper=np.array([0., 0., 2.])
)

In [39]:
prob.add_objective(lambda g:jnp.array(0.))

In [40]:
prob.build()

In [53]:
xsol, info = prob.solve(jnp.zeros(3), viol_tol=0.001)

This is Ipopt version 3.14.10, running with linear solver MUMPS 5.2.1.

Number of nonzeros in equality constraint Jacobian...:        6
Number of nonzeros in inequality constraint Jacobian.:        3
Number of nonzeros in Lagrangian Hessian.............:        0

Total number of variables............................:        3
                     variables with only lower bounds:        0
                variables with lower and upper bounds:        3
                     variables with only upper bounds:        0
Total number of equality constraints.................:        2
Total number of inequality constraints...............:        1
        inequality constraints with only lower bounds:        0
   inequality constraints with lower and upper bounds:        1
        inequality constraints with only upper bounds:        0

iter    objective    inf_pr   inf_du lg(mu)  ||d||  lg(rg) alpha_du alpha_pr  ls
   0  0.0000000e+00 6.79e-06 0.00e+00   0.0 0.00e+00    -  0.00e+00 0.00e+00 

In [55]:
posevecs = grasp_to_posevecs(xsol)
set_hand_pose(box_start.pose @ grasp_pose, hand1)
set_hand_pose(box_goal.pose @ grasp_pose, hand2)
hand1.set_pose(SE3(posevecs[0]))
hand2.set_pose(SE3(posevecs[1]))

In [57]:
grasp_pose = grasp_reconst(xsol)
set_hand_pose(box_start.pose@grasp_pose, hand1)

In [58]:
grasp_pose

SE3(wxyz=[-0.52173    -0.47728    -0.47151998  0.52694   ], xyz=[ 7.9999998e-05  6.9199996e-03 -6.8999996e-04])

In [59]:
box_start.pose

SE3(wxyz=[1. 0. 0. 0.], xyz=[ 0.39999998 -0.29999998  0.125     ])

In [56]:
get_grasp_dist(xsol)

Array(0.00085102, dtype=float32)

In [54]:
constraints(xsol)

Array([0.00085102, 0.        , 0.7391411 ], dtype=float32)

In [156]:
%time
grasp_to_posevecs(g)

CPU times: user 6 µs, sys: 15 µs, total: 21 µs
Wall time: 39.1 µs


Array([[-0.33871344, -0.62070376, -0.11753047,  0.6972708 ,  0.58595467,
        -0.16174097,  0.07724981],
       [ 0.19939728, -0.6784104 , -0.5761515 ,  0.40993834,  0.58595467,
         0.3477502 ,  0.21325906]], dtype=float32)

Array([0.23824543, 0.27298832], dtype=float32)

In [138]:
%time 
constr_penet(poses)

CPU times: user 2 µs, sys: 15 µs, total: 17 µs
Wall time: 33.9 µs


Array(0.00012368, dtype=float32)

In [125]:
pc.reload(points=points)

In [121]:
pc = PointCloud(world.vis, "pc", points, size=0.01, color="red")

In [79]:
g_grad = jax.grad(constr_penetration)(g, box_start, box_goal, env)
g = - g_grad*0.01 + g
grasp_pose = grasp_reconst(g, box_start)
set_hand_pose(box_start.pose @ grasp_pose, hand1)
set_hand_pose(box_goal.pose @ grasp_pose, hand2)

In [80]:
print(constr_penetration(g, box_start, box_goal, env))

0.00089821743


In [42]:
env.penetration_sum(points)

Array(11.353209, dtype=float32)

In [167]:
grad_p = penet_grad_fn(g, box_start)
g = - grad_p*0.05 + g
grasp_pose = grasp_reconst(g, box_start)
set_hand_pose(box_start.pose @ grasp_pose, hand1)
set_hand_pose(box_goal.pose @ grasp_pose, hand2)
print(penetrations(g, box_start))

4.9822407e-05


In [168]:
grasp_pose = grasp_reconst(g, box_start)
hand_base_pose_wrt_world = box_start.pose @ grasp_pose @ hand_pose_wrt_ee
assigned_hand_pc = jax.vmap(hand_base_pose_wrt_world.apply)(hand_pc)

In [175]:
assigned_hand_pc[8]

Array([ 0.24431565, -0.1587292 ,  0.04776791], dtype=float32)

In [174]:
jax.vmap(env.penetration)(assigned_hand_pc).argmax()

Array(8, dtype=int32)

In [170]:
pc_start = PointCloud(world.vis, "pc_start", 
                      assigned_hand_pc, size=0.01, color="red")

In [181]:
del pc_goal

In [61]:
penetrations(box_start.pose @ grasp_pose)

Array(0.80027175, dtype=float32)

In [24]:
def assign_hand_pc(ee_pose):
    hand_base_pose = ee_pose @ hand_pose_wrt_ee
    return jax.vmap(hand_base_pose.apply)(hand_pc)

In [52]:
hand_pc_start = assign_hand_pc(box_start.pose@grasp_pose)
hand_pc_goal = assign_hand_pc(box_goal.pose@grasp_pose)

In [59]:
jax.jit(env.penetration_sum)(hand_pc_start)

Array(3.8270435, dtype=float32)

In [55]:
jax.vmap(ground.penetration, in_axes=(0, None))(hand_pc_start, 0.05)

Array([0.0903221 , 0.05407002, 0.16405903, 0.02178621, 0.11948074,
       0.11161787, 0.07293419, 0.07485468, 0.15773171, 0.11410213,
       0.01828003, 0.04888446, 0.14840655, 0.05378092, 0.13687733,
       0.06560025, 0.10658304, 0.05655343, 0.14534341, 0.00621273,
       0.08871541, 0.09663767, 0.05944959, 0.10215046, 0.0749354 ,
       0.10418446, 0.06886864, 0.08655269, 0.08516046, 0.08507899,
       0.07272626, 0.10985497, 0.07414203, 0.07997886, 0.09959988,
       0.06137355, 0.09734934, 0.06699882, 0.09366957, 0.09004926,
       0.03192899, 0.00810587, 0.00903368, 0.02626381, 0.00995794,
       0.01350285, 0.0228666 , 0.03721107, 0.00403808, 0.02144613,
       0.02283757, 0.03633609, 0.01403445, 0.01210649, 0.00574847,
       0.027432  , 0.00623619, 0.02207015, 0.01279055, 0.0181395 ],      dtype=float32)

In [35]:
pc_start = PointCloud(world.vis, "pc_start", hand_pc_start, size=0.01, color="red")
pc_goal = PointCloud(world.vis, "pc_goal", hand_pc_goal, size=0.01, color="red")

In [34]:
del pc_start, pc_goal