In [2]:
import numpy as np
import jax
import jax.numpy as jnp
from jaxlie import SE3, SO3
import jax_dataclasses as jdc
from functools import partial

from sdf_world.sdf_world import *
from sdf_world.robots import *
from sdf_world.util import *

In [3]:
world = SDFWorld()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7008/static/


In [7]:
#load object
from flax import linen as nn
from flax.training import orbax_utils
import orbax
import pickle

class GraspNet(nn.Module):
    hidden_dim: int
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        logit = nn.Dense(features=5)(x)
        return logit

orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
raw_restored = orbax_checkpointer.restore("model/grasp_net")
params = raw_restored["params"]
grasp_net = GraspNet(raw_restored["hidden_dim"])
grasp_fn = lambda x: grasp_net.apply(params, x)

with open("./sdf_world/assets/object"+'/info.pkl', 'rb') as f:
    obj_data = pickle.load(f)
scale_to_norm = obj_data["scale_to_norm"]
def grasp_reconst(g:Array):
    rot = SO3(grasp_fn(g)[1:5]).normalize()
    trans = g/scale_to_norm
    return SE3.from_rotation_and_translation(rot, trans)
grasp_logit_fn = lambda g: grasp_fn(g)[0]

In [5]:
obj_start = Mesh(world.vis, "obj_start", "./sdf_world/assets/object/mesh.obj",
                 alpha=0.5)
frame = Frame(world.vis, "grasp_pose")
d, w, h = obj_start.mesh.bounding_box.primitive.extents

In [512]:
panda_model = RobotModel(PANDA_URDF, PANDA_PACKAGE)
panda = Robot(world.vis, "panda", panda_model, alpha=0.5)
panda.reduce_dim([7, 8], [0.04, 0.04])

In [514]:
# Kinematics
def get_rotvec_angvel_map(v):
    def skew(v):
        v1, v2, v3 = v
        return jnp.array([[0, -v3, v2],
                        [v3, 0., -v1],
                        [-v2, v1, 0.]])
    vmag = jnp.linalg.norm(v)
    vskew = skew(v)
    return jnp.eye(3) \
        - 1/2*skew(v) \
        + vskew@vskew * 1/vmag**2 * (1-vmag/2 * jnp.sin(vmag)/(1-jnp.cos(vmag)))

@jax.jit
def get_ee_fk_jac(q):
    # outputs ee_posevec and analytical jacobian
    fks = panda_model.fk_fn(q)
    p_ee = fks[-1][-3:]
    rotvec_ee = SO3(fks[-1][:4]).log()
    E = get_rotvec_angvel_map(rotvec_ee)
    jac = []
    for posevec in fks[1:8]:
        p_frame = posevec[-3:]
        rot_axis = SE3(posevec).as_matrix()[:3, 2]
        lin_vel = jnp.cross(rot_axis, p_ee - p_frame)
        jac.append(jnp.hstack([lin_vel, rot_axis]))
    jac = jnp.array(jac).T
    jac = jac.at[3:, :].set(E @ jac[3:, :])
    return jnp.hstack([p_ee, rotvec_ee]), jac

def to_posevec(pose:SE3):
    return jnp.hstack([pose.translation(), pose.rotation().log()])

In [504]:
obj_start.set_translate([0.5,0, h/2])

In [1257]:
@jdc.pytree_dataclass
class Residual:
    coordinate: Array
    residual_and_jac_fn: Callable
    weight: Array

@jdc.pytree_dataclass
class Param:
    obj_pose: SE3
    pose_goal: SE3

In [784]:
x = jnp.hstack([jnp.array([-0.,-2.,0.]), panda.neutral])
param = Param(obj_start.pose)

In [829]:
def value_and_jacrev(x, state, f):
    y, pullback = jax.vjp(f, x, state)
    basis = jnp.eye(y.size, dtype=y.dtype)
    jac = jax.vmap(pullback)(basis)
    return y, jac[0]

