In [1]:
from sdf_world.sdf_world import *
from sdf_world.robots import *
from sdf_world.util import *
from sdf_world.network import *

import numpy as np
import jax
import jax.numpy as jnp
from jaxlie import SE3, SO3
import jax_dataclasses as jdc

In [2]:
world = SDFWorld()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7006/static/


In [3]:
#load learned functions
grasp_fn, grasp_reconst = load_grasp_fn("model/grasp_net_waffle")
inv_manip_fn_gen3 = load_manip_fn("model/manip_net_gen3", SE3.identity())

hidden_dim:32, out_dim:7


In [4]:
# refer to PREDEFINED_ROBOTS
gen3 = get_predefined_robot(world.vis, "gen3+hand_e")

In [5]:
obj_mesh = "./sdf_world/assets/waffle_box/waffle_box_centered.obj"
obj0 = Mesh(world.vis, "obj0", obj_mesh, color="green", alpha=0.3)
objg = Mesh(world.vis, "objg", obj_mesh, color="red", alpha=0.3)

In [6]:
def SE3_trans(xyz):
    return SE3.from_translation(jnp.array(xyz))

In [7]:
def object_placement(obj, upright_axis, xy, yaw=None):
    if yaw is None:
        yaw = np.random.uniform(-np.pi, np.pi)
    obj_extents = obj.mesh.bounding_box.primitive.extents
    if upright_axis == "z":
        rot = SO3.from_z_radians(yaw)
        z_offset = obj_extents[-1]/2
    if upright_axis == "x":
        rot = SO3.from_z_radians(yaw)@SO3.from_y_radians(-np.pi/2)
        z_offset = obj_extents[0]/2
    elif upright_axis == "y":
        rot = SO3.from_z_radians(yaw)@SO3.from_x_radians(np.pi/2)
        z_offset = obj_extents[1]/2
    xyz = [*xy, z_offset]
    pose = SE3.from_rotation_and_translation(rot, jnp.array(xyz))
    return pose

# xy, upright_axis= ground
yaw = np.random.uniform(-np.pi, np.pi)
upright_axis = "y"
pose0 = object_placement(obj0, "x", xy=[0.4,0.3])
poseg = object_placement(obj0, "z", xy=[0.4,-0.3])

obj0.set_pose(pose0)
objg.set_pose(poseg)

In [8]:
ground = Box(world.vis, "ground", [1,1.8,0.5], 'gray', 0.3)
ground.set_pose(SE3_trans([0.4, 0, -0.25]))
env = SDFContainer([ground]) #we can add more obstacles

In [9]:
hande = Gripper(world.vis, "hande", gen3.gripper.model, gen3.gripper.tool_pose_offset, max_width=0.05, scale=0.001, alpha=0.5, color='white')
hande_pc = hande.get_hand_pc_wrt_tool_pose()

In [10]:
def P_grasp(g):
    return grasp_fn(g)[0]

def M_inv(tool_pose:Array):
    zflip_SE3 = SE3.from_rotation(SO3.from_z_radians(jnp.pi))
    #pre
    tool_pose = SE3(tool_pose) # @ SE3_trans([0,0,-0.1])
    tool_pose_flip = (tool_pose @ zflip_SE3)
    candidates = jnp.vstack([pose.parameters() for pose in [tool_pose, tool_pose_flip]])
    return jax.vmap(inv_manip_fn_gen3)(candidates).max()

def D_hand(tool_pose:Array):
    assigned_pc = jax.vmap(SE3(tool_pose).apply)(hande_pc)
    return env.distances(assigned_pc).min()

def R_grasp(grasp, obj_pose:Array):
    grasp_pose = grasp_reconst(grasp)
    return (SE3(obj_pose) @ grasp_pose @ SE3_trans([0,0,-0.1])).parameters()

In [11]:
K = [0] #grasp numbers
KT = [(0,0), (0,1)] #grasp/object pair numbers
num_constrs = len(K) + len(KT)*2

num_vars = 1
var_names = ["g0"]
var_coords = [np.arange(0,3)]

num_params = 2
param_names = ["p0", "p1"]
poses = [pose0, poseg]
param_values = [pose.parameters() for pose in poses]

