In [1]:
import numpy as np
import jax
from jax import Array
import jax.numpy as jnp
import jax_dataclasses as jdc
from typing import *
from jax.experimental.sparse import BCOO

In [None]:
from sdf_world.sdf_world import *
from sdf_world.robots import *
from sdf_world.util import *

world = SDFWorld()
panda_model = RobotModel(PANDA_URDF, PANDA_PACKAGE)
panda = Robot(world.vis, "panda", panda_model, alpha=0.5)
panda.reduce_dim([7, 8], [0.04, 0.04])

In [78]:
@jdc.pytree_dataclass
class Params:
    q0: Array
    target: Array
    
@jdc.pytree_dataclass
class Var:
    name: str
    dim: int

    def freeze(self, key):
        return FrozenVar(key, self.dim)

@jdc.pytree_dataclass
class FrozenVar:
    key: int
    dim: int

@jdc.pytree_dataclass
class Feature:
    name: str
    dim: int
    vars: List[Var]
    value_and_jac_fn: Callable
    weight: Array

    @property
    def x_dim(self):
        return np.sum([var.dim for var in self.vars])

    def freeze(self, key, x_coord, jac_indices): #frozen_var_map:Dict[str, FrozenVar]
        #frozen_vars = [frozen_var_map[var.name] for var in self.vars]
        return FrozenFeature(
            key, self.dim, self.x_dim, x_coord, jac_indices,
            jax.tree_util.Partial(self.value_and_jac_fn),
            self.weight
        )
    
@jdc.pytree_dataclass
class FrozenFeature:
    key: int
    dim: int
    x_dim: int
    x_coord: Array
    jac_indices: Array
    value_and_jac_fn: Callable
    weight: Array

In [89]:
@jdc.pytree_dataclass
class FactorGraph:
    x_dim:int
    feat_dim:int
    vars: Dict[str, FrozenVar]
    features: Dict[str, FrozenFeature]
    #coords_map: Dict[str, Array]
    # feature_var_map: Tuple
    # x_coords_map: List[Array]
    # x_offsets: List[int]
    # jac_indices: List[Array]
    # weight_mat: BCOO

class FactorGraphBuilder:
    def __init__(self):
        self.vars: Dict[str,Var] = dict()
        self.features: Dict[str, Feature] = dict()

    @property
    def x_dim(self):
        return np.sum([var.dim for var in self.vars.values()])
    @property
    def feat_dim(self):
        return np.sum([feat.dim for feat in self.features.values()])
    @property
    def weight_vec(self):
        weights = []
        for name, feature in self.features.items():
            self.jac_indices[name] = self.get_jac_indices(feature)
            weights.append(feature.weight)
        return np.hstack(weights)
    
    def add_var(self, var:Var):
        self.vars[var.name] = var

    def add_feature(self, feature:Feature):
        self.features[feature.name] = feature

    def get_jac_indices(self, feature):
        indices_list = []
        for var in feature.vars:
            indices = np.indices((feature.dim, var.dim)).reshape(2, -1).T
            indices[:,0] += self.feat_offset[feature.name]
            indices[:,1] += self.x_offsets[var.name]
            indices_list.append(indices)
        return np.vstack(indices_list)
    
    def build(self):
        self.x_coords_map = dict()
        self.x_offsets = dict()
        self.feat_offset = dict()
        self.jac_indices = dict()

        frozen_vars = {}
        for idx, var in enumerate(self.vars.values()):
            frozen_vars[idx] = var.freeze(idx)
        
        #calculate x indices(coordinate)
        offset = 0
        for name, var in self.vars.items():
            self.x_coords_map[name] = np.arange(var.dim) + offset
            self.x_offsets[name] = offset
            offset += var.dim

        #calculate feature indices
        offset = 0
        for name, feat in self.features.items():
            self.feat_offset[name] = offset
            offset += feat.dim

        frozen_features = {}
        for idx, feat in enumerate(self.features.values()):
            var_names = [var.name for var in feat.vars]
            x_coords = [self.x_coords_map[name] for name in var_names]
            jac_indices = self.get_jac_indices(feat)
            frozen_features[idx] =\
                feat.freeze(
                    idx, 
                    np.hstack(x_coords), 
                    jac_indices)

        return FactorGraph(
            self.x_dim, 
            self.feat_dim,
            frozen_vars, 
            frozen_features)

        # #calculate x_coords, offset
        # offset = 0
        # for name, var in self.vars.items():
        #     self.x_coords[name] = np.arange(var.dim) + offset
        #     self.x_offsets[name] = offset
        #     offset += var.dim
            
        # offset = 0
        # for name, feature in self.features.items():
        #     self.feat_offset[name] = offset
        #     offset += feature.dim

        # #freeze var, features
        # self.frozen_vars = []
        # self.var_idx_map = {}
        # for i, var in enumerate(self.vars.values()):
        #     self.frozen_vars.append(var.freeze(i))
        #     self.var_idx_map[var.name] = i
        # self.frozen_features = []
        # for i, feature in enumerate(self.features.values()):
        #     self.frozen_features.append(
        #         feature.freeze(i, self.var_idx_map))
        # eye_indices = np.tile(np.arange(self.feat_dim), 2).reshape(2,-1).T
        # weight_mat = BCOO((self.weight_vec, eye_indices), shape=(self.feat_dim,self.feat_dim))
        # var_indices = {var.name:i for i, var in enumerate(self.vars.values())}
        # feature_var_map = {i:list(var_indices[var.name] for var in feature.vars)
        #                    for i, feature in enumerate(self.features.values())}
        # x_coords_map = []
        # for feature in self.features.values():
        #     var_names = [var.name for var in feature.vars]
        #     coords = [self.x_coords[name] for name in var_names]
        #     x_coords_map.append(np.hstack(coords))
        #     self.jac_indices[feature.name] = self.get_jac_indices(feature)

        # return FactorGraph(
        #     self.x_dim,
        #     self.feat_dim,
        #     self.frozen_vars,
        #     self.frozen_features,
        #     feature_var_map,
        #     x_coords_map,
        #     self.x_offsets,
        #     list(self.jac_indices.values()),
        #     weight_mat
        # )