In [830]:
grasp_residual_fn = lambda x, param: jnp.array([1. - grasp_logit_fn(x)])
vg_grasp_residual_fn = jax.jit(partial(value_and_jacrev, f=grasp_residual_fn))
grasp_feature = Residual(
    jnp.array([0, 1, 2]),
    vg_grasp_residual_fn,
    jnp.array(1.)
)

In [1259]:
def vg_kin_residual_fn(x, param:Param):
    def grasp_fk(grasp, param:Param):
        grasp_pose = param.obj_pose @ grasp_reconst(grasp)
        return to_posevec(grasp_pose)
    grasp, q = x[:3], x[3:]
    grasp_posevec, grasp_jac = value_and_jacrev(grasp, param, grasp_fk)
    ee_posevec, ee_jac = get_ee_fk_jac(q)
    residual = grasp_posevec - ee_posevec
    jac = jnp.hstack([grasp_jac, ee_jac])
    return residual, jac

kin_feature = Residual(
    jnp.arange(10),
    vg_kin_residual_fn,
    jnp.array([1, 1, 1, 0.3, 0.3, 0.3])
)

def vg_fk_residual_fn(x, param:Param):
    ee_posevec, ee_jac = get_ee_fk_jac(x)
    residual = to_posevec(param.pose_goal) - ee_posevec
    jac = - ee_jac
    return residual, jac
fk_feature = Residual(
    jnp.arange(7),
    vg_fk_residual_fn,
    jnp.array([1, 1, 1, 0.3, 0.3, 0.3])
)

In [1427]:
def get_value_jacs(x, param, features:List[Residual]):
    residuals, jacs = [], []
    for feature in features:
        input = x[feature.coordinate]
        residual, jac = feature.residual_and_jac_fn(input, param)
        jac_full = jnp.zeros((len(residual), len(x)))
        jac_full = jac_full.at[:,feature.coordinate].set(jac)
        residuals.append(residual)
        jacs.append(jac_full)
    return jnp.hstack(residuals), jnp.vstack(jacs)

def calculate_step(residual, jac, weights):
    grad = jac.T@jnp.diag(weights)@residual
    hess_true = jac.T@jnp.diag(weights)@jac
    try:
        np.linalg.cholesky(hess_true)
        hess = hess_true
        print("hess")
    except:
        min_eig = -jnp.linalg.eigh(hess_true)[0][0]
        min_eig = jnp.maximum(min_eig, 1e-4)
        hess = hess_true + jnp.eye(len(x)) * min_eig * 1.1
        print("mod_hess")
    p = np.linalg.solve(hess, -grad)
    return p

In [1428]:
x = panda.neutral
param = Param(obj_start.pose, frame.pose)
weight_mat = jnp.array([1, 1, 1, 0.3, 0.3, 0.3])

In [1484]:
residual, jac = get_value_jacs(x, param, [fk_feature])
p = calculate_step(residual, jac, weight_mat)
x = x + p*0.1

#frame.set_pose(grasp_fk(x[:3], param))
panda.set_joint_angles(x)

mod_hess


: 

In [864]:
features = [grasp_feature, kin_feature]
residuals = []
jacs = []
for feature in features:
    residual, jac = get_value_jac(x, param, feature)
    residuals.append(residual)
    jacs.append(jac)
residual = jnp.hstack(residuals)
jac = jnp.vstack(jacs).shape

In [1057]:
x = jnp.hstack([jnp.ones(3), panda.neutral])

In [1054]:
def grasp_fk(grasp, param:Param):
    grasp_pose = param.obj_pose @ grasp_reconst(grasp)
    return grasp_pose
frame.set_pose(grasp_fk(x[:3], param))
panda.set_joint_angles(x[3:])

In [887]:
features = [grasp_feature, kin_feature]
residuals = []
jacs = []
for feature in features:
    residual, jac = get_value_jac(x, param, feature)
    residuals.append(residual)
    jacs.append(jac)