In [12]:
def calculate_constraints(x):
    variables = {var_names[i]:x[var_coords[i]] for i in range(num_vars)}
    variables.update({param_names[i]:param_values[i] for i in range(num_params)})

    grasps = []
    for k in K:
        grasp = variables[f"g{k}"]
        grasps.append(grasp)
    grasps = jnp.vstack(grasps)
    
    tool_poses = []
    grasps2 = []
    for k, t in KT:
        grasp = variables[f"g{k}"]
        pose = variables[f"p{t}"]
        tool_pose = R_grasp(grasp, pose)
        tool_poses.append(tool_pose)
        grasps2.append(grasp)
    tool_poses = jnp.vstack(tool_poses)
    grasps2 = jnp.vstack(grasps2)
    
    logits = jax.vmap(P_grasp)(grasps)
    manips = jax.vmap(M_inv)(tool_poses)
    dists = jax.vmap(D_hand)(tool_poses)

    return jnp.hstack([logits, manips, dists])

In [13]:
import cyipopt
import time

In [85]:
class Prob:
    def __init__(self, obj_fn, constr_fn, compile=False, debug_fn=None):
        self.grads = {}
        self.jacs = {}
        self.debug_fn = debug_fn
        self.vg_obj_fn = jax.value_and_grad(obj_fn, argnums=0)
        self.vj_constr_fn = value_and_jacfwd(constr_fn, argnums=0)
        if compile:
            self.vg_obj_fn = jax.jit(self.vg_obj_fn)
            self.vj_constr_fn = jax.jit(self.vj_constr_fn)

    def hash(self, x):
        return hash(x.tobytes())

    def objective(self, x):
        value, grad = self.vg_obj_fn(x)
        key = self.hash(x)
        self.grads[key] = grad
        if self.debug_fn is not None:
            self.debug_fn(x)
        return value
    
    def gradient(self, x):
        key = self.hash(x)
        if key in self.grads:
            return self.grads[key]
        _, grad = self.vg_obj_fn(x)
        return grad
    
    def constraints(self, x):
        value, jac = self.vj_constr_fn(x)
        key = self.hash(x)
        self.jacs[key] = jac
        return value
    
    def jacobian(self, x):
        key = self.hash(x)
        if key in self.jacs:
            return self.jacs[key]
        _, jac = self.vj_constr_fn(x)
        return jac


In [20]:
hande_pc_vis = PointCloud(world.vis, "hande_pc", np.zeros((10,3)), 0.01, "red")

In [14]:
obj_fn = lambda x: 0.
def debug_fn(x):
    variables = {var_names[i]:x[var_coords[i]] for i in range(num_vars)}
    variables.update({param_names[i]:param_values[i] for i in range(num_params)})
    
    tool_poses = []
    for k, t in KT:
        grasp = variables[f"g{k}"]
        pose = variables[f"p{t}"]
        tool_pose = R_grasp(grasp, pose)
        tool_poses.append(tool_pose)
    tool_poses = jnp.vstack(tool_poses)

    pcs = []
    for tool_pose in tool_poses:
        pc = jax.vmap(SE3(tool_pose).apply)(hande_pc)
        pcs.append(pc)
    pcs = jnp.vstack(pcs)
    hande_pc_vis.reload(points=np.array(pcs, dtype=float))
    time.sleep(0.1)

constr_fn = calculate_constraints
prob = Prob(obj_fn, constr_fn, compile=True, debug_fn=None) #, debug_fn=debug_fn

NameError: name 'Prob' is not defined

In [15]:
no_upper_lim = np.inf
logit_thres = 1
manip_thres = 0.2
safe_dist = 0.03

ipopt = cyipopt.Problem(
    n=3, m=5, problem_obj=prob, 
    lb=-np.ones(3), ub=np.ones(3),
    cl=[logit_thres, manip_thres, manip_thres, safe_dist, safe_dist], cu=[no_upper_lim]*5)

NameError: name 'prob' is not defined

In [175]:
ipopt.add_option("acceptable_iter", 2)
ipopt.add_option("acceptable_tol", np.inf) #release
ipopt.add_option("acceptable_constr_viol_tol", 1.)
ipopt.add_option('mu_strategy', 'adaptive')
# ipopt.add_option("acceptable_obj_change_tol", 0.1)
sol, info = ipopt.solve(np.zeros(3))

