In [45]:
import jax
from jax import Array
import jax.numpy as jnp
import jax_dataclasses as jdc
from typing import *
import numpy as np

import osqp
from scipy import sparse

In [189]:
from jax.experimental.sparse.csr import *
import scipy

@jdc.pytree_dataclass
class SparseCooCoordinates:
    rows: Array
    cols: Array

@jdc.pytree_dataclass
class SparseCooMatrix:
    values: Array
    coords: SparseCooCoordinates
    shape: Tuple[int, int] = jdc.static_field()

    def __matmul__(self, other: Array):
        assert other.shape == (
            self.shape[1],
        ), "Inner product only supported for 1D vectors!"
        return (
            jnp.zeros(self.shape[0], dtype=other.dtype)
            .at[self.coords.rows]
            .add(self.values * other[self.coords.cols])
        )

    def as_dense(self) -> jnp.ndarray:
        return (
            jnp.zeros(self.shape)
            .at[self.coords.rows, self.coords.cols]
            .set(self.values)
        )

    @staticmethod
    def from_scipy(matrix: scipy.sparse.coo_matrix) -> "SparseCooMatrix":
        return SparseCooMatrix(
            values=matrix.data,
            coords=SparseCooCoordinates(
                rows=matrix.row,
                cols=matrix.col,
            ),
            shape=matrix.shape,
        )

    def as_scipy_coo_matrix(self) -> scipy.sparse.coo_matrix:
        return scipy.sparse.coo_matrix(
            (self.values, (self.coords.rows, self.coords.cols)), shape=self.shape
        )

    @property
    def T(self):
        return SparseCooMatrix(
            values=self.values,
            coords=SparseCooCoordinates(
                rows=self.coords.cols,
                cols=self.coords.rows,
            ),
            shape=self.shape[::-1],
        )

In [194]:
class Problem:
    def __init__(
        self,
        dim,
        eq_dim,
        ineq_dim,
        obj_fn, 
        ceq_fn, 
        cineq_fn, 
        grad_fn=None, 
        hess_fn=None,
        ceq_jac_fn=None,
        cineq_jac_fn=None,
    ):
        self.dim = dim
        self.eq_dim = eq_dim
        self.ineq_dim = ineq_dim
        self.f = obj_fn
        self.grad_f = jax.grad(obj_fn)
        self.hess_f = jax.hessian(obj_fn)
        self.g = ceq_fn
        self.h = cineq_fn
        self.jac_g = jax.jacrev(ceq_fn)
        self.jac_h = jax.jacrev(cineq_fn)

obj_fn = lambda x: x[0]*x[3]*jnp.sum(x[:3]) + x[2]
ceq_fn = lambda x: jnp.sum(x**2) - 40
cineq_fn = lambda x: jnp.prod(x) - 25 # >= 0

prob = Problem(4, 1, 1, obj_fn, ceq_fn, cineq_fn)

In [197]:
dim = 4
eq_dim = 1
ineq_dim = 1
x = jnp.ones(4)

In [196]:
grad = prob.grad_f(x)
hess = prob.hess_f(x)
ceq = prob.g(x)
cineq = prob.h(x)
ceq_jac = prob.jac_g(x)
cineq_jac = prob.jac_h(x)

In [198]:
hess_coord = SparseCooCoordinates(np.arange(dim), np.arange(dim))
dim_aug = dim+eq_dim*2+ineq_dim
hess = SparseCooMatrix(hess, hess_coord, shape=(dim_aug,dim_aug))

In [180]:
mat = MyCSC.fromdense(jnp.eye(10)[2:])

In [181]:
mat

CSC(float32[8, 10], nse=8)

In [173]:
mat.as_scipy()

AttributeError: 'CSC' object has no attribute 'as_scipy'

In [144]:
mat = CSC.fromdense(jnp.eye(10))

In [146]:
mat.todense()

Array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]], dtype=float32)

In [137]:
prob.hess_f(jnp.ones(4))

Array([[2., 1., 1., 4.],
       [1., 0., 0., 1.],
       [1., 0., 0., 1.],
       [4., 1., 1., 0.]], dtype=float32)

In [107]:
@jdc.pytree_dataclass
class Param:
    q0: Array
    target: Array

@jdc.pytree_dataclass
class Residual:
    error_and_jac_fn: Callable[[Array, Param], Tuple[Array]]
    weights: Array

