In [1]:
import jax
import jax.numpy as jnp
import mujoco
from mujoco import mjx
from mujoco_logger import SimLog
from robot_descriptions.skydio_x2_mj_description import MJCF_PATH

from mujoco_sysid.mjx.convert import logchol2theta, theta2logchol
from mujoco_sysid.mjx.parameters import get_dynamic_parameters, set_dynamic_parameters

import matplotlib.pyplot as plt

In [2]:
model = mujoco.MjModel.from_xml_path(MJCF_PATH)
# Turn off collisions
for geom_id in range(model.ngeom):
    model.geom_contype[geom_id] = 0
    model.geom_conaffinity[geom_id] = 0


data = mujoco.MjData(model)
log = SimLog("quadrotor.json")

true_parameters = get_dynamic_parameters(model, 1)
true_parameters

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


Array([ 1.325   ,  0.      ,  0.      ,  0.0715  ,  0.04051 ,  0.      ,
        0.02927 , -0.0021  ,  0.      ,  0.060528], dtype=float32)

In [3]:
log_qpos = jnp.array(log.data("qpos"))
log_qvel = jnp.array(log.data("qvel"))
log_ctrl = jnp.array(log.data("ctrl"))

log_qpos.shape, log_qvel.shape, log_ctrl.shape

((1001, 7), (1001, 6), (1001, 4))

In [4]:
log_x = jnp.concatenate([log_qpos, log_qvel], axis=-1)

log_x.shape

(1001, 13)

In [5]:
from mujoco_sysid.mjx.model import rollout

In [6]:
mjx_model = mjx.put_model(model)

x = log_x[0]
ctrls = log_ctrl[:30]

x.shape, ctrls.shape

((13,), (30, 4))

In [7]:
# rollout = jax.jit(rollout)

# rollout(mjx_model, x, ctrls, true_parameters)

In [8]:
# %timeit rollout(mjx_model, x, ctrls, true_parameters)

In [9]:
model.body_mass

array([0.   , 1.325])

In [10]:
rollout2 = jax.jit(rollout)

rollout2(mjx_model, x, ctrls, true_parameters);

In [11]:
%timeit rollout2(mjx_model, x, ctrls, true_parameters)

3.79 ms ± 188 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [12]:
%timeit rollout2(mjx_model, x, ctrls, true_parameters)

3.67 ms ± 178 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [13]:
log_x[1]

Array([-8.6045888e-04,  8.2766024e-07,  9.9254496e-02,  9.9996793e-01,
        7.7065242e-06,  8.0089960e-03,  7.8795711e-07, -5.7232220e-02,
        5.5040164e-05, -4.9598008e-02,  1.0275191e-03,  1.0678567e+00,
        1.0522765e-04], dtype=float32)