In [2]:
import numpy as np
import jax
import jax.numpy as jnp
from jaxlie import SE3, SO3

from sdf_world.sdf_world import *
from sdf_world.robots import *
from sdf_world.util import *

In [3]:
world = SDFWorld()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7001/static/


In [4]:
panda_model = RobotModel(PANDA_URDF, PANDA_PACKAGE)
panda = Robot(world.vis, "panda", panda_model, alpha=0.5)
panda.reduce_dim([7, 8], [0.04, 0.04])

In [3732]:
# configurations
robot_dim = 7
horizon = 16
dim = robot_dim * horizon
dt1, dt2 = 0.05, 0.1
umax = 1.

#dt_vec = jnp.ones(dim) * dt
dt_vec = jnp.hstack([jnp.ones(robot_dim*6)*dt1, jnp.ones(robot_dim*10)*dt2])

lower_tri = np.tril(np.full((horizon,horizon), 1))
upper_tri = np.triu(np.full((horizon,horizon), 1))
eye = np.eye(robot_dim)
integration_mat = np.kron(lower_tri, eye) @ np.diag(dt_vec)
double_integration_mat = integration_mat@integration_mat
qdot_max = umax*np.kron(upper_tri, eye)@dt_vec

to_mat = lambda x: x.reshape(-1, robot_dim)
to_vec = lambda x: x.flatten()

In [3733]:
# rollout
@jax.jit
def integration(q0, qdots):
    return integration_mat@qdots + jnp.tile(q0, horizon)
    
@jax.jit
def rollout(us, state):
    q0, qdot0 = state
    qdots = integration(qdot0, us)
    qs = integration(q0, qdots)
    return qs, qdots

In [3734]:
# Kinematics
def skew(v):
    v1, v2, v3 = v
    return jnp.array([[0, -v3, v2],
                      [v3, 0., -v1],
                      [-v2, v1, 0.]])

def get_rotvec_angvel_map(v):
    vmag = jnp.linalg.norm(v)
    return jnp.eye(3) \
        - 1/2*skew(v) \
        + skew(v)@skew(v) * 1/vmag**2 * (1-vmag/2 * jnp.sin(vmag)/(1-jnp.cos(vmag)))

@jax.jit
def get_ee_fk_jac(q):
    fks = panda_model.fk_fn(q)
    p_ee = fks[-1][-3:]
    rotvec_ee = SO3(fks[-1][:4]).log()
    E = get_rotvec_angvel_map(rotvec_ee)
    jac = []
    for posevec in fks[1:8]:
        p_frame = posevec[-3:]
        rot_axis = SE3(posevec).as_matrix()[:3, 2]
        lin_vel = jnp.cross(rot_axis, p_ee - p_frame)
        jac.append(jnp.hstack([lin_vel, rot_axis]))
    jac = jnp.array(jac).T
    jac = jac.at[3:, :].set(E @ jac[3:, :])
    return jnp.hstack([p_ee, rotvec_ee]), jac

In [3735]:
from scipy.optimize import minimize, Bounds
# optimization functions
# Objective / gradient
def to_posevec(pose:SE3):
    return jnp.hstack([
        pose.translation(), pose.rotation().log()
    ])
def to_SE3(posevec:Array):
    return SE3.from_rotation_and_translation(
        SO3.exp(posevec[3:]), posevec[:3]
    )

@jax.jit
def vg_pose_err(q, target:SE3):
    ee_posevec, jac = get_ee_fk_jac(q)
    target_posevec = to_posevec(target)
    err = (target_posevec - ee_posevec)
    val = jnp.sum(err**2)
    grad_err = - 2*jac.T@err
    return val, grad_err

def vg_objective(u, state, target):
    qs, _ = rollout(u, state)
    vals, grads = jax.vmap(vg_pose_err, in_axes=(0,None))(to_mat(qs), target)
    grads = grads.flatten() @ double_integration_mat #chain rule
    return vals.sum(), grads

# Constraints
lb, ub = -np.ones(dim), np.ones(dim)
bounds = Bounds(lb, ub)