@jdc.pytree_dataclass
class Constr:
    value_and_jac_fn: Callable[[Array, Param], Tuple[Array]]
    lb: Array
    ub: Array
    
    @property
    def dim(self):
        return len(self.lb)
    
    def is_eq_constr(self):
        return jnp.array_equal(self.lb, self.ub)
    
    def value_and_jac_onesided(self, x, param):
        cval, cjac = self.value_and_jac_fn(x, param)
        cval_onesided = jnp.hstack([cval-self.lb, self.ub-cval])
        cjac_onesided = jnp.vstack([cjac, -cjac])
        return cval_onesided, cjac_onesided
    
    def get_bounds_onesided(self):
        lb = jnp.zeros(self.dim*2)
        if self.is_eq_constr():
            ub = jnp.zeros(self.dim*2)
        else:
            ub = jnp.full(self.dim*2, jnp.inf)
        return lb, ub


def get_residual_weights(res_fns:Tuple[Residual]):
    weights = []
    for res_fn in res_fns:
        weights.append(res_fn.weights)
    return jnp.hstack(weights)

# def get_constr_bounds(constr_fns:Tuple[ConstrFn]):
#     lbs, ubs = [], []
#     for constr_fn in constr_fns:
#         zeros = jnp.zeros(constr_fn.dim)
#         if constr_fn.is_eq_constr():
#             ubs.append(zeros)
#         else:
#             ubs.append(np.full(constr_fn.dim, np.inf))
#         lbs.append(zeros)
#         lbs.append(zeros)
#     return jnp.hstack(lbs), jnp.hstack(ubs)

def eval_residual(x:Array, param:Param, res_fns:Tuple[Residual]):
    errors, jacs = [], []
    for res_fn in res_fns:
        error, jac = res_fn.error_and_jac_fn(x, param)
        errors.append(error)
        jacs.append(jac)
    return jnp.hstack(errors), jnp.vstack(jacs)

def eval_constr(x, state, constr_fns:Tuple[Constr]):
    ceq, ceq_jacs = [], []
    cineq, cineq_jacs = [], []
    for constr_fn in constr_fns:
        val, jac = constr_fn.value_and_jac_onesided(x, state)
        if constr_fn.is_eq_constr():
            ceq.append(val)
            ceq_jacs.append(jac)
        else:
            cineq.append(val)
            cineq_jacs.append(jac)
    return jnp.hstack(ceq), jnp.vstack(ceq_jacs), jnp.hstack(cineq), jnp.vstack(cineq_jacs)

In [3]:
from sdf_world.sdf_world import *
from sdf_world.robots import *
from sdf_world.util import *

world = SDFWorld()
panda_model = RobotModel(PANDA_URDF, PANDA_PACKAGE)
panda = Robot(world.vis, "panda", panda_model, alpha=0.5)
panda.reduce_dim([7, 8], [0.04, 0.04])

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7007/static/


In [108]:
robot_dim = 7
horizon = 1
dim = robot_dim * horizon
dt = 0.1

to_mat = lambda x: x.reshape(-1, robot_dim)
to_vec = lambda x: x.flatten()
def to_posevec(pose:SE3):
    return jnp.hstack([pose.translation(), pose.rotation().log()])

# Kinematics
def get_rotvec_angvel_map(v):
    def skew(v):
        v1, v2, v3 = v
        return jnp.array([[0, -v3, v2],
                        [v3, 0., -v1],
                        [-v2, v1, 0.]])
    vmag = jnp.linalg.norm(v)
    vskew = skew(v)
    return jnp.eye(3) \
        - 1/2*skew(v) \
        + vskew@vskew * 1/vmag**2 * (1-vmag/2 * jnp.sin(vmag)/(1-jnp.cos(vmag)))

@jax.jit
def get_ee_fk_jac(q):
    # outputs ee_posevec and analytical jacobian
    fks = panda_model.fk_fn(q)
    p_ee = fks[-1][-3:]
    rotvec_ee = SO3(fks[-1][:4]).log()
    E = get_rotvec_angvel_map(rotvec_ee)
    jac = []
    for posevec in fks[1:8]:
        p_frame = posevec[-3:]
        rot_axis = SE3(posevec).as_matrix()[:3, 2]
        lin_vel = jnp.cross(rot_axis, p_ee - p_frame)
        jac.append(jnp.hstack([lin_vel, rot_axis]))
    jac = jnp.array(jac).T
    jac = jac.at[3:, :].set(E @ jac[3:, :])
    return jnp.hstack([p_ee, rotvec_ee]), jac