residual = jnp.hstack(residuals)
jac = jnp.vstack(jacs)

In [1117]:
weight_mat = jnp.diag(jnp.array([1, 1, 1, 0.3, 0.3, 0.3, 1.]))
features = [grasp_feature, kin_feature]
residuals = []
jacs = []
for feature in features:
    residual, jac = get_value_jac(x, param, feature)
    residuals.append(residual)
    jacs.append(jac)
residual = jnp.hstack(residuals)
jac = jnp.vstack(jacs)

grad = - jac.T@weight_mat@residual
hess_true = jac.T@weight_mat@jac

try:
    np.linalg.cholesky(hess_true)
    hess = hess_true
    print("hess")
except:
    min_eig = -jnp.linalg.eigh(hess_true)[0][0]
    min_eig = jnp.maximum(min_eig, 1e-4)
    hess = hess_true + jnp.eye(len(x)) * min_eig * 1.1
    print("mod_hess")
p = np.linalg.solve(hess, -grad)
x = x + p*0.01

frame.set_pose(grasp_fk(x[:3], param))
panda.set_joint_angles(x[3:])

mod_hess


In [761]:
grasp_fk_fn = lambda x:to_posevec(obj_start.pose@grasp_reconst(x))

In [1164]:
x = panda.neutral

In [1252]:
grasp = jnp.array([0,0,1])

In [1249]:
grasp_pose = grasp_fk(grasp, param)
frame.set_pose(grasp_pose)

In [1256]:
kin_feature.coordinate

Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

In [1253]:
weight_mat = jnp.diag(jnp.array([1, 1, 1, 0.3, 0.3, 0.3]))
# ee, jac = get_ee_fk_jac(x)
# residual = to_posevec(grasp_pose) - ee
residual, jac = get_value_jac(x, param, kin_feature)

grad = - jac.T@weight_mat@residual 
hess_true = jac.T@weight_mat@jac

try:
    np.linalg.cholesky(hess_true)
    hess = hess_true
    print("hess")
except:
    min_eig = -jnp.linalg.eigh(hess_true)[0][0]
    min_eig = jnp.maximum(min_eig, 1e-4)
    hess = hess_true + jnp.eye(7) * min_eig * 1.1
    print("mod_hess")
p = np.linalg.solve(hess, -grad)
x = x + p*0.1

panda.set_joint_angles(x)

IndexError: index 7 is out of bounds for axis 0 with size 7

In [757]:
feature = grasp_feature

In [759]:
val, jac = feature.residual_and_jac_fn(x[feature.coordinate])

In [760]:
val

Array(326.8755, dtype=float32)

In [736]:
vg_grasp_residual_fn(x)

(Array(326.8755, dtype=float32),
 Array([-283.92578 , -253.43362 ,  -11.379831], dtype=float32))

In [511]:
residual = (1. - grasp_logit_fn(x))
jac = - jax.grad(grasp_logit_fn)(x)
p = - 1./(jac@jac) * jac*(residual)
x = x + p

grasp_pose = obj_start.pose@grasp_reconst(x)
frame.set_pose(grasp_pose)

In [683]:
x = panda.neutral

In [724]:
weight_mat = jnp.diag(jnp.array([1, 1, 1, 0.3, 0.3, 0.3]))
ee, jac = get_ee_fk_jac(x)
residual = to_posevec(grasp_pose) - ee

grad = - jac.T@weight_mat@residual 
hess_true = jac.T@weight_mat@jac

try:
    np.linalg.cholesky(hess_true)
    hess = hess_true
    print("hess")
except:
    min_eig = -jnp.linalg.eigh(hess_true)[0][0]
    min_eig = jnp.maximum(min_eig, 1e-4)
    hess = hess_true + jnp.eye(7) * min_eig * 1.1
    print("mod_hess")
p = np.linalg.solve(hess, -grad)
x = x + p*0.1

panda.set_joint_angles(x)

mod_hess
