In [1]:
import numpy as np
import jax
import jax.numpy as jnp
from jaxlie import SE3, SO3
import cyipopt

from sdf_world.sdf_world import *
from sdf_world.robots import *
from sdf_world.util import *

In [2]:
world = SDFWorld()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7001/static/


In [3]:
panda_model = RobotModel(PANDA_URDF, PANDA_PACKAGE)
panda = Robot(world.vis, "panda", panda_model, alpha=0.5)
panda.reduce_dim([7, 8], [0.04, 0.04])

In [618]:
# configurations
robot_dim = 7
horizon = 20
dim = robot_dim * horizon
dt = 0.1

qdot_lb = -np.ones(dim)
qdot_ub = np.ones(dim)

#dt_vec = jnp.ones(dim) * dt
dt_vec = jnp.ones(horizon*robot_dim) * dt
#jnp.hstack([jnp.ones(robot_dim*6)*dt1, jnp.ones(robot_dim*10)*dt2])

lower_tri = np.tril(np.full((horizon,horizon), 1))
upper_tri = np.triu(np.full((horizon,horizon), 1))
eye = np.eye(robot_dim)
integration_mat = np.kron(lower_tri, eye) @ np.diag(dt_vec)
double_integration_mat = integration_mat@integration_mat
#qdot_max = qdot_max*np.kron(upper_tri, eye)@dt_vec

to_mat = lambda x: x.reshape(-1, robot_dim)
to_vec = lambda x: x.flatten()
def to_qdots_mat(q0, qs):
    qs = jnp.vstack([q0, to_mat(qs)])
    return (qs[1:] - qs[:-1]) /dt

In [619]:
#state initialize
q0 = panda.neutral
qs = jnp.tile(q0, horizon)

In [112]:
# # rollout
# @jax.jit
# def rollout(q0, qdots):
#     return integration_mat@qdots + jnp.tile(q0, horizon)

In [153]:
# state, u initialize
q0 = panda.neutral
qdots = jnp.zeros(dim)
#draw_ee_traj(u, state)

In [620]:
# Kinematics
def skew(v):
    v1, v2, v3 = v
    return jnp.array([[0, -v3, v2],
                      [v3, 0., -v1],
                      [-v2, v1, 0.]])

def get_rotvec_angvel_map(v):
    vmag = jnp.linalg.norm(v)
    return jnp.eye(3) \
        - 1/2*skew(v) \
        + skew(v)@skew(v) * 1/vmag**2 * (1-vmag/2 * jnp.sin(vmag)/(1-jnp.cos(vmag)))

@jax.jit
def get_ee_fk_jac(q):
    #TODO: rotation change with quaternion
    fks = panda_model.fk_fn(q)
    p_ee = fks[-1][-3:]
    rotvec_ee = SO3(fks[-1][:4]).log()
    E = get_rotvec_angvel_map(rotvec_ee)
    jac = []
    for posevec in fks[1:8]:
        p_frame = posevec[-3:]
        rot_axis = SE3(posevec).as_matrix()[:3, 2]
        lin_vel = jnp.cross(rot_axis, p_ee - p_frame)
        jac.append(jnp.hstack([lin_vel, rot_axis]))
    jac = jnp.array(jac).T
    jac = jac.at[3:, :].set(E @ jac[3:, :])
    return jnp.hstack([p_ee, rotvec_ee]), jac


In [621]:
def to_posevec(pose:SE3):
    return jnp.hstack([
        pose.translation(), pose.rotation().log()
    ])
def to_SE3(posevec:Array):
    return SE3.from_rotation_and_translation(
        SO3.exp(posevec[3:]), posevec[:3]
    )

@jax.jit
def pose_err(q, target:SE3):
    ee_posevec, _ = get_ee_fk_jac(q)
    target_posevec = to_posevec(target)
    err = (target_posevec - ee_posevec)
    val = jnp.sum(err**2)
    return val

@jax.jit
def pose_err_grad(q, target:SE3):
    ee_posevec, jac = get_ee_fk_jac(q)
    target_posevec = to_posevec(target)
    err = (target_posevec - ee_posevec)
    grad_err = - 2*jac.T@err
    return grad_err

In [622]:
#decision vars: qs
def obj_fn(qs, q0, target):
    errs = jax.vmap(pose_err, in_axes=(0,None))(
        to_mat(qs), target)
    qdots = to_qdots_mat(q0, qs).flatten()
    draw_ee_traj(q0, qdots) #debug
    return errs.sum()

def grad_fn(qs, q0, target):
    err_grads = jax.vmap(pose_err_grad, in_axes=(0,None))(
        to_mat(qs), target)
    return err_grads.flatten() @ integration_mat

qdots_lb = -jnp.ones(dim)
qdots_ub = jnp.ones(dim)

def constr_fn(qs, q0):
    #qdots_mat = to_qdots_mat(q0, qs)
    ub_viol = qdots_ub - to_qdots_mat(q0, qs).flatten()
    lb_viol = to_qdots_mat(q0, qs).flatten() - qdots_lb
    return jnp.hstack([lb_viol, ub_viol])

constr_fn_jac = jax.jacrev(constr_fn, argnums=0)

In [590]:
# #decision vars: qdots
# def obj_fn(qdots, q0, target):
#     qs = rollout(q0, qdots)
#     errs = jax.vmap(pose_err, in_axes=(0,None))(to_mat(qs), target)