In [8]:
frame = Frame(world.vis, "frame")
def make_pose():
    return SE3.from_rotation_and_translation(
        SO3(np.random.random(4)).normalize(),
        np.random.uniform([-0.3,-0.5,0.3],[0.6, 0.5, 0.8])
    )

In [109]:
@jax.jit
def vg_pose_error(x:Array, param:Param):
    ee, jac = get_ee_fk_jac(x)
    error = param.target - ee
    return error, -jac

@jax.jit
def vj_joint_limit(x:Array, param:Param):
    cval = x
    cjac = jnp.eye(robot_dim)
    return cval, cjac

In [13]:
pose_d = make_pose()
frame.set_pose(pose_d)
posevec_d = to_posevec(pose_d)

In [110]:
x = panda.neutral
param = Param(panda.neutral, posevec_d)

In [111]:
pose_error = Residual(vg_pose_error, jnp.array([1, 1, 1, 0.3, 0.3, 0.3]))
joint_limit = Constr(vj_joint_limit, panda.lb, panda.ub)
joint_limit_ = Constr(vj_joint_limit, panda.lb, panda.lb)
res_fns = [pose_error]
constr_fns = [joint_limit, joint_limit_]


weight = get_residual_weights(res_fns)
W = jnp.diag(weight)
dim_x = 7
dim_error = 6
dim_eq = 14
dim_ineq = 14

In [117]:
class Problem:
    def __init__(self, dim, res_fns:List[Residual], constr_fns:List[Constr]):
        self.res_fns = res_fns
        self.constr_fns = constr_fns
        self.dim_x = dim
        self.dim_error = 0
        for res_fn in res_fns:
            self.dim_error += len(res_fn.weights)
        self.dim_eq = 0
        self.dim_ineq = 0
        for constr_fn in constr_fns:
            if constr_fn.is_eq_constr():
                self.dim_eq += len(constr_fn.lb)
            else:
                self.dim_ineq += len(constr_fn.lb)

    
prob = Problem(7, res_fns, constr_fns)

In [116]:
prob.dim_error

6

In [118]:
error, jac = eval_residual(x, param, res_fns)
ceq, ceq_jac, cineq, cineq_jac = eval_constr(x, param, constr_fns)

grad = jac.T @ W @ error
hess = jac.T @ W @ jac

In [119]:
mu = 1.

In [126]:
dim_aug = prob.dim_x + prob.dim_eq *2 + prob.dim_ineq
P = jnp.zeros((dim_aug, dim_aug))
P = P.at[:dim_x, :dim_x].set(hess)
q = jnp.hstack([grad, jnp.full(prob.dim_eq*2+prob.dim_ineq, mu)])


In [None]:
A = jnp.hstack(
    [ceq_jac, 
     -jnp.eye(prob.dim_eq), 
     jnp.eye(prob.dim_eq), 
     jnp.zeros((prob.dim_ineq, prob.dim_ineq)) ])

In [129]:
ceq_jac

Array([[ 1.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  1.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  1.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  1.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  1.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  1.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  1.],
       [-1., -0., -0., -0., -0., -0., -0.],
       [-0., -1., -0., -0., -0., -0., -0.],
       [-0., -0., -1., -0., -0., -0., -0.],
       [-0., -0., -0., -1., -0., -0., -0.],
       [-0., -0., -0., -0., -1., -0., -0.],
       [-0., -0., -0., -0., -0., -1., -0.],
       [-0., -0., -0., -0., -0., -0., -1.]], dtype=float32)

In [124]:
dim_aug

28

In [None]:
jnp.vstack([ceq_jac, jnp.eye(dim_eq)])

In [105]:
dim_eq

14

In [104]:
ceq_jac

Array([[ 1.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  1.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  1.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  1.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  1.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  1.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  1.],
       [-1., -0., -0., -0., -0., -0., -0.],
       [-0., -1., -0., -0., -0., -0., -0.],
       [-0., -0., -1., -0., -0., -0., -0.],
       [-0., -0., -0., -1., -0., -0., -0.],
       [-0., -0., -0., -0., -1., -0., -0.],
       [-0., -0., -0., -0., -0., -1., -0.],
       [-0., -0., -0., -0., -0., -0., -1.]], dtype=float32)

In [43]:
val = error@jnp.diag(weight)@error
grad = jac.T @ error
hess = jac.T @ jnp.diag(weight) @ jac

In [None]:
# merit fn
merit = 
for constr_fn in constr_fns:


In [None]:
#f

#g

#h