In [1]:
import numpy as np
import jax
import jax.numpy as jnp
from jaxlie import SE3, SO3

from sdf_world.sdf_world import *
from sdf_world.robots import *
from sdf_world.util import *

In [2]:
world = SDFWorld()
world.show_in_jupyter()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7005/static/


In [3]:
panda_model = RobotModel(PANDA_URDF, PANDA_PACKAGE)
panda = Robot(world.vis, "panda", panda_model, alpha=0.5)

concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts


In [8]:
box = Box(world.vis, "box", [0.06, 0.15, 0.25], "green", 0.5)
box.set_translate([0.4, 0, 0.25/2])

In [13]:
HAND_URDF = PANDA_PACKAGE / "hand.urdf"
hand_model = RobotModel(HAND_URDF, PANDA_PACKAGE, True)
hand = Robot(world.vis, "hand", hand_model, alpha=0.5)

concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts


Array([0.4  , 0.   , 0.125], dtype=float32)

In [267]:
box_x = box.pose.rotation().as_matrix()[:,0]
def grasp_point_to_pose(p):
    normalize = lambda v: v/safe_2norm(v)
    z = normalize(box.pose.translation() - p)
    _y = box_x
    x = normalize(jnp.cross(_y, z))
    y = jnp.cross(z, x)
    rot_mat = jnp.vstack([x, y, z]).T
    grasp_pose = SE3.from_rotation_and_translation(
        SO3.from_matrix(rot_mat), p)
    return grasp_pose

def set_hand_pose(pose):
    ee_base_pose = pose @ SE3.from_translation(jnp.array([0,0,-0.105]))
    hand.set_pose(ee_base_pose)
    frame.set_pose(pose)

In [276]:
def get_grasp_dist(g):
    vec_to_center = box.pose.translation() - g
    dist_surface = jnp.abs(box.distance(g, box.pose, box.lengths/2))
    dist_center = jnp.linalg.norm(vec_to_center)
    p_surface = vec_to_center*dist_surface/dist_center + g
    p_surface_wrt_box = box.pose.inverse().apply(p_surface)
    p_surface_proj_wrt_box = p_surface_wrt_box.at[0].set(0.)
    p_surface_proj = box.pose.apply(p_surface_proj_wrt_box)
    return safe_2norm(p_surface_proj - g)

In [277]:
p = np.array([0.47, -0.1, 0.4])
set_hand_pose(grasp_point_to_pose(p))

In [278]:
grad_p = jax.grad(get_grasp_dist)(p)
p = - grad_p*0.001 + p
grasp_pose = grasp_point_to_pose(p)
set_hand_pose(grasp_pose)
print(p)

[ 0.46946636 -0.09985061  0.39913225]


In [279]:
from flax.training import orbax_utils
from flax import linen as nn
#load
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
ckpt = orbax_checkpointer.restore('model/manip_net')
params = ckpt["params"]

class ManipNet(nn.Module):
    hidden_dim: int
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(features=1)(x)
        return nn.softplus(x)
manip_net = ManipNet(64)
ws_r = 1.
ws_center = jnp.array([0,0,0.5])
def get_manip_value(pose:SE3):
    xyz = pose.translation()
    is_out_of_bound = jnp.linalg.norm(xyz - ws_center) > ws_r
    return jax.lax.cond(
        is_out_of_bound, lambda x:jnp.array([0.]), 
        lambda x: manip_net.apply(params, x), pose.parameters())[0]

In [280]:
from sdf_world.nlp import NLP

In [293]:
prob = NLP()
prob.add_var("g", 3, np.ones(3)*-1, np.ones(3))
prob.add_con("grasp", 1, ["g"], get_grasp_dist, 0.)

In [294]:
pre_pose = SE3.from_translation(jnp.array([0,0,-0.05]))
def neg_manipulability(g):
    grasp_pose = grasp_point_to_pose(g)
    grasp_pose_rev = grasp_pose @ SE3.from_rotation(SO3.from_z_radians(jnp.pi))
    return - jnp.maximum(
        get_manip_value(grasp_pose @ pre_pose),
        get_manip_value(grasp_pose_rev @ pre_pose),
    )
prob.add_objective(neg_manipulability)

In [295]:
prob.build()

In [296]:
xsol, info = prob.solve(jnp.zeros(3))

This is Ipopt version 3.14.10, running with linear solver MUMPS 5.2.1.

Number of nonzeros in equality constraint Jacobian...:        3
Number of nonzeros in inequality constraint Jacobian.:        0
Number of nonzeros in Lagrangian Hessian.............:        0

Total number of variables............................:        3
                     variables with only lower bounds:        0
                variables with lower and upper bounds:        3
                     variables with only upper bounds:        0
Total number of equality constraints.................:        1
Total number of inequality constraints...............:        0
        inequality constraints with only lower bounds:        0
   inequality constraints with lower and upper bounds:        0
        inequality constraints with only upper bounds:        0

iter    objective    inf_pr   inf_du lg(mu)  ||d||  lg(rg) alpha_du alpha_pr  ls
   0 -5.5888420e-01 4.15e-01 7.49e-01   0.0 0.00e+00    -  0.00e+00 0.00e+00 

In [297]:
set_hand_pose(grasp_point_to_pose(xsol))

In [298]:
neg_manipulability(xsol)

Array(-0.8098788, dtype=float32)

In [230]:
get_manip_value(grasp_point_to_pose(p))

Array(0.5928991, dtype=float32)

In [139]:
grad_p

Array([0.        , 0.01050212, 0.        ], dtype=float32)

In [91]:
grad_p

Array([ 0.       , -2.5078523, -5.0157065], dtype=float32)

In [78]:
box.distance(p, box.pose, box.lengths/2)

Array(0.35493147, dtype=float32)

In [77]:
get_grasp_prob(p)

Array(2.374961e-05, dtype=float32)

In [22]:
frame = Frame(world.vis, "frame")

In [12]:
box.distance(jnp.zeros(3), box.pose, box.lengths/2)

Array(0.37, dtype=float32)