In [1]:
import numpy as np
import jax.numpy as jnp
import jax
import cyipopt

from functools import partial
from typing import *
from dataclasses import dataclass, field
from jaxlie import SE3, SO3
import jax_dataclasses as jdc

from sdf_world.sdf_world import *
from sdf_world.robots import *
from sdf_world.util import *
from sdf_world.network import *
from sdf_world.sparse_ipopt import *

from flax import linen as nn
from flax.training import orbax_utils
import orbax
import pickle
import time

In [2]:
@dataclass
class Variable:
    name: str
    coord: np.ndarray
    lb: np.ndarray
    ub: np.ndarray
    
    @property
    def dim(self): return len(self.coord)

@dataclass
class Parameter:
    name: str
    coord: np.ndarray
    dim: int

    @property
    def lb(self): return np.full(self.dim, -np.inf)
    @property
    def ub(self): return np.full(self.dim, np.inf)

@dataclass
class Constraint:
    name: str
    coord: np.ndarray
    inputs: List[Variable]
    fn: "Function"
    lb: np.ndarray
    ub: np.ndarray
    no_deriv_names: List[str]
    jac_indices: np.ndarray

    @property
    def dim(self): return len(self.coord)

@dataclass
class Function:
    name: str
    in_dims: List[int]
    out_dim: int
    eval_fn: Callable
    jac_fn: Callable
    custom_jac_indices: Optional[List[np.ndarray]]
    constraints: List[Constraint] = field(default_factory=list)