def constr_fn(u, state):
    qs, qdots = rollout(u, state)
    qdot_max_viol = jnp.min(qdot_max - jnp.abs(qdots))
    q_ub_viol = jnp.min(jnp.tile(panda.ub, horizon) - qs)
    q_lb_viol = jnp.min(qs - jnp.tile(panda.lb, horizon))
    return jnp.hstack([
        qdot_max_viol, q_ub_viol, q_lb_viol
    ])

In [3736]:
# dynamics
def update_dynamics(u, state):
    q, qdot = state
    q_new = q + qdot*dt1
    qdot_new = qdot + u*dt1
    return q_new, qdot_new

In [3737]:
#visualize
import meshcat.geometry as g
get_ee_pos = lambda q: panda_model.fk_fn(q)[-1][-3:]
def draw_ee_traj(u, state):
    world.vis["line"].delete()
    world.vis["line_pc"].delete()

    q = state[0]
    panda.set_joint_angles(q)
    qs, _ = rollout(u, state)    

    vertices = jax.vmap(get_ee_pos)(to_mat(qs))
    color = Colors.read("red")
    # colors = np.tile(Colors.read("red", return_rgb=True), 
    #                  vertices.shape[0]).reshape(-1, 3)
    
    point_obj = g.PointsGeometry(vertices.T, colors.T)
    line_material = g.MeshBasicMaterial(color=color)
    point_material = g.PointsMaterial(size=0.02)
    world.vis["line"].set_object(
        g.Line(point_obj, line_material)
    )
    world.vis["line_pc"].set_object(
        point_obj, point_material)

def make_random_pose():
    return SE3.from_rotation_and_translation(
        SO3(np.random.random(4)).normalize(),
        np.random.uniform([-0.,-0.5,0.3],[0.5,0.5,0.8])
    )

In [2270]:
frame = Frame(world.vis, "frame")

In [3738]:
# state, u initialize
state = (panda.neutral, jnp.zeros(robot_dim))
u = jnp.zeros(dim)
draw_ee_traj(u, state)

In [2629]:
#compile
vg_objective_j = jax.jit(vg_objective).lower(u, state, target_pose).compile()
constraints_j = jax.jit(constr_fn).lower(u, state).compile()
jac_constraints_j = jax.jit(jax.jacrev(constr_fn)).lower(u, state).compile()

In [3739]:
class MPC:
    def __init__(
            self, 
            vg_obj_fn, 
            v_constr_fn,
            jac_constr_fn,
            lb, ub):
        self.vg_obj_fn = vg_obj_fn
        self.v_constr_fn = v_constr_fn
        self.jac_constr_fn = jac_constr_fn

        def vg_obj_fn_aug(x, state, target):
            u, t = x[:-3], x[-3:]
            val, grads = self.vg_obj_fn(u, state, target)
            val_new = val + jnp.sum(t)
            grads_new = jnp.hstack([grads, 1., 1., 1.])
            return val_new, grads_new    
        self.vg_obj_fn_aug = vg_obj_fn_aug
        self.v_constr_fn_aug = lambda x, state: v_constr_fn(x[:-3], state) + x[-3:]
        self.jac_constr_fn_aug = lambda x, state: jnp.hstack([self.jac_constr_fn(x[:-3], state), jnp.eye(3)])
        self.bounds_aug = Bounds(
            lb=np.hstack([lb, 0,0,0]), 
            ub=np.hstack([ub, np.inf, np.inf, np.inf]))

        self.lb = lb
        self.ub = ub
        self.bounds = Bounds(self.lb, self.ub)
    
    def get_fixed_input_dict(self, elastic=False, maxiter=10, ftol=0.001):
        if elastic:
            input_dict = {
                "fun":self.vg_obj_fn_aug,
                "method":"SLSQP",
                "jac":True,
                "bounds":self.bounds_aug,
                "options":{
                    "ftol":ftol,
                    "maxiter":maxiter
                }
            }
        else:
            input_dict = {
                "fun":self.vg_obj_fn,
                "method":"SLSQP",
                "jac":True,
                "bounds":self.bounds,
                "options":{
                    "ftol":ftol,
                    "maxiter":maxiter
                }
            }
        return input_dict
    
    def get_constraint_tuple(self, state, elastic=False):
        if elastic:
            con = tuple([{
                "type":"ineq", 
                "fun":self.v_constr_fn_aug, 
                "jac":self.jac_constr_fn_aug, 
                "args":(state,)
            }])
        else:
            con = tuple([{
                "type":"ineq", 
                "fun":self.v_constr_fn, 
                "jac":self.jac_constr_fn, 
                "args":(state,)
            }])
        return con

    def is_infeasible(self, x, state):
        return self.v_constr_fn(x, state).min() < 0.
    
    def optimize(
            self, x, state, target_pose, force_elastic=False, 
            maxiter=10, ftol=0.001):
        if self.is_infeasible(x, state) or force_elastic:
            # infeasible start: elastic mode
            print("elastic mode")
            # feasible start
            t0 = - jnp.clip(self.v_constr_fn(u, state), a_max=0.)
            res = minimize(
                x0=jnp.hstack([x, t0]),
                args=(state, target_pose),
                constraints=self.get_constraint_tuple(state, True),
                **self.get_fixed_input_dict(True, maxiter, ftol)
            )
            if self.is_infeasible(res.x[:-3], state):
                print("cannot solve infeasibility")
                return False, res.x[:-3]    
            print(res.message)
            return True, res.x[:-3]
        else:
            print("normal mode")
            # feasible start    
            res = minimize(
                x0=x,
                args=(state, target_pose),
                constraints=self.get_constraint_tuple(state),
                **self.get_fixed_input_dict(maxiter=maxiter, ftol=ftol)
            )
            print(res.message)
            return True, res.x
        

