In [1]:
import jax
import jax.numpy as jnp
import jax.typing as jpt
import mujoco as mj
import mujoco.mjx as mjx
from mujoco.mjx._src import scan
from mujoco.mjx._src.math import transform_motion
from mujoco.mjx._src.types import DisableBit
from robot_descriptions.z1_mj_description import MJCF_PATH

jnp.set_printoptions(precision=5, suppress=True, linewidth=500)

key = jax.random.PRNGKey(0)

2024-07-04 10:15:46.022376: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [2]:
mjmodel = mj.MjModel.from_xml_path(MJCF_PATH)
mjdata = mj.MjData(mjmodel)

# alter the model so it becomes mjx compatible
mjmodel.dof_frictionloss = 0
mjmodel.opt.integrator = 1

mjxmodel = mjx.put_model(mjmodel)
mjxdata = mjx.put_data(mjmodel, mjdata)

  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


In [3]:
q, v, dv = jax.random.normal(key, (3, mjmodel.nq))

mj_inverse + mj_rnePostConstraint

In [4]:
mjdata.qpos = q
mjdata.qvel = v
mjdata.qacc = dv

mj.mj_inverse(mjmodel, mjdata)
mjdata.cacc

array([[0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.]])

In [5]:
mj.mj_rnePostConstraint(mjmodel, mjdata)

mjdata.cacc

array([[ 0.     ,  0.     ,  0.     , -0.     , -0.     ,  9.81   ],
       [ 0.     ,  0.     ,  0.     ,  0.     ,  0.     ,  9.81   ],
       [ 0.     ,  0.     , -0.24564,  0.04451,  0.03984,  9.81   ],
       [ 0.26177,  0.17445, -0.24564,  0.05323,  0.02676,  9.88573],
       [ 2.22556, -0.4742 , -0.24564, -0.02487, -0.14671,  9.75662],
       [ 3.04146, -1.35822, -0.24564, -0.14381,  0.16344, 10.17372],
       [ 3.97377, -0.57563,  3.77315,  0.23265,  0.7747 , 10.4696 ],
       [ 4.84202,  0.20971,  3.80627,  0.14318,  0.83105, 10.06915]])

In [6]:
mjdata.cdof, mjdata.cdof_dot

(array([[ 0.     ,  0.     ,  1.     , -0.1812 , -0.1622 ,  0.     ],
        [ 0.7655 ,  0.64344,  0.     ,  0.03215, -0.03825,  0.24308],
        [ 0.7655 ,  0.64344,  0.     ,  0.06266, -0.07455, -0.1037 ],
        [ 0.7655 ,  0.64344,  0.     , -0.0581 ,  0.06912, -0.22839],
        [-0.63989,  0.76128, -0.10492, -0.14118, -0.11763,  0.00753],
        [-0.75908, -0.60487,  0.24068,  0.14133, -0.08628,  0.22892]]),
 array([[ 0.     ,  0.     ,  0.     ,  0.     ,  0.     ,  0.     ],
        [-0.61642,  0.73335,  0.     ,  0.03664,  0.0308 ,  0.00726],
        [-0.61642,  0.73335,  0.     ,  0.07954,  0.05036,  0.00898],
        [-0.61642,  0.73335,  0.     ,  0.09224, -0.24418, -0.34267],
        [-0.51814, -0.86424, -3.1108 , -0.24798, -0.43955, -0.23407],
        [-0.05795, -0.44941, -1.31224, -0.30063,  0.1761 ,  0.61546]]))

mjx_rne + mjx_rnePostConstraint

In [7]:
mjxdata = mjxdata.replace(qpos=q, qvel=v, qacc=dv)


def mjx_invPosition(m: mjx.Model, d: mjx.Data) -> mjx.Data:
    d = mjx.kinematics(m, d)
    d = mjx.com_pos(m, d)
    d = mjx.camlight(m, d)
    # flex is missing
    # tendon is missing

    d = mjx.crb(m, d)
    d = mjx.factor_m(m, d)
    d = mjx.collision(m, d)
    d = mjx.make_constraint(m, d)
    d = mjx.transmission(m, d)

    return d