class SparseIPOPT():
    def __init__(self, check_fn=False):
        self.x_info: Dict[str,Variable] = {}
        self.c_info: Dict[str,Constraint] = {}
        self.fn_info: Dict[str,Function] = {}
        self.obj_info: Dict = {}
        self.param_info: Dict[str, Array] = {}
        self.check_fn = check_fn

        self.x_idx, self.c_idx = 0, 0
        self.param_info: Dict[str, np.ndarray] = {}

    @property
    def xdim(self):
        return sum([x.dim for x in self.x_info.values()])
    @property
    def cdim(self):
        return sum([c.dim for c in self.c_info.values()])
    @property
    def input_info(self):
        return {**self.x_info, **self.param_info}
    
    def add_variable(self, name, dim, lb=-np.inf, ub=np.inf):
        assert name not in self.x_info
        assert isinstance(lb, float) or len(lb) == dim
        assert isinstance(ub, float) or len(ub) == dim

        if isinstance(lb, float): lb = np.full(dim, lb)
        if isinstance(ub, float): ub = np.full(dim, ub)

        coord = np.arange(self.x_idx, self.x_idx+dim)
        self.x_info[name] = Variable(
            name, coord, lb, ub)
        
        self.x_idx += dim
    
    def add_parameter(self, name, dim):
        assert name not in self.param_info
        #TODO: name also should not be in "variable"
        coord = np.arange(self.x_idx, self.x_idx+dim)
        self.x_info[name] = Parameter(name, coord, dim)
        self.x_idx += dim
    
    def set_objective(self, fn_name, input_x_names):
        self.obj_info["fn"] = self.fn_info[fn_name]
        self.obj_info["inputs"] = [self.x_info[name] for name in input_x_names]
    
    def set_debug_callback(self, debug_callback:Callable):
        self.obj_info["debug_cb"] = debug_callback

    def set_constr(self, name, cfn_name, input_x_names, lb, ub, no_deriv_names=[]):
        c_fn = self.fn_info[cfn_name]
        cdim = c_fn.out_dim
        assert name not in self.c_info
        assert isinstance(lb, float) or len(lb) == cdim
        assert isinstance(ub, float) or len(ub) == cdim
        if isinstance(lb, float): lb = np.full(cdim, lb)
        if isinstance(ub, float): ub = np.full(cdim, ub)

        vars = [self.x_info[name] for name in input_x_names]
        c_coord = np.arange(self.c_idx, self.c_idx+cdim)

        jac_indices = []
        for i, var in enumerate(vars):
            if var.name in no_deriv_names: continue
            if isinstance(var, Parameter): continue
            
            if c_fn.custom_jac_indices is not None:
                row, col = c_fn.custom_jac_indices[i]
            else:
                row, col = np.indices((cdim, var.dim)).reshape(2, -1)
            row_offset, col_offset = c_coord[0], var.coord[0] # offset
            jac_indices.append(np.vstack([row+row_offset, col+col_offset]))

        self.c_info[name] = Constraint(
            name, c_coord, vars, 
            c_fn, lb, ub,
            no_deriv_names, jac_indices)
        c_fn.constraints.append(self.c_info[name])
        self.c_idx += cdim

    def register_fn(self, name, in_dims, out_dim, eval_fn, jac_fn, custom_jac_indices=None):
        xdummies = [jnp.zeros(dim) for dim in in_dims]
        if self.check_fn:
            assert eval_fn(*xdummies).size == out_dim
            assert len(jac_fn(*xdummies)) == len(in_dims)

        if custom_jac_indices is not None:
            assert len(custom_jac_indices) == len(in_dims)
        self.fn_info[name] = Function(
            name, in_dims, out_dim, eval_fn, jac_fn, custom_jac_indices)
    
    def get_objective_fn(self, compile=True):
        no_obj = False
        if "fn" not in self.obj_info: 
            objective = lambda x: 0.
            no_obj = True
        else:
            def objective(x):        
                xs = {var.name:x[var.coord] for var in self.x_info.values()}
                fn_input = [xs[var.name] for var in self.obj_info["inputs"]]    
                val = self.obj_info["fn"].eval_fn(*fn_input)
                return val
        
        if "debug_cb" in self.obj_info:
            def objective_debug(x):
                xs = {var.name:x[var.coord] for var in self.x_info.values()}
                self.obj_info["debug_cb"](xs)    
                return objective(x)
            return objective_debug
        elif compile and not no_obj:
            return jax.jit(objective)
        return objective
    
    def get_gradient_fn(self, compile=True):
        no_obj = False
        if "fn" not in self.obj_info: 
            gradient = lambda x: np.zeros(self.xdim)
            no_obj = True
        else:
            grad_value_dict = {var.name: np.zeros(var.dim) for var in self.x_info.values()}
            def gradient(x):
                xs = {var.name:x[var.coord] for var in self.x_info.values()}
                fn_input = [xs[var.name] for var in self.obj_info["inputs"]]    
                grads = self.obj_info["fn"].jac_fn(*fn_input)
                for var, grad in zip(self.obj_info['inputs'], grads):
                    grad_value_dict[var.name] = grad
                return jnp.hstack(grad_value_dict.values())
        if compile and not no_obj:
            return jax.jit(gradient)
        return gradient      
    
    def get_constraint_fn(self, compile=True):
        def constraints(x):
            xs = {var.name:x[var.coord] for var in self.x_info.values()}
            result = []
            for constr in self.c_info.values():
                fn_input = [xs[var.name] for var in constr.inputs]    
                out = constr.fn.eval_fn(*fn_input)
                result.append(out)
            return jnp.hstack(result)
        if compile:
            return jax.jit(constraints)
        return constraints
    
    def get_jacobian_fn(self, compile=True):
        def jacobian(x):
            xs = {var.name:x[var.coord] for var in self.x_info.values()}
            result = []
            for constr in self.c_info.values():
                fn_input = [xs[var.name] for var in constr.inputs]    
                jacs = constr.fn.jac_fn(*fn_input)
                for i, var in enumerate(constr.inputs):
                    if var.name in constr.no_deriv_names: continue
                    elif isinstance(var, Parameter): continue
                    result.append(jacs[i].flatten())
            return jnp.hstack(result)
        if compile:
            return jax.jit(jacobian)
        return jacobian

    
    def get_jacobian_structure(self):
        rows, cols = [], []
        rows, cols = [], []
        for constr in self.c_info.values():
            for jac_idx in constr.jac_indices:
                if jac_idx is None: continue
                rows.append(jac_idx[0])
                cols.append(jac_idx[1])
        rows = np.hstack(rows)
        cols = np.hstack(cols)
        return rows, cols

    def print_sparsity(self):
        row, col = self.get_jacobian_structure()
        jac_struct = np.full((self.cdim, self.xdim), -1, dtype=int)
        jac_struct[row, col] = 1
        for row in jac_struct:
            row_str = ""
            for val in row:
                if val == -1: row_str += "-"
                else: row_str += f"o"
            print(row_str)
    
    def build(self, compile=True):
        lb = np.hstack([x.lb for x in self.x_info.values()])
        ub = np.hstack([x.ub for x in self.x_info.values()])
        cl = np.hstack([c.lb for c in self.c_info.values()])
        cu = np.hstack([c.ub for c in self.c_info.values()])
        row, col = self.get_jacobian_structure()
        jac_struct_fn = lambda : (row, col)

        fns = {
            "objective": self.get_objective_fn(compile),
            "gradient": self.get_gradient_fn(compile),
            "constraints": self.get_constraint_fn(compile),
            "jacobian": self.get_jacobian_fn(compile),
        }
        class Prob:
            pass
        prob = Prob()
        xdummy = jnp.zeros(self.xdim)
        for fn_name, fn in fns.items():
            print(f"compiling {fn_name} ...")
            fn(xdummy)
            setattr(prob, fn_name, fn)
        setattr(prob, "jacobianstructure", jac_struct_fn)

        ipopt = cyipopt.Problem(
            n=self.xdim, m=self.cdim,
            problem_obj=prob,
            lb=lb, ub=ub, cl=cl, cu=cu
        )
        # default option
        ipopt.add_option("acceptable_iter", 2)
        ipopt.add_option("acceptable_tol", np.inf) #release
        ipopt.add_option("acceptable_obj_change_tol", 0.1)
        ipopt.add_option("acceptable_constr_viol_tol", 1.)
        #ipopt.add_option("acceptable_dual_inf_tol", 1.) 
        ipopt.add_option('mu_strategy', 'adaptive')
        self.print_sparsity()
        return ipopt

