# Simple RL code to teach the model to stand

## Setup jax enviroment and Load Model 

In [6]:
import numpy as np
import mediapy as media
import matplotlib.pyplot as plt


import os
# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".60"

# Optionally, force JAX to preallocate memory.
# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "true"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

# Setup environment variable for Nvidia GPU acceleration
os.environ['XLA_FLAGS'] = (
    # '--xla_gpu_enable_triton_softmax_fusion=true '
    '--xla_gpu_triton_gemm_any=True '
    # '--xla_gpu_enable_async_collectives=true '
    # '--xla_gpu_enable_latency_hiding_scheduler=true '
    '--xla_gpu_enable_highest_priority_async_stream=true '
    # '--xla_cpu_multi_thread_eigen=true intra_op_parallelism_threads=32'
)

backend = 'gpu'
# backend = 'METAL'
# backend = 'cpu'

import jax
jax.config.update('jax_platform_name', 'gpu')
os.environ["TF_XLA_FLAGS"] = "--tf_xla_enable_lazy_compilation=false"
# Enable compliation catch
os.environ["JAX_COMPILATION_CACHE_DIR"] = "./jax_cache"
jax.config.update("jax_compilation_cache_dir", "./jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", 0)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 2)
# jax.config.update("jax_explain_cache_misses", True)

from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir("./jax_cache")
# Debug Nan
# jax.config.update("jax_debug_nans", True)

from jax import numpy as jp
# More legible printing from numpy.
jp.set_printoptions(precision=4, suppress=True, linewidth=100)

import mujoco
import mujoco.mjx as mjx
from mujoco.mjx._src import scan
from mujoco.mjx._src import types

# More legible printing from numpy.
np.set_printoptions(precision=4, suppress=True, linewidth=100)

from IPython.display import clear_output
clear_output()

device = jax.devices(backend=backend)[0]

model_path = '/home/bugman/Currentwork/biomujoco_converter/converted/mjc/Gait2354/gait2354_cvt1.xml'

# Single step
mjx_step = jax.jit(mjx.step, backend=backend)



# mjx_multiple_steps = jax.jit(multiple_steps, backend=backend, )

In [7]:
from mujoco.mjx._src.biomtu import acceleration_mtu

mj_model = mujoco.MjModel.from_xml_path(model_path)
mjx_model = mjx.put_model(mj_model,device=device)

# Disable tendon
opt = mjx_model.opt.replace(disableflags = mjx_model.opt.disableflags |mujoco.mjtDisableBit.mjDSBL_PASSIVE)
mjx_model = mjx_model.replace(opt=opt)

mjx_data = mjx.make_data(mjx_model)
mj_data = mujoco.MjData(mj_model)

# Load the Keyframe
mjx_data = mjx_data.replace(qpos = mj_model.key_qpos[0])
mj_data.qpos = mj_model.key_qpos[0]

# Calculate equilibrum
mjx_data = acceleration_mtu.calc_equilibrium(mjx_model, mjx_data)
for i in range(10):
    mjx_data = mjx_step(mjx_model, mjx_data)

def print_all():
    print(mjx_model.biomtu_adr)
    print(mjx_model.mtu_wrap_objid)
    print(mjx_model.mtu_wrap_type)
    print(mjx_model.biomtu_fiso)
    print(mjx_model.biomtu_vmax)
    print(mjx_model.biomtu_ofl)
    print(mjx_model.biomtu_opa)
    print(mjx_model.biomtu_mass)
    print("-------Data--------")
    print("qpos:", mjx_data.qpos)
    print("mtu l:", mjx_data.biomtu.l)
    print("tendon l:", mjx_data.biomtu.tendon_l)
    print("fiber l :", mjx_data.biomtu.fiber_l)
    print("Muscle Bce:", mjx_data.biomtu.B_ce)
    print("Muscle vm:", mjx_data.biomtu.m)
    print("Fiber acc:", mjx_data.biomtu.fiber_acc)
    print("Fiber v:", mjx_data.biomtu.fiber_v)
    print("Biomtu h:", mjx_data.biomtu.h)
    print(mjx_data.biomtu.v)
    print(mjx_data.biomtu.h)  # The constant high of the muscle.
    print(mjx_data.biomtu.pennation_angle)
    print(mjx_data.biomtu.origin_body_id)
    print(mjx_data.biomtu.insertion_body_id)
    print("mtu act:", mjx_data.biomtu.act)
    # print(mjx_data.biomtu.j)
    print(mjx_data.qfrc_biomtu)
    print(mj_model.key_time)
    print(mj_model.key_qpos)
    print(mj_model.key_qvel)

