# Simple RL code to teach the model to stand

## Setup jax enviroment and Load Model 

In [1]:
import numpy as np
import mediapy as media
import matplotlib.pyplot as plt


import os
# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".60"

# Optionally, force JAX to preallocate memory.
# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "true"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

# Setup environment variable for Nvidia GPU acceleration
os.environ['XLA_FLAGS'] = (
    '--xla_gpu_triton_gemm_any=True '
)

backend = 'gpu'
# backend = 'METAL'
# backend = 'cpu'

import jax
# Enable compliation catch
os.environ["JAX_COMPILATION_CACHE_DIR"] = "./jax_cache"
jax.config.update("jax_compilation_cache_dir", "./jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 5)
jax.config.update("jax_explain_cache_misses", True)

from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir("./jax_cache")
# Debug Nan
# jax.config.update("jax_debug_nans", True)

from jax import numpy as jp
# More legible printing from numpy.
jp.set_printoptions(precision=4, suppress=True, linewidth=100)

import mujoco
import mujoco.mjx as mjx
from mujoco.mjx._src import scan
from mujoco.mjx._src import types

# More legible printing from numpy.
np.set_printoptions(precision=4, suppress=True, linewidth=100)

from IPython.display import clear_output
clear_output()

device = jax.devices(backend=backend)[0]

model_path = '/home/bugman/Currentwork/biomujoco_converter/converted/mjc/Gait2354/gait2354_cvt1.xml'

# Single step
mjx_step = jax.jit(mjx.step, backend=backend)



# mjx_multiple_steps = jax.jit(multiple_steps, backend=backend, )

In [2]:
from mujoco.mjx._src.biomtu.acceleration_mtu import acceleration_mtu

mj_model = mujoco.MjModel.from_xml_path(model_path)
mjx_model = mjx.put_model(mj_model,device=device)

# Disable tendon
opt = mjx_model.opt.replace(disableflags = mjx_model.opt.disableflags |mujoco.mjtDisableBit.mjDSBL_PASSIVE)
mjx_model = mjx_model.replace(opt=opt)

mjx_data = mjx.make_data(mjx_model)
mj_data = mujoco.MjData(mj_model)

# Load the Keyframe
mjx_data = mjx_data.replace(qpos = mj_model.key_qpos[0])
mj_data.qpos = mj_model.key_qpos[0]

# Calculate equilibrum
mjx_data = acceleration_mtu.calc_equilibrium(mjx_model, mjx_data)
mjx_data = mjx_step(mjx_model, mjx_data)

def print_all():
    print(mjx_model.biomtu_adr)
    print(mjx_model.mtu_wrap_objid)
    print(mjx_model.mtu_wrap_type)
    print(mjx_model.biomtu_fiso)
    print(mjx_model.biomtu_vmax)
    print(mjx_model.biomtu_ofl)
    print(mjx_model.biomtu_opa)
    print(mjx_model.biomtu_mass)
    print("-------Data--------")
    print("qpos:", mjx_data.qpos)
    print("mtu l:", mjx_data.biomtu.l)
    print("tendon l:", mjx_data.biomtu.tendon_l)
    print("fiber l :", mjx_data.biomtu.fiber_l)
    print("Muscle Bce:", mjx_data.biomtu.B_ce)
    print("Muscle vm:", mjx_data.biomtu.m)
    print("Fiber acc:", mjx_data.biomtu.fiber_acc)
    print("Fiber v:", mjx_data.biomtu.fiber_v)
    print("Biomtu h:", mjx_data.biomtu.h)
    print(mjx_data.biomtu.v)
    print(mjx_data.biomtu.h)  # The constant high of the muscle.
    print(mjx_data.biomtu.pennation_angle)
    print(mjx_data.biomtu.origin_body_id)
    print(mjx_data.biomtu.insertion_body_id)
    print("mtu act:", mjx_data.biomtu.act)
    # print(mjx_data.biomtu.j)
    print(mjx_data.qfrc_biomtu)
    print(mj_model.key_time)
    print(mj_model.key_qpos)
    print(mj_model.key_qvel)

print_all()

TRACING CACHE MISS at /tmp/ipykernel_57038/3517667697.py:4:12 (<module>) because:
  for convert_element_type defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/dispatch.py:98
  never seen input type signature:
    args[0]: f32[46,3]
  closest seen input type signature has 1 mismatches, including:
    * at args[0], seen f32[4,3], but now given f32[46,3]
PERSISTENT COMPILATION CACHE MISS for 'jit_convert_element_type' with key 'jit_convert_element_type-83e5920c9ffc098718dc3d3313149072e8c5849ce4f510a7cf10e1f9bddcd9aa'
Not writing persistent cache entry for 'jit_convert_element_type' because it took < 5.00 seconds to compile (0.01s)
TRACING CACHE MISS at /tmp/ipykernel_57038/3517667697.py:4:12 (<module>) because:
  for convert_element_type defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/dispatch.py:98
  never seen input type signature:
    args[0]: f32[74,5,3]
  closest seen input type signature has 1 mismatches, in

--------------------
BUGMAN: muscle name-glut_med1_r
--------------------
BUGMAN: muscle name-glut_med2_r
--------------------
BUGMAN: muscle name-glut_med3_r
--------------------
BUGMAN: muscle name-bifemlh_r
--------------------
BUGMAN: muscle name-bifemsh_r
--------------------
BUGMAN: muscle name-sar_r
--------------------
BUGMAN: muscle name-add_mag2_r
--------------------
BUGMAN: muscle name-tfl_r
--------------------
BUGMAN: muscle name-pect_r
--------------------
BUGMAN: muscle name-grac_r
--------------------
BUGMAN: muscle name-glut_max1_r
--------------------
BUGMAN: muscle name-glut_max2_r
--------------------
BUGMAN: muscle name-glut_max3_r
--------------------
BUGMAN: muscle name-iliacus_r
--------------------
BUGMAN: muscle name-psoas_r
--------------------
BUGMAN: muscle name-quad_fem_r
--------------------
BUGMAN: muscle name-gem_r
--------------------
BUGMAN: muscle name-peri_r
--------------------
BUGMAN: muscle name-rect_fem_r
--------------------
BUGMAN: muscle nam

Not writing persistent cache entry for 'jit_convert_element_type' because it took < 5.00 seconds to compile (0.00s)
TRACING CACHE MISS at /tmp/ipykernel_57038/3517667697.py:4:12 (<module>) because:
  for convert_element_type defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/dispatch.py:98
  never seen input type signature:
    args[0]: f32[85,4,3]
  closest seen input type signature has 1 mismatches, including:
    * at args[0], seen f32[4,3], but now given f32[85,4,3]
PERSISTENT COMPILATION CACHE MISS for 'jit_convert_element_type' with key 'jit_convert_element_type-913d59af2f8854e77a2a2ce2a4fee781573669fe5f24ee57a082de99eae7fdbd'
Not writing persistent cache entry for 'jit_convert_element_type' because it took < 5.00 seconds to compile (0.01s)
TRACING CACHE MISS at /tmp/ipykernel_57038/3517667697.py:4:12 (<module>) because:
  for convert_element_type defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/dispatch.py

[  0   2   4   6   9  12  17  19  23  25  29  33  37  41  46  51  53  55  58  61  65  68  70  74
  77  79  81  83  86  89  94  96 100 102 106 110 114 118 123 128 130 132 135 138 142 145 147 151
 154 156 158 160 162 164]
[  0  58   1  59   2  60   3  86  87  61  88  89   4  62  90  91  92   5  63   6  64  65  93   7
  66   8  94  95  96   9  10  67  68  11  12  69  70  13  14  71  72  15  16  17  73  74  18  19
  20  75  76  21  77  22  78  23  24  79  25  80  97  81  82  83  98  84  85 104  99 105 100 101
 106 107 102 103 108  26 109  27 110  28 111  29 137 138 112 139 140  30 113 141 142 143  31 114
  32 115 116 144  33 117  34 145 146 147  35  36 118 119  37  38 120 121  39  40 122 123  41  42
  43 124 125  44  45  46 126 127  47 128  48 129  49  50 130  51 131 148 132 133 134 149 135 136
 155 150 156 151 152 157 158 153 154 159  52 161  53 162  54 163  55 164  56 165  57 166]
[3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 

In [3]:
print(mjx_model.nbiomtu)
print(mjx_model.nq)
print(mjx_data.qpos)
print(mjx_data.qvel)

54
27
[ 0.      0.9501  0.      0.0006  0.     -0.     -0.0013 -0.     -0.0009 -0.0037 -0.3957  0.0023
 -0.0199 -0.0032  0.0228 -0.0013 -0.     -0.0009 -0.0037 -0.3957  0.0023 -0.0199 -0.0032  0.0228
 -0.0007 -0.      0.    ]
[ 0.0149  0.0271  0.      0.3064  0.     -0.     -0.6693 -0.0194 -0.4365 -0.0074  0.0056  1.1327
 -9.9627 -1.625  11.377  -0.6693 -0.0194 -0.4365 -0.0074  0.0056  1.1327 -9.9628 -1.625  11.377
 -0.3256 -0.      0.    ]


## Neural Network

In [3]:
import jax
import jax.numpy as jp
import flax
import flax.linen as nn
import optax

class Controller_NN(nn.Module):
    # It will output the mean and variance for each muscle's activation
    nbiomtu = 54
    def setup(self):
        # Features means the output dimension# Single step

        self.linear1 = nn.Dense(features=400)
        self.linear2 = nn.Dense(features=400)
        self.linear3 = nn.Dense(features=400)
        self.linear4 = nn.Dense(features=400)
        # The last layer will output the mean and logstd
        self.linear5 = nn.Dense(features=self.nbiomtu*2)
        
    
    def __call__(self, x, key):
        x = self.linear1(x)
        x = nn.relu(x)
        x = self.linear2(x)
        x = nn.relu(x)
        x = self.linear3(x)
        x = nn.relu(x)
        x = self.linear4(x)
        x = nn.relu(x)
        x = self.linear5(x)
        # The last layer of the neural requires samping
        mean = x[0:self.nbiomtu]
        logstd = x[self.nbiomtu:self.nbiomtu*2]
        std = jp.exp(logstd)
        samples = jp.clip(jax.random.normal(key)*std*0.3 + mean, 0 ,1)
        
        return samples, mean, logstd


# Test the neural network
control_model = Controller_NN()

print(control_model)
# Init the model
key = jax.random.PRNGKey(66)
sub_keys = jax.random.split(key,1)
# The second parameter is the dommy input
params = control_model.init(key,jp.empty([1,mjx_model.nq*2]),sub_keys[0])
# print(params)
print(control_model.apply( params,jp.ones(mjx_model.nq*2), sub_keys[0]))
jit_nn_apply = jax.jit(lambda params,states,key : control_model.apply(params,states,key))

PERSISTENT COMPILATION CACHE MISS for 'jit__threefry_seed' with key 'jit__threefry_seed-91ef0eb35b4eeb61b70409366ebf41fd7bcde007b213947867e7a4ea5d270d11'
Not writing persistent cache entry for 'jit__threefry_seed' because it took < 5.00 seconds to compile (0.02s)
TRACING CACHE MISS at /tmp/ipykernel_57038/1754672284.py:46:11 (<module>) because:
  never seen function:
    _threefry_split id=134369888628512 defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/prng.py:1089
PERSISTENT COMPILATION CACHE MISS for 'jit__threefry_split' with key 'jit__threefry_split-ef7af2667637d1d3fdcd9d9de71a352335a4ca1497908902e5a04bdc3ccd9469'
Not writing persistent cache entry for 'jit__threefry_split' because it took < 5.00 seconds to compile (0.06s)
TRACING CACHE MISS at /tmp/ipykernel_57038/1754672284.py:48:32 (<module>) because:
  never seen function:
    broadcast_in_dim id=134362890747744 defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/

Controller_NN()


Not writing persistent cache entry for 'jit__threefry_fold_in' because it took < 5.00 seconds to compile (0.06s)
PERSISTENT COMPILATION CACHE MISS for 'jit_sqrt' with key 'jit_sqrt-fa363c78eb210675c8f20e01b4989e1b39cdc900eb109eafd24a2ed3dafd6426'
Not writing persistent cache entry for 'jit_sqrt' because it took < 5.00 seconds to compile (0.02s)
PERSISTENT COMPILATION CACHE MISS for 'jit_true_divide' with key 'jit_true_divide-e660470ad54bb12fce1ce71327d437fdd083625c7de12830b22d6ead90aced92'
Not writing persistent cache entry for 'jit_true_divide' because it took < 5.00 seconds to compile (0.02s)
TRACING CACHE MISS at /tmp/ipykernel_57038/1754672284.py:22:12 (Controller_NN.__call__) because:
  never seen function:
    _uniform id=134369888641376 defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/random.py:405
  but seen another function defined on the same line; maybe the function is
  being re-defined repeatedly, preventing caching?
TRACING CACHE MISS

(Array([0.3537, 0.4582, 0.    , 0.    , 0.3568, 0.    , 0.    , 0.    , 0.    , 0.0299, 0.1213,
       0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.0104, 0.    , 0.    , 0.    ,
       0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.4784, 0.    , 0.2003, 0.    , 0.    ,
       0.    , 0.    , 0.2286, 0.    , 0.    , 0.    , 0.    , 0.2364, 0.    , 0.    , 0.    ,
       0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    ],      dtype=float32), Array([ 0.5507,  0.7117,  0.1663, -0.386 ,  0.572 ,  0.0428,  0.151 , -0.341 , -0.1592,  0.1468,
        0.3125, -0.506 , -0.098 , -0.1812, -0.2472, -0.1027, -0.2657, -0.0098,  0.1604,  0.1802,
       -0.4265, -0.2269, -0.2294, -0.0505,  0.1361,  0.0374,  0.0641, -0.0512,  0.6791, -0.3706,
        0.4092, -0.1719, -0.2406, -0.022 ,  0.1271,  0.4203,  0.0284,  0.0943, -0.4435,  0.1295,
        0.4245,  0.0444, -0.1952,  0.121 , -0.08  , -0.3598, -0.0973, -0.0134,  0.0238, -0.2089,
       -0.5425, -0.3392, 

## Combine Neural Net and Simulation into one Jax function

In [4]:
# Multiple steps
def step_fn(carry, _):
    data, model = carry
    new_data = mjx.step(model, data)
    new_carry = (new_data, model)
    return new_carry, _

def multiple_steps(model, data):
    init_carry = (data, model)
    y, _ = jax.lax.scan(step_fn, init_carry, None, length=10)
    new_data = y[0]
    return new_data

def nn_mjx_one_step(nn_params, model, data, key):
    states = jp.concatenate([data.qpos, data.qvel])
    act = jit_nn_apply(nn_params, states, key)[0]
    # Generate the next key
    new_key = jax.random.split(key,1)[0]
    mtu = data.biomtu
    mtu = mtu.replace(act = act)
    data = data.replace(biomtu = mtu)
    new_data = mjx.step(model, data)
    return new_data, new_key

@jax.jit
def nn_mjx_multi_steps(nn_params, model, data, key):
    def nn_step_fn(carry, _):
        data, key = carry
        new_data, new_key = nn_mjx_one_step(nn_params, model, data, key)
        new_carry = (new_data, new_key)
        # head_hight = new_data.xpos[2,2]
        # jax.debug.print("Head Height {0}",head_hight)
        return new_carry, _
    init_carry = (data, key)
    y, _ = jax.lax.scan(nn_step_fn, init_carry, None, length=10)
    new_data = y[0]
    new_key = y[1]
    return new_data, new_key

## Testing control model with neural networks

In [6]:
import mujoco.viewer
import time

previous_frame_time = time.time()
i = 0
key = jax.random.key(334)
with mujoco.viewer.launch_passive(mj_model, mj_data) as viewer:
    while viewer.is_running():
        # Update mjx_data from mj_data. The mj_data was modified by the viewer
        # mjx_data = mjx_data.replace(ctrl=mj_data.ctrl, xfrc_applied=mj_data.xfrc_applied)
        # Use the nerual network to generate ctrl signal
        # Generate key
        
        mjx_data = mjx_data.replace(xfrc_applied=jp.array(mj_data.xfrc_applied, dtype=jp.float32))
        mjx_data = mjx_data.replace(
            qpos= jp.array(mj_data.qpos, dtype=jp.float32),
            qvel= jp.array(mj_data.qvel, dtype=jp.float32),
            time = jp.array(mj_data.time, dtype=jp.float32))
        
        # Update mjx_model from mj_model
        mjx_model = mjx_model.tree_replace({
            'opt.gravity': jp.array(mj_model.opt.gravity, dtype=jp.float32),
            'opt.tolerance': jp.array(mj_model.opt.tolerance, dtype=jp.float32),
            'opt.ls_tolerance': jp.array(mj_model.opt.ls_tolerance, dtype=jp.float32),
            'opt.timestep': jp.array(mj_model.opt.timestep, dtype=jp.float32),
        })
        
        # mjx_data = mjx_step(mjx_model, mjx_data)
        mjx_data, key = nn_mjx_multi_steps(params, mjx_model, mjx_data, key)
        mjx.get_data_into(mj_data, mj_model, mjx_data)
        
        # Record the current time at the start of this frame
        current_frame_time = time.time()
    
        # Calculate the difference in time from the last frame
        time_between_frames = current_frame_time - previous_frame_time
    
        # Print the time between frames
        print(f"Time between frames: {time_between_frames} seconds")
        previous_frame_time = current_frame_time
        
        # print("ACT:", mjx_data.biomtu.act)
        # print(mjx_data.qpos)
        print(mjx_data.sensordata)
        # print(len(mjx_data.qvel))
        viewer.sync()

TRACING CACHE MISS at /tmp/ipykernel_37373/366203590.py:14:49 (<module>) because:
  for convert_element_type defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/dispatch.py:98
  never seen input type signature:
    args[0]: f32[14,6]
  closest seen input type signature has 1 mismatches, including:
    * at args[0], seen f32[4,3], but now given f32[14,6]
PERSISTENT COMPILATION CACHE MISS for 'jit_convert_element_type' with key 'jit_convert_element_type-29ce4cc958dc08194a2efedbb735f0ccfdf94e7887751ae02bdf8b5bdb1916db'
Not writing persistent cache entry for 'jit_convert_element_type' because it took < 5.00 seconds to compile (0.01s)
TRACING CACHE MISS at /tmp/ipykernel_37373/366203590.py:16:18 (<module>) because:
  for convert_element_type defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/dispatch.py:98
  never seen input type signature:
    args[0]: f32[27]
  closest seen input type signature has 1 mismatches, includ

Time between frames: 19.109546422958374 seconds
[-0.1023 -0.0011  1.63  ]


PERSISTENT COMPILATION CACHE MISS for 'jit_nn_mjx_multi_steps' with key 'jit_nn_mjx_multi_steps-cfd6de580c7d5a527f44582fc2616e7201296b374d891a1d67087ee4600d754d'
Writing jit_nn_mjx_multi_steps to persistent compilation cache with key 'jit_nn_mjx_multi_steps-cfd6de580c7d5a527f44582fc2616e7201296b374d891a1d67087ee4600d754d'


Time between frames: 16.604049682617188 seconds
[-0.1085 -0.0009  1.6253]
Time between frames: 0.04529523849487305 seconds
[-0.1169  0.0003  1.6192]
Time between frames: 0.041046142578125 seconds
[-0.1264  0.0008  1.6158]
Time between frames: 0.041832685470581055 seconds
[-0.1389 -0.0008  1.6128]
Time between frames: 0.034471750259399414 seconds
[-0.1544 -0.005   1.6075]
Time between frames: 0.0346980094909668 seconds
[-0.1728 -0.0096  1.5996]
Time between frames: 0.03499436378479004 seconds
[-0.1935 -0.0136  1.5892]
Time between frames: 0.036148786544799805 seconds
[-0.2155 -0.018   1.5753]
Time between frames: 0.033162593841552734 seconds
[-0.24   -0.0229  1.5577]
Time between frames: 0.03453779220581055 seconds
[-0.267  -0.0285  1.5363]
Time between frames: 0.03453516960144043 seconds
[-0.296  -0.0346  1.5103]
Time between frames: 0.026415109634399414 seconds
[-0.3255 -0.0405  1.4797]
Time between frames: 0.036164283752441406 seconds
[-0.3547 -0.0457  1.4433]
Time between frames: 0.

## Batched Random Init Model

In [5]:
def random_init(data, model, rng: jax.Array):
    nbiomtu = model.nq
    init_qpos = data.qpos
    init_qvel = data.qvel
    new_rng, rng1, rng2 = jax.random.split(rng, 3)
    # Qpos_1 is the vertical position
    random_qpos = init_qpos + jax.random.uniform(rng1, [nbiomtu], minval=jp.array(-1.0), maxval=jp.array(1.0))*0.01
    random_qvel = init_qvel + jax.random.uniform(rng2, [nbiomtu], minval=jp.array(-1.0), maxval=jp.array(1.0))*0.01
    newdata = data.replace(qpos=random_qpos)
    newdata = newdata.replace(qvel=random_qvel)
    newdata = mjx.forward(mjx_model, newdata)
    # print('data:',data.qpos, data.qvel)
    # Calculate equilibrum
    # newdata = acceleration_mtu.calc_equilibrium(mjx_model, newdata)
    # newdata = mjx_step(mjx_model, newdata)
    return newdata, new_rng

vrandom_init = jax.jit(jax.vmap(random_init, in_axes=(None, None, 0),out_axes=(0,0)))

## Reward Function

In [6]:

def nn_step_fn(carry, _):
    data, key, nn_params, model = carry
    new_data, new_key = nn_mjx_one_step(nn_params, model, data, key)
    new_carry = (new_data, new_key, nn_params, model)
    head_hight = new_data.sensordata[2]
    # jax.debug.print("Head Height {0}",head_hight)
    return new_carry, (head_hight-1.63)**2
    
def decay_sum_scan(x, decay):
    def f(sxtm1, xt):
        b = xt + decay * sxtm1
        return b, b
    return jax.lax.scan(f, jp.zeros(x.shape[1:]), x)[1]

def reward_n_step(nn_params, model, data, key):
    repeat_n = 150
    _, head_heights = jax.lax.scan(nn_step_fn, (data, key, nn_params, model), None, repeat_n)
    reward = decay_sum_scan(head_heights, 0.95)[repeat_n -1]
    return reward
    
# reward_grad = jax.jit(jax.grad(reward_n_step))


def batch_reward(nn_params,model, batched_data, keys):
    return jp.mean(jax.vmap(reward_n_step, (None, None, 0, 0))(nn_params, model, batched_data, keys))

jit_batch_reward_grad = jax.jit(jax.grad(batch_reward))
jit_batch_reward = jax.jit(batch_reward)

In [7]:
# Generate batched data
batch_size = 40
seed = 2024
key = jax.random.key(seed)
rngs = jax.random.split(key, batch_size) 
mjx_data_batch, rngs = vrandom_init(mjx_data, mjx_model, rngs)
r = jit_batch_reward(params,mjx_model, mjx_data_batch, rngs)
print(r)
print("Calculating reward grad")
g = jit_batch_reward_grad(params, mjx_model, mjx_data_batch, rngs)
mjx_data_batch, rngs = vrandom_init(mjx_data, mjx_model, rngs)
g = jit_batch_reward_grad(params, mjx_model, mjx_data_batch, rngs)


TRACING CACHE MISS at /tmp/ipykernel_57038/3863415242.py:5:7 (<module>) because:
  for _threefry_split defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/prng.py:1089
explanation unavailable! please open an issue at https://github.com/jax-ml/jax


PERSISTENT COMPILATION CACHE MISS for 'jit__threefry_split' with key 'jit__threefry_split-54245a166cda035886c2cfb90644e0c7d4b3fe774ab6e9aff9abb4c7729787d9'
Not writing persistent cache entry for 'jit__threefry_split' because it took < 5.00 seconds to compile (0.07s)
TRACING CACHE MISS at /tmp/ipykernel_57038/3885834202.py:7:30 (random_init) because:
  for _uniform defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/random.py:405
  never seen input type signature:
    key: key<fry>[],  minval: f32[],  maxval: f32[]
  closest seen input type signature has 2 mismatches, including:
    * at minval, seen f32[]{weak_type=False}, but now given f32[]{weak_type=True}
    * at maxval, seen f32[]{weak_type=False}, but now given f32[]{weak_type=True}
where weak_type=True often means a Python builtin numeric value, and 
weak_type=False means a jax.Array.
See https://jax.readthedocs.io/en/latest/type_promotion.html#weak-types
TRACING CACHE MISS at /tmp/ipykernel_57

0.6107473
Calculating reward grad


TRACING CACHE MISS at /tmp/ipykernel_57038/3863415242.py:10:4 (<module>) because:
  never seen function:
    batch_reward id=134362887627744 defined at /tmp/ipykernel_57038/467779202.py:24


In [11]:
print(r)
print(g)

0.6248491
{'params': {'linear1': {'bias': Array([ 0.0113, -0.0031,  0.0193, -0.0159,  0.0188,  0.0618, -0.0013,  0.0017,  0.0206,  0.0405,
        0.0384, -0.0051,  0.0245,  0.0581,  0.0464, -0.0008, -0.0062, -0.0645, -0.0437,  0.0187,
       -0.0798,  0.0065,  0.0208,  0.0083, -0.0185, -0.0087,  0.0019, -0.0015, -0.0039, -0.0863,
        0.0261,  0.0025, -0.0402,  0.0024, -0.0063,  0.0027,  0.0142,  0.0559,  0.0038, -0.0009,
        0.0037,  0.0047, -0.0024, -0.0058, -0.021 , -0.0021,  0.0076,  0.0282, -0.0099,  0.0153,
        0.0053, -0.0023,  0.0004, -0.0269, -0.0025, -0.0263, -0.0044, -0.0323,  0.0036, -0.0034,
        0.0079, -0.0089, -0.0066,  0.0199,  0.0166, -0.0039, -0.0569,  0.0087,  0.0141, -0.002 ,
       -0.0068, -0.0739, -0.0019,  0.0113,  0.0035,  0.035 , -0.0364,  0.0187, -0.0426,  0.0295,
        0.0047,  0.01  , -0.0425, -0.0551, -0.0104,  0.0183,  0.0057,  0.023 , -0.0059, -0.0228,
       -0.0001,  0.0243,  0.0178,  0.0091,  0.0141, -0.0666, -0.0012,  0.0024, -0.019

## Train the NN

In [8]:
mjx_data_batch, rngs = vrandom_init(mjx_data, mjx_model, rngs)

In [9]:
params = control_model.init(key,jp.empty([1,mjx_model.nq*2]),sub_keys[0])


In [None]:
tx = optax.adam(learning_rate=0.0003)
opt_state = tx.init(params)

for i in range(400):
    # generate random mjx_data
    mjx_data_batch, rngs = vrandom_init(mjx_data, mjx_model, rngs)
    # print(rngs[0])
    r = jit_batch_reward(params,mjx_model, mjx_data_batch, rngs)
    print(r)
    g = jit_batch_reward_grad(params, mjx_model, mjx_data_batch, rngs)
    updates, opt_stats = tx.update(g, opt_state)
    params = optax.apply_updates(params, updates)
    print("params updated")

TRACING CACHE MISS at /tmp/ipykernel_57038/556588624.py:2:12 (<module>) because:
  never seen function:
    broadcast_in_dim id=134355267734720 defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/dispatch.py:98
PERSISTENT COMPILATION CACHE MISS for 'jit_broadcast_in_dim' with key 'jit_broadcast_in_dim-842b0b11f2b82d19136d29f19721e0c0ef79d4653409364adec42bc9daafdb29'
Not writing persistent cache entry for 'jit_broadcast_in_dim' because it took < 5.00 seconds to compile (0.02s)
TRACING CACHE MISS at /tmp/ipykernel_57038/556588624.py:2:12 (<module>) because:
  never seen function:
    broadcast_in_dim id=134372204104736 defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/dispatch.py:98
  but seen another function defined on the same line; maybe the function is
  being re-defined repeatedly, preventing caching?
PERSISTENT COMPILATION CACHE MISS for 'jit_broadcast_in_dim' with key 'jit_broadcast_in_dim-43a81176cbc2d4bffd3

0.56852835


TRACING CACHE MISS at /tmp/ipykernel_57038/556588624.py:11:25 (<module>) because:
  never seen function:
    integer_pow id=134355319782144 defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/dispatch.py:98
PERSISTENT COMPILATION CACHE MISS for 'jit_integer_pow' with key 'jit_integer_pow-a2fc0e4f7e615dace26a72c94da56f86daa798ff1fcd797b07f0549f1d200672'
Not writing persistent cache entry for 'jit_integer_pow' because it took < 5.00 seconds to compile (0.01s)
PERSISTENT COMPILATION CACHE MISS for 'jit__multiply' with key 'jit__multiply-93bd4d6a8f702625430d5f1915361a3f0399e239163ec3a60eee0f4455580ca4'
Not writing persistent cache entry for 'jit__multiply' because it took < 5.00 seconds to compile (0.02s)
PERSISTENT COMPILATION CACHE MISS for 'jit__multiply' with key 'jit__multiply-17150907eed9444c6d938723b37acad7a303ec401b215fb001350d3325cfd3aa'
Not writing persistent cache entry for 'jit__multiply' because it took < 5.00 seconds to compile (0.01s)
PERSI

params updated
0.030625448


PERSISTENT COMPILATION CACHE MISS for 'jit__add' with key 'jit__add-12d26b1b4f80e3dca5466173c35a4f35effe8aab62c4a46e04444b1ad9ac1363'
Not writing persistent cache entry for 'jit__add' because it took < 5.00 seconds to compile (0.01s)
PERSISTENT COMPILATION CACHE MISS for 'jit__add' with key 'jit__add-7b2e1b57c81fe9f806f2a3d68303f26a37e7c0b3def65a30dbc680091a94bf23'
Not writing persistent cache entry for 'jit__add' because it took < 5.00 seconds to compile (0.01s)
PERSISTENT COMPILATION CACHE MISS for 'jit__add' with key 'jit__add-97e131b1dcc009d8b002386393e85b42600944bcfe0649ae96389058ab7e73cd'
Not writing persistent cache entry for 'jit__add' because it took < 5.00 seconds to compile (0.01s)
PERSISTENT COMPILATION CACHE MISS for 'jit__add' with key 'jit__add-9973d82a9b58c492540220189967a2f88b575694caa21e8f0979e63481abfae4'
Not writing persistent cache entry for 'jit__add' because it took < 5.00 seconds to compile (0.01s)
PERSISTENT COMPILATION CACHE MISS for 'jit__add' with key 'jit__

params updated
0.112797454
params updated
0.031182308
params updated
0.0022083796
params updated
0.03267385
params updated
0.0005803454
params updated
0.0007010314
params updated
0.0307364
params updated
0.00031705803
params updated
0.0082477005
params updated
0.00074461894
params updated
0.0011942253
params updated
0.00035470075
params updated
0.00019800871
params updated
0.0020576608
params updated
0.00023886915
params updated
0.00036383056
params updated
0.00033621697
params updated
0.0020750028
params updated
0.00020769837
params updated
0.0006475648
params updated
0.00027855087
params updated
0.0015251801
params updated
8.68871e-05
params updated
0.00019563285
params updated
0.0008496833
params updated
6.422465e-05
params updated
0.0008825621


## Test Train

In [20]:
jit_nn_mjx_one_step = jax.jit(nn_mjx_one_step)


In [26]:
import mujoco.viewer
import time


mjx_data = mjx.make_data(mjx_model)
mj_data = mujoco.MjData(mj_model)

# Load the Keyframe
mjx_data = mjx_data.replace(qpos = mj_model.key_qpos[0])
mj_data.qpos = mj_model.key_qpos[0]

# Calculate equilibrum
mjx_data = acceleration_mtu.calc_equilibrium(mjx_model, mjx_data)
mjx_data = mjx_step(mjx_model, mjx_data)

previous_frame_time = time.time()
i = 0
key = jax.random.key(seed)
with mujoco.viewer.launch_passive(mj_model, mj_data) as viewer:
    while viewer.is_running():
        i += 1
        if(i ==4):
            time.sleep(0)
            # pass
        # Update mjx_data from mj_data. The mj_data was modified by the viewer
        # mjx_data = mjx_data.replace(ctrl=mj_data.ctrl, xfrc_applied=mj_data.xfrc_applied)
        # Use the nerual network to generate ctrl signal
        # Generate key
        
        mjx_data = mjx_data.replace(xfrc_applied=mj_data.xfrc_applied)
        mjx_data = mjx_data.replace(qpos=mj_data.qpos, qvel=mj_data.qvel, time = mj_data.time)
        
        # Update mjx_model from mj_model
        mjx_model = mjx_model.tree_replace({
            'opt.gravity': mj_model.opt.gravity,
            'opt.tolerance': mj_model.opt.tolerance,
            'opt.ls_tolerance': mj_model.opt.ls_tolerance,
            'opt.timestep': mj_model.opt.timestep,
        })
        
        # mjx_data = mjx_step(mjx_model, mjx_data)
        mjx_data, key = jit_nn_mjx_one_step(params, mjx_model, mjx_data, key)
        # mjx_data, key = nn_mjx_multi_steps(params, mjx_model, mjx_data, key)
        
        mjx.get_data_into(mj_data, mj_model, mjx_data)
        
        # Record the current time at the start of this frame
        current_frame_time = time.time()
    
        # Calculate the difference in time from the last frame
        time_between_frames = current_frame_time - previous_frame_time
    
        # Print the time between frames
        print(f"Time between frames: {time_between_frames} seconds")
        previous_frame_time = current_frame_time
        
        # print("ACT:", mjx_data.biomtu.act)
        # print(mjx_data.qpos)
        print(mjx_data.sensordata[2])
        # print(len(mjx_data.qvel))
        viewer.sync()

Time between frames: 0.16742444038391113 seconds
1.6315
Time between frames: 0.012678384780883789 seconds
1.6314864
Time between frames: 0.013202905654907227 seconds
1.6314447
Time between frames: 0.012907743453979492 seconds
1.6313658
Time between frames: 0.01244211196899414 seconds
1.6312408
Time between frames: 0.010648965835571289 seconds
1.6310611
Time between frames: 0.012441873550415039 seconds
1.6308208
Time between frames: 0.013338565826416016 seconds
1.6305243
Time between frames: 0.012471437454223633 seconds
1.6301804
Time between frames: 0.01353311538696289 seconds
1.6297884
Time between frames: 0.014764785766601562 seconds
1.6293757
Time between frames: 0.013852596282958984 seconds
1.6289992
Time between frames: 0.012926101684570312 seconds
1.6286018
Time between frames: 0.010102272033691406 seconds
1.6282041
Time between frames: 0.011306524276733398 seconds
1.6277759
Time between frames: 0.012139558792114258 seconds
1.627289
Time between frames: 0.0093841552734375 seconds

## Test Gradient for activation

In [14]:
n = mjx_model.nbiomtu
act = jp.ones(n)*0.1

In [None]:
# jax.config.update("jax_debug_nans", True)
for i in range(10):
    g = jit_goal_grad(act, mjx_data, mjx_model)
    # g = goal_grad(act, mjx_data, mjx_model)
print(g)

In [None]:
v_act = jp.ones((50,n))*0.3
for i in range(5):
    v_g = jit_v_goal_grad(v_act,mjx_data,mjx_model)
print(v_g)

In [None]:
for i in range(10):
    reward = jit_goal(act, mjx_data, mjx_model)
print(reward)

In [None]:
act = jp.ones(n)*0.0
mtu = mjx_data.biomtu
mtu = mtu.replace(act = act)
test_d = mjx_data
test_d = test_d.replace(biomtu = mtu)
for i in range(100):
    test_d = mjx_step(mjx_model, test_d)
print(test_d.xipos)

In [None]:
print(test_d.biomtu.f_se)
print(test_d.biomtu.f_ce)
print(test_d.biomtu.l)
print(test_d.biomtu.tendon_l - mjx_model.biomtu_tendon_slack_l)
print(test_d.biomtu.fiber_l - mjx_model.biomtu_ofl)
print(test_d.biomtu.fiber_l)
print(test_d.biomtu.act)

In [None]:
print(mjx_data.xpos)

In [5]:
mjx_model.nbiomtu

54

In [19]:
batch_size = 40
seed = 2024
key = jax.random.key(seed)
rngs = jax.random.split(key, batch_size) 
mjx_data_batch, rngs2 = vrandom_init(mjx_data, mjx_model, rngs)

In [24]:
print(rngs.shape)
print(rngs2.shape)
print(mjx_data_batch.biomtu.l.shape)

(40,)
(40,)
(40, 54)