This is Ipopt version 3.14.10, running with linear solver MUMPS 5.2.1.

Number of nonzeros in equality constraint Jacobian...:        0
Number of nonzeros in inequality constraint Jacobian.:       15
Number of nonzeros in Lagrangian Hessian.............:        0

Total number of variables............................:        3
                     variables with only lower bounds:        0
                variables with lower and upper bounds:        3
                     variables with only upper bounds:        0
Total number of equality constraints.................:        0
Total number of inequality constraints...............:        5
        inequality constraints with only lower bounds:        5
   inequality constraints with lower and upper bounds:        0
        inequality constraints with only upper bounds:        0

iter    objective    inf_pr   inf_du lg(mu)  ||d||  lg(rg) alpha_du alpha_pr  ls
   0  0.0000000e+00 1.18e+02 1.11e+00   0.0 0.00e+00    -  0.00e+00 0.00e+00 

In [176]:
debug_fn(sol)

In [92]:
from dataclasses import dataclass, field

@dataclass
class Variable:
    name: str
    coord: np.ndarray
    lb: np.ndarray
    ub: np.ndarray
    
    @property
    def dim(self): return len(self.coord)
    
@dataclass
class Parameter:
    name: str
    coord: np.ndarray
    dim: int
    value: np.ndarray

    @property
    def lb(self): return np.full(self.dim, -np.inf)
    @property
    def ub(self): return np.full(self.dim, np.inf)

In [22]:
K = [0] #grasp numbers
T = [1, 2, 3, 4, 5] # planning scene numbers
P = [0, 1, 2] #object pose numbers
T_h = [2] # H in T

# scene/grasp/isswitch/pose
Pairs = [(2,0,True,0), (3,0, False,1), (4,0, True,2)] #grasp/scene pair numbers at kinematic switch
#num_constrs = len(K) + len(KT)*2

num_grasp = 1
num_config = 5
num_vars = num_grasp + num_config
names_grasp = [f"g{i}" for i in range(num_grasp)]
names_config = [f"q{i+1}" for i in range(num_config)]
var_names = names_grasp + names_config
#var coords
var_coords = []
i = 0
for var_name in var_names:
    if "g" in var_name:
        var_len = 3
    if "q" in var_name:
        var_len = 7
    var_coords.append(np.arange(i,i+var_len))
    i += var_len
dim = var_coords[-1][-1]+1

q0 = np.array([0,0,0,np.pi/2, 0, np.pi/2, 0])
param_names = ["p0", "p2", "q0", f"q{num_config+1}"]
param_values = [pose0.parameters(), poseg.parameters(), q0, q0]
params_dict = {param_names[i]:param_values[i] for i in range(len(param_names))}

In [25]:
def get_var_dict(x):
    variables = {var_names[i]:x[var_coords[i]] for i in range(num_vars)}
    variables.update(params_dict)
    return variables

In [27]:
fk_ee_fn = gen3.get_fk_ee_fn()
def K_robot(tool_pose:Array, q:Array):
    tool_pose = SE3(tool_pose)
    ee_pose = SE3(fk_ee_fn(q))
    pos_err = tool_pose.translation() - ee_pose.translation()
    rot_err = (ee_pose.rotation().inverse() @ tool_pose.rotation()).log()
    return jnp.hstack([pos_err, 0.5*rot_err])

In [16]:
frame = Frame(world.vis, "frame")

In [23]:
i = 8
link_pose = SE3(gen3.fk_fn(jnp.zeros(7))[i])
points = jax.vmap(link_pose.apply)(gen3.pcs[i])
hande_pc_vis.reload(points=points)

In [34]:
#robot collision
robot_pc_fn = gen3.get_robot_pc_fn()
def D_robot(q:Array):
    assigned_pc = robot_pc_fn(q)
    return env.distances(assigned_pc).min()

#object collision
obj_pc = farthest_point_sampling(obj0.mesh.sample(100), 10)
def D_obj(q:Array, grasp:Array):
    ee_pose = SE3(fk_ee_fn(q))
    grasp_pose = grasp_reconst(grasp)
    obj_pose = ee_pose @ grasp_pose.inverse()
    pc = jax.vmap(obj_pose.apply)(obj_pc)
    return env.distances(pc).min()

