In [1]:
import jax
import jax.numpy as jnp
import mujoco
from mujoco import mjx
from mujoco_logger import SimLog
from robot_descriptions.skydio_x2_mj_description import MJCF_PATH

from mujoco_sysid.mjx.convert import logchol2theta, theta2logchol
from mujoco_sysid.mjx.parameters import get_dynamic_parameters, set_dynamic_parameters


### What is the idea behind the notebook?

- we have gathered some data from the quadrotor running in the simulator
- we have an access to the full state and control inputs
- we want to use this data in MJX to train a model that can identify the dynamical parameters of the quadrotor


In [2]:
model = mujoco.MjModel.from_xml_path(MJCF_PATH)
data = mujoco.MjData(model)
log = SimLog("quadrotor.json")

true_parameters = get_dynamic_parameters(model, 1)
true_parameters

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


Array([ 1.325   ,  0.      ,  0.      ,  0.0715  ,  0.04051 ,  0.      ,
        0.02927 , -0.0021  ,  0.      ,  0.060528], dtype=float32)

In [3]:
mjx_model = mjx.put_model(model)


def smart_step(acc, vel, pos, ctrl, parameters):
    # update the parameters
    new_model = set_dynamic_parameters(mjx_model, 1, parameters)
    # new_model = mjx_model

    mjx_data = mjx.make_data(new_model)
    # set initial data for the step
    mjx_data = mjx_data.replace(qacc=acc, qvel=vel, qpos=pos, ctrl=ctrl)
    # step the simulation
    mjx_data = mjx.step(new_model, mjx_data)
    return mjx_data.qpos, mjx_data.qvel
    # return mjx_data.qpos.at[0], mjx_data.qvel.at[0]


smart_step = jax.jit(smart_step)

In [4]:
pos = jnp.array(log.data("qpos")[0])
vel = jnp.array(log.data("qvel")[0])
acc = jnp.array(log.data("qacc")[0])
ctrl = jnp.array(log.data("ctrl")[0])

print(pos, vel, acc, ctrl)

[-2.8813668e-04  2.7725861e-07  9.9750474e-02  9.9999642e-01
  2.5690015e-06  2.6697947e-03  2.6182667e-07] [-2.8813669e-02  2.7725859e-05 -2.4952721e-02  5.1380089e-04
  5.3395963e-01  5.2365394e-05] [-2.8813670e+00  2.7725860e-03 -2.4952722e+00  5.1380090e-02
  5.3395962e+01  5.2365395e-03] [ 4.8408070e+00  4.8512077e+00 -5.7623768e-09 -5.7623768e-09]


In [6]:
# now we add 1kg to the mass
mass = true_parameters.at[0].get()
# true_parameters = true_parameters.at[0].set(mass + 1)

next_pos, next_vel = smart_step(acc, vel, pos, ctrl, true_parameters)

print("pos", log.data("qpos")[1])
print("next_pos", next_pos)

print("diff l2", jnp.linalg.norm(log.data("qpos")[1] - next_pos))

pos [-8.60458879e-04  8.27660234e-07  9.92544927e-02  9.99967927e-01
  7.70652416e-06  8.00899567e-03  7.87957120e-07]
next_pos [-6.9424306e-04  1.5458909e-05  9.9253602e-02  9.9997908e-01
  1.4327142e-04  6.4689084e-03 -7.6393530e-07]
diff l2 0.0015550618