print_all()

[  0   2   4   6   9  12  17  19  23  25  29  33  37  41  46  51  53  55  58  61  65  68  70  74
  77  79  81  83  86  89  94  96 100 102 106 110 114 118 123 128 130 132 135 138 142 145 147 151
 154 156 158 160 162 164]
[  0  58   1  59   2  60   3  86  87  61  88  89   4  62  90  91  92   5  63   6  64  65  93   7
  66   8  94  95  96   9  10  67  68  11  12  69  70  13  14  71  72  15  16  17  73  74  18  19
  20  75  76  21  77  22  78  23  24  79  25  80  97  81  82  83  98  84  85 104  99 105 100 101
 106 107 102 103 108  26 109  27 110  28 111  29 137 138 112 139 140  30 113 141 142 143  31 114
  32 115 116 144  33 117  34 145 146 147  35  36 118 119  37  38 120 121  39  40 122 123  41  42
  43 124 125  44  45  46 126 127  47 128  48 129  49  50 130  51 131 148 132 133 134 149 135 136
 155 150 156 151 152 157 158 153 154 159  52 161  53 162  54 163  55 164  56 165  57 166]
[3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 

In [8]:
print(mjx_model.nbiomtu)
print(mjx_model.nq)
print(mjx_data.qpos)
print(mjx_data.qvel)

54
27
[ 0.0001  0.9481 -0.     -0.0019 -0.0002  0.0001 -0.0014 -0.0024 -0.0302 -0.0037 -0.3956  0.0119
 -0.1066  0.0592  0.2745 -0.0013 -0.0028 -0.0297 -0.0037 -0.3956  0.0118 -0.1065  0.0589  0.2741
  0.0025  0.0002 -0.0001]
[-0.0091 -0.2265 -0.0004 -0.51   -0.016   0.0073  0.6328 -0.1861 -2.0535  0.0016 -0.0005 -0.2853
  2.1858  6.1782  8.8636  0.6363 -0.2228 -2.0195  0.0017 -0.0005 -0.2925  2.1751  6.145   8.8511
  0.5844  0.018  -0.007 ]


## Neural Network

In [9]:
import jax
import jax.numpy as jp
import flax
import flax.linen as nn
import optax

class Controller_NN(nn.Module):
    # It will output the mean and variance for each muscle's activation
    
    def setup(self):
        # Features means the output dimension# Single step
        self.nbiomtu = 54
        self.linear1 = nn.Dense(features=400)
        self.linear2 = nn.Dense(features=400)
        self.linear3 = nn.Dense(features=400)
        self.linear4 = nn.Dense(features=400)
        # The last layer will output the mean and logstd
        self.linear5 = nn.Dense(features=self.nbiomtu*2)
        
    
    def __call__(self, x, key):
        x = self.linear1(x)
        x = nn.relu(x)
        x = self.linear2(x)
        x = nn.relu(x)
        x = self.linear3(x)
        x = nn.relu(x)
        x = self.linear4(x)
        x = nn.relu(x)
        x = self.linear5(x)
        # The last layer of the neural requires samping
        mean = x[0:self.nbiomtu]
        logstd = x[self.nbiomtu:self.nbiomtu*2]
        std = jp.exp(logstd)
        samples = jp.clip(jax.random.normal(key)*std*0.3 + mean, 0 ,1)
        
        return samples, mean, logstd


# Test the neural network
control_model = Controller_NN()

print(control_model)
# Init the model
key = jax.random.key(66)
sub_keys = jax.random.split(key,1)
# The second parameter is the dommy input
params = control_model.init(key,jp.empty([1,mjx_model.nq*2]),sub_keys[0])
# print(params)
print(control_model.apply( params, jp.ones(mjx_model.nq*2), sub_keys[0]))
jit_nn_apply = jax.jit(lambda params,states,key : control_model.apply(params,states,key))

Controller_NN()
(Array([0.3537, 0.4582, 0.    , 0.    , 0.3568, 0.    , 0.    , 0.    , 0.    , 0.0299, 0.1213,
       0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.0104, 0.    , 0.    , 0.    ,
       0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.4784, 0.    , 0.2003, 0.    , 0.    ,
       0.    , 0.    , 0.2286, 0.    , 0.    , 0.    , 0.    , 0.2364, 0.    , 0.    , 0.    ,
       0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    ],      dtype=float32), Array([ 0.5507,  0.7117,  0.1663, -0.386 ,  0.572 ,  0.0428,  0.151 , -0.341 , -0.1592,  0.1468,
        0.3125, -0.506 , -0.098 , -0.1812, -0.2472, -0.1027, -0.2657, -0.0098,  0.1604,  0.1802,
       -0.4265, -0.2269, -0.2294, -0.0505,  0.1361,  0.0374,  0.0641, -0.0512,  0.6791, -0.3706,
        0.4092, -0.1719, -0.2406, -0.022 ,  0.1271,  0.4203,  0.0284,  0.0943, -0.4435,  0.1295,
        0.4245,  0.0444, -0.1952,  0.121 , -0.08  , -0.3598, -0.0973, -0.0134,  0.0238, -0.2089,
       -0

## Combine Neural Net and Simulation into one Jax function

In [10]:
# Multiple steps
def step_fn(carry, _):
    data, model = carry
    new_data = mjx.step(model, data)
    new_carry = (new_data, model)
    return new_carry, _

def multiple_steps(model, data):
    init_carry = (data, model)
    y, _ = jax.lax.scan(step_fn, init_carry, None, length=10)
    new_data = y[0]
    return new_data

def nn_mjx_one_step(nn_params, model, data, key):
    states = jp.concatenate([data.qpos, data.qvel])
    act = jit_nn_apply(nn_params, states, key)[0]
    # Generate the next key
    new_key = jax.random.split(key,1)[0]
    # mtu = data.biomtu
    # mtu = mtu.replace(act = act)
    # data = data.replace(biomtu = mtu)
    new_data = mjx.step(model, data)
    return new_data, new_key

@jax.jit
def nn_mjx_multi_steps(nn_params, model, data, key):
    def nn_step_fn(carry, _):
        data, key = carry
        new_data, new_key = nn_mjx_one_step(nn_params, model, data, key)
        new_carry = (new_data, new_key)
        # head_hight = new_data.xpos[2,2]
        # jax.debug.print("Head Height {0}",head_hight)
        return new_carry, _
    init_carry = (data, key)
    y, _ = jax.lax.scan(nn_step_fn, init_carry, None, length=10)
    new_data = y[0]
    new_key = y[1]
    return new_data, new_key

## Testing control model with neural networks

In [11]:
import mujoco.viewer
import time

previous_frame_time = time.time()
i = 0
key = jax.random.key(334)
with mujoco.viewer.launch_passive(mj_model, mj_data) as viewer:
    while viewer.is_running():
        # Update mjx_data from mj_data. The mj_data was modified by the viewer
        # mjx_data = mjx_data.replace(ctrl=mj_data.ctrl, xfrc_applied=mj_data.xfrc_applied)
        # Use the nerual network to generate ctrl signal
        # Generate key
        
        mjx_data = mjx_data.replace(xfrc_applied=jp.array(mj_data.xfrc_applied, dtype=jp.float32))
        mjx_data = mjx_data.replace(
            qpos= jp.array(mj_data.qpos, dtype=jp.float32),
            qvel= jp.array(mj_data.qvel, dtype=jp.float32),
            time = jp.array(mj_data.time, dtype=jp.float32))
        
        # Update mjx_model from mj_model
        mjx_model = mjx_model.tree_replace({
            'opt.gravity': jp.array(mj_model.opt.gravity, dtype=jp.float32),
            'opt.tolerance': jp.array(mj_model.opt.tolerance, dtype=jp.float32),
            'opt.ls_tolerance': jp.array(mj_model.opt.ls_tolerance, dtype=jp.float32),
            'opt.timestep': jp.array(mj_model.opt.timestep, dtype=jp.float32),
        })
        
        # mjx_data = mjx_step(mjx_model, mjx_data)
        mjx_data, key = nn_mjx_multi_steps(params, mjx_model, mjx_data, key)
        mjx.get_data_into(mj_data, mj_model, mjx_data)
        
        # Record the current time at the start of this frame
        current_frame_time = time.time()
    
        # Calculate the difference in time from the last frame
        time_between_frames = current_frame_time - previous_frame_time
    
        # Print the time between frames
        print(f"Time between frames: {time_between_frames} seconds")
        previous_frame_time = current_frame_time
        
        # print("ACT:", mjx_data.biomtu.act)
        # print(mjx_data.qpos)
        print(mjx_data.sensordata)
        # print(len(mjx_data.qvel))
        viewer.sync()

Time between frames: 6.814907550811768 seconds
[-0.1008 -0.0003  1.6299]


## Batched Random Init Model

In [12]:
def random_init(data, model, rng: jax.Array):
    nbiomtu = model.nq
    init_qpos = data.qpos
    init_qvel = data.qvel
    new_rng, rng1, rng2 = jax.random.split(rng, 3)
    # Qpos_1 is the vertical position
    random_qpos = init_qpos + jax.random.uniform(rng1, [nbiomtu], minval=jp.array(-1.0, dtype=jp.float32), maxval=jp.array(1.0, dtype=jp.float32))*0.01
    random_qvel = init_qvel + jax.random.uniform(rng2, [nbiomtu], minval=jp.array(-1.0, dtype=jp.float32), maxval=jp.array(1.0, dtype=jp.float32))*0.01
    newdata = data.replace(qpos=random_qpos)
    newdata = newdata.replace(qvel=random_qvel)
    newdata = mjx.forward(mjx_model, newdata)
    # print('data:',data.qpos, data.qvel)
    # Calculate equilibrum
    # newdata = acceleration_mtu.calc_equilibrium(mjx_model, newdata)
    # newdata = mjx_step(mjx_model, newdata)
    return newdata, new_rng



## Test Init

In [13]:
batch_size = 20
seed = 2024
key = jax.random.key(seed)
rngs = jax.random.split(key, batch_size) 
vrandom_init_lower = jax.jit(jax.vmap(random_init, in_axes=(None, None, 0), out_axes=0)).lower(mjx_data, mjx_model, rngs)
vrandom_init = vrandom_init_lower.compile()


In [14]:
mjx_data_batch, rngs = vrandom_init(mjx_data, mjx_model, rngs)
print(mjx_data_batch.biomtu)

Biomtu(act=Array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32), l=Array([[0.1209, 0.1308, 0.1092, ..., 0.1997, 0.2741, 0.2748],
       [0.1214, 0.131 , 0.1091, ..., 0.1997, 0.2743, 0.2748],
       [0.1213, 0.1312, 0.1095, ..., 0.1994, 0.2754, 0.2747],
       ...,
       [0.1217, 0.1318, 0.1101, ..., 0.1995, 0.2739, 0.2741],
       [0.121 , 0.1311, 0.1096, ..., 0.2001, 0.2743, 0.2755],
       [0.1215, 0.1311, 0.1092, ..., 0.1999, 0.2737, 0.2742]], dtype=float32), v=Array([[ 0.0141,  0.004 , -0.0062, ...,  0.0276,  0.0339,  0.0324],
       [ 0.0153,  0.0048, -0.0058, ...,  0.029 ,  0.0327,  0.0341],
       [ 0.0143,  0.0039, -0.0064, ...,  0.0288,  0.0318,  0.0335],
       ...,
       [ 0.0143,  0.0044, -0.0056, ...,  0.0278,  0.0333,  0.0335],
       [ 0.0144,  0.0041, -0.0062, ...,  0.

## Reward Function
get reward at the same time calculate the mjx_data

In [15]:
# from functools import partial

def nn_mjx_one_step(nn_params, model, data, key):
    states = jp.concatenate([data.qpos, data.qvel])
    act = jit_nn_apply(nn_params, states, key)[0]
    # Generate the next key
    new_key = jax.random.split(key,1)[0]
    data = data.replace(biomtu = data.biomtu.replace(act=act))
    new_data = mjx.step(model, data)
    return new_data, new_key

def nn_step_fn(carry, _):
    data, key, nn_params, model = carry
    new_data, new_key = nn_mjx_one_step(nn_params, model, data, key)
    new_carry = (new_data, new_key, nn_params, model)
    head_hight = new_data.sensordata[2]
    # jax.debug.print("Head Height {0}",head_hight)
    # return new_carry, (head_hight-1.63)**2
    return new_carry, -head_hight

def decay_sum_scan(x, decay):
    def f(sxtm1, xt):
        b = xt + decay * sxtm1
        return b, b
    return jax.lax.scan(f, jp.zeros(x.shape[1:]), x)[1]

# This function return the loss as well as the new data
def reward_n_step(nn_params, model, data, key):
    repeat_n = 150   # Simulate for 0.1s
    carry, head_heights = jax.lax.scan(nn_step_fn, (data, key, nn_params, model), None, repeat_n)
    loss = decay_sum_scan(head_heights, 0.6)[repeat_n -1]
    new_data = carry[0]
    return loss, new_data

def batch_reward(nn_params, batched_data, keys, model):
    batch_loss, batch_new_data = jax.vmap(reward_n_step, (None, None, 0, 0))(nn_params, model, batched_data, keys)
    return jp.mean(batch_loss), batch_new_data 


batch_size = 20
seed = 2024
key = jax.random.key(seed)
rngs = jax.random.split(key, batch_size) 
mjx_data_batch, rngs = vrandom_init(mjx_data, mjx_model, rngs)

print("start lower")

jit_batch_reward_lower = jax.jit(batch_reward).lower(params, mjx_data_batch, rngs, mjx_model)
jit_batch_reward_grad_lower = jax.jit(jax.grad(batch_reward,has_aux=True)).lower(params, mjx_data_batch, rngs, mjx_model)


start lower


In [16]:
print("start compiling ")

jit_batch_reward = jit_batch_reward_lower.compile()
jit_batch_reward_grad = jit_batch_reward_grad_lower.compile()

start compiling 


In [17]:
# Generate batched data
batch_size = 20
seed = 2024
key = jax.random.key(seed)
rngs = jax.random.split(key, batch_size) 
mjx_data_batch, rngs = vrandom_init(mjx_data, mjx_model, rngs)

r,d = jit_batch_reward(params, mjx_data_batch, rngs, mjx_model)
print(r)
print("Calculating reward grad")
# pr = jax.make_jaxpr(jit_batch_reward_grad)(params, mjx_data_batch, rngs, mjx_model)
g,d = jit_batch_reward_grad(params, mjx_data_batch, rngs, mjx_model)
print(g)

-3.5399919
Calculating reward grad
{'params': {'linear1': {'bias': Array([ 0.0019, -0.0002,  0.0131,  0.0005,  0.0093,  0.0225, -0.0033, -0.0056, -0.0013,  0.0026,
        0.0008, -0.0082,  0.001 ,  0.0255,  0.01  , -0.0062, -0.0014, -0.0185, -0.0082,  0.0029,
       -0.0229, -0.0008,  0.0058, -0.0008,  0.0063, -0.0051, -0.0167, -0.001 , -0.0108, -0.0347,
       -0.0056, -0.0029, -0.0084,  0.0011, -0.0132,  0.002 ,  0.0012,  0.0026,  0.0011, -0.0028,
        0.0042,  0.0004,  0.0033, -0.0036, -0.0042,  0.0008, -0.0002, -0.0077, -0.0077,  0.0078,
        0.0033,  0.0016,  0.0021, -0.0122,  0.0012,  0.0031,  0.0002, -0.0039, -0.0116,  0.012 ,
        0.004 ,  0.0015, -0.0059,  0.0122,  0.0041, -0.0023,  0.0029,  0.0136,  0.0024,  0.0003,
       -0.003 , -0.0088, -0.0015,  0.0085, -0.0022,  0.0311, -0.0149, -0.0026, -0.0172,  0.016 ,
        0.0111,  0.0049, -0.017 , -0.0213, -0.0055,  0.0207, -0.003 , -0.0016, -0.0019, -0.0028,
        0.0002,  0.0085,  0.0002,  0.0066,  0.0072,  0.0103,

## Reset part of the data based on the condition

## Train the NN

In [18]:
seed = 2024
mjx_data = mjx.make_data(mjx_model)
mj_data = mujoco.MjData(mj_model)

# Load the Keyframe
mjx_data = mjx_data.replace(qpos = mj_model.key_qpos[0])
mj_data.qpos = mj_model.key_qpos[0]
# Calculate equilibrum
mjx_data = acceleration_mtu.calc_equilibrium(mjx_model, mjx_data)
mjx_data = mjx_step(mjx_model, mjx_data)
mjx_data_batch, rngs = vrandom_init(mjx_data, mjx_model, rngs)

In [19]:
params = control_model.init(key,jp.empty([1,mjx_model.nq*2]),sub_keys[0])


In [None]:
tx = optax.adam(learning_rate=0.0003)
opt_state = tx.init(params)

for i in range(400):
    print(i)
    # generate random mjx_data
    mjx_data_batch, rngs = vrandom_init(mjx_data, mjx_model, rngs)
    # print(rngs[0])
    r,d = jit_batch_reward(params, mjx_data_batch, rngs, mjx_model)
    print(r)
    g,d = jit_batch_reward_grad(params, mjx_data_batch, rngs, mjx_model)
    updates, opt_stats = tx.update(g, opt_state)
    params = optax.apply_updates(params, updates)
    print("params updated")

-3.5281446
params updated
-4.0209203
params updated
-3.7786527
params updated
-3.942047
params updated
-3.935949
params updated
-3.9823956
params updated
-4.0152564
params updated
-4.01274
params updated
-4.0288286
params updated
-4.0533924
params updated
-3.9708288
params updated
-4.075292
params updated
-4.002211
params updated
-4.062696
params updated
-4.071758
params updated
-4.01374
params updated
-4.0753207
params updated
-4.0449924
params updated
-4.0877795
params updated
-4.088748
params updated
-4.0683517
params updated
-4.1258464
params updated
-4.1218753
params updated
-4.1411395
params updated
-4.130592
params updated
-4.1707473
params updated
-4.182951
params updated
-4.191001
params updated
-4.19615
params updated
-4.1989636
params updated
-4.214529
params updated
-4.218633
params updated
-4.21264
params updated
-4.234961
params updated
-4.2385287
params updated
-4.2108254
params updated
-4.242253
params updated
-4.219259
params updated
-4.246497
params updated
-4.24369
p

KeyboardInterrupt: 

## Test Train

In [21]:
jit_nn_mjx_one_step = jax.jit(nn_mjx_one_step)


In [25]:
import mujoco.viewer
import time


mjx_data = mjx.make_data(mjx_model)
mj_data = mujoco.MjData(mj_model)

# Load the Keyframe
mjx_data = mjx_data.replace(qpos = mj_model.key_qpos[0])
mj_data.qpos = mj_model.key_qpos[0]

# Calculate equilibrum
mjx_data = acceleration_mtu.calc_equilibrium(mjx_model, mjx_data)
mjx_data = mjx_step(mjx_model, mjx_data)

previous_frame_time = time.time()
i = 0
key = jax.random.key(seed)
with mujoco.viewer.launch_passive(mj_model, mj_data) as viewer:
    while viewer.is_running():
        i += 1
        if(i ==4):
            time.sleep(5)
            # pass
        # Update mjx_data from mj_data. The mj_data was modified by the viewer
        # mjx_data = mjx_data.replace(ctrl=mj_data.ctrl, xfrc_applied=mj_data.xfrc_applied)
        # Use the nerual network to generate ctrl signal
        # Generate key
        
        mjx_data = mjx_data.replace(xfrc_applied=jp.array(mj_data.xfrc_applied, dtype=jp.float32))
        
        mjx_data = mjx_data.replace(
            qpos= jp.array(mj_data.qpos, dtype=jp.float32),
            qvel= jp.array(mj_data.qvel, dtype=jp.float32),
            time = jp.array(mj_data.time, dtype=jp.float32))
        
        # Update mjx_model from mj_model
        mjx_model = mjx_model.tree_replace({
            'opt.gravity': jp.array(mj_model.opt.gravity, dtype=jp.float32),
            'opt.tolerance': jp.array(mj_model.opt.tolerance, dtype=jp.float32),
            'opt.ls_tolerance': jp.array(mj_model.opt.ls_tolerance, dtype=jp.float32),
            'opt.timestep': jp.array(mj_model.opt.timestep, dtype=jp.float32),
        })
        
        # mjx_data = mjx_step(mjx_model, mjx_data)
        mjx_data, key = jit_nn_mjx_one_step(params, mjx_model, mjx_data, key)
        # mjx_data, key = nn_mjx_multi_steps(params, mjx_model, mjx_data, key)
        
        mjx.get_data_into(mj_data, mj_model, mjx_data)
        
        # Record the current time at the start of this frame
        current_frame_time = time.time()
    
        # Calculate the difference in time from the last frame
        time_between_frames = current_frame_time - previous_frame_time
    
        # Print the time between frames
        print(f"Time between frames: {time_between_frames} seconds")
        previous_frame_time = current_frame_time
        
        # print("ACT:", mjx_data.biomtu.act)
        # print(mjx_data.qpos)
        print(mjx_data.sensordata[2])
        # print(len(mjx_data.qvel))
        viewer.sync()

Time between frames: 0.24607539176940918 seconds
1.6315
Time between frames: 0.016235828399658203 seconds
1.6314892
Time between frames: 0.017040729522705078 seconds
1.6314808
Time between frames: 5.015024423599243 seconds
1.6314725
Time between frames: 0.014901161193847656 seconds
1.6314596
Time between frames: 0.014194726943969727 seconds
1.631412
Time between frames: 0.014748573303222656 seconds
1.6313057
Time between frames: 0.014504194259643555 seconds
1.6311425
Time between frames: 0.014104366302490234 seconds
1.6309278
Time between frames: 0.015198707580566406 seconds
1.6306683
Time between frames: 0.0168154239654541 seconds
1.6303751
Time between frames: 0.01676177978515625 seconds
1.6300918
Time between frames: 0.016523122787475586 seconds
1.6298199
Time between frames: 0.016750335693359375 seconds
1.6295555
Time between frames: 0.016635417938232422 seconds
1.629287
Time between frames: 0.016657114028930664 seconds
1.629
Time between frames: 0.01719975471496582 seconds
1.62867

## Test Gradient for activation

In [None]:
n = mjx_model.nbiomtu
act = jp.ones(n)*0.1

In [None]:
# jax.config.update("jax_debug_nans", True)
for i in range(10):
    g = jit_goal_grad(act, mjx_data, mjx_model)
    # g = goal_grad(act, mjx_data, mjx_model)
print(g)

In [None]:
v_act = jp.ones((50,n))*0.3
for i in range(5):
    v_g = jit_v_goal_grad(v_act,mjx_data,mjx_model)
print(v_g)

In [None]:
for i in range(10):
    reward = jit_goal(act, mjx_data, mjx_model)
print(reward)

In [None]:
act = jp.ones(n)*0.0
mtu = mjx_data.biomtu
mtu = mtu.replace(act = act)
test_d = mjx_data
test_d = test_d.replace(biomtu = mtu)
for i in range(100):
    test_d = mjx_step(mjx_model, test_d)
print(test_d.xipos)

In [None]:
print(test_d.biomtu.f_se)
print(test_d.biomtu.f_ce)
print(test_d.biomtu.l)
print(test_d.biomtu.tendon_l - mjx_model.biomtu_tendon_slack_l)
print(test_d.biomtu.fiber_l - mjx_model.biomtu_ofl)
print(test_d.biomtu.fiber_l)
print(test_d.biomtu.act)

In [None]:
print(mjx_data.xpos)

In [None]:
mjx_model.nbiomtu

In [None]:
batch_size = 40
seed = 2024
key = jax.random.key(seed)
rngs = jax.random.split(key, batch_size) 
mjx_data_batch, rngs2 = vrandom_init(mjx_data, mjx_model, rngs)

In [None]:
print(rngs.shape)
print(rngs2.shape)
print(mjx_data_batch.biomtu.l.shape)

In [None]:
from functools import *

# A normal function
def test_f(input):
    a,b,c = input
    return 100 * a + 10 * b + c, a*10

test_g = jax.jit(jax.grad(test_f, has_aux=True))


r = test_g((1.,3.,4.))

print(r)

In [None]:
mjx_data.xfrc_applied.shape

In [None]:
a = jp.ones(54)
b = jp.zeros(30)
c = jp.pad(a,15)
print(c)

In [None]:
print([0]*30)