In [1]:
import jax
import jax.numpy as jnp
import mujoco
from mujoco import mjx
from mujoco_logger import SimLog
from robot_descriptions.skydio_x2_mj_description import MJCF_PATH
import jaxlie
from mujoco_sysid.mjx.convert import logchol2theta, theta2logchol
from mujoco_sysid.mjx.parameters import get_dynamic_parameters, set_dynamic_parameters

import matplotlib.pyplot as plt

key = jax.random.PRNGKey(0)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [2]:
model = mujoco.MjModel.from_xml_path(MJCF_PATH)
mjx_model = mjx.put_model(model)

data = mujoco.MjData(model)
log = SimLog("quadrotor.json")

true_parameters = get_dynamic_parameters(model, 1)
true_parameters


Array([ 1.325   ,  0.      ,  0.      ,  0.0715  ,  0.04051 ,  0.      ,
        0.02927 , -0.0021  ,  0.      ,  0.060528], dtype=float32)

In [3]:
log_qpos = jnp.array(log.data("qpos"))
log_qvel = jnp.array(log.data("qvel"))
log_x = jnp.concatenate([log_qpos, log_qvel], axis=-1)
log_ctrl = jnp.array(log.data("ctrl"))

log_x.shape, log_ctrl.shape


((1001, 13), (1001, 4))

In [4]:
def diff_x(x1, x2):
    # qpos = [x, y, z, qw, qx, qy, qz]
    qpos1 = x1[:7]
    qpos2 = x2[:7]
    quat1 = qpos1[3:][jnp.array([3, 0, 1, 2])]
    quat2 = qpos2[3:][jnp.array([3, 0, 1, 2])]
    q1 = jaxlie.SO3.from_quaternion_xyzw(quat1)
    q2 = jaxlie.SO3.from_quaternion_xyzw(quat2)

    vel1 = x1[7:]
    vel2 = x2[7:]

    return jnp.concatenate(
        [
            qpos1[:3] - qpos2[:3],
            jaxlie.SO3.log(q1.inverse() @ q2),
            vel1 - vel2,
        ]
    )

In [5]:
from mujoco_sysid.mjx.loss import create_compute_loss
from mujoco_sysid.mjx.model import rollout, step

step = jax.jit(step)

# rollout = jax.jit(rollout)
# compute_loss = jax.jit(compute_loss)

In [6]:
horizon = 1  # horizon: we know all the measured data i - N to i

theta_estimate = get_dynamic_parameters(model, 1)
logchol = theta2logchol(theta_estimate)
logchol += jax.random.normal(key, logchol.shape) * 0.3
logchol_estimate = logchol

# estimate = logchol # we optimize on logcholesky space

logchol_estimate


Array([ 0.02907295, -2.0457456 , -1.8591964 , -3.9478104 , -0.13209113,
       -0.04564326, -0.13556199, -0.17725924,  0.21950667,  0.22415304],      dtype=float32)

In [7]:
def error_step(logchol_estimate, mjx_model, x, ctrl):
    theta_estimate = logchol2theta(logchol_estimate)
    x_next = step(theta_estimate, mjx_model, x, ctrl)
    return jnp.linalg.norm(diff_x(x_next, x), 2)


jax.value_and_grad(error_step)(logchol_estimate, mjx_model, log_x[0], log_ctrl[0])

(Array(0.21685232, dtype=float32),
 Array([-4.2706367e-01, -9.6768722e-02, -1.3716984e-01,  1.3185101e-02,
         9.8154247e-01, -1.3037524e-01,  6.5588957e-01,  1.1920929e-07,
        -2.3841858e-07,  0.0000000e+00], dtype=float32))

In [8]:
def error_rollout(logchol_estimate, mjx_model, x, ctrls, expected_x):
    theta_estimate = logchol2theta(logchol_estimate)
    x_hist = rollout(theta_estimate, mjx_model, x, ctrls)
    return jnp.linalg.norm(diff_x(x_hist[-1], expected_x), 2)


error_rollout_compiled = jax.jit(jax.value_and_grad(error_rollout))

error_rollout_compiled(logchol_estimate, mjx_model, log_x[0], log_ctrl[:3], log_x[3])

(Array(0.9799221, dtype=float32),
 Array([ 1.1851026e+00,  2.5155231e-01,  3.5932395e-01, -4.7413106e-03,
        -3.3092239e+00,  8.9998448e-01, -1.6065929e+00, -2.3841858e-07,
         0.0000000e+00,  0.0000000e+00], dtype=float32))

In [9]:
error_rollout_compiled(logchol_estimate, mjx_model, log_x[0], log_ctrl[:3], log_x[3])


(Array(0.9799221, dtype=float32),
 Array([ 1.1851026e+00,  2.5155231e-01,  3.5932395e-01, -4.7413106e-03,
        -3.3092239e+00,  8.9998448e-01, -1.6065929e+00, -2.3841858e-07,
         0.0000000e+00,  0.0000000e+00], dtype=float32))

In [10]:
error_rollout_compiled(logchol_estimate, mjx_model, log_x[0], log_ctrl[:100], log_x[100])


(Array(4.398834, dtype=float32),
 Array([-1.1211424e+01,  9.0352692e-02,  1.1775113e+00, -6.3438274e-02,
        -7.4513948e-01, -2.0146542e+00, -5.2254528e-01, -6.4373016e-06,
         6.6757202e-06,  5.7220459e-06], dtype=float32))

In [11]:
import optax

start_learning_rate = 1e-3
optimizer = optax.fromage(learning_rate=start_learning_rate)

parameters = logchol_estimate
opt_state = optimizer.init(parameters)

In [17]:
horizon = 100

for i in range(horizon, len(log)):
    x_start = log_x[i - horizon]
    ctrls = log_ctrl[i - horizon : i]

    loss, grad = error_rollout_compiled(parameters, mjx_model, x_start, ctrls, log_x[i])

    updates, opt_state = optimizer.update(grad, opt_state, parameters)
    parameters = optax.apply_updates(parameters, updates)

    # print(f"Step {i}, loss: {loss}")

    # if i % 10 == 0:
    #     print(f"Step {i}, loss: {loss}")
    #     print(logchol2theta(parameters))

In [18]:
print(f"Optimized theta parameters: {logchol2theta(parameters)}")
print(f"True theta parameters: {true_parameters}")

print(f"LogChol distance: {jnp.linalg.norm(parameters - theta2logchol(true_parameters), 2)}")

Optimized theta parameters: [ 1.3125805  -0.23245792  0.28787327  0.29397175  0.16218095  0.05153608
  0.14185844  0.05439965 -0.06343957  0.17137577]
True theta parameters: [ 1.325     0.        0.        0.0715    0.04051   0.        0.02927
 -0.0021    0.        0.060528]
LogChol distance: 0.4429807960987091