In [3]:
# models 
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
restored_grasp = orbax_checkpointer.restore("model/grasp_net_prob_dist")
restored_manip = orbax_checkpointer.restore("model/manip_net_posevec")

#grasp net
grasp_net = GraspNet(32)
grasp_fn = lambda x: grasp_net.apply(restored_grasp["params"], x)

grasp_logit_fn = lambda g: grasp_fn(g)[0]
grasp_dist_fn = lambda g: grasp_fn(g)[1]

#manip net
manip_net = ManipNet(64)
manip_fn = lambda x: manip_net.apply(restored_manip["params"], x)[0]

In [4]:
world = SDFWorld()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7001/static/


In [5]:
world.show_in_jupyter()

In [6]:
# robot, hand
panda_model = RobotModel(PANDA_URDF, PANDA_PACKAGE)
panda1 = Robot(world.vis, "panda1", panda_model, alpha=0.5)
panda1.reduce_dim([7, 8], [0.04, 0.04])
panda2 = Robot(world.vis, "panda2", panda_model, alpha=0.5)
panda2.reduce_dim([7, 8], [0.04, 0.04])

hand_model = RobotModel(HAND_URDF, PANDA_PACKAGE, True)
for link_name, link in hand_model.links.items():
    link.set_surface_points(10)
#hand = Robot(world.vis, "hand1", hand_model, color="white", alpha=0.5)



In [7]:
class PandaHand:
    def __init__(self, hand_model, name="hand"):
        self.model = hand_model
        self.robot = Robot(world.vis, name, hand_model, color="white", alpha=0.5, visualized_mesh="visual")
        self.hand_pose_wrt_ee = SE3.from_translation(jnp.array([0,0,-0.105]))
        self.hand_pc = self.robot.get_surface_points_fn(jnp.array([0.04, 0.04]))
    
    def get_bounding_box(self, name):
        fks = self.model.fk_fn(jnp.array([0.04, 0.04]))
        hand_points = []
        for i, link in enumerate(hand_model.links.values()):
            pts = jax.vmap(SE3(fks[i]).apply)(link.collision_mesh.vertices)
            hand_points.append(pts)
        hand_points = np.vstack(hand_points, dtype=float)
        min_points = hand_points.min(axis=0)
        max_points = hand_points.max(axis=0)
        extents = max_points - min_points
        center = (max_points + min_points) / 2
        box = Box(world.vis, name, extents, alpha=0.5, visualize=False)
        box.set_translate(center)
        return box
    
    def set_pose(self, pose):
        self.robot.set_pose(pose @ self.hand_pose_wrt_ee)

In [8]:
hand1 = PandaHand(hand_model, "hand1")
hand2 = PandaHand(hand_model, "hand2")

In [9]:
#load sdf meshes
table_lengths = [0.4, 0.4, 0.2]
table_start = Box(world.vis, "table_start", table_lengths, 'white', 0.5)
table_goal = Box(world.vis, "table_goal", table_lengths, 'white', 0.5)
obstacle = Box(world.vis, "obstacle", [0.4, 0.2, 0.35], 'white', 0.5)

obj_start = Mesh(world.vis, "obj_start", 
                 "./sdf_world/assets/object/mesh.obj",
                 color="blue", alpha=0.5)
obj_goal = Mesh(world.vis, "obj_goal", 
                "./sdf_world/assets/object/mesh.obj",
                color="green", alpha=0.5)
obj_ho = Mesh(world.vis, "obj_ho", 
                "./sdf_world/assets/object/mesh.obj",
                color="yellow", alpha=0.5)
obj = Mesh(world.vis, "obj", 
                "./sdf_world/assets/object/mesh.obj",
                color="white", alpha=0.8)

In [10]:
ydev = 0.4
panda1.set_translate([0,-ydev,0])
panda2.set_translate([0,ydev,0])

table_start.set_translate([0.45, -ydev, 0.2/2])
table_goal.set_translate([0.45, ydev, 0.2/2])
obj_lengths = obj_start.mesh.bounding_box.primitive.extents
obj_start.set_translate([0.45, -0.4, obj_lengths[-1]/2+table_lengths[-1]])
trans_goal = jnp.array([0.45, 0.4, obj_lengths[-2]/2+table_lengths[-1]])
obstacle.set_translate([0.45, 0., 0.35/2])
obj_goal_pose = SE3.from_rotation_and_translation(
    SO3.from_rpy_radians(jnp.pi/2, 0,0), trans_goal)