In [62]:
def calculate_constraints2(x):
    var_dict = get_var_dict(x)
    grasps = []
    for k in K:
        grasp = var_dict[f"g{k}"]
        grasps.append(grasp)
    grasps = jnp.vstack(grasps)
    logits = jax.vmap(P_grasp)(grasps)

    #kinematic error scene/grasp/isswitch/pose
    tool_poses, configs = [], []
    for (t, k, isswitch, p) in Pairs:
        if isswitch:
            grasp = var_dict[f"g{k}"]
            pose = var_dict[f"p{p}"]
            q = var_dict[f"q{t}"]
            tool_poses.append(R_grasp(grasp, pose))
            configs.append(q)
    tool_poses = jnp.vstack(tool_poses)
    configs = jnp.vstack(configs)
    kin_errors = jax.vmap(kin_error, in_axes=(0,0))(tool_poses, configs)

    #robot collision
    configs = []
    for t in T:
        q = var_dict[f"q{t}"]
        configs.append(q)
    configs = jnp.vstack(configs)
    col_dist = jax.vmap(D_robot)(configs)

    #object collision
    grasps, configs = [], []
    for (t,k,isswitch, p) in Pairs:
        if not isswitch:
            grasp = var_dict[f"g{k}"]
            q = var_dict[f"q{t}"]
            grasps.append(grasp)
            configs.append(q)
    configs = jnp.vstack(configs)
    grasps = jnp.vstack(grasps)
    obj_col_dist = jax.vmap(D_obj)(configs, grasps)
    constraints = [logits, kin_errors, col_dist, obj_col_dist]
    constraints = jnp.hstack([constr.flatten() for constr in constraints])
    return constraints

In [86]:
def obj_fn(x):
    var_dict = get_var_dict(jnp.zeros(dim))
    configs = [var_dict[f"q{i}"] for i in range(7)]
    configs = jnp.vstack(configs)
    diff = configs[1:] - configs[:-1]
    return jnp.sum(diff**2)

In [90]:
from sdf_world.sparse_ipopt import *

In [91]:
var_names

['g0', 'q1', 'q2', 'q3', 'q4', 'q5']

In [None]:
bdr = SparseIPOPT()
# grasp
bdr.add_variable()

In [87]:
prob = Prob(obj_fn, calculate_constraints2)

In [89]:
calculate_constraints2(jnp.zeros(dim)).shape

(19,)

In [None]:
no_upper_lim = np.inf
logit_thres = 1
manip_thres = 0.2
safe_dist = 0.03

ipopt = cyipopt.Problem(
    n=dim, m=19, problem_obj=prob, 
    lb=-np.ones(3), ub=np.ones(3),
    cl=[logit_thres, manip_thres, manip_thres, safe_dist, safe_dist], cu=[no_upper_lim]*5)

In [None]:
var_dict = get_var_dict(jnp.zeros(dim))
grasps = []
for k in K:
    grasp = var_dict[f"g{k}"]
    grasps.append(grasp)
grasps = jnp.vstack(grasps)
logits = jax.vmap(P_grasp)(grasps)

#kinematic error scene/grasp/isswitch/pose
tool_poses, configs = [], []
for (t, k, isswitch, p) in Pairs:
    if isswitch:
        grasp = var_dict[f"g{k}"]
        pose = var_dict[f"p{p}"]
        q = var_dict[f"q{t}"]
        tool_poses.append(R_grasp(grasp, pose))
        configs.append(q)
tool_poses = jnp.vstack(tool_poses)
configs = jnp.vstack(configs)
kin_errors = jax.vmap(kin_error, in_axes=(0,0))(tool_poses, configs)

#robot collision
configs = []
for t in T:
    q = var_dict[f"q{t}"]
configs = jnp.vstack(configs)
col_dist = jax.vmap(D_robot)(configs)


ValueError: Need at least one array to concatenate.

[Array([0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 Array([0., 0., 0., 0., 0., 0., 0.], dtype=float32)]

In [228]:
SE3(tool_poses[1])

SE3(wxyz=[ 0.45220998 -0.25160998  0.85379     0.05696   ], xyz=[ 0.32565    -0.33247998  0.09344999])