# Simple RL code to teach the model to stand

## Setup jax enviroment and Load Model 

In [1]:
import numpy as np
import mediapy as media
import matplotlib.pyplot as plt


import os
# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".60"

# Optionally, force JAX to preallocate memory.
# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "true"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

# Setup environment variable for Nvidia GPU acceleration
os.environ['XLA_FLAGS'] = (
    '--xla_gpu_triton_gemm_any=True '
)

backend = 'gpu'
# backend = 'METAL'
# backend = 'cpu'

import jax
# Enable compliation catch
jax.config.update("jax_compilation_cache_dir", "./jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", 0)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 5)
jax.config.update("jax_explain_cache_misses", True)

# Debug Nan
# jax.config.update("jax_debug_nans", True)

from jax import numpy as jp
# More legible printing from numpy.
jp.set_printoptions(precision=4, suppress=True, linewidth=100)

import mujoco
import mujoco.mjx as mjx
from mujoco.mjx._src import scan
from mujoco.mjx._src import types

# More legible printing from numpy.
np.set_printoptions(precision=4, suppress=True, linewidth=100)

from IPython.display import clear_output
clear_output()

device = jax.devices(backend=backend)[0]

model_path = '/home/bugman/Currentwork/biomujoco_converter/converted/mjc/Gait2354/gait2354_cvt1.xml'

# Single step
mjx_step = jax.jit(mjx.step, backend=backend)



# mjx_multiple_steps = jax.jit(multiple_steps, backend=backend, )

In [2]:
from mujoco.mjx._src.biomtu.acceleration_mtu import acceleration_mtu

mj_model = mujoco.MjModel.from_xml_path(model_path)
mjx_model = mjx.put_model(mj_model,device=device)

# Disable tendon
opt = mjx_model.opt.replace(disableflags = mjx_model.opt.disableflags |mujoco.mjtDisableBit.mjDSBL_PASSIVE)
mjx_model = mjx_model.replace(opt=opt)

mjx_data = mjx.make_data(mjx_model)
mj_data = mujoco.MjData(mj_model)

# Load the Keyframe
mjx_data = mjx_data.replace(qpos = mj_model.key_qpos[0])
mj_data.qpos = mj_model.key_qpos[0]

# Calculate equilibrum
mjx_data = acceleration_mtu.calc_equilibrium(mjx_model, mjx_data)
mjx_data = mjx_step(mjx_model, mjx_data)

def print_all():
    print(mjx_model.biomtu_adr)
    print(mjx_model.mtu_wrap_objid)
    print(mjx_model.mtu_wrap_type)
    print(mjx_model.biomtu_fiso)
    print(mjx_model.biomtu_vmax)
    print(mjx_model.biomtu_ofl)
    print(mjx_model.biomtu_opa)
    print(mjx_model.biomtu_mass)
    print("-------Data--------")
    print("qpos:", mjx_data.qpos)
    print("mtu l:", mjx_data.biomtu.l)
    print("tendon l:", mjx_data.biomtu.tendon_l)
    print("fiber l :", mjx_data.biomtu.fiber_l)
    print("Muscle Bce:", mjx_data.biomtu.B_ce)
    print("Muscle vm:", mjx_data.biomtu.m)
    print("Fiber acc:", mjx_data.biomtu.fiber_acc)
    print("Fiber v:", mjx_data.biomtu.fiber_v)
    print("Biomtu h:", mjx_data.biomtu.h)
    print(mjx_data.biomtu.v)
    print(mjx_data.biomtu.h)  # The constant high of the muscle.
    print(mjx_data.biomtu.pennation_angle)
    print(mjx_data.biomtu.origin_body_id)
    print(mjx_data.biomtu.insertion_body_id)
    print("mtu act:", mjx_data.biomtu.act)
    # print(mjx_data.biomtu.j)
    print(mjx_data.qfrc_biomtu)
    print(mj_model.key_time)
    print(mj_model.key_qpos)
    print(mj_model.key_qvel)

print_all()

TRACING CACHE MISS at /tmp/ipykernel_96018/3517667697.py:4:12 (<module>) because:
  for convert_element_type defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/dispatch.py:97
  never seen input type signature:
    args[0]: f32[46,3]
  closest seen input type signature has 1 mismatches, including:
    * at args[0], seen f32[4,3], but now given f32[46,3]
PERSISTENT COMPILATION CACHE MISS for 'jit_convert_element_type' with key 'jit_convert_element_type-69299ab131d3ee801e01e3f56b4197a59f3ff63fca746edcafe084cca7ab01a9'
Not writing persistent cache entry for 'jit_convert_element_type' because it took < 5.00 seconds to compile (0.00s)
TRACING CACHE MISS at /tmp/ipykernel_96018/3517667697.py:4:12 (<module>) because:
  for convert_element_type defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/dispatch.py:97
  never seen input type signature:
    args[0]: f32[74,5,3]
  closest seen input type signature has 1 mismatches, in

