# Simple RL code to teach the model to stand

## Setup jax enviroment and Load Model 

In [1]:
import numpy as np
import mediapy as media
import matplotlib.pyplot as plt


import os
# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".60"

# Optionally, force JAX to preallocate memory.
# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "true"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

# Setup environment variable for Nvidia GPU acceleration
os.environ['XLA_FLAGS'] = (
    '--xla_gpu_enable_triton_softmax_fusion=true '
    '--xla_gpu_triton_gemm_any=True '
    # '--xla_gpu_enable_async_collectives=true '
    # '--xla_gpu_enable_latency_hiding_scheduler=true '
    '--xla_gpu_enable_highest_priority_async_stream=true '
)

backend = 'gpu'
# backend = 'METAL'
# backend = 'cpu'

import jax
# Enable compliation catch
os.environ["JAX_COMPILATION_CACHE_DIR"] = "./jax_cache"
jax.config.update("jax_compilation_cache_dir", "./jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", 0)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 2)
jax.config.update("jax_explain_cache_misses", True)

from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir("./jax_cache")
# Debug Nan
# jax.config.update("jax_debug_nans", True)

from jax import numpy as jp
# More legible printing from numpy.
jp.set_printoptions(precision=4, suppress=True, linewidth=100)

import mujoco
import mujoco.mjx as mjx
from mujoco.mjx._src import scan
from mujoco.mjx._src import types

# More legible printing from numpy.
np.set_printoptions(precision=4, suppress=True, linewidth=100)

from IPython.display import clear_output
clear_output()

device = jax.devices(backend=backend)[0]

model_path = '/home/bugman/Currentwork/biomujoco_converter/converted/mjc/Gait2354/gait2354_cvt1.xml'

# Single step
mjx_step = jax.jit(mjx.step, backend=backend)



# mjx_multiple_steps = jax.jit(multiple_steps, backend=backend, )

In [2]:
from mujoco.mjx._src.biomtu import acceleration_mtu

mj_model = mujoco.MjModel.from_xml_path(model_path)
mjx_model = mjx.put_model(mj_model,device=device)

# Disable tendon
opt = mjx_model.opt.replace(disableflags = mjx_model.opt.disableflags |mujoco.mjtDisableBit.mjDSBL_PASSIVE)
mjx_model = mjx_model.replace(opt=opt)

mjx_data = mjx.make_data(mjx_model)
mj_data = mujoco.MjData(mj_model)

# Load the Keyframe
mjx_data = mjx_data.replace(qpos = mj_model.key_qpos[0])
mj_data.qpos = mj_model.key_qpos[0]

# Calculate equilibrum
mjx_data = acceleration_mtu.calc_equilibrium(mjx_model, mjx_data)
mjx_data = mjx_step(mjx_model, mjx_data)

def print_all():
    print(mjx_model.biomtu_adr)
    print(mjx_model.mtu_wrap_objid)
    print(mjx_model.mtu_wrap_type)
    print(mjx_model.biomtu_fiso)
    print(mjx_model.biomtu_vmax)
    print(mjx_model.biomtu_ofl)
    print(mjx_model.biomtu_opa)
    print(mjx_model.biomtu_mass)
    print("-------Data--------")
    print("qpos:", mjx_data.qpos)
    print("mtu l:", mjx_data.biomtu.l)
    print("tendon l:", mjx_data.biomtu.tendon_l)
    print("fiber l :", mjx_data.biomtu.fiber_l)
    print("Muscle Bce:", mjx_data.biomtu.B_ce)
    print("Muscle vm:", mjx_data.biomtu.m)
    print("Fiber acc:", mjx_data.biomtu.fiber_acc)
    print("Fiber v:", mjx_data.biomtu.fiber_v)
    print("Biomtu h:", mjx_data.biomtu.h)
    print(mjx_data.biomtu.v)
    print(mjx_data.biomtu.h)  # The constant high of the muscle.
    print(mjx_data.biomtu.pennation_angle)
    print(mjx_data.biomtu.origin_body_id)
    print(mjx_data.biomtu.insertion_body_id)
    print("mtu act:", mjx_data.biomtu.act)
    # print(mjx_data.biomtu.j)
    print(mjx_data.qfrc_biomtu)
    print(mj_model.key_time)
    print(mj_model.key_qpos)
    print(mj_model.key_qvel)

print_all()

TRACING CACHE MISS at /tmp/ipykernel_75063/2268553419.py:4:12 (<module>) because:
  for convert_element_type defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/dispatch.py:98
  never seen input type signature:
    args[0]: f32[46,3]
  closest seen input type signature has 1 mismatches, including:
    * at args[0], seen f32[4,3], but now given f32[46,3]