In [3740]:
mpc = MPC(vg_objective_j, constraints_j, jac_constraints_j, lb, ub)

In [3741]:
@jax.jit
def update(u_opt, state):
    state_new = update_dynamics(u_opt[:robot_dim], state)
    qs, qdots = rollout(u_opt, state)
    last_qdot = to_mat(qdots)[-1]
    u_new = jnp.hstack([u_opt[robot_dim:], - last_qdot / dt2])
    return u_new, state_new

In [3742]:
for j in range(10):
    target_pose = make_random_pose()
    frame.set_pose(target_pose)
    for i in range(100):
        solved, u_opt = mpc.optimize(u, state, target_pose, maxiter=10, ftol=0.00001)
        if solved == False:
            print("reset input")
            _, u_opt = mpc.optimize(jnp.zeros(dim), state, target_pose, maxiter=10, ftol=0.00001)
        u, state = update(u_opt, state)
        draw_ee_traj(u_opt, state)


normal mode
Iteration limit reached
elastic mode
cannot solve infeasibility
reset input
normal mode
Iteration limit reached
elastic mode
cannot solve infeasibility
reset input
normal mode
Iteration limit reached
elastic mode
cannot solve infeasibility
reset input
elastic mode
cannot solve infeasibility
elastic mode
cannot solve infeasibility
reset input
elastic mode
cannot solve infeasibility
elastic mode
cannot solve infeasibility
reset input
elastic mode
cannot solve infeasibility
elastic mode
cannot solve infeasibility
reset input
elastic mode
cannot solve infeasibility
elastic mode
cannot solve infeasibility
reset input
elastic mode
cannot solve infeasibility
elastic mode
cannot solve infeasibility
reset input
elastic mode
cannot solve infeasibility
elastic mode
cannot solve infeasibility
reset input
elastic mode
cannot solve infeasibility
elastic mode
cannot solve infeasibility
reset input
elastic mode
cannot solve infeasibility
elastic mode
cannot solve infeasibility
reset input


KeyboardInterrupt: 

In [3599]:
#"x0":u,
# constraints = tuple([
#     {"type":"ineq", 
#      "fun":constraints_j, "jac":jac_constraints_j, "args":(state,)}
# ])
inputs = {
    "fun":vg_objective_j,
    "args":(state, target_pose),
    "method":"SLSQP",
    "jac":True,
    "bounds":bounds,
    "options":{
        "ftol":0.0001,
        "maxiter":10
    }
}