#     draw_ee_traj(q0, qdots) #debug
#     return errs.sum()

# def grad_fn(qdots, q0, target):
#     qs = rollout(q0, qdots)
#     err_grads = jax.vmap(pose_err_grad, in_axes=(0,None))(to_mat(qs), target)
#     return err_grads.flatten() @ integration_mat

In [591]:
import cyipopt
from functools import partial

In [584]:
frame = Frame(world.vis, "frame")
frame_ee = Frame(world.vis, "frame_ee")

In [583]:
del frame, frame_ee

In [623]:
def make_random_pose():
    return SE3.from_rotation_and_translation(
        SO3(np.random.random(4)).normalize(),
        np.random.uniform([-0.,-0.5,0.3],[0.5,0.5,0.8])
    )
target = make_random_pose()
frame.set_pose(target)

In [624]:
objective = partial(obj_fn, q0=q0, target=target)
gradient = partial(grad_fn, q0=q0, target=target)
constraints = partial(constr_fn, q0=q0)
jacobian = partial(constr_fn_jac, q0=q0)
class Prob:
    pass
prob = Prob()
setattr(prob, "objective", objective)
setattr(prob, "gradient", gradient)
setattr(prob, "constraints", constraints)
setattr(prob, "jacobian", jacobian)

In [625]:
ipopt = cyipopt.Problem(
    n=dim, m=0,
    problem_obj=prob,
    lb=qdot_lb, ub=qdot_ub,
    cl=np.zeros(dim), cu=np.full(dim, np.inf)
)

In [626]:
#ipopt.add_option("acceptable_obj_change_tol", 1.)
ipopt.add_option("acceptable_tol", 0.0001)
ipopt.add_option("acceptable_iter", 2)
ipopt.add_option("print_level", 5)
ipopt.add_option("max_iter", 100)

qsol, info = ipopt.solve(qdots)

This is Ipopt version 3.14.10, running with linear solver MUMPS 5.2.1.

Number of nonzeros in equality constraint Jacobian...:        0
Number of nonzeros in inequality constraint Jacobian.:        0
Number of nonzeros in Lagrangian Hessian.............:        0

Total number of variables............................:      140
                     variables with only lower bounds:        0
                variables with lower and upper bounds:      140
                     variables with only upper bounds:        0
Total number of equality constraints.................:        0
Total number of inequality constraints...............:        0
        inequality constraints with only lower bounds:        0
   inequality constraints with lower and upper bounds:        0
        inequality constraints with only upper bounds:        0

iter    objective    inf_pr   inf_du lg(mu)  ||d||  lg(rg) alpha_du alpha_pr  ls
   0  5.6487805e+02 0.00e+00 1.95e+01   0.0 0.00e+00    -  0.00e+00 0.00e+00 

In [629]:
#qs = rollout(q0, to_qdots_mat(q0, qsol)) 
draw_ee_traj(q0, to_qdots_mat(q0, qsol).flatten())

In [643]:
i=0

In [665]:
panda.set_joint_angles(to_mat(qsol)[i])
frame_ee.set_pose(SE3(panda_model.fk_fn(to_mat(qsol)[i])[-1]))
i+= 1

IndexError: index 20 is out of bounds for axis 0 with size 20

In [543]:
pose_err(to_mat(qs)[-1], target)

Array(0.0190942, dtype=float32)

In [544]:
pose_err_grad(to_mat(qs)[-1], target)

Array([-0.00276927,  0.00031458,  0.08400258, -0.00164529, -0.00213077,
       -0.00996618,  0.02344448], dtype=float32)

Array([ 0.18797217, -1.5620015 , -1.7517192 , -2.6781812 , -1.1485922 ,
        3.8666735 , -1.8566442 ], dtype=float32)

SE3(wxyz=[-0.49978998 -0.51574    -0.41063    -0.56179   ], xyz=[ 0.08882    -0.15044999  0.35022998])

In [41]:
pose_err_grad(panda.neutral, target)

Array([-0.63412786,  1.6975825 , -0.63412786, -1.4206374 ,  3.382577  ,
       -1.1811363 ,  1.8599132 ], dtype=float32)

In [18]:
#visualize
import meshcat.geometry as g
get_ee_pos = lambda q: panda_model.fk_fn(q)[-1][-3:]
def draw_ee_traj(q0, qdots):
    world.vis["line"].delete()
    world.vis["line_pc"].delete()

    panda.set_joint_angles(q0)
    qs = rollout(q0, qdots)    
    vertices = jax.vmap(get_ee_pos)(to_mat(qs))
    color = Colors.read("red")
    colors = np.tile(Colors.read("red", return_rgb=True), 
                     vertices.shape[0]).reshape(-1, 3)
    point_obj = g.PointsGeometry(vertices.T, colors.T)
    line_material = g.MeshBasicMaterial(color=color)
    point_material = g.PointsMaterial(size=0.02)
    world.vis["line"].set_object(
        g.Line(point_obj, line_material)
    )
    world.vis["line_pc"].set_object(
        point_obj, point_material)

def make_random_pose():
    return SE3.from_rotation_and_translation(
        SO3(np.random.random(4)).normalize(),
        np.random.uniform([-0.,-0.5,0.3],[0.5,0.5,0.8])
    )

In [30]:
draw_ee_traj(q0, rollout(q0, qdots))