In [108]:
import numpy as np
import jax
import jax.numpy as jnp
from jaxlie import SE3, SO3
import jax_dataclasses as jdc
from functools import partial
import PyCeres

from sdf_world.sdf_world import *
from sdf_world.robots import *
from sdf_world.util import *

from flax import linen as nn
from flax.training import orbax_utils
import orbax
import pickle
import time

In [2]:
world = SDFWorld()
panda_model = RobotModel(PANDA_URDF, PANDA_PACKAGE)
panda = Robot(world.vis, "panda", panda_model, alpha=0.5)
panda.reduce_dim([7, 8], [0.04, 0.04])

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7002/static/


### Regarding Grasp

In [45]:
#load object
class GraspNet(nn.Module):
    hidden_dim: int
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        logit = nn.Dense(features=5)(x)
        return logit

orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
raw_restored = orbax_checkpointer.restore("model/grasp_net")
params = raw_restored["params"]
grasp_net = GraspNet(raw_restored["hidden_dim"])
grasp_fn = lambda x: grasp_net.apply(params, x)
with open("./sdf_world/assets/object"+'/info.pkl', 'rb') as f:
    obj_data = pickle.load(f)
scale_to_norm = obj_data["scale_to_norm"]
def grasp_reconst(g:Array):
    rot = SO3(grasp_fn(g)[1:5]).normalize()
    trans = g/scale_to_norm
    return SE3.from_rotation_and_translation(rot, trans)
grasp_logit_fn = lambda g: grasp_fn(g)[0]

In [19]:
table_lengths = [0.4, 0.4, 0.2]
table_start = Box(world.vis, "table_start", table_lengths, 'green', 0.5)
table_goal = Box(world.vis, "table_goal", table_lengths, 'blue', 0.5)
table_start.set_translate([0.5, -0.3, 0.2/2])
table_goal.set_translate([0.5, 0.3, 0.2/2])

In [38]:
obj_start = Mesh(world.vis, "obj_start", "./sdf_world/assets/object/mesh.obj",
                 alpha=0.5)
obj_goal = Mesh(world.vis, "obj_goal", "./sdf_world/assets/object/mesh.obj",
                 alpha=0.5)
obj_lengths = obj_start.mesh.bounding_box.primitive.extents

In [40]:
obj_start.set_translate([0.5, -0.3, obj_lengths[-1]/2+table_lengths[-1]])
trans_goal = jnp.array([0.5, 0.3, obj_lengths[-2]/2+table_lengths[-1]])
obj_goal_pose = SE3.from_rotation_and_translation(
    SO3.from_rpy_radians(jnp.pi/2, 0,0), trans_goal
)
obj_goal.set_pose(obj_goal_pose)

### Regarding Kinematics

In [115]:
#utility
def to_posevec(pose:SE3):
    return jnp.hstack([pose.translation(), pose.rotation().log()])

def make_pose():
    return SE3.from_rotation_and_translation(
        SO3(np.random.random(4)).normalize(),
        np.random.uniform([-0.3,-0.5,0.3],[0.6, 0.5, 0.8])
    )

# Kinematics
def get_rotvec_angvel_map(v):
    def skew(v):
        v1, v2, v3 = v
        return jnp.array([[0, -v3, v2],
                        [v3, 0., -v1],
                        [-v2, v1, 0.]])
    vmag = jnp.linalg.norm(v)
    vskew = skew(v)
    return jnp.eye(3) \
        - 1/2*skew(v) \
        + vskew@vskew * 1/vmag**2 * (1-vmag/2 * jnp.sin(vmag)/(1-jnp.cos(vmag)))

@jax.jit
def get_ee_fk_jac(q):
    # outputs ee_posevec and analytical jacobian
    fks = panda_model.fk_fn(q)
    p_ee = fks[-1][-3:]
    rotvec_ee = SO3(fks[-1][:4]).log()
    E = get_rotvec_angvel_map(rotvec_ee)
    jac = []
    for posevec in fks[1:8]:
        p_frame = posevec[-3:]
        rot_axis = SE3(posevec).as_matrix()[:3, 2]
        lin_vel = jnp.cross(rot_axis, p_ee - p_frame)
        jac.append(jnp.hstack([lin_vel, rot_axis]))
    jac = jnp.array(jac).T
    jac = jac.at[3:, :].set(E @ jac[3:, :])
    return jnp.hstack([p_ee, rotvec_ee]), jac



### Problem

In [54]:
# problem 
obj_start_pose = obj_start.pose
obj_goal_pose = obj_goal.pose

In [73]:
frame_pick = Frame(world.vis, "g_pick")
frame_place = Frame(world.vis, "g_place")
def show_grasp(grasp):
    grasp_pose = grasp_reconst(grasp)
    frame_pick.set_pose(obj_start_pose@grasp_pose)
    frame_place.set_pose(obj_goal_pose@grasp_pose)

In [150]:
def to_posevec(pose:SE3):
    return jnp.hstack([pose.translation(), pose.rotation().log()])
def to_SE3(posevec:Array):
    return SE3.from_rotation_and_translation(
        SO3.exp(posevec[3:]), posevec[:3]
    )