def mjx_invVelocity(m: mjx.Model, d: mjx.Data) -> mjx.Data:
    d = mjx.fwd_velocity(m, d)

    return d


def mjx_invConstraint(m: mjx.Model, d: mjx.Data) -> mjx.Data:
    # return data if there are no constraints
    if d.nefc == 0:
        return d

    # jar = Jac*qacc - aref
    jar = d.efc_J @ d.qacc - d.efc_aref

In [None]:
def mjx_inverse(m: mjx.Model, d: mjx.Data) -> mjx.Data:
    d = mjx_invPosition(m, d)
    d = mjx_invVelocity(m, d)

    # acceleration dependent
    mjx

    d = mjx.rne(m, d)

    return d


mjxdata = mjx.rne(mjxmodel, mjxdata)

In [8]:
def com_acc(m: mjx.Model, d: mjx.Data) -> jpt.ArrayLike:
    # forward scan over tree: accumulate link center of mass acceleration
    def cacc_fn(cacc, cdof_dot, qvel):
        if cacc is None:
            if m.opt.disableflags & DisableBit.GRAVITY:
                cacc = jnp.zeros((6,))
            else:
                cacc = jnp.concatenate((jnp.zeros((3,)), -m.opt.gravity))

        cacc += jnp.sum(jax.vmap(jnp.multiply)(cdof_dot, qvel), axis=0)

        return cacc

    return scan.body_tree(m, cacc_fn, "vv", "b", d.cdof_dot, d.qvel)


com_acc(mjxmodel, mjxdata)

Array([[0.  , 0.  , 0.  , 0.  , 0.  , 9.81],
       [0.  , 0.  , 0.  , 0.  , 0.  , 9.81],
       [0.  , 0.  , 0.  , 0.  , 0.  , 9.81],
       [0.  , 0.  , 0.  , 0.  , 0.  , 9.81],
       [0.  , 0.  , 0.  , 0.  , 0.  , 9.81],
       [0.  , 0.  , 0.  , 0.  , 0.  , 9.81],
       [0.  , 0.  , 0.  , 0.  , 0.  , 9.81],
       [0.  , 0.  , 0.  , 0.  , 0.  , 9.81]], dtype=float32)

In [9]:
mjxdata.cdof, mjxdata.cdof_dot

(Array([[0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.]], dtype=float32),
 Array([[0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.]], dtype=float32))

In [10]:
def mjx_rnePostConstraint(m: mjx.Model, d: mjx.Data):
    nbody = m.nbody

    all_cacc = jnp.zeros((nbody, 6))

    # clear cacc, set world acceleration to -gravity
    if not m.opt.disableflags & DisableBit.GRAVITY:
        all_cacc = all_cacc.at[0, 3:].set(-m.opt.gravity)

    # FIXME: assumption that xfrc_applied is zero
    # FIXME: assumption that contacts are zero
    # FIXME: assumption that connect and weld constraints are zero

    # forward pass over bodies: compute acc
    for j in range(nbody):
        bda = m.body_dofadr[j]

        # cacc = cacc_parent + cdofdot * qvel + cdof * qacc
        cacc_j = all_cacc[m.body_parentid[j]] + d.cdof_dot[bda] * d.qvel[bda] + d.cdof[bda] * d.qacc[bda]
        all_cacc = all_cacc.at[j].set(cacc_j)

    return all_cacc


mjx_rnePostConstraint(mjxmodel, mjxdata)

Array([[ 0.  ,  0.  ,  0.  , -0.  , -0.  ,  9.81],
       [ 0.  ,  0.  ,  0.  , -0.  , -0.  ,  9.81],
       [ 0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  9.81],
       [ 0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  9.81],
       [ 0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  9.81],
       [ 0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  9.81],
       [ 0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  9.81],
       [ 0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  9.81]], dtype=float32)