--------------------
BUGMAN: muscle name-glut_med1_r
--------------------
BUGMAN: muscle name-glut_med2_r
--------------------
BUGMAN: muscle name-glut_med3_r
--------------------
BUGMAN: muscle name-bifemlh_r
--------------------
BUGMAN: muscle name-bifemsh_r
--------------------
BUGMAN: muscle name-sar_r
--------------------
BUGMAN: muscle name-add_mag2_r
--------------------
BUGMAN: muscle name-tfl_r
--------------------
BUGMAN: muscle name-pect_r
--------------------
BUGMAN: muscle name-grac_r
--------------------
BUGMAN: muscle name-glut_max1_r
--------------------
BUGMAN: muscle name-glut_max2_r
--------------------
BUGMAN: muscle name-glut_max3_r
--------------------
BUGMAN: muscle name-iliacus_r
--------------------
BUGMAN: muscle name-psoas_r
--------------------
BUGMAN: muscle name-quad_fem_r
--------------------
BUGMAN: muscle name-gem_r
--------------------
BUGMAN: muscle name-peri_r
--------------------
BUGMAN: muscle name-rect_fem_r
--------------------
BUGMAN: muscle nam

Not writing persistent cache entry for 'jit_convert_element_type' because it took < 5.00 seconds to compile (0.00s)
TRACING CACHE MISS at /tmp/ipykernel_96018/3517667697.py:4:12 (<module>) because:
  for convert_element_type defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/dispatch.py:97
  never seen input type signature:
    args[0]: f32[85,4,3]
  closest seen input type signature has 1 mismatches, including:
    * at args[0], seen f32[4,3], but now given f32[85,4,3]
PERSISTENT COMPILATION CACHE MISS for 'jit_convert_element_type' with key 'jit_convert_element_type-dd3984c1dd5349d87f65cbd3e0a016bd08832b9ae07aa38da6a3d92965ef29aa'
Not writing persistent cache entry for 'jit_convert_element_type' because it took < 5.00 seconds to compile (0.00s)
TRACING CACHE MISS at /tmp/ipykernel_96018/3517667697.py:4:12 (<module>) because:
  for convert_element_type defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/dispatch.py

[  0   2   4   6   9  12  17  19  23  25  29  33  37  41  46  51  53  55  58  61  65  68  70  74
  77  79  81  83  86  89  94  96 100 102 106 110 114 118 123 128 130 132 135 138 142 145 147 151
 154 156 158 160 162 164]
[  0  58   1  59   2  60   3  86  87  61  88  89   4  62  90  91  92   5  63   6  64  65  93   7
  66   8  94  95  96   9  10  67  68  11  12  69  70  13  14  71  72  15  16  17  73  74  18  19
  20  75  76  21  77  22  78  23  24  79  25  80  97  81  82  83  98  84  85 104  99 105 100 101
 106 107 102 103 108  26 109  27 110  28 111  29 137 138 112 139 140  30 113 141 142 143  31 114
  32 115 116 144  33 117  34 145 146 147  35  36 118 119  37  38 120 121  39  40 122 123  41  42
  43 124 125  44  45  46 126 127  47 128  48 129  49  50 130  51 131 148 132 133 134 149 135 136
 155 150 156 151 152 157 158 153 154 159  52 160  53 161  54 162  55 163  56 164  57 165]