PERSISTENT COMPILATION CACHE MISS for 'jit_convert_element_type' with key 'jit_convert_element_type-e1a41599b5465e70a4219149512fb2a54525222ad37620128d994af5fcbd0131'
Not writing persistent cache entry for 'jit_convert_element_type' because it took < 2.00 seconds to compile (0.01s)
TRACING CACHE MISS at /tmp/ipykernel_75063/2268553419.py:4:12 (<module>) because:
  for convert_element_type defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/dispatch.py:98
  never seen input type signature:
    args[0]: f32[74,5,3]
  closest seen input type signature has 1 mismatches, in

--------------------
BUGMAN: muscle name-glut_med1_r
--------------------
BUGMAN: muscle name-glut_med2_r
--------------------
BUGMAN: muscle name-glut_med3_r
--------------------
BUGMAN: muscle name-bifemlh_r
--------------------
BUGMAN: muscle name-bifemsh_r
--------------------
BUGMAN: muscle name-sar_r
--------------------
BUGMAN: muscle name-add_mag2_r
--------------------
BUGMAN: muscle name-tfl_r
--------------------
BUGMAN: muscle name-pect_r
--------------------
BUGMAN: muscle name-grac_r
--------------------
BUGMAN: muscle name-glut_max1_r
--------------------
BUGMAN: muscle name-glut_max2_r
--------------------
BUGMAN: muscle name-glut_max3_r
--------------------
BUGMAN: muscle name-iliacus_r
--------------------
BUGMAN: muscle name-psoas_r
--------------------
BUGMAN: muscle name-quad_fem_r
--------------------
BUGMAN: muscle name-gem_r
--------------------
BUGMAN: muscle name-peri_r
--------------------
BUGMAN: muscle name-rect_fem_r
--------------------
BUGMAN: muscle nam

Not writing persistent cache entry for 'jit_convert_element_type' because it took < 2.00 seconds to compile (0.01s)
TRACING CACHE MISS at /tmp/ipykernel_75063/2268553419.py:4:12 (<module>) because:
  for convert_element_type defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/dispatch.py:98
  never seen input type signature:
    args[0]: i32[120,2]
  closest seen input type signature has 1 mismatches, including:
    * at args[0], seen i32[118,2], but now given i32[120,2]
PERSISTENT COMPILATION CACHE MISS for 'jit_convert_element_type' with key 'jit_convert_element_type-9fd6f0aab73245ad68200eb7d90dd954a9a451d24c015132e94871db1dd029f8'
Not writing persistent cache entry for 'jit_convert_element_type' because it took < 2.00 seconds to compile (0.01s)
TRACING CACHE MISS at /tmp/ipykernel_75063/2268553419.py:4:12 (<module>) because:
  for convert_element_type defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/dispatch.py

[  0   2   4   6   9  12  17  19  23  25  29  33  37  41  46  51  53  55  58  61  65  68  70  74
  77  79  81  83  86  89  94  96 100 102 106 110 114 118 123 128 130 132 135 138 142 145 147 151
 154 156 158 160 162 164]
[  0  58   1  59   2  60   3  86  87  61  88  89   4  62  90  91  92   5  63   6  64  65  93   7
  66   8  94  95  96   9  10  67  68  11  12  69  70  13  14  71  72  15  16  17  73  74  18  19
  20  75  76  21  77  22  78  23  24  79  25  80  97  81  82  83  98  84  85 104  99 105 100 101
 106 107 102 103 108  26 109  27 110  28 111  29 137 138 112 139 140  30 113 141 142 143  31 114
  32 115 116 144  33 117  34 145 146 147  35  36 118 119  37  38 120 121  39  40 122 123  41  42
  43 124 125  44  45  46 126 127  47 128  48 129  49  50 130  51 131 148 132 133 134 149 135 136
 155 150 156 151 152 157 158 153 154 159  52 161  53 162  54 163  55 164  56 165  57 166]