In [363]:
clipped_logit_viol = lambda g: jnp.clip(1. - grasp_logit_fn(g), a_min=0.)
class GraspProb(PyCeres.CostFunction):
    def __init__(self):
        super().__init__()
        self.set_num_residuals(1)
        self.set_parameter_block_sizes([3])
        self.weight = 1.
        self.logit_threshold = 1.
    
    def set_weight(self, weight):
        self.weight = weight

    def Evaluate(self, parameters, residuals, jacobians):
        grasp = parameters[0]
        viol, jac = jax.value_and_grad(clipped_logit_viol)(grasp)
        residuals[0] = self.weight * viol
        if (jacobians != None):
            jacobians[0][:] = self.weight * jac
        show_grasp(grasp)
        time.sleep(0.1)
        return True

grasp_fk = lambda g, obj_pose: to_posevec(obj_pose@grasp_reconst(g))
class KinError(PyCeres.CostFunction):
    def __init__(self):
        super().__init__()
        self.set_num_residuals(3)
        self.set_parameter_block_sizes([3, 7])
        self.weight_mat = None
        self.obj_pose = None
        
    def set_obj_pose(self, obj_pose):
        self.obj_pose = obj_pose
        
    def set_weight(self, weight):
        self.weight_mat = np.diag(np.sqrt(weight))

    def Evaluate(self, parameters, residuals, jacobians):
        grasp, q  = parameters[0], parameters[1]
        target, jac_grasp = value_and_jacrev(grasp_fk, grasp, self.obj_pose)
        ee, jac_q = get_ee_fk_jac(q)
        residuals[:] = self.weight_mat @ (target - ee)[:3]
        if (jacobians != None):
            jacobians[0][:] = (self.weight_mat @jac_grasp)[:3,:].flatten()
            # (self.weight_mat @ jac_grasp).flatten()
            jacobians[1][:] = - (self.weight_mat @ jac_q)[:3,:].flatten()
        
        frame_pick.set_pose(to_SE3(target))
        panda.set_joint_angles(q)
        time.sleep(0.1)
        return True

In [364]:
pose_weight = np.array([1, 1, 1, 0.3, 0.3, 0.3])
feature_grasp_prob = GraspProb()
feature_fk_error = KinError()
feature_fk_error.set_weight(pose_weight*0.1)
feature_fk_error.set_obj_pose(obj_start_pose)

In [365]:
g_pick = np.ones(3)*0.8
q_pick = panda.neutral.copy()

In [371]:
problem = PyCeres.Problem()
problem.AddResidualBlock(feature_grasp_prob, None, g_pick)
problem.AddResidualBlock(feature_fk_error, None, g_pick, q_pick)

options = PyCeres.SolverOptions()
#options.minimizer_type = PyCeres.MinimizerType.LINE_SEARCH
options.minimizer_type = PyCeres.MinimizerType.TRUST_REGION
options.linear_solver_type = PyCeres.LinearSolverType.SPARSE_NORMAL_CHOLESKY
#options.parameter_tolerance = 1e-4
options.minimizer_progress_to_stdout = True
summary = PyCeres.Summary()


In [372]:
PyCeres.Solve(options, problem, summary)

iter      cost      cost_change  |gradient|   |step|    tr_ratio  tr_radius  ls_iter  iter_time  total_time
   0  3.993054e-01    0.00e+00    1.47e-01   0.00e+00   0.00e+00  1.00e+04        0    3.25e-01    3.25e-01
   1  3.292985e-01    7.00e-02    1.16e-01   6.05e+00   1.75e-01  7.85e+03        1    7.25e-01    1.05e+00
   2  5.797845e-01   -2.50e-01    1.16e-01   6.24e+00  -7.61e-01  3.93e+03        1    3.63e-01    1.41e+00
   3  5.795436e-01   -2.50e-01    1.16e-01   6.24e+00  -7.60e-01  9.81e+02        1    3.42e-01    1.76e+00
   4  5.780406e-01   -2.49e-01    1.16e-01   6.19e+00  -7.55e-01  1.23e+02        1    3.44e-01    2.10e+00
   5  5.600733e-01   -2.31e-01    1.16e-01   5.83e+00  -7.02e-01  7.67e+00        1    3.74e-01    2.47e+00
   6  2.529881e-01    7.63e-02    1.05e-01   3.46e+00   2.47e-01  6.79e+00        1    7.51e-01    3.23e+00
   7  2.445811e-01    8.41e-03    1.01e-01   3.40e+00   3.62e-02  3.78e+00        1    7.25e-01    3.95e+00
   8  7.628743e-02    1.68e-

KeyboardInterrupt: 

In [340]:
show_grasp(g_pick)
panda.set_joint_angles(q_pick)

In [157]:
panda.set_joint_angles(q_pick)

In [145]:
grasp_pose = grasp_reconst(g_pick)
frame_pick.set_pose(obj_start_pose@grasp_pose)
frame_place.set_pose(obj_goal_pose@grasp_pose)

In [197]:
print(summary.BriefReport())

Ceres Solver Report: Iterations: 42, Initial cost: 1.826630e+02, Final cost: 1.700756e-02, Termination: CONVERGENCE