obj_goal.set_pose(obj_goal_pose)
obj.set_translate([0.45, -0.4, obj_lengths[-1]/2+table_lengths[-1]])

In [11]:
obj_ho.set_translate([0., 0.4, 0.4])

In [14]:
#visualization
pc_hands = PointCloud(world.vis, "pc_hands", np.zeros((10,3)), 0.01, "red")
dotted_line = DottedLine(world.vis, "traj", np.zeros((10,3)), 0.01, "blue")
pc_body = PointCloud(world.vis, "body_pc", np.zeros((10,3)), 0.01, "blue")
pc_obj = PointCloud(world.vis, "pc_obj", np.zeros((10,3)), 0.01, "green")

In [15]:
# kinematics
def skew(v):
    v1, v2, v3 = v
    return jnp.array([[0, -v3, v2],
                      [v3, 0., -v1],
                      [-v2, v1, 0.]])

@jax.custom_jvp
def fk(q):
    fks = panda_model.fk_fn(q)
    return fks[-1]

@fk.defjvp
def fk_jvp(primals, tangents):
    q, = primals
    q_dot, = tangents
    fks = panda_model.fk_fn(q)
    qtn, p_ee = fks[-1][:4], fks[-1][-3:]
    w, xyz = qtn[0], qtn[1:]
    geom_jac = []
    for posevec in fks[1:8]:
        p_frame = posevec[-3:]
        rot_axis = SE3(posevec).as_matrix()[:3, 2]
        lin_vel = jnp.cross(rot_axis, p_ee - p_frame)
        geom_jac.append(jnp.hstack([lin_vel, rot_axis]))
    geom_jac = jnp.array(geom_jac).T  #geom_jacobian
    H = jnp.hstack([-xyz[:,None], skew(xyz)+jnp.eye(3)*w])
    rot_jac = 0.5*H.T@geom_jac[3:,:]
    jac = jnp.vstack([rot_jac, geom_jac[:3,:]])
    return fks[-1], jac@q_dot

In [None]:
# constants, data
safe_dist = 0.05
ws_lb = np.array([-1,-1,-0.5, -np.pi, -np.pi, -np.pi])
ws_ub = np.array([1,1,1.5, np.pi, np.pi, np.pi])
hand_pose_wrt_ee = SE3.from_translation(jnp.array([0,0,-0.105]))
hand_pc = hand.get_surface_points_fn(jnp.array([0.04, 0.04]))

#links_points_mat
links_points_mat = []
for link in panda_model.links.values():
    if not link.has_mesh: continue
    if not panda_model.is_floating and link == panda_model.root_link: continue
    links_points_mat.append(link.surface_points)
links_points_mat = np.array(links_points_mat)

#object points
obj_points = farthest_point_sampling(obj_start.mesh.sample(200), 20)

# environment sdf
env = SDFContainer([table_start, table_goal, obstacle], safe_dist)

In [16]:
obj.set_translate([0.5, 0, 0.5])

In [18]:
def grasp_reconst(grasp:Array):
    rot = SO3(grasp_fn(grasp)[2:]).normalize()
    trans = grasp/restored_grasp["scale_to_norm"]
    return SE3.from_rotation_and_translation(rot, trans)
def grasp_embedding(grasp_point):
    grasp = grasp_point * restored_grasp["scale_to_norm"]
    return grasp

In [25]:
#utility functions
to_posevec = lambda x: jnp.hstack([x[4:], SO3(x[:4]).log()])
to_wxyzxyz = lambda x: jnp.hstack([SO3.exp(x[3:]).parameters(), x[:3]])
hand_pose_wrt_ee = SE3.from_translation(jnp.array([0,0,-0.105]))
hand_pc = hand.get_surface_points_fn(jnp.array([0.04, 0.04]))

def get_hand_pc(grasp, wxyzxyz):
    grasp_pose = grasp_reconst(grasp)
    hand_base_pose_wrt_world = SE3(wxyzxyz) @ grasp_pose @ hand_pose_wrt_ee
    assigned_hand_pc = jax.vmap(hand_base_pose_wrt_world.apply)(hand_pc)
    return assigned_hand_pc

In [27]:
grasp1 = jnp.array([0,-0.3,0])
grasp2 = jnp.array([0,0.3,0])
#grasp_pose1 = obj.pose @ grasp_reconst(grasp1)
hand_pc1 = get_hand_pc(grasp1, obj.pose.parameters())
pc_hands.reload(points=hand_pc1)

In [None]:
squ