[3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 

In [3]:
print(mjx_model.nbiomtu)
print(mjx_model.nq)
print(mjx_data.qpos)
print(mjx_data.qvel)

54
27
[ 0.      0.9501  0.      0.0006  0.     -0.     -0.0013 -0.     -0.0009 -0.0037 -0.3957  0.0023
 -0.0199 -0.0033  0.0228 -0.0013 -0.     -0.0009 -0.0037 -0.3957  0.0023 -0.02   -0.0033  0.0228
 -0.0006 -0.      0.    ]
[ 0.0149  0.0271  0.0001  0.3058  0.0001 -0.0008 -0.6684 -0.0194 -0.4337 -0.0074  0.0056  1.1325
 -9.9726 -1.6341 11.3898 -0.6699 -0.0193 -0.4359 -0.0074  0.0056  1.1354 -9.9759 -1.6319 11.3917
 -0.325  -0.0001  0.0008]


## Neural Network

In [4]:
import jax
import jax.numpy as jp
import flax
import flax.linen as nn
import optax

class Controller_NN(nn.Module):
    # It will output the mean and variance for each muscle's activation
    nbiomtu = 54
    def setup(self):
        # Features means the output dimension# Single step

        self.linear1 = nn.Dense(features=400)
        self.linear2 = nn.Dense(features=400)
        self.linear3 = nn.Dense(features=400)
        self.linear4 = nn.Dense(features=400)
        # The last layer will output the mean and logstd
        self.linear5 = nn.Dense(features=self.nbiomtu*2)
        
    
    def __call__(self, x, key):
        x = self.linear1(x)
        x = nn.relu(x)
        x = self.linear2(x)
        x = nn.relu(x)
        x = self.linear3(x)
        x = nn.relu(x)
        x = self.linear4(x)
        x = nn.relu(x)
        x = self.linear5(x)
        # The last layer of the neural requires samping
        mean = x[0:self.nbiomtu]
        logstd = x[self.nbiomtu:self.nbiomtu*2]
        std = jp.exp(logstd)
        samples = jp.clip(jax.random.normal(key)*std*0.3 + mean, 0 ,1)
        
        return samples, mean, logstd


# Test the neural network
control_model = Controller_NN()

print(control_model)
# Init the model
key = jax.random.PRNGKey(66)
sub_keys = jax.random.split(key,1)
# The second parameter is the dommy input
params = control_model.init(key,jp.empty([1,mjx_model.nq*2]),sub_keys[0])
# print(params)
print(control_model.apply( params,jp.ones(mjx_model.nq*2), sub_keys[0]))
jit_nn_apply = jax.jit(lambda params,states,key : control_model.apply(params,states,key))

PERSISTENT COMPILATION CACHE MISS for 'jit__threefry_seed' with key 'jit__threefry_seed-f53b1596a7a0fef0e144b325edde2bb71cf806ffcdbc50f0ebe41f8c4d8354b8'
Not writing persistent cache entry for 'jit__threefry_seed' because it took < 5.00 seconds to compile (0.01s)
TRACING CACHE MISS at /tmp/ipykernel_96018/1754672284.py:46:11 (<module>) because:
  never seen function:
    _threefry_split id=130795143702752 defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/prng.py:1085
PERSISTENT COMPILATION CACHE MISS for 'jit__threefry_split' with key 'jit__threefry_split-4d258143e8ef85ee3942d85c53921955611d45d2a27cc72a934a9f0bbf5bb0e6'
Not writing persistent cache entry for 'jit__threefry_split' because it took < 5.00 seconds to compile (0.04s)
TRACING CACHE MISS at /tmp/ipykernel_96018/1754672284.py:48:32 (<module>) because:
  never seen function:
    broadcast_in_dim id=130792568034368 defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/

Controller_NN()


Not writing persistent cache entry for 'jit_dynamic_slice' because it took < 5.00 seconds to compile (0.00s)
TRACING CACHE MISS at /tmp/ipykernel_96018/1754672284.py:48:61 (<module>) because:
  for squeeze defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/dispatch.py:97
  never seen input type signature:
    args[0]: u32[1,2]
  closest seen input type signature has 1 mismatches, including:
    * at args[0], seen f32[1,4], but now given u32[1,2]
PERSISTENT COMPILATION CACHE MISS for 'jit_squeeze' with key 'jit_squeeze-f1a448b2ae7803941ebbade835e9a3da0ceb2a4326f90a0c073282e1808d9983'
Not writing persistent cache entry for 'jit_squeeze' because it took < 5.00 seconds to compile (0.00s)
TRACING CACHE MISS at /tmp/ipykernel_96018/1754672284.py:48:9 (<module>) because:
  never seen function:
    _is_valid_rng.<locals>.<lambda> id=130792568044608 defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/flax/core/scope.py:1230
TRACING C

(Array([0.3537, 0.4582, 0.    , 0.    , 0.3568, 0.    , 0.    , 0.    , 0.    , 0.0299, 0.1213,
       0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.0104, 0.    , 0.    , 0.    ,
       0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.4784, 0.    , 0.2003, 0.    , 0.    ,
       0.    , 0.    , 0.2286, 0.    , 0.    , 0.    , 0.    , 0.2364, 0.    , 0.    , 0.    ,
       0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    ],      dtype=float32), Array([ 0.5507,  0.7117,  0.1663, -0.386 ,  0.572 ,  0.0428,  0.151 , -0.341 , -0.1592,  0.1468,
        0.3125, -0.506 , -0.098 , -0.1812, -0.2472, -0.1027, -0.2657, -0.0098,  0.1604,  0.1802,
       -0.4265, -0.2269, -0.2294, -0.0505,  0.1361,  0.0374,  0.0641, -0.0512,  0.6791, -0.3706,
        0.4092, -0.1719, -0.2406, -0.022 ,  0.1271,  0.4203,  0.0284,  0.0943, -0.4435,  0.1295,
        0.4245,  0.0444, -0.1952,  0.121 , -0.08  , -0.3598, -0.0973, -0.0134,  0.0238, -0.2089,
       -0.5425, -0.3392, 

## Combine Neural Net and Simulation into one Jax function

In [5]:
# Multiple steps
def step_fn(carry, _):
    data, model = carry
    new_data = mjx.step(model, data)
    new_carry = (new_data, model)
    return new_carry, _

def multiple_steps(model, data):
    init_carry = (data, model)
    y, _ = jax.lax.scan(step_fn, init_carry, None, length=10)
    new_data = y[0]
    return new_data

def nn_mjx_one_step(nn_params, model, data, key):
    states = jp.concatenate([data.qpos, data.qvel])
    act = jit_nn_apply(nn_params, states, key)[0]
    # Generate the next key
    new_key = jax.random.split(key,1)[0]
    mtu = data.biomtu
    mtu = mtu.replace(act = act)
    data = data.replace(biomtu = mtu)
    new_data = mjx.step(model, data)
    return new_data, new_key

@jax.jit
def nn_mjx_multi_steps(nn_params, model, data, key):
    def nn_step_fn(carry, _):
        data, key = carry
        new_data, new_key = nn_mjx_one_step(nn_params, model, data, key)
        new_carry = (new_data, new_key)
        # head_hight = new_data.xpos[2,2]
        # jax.debug.print("Head Height {0}",head_hight)
        return new_carry, _
    init_carry = (data, key)
    y, _ = jax.lax.scan(nn_step_fn, init_carry, None, length=10)
    new_data = y[0]
    new_key = y[1]
    return new_data, new_key

## Testing control model with neural networks

In [6]:
import mujoco.viewer
import time

previous_frame_time = time.time()
i = 0
key = jax.random.key(334)
with mujoco.viewer.launch_passive(mj_model, mj_data) as viewer:
    while viewer.is_running():
        # Update mjx_data from mj_data. The mj_data was modified by the viewer
        # mjx_data = mjx_data.replace(ctrl=mj_data.ctrl, xfrc_applied=mj_data.xfrc_applied)
        # Use the nerual network to generate ctrl signal
        # Generate key
        
        mjx_data = mjx_data.replace(xfrc_applied=jp.array(mj_data.xfrc_applied, dtype=jp.float32))
        mjx_data = mjx_data.replace(
            qpos= jp.array(mj_data.qpos, dtype=jp.float32),
            qvel= jp.array(mj_data.qvel, dtype=jp.float32),
            time = jp.array(mj_data.time, dtype=jp.float32))
        
        # Update mjx_model from mj_model
        mjx_model = mjx_model.tree_replace({
            'opt.gravity': jp.array(mj_model.opt.gravity, dtype=jp.float32),
            'opt.tolerance': jp.array(mj_model.opt.tolerance, dtype=jp.float32),
            'opt.ls_tolerance': jp.array(mj_model.opt.ls_tolerance, dtype=jp.float32),
            'opt.timestep': jp.array(mj_model.opt.timestep, dtype=jp.float32),
        })
        
        # mjx_data = mjx_step(mjx_model, mjx_data)
        mjx_data, key = nn_mjx_multi_steps(params, mjx_model, mjx_data, key)
        mjx.get_data_into(mj_data, mj_model, mjx_data)
        
        # Record the current time at the start of this frame
        current_frame_time = time.time()
    
        # Calculate the difference in time from the last frame
        time_between_frames = current_frame_time - previous_frame_time
    
        # Print the time between frames
        print(f"Time between frames: {time_between_frames} seconds")
        previous_frame_time = current_frame_time
        
        # print("ACT:", mjx_data.biomtu.act)
        # print(mjx_data.qpos)
        print(mjx_data.xpos)
        # print(len(mjx_data.qvel))
        viewer.sync()

TRACING CACHE MISS at /tmp/ipykernel_47096/3218733631.py:14:49 (<module>) because:
  for convert_element_type defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/dispatch.py:97
  never seen input type signature:
    args[0]: f32[14,6]
  closest seen input type signature has 1 mismatches, including:
    * at args[0], seen f32[4,3], but now given f32[14,6]
PERSISTENT COMPILATION CACHE MISS for 'jit_convert_element_type' with key 'jit_convert_element_type-73b1baa832c2654eb280b01b3205d6dc47f33bcae6f45a4a79d9b1007d4844af'
Not writing persistent cache entry for 'jit_convert_element_type' because it took < 2.00 seconds to compile (0.01s)
TRACING CACHE MISS at /tmp/ipykernel_47096/3218733631.py:16:18 (<module>) because:
  for convert_element_type defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/dispatch.py:97
  never seen input type signature:
    args[0]: f32[27]
  closest seen input type signature has 1 mismatches, incl

Time between frames: 4.256472110748291 seconds
[[ 0.      0.      0.    ]
 [ 0.      0.      0.    ]
 [ 0.0013  0.0001  0.9477]
 [-0.0704 -0.0827  0.8819]
 [-0.0789 -0.0851  0.4866]
 [-0.0716 -0.0889  0.0567]
 [-0.1221 -0.0937  0.0163]
 [ 0.0561 -0.0832  0.0059]
 [-0.0695  0.0842  0.8825]
 [-0.0766  0.0846  0.4871]
 [-0.0731  0.0858  0.0572]
 [-0.1228  0.0904  0.0157]
 [ 0.0557  0.0826  0.0091]
 [-0.0987  0.0004  1.0301]]
Time between frames: 1.8994770050048828 seconds
[[ 0.      0.      0.    ]
 [ 0.      0.      0.    ]
 [ 0.007   0.0008  0.9418]
 [-0.0665 -0.0788  0.874 ]
 [-0.0915 -0.0848  0.4799]
 [-0.0752 -0.0983  0.0504]
 [-0.1203 -0.1001  0.0039]
 [ 0.0567 -0.0786  0.0172]
 [-0.0639  0.088   0.881 ]
 [-0.0843  0.0837  0.4867]
 [-0.0725  0.0866  0.0569]
 [-0.1211  0.084   0.014 ]
 [ 0.0577  0.0815  0.0132]
 [-0.0918 -0.0012  1.0256]]
Time between frames: 0.0184781551361084 seconds
[[ 0.      0.      0.    ]
 [ 0.      0.      0.    ]
 [ 0.0146  0.003   0.9371]
 [-0.0562 -0.077  

## Batched Random Init Model

In [6]:
def random_init(data, model, rng: jax.Array):
    nbiomtu = model.nq
    init_qpos = data.qpos
    init_qvel = data.qvel
    new_rng, rng1, rng2 = jax.random.split(rng, 3)
    # Qpos_1 is the vertical position
    random_qpos = init_qpos + jax.random.uniform(rng1, [nbiomtu], minval=-1.0, maxval=1.0)*0.01
    random_qvel = init_qvel + jax.random.uniform(rng2, [nbiomtu], minval=-1.0, maxval=1.0)*0.01
    newdata = data.replace(qpos=random_qpos)
    newdata = newdata.replace(qvel=random_qvel)
    newdata = mjx.forward(mjx_model, newdata)
    # print('data:',data.qpos, data.qvel)
    return newdata, new_rng

vrandom_init = jax.jit(jax.vmap(random_init, in_axes=(None, None, 0),out_axes=(0,0)))

## Reward Function

In [7]:
def nn_step_fn(carry, _):
        data, key, nn_params, model = carry
        new_data, new_key = nn_mjx_one_step(nn_params, model, data, key)
        new_carry = (new_data, new_key, nn_params, model)
        head_hight = new_data.xpos[2,2]
        # jax.debug.print("Head Height {0}",head_hight)
        return new_carry, (head_hight-0.94)**2
    
def decay_sum_scan(x, decay):
    def f(sxtm1, xt):
        b = xt + decay * sxtm1
        return b, b
    return jax.lax.scan(f, jp.zeros(x.shape[1:]), x)[1]

def reward_n_step(nn_params, model, data, key):
    repeat_n = 150
    _, head_heights = jax.lax.scan(nn_step_fn, (data, key, nn_params, model), None, repeat_n)
    reward = decay_sum_scan(head_heights, 0.95)[repeat_n -1]
    return reward
    
# reward_grad = jax.jit(jax.grad(reward_n_step))


def batch_reward(nn_params,model, batched_data, keys):
    return jp.mean(jax.vmap(reward_n_step, (None, None, 0, 0))(nn_params, model, batched_data, keys))

jit_batch_reward_grad = jax.jit(jax.grad((batch_reward)))
jit_batch_reward = jax.jit(batch_reward)

In [8]:
# Generate batched data
batch_size = 40
seed = 2024
key = jax.random.key(seed)
rngs = jax.random.split(key, batch_size) 
mjx_data_batch, rngs = vrandom_init(mjx_data, mjx_model, rngs)

print("Calculating reward grad")
r = jit_batch_reward(params,mjx_model, mjx_data_batch, rngs)
print(r)
g = jit_batch_reward_grad(params, mjx_model, mjx_data_batch, rngs)
mjx_data_batch, rngs = vrandom_init(mjx_data, mjx_model, rngs)
g = jit_batch_reward_grad(params, mjx_model, mjx_data_batch, rngs)


TRACING CACHE MISS at /tmp/ipykernel_96018/3492614587.py:5:7 (<module>) because:
  for _threefry_split defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/prng.py:1085
explanation unavailable! please open an issue at https://github.com/jax-ml/jax


PERSISTENT COMPILATION CACHE MISS for 'jit__threefry_split' with key 'jit__threefry_split-c0a3e0515d2dc54a05f94aa3b992cc376fd3ebd48763c45ef61d602bacc3daf0'
Not writing persistent cache entry for 'jit__threefry_split' because it took < 5.00 seconds to compile (0.08s)
TRACING CACHE MISS at /tmp/ipykernel_96018/1679496683.py:7:30 (random_init) because:
  for _uniform defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/random.py:405
  never seen input type signature:
    key: key<fry>[],  minval: f32[],  maxval: f32[]
  closest seen input type signature has 2 mismatches, including:
    * at minval, seen f32[]{weak_type=False}, but now given f32[]{weak_type=True}
    * at maxval, seen f32[]{weak_type=False}, but now given f32[]{weak_type=True}
where weak_type=True often means a Python builtin numeric value, and 
weak_type=False means a jax.Array.
See https://jax.readthedocs.io/en/latest/type_promotion.html#weak-types
TRACING CACHE MISS at /tmp/ipykernel_96

Calculating reward grad


TRACING CACHE MISS at /tmp/ipykernel_96018/3492614587.py:9:4 (<module>) because:
  never seen function:
    batch_reward id=130792566846336 defined at /tmp/ipykernel_96018/1928993877.py:24
TRACING CACHE MISS at /tmp/ipykernel_96018/2944119200.py:18:14 (nn_mjx_one_step) because:
  for _threefry_split defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/prng.py:1085
  tracing context doesn't match, e.g. due to config or context manager
  closest seen context tuple differs at positions:
    3
  compare to tuple returned by config._trace_context() in jax/_src/config.py.
TRACING CACHE MISS at /tmp/ipykernel_96018/1754672284.py:23:12 (Controller_NN.__call__) because:
  for relu defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/nn/functions.py:50
  tracing context doesn't match, e.g. due to config or context manager
  closest seen context tuple differs at positions:
    3
  compare to tuple returned by config._trace_contex

0.0011116982


TRACING CACHE MISS at /tmp/ipykernel_96018/3492614587.py:11:4 (<module>) because:
  never seen function:
    batch_reward id=130792566847136 defined at /tmp/ipykernel_96018/1928993877.py:24


In [8]:
for i in range(10):
    g = batch_reward_grad(params, mjx_data_batch, rngs)

In [12]:
for i in range(10):
    r = jit_batch_reward(params, mjx_data_batch, rngs)

TypeError: batch_reward() missing 1 required positional argument: 'keys'

In [9]:
print(r)
print(g)

0.0011116982
{'params': {'linear1': {'bias': Array([-0.0001,  0.    , -0.0002,  0.0001, -0.    ,  0.    ,  0.    , -0.0001, -0.0004, -0.0001,
       -0.0002, -0.0003,  0.    ,  0.0002, -0.0003, -0.    , -0.    ,  0.0002,  0.0003, -0.0001,
        0.0002, -0.0001, -0.0001,  0.0001,  0.0002, -0.0001, -0.0004,  0.    , -0.0002,  0.0001,
       -0.0003, -0.    ,  0.0002,  0.0001, -0.0001, -0.0002,  0.    , -0.0002,  0.    , -0.    ,
        0.0002, -0.    ,  0.    , -0.    ,  0.    ,  0.0002, -0.    , -0.0004, -0.    ,  0.    ,
        0.0002,  0.0002,  0.    ,  0.0001, -0.    ,  0.0002,  0.0001,  0.0001, -0.0001,  0.0005,
        0.0001,  0.0001,  0.0002,  0.0002, -0.0003,  0.    ,  0.0005,  0.0002, -0.    ,  0.0001,
       -0.    ,  0.0003,  0.    ,  0.0001,  0.0002,  0.0006,  0.0001, -0.0001,  0.    ,  0.0001,
        0.0003,  0.0001, -0.    , -0.0004, -0.    ,  0.0003,  0.    , -0.0001, -0.0001, -0.0001,
        0.    , -0.0001, -0.0002, -0.0001, -0.0001,  0.0007, -0.    , -0.    , -0.

## Train the NN

In [10]:
params = control_model.init(key,jp.empty([1,mjx_model.nq*2]),sub_keys[0])


In [None]:
tx = optax.adam(learning_rate=0.0005)
opt_state = tx.init(params)

for i in range(200):
    # generate random mjx_data
    mjx_data_batch, rngs = vrandom_init(mjx_data, mjx_model, rngs)
    # print(rngs[0])
    r = jit_batch_reward(params,mjx_model, mjx_data_batch, rngs)
    print(r)
    g = jit_batch_reward_grad(params, mjx_model, mjx_data_batch, rngs)
    updates, opt_stats = tx.update(g, opt_state)
    params = optax.apply_updates(params, updates)
    print("params updated")

TRACING CACHE MISS at /tmp/ipykernel_96018/1710814893.py:2:12 (<module>) because:
  never seen function:
    broadcast_in_dim id=130793931337024 defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/dispatch.py:97
PERSISTENT COMPILATION CACHE MISS for 'jit_broadcast_in_dim' with key 'jit_broadcast_in_dim-00d312b05aa0bb26fb316548c337fc14038e8b7117a405aa136c1b375ac27e05'
Not writing persistent cache entry for 'jit_broadcast_in_dim' because it took < 5.00 seconds to compile (0.01s)
TRACING CACHE MISS at /tmp/ipykernel_96018/1710814893.py:2:12 (<module>) because:
  never seen function:
    broadcast_in_dim id=130790094130368 defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/dispatch.py:97
  but seen another function defined on the same line; maybe the function is
  being re-defined repeatedly, preventing caching?
PERSISTENT COMPILATION CACHE MISS for 'jit_broadcast_in_dim' with key 'jit_broadcast_in_dim-231e697849ae6a1d3

0.0048969346


TRACING CACHE MISS at /tmp/ipykernel_96018/1710814893.py:11:25 (<module>) because:
  never seen function:
    integer_pow id=130790094136608 defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/dispatch.py:97
PERSISTENT COMPILATION CACHE MISS for 'jit_integer_pow' with key 'jit_integer_pow-2e2ff22179104b0ccbe90b614e2ca76b4470eb3e9d794572b242935ab21c5fb5'
Not writing persistent cache entry for 'jit_integer_pow' because it took < 5.00 seconds to compile (0.01s)
PERSISTENT COMPILATION CACHE MISS for 'jit__multiply' with key 'jit__multiply-15324d8f3ee0132774e5767e945abb9154285c0d7362a75bae4977e8f0200719'
Not writing persistent cache entry for 'jit__multiply' because it took < 5.00 seconds to compile (0.02s)
PERSISTENT COMPILATION CACHE MISS for 'jit__multiply' with key 'jit__multiply-c62223e1411a54d08ecddbeb1b46eb02739a22e8e88f2b20d90aad3483630469'
Not writing persistent cache entry for 'jit__multiply' because it took < 5.00 seconds to compile (0.01s)
PERS

params updated
0.02267608


PERSISTENT COMPILATION CACHE MISS for 'jit__add' with key 'jit__add-aeedc25992dd0c08112239692eede1e0d6fbc219b41c1560191279447199a626'
Not writing persistent cache entry for 'jit__add' because it took < 5.00 seconds to compile (0.01s)
PERSISTENT COMPILATION CACHE MISS for 'jit__add' with key 'jit__add-148bf9ef75f39556ebc84344a5bac3c817c122d27a566ca3d6e6bb9358d3a2c7'
Not writing persistent cache entry for 'jit__add' because it took < 5.00 seconds to compile (0.01s)
PERSISTENT COMPILATION CACHE MISS for 'jit__add' with key 'jit__add-336c211901bfc000e8b8fd19f04b8019514a0437862afc5a81f6c588c346bacb'
Not writing persistent cache entry for 'jit__add' because it took < 5.00 seconds to compile (0.01s)
PERSISTENT COMPILATION CACHE MISS for 'jit__add' with key 'jit__add-808a03a8940657b3cd7e21029473914a2303fd2165f23b9310ab07699fcb8265'
Not writing persistent cache entry for 'jit__add' because it took < 5.00 seconds to compile (0.01s)
PERSISTENT COMPILATION CACHE MISS for 'jit__add' with key 'jit__

params updated
0.030924765
params updated
0.0012783896
params updated
0.0023813196
params updated
0.002797855
params updated
0.006820481
params updated
0.0029855794
params updated
0.00057281525
params updated
0.009325516
params updated
0.002782182
params updated
0.0034432642
params updated
0.012482496
params updated
0.0015386792
params updated
0.005845227
params updated
0.00049473584
params updated
0.0005971844


## Test Train

In [13]:
jit_nn_mjx_one_step = jax.jit(nn_mjx_one_step)


In [24]:
import mujoco.viewer
import time


mjx_data = mjx.make_data(mjx_model)
mj_data = mujoco.MjData(mj_model)

# Load the Keyframe
mjx_data = mjx_data.replace(qpos = mj_model.key_qpos[0])
mj_data.qpos = mj_model.key_qpos[0]

# Calculate equilibrum
mjx_data = acceleration_mtu.calc_equilibrium(mjx_model, mjx_data)
mjx_data = mjx_step(mjx_model, mjx_data)

previous_frame_time = time.time()
i = 0
key = jax.random.key(seed)
with mujoco.viewer.launch_passive(mj_model, mj_data) as viewer:
    while viewer.is_running():
        i += 1
        if(i ==4):
            time.sleep(5)
            # pass
        # Update mjx_data from mj_data. The mj_data was modified by the viewer
        # mjx_data = mjx_data.replace(ctrl=mj_data.ctrl, xfrc_applied=mj_data.xfrc_applied)
        # Use the nerual network to generate ctrl signal
        # Generate key
        
        mjx_data = mjx_data.replace(xfrc_applied=mj_data.xfrc_applied)
        mjx_data = mjx_data.replace(qpos=mj_data.qpos, qvel=mj_data.qvel, time = mj_data.time)
        
        # Update mjx_model from mj_model
        mjx_model = mjx_model.tree_replace({
            'opt.gravity': mj_model.opt.gravity,
            'opt.tolerance': mj_model.opt.tolerance,
            'opt.ls_tolerance': mj_model.opt.ls_tolerance,
            'opt.timestep': mj_model.opt.timestep,
        })
        
        # mjx_data = mjx_step(mjx_model, mjx_data)
        mjx_data, key = jit_nn_mjx_one_step(params, mjx_model, mjx_data, key)
        # mjx_data, key = nn_mjx_multi_steps(params, mjx_model, mjx_data, key)
        
        mjx.get_data_into(mj_data, mj_model, mjx_data)
        
        # Record the current time at the start of this frame
        current_frame_time = time.time()
    
        # Calculate the difference in time from the last frame
        time_between_frames = current_frame_time - previous_frame_time
    
        # Print the time between frames
        print(f"Time between frames: {time_between_frames} seconds")
        previous_frame_time = current_frame_time
        
        # print("ACT:", mjx_data.biomtu.act)
        # print(mjx_data.qpos)
        print(mjx_data.xpos[2,2])
        # print(len(mjx_data.qvel))
        viewer.sync()

Time between frames: 0.2678225040435791 seconds
0.9499999
Time between frames: 0.04675865173339844 seconds
0.9500231
Time between frames: 0.011263132095336914 seconds
0.9499774
Time between frames: 5.012398958206177 seconds
0.94981
Time between frames: 0.011791467666625977 seconds
0.94967103
Time between frames: 0.011698007583618164 seconds
0.9495986
Time between frames: 0.009348154067993164 seconds
0.949608
Time between frames: 0.009460687637329102 seconds
0.9496927
Time between frames: 0.011826515197753906 seconds
0.9498546
Time between frames: 0.010198831558227539 seconds
0.9499677
Time between frames: 0.010564327239990234 seconds
0.9496804
Time between frames: 0.008799314498901367 seconds
0.9490548
Time between frames: 0.01104426383972168 seconds
0.94851875
Time between frames: 0.012903451919555664 seconds
0.9478325
Time between frames: 0.010254383087158203 seconds
0.94725096
Time between frames: 0.010094881057739258 seconds
0.94683576
Time between frames: 0.012192964553833008 seco

## Test Gradient for activation

In [14]:
n = mjx_model.nbiomtu
act = jp.ones(n)*0.1

In [28]:
# jax.config.update("jax_debug_nans", True)
for i in range(10):
    g = jit_goal_grad(act, mjx_data, mjx_model)
    # g = goal_grad(act, mjx_data, mjx_model)
print(g)

[ 0.0029  0.0177  0.0011  0.0183  0.0053  0.0064  0.0115  0.02    0.0023  0.0084  0.0302  0.0048
  0.0306  0.0018  0.0038  0.0019  0.0006  0.0057  0.0181  0.0067 -0.0016  0.0128 -0.0306  0.0814
  0.0015  0.0042  0.0169  0.1057  0.0211  0.005   0.0255  0.006   0.021   0.0105  0.0301 -0.0111
  0.0027  0.0128  0.0167  0.0024  0.005   0.0058 -0.0179 -0.1854  0.0519  0.0019  0.0187  0.0343
  0.0252  0.0251 -0.0008  0.0025  0.0002  0.0114]


In [36]:
v_act = jp.ones((50,n))*0.3
for i in range(5):
    v_g = jit_v_goal_grad(v_act,mjx_data,mjx_model)
print(v_g)

2024-10-23 15:14:45.445644: W external/xla/xla/service/hlo_rematerialization.cc:3005] Can't reduce memory use below 6.02GiB (6467407612 bytes) by rematerialization; only reduced to 12.31GiB (13220408549 bytes), down from 12.31GiB (13220408549 bytes) originally
2024-10-23 15:15:46.492778: W external/xla/xla/tsl/framework/bfc_allocator.cc:497] Allocator (GPU_0_bfc) ran out of memory trying to allocate 12.31GiB (rounded to 13220816896)requested by op 
2024-10-23 15:15:46.493013: W external/xla/xla/tsl/framework/bfc_allocator.cc:508] *___________________________________________________________________________________________________
E1023 15:15:46.493063 2100274 pjrt_stream_executor_client.cc:3084] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 13220816736 bytes.


XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 13220816736 bytes.

In [27]:
for i in range(10):
    reward = jit_goal(act, mjx_data, mjx_model)
print(reward)

8.035784


In [None]:
act = jp.ones(n)*0.0
mtu = mjx_data.biomtu
mtu = mtu.replace(act = act)
test_d = mjx_data
test_d = test_d.replace(biomtu = mtu)
for i in range(100):
    test_d = mjx_step(mjx_model, test_d)
print(test_d.xipos)

In [None]:
print(test_d.biomtu.f_se)
print(test_d.biomtu.f_ce)
print(test_d.biomtu.l)
print(test_d.biomtu.tendon_l - mjx_model.biomtu_tendon_slack_l)
print(test_d.biomtu.fiber_l - mjx_model.biomtu_ofl)
print(test_d.biomtu.fiber_l)
print(test_d.biomtu.act)

In [64]:
print(mjx_data.xpos)

[[ 0.      0.      0.    ]
 [ 0.      0.      0.    ]
 [ 0.     -0.0002  0.95  ]
 [-0.0707 -0.0837  0.8839]
 [-0.0743 -0.0836  0.4882]
 [-0.0743 -0.0835  0.0582]
 [-0.1231 -0.0914  0.0162]
 [ 0.0557 -0.0925  0.0142]
 [-0.0707  0.0833  0.8839]
 [-0.0743  0.0834  0.4882]
 [-0.0743  0.0835  0.0582]
 [-0.1231  0.0914  0.0163]
 [ 0.0557  0.0925  0.0143]
 [-0.1007 -0.0002  1.0315]]


In [5]:
mjx_model.nbiomtu

54

In [19]:
batch_size = 40
seed = 2024
key = jax.random.key(seed)
rngs = jax.random.split(key, batch_size) 
mjx_data_batch, rngs2 = vrandom_init(mjx_data, mjx_model, rngs)

In [24]:
print(rngs.shape)
print(rngs2.shape)
print(mjx_data_batch.biomtu.l.shape)

(40,)
(40,)
(40, 54)