[3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 

In [None]:
print(mjx_model.nbiomtu)
print(mjx_model.nq)
print(mjx_data.qpos)
print(mjx_data.qvel)

## Neural Network

In [3]:
import jax
import jax.numpy as jp
import flax
import flax.linen as nn
import optax

class Controller_NN(nn.Module):
    # It will output the mean and variance for each muscle's activation
    
    def setup(self):
        # Features means the output dimension# Single step
        self.nbiomtu = 54
        self.linear1 = nn.Dense(features=400)
        self.linear2 = nn.Dense(features=400)
        self.linear3 = nn.Dense(features=400)
        self.linear4 = nn.Dense(features=400)
        # The last layer will output the mean and logstd
        self.linear5 = nn.Dense(features=self.nbiomtu*2)
        
    
    def __call__(self, x, key):
        x = self.linear1(x)
        x = nn.relu(x)
        x = self.linear2(x)
        x = nn.relu(x)
        x = self.linear3(x)
        x = nn.relu(x)
        x = self.linear4(x)
        x = nn.relu(x)
        x = self.linear5(x)
        # The last layer of the neural requires samping
        mean = x[0:self.nbiomtu]
        logstd = x[self.nbiomtu:self.nbiomtu*2]
        std = jp.exp(logstd)
        samples = jp.clip(jax.random.normal(key)*std*0.3 + mean, 0 ,1)
        
        return samples, mean, logstd


# Test the neural network
control_model = Controller_NN()

print(control_model)
# Init the model
key = jax.random.key(66)
sub_keys = jax.random.split(key,1)
# The second parameter is the dommy input
params = control_model.init(key,jp.empty([1,mjx_model.nq*2]),sub_keys[0])
# print(params)
print(control_model.apply( params, jp.ones(mjx_model.nq*2), sub_keys[0]))
jit_nn_apply = jax.jit(lambda params,states,key : control_model.apply(params,states,key))

PERSISTENT COMPILATION CACHE MISS for 'jit__threefry_seed' with key 'jit__threefry_seed-4b02442836201a5049814d2c13dbd0d5d397aad9b7380caccab02069b0187537'
Not writing persistent cache entry for 'jit__threefry_seed' because it took < 2.00 seconds to compile (0.02s)
TRACING CACHE MISS at /tmp/ipykernel_75063/3932189320.py:46:11 (<module>) because:
  never seen function:
    _threefry_split id=135279509219712 defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/prng.py:1089
PERSISTENT COMPILATION CACHE MISS for 'jit__threefry_split' with key 'jit__threefry_split-bc786fcc0b21f80ba0954d01102ab63800431d5ccde71fd42930c70749933f1f'


Controller_NN()


Not writing persistent cache entry for 'jit__threefry_split' because it took < 2.00 seconds to compile (0.05s)
TRACING CACHE MISS at /tmp/ipykernel_75063/3932189320.py:48:32 (<module>) because:
  never seen function:
    broadcast_in_dim id=135277716270400 defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/dispatch.py:98
PERSISTENT COMPILATION CACHE MISS for 'jit_broadcast_in_dim' with key 'jit_broadcast_in_dim-15c52d050c0d510fada14079342f99aa7654ad6b6beeaf4a30f22ef90981f0bf'
Not writing persistent cache entry for 'jit_broadcast_in_dim' because it took < 2.00 seconds to compile (0.01s)
TRACING CACHE MISS at /tmp/ipykernel_75063/3932189320.py:48:61 (<module>) because:
  never seen function:
    dynamic_slice id=135277716376768 defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/dispatch.py:98
PERSISTENT COMPILATION CACHE MISS for 'jit_dynamic_slice' with key 'jit_dynamic_slice-2975fad560ec495afff8400641e610116d752b59

(Array([0.3537, 0.4582, 0.    , 0.    , 0.3568, 0.    , 0.    , 0.    , 0.    , 0.0299, 0.1213,
       0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.0104, 0.    , 0.    , 0.    ,
       0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.4784, 0.    , 0.2003, 0.    , 0.    ,
       0.    , 0.    , 0.2286, 0.    , 0.    , 0.    , 0.    , 0.2364, 0.    , 0.    , 0.    ,
       0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    ],      dtype=float32), Array([ 0.5507,  0.7117,  0.1663, -0.386 ,  0.572 ,  0.0428,  0.151 , -0.341 , -0.1592,  0.1468,
        0.3125, -0.506 , -0.098 , -0.1812, -0.2472, -0.1027, -0.2657, -0.0098,  0.1604,  0.1802,
       -0.4265, -0.2269, -0.2294, -0.0505,  0.1361,  0.0374,  0.0641, -0.0512,  0.6791, -0.3706,
        0.4092, -0.1719, -0.2406, -0.022 ,  0.1271,  0.4203,  0.0284,  0.0943, -0.4435,  0.1295,
        0.4245,  0.0444, -0.1952,  0.121 , -0.08  , -0.3598, -0.0973, -0.0134,  0.0238, -0.2089,
       -0.5425, -0.3392, 

## Combine Neural Net and Simulation into one Jax function

In [4]:
# Multiple steps
def step_fn(carry, _):
    data, model = carry
    new_data = mjx.step(model, data)
    new_carry = (new_data, model)
    return new_carry, _

def multiple_steps(model, data):
    init_carry = (data, model)
    y, _ = jax.lax.scan(step_fn, init_carry, None, length=10)
    new_data = y[0]
    return new_data

def nn_mjx_one_step(nn_params, model, data, key):
    states = jp.concatenate([data.qpos, data.qvel])
    act = jit_nn_apply(nn_params, states, key)[0]
    # Generate the next key
    new_key = jax.random.split(key,1)[0]
    # mtu = data.biomtu
    # mtu = mtu.replace(act = act)
    # data = data.replace(biomtu = mtu)
    new_data = mjx.step(model, data)
    return new_data, new_key

@jax.jit
def nn_mjx_multi_steps(nn_params, model, data, key):
    def nn_step_fn(carry, _):
        data, key = carry
        new_data, new_key = nn_mjx_one_step(nn_params, model, data, key)
        new_carry = (new_data, new_key)
        # head_hight = new_data.xpos[2,2]
        # jax.debug.print("Head Height {0}",head_hight)
        return new_carry, _
    init_carry = (data, key)
    y, _ = jax.lax.scan(nn_step_fn, init_carry, None, length=10)
    new_data = y[0]
    new_key = y[1]
    return new_data, new_key

## Testing control model with neural networks

In [None]:
import mujoco.viewer
import time

previous_frame_time = time.time()
i = 0
key = jax.random.key(334)
with mujoco.viewer.launch_passive(mj_model, mj_data) as viewer:
    while viewer.is_running():
        # Update mjx_data from mj_data. The mj_data was modified by the viewer
        # mjx_data = mjx_data.replace(ctrl=mj_data.ctrl, xfrc_applied=mj_data.xfrc_applied)
        # Use the nerual network to generate ctrl signal
        # Generate key
        
        mjx_data = mjx_data.replace(xfrc_applied=jp.array(mj_data.xfrc_applied, dtype=jp.float32))
        mjx_data = mjx_data.replace(
            qpos= jp.array(mj_data.qpos, dtype=jp.float32),
            qvel= jp.array(mj_data.qvel, dtype=jp.float32),
            time = jp.array(mj_data.time, dtype=jp.float32))
        
        # Update mjx_model from mj_model
        mjx_model = mjx_model.tree_replace({
            'opt.gravity': jp.array(mj_model.opt.gravity, dtype=jp.float32),
            'opt.tolerance': jp.array(mj_model.opt.tolerance, dtype=jp.float32),
            'opt.ls_tolerance': jp.array(mj_model.opt.ls_tolerance, dtype=jp.float32),
            'opt.timestep': jp.array(mj_model.opt.timestep, dtype=jp.float32),
        })
        
        # mjx_data = mjx_step(mjx_model, mjx_data)
        mjx_data, key = nn_mjx_multi_steps(params, mjx_model, mjx_data, key)
        mjx.get_data_into(mj_data, mj_model, mjx_data)
        
        # Record the current time at the start of this frame
        current_frame_time = time.time()
    
        # Calculate the difference in time from the last frame
        time_between_frames = current_frame_time - previous_frame_time
    
        # Print the time between frames
        print(f"Time between frames: {time_between_frames} seconds")
        previous_frame_time = current_frame_time
        
        # print("ACT:", mjx_data.biomtu.act)
        # print(mjx_data.qpos)
        print(mjx_data.sensordata)
        # print(len(mjx_data.qvel))
        viewer.sync()

## Batched Random Init Model

In [5]:
def random_init(data, model, rng: jax.Array):
    nbiomtu = model.nq
    init_qpos = data.qpos
    init_qvel = data.qvel
    new_rng, rng1, rng2 = jax.random.split(rng, 3)
    # Qpos_1 is the vertical position
    random_qpos = init_qpos + jax.random.uniform(rng1, [nbiomtu], minval=jp.array(-1.0, dtype=jp.float32), maxval=jp.array(1.0, dtype=jp.float32))*0.01
    random_qvel = init_qvel + jax.random.uniform(rng2, [nbiomtu], minval=jp.array(-1.0, dtype=jp.float32), maxval=jp.array(1.0, dtype=jp.float32))*0.01
    newdata = data.replace(qpos=random_qpos)
    newdata = newdata.replace(qvel=random_qvel)
    newdata = mjx.forward(mjx_model, newdata)
    # print('data:',data.qpos, data.qvel)
    # Calculate equilibrum
    # newdata = acceleration_mtu.calc_equilibrium(mjx_model, newdata)
    # newdata = mjx_step(mjx_model, newdata)
    return newdata, new_rng

vrandom_init = jax.jit(jax.vmap(random_init, in_axes=(None, None, 0), out_axes=0))

## Test Init

In [None]:
batch_size = 20
seed = 2024
key = jax.random.key(seed)
rngs = jax.random.split(key, batch_size) 
mjx_data_batch, rngs = vrandom_init(mjx_data, mjx_model, rngs)

In [None]:
print(mjx_data_batch.biomtu)

## Reward Function
get reward at the same time calculate the mjx_data

In [7]:
# from functools import partial

def nn_mjx_one_step(nn_params, model, data, key):
    states = jp.concatenate([data.qpos, data.qvel])
    act = jit_nn_apply(nn_params, states, key)[0]
    # Generate the next key
    new_key = jax.random.split(key,1)[0]
    data = data.replace(biomtu = data.biomtu.replace(act=act))
    new_data = mjx.step(model, data)
    return new_data, new_key

def nn_step_fn(carry, _):
    data, key, nn_params, model = carry
    new_data, new_key = nn_mjx_one_step(nn_params, model, data, key)
    new_carry = (new_data, new_key, nn_params, model)
    head_hight = new_data.sensordata[2]
    # jax.debug.print("Head Height {0}",head_hight)
    # return new_carry, (head_hight-1.63)**2
    return new_carry, -head_hight

def decay_sum_scan(x, decay):
    def f(sxtm1, xt):
        b = xt + decay * sxtm1
        return b, b
    return jax.lax.scan(f, jp.zeros(x.shape[1:]), x)[1]

# This function return the loss as well as the new data
def reward_n_step(nn_params, model, data, key):
    repeat_n = 150   # Simulate for 0.1s
    carry, head_heights = jax.lax.scan(nn_step_fn, (data, key, nn_params, model), None, repeat_n)
    loss = decay_sum_scan(head_heights, 0.6)[repeat_n -1]
    new_data = carry[0]
    return loss, new_data

def batch_reward(nn_params, batched_data, keys, model):
    batch_loss, batch_new_data = jax.vmap(reward_n_step, (None, None, 0, 0))(nn_params, model, batched_data, keys)
    return jp.mean(batch_loss), batch_new_data 

jit_batch_reward = jax.jit(batch_reward)
jit_batch_reward_grad = jax.jit(jax.grad(batch_reward,has_aux=True))



In [8]:
# Generate batched data
batch_size = 20
seed = 2024
key = jax.random.key(seed)
rngs = jax.random.split(key, batch_size) 
mjx_data_batch, rngs = vrandom_init(mjx_data, mjx_model, rngs)
r,d = jit_batch_reward(params, mjx_data_batch, rngs, mjx_model)
print(r)
print("Calculating reward grad")
g,d = jit_batch_reward_grad(params, mjx_data_batch, rngs, mjx_model)
print(g)

TRACING CACHE MISS at /tmp/ipykernel_75063/3187473002.py:5:7 (<module>) because:
  for _threefry_split defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/prng.py:1089
explanation unavailable! please open an issue at https://github.com/jax-ml/jax


PERSISTENT COMPILATION CACHE MISS for 'jit__threefry_split' with key 'jit__threefry_split-e380d8b8d4f5377f46c2ec378fa9c27d852e0434811ab02e582bb12deb956f0e'
Not writing persistent cache entry for 'jit__threefry_split' because it took < 2.00 seconds to compile (0.12s)
TRACING CACHE MISS at /tmp/ipykernel_75063/1876140736.py:7:30 (random_init) because:
  for _uniform defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/random.py:405
explanation unavailable! please open an issue at https://github.com/jax-ml/jax
TRACING CACHE MISS at /tmp/ipykernel_75063/3187473002.py:6:23 (<module>) because:
  never seen function:
    random_init id=135277715293024 defined at /tmp/ipykernel_75063/1876140736.py:1
TRACING CACHE MISS at /tmp/ipykernel_75063/1876140736.py:5:26 (random_init) because:
  for _threefry_split defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/prng.py:1089
  tracing context doesn't match, e.g. due to config or con

-3.4703636
Calculating reward grad


TRACING CACHE MISS at /tmp/ipykernel_75063/3187473002.py:10:6 (<module>) because:
  never seen function:
    batch_reward id=135277715296064 defined at /tmp/ipykernel_75063/870905365.py:35


{'params': {'linear1': {'bias': Array([ 0.0001,  0.0011,  0.0057, -0.0009,  0.0154,  0.0223, -0.0011, -0.0015,  0.0041,  0.008 ,
        0.013 , -0.0077,  0.011 ,  0.0225,  0.0163, -0.0043,  0.0005, -0.0212, -0.0126,  0.0011,
       -0.0184,  0.004 ,  0.0095,  0.0032, -0.0038, -0.0064, -0.0048, -0.0004, -0.0136, -0.0358,
        0.0068, -0.0005, -0.0188,  0.0006,  0.0003, -0.0015,  0.005 ,  0.0104,  0.0021,  0.0002,
        0.0042,  0.0016,  0.0035, -0.0003, -0.0088,  0.0029,  0.0014,  0.0028, -0.0069,  0.0019,
        0.0031, -0.0001,  0.0003, -0.0109, -0.0036, -0.0032, -0.0004, -0.0061, -0.0002, -0.0015,
        0.0033, -0.0048, -0.0075,  0.0127,  0.0119,  0.0012, -0.0113,  0.0105,  0.0021, -0.0043,
       -0.0014, -0.0203, -0.0011,  0.0098,  0.0003,  0.0112, -0.0129,  0.0029, -0.0133,  0.0095,
        0.0097,  0.007 , -0.0101, -0.0181, -0.0016,  0.0158, -0.0014,  0.0114, -0.0012, -0.0046,
       -0.0001,  0.0021,  0.    ,  0.0046,  0.004 , -0.0135, -0.0009,  0.0004, -0.0007,  0.0015

## Reset part of the data based on the condition

In [None]:
print(r)
print(d.sensordata[:,2])

## Train the NN

In [9]:
seed = 2024
mjx_data = mjx.make_data(mjx_model)
mj_data = mujoco.MjData(mj_model)

# Load the Keyframe
mjx_data = mjx_data.replace(qpos = mj_model.key_qpos[0])
mj_data.qpos = mj_model.key_qpos[0]
# Calculate equilibrum
mjx_data = acceleration_mtu.calc_equilibrium(mjx_model, mjx_data)
mjx_data = mjx_step(mjx_model, mjx_data)
mjx_data_batch, rngs = vrandom_init(mjx_data, mjx_model, rngs)

TRACING CACHE MISS at /tmp/ipykernel_75063/2935256365.py:9:11 (<module>) because:
  never seen function:
    calc_equilibrium.<locals>.callback_optimize_seperate id=135275499295168 defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/mujoco/mjx/_src/biomtu/acceleration_mtu.py:241
PERSISTENT COMPILATION CACHE MISS for 'jit_callback_optimize_seperate' with key 'jit_callback_optimize_seperate-59fb60c6c41ea6ad181db1b2d6fcba568077feca583fb9429e151f8692d637cc'
Not writing persistent cache entry for 'jit_callback_optimize_seperate' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
TRACING CACHE MISS at /tmp/ipykernel_75063/2935256365.py:9:11 (<module>) because:
  never seen function:
    calc_equilibrium.<locals>.goal id=135275499296608 defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/mujoco/mjx/_src/biomtu/acceleration_mtu.py:215
  but seen another function defined on the same line; maybe the function is
  being re-

In [10]:
params = control_model.init(key,jp.empty([1,mjx_model.nq*2]),sub_keys[0])


In [11]:
tx = optax.adam(learning_rate=0.0003)
opt_state = tx.init(params)

for i in range(400):
    # generate random mjx_data
    mjx_data_batch, rngs = vrandom_init(mjx_data, mjx_model, rngs)
    # print(rngs[0])
    r,d = jit_batch_reward(params, mjx_data_batch, rngs, mjx_model)
    print(r)
    g,d = jit_batch_reward_grad(params, mjx_data_batch, rngs, mjx_model)
    updates, opt_stats = tx.update(g, opt_state)
    params = optax.apply_updates(params, updates)
    print("params updated")

TRACING CACHE MISS at /tmp/ipykernel_75063/3541745262.py:2:12 (<module>) because:
  never seen function:
    broadcast_in_dim id=135276238397952 defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/dispatch.py:98
PERSISTENT COMPILATION CACHE MISS for 'jit_broadcast_in_dim' with key 'jit_broadcast_in_dim-06166a8828405e0b54f4aa19d96f152965eedc5f137f69f531acf76a4dc96295'
Not writing persistent cache entry for 'jit_broadcast_in_dim' because it took < 2.00 seconds to compile (0.04s)
TRACING CACHE MISS at /tmp/ipykernel_75063/3541745262.py:2:12 (<module>) because:
  never seen function:
    broadcast_in_dim id=135275499300608 defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/dispatch.py:98
  but seen another function defined on the same line; maybe the function is
  being re-defined repeatedly, preventing caching?
PERSISTENT COMPILATION CACHE MISS for 'jit_broadcast_in_dim' with key 'jit_broadcast_in_dim-6538d6c2de526808b

-3.531659


TRACING CACHE MISS at /tmp/ipykernel_75063/3541745262.py:11:25 (<module>) because:
  never seen function:
    integer_pow id=135275351897920 defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/dispatch.py:98
PERSISTENT COMPILATION CACHE MISS for 'jit_integer_pow' with key 'jit_integer_pow-b5ae0d7c9f85b1c81c6867c54e9006e2fe7c5f88982f1dd7daa56d642a0196e2'
Not writing persistent cache entry for 'jit_integer_pow' because it took < 2.00 seconds to compile (0.01s)
PERSISTENT COMPILATION CACHE MISS for 'jit__multiply' with key 'jit__multiply-632b88610bbee378933c5ea9cc245d7db252767fb648545aac7daf172960b0ec'
Not writing persistent cache entry for 'jit__multiply' because it took < 2.00 seconds to compile (0.02s)
PERSISTENT COMPILATION CACHE MISS for 'jit__multiply' with key 'jit__multiply-acff3f1934d7896f72ea0ab7ce37a13b9007cb3f5b3ac1d53a85bc5e3c6a0666'
Not writing persistent cache entry for 'jit__multiply' because it took < 2.00 seconds to compile (0.01s)
PERS

params updated
-4.028057


PERSISTENT COMPILATION CACHE MISS for 'jit__add' with key 'jit__add-945ac1a299deea283f1692e283792634b7682d032bdda23a2808c55c71c68108'
Not writing persistent cache entry for 'jit__add' because it took < 2.00 seconds to compile (0.01s)
PERSISTENT COMPILATION CACHE MISS for 'jit__add' with key 'jit__add-0a38ad05b3175d49c61c1dce6844af52e468e1904974e07bcaab6e9bef533428'
Not writing persistent cache entry for 'jit__add' because it took < 2.00 seconds to compile (0.01s)
PERSISTENT COMPILATION CACHE MISS for 'jit__add' with key 'jit__add-bc7afccc5988a3cc77f5df09e8efd0709700dcb7f61d0f678c70ff860821320e'
Not writing persistent cache entry for 'jit__add' because it took < 2.00 seconds to compile (0.01s)
PERSISTENT COMPILATION CACHE MISS for 'jit__add' with key 'jit__add-1c1ae83db20b28dfa80995ecb654ef2e283b9495339b678d599faf14d5a7a69b'
Not writing persistent cache entry for 'jit__add' because it took < 2.00 seconds to compile (0.01s)
PERSISTENT COMPILATION CACHE MISS for 'jit__add' with key 'jit__

params updated
-3.9641628
params updated
-3.8604589
params updated
-4.022585
params updated
-3.911784
params updated
-4.0586495
params updated
-3.9795787
params updated
-4.0736823
params updated
-4.0397162
params updated
-4.070127
params updated
-3.9950447
params updated
-4.0851407
params updated
-4.0834107
params updated
-4.0675435
params updated
-4.067733
params updated
-4.0724196
params updated
-4.074221
params updated
-4.080664
params updated
-4.0765233
params updated
-4.088707
params updated
-4.0771914
params updated
-4.0878325
params updated
-4.0858183
params updated
-4.087845
params updated
-4.0826335
params updated
-4.090716
params updated
-4.09813
params updated
-4.1089063
params updated
-4.1268663
params updated
-4.141212
params updated
-4.1717315
params updated
-4.1747627
params updated
-4.2200575
params updated
-4.2165017
params updated
-4.2269797
params updated
-4.2299595
params updated
-4.2555833
params updated
-4.252299
params updated
-4.2548537
params updated
-4.2623987

## Test Train

In [12]:
jit_nn_mjx_one_step = jax.jit(nn_mjx_one_step)


In [19]:
import mujoco.viewer
import time


mjx_data = mjx.make_data(mjx_model)
mj_data = mujoco.MjData(mj_model)

# Load the Keyframe
mjx_data = mjx_data.replace(qpos = mj_model.key_qpos[0])
mj_data.qpos = mj_model.key_qpos[0]

# Calculate equilibrum
mjx_data = acceleration_mtu.calc_equilibrium(mjx_model, mjx_data)
mjx_data = mjx_step(mjx_model, mjx_data)

previous_frame_time = time.time()
i = 0
key = jax.random.key(seed)
with mujoco.viewer.launch_passive(mj_model, mj_data) as viewer:
    while viewer.is_running():
        i += 1
        if(i ==4):
            time.sleep(5)
            # pass
        # Update mjx_data from mj_data. The mj_data was modified by the viewer
        # mjx_data = mjx_data.replace(ctrl=mj_data.ctrl, xfrc_applied=mj_data.xfrc_applied)
        # Use the nerual network to generate ctrl signal
        # Generate key
        
        mjx_data = mjx_data.replace(xfrc_applied=jp.array(mj_data.xfrc_applied, dtype=jp.float32))
        
        mjx_data = mjx_data.replace(
            qpos= jp.array(mj_data.qpos, dtype=jp.float32),
            qvel= jp.array(mj_data.qvel, dtype=jp.float32),
            time = jp.array(mj_data.time, dtype=jp.float32))
        
        # Update mjx_model from mj_model
        mjx_model = mjx_model.tree_replace({
            'opt.gravity': jp.array(mj_model.opt.gravity, dtype=jp.float32),
            'opt.tolerance': jp.array(mj_model.opt.tolerance, dtype=jp.float32),
            'opt.ls_tolerance': jp.array(mj_model.opt.ls_tolerance, dtype=jp.float32),
            'opt.timestep': jp.array(mj_model.opt.timestep, dtype=jp.float32),
        })
        
        # mjx_data = mjx_step(mjx_model, mjx_data)
        mjx_data, key = jit_nn_mjx_one_step(params, mjx_model, mjx_data, key)
        # mjx_data, key = nn_mjx_multi_steps(params, mjx_model, mjx_data, key)
        
        mjx.get_data_into(mj_data, mj_model, mjx_data)
        
        # Record the current time at the start of this frame
        current_frame_time = time.time()
    
        # Calculate the difference in time from the last frame
        time_between_frames = current_frame_time - previous_frame_time
    
        # Print the time between frames
        print(f"Time between frames: {time_between_frames} seconds")
        previous_frame_time = current_frame_time
        
        # print("ACT:", mjx_data.biomtu.act)
        # print(mjx_data.qpos)
        print(mjx_data.sensordata[2])
        # print(len(mjx_data.qvel))
        viewer.sync()

TRACING CACHE MISS at /tmp/ipykernel_75063/941529699.py:13:11 (<module>) because:
  never seen function:
    calc_equilibrium.<locals>.callback_optimize_seperate id=135275330044608 defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/mujoco/mjx/_src/biomtu/acceleration_mtu.py:241
  but seen another function defined on the same line; maybe the function is
  being re-defined repeatedly, preventing caching?
PERSISTENT COMPILATION CACHE MISS for 'jit_callback_optimize_seperate' with key 'jit_callback_optimize_seperate-282b895711dd23c129b2749d2e5fe7bca82802f56a42781d14d2a598690321fa'
Not writing persistent cache entry for 'jit_callback_optimize_seperate' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
TRACING CACHE MISS at /tmp/ipykernel_75063/941529699.py:13:11 (<module>) because:
  never seen function:
    calc_equilibrium.<locals>.goal id=135275330046688 defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/mujoco/

Time between frames: 0.1699833869934082 seconds
1.6315
Time between frames: 0.010989189147949219 seconds
1.6314881
Time between frames: 0.01815342903137207 seconds
1.6314712
Time between frames: 5.010533809661865 seconds
1.6314607
Time between frames: 0.016565322875976562 seconds
1.6314473
Time between frames: 0.01455068588256836 seconds
1.6313987
Time between frames: 0.01434779167175293 seconds
1.6312736
Time between frames: 0.01475834846496582 seconds
1.6310743
Time between frames: 0.013857603073120117 seconds
1.6308093
Time between frames: 0.01648426055908203 seconds
1.6304877
Time between frames: 0.013596057891845703 seconds
1.6301588
Time between frames: 0.014947175979614258 seconds
1.6298215
Time between frames: 0.013559818267822266 seconds
1.6294746
Time between frames: 0.015026569366455078 seconds
1.6291128
Time between frames: 0.015181779861450195 seconds
1.6287258
Time between frames: 0.014060735702514648 seconds
1.6283094
Time between frames: 0.018886089324951172 seconds
1.6

## Test Gradient for activation

In [None]:
n = mjx_model.nbiomtu
act = jp.ones(n)*0.1

In [None]:
# jax.config.update("jax_debug_nans", True)
for i in range(10):
    g = jit_goal_grad(act, mjx_data, mjx_model)
    # g = goal_grad(act, mjx_data, mjx_model)
print(g)

In [None]:
v_act = jp.ones((50,n))*0.3
for i in range(5):
    v_g = jit_v_goal_grad(v_act,mjx_data,mjx_model)
print(v_g)

In [None]:
for i in range(10):
    reward = jit_goal(act, mjx_data, mjx_model)
print(reward)

In [None]:
act = jp.ones(n)*0.0
mtu = mjx_data.biomtu
mtu = mtu.replace(act = act)
test_d = mjx_data
test_d = test_d.replace(biomtu = mtu)
for i in range(100):
    test_d = mjx_step(mjx_model, test_d)
print(test_d.xipos)

In [None]:
print(test_d.biomtu.f_se)
print(test_d.biomtu.f_ce)
print(test_d.biomtu.l)
print(test_d.biomtu.tendon_l - mjx_model.biomtu_tendon_slack_l)
print(test_d.biomtu.fiber_l - mjx_model.biomtu_ofl)
print(test_d.biomtu.fiber_l)
print(test_d.biomtu.act)

In [None]:
print(mjx_data.xpos)

In [None]:
mjx_model.nbiomtu

In [None]:
batch_size = 40
seed = 2024
key = jax.random.key(seed)
rngs = jax.random.split(key, batch_size) 
mjx_data_batch, rngs2 = vrandom_init(mjx_data, mjx_model, rngs)

In [None]:
print(rngs.shape)
print(rngs2.shape)
print(mjx_data_batch.biomtu.l.shape)

In [None]:
from functools import *

# A normal function
def test_f(input):
    a,b,c = input
    return 100 * a + 10 * b + c, a*10

test_g = jax.jit(jax.grad(test_f, has_aux=True))


r = test_g((1.,3.,4.))

print(r)

In [18]:
mjx_data.xfrc_applied.shape

(14, 6)

In [29]:
a = jp.ones(54)
b = jp.zeros(30)
c = jp.pad(a,15)
print(c)

TRACING CACHE MISS at /tmp/ipykernel_123049/1407306211.py:3:4 (<module>) because:
  for _pad defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:4203
explanation unavailable! please open an issue at https://github.com/jax-ml/jax
PERSISTENT COMPILATION CACHE MISS for 'jit__pad' with key 'jit__pad-bf6f6a5add0d571c16c6e1c3eadd63df0f1b02f9602e1b78c7f03fa15ed56bbb'


Not writing persistent cache entry for 'jit__pad' because it took < 2.00 seconds to compile (0.02s)


[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]


In [28]:
print([0]*30)

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