In [11]:
@jax.jit
def get_feature_and_jac(x, param, graph:FactorGraph): #, param:Params
    vals = []
    data, indices = [], []
    for i, feature in enumerate(graph.features):
        xin = x[graph.x_coords_map[i]]
        val, jac = feature.value_and_jac_fn(xin, param)
        vals.append(val)
        data.append(jac.flatten())
        indices.append(graph.jac_indices[i])
    jac_data = jnp.hstack(data)
    jac_indices = jnp.vstack(indices)
    val = jnp.hstack(vals)
    return val, jac_data, jac_indices


In [55]:
# Kinematics
def get_rotvec_angvel_map(v):
    def skew(v):
        v1, v2, v3 = v
        return jnp.array([[0, -v3, v2],
                        [v3, 0., -v1],
                        [-v2, v1, 0.]])
    vmag = jnp.linalg.norm(v)
    vskew = skew(v)
    return jnp.eye(3) \
        - 1/2*skew(v) \
        + vskew@vskew * 1/vmag**2 * (1-vmag/2 * jnp.sin(vmag)/(1-jnp.cos(vmag)))

@jax.jit
def get_ee_fk_jac(q):
    # outputs ee_posevec and analytical jacobian
    fks = panda_model.fk_fn(q)
    p_ee = fks[-1][-3:]
    rotvec_ee = SO3(fks[-1][:4]).log()
    E = get_rotvec_angvel_map(rotvec_ee)
    jac = []
    for posevec in fks[1:8]:
        p_frame = posevec[-3:]
        rot_axis = SE3(posevec).as_matrix()[:3, 2]
        lin_vel = jnp.cross(rot_axis, p_ee - p_frame)
        jac.append(jnp.hstack([lin_vel, rot_axis]))
    jac = jnp.array(jac).T
    jac = jac.at[3:, :].set(E @ jac[3:, :])
    return jnp.hstack([p_ee, rotvec_ee]), jac

def to_posevec(pose:SE3):
    return jnp.hstack([pose.translation(), pose.rotation().log()])

In [154]:
# feature functions
pose_weight = np.array([1, 1, 1, 0.3, 0.3, 0.3])
@jax.jit
def vj_pose_error(x, param:Params):
    ee, robot_jac = get_ee_fk_jac(x)
    residual = param.target - ee
    jac = - robot_jac
    return residual, jac

In [155]:
builder = FactorGraphBuilder()
var_q1 = Var("q1", 7)
var_q2 = Var("q2", 7)
#var_pose = Var("pose", 6)
feat_pose_err1 = Feature(
    "pose_err1", 6, [var_q1], vj_pose_error, pose_weight)
