In [1]:
import jax
import jax.numpy as jnp

In [2]:
import mujoco
from mujoco import mjx

In [3]:
host_model = mujoco.MjModel.from_xml_path("../../latch/env/finger/assets/finger.xml")

In [4]:
device_model = mjx.put_model(host_model)

In [5]:
def step_new_data(state, action):
    data = mjx.make_data(device_model)
    data = data.replace(qpos=state[:3])
    data = data.replace(qvel=state[3:])
    data = data.replace(ctrl=action)
    data = mjx.step(device_model, data)
    return jnp.concatenate([data.qpos, data.qvel])

In [6]:
def step_existing_data(state, action, data):
    data = data.replace(qpos=state[:3])
    data = data.replace(qvel=state[3:])
    data = data.replace(ctrl=action)
    data = mjx.step(device_model, data)
    return jnp.concatenate([data.qpos, data.qvel])

In [7]:
@jax.jit
def collect_traj_new_data(key):
    def scanf(
        state,
        key,
    ):
        action = jax.random.normal(key, shape=(2,))
        next_state = step_new_data(state, action)
        return next_state, next_state

    rng, key = jax.random.split(key)
    rngs = jax.random.split(rng, 100)

    initial_state = jnp.zeros(6)

    return jax.lax.scan(scanf, initial_state, rngs)

@jax.jit
def collect_traj_existing_data(key):
    data = mjx.make_data(device_model)

    def scanf(
        state,
        key,
    ):
        action = jax.random.normal(key, shape=(2,))
        next_state = step_existing_data(state, action, data)
        return next_state, next_state

    rng, key = jax.random.split(key)
    rngs = jax.random.split(rng, 100)

    initial_state = jnp.zeros(6)

    return jax.lax.scan(scanf, initial_state, rngs)

# Now benchmark the two

In [8]:
# Benchmark
import timeit
import numpy as np

# Compile
_ = collect_traj_new_data(jax.random.PRNGKey(0))
_ = collect_traj_existing_data(jax.random.PRNGKey(0))

# Run
n = 32768 * 4

def new_bench():
    key = jax.random.PRNGKey(0)
    rng, key = jax.random.split(key)
    rngs = jax.random.split(rng, n)

    result = jax.vmap(collect_traj_new_data)(rngs)

    return result

def existing_bench():
    key = jax.random.PRNGKey(0)
    rng, key = jax.random.split(key)
    rngs = jax.random.split(rng, n)

    result = jax.vmap(collect_traj_existing_data)(rngs)

    return result

print("New data")
print(timeit.timeit(new_bench, number=1))
print("Existing data")
print(timeit.timeit(existing_bench, number=1))


New data
16.364803396998468
Existing data
15.631491056999948


In [9]:
new_bench()

(Array([[ 3.9901916e-02, -8.9636352e-03,  8.4578723e-02, -6.5824330e-01,
         -1.8807429e+00,  2.6832926e+00],
        [-3.3659067e-02,  2.4751777e-02, -7.5314417e-02, -1.9174249e-01,
          1.3330293e+00,  7.1752272e+00],
        [ 4.6820652e-02,  5.2667715e-02, -6.4203888e-02,  1.7702432e-01,
         -3.7799972e-01, -3.6240232e+00],
        ...,
        [ 1.9170301e-02,  2.0148201e-02,  2.3057230e-02,  1.9279389e-01,
         -2.0671678e-01, -3.6449184e+00],
        [-1.8309964e-03, -3.3346865e-02, -2.4301572e-01, -8.8702962e-03,
         -1.3158377e+00,  5.0260949e+00],
        [-9.1674760e-02, -5.3285724e-03,  6.9978476e-02, -4.9782768e-02,
         -1.8344868e+00,  1.5397359e+00]], dtype=float32),
 Array([[[ 0.0000000e+00, -3.1847646e-03,  4.8278603e-03,  0.0000000e+00,
          -1.5923822e+00,  2.4139299e+00],
         [ 0.0000000e+00, -8.9949630e-03,  1.3375195e-02,  0.0000000e+00,
          -2.9050994e+00,  4.2736669e+00],
         [ 0.0000000e+00, -1.5725939e-02,  1.6