In [1911]:
%timeit minimize(x0=u, constraints=constraints, **inputs)

64 ms ± 159 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [1903]:
constraints = tuple([
    {"type":"ineq", 
     "fun":constraints_j, "jac":jac_constraints_j, "args":(state,)}
])

In [1912]:
# first iteration
import time
tic = time.perf_counter()
constraints = tuple([
    {"type":"ineq", 
     "fun":constraints_j, "jac":jac_constraints_j, "args":(state,)}
])
res = minimize(
    fun=vg_objective_j,
    x0=u,
    args=(state,target_pose),
    method="SLSQP",
    jac=True,
    bounds=bounds,
    constraints=constraints,
    options={
        'ftol':0.0001,
        "maxiter":10
    }
)
toc = time.perf_counter()
print(res.message, f"elapsed:{toc-tic}")

draw_ee_traj(res.x, state)
state_new = update_dynamics(res.x[:robot_dim], state)
qs, qdots = rollout(res.x, state)
last_qdot = to_mat(qdots)[-1]
unew = jnp.hstack([res.x[robot_dim:], - last_qdot / dt2]) #dt2
u, state = unew, state_new

Iteration limit reached elapsed:0.08584142685867846


In [1851]:
import time
tic = time.perf_counter()
constraints = tuple([
    {"type":"ineq", 
     "fun":constraints_j, "jac":jac_constraints_j, "args":(state,)}
])
if constraints_j(u, state).min() < 0.:
    print("infeasible!")
res = minimize(
    fun=vg_objective_j,
    x0=u,
    args=(state,target_pose),
    method="SLSQP",
    jac=True,
    bounds=bounds,
    constraints=constraints,
    options={
        'ftol':0.0001,
        'maxiter':10
    }
)
toc = time.perf_counter()
print(res.message, f"elapsed:{toc-tic}")

draw_ee_traj(res.x, state)
state_new = update_dynamics(res.x[:robot_dim], state)
qs, qdots = rollout(res.x, state)
last_qdot = to_mat(qdots)[-1]
unew = jnp.hstack([res.x[robot_dim:], - last_qdot / dt2])
u, state = unew, state_new

infeasible!
Iteration limit reached elapsed:0.07540052593685687


In [1855]:
t0 = - jnp.clip(constraints_j(u, state), a_max=0.)

In [1885]:
x = jnp.hstack([u, t0])
vg_objective_j(x[:-3], state, target_pose)[0] 

Array(4.7472286, dtype=float32)

In [1893]:
def vg_objective_aug(x, state, target):
    u, t = x[:-3], x[-3:]
    val, grads = vg_objective_j(u, state, target)
    val_new = val + jnp.sum(t)
    grads_new = jnp.hstack([grads, 1., 1., 1.])
    return val_new, grads_new    
constr_aug = lambda x, state: constraints_j(x[:-3], state) + x[-3:]
jac_constr_aug = lambda x, state: jnp.hstack([jac_constraints_j(x[:-3], state), jnp.eye(3)])
bounds_aug = Bounds(
    lb=np.hstack([lb, 0,0,0]), ub=np.hstack([ub, np.inf, np.inf, np.inf]))

In [1895]:
#elastic mode
t0 = constraints_j(u, state)
constraints = tuple([
    {"type":"ineq",
     "fun":constr_aug, "jac":jac_constr_aug, "args":(state,)}
])

res = minimize(
    fun = vg_objective_aug,
    x0=jnp.hstack([u, t0]),
    args=(state,target_pose),
    method="SLSQP",
    jac=True,
    bounds=bounds_aug,
    constraints=constraints,
    options={
        'ftol':0.0001,
        'maxiter':10
    }
)

In [1898]:
res.x[-3:]

array([4.85976353e-16, 9.99200722e-16, 6.66133815e-16])

In [1890]:
target_pose

SE3(wxyz=[0.77458    0.31682998 0.21812    0.50206   ], xyz=[ 0.09892999 -0.38472998  0.49405998])