# Simple RL code to teach the model to stand

## Setup jax enviroment and Load Model 

In [1]:
import numpy as np
import mediapy as media
import matplotlib.pyplot as plt


import os
# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".60"

# Optionally, force JAX to preallocate memory.
# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "true"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

# Setup environment variable for Nvidia GPU acceleration
os.environ['XLA_FLAGS'] = (
    '--xla_gpu_triton_gemm_any=True '
)

backend = 'gpu'
# backend = 'METAL'
# backend = 'cpu'

import jax
# Enable compliation catch
jax.config.update("jax_compilation_cache_dir", "./jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

# Debug Nan
# jax.config.update("jax_debug_nans", True)

from jax import numpy as jp
# More legible printing from numpy.
jp.set_printoptions(precision=4, suppress=True, linewidth=100)

import mujoco
import mujoco.mjx as mjx
from mujoco.mjx._src import scan
from mujoco.mjx._src import types

# More legible printing from numpy.
np.set_printoptions(precision=4, suppress=True, linewidth=100)

from IPython.display import clear_output
clear_output()

device = jax.devices(backend=backend)[0]

model_path = '/home/bugman/Currentwork/biomujoco_converter/converted/mjc/Gait2354/gait2354_cvt1.xml'

# Single step
mjx_step = jax.jit(mjx.step, backend=backend)



# mjx_multiple_steps = jax.jit(multiple_steps, backend=backend, )

In [2]:
from mujoco.mjx._src.biomtu.acceleration_mtu import acceleration_mtu

mj_model = mujoco.MjModel.from_xml_path(model_path)
mjx_model = mjx.put_model(mj_model,device=device)

# Disable tendon
opt = mjx_model.opt.replace(disableflags = mjx_model.opt.disableflags |mujoco.mjtDisableBit.mjDSBL_PASSIVE)
mjx_model = mjx_model.replace(opt=opt)

mjx_data = mjx.make_data(mjx_model)
mj_data = mujoco.MjData(mj_model)

# Load the Keyframe
mjx_data = mjx_data.replace(qpos = mj_model.key_qpos[0])
mj_data.qpos = mj_model.key_qpos[0]

# Calculate equilibrum
mjx_data = acceleration_mtu.calc_equilibrium(mjx_model, mjx_data)
mjx_data = mjx_step(mjx_model, mjx_data)

def print_all():
    print(mjx_model.biomtu_adr)
    print(mjx_model.mtu_wrap_objid)
    print(mjx_model.mtu_wrap_type)
    print(mjx_model.biomtu_fiso)
    print(mjx_model.biomtu_vmax)
    print(mjx_model.biomtu_ofl)
    print(mjx_model.biomtu_opa)
    print(mjx_model.biomtu_mass)
    print("-------Data--------")
    print("qpos:", mjx_data.qpos)
    print("mtu l:", mjx_data.biomtu.l)
    print("tendon l:", mjx_data.biomtu.tendon_l)
    print("fiber l :", mjx_data.biomtu.fiber_l)
    print("Muscle Bce:", mjx_data.biomtu.B_ce)
    print("Muscle vm:", mjx_data.biomtu.m)
    print("Fiber acc:", mjx_data.biomtu.fiber_acc)
    print("Fiber v:", mjx_data.biomtu.fiber_v)
    print("Biomtu h:", mjx_data.biomtu.h)
    print(mjx_data.biomtu.v)
    print(mjx_data.biomtu.h)  # The constant high of the muscle.
    print(mjx_data.biomtu.pennation_angle)
    print(mjx_data.biomtu.origin_body_id)
    print(mjx_data.biomtu.insertion_body_id)
    print("mtu act:", mjx_data.biomtu.act)
    # print(mjx_data.biomtu.j)
    print(mjx_data.qfrc_biomtu)
    print(mj_model.key_time)
    print(mj_model.key_qpos)
    print(mj_model.key_qvel)

print_all()

--------------------
BUGMAN: muscle name-glut_med1_r
--------------------
BUGMAN: muscle name-glut_med2_r
--------------------
BUGMAN: muscle name-glut_med3_r
--------------------
BUGMAN: muscle name-bifemlh_r
--------------------
BUGMAN: muscle name-bifemsh_r
--------------------
BUGMAN: muscle name-sar_r
--------------------
BUGMAN: muscle name-add_mag2_r
--------------------
BUGMAN: muscle name-tfl_r
--------------------
BUGMAN: muscle name-pect_r
--------------------
BUGMAN: muscle name-grac_r
--------------------
BUGMAN: muscle name-glut_max1_r
--------------------
BUGMAN: muscle name-glut_max2_r
--------------------
BUGMAN: muscle name-glut_max3_r
--------------------
BUGMAN: muscle name-iliacus_r
--------------------
BUGMAN: muscle name-psoas_r
--------------------
BUGMAN: muscle name-quad_fem_r
--------------------
BUGMAN: muscle name-gem_r
--------------------
BUGMAN: muscle name-peri_r
--------------------
BUGMAN: muscle name-rect_fem_r
--------------------
BUGMAN: muscle nam

In [3]:
print(mjx_model.nbiomtu)
print(mjx_model.nq)
print(mjx_data.qpos)
print(mjx_data.qvel)

54
27
[ 0.      0.9501  0.      0.0006  0.     -0.     -0.0013 -0.     -0.0009 -0.0037 -0.3957  0.0023
 -0.0199 -0.0032  0.0228 -0.0013 -0.     -0.0009 -0.0037 -0.3957  0.0023 -0.0199 -0.0032  0.0228
 -0.0007 -0.      0.    ]
[ 0.0149  0.0271  0.      0.3064  0.     -0.     -0.6693 -0.0194 -0.4365 -0.0074  0.0056  1.1327
 -9.9627 -1.625  11.377  -0.6693 -0.0194 -0.4365 -0.0074  0.0056  1.1327 -9.9628 -1.625  11.377
 -0.3256 -0.      0.    ]


## Neural Network

In [4]:
import jax
import jax.numpy as jp
import flax
import flax.linen as nn
import optax

class Controller_NN(nn.Module):
    # It will output the mean and variance for each muscle's activation
    nbiomtu = mjx_model.nbiomtu
    def setup(self):
        # Features means the output dimension# Single step

        self.linear1 = nn.Dense(features=400)
        self.linear2 = nn.Dense(features=400)
        self.linear3 = nn.Dense(features=400)
        self.linear4 = nn.Dense(features=400)
        # The last layer will output the mean and logstd
        self.linear5 = nn.Dense(features=self.nbiomtu*2)
        
    
    def __call__(self, x, key):
        x = self.linear1(x)
        x = nn.relu(x)
        x = self.linear2(x)
        x = nn.relu(x)
        x = self.linear3(x)
        x = nn.relu(x)
        x = self.linear4(x)
        x = nn.relu(x)
        x = self.linear5(x)
        # The last layer of the neural requires samping
        mean = x[0:self.nbiomtu]
        logstd = x[self.nbiomtu:self.nbiomtu*2]
        std = jp.exp(logstd)
        samples = jp.clip(jax.random.normal(key)*std*0.3 + mean, 0 ,1)
        
        return samples, mean, logstd


# Test the neural network
control_model = Controller_NN()

print(control_model)
# Init the model
key = jax.random.PRNGKey(66)
sub_keys = jax.random.split(key,1)
# The second parameter is the dommy input
params = control_model.init(key,jp.empty([1,mjx_model.nq*2]),sub_keys[0])
# print(params)
print(control_model.apply( params,jp.ones(mjx_model.nq*2), sub_keys[0]))
jit_nn_apply = jax.jit(lambda params,states,key : control_model.apply(params,states,key))

Controller_NN()
(Array([0.3537, 0.4582, 0.    , 0.    , 0.3568, 0.    , 0.    , 0.    , 0.    , 0.0299, 0.1213,
       0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.0104, 0.    , 0.    , 0.    ,
       0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.4784, 0.    , 0.2003, 0.    , 0.    ,
       0.    , 0.    , 0.2286, 0.    , 0.    , 0.    , 0.    , 0.2364, 0.    , 0.    , 0.    ,
       0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    ],      dtype=float32), Array([ 0.5507,  0.7117,  0.1663, -0.386 ,  0.572 ,  0.0428,  0.151 , -0.341 , -0.1592,  0.1468,
        0.3125, -0.506 , -0.098 , -0.1812, -0.2472, -0.1027, -0.2657, -0.0098,  0.1604,  0.1802,
       -0.4265, -0.2269, -0.2294, -0.0505,  0.1361,  0.0374,  0.0641, -0.0512,  0.6791, -0.3706,
        0.4092, -0.1719, -0.2406, -0.022 ,  0.1271,  0.4203,  0.0284,  0.0943, -0.4435,  0.1295,
        0.4245,  0.0444, -0.1952,  0.121 , -0.08  , -0.3598, -0.0973, -0.0134,  0.0238, -0.2089,
       -0

## Combine Neural Net and Simulation into one Jax function

In [5]:
# Multiple steps
def step_fn(carry, _):
    data, model = carry
    new_data = mjx.step(model, data)
    new_carry = (new_data, model)
    return new_carry, _

def multiple_steps(model, data):
    init_carry = (data, model)
    y, _ = jax.lax.scan(step_fn, init_carry, None, length=10)
    new_data = y[0]
    return new_data

def nn_mjx_one_step(nn_params, model, data, key):
    states = jp.concatenate([data.qpos, data.qvel])
    act = jit_nn_apply(nn_params, states, key)[0]
    # Generate the next key
    new_key = jax.random.split(key,1)[0]
    mtu = data.biomtu
    mtu = mtu.replace(act = act)
    data = data.replace(biomtu = mtu)
    new_data = mjx.step(model, data)
    return new_data, new_key

@jax.jit
def nn_mjx_multi_steps(nn_params, model, data, key):
    def nn_step_fn(carry, _):
        data, key = carry
        new_data, new_key = nn_mjx_one_step(nn_params, model, data, key)
        new_carry = (new_data, new_key)
        # head_hight = new_data.xpos[2,2]
        # jax.debug.print("Head Height {0}",head_hight)
        return new_carry, _
    init_carry = (data, key)
    y, _ = jax.lax.scan(nn_step_fn, init_carry, None, length=10)
    new_data = y[0]
    new_key = y[1]
    return new_data, new_key

## Testing control model with neural networks

In [6]:
import mujoco.viewer
import time

previous_frame_time = time.time()
i = 0
key = jax.random.key(387)
with mujoco.viewer.launch_passive(mj_model, mj_data) as viewer:
    while viewer.is_running():
        # Update mjx_data from mj_data. The mj_data was modified by the viewer
        # mjx_data = mjx_data.replace(ctrl=mj_data.ctrl, xfrc_applied=mj_data.xfrc_applied)
        # Use the nerual network to generate ctrl signal
        # Generate key
        
        mjx_data = mjx_data.replace(xfrc_applied=mj_data.xfrc_applied)
        mjx_data = mjx_data.replace(qpos=mj_data.qpos, qvel=mj_data.qvel, time = mj_data.time)
        
        # Update mjx_model from mj_model
        mjx_model = mjx_model.tree_replace({
            'opt.gravity': mj_model.opt.gravity,
            'opt.tolerance': mj_model.opt.tolerance,
            'opt.ls_tolerance': mj_model.opt.ls_tolerance,
            'opt.timestep': mj_model.opt.timestep,
        })
        
        # mjx_data = mjx_step(mjx_model, mjx_data)
        mjx_data, key = nn_mjx_multi_steps(params, mjx_model, mjx_data, key)
        mjx.get_data_into(mj_data, mj_model, mjx_data)
        
        # Record the current time at the start of this frame
        current_frame_time = time.time()
    
        # Calculate the difference in time from the last frame
        time_between_frames = current_frame_time - previous_frame_time
    
        # Print the time between frames
        print(f"Time between frames: {time_between_frames} seconds")
        previous_frame_time = current_frame_time
        
        # print("ACT:", mjx_data.biomtu.act)
        # print(mjx_data.qpos)
        print(mjx_data.sensordata)
        # print(len(mjx_data.qvel))
        viewer.sync()

Time between frames: 5.010230779647827 seconds
[-0.1015 -0.0011  1.6299]
Time between frames: 2.173220634460449 seconds
[-0.1047 -0.0019  1.6246]
Time between frames: 0.04962801933288574 seconds
[-0.1129 -0.0027  1.6188]
Time between frames: 0.04738759994506836 seconds
[-0.1225 -0.004   1.6153]
Time between frames: 0.033091068267822266 seconds
[-0.134  -0.0056  1.6137]
Time between frames: 0.030276775360107422 seconds
[-0.1495 -0.0068  1.611 ]
Time between frames: 0.03369450569152832 seconds
[-0.1683 -0.0085  1.6057]
Time between frames: 0.03244209289550781 seconds
[-0.1885 -0.0112  1.5974]
Time between frames: 0.031516313552856445 seconds
[-0.2088 -0.0145  1.5856]
Time between frames: 0.03200340270996094 seconds
[-0.2306 -0.0187  1.5699]
Time between frames: 0.03071451187133789 seconds
[-0.2542 -0.0234  1.5505]
Time between frames: 0.0272674560546875 seconds
[-0.2793 -0.0287  1.5263]
Time between frames: 0.03175783157348633 seconds
[-0.3057 -0.0343  1.4976]
Time between frames: 0.0288

## Batched Random Init Model

In [7]:
def random_init(data, model, rng: jax.Array):
    nbiomtu = model.nq
    init_qpos = data.qpos
    init_qvel = data.qvel
    new_rng, rng1, rng2 = jax.random.split(rng, 3)
    # Qpos_1 is the vertical position
    random_qpos = init_qpos + jax.random.uniform(rng1, [nbiomtu], minval=-1, maxval=1)*0.01
    random_qvel = init_qvel + jax.random.uniform(rng2, [nbiomtu], minval=-1, maxval=1)*0.01
    newdata = data.replace(qpos=random_qpos)
    newdata = newdata.replace(qvel=random_qvel)
    newdata = mjx.forward(mjx_model, newdata)
    # print('data:',data.qpos, data.qvel)
    # Calculate equilibrum
    # newdata = acceleration_mtu.calc_equilibrium(mjx_model, newdata)
    # newdata = mjx_step(mjx_model, newdata)
    return newdata, new_rng

vrandom_init = jax.jit(jax.vmap(random_init, in_axes=(None, None, 0),out_axes=(0,0)))

## Reward Function

In [8]:
def reward_n_step(nn_params, data, key):
    repeat_n = 150
    def nn_step_fn(carry, _):
        data, key = carry
        new_data, new_key = nn_mjx_one_step(nn_params, mjx_model, data, key)
        new_carry = (new_data, new_key)
        head_hight = new_data.sensordata[2]
        # jax.debug.print("Head Height {0}",head_hight)
        return new_carry, (head_hight-1.63)**2
    
    def decay_sum_scan(x, decay):
        def f(sxtm1, xt):
            b = xt + decay * sxtm1
            return b, b
        return jax.lax.scan(f, jp.zeros(x.shape[1:]), x)[1]
    
    _, head_heights = jax.lax.scan(nn_step_fn, (data, key), None, repeat_n)
    # new_data, new_key = carry
    reward = decay_sum_scan(head_heights, 0.95)[repeat_n -1]
    return reward
    
reward_grad = jax.jit(jax.grad(reward_n_step))


def batch_reward(nn_params, batched_data, keys):
    return jp.mean(jax.vmap(reward_n_step, (None, 0, 0))(nn_params, batched_data, keys))

batch_reward_grad = jax.jit(jax.grad((batch_reward)))
jit_batch_reward = jax.jit(batch_reward)

In [9]:
# Generate batched data
batch_size = 40
seed = 2024
key = jax.random.key(seed)
rngs = jax.random.split(key, batch_size) 
mjx_data_batch, rngs = vrandom_init(mjx_data, mjx_model, rngs)
r = jit_batch_reward(params, mjx_data_batch, rngs)
print(r)
print("Calculating reward grad")
g = batch_reward_grad(params, mjx_data_batch, rngs)
mjx_data_batch, rngs = vrandom_init(mjx_data, mjx_model, rngs)
print("Calculating reward grad")
g = batch_reward_grad(params, mjx_data_batch, rngs)

48.429073
Calculating reward grad
Calculating reward grad


In [10]:
for i in range(10):
    g = batch_reward_grad(params, mjx_data_batch, rngs)

In [11]:
for i in range(10):
    r = jit_batch_reward(params, mjx_data_batch, rngs)

In [12]:
print(r)
print(g)

48.415497
{'params': {'linear1': {'bias': Array([ 0.0104, -0.0004, -0.0017, -0.0086, -0.0006,  0.0004, -0.0004, -0.0083, -0.0038, -0.0022,
       -0.0017,  0.0026,  0.0026, -0.0053,  0.0012,  0.005 , -0.0004,  0.0045, -0.0045,  0.0035,
       -0.0001,  0.0075, -0.0027,  0.0039, -0.0052, -0.0162,  0.0002,  0.002 , -0.0023,  0.0109,
        0.0084, -0.0004, -0.0002,  0.0022, -0.0008,  0.0153, -0.0006,  0.0065, -0.0045,  0.0002,
       -0.0002, -0.0028,  0.0047, -0.0039, -0.0022, -0.0129, -0.0002,  0.0019, -0.0004, -0.0027,
       -0.0018,  0.001 , -0.0056,  0.0016, -0.0013,  0.0041,  0.004 ,  0.0004,  0.0009,  0.0019,
        0.0027, -0.002 ,  0.0002, -0.0046, -0.0005,  0.0021,  0.0036, -0.0001,  0.0012, -0.0038,
       -0.0003,  0.0112,  0.0025, -0.0114, -0.0048,  0.0002,  0.0149,  0.0011,  0.0016, -0.0004,
        0.0009, -0.0011, -0.0001, -0.0059,  0.    ,  0.004 ,  0.0033, -0.0079, -0.0049,  0.0041,
       -0.001 , -0.002 , -0.0045, -0.0022, -0.0052,  0.009 , -0.0065, -0.0062,  0.004

## Train the NN

In [13]:
mjx_data_batch, rngs = vrandom_init(mjx_data, mjx_model, rngs)

In [14]:
params = control_model.init(key,jp.empty([1,mjx_model.nq*2]),sub_keys[0])


In [None]:
tx = optax.adam(learning_rate=0.0003)
opt_state = tx.init(params)

for i in range(400):
    # generate random mjx_data
    mjx_data_batch, rngs = vrandom_init(mjx_data, mjx_model, rngs)
    # print(rngs[0])
    r = jit_batch_reward(params, mjx_data_batch, rngs)
    print(r)
    g = batch_reward_grad(params, mjx_data_batch, rngs)
    updates, opt_stats = tx.update(g, opt_state)
    params = optax.apply_updates(params, updates)
    print("params updated")

48.442734
params updated
48.33086
params updated
47.283268
params updated
46.56735
params updated
46.55094
params updated
46.27307
params updated
46.08646
params updated
45.9296
params updated
46.04953
params updated
46.099174
params updated
46.111816
params updated
46.089466
params updated
45.963337
params updated
45.83628
params updated
45.78471
params updated
45.588955
params updated
45.268383
params updated
45.020596
params updated
44.522923
params updated
44.552345
params updated
44.73208
params updated
44.28784
params updated
44.04806
params updated
43.834034
params updated
44.036953
params updated
43.728428
params updated
44.30366
params updated
43.553883
params updated
43.892494
params updated
43.327023
params updated
43.122494
params updated
43.26449
params updated
43.29893
params updated
43.327442
params updated
43.032936
params updated
43.264893
params updated
42.97
params updated
43.053814
params updated
43.052914
params updated
43.205505
params updated
43.08517
params upda

## Test Train

In [20]:
jit_nn_mjx_one_step = jax.jit(nn_mjx_one_step)


In [26]:
import mujoco.viewer
import time


mjx_data = mjx.make_data(mjx_model)
mj_data = mujoco.MjData(mj_model)

# Load the Keyframe
mjx_data = mjx_data.replace(qpos = mj_model.key_qpos[0])
mj_data.qpos = mj_model.key_qpos[0]

# Calculate equilibrum
mjx_data = acceleration_mtu.calc_equilibrium(mjx_model, mjx_data)
mjx_data = mjx_step(mjx_model, mjx_data)

previous_frame_time = time.time()
i = 0
key = jax.random.key(seed)
with mujoco.viewer.launch_passive(mj_model, mj_data) as viewer:
    while viewer.is_running():
        i += 1
        if(i ==4):
            time.sleep(0)
            # pass
        # Update mjx_data from mj_data. The mj_data was modified by the viewer
        # mjx_data = mjx_data.replace(ctrl=mj_data.ctrl, xfrc_applied=mj_data.xfrc_applied)
        # Use the nerual network to generate ctrl signal
        # Generate key
        
        mjx_data = mjx_data.replace(xfrc_applied=mj_data.xfrc_applied)
        mjx_data = mjx_data.replace(qpos=mj_data.qpos, qvel=mj_data.qvel, time = mj_data.time)
        
        # Update mjx_model from mj_model
        mjx_model = mjx_model.tree_replace({
            'opt.gravity': mj_model.opt.gravity,
            'opt.tolerance': mj_model.opt.tolerance,
            'opt.ls_tolerance': mj_model.opt.ls_tolerance,
            'opt.timestep': mj_model.opt.timestep,
        })
        
        # mjx_data = mjx_step(mjx_model, mjx_data)
        mjx_data, key = jit_nn_mjx_one_step(params, mjx_model, mjx_data, key)
        # mjx_data, key = nn_mjx_multi_steps(params, mjx_model, mjx_data, key)
        
        mjx.get_data_into(mj_data, mj_model, mjx_data)
        
        # Record the current time at the start of this frame
        current_frame_time = time.time()
    
        # Calculate the difference in time from the last frame
        time_between_frames = current_frame_time - previous_frame_time
    
        # Print the time between frames
        print(f"Time between frames: {time_between_frames} seconds")
        previous_frame_time = current_frame_time
        
        # print("ACT:", mjx_data.biomtu.act)
        # print(mjx_data.qpos)
        print(mjx_data.sensordata[2])
        # print(len(mjx_data.qvel))
        viewer.sync()

Time between frames: 0.16742444038391113 seconds
1.6315
Time between frames: 0.012678384780883789 seconds
1.6314864
Time between frames: 0.013202905654907227 seconds
1.6314447
Time between frames: 0.012907743453979492 seconds
1.6313658
Time between frames: 0.01244211196899414 seconds
1.6312408
Time between frames: 0.010648965835571289 seconds
1.6310611
Time between frames: 0.012441873550415039 seconds
1.6308208
Time between frames: 0.013338565826416016 seconds
1.6305243
Time between frames: 0.012471437454223633 seconds
1.6301804
Time between frames: 0.01353311538696289 seconds
1.6297884
Time between frames: 0.014764785766601562 seconds
1.6293757
Time between frames: 0.013852596282958984 seconds
1.6289992
Time between frames: 0.012926101684570312 seconds
1.6286018
Time between frames: 0.010102272033691406 seconds
1.6282041
Time between frames: 0.011306524276733398 seconds
1.6277759
Time between frames: 0.012139558792114258 seconds
1.627289
Time between frames: 0.0093841552734375 seconds

## Test Gradient for activation

In [14]:
n = mjx_model.nbiomtu
act = jp.ones(n)*0.1

In [None]:
# jax.config.update("jax_debug_nans", True)
for i in range(10):
    g = jit_goal_grad(act, mjx_data, mjx_model)
    # g = goal_grad(act, mjx_data, mjx_model)
print(g)

In [None]:
v_act = jp.ones((50,n))*0.3
for i in range(5):
    v_g = jit_v_goal_grad(v_act,mjx_data,mjx_model)
print(v_g)

In [None]:
for i in range(10):
    reward = jit_goal(act, mjx_data, mjx_model)
print(reward)

In [None]:
act = jp.ones(n)*0.0
mtu = mjx_data.biomtu
mtu = mtu.replace(act = act)
test_d = mjx_data
test_d = test_d.replace(biomtu = mtu)
for i in range(100):
    test_d = mjx_step(mjx_model, test_d)
print(test_d.xipos)

In [None]:
print(test_d.biomtu.f_se)
print(test_d.biomtu.f_ce)
print(test_d.biomtu.l)
print(test_d.biomtu.tendon_l - mjx_model.biomtu_tendon_slack_l)
print(test_d.biomtu.fiber_l - mjx_model.biomtu_ofl)
print(test_d.biomtu.fiber_l)
print(test_d.biomtu.act)

In [None]:
print(mjx_data.xpos)

In [None]:
mjx_data.sensordata