feat_pose_err2 = Feature(
    "pose_err2", 6, [var_q2], vj_pose_error, pose_weight)
builder.add_var(var_q1)
builder.add_var(var_q2)
builder.add_feature(feat_pose_err1)
builder.add_feature(feat_pose_err2)

In [156]:
fg = builder.build()

In [163]:
@jax.jit
def feature_and_jac(x, param:Params, graph:FactorGraph):
    features = []
    jacs = []
    jac_indices = []
    for feature in graph.features.values():
        xin = x[feature.x_coord]
        val, jac = feature.value_and_jac_fn(xin, param)
        features.append(val)
        jacs.append(jac.flatten())
        jac_indices.append(feature.jac_indices.T)
    return jnp.hstack(features), jnp.hstack(jacs), jnp.hstack(jac_indices)

In [173]:
jax.tree_map(lambda x, param: a+1 , {1:1, 2:2})

{1: 2, 2: 3}

In [179]:
xins = {}
params = {}
for key, feature in fg.features.items():
    xins[key] = x[feature.x_coord]
    params[key] = param


In [180]:
jax.tree_map(Feature.value_and_jac_fn, )

{0: Params(q0=array([ 0.    ,  0.    ,  0.    , -1.5708,  0.    ,  1.8675,  0.    ]), target=Array([-0.23464473,  0.35054806,  0.41998622,  1.4455944 ,  1.6146585 ,
         1.8347615 ], dtype=float32)),
 1: Params(q0=array([ 0.    ,  0.    ,  0.    , -1.5708,  0.    ,  1.8675,  0.    ]), target=Array([-0.23464473,  0.35054806,  0.41998622,  1.4455944 ,  1.6146585 ,
         1.8347615 ], dtype=float32))}

In [178]:
xins

{0: Array([ 0.    ,  0.    ,  0.    , -1.5708,  0.    ,  1.8675,  0.    ],      dtype=float32),
 1: Array([ 0.    ,  0.    ,  0.    , -1.5708,  0.    ,  1.8675,  0.    ],      dtype=float32)}

In [175]:
x

Array([ 0.    ,  0.    ,  0.    , -1.5708,  0.    ,  1.8675,  0.    ,
        0.    ,  0.    ,  0.    , -1.5708,  0.    ,  1.8675,  0.    ],      dtype=float32)

In [171]:
fg.features

{0: FrozenFeature(key=0, dim=6, x_dim=7, x_coord=array([0, 1, 2, 3, 4, 5, 6]), jac_indices=array([[0, 0],
        [0, 1],
        [0, 2],
        [0, 3],
        [0, 4],
        [0, 5],
        [0, 6],
        [1, 0],
        [1, 1],
        [1, 2],
        [1, 3],
        [1, 4],
        [1, 5],
        [1, 6],
        [2, 0],
        [2, 1],
        [2, 2],
        [2, 3],
        [2, 4],
        [2, 5],
        [2, 6],
        [3, 0],
        [3, 1],
        [3, 2],
        [3, 3],
        [3, 4],
        [3, 5],
        [3, 6],
        [4, 0],
        [4, 1],
        [4, 2],
        [4, 3],
        [4, 4],
        [4, 5],
        [4, 6],
        [5, 0],
        [5, 1],
        [5, 2],
        [5, 3],
        [5, 4],
        [5, 5],
        [5, 6]]), value_and_jac_fn=Partial(<PjitFunction of <function vj_pose_error at 0x7fdc705d7f70>>), weight=array([1. , 1. , 1. , 0.3, 0.3, 0.3])),
 1: FrozenFeature(key=1, dim=6, x_dim=7, x_coord=array([ 7,  8,  9, 10, 11, 12, 13]), jac_indices=arr

In [None]:
jax.tree_map()

In [158]:
x = jnp.hstack([panda.neutral, panda.neutral])
param = Params(panda.neutral, to_posevec(make_pose()))

In [170]:
%timeit feature_and_jac(x, param, fg)

255 µs ± 6.96 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [160]:
%timeit feature_and_jac(x, param, fg)

260 µs ± 3.12 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [None]:
builder.add_feature(
    Feature()
)

In [33]:
frame = Frame(world.vis, "frame")
def make_pose():
    return SE3.from_rotation_and_translation(
        SO3(np.random.random(4)).normalize(),
        np.random.uniform([-0.3,-0.5,0.3],[0.6, 0.5, 0.8])
    )