# Simple RL code to teach the model to stand

## Setup jax enviroment and Load Model 

In [1]:
import numpy as np
import mediapy as media
import matplotlib.pyplot as plt


import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".60"

# Optionally, force JAX to preallocate memory.
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "true"
# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

# Setup environment variable for Nvidia GPU acceleration
os.environ['XLA_FLAGS'] = (
    '--xla_gpu_triton_gemm_any=True '
)

backend = 'gpu'
# backend = 'METAL'
# backend = 'cpu'

import jax
# Enable compliation catch
jax.config.update("jax_compilation_cache_dir", "./jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

# Debug Nan
# jax.config.update("jax_debug_nans", True)

from jax import numpy as jp
# More legible printing from numpy.
jp.set_printoptions(precision=4, suppress=True, linewidth=100)

import mujoco
import mujoco.mjx as mjx
from mujoco.mjx._src import scan
from mujoco.mjx._src import types

# More legible printing from numpy.
np.set_printoptions(precision=4, suppress=True, linewidth=100)

from IPython.display import clear_output
clear_output()

device = jax.devices(backend=backend)[0]

model_path = '/home/bugman/Currentwork/biomujoco_converter/converted/mjc/Gait2354/gait2354_cvt1.xml'

# Single step
mjx_step = jax.jit(mjx.step, backend=backend)



# mjx_multiple_steps = jax.jit(multiple_steps, backend=backend, )

In [2]:
from mujoco.mjx._src.biomtu.acceleration_mtu import acceleration_mtu

mj_model = mujoco.MjModel.from_xml_path(model_path)
mjx_model = mjx.put_model(mj_model,device=device)

# Disable tendon
opt = mjx_model.opt.replace(disableflags = mjx_model.opt.disableflags |mujoco.mjtDisableBit.mjDSBL_PASSIVE)
mjx_model = mjx_model.replace(opt=opt)

mjx_data = mjx.make_data(mjx_model)
mj_data = mujoco.MjData(mj_model)

# Load the Keyframe
mjx_data = mjx_data.replace(qpos = mj_model.key_qpos[0])
mj_data.qpos = mj_model.key_qpos[0]

# Calculate equilibrum
mjx_data = acceleration_mtu.calc_equilibrium(mjx_model, mjx_data)

def print_all():
    print(mjx_model.biomtu_adr)
    print(mjx_model.mtu_wrap_objid)
    print(mjx_model.mtu_wrap_type)
    print(mjx_model.biomtu_fiso)
    print(mjx_model.biomtu_vmax)
    print(mjx_model.biomtu_ofl)
    print(mjx_model.biomtu_opa)
    print(mjx_model.biomtu_mass)
    print("-------Data--------")
    print("qpos:", mjx_data.qpos)
    print("mtu l:", mjx_data.biomtu.l)
    print("tendon l:", mjx_data.biomtu.tendon_l)
    print("fiber l :", mjx_data.biomtu.fiber_l)
    print("Muscle Bce:", mjx_data.biomtu.B_ce)
    print("Muscle vm:", mjx_data.biomtu.m)
    print("Fiber acc:", mjx_data.biomtu.fiber_acc)
    print("Fiber v:", mjx_data.biomtu.fiber_v)
    print("Biomtu h:", mjx_data.biomtu.h)
    print(mjx_data.biomtu.v)
    print(mjx_data.biomtu.h)  # The constant high of the muscle.
    print(mjx_data.biomtu.pennation_angle)
    print(mjx_data.biomtu.origin_body_id)
    print(mjx_data.biomtu.insertion_body_id)
    print("mtu act:", mjx_data.biomtu.act)
    # print(mjx_data.biomtu.j)
    print(mjx_data.qfrc_biomtu)
    print(mj_model.key_time)
    print(mj_model.key_qpos)
    print(mj_model.key_qvel)

print_all()

--------------------
BUGMAN: muscle name-glut_med1_r
--------------------
BUGMAN: muscle name-glut_med2_r
--------------------
BUGMAN: muscle name-glut_med3_r
--------------------
BUGMAN: muscle name-bifemlh_r
--------------------
BUGMAN: muscle name-bifemsh_r
--------------------
BUGMAN: muscle name-sar_r
--------------------
BUGMAN: muscle name-add_mag2_r
--------------------
BUGMAN: muscle name-tfl_r
--------------------
BUGMAN: muscle name-pect_r
--------------------
BUGMAN: muscle name-grac_r
--------------------
BUGMAN: muscle name-glut_max1_r
--------------------
BUGMAN: muscle name-glut_max2_r
--------------------
BUGMAN: muscle name-glut_max3_r
--------------------
BUGMAN: muscle name-iliacus_r
--------------------
BUGMAN: muscle name-psoas_r
--------------------
BUGMAN: muscle name-quad_fem_r
--------------------
BUGMAN: muscle name-gem_r
--------------------
BUGMAN: muscle name-peri_r
--------------------
BUGMAN: muscle name-rect_fem_r
--------------------
BUGMAN: muscle nam

In [3]:
print(mjx_model.nbiomtu)
print(mjx_model.nq)
print(mjx_data.qpos)
print(mjx_data.qvel)

54
27
[ 0.      0.95    0.      0.      0.      0.      0.      0.      0.     -0.0036 -0.3957  0.
  0.      0.      0.      0.      0.      0.     -0.0036 -0.3957  0.      0.      0.      0.
  0.      0.      0.    ]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]


## Neural Network

In [4]:
import jax
import jax.numpy as jp
import flax
import flax.linen as nn
import optax

class Controller_NN(nn.Module):
    # It will output the mean and variance for each muscle's activation
    nbiomtu = mjx_model.nbiomtu
    def setup(self):
        # Features means the output dimension# Single step

        self.linear1 = nn.Dense(features=400)
        self.linear2 = nn.Dense(features=400)
        self.linear3 = nn.Dense(features=400)
        self.linear4 = nn.Dense(features=400)
        # The last layer will output the mean and logstd
        self.linear5 = nn.Dense(features=self.nbiomtu*2)
        
    
    def __call__(self, x, key):
        x = self.linear1(x)
        x = nn.relu(x)
        x = self.linear2(x)
        x = nn.relu(x)
        x = self.linear3(x)
        x = nn.relu(x)
        x = self.linear4(x)
        x = nn.relu(x)
        x = self.linear5(x)
        # The last layer of the neural requires samping
        mean = x[0:self.nbiomtu]
        logstd = x[self.nbiomtu:self.nbiomtu*2]
        std = jp.exp(logstd)
        samples = jp.clip(jax.random.normal(key)*std*0.3 + mean, 0 ,1)
        
        return samples, mean, logstd


# Test the neural network
control_model = Controller_NN()

print(control_model)
# Init the model
key = jax.random.PRNGKey(66)
sub_keys = jax.random.split(key,1)
# The second parameter is the dommy input
params = control_model.init(key,jp.empty([1,mjx_model.nq*2]),sub_keys[0])
# print(params)
print(control_model.apply( params,jp.ones(mjx_model.nq*2), sub_keys[0]))
jit_nn_apply = jax.jit(lambda params,states,key : control_model.apply(params,states,key))

Controller_NN()
(Array([0.3537, 0.4582, 0.    , 0.    , 0.3568, 0.    , 0.    , 0.    , 0.    , 0.0299, 0.1213,
       0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.0104, 0.    , 0.    , 0.    ,
       0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.4784, 0.    , 0.2003, 0.    , 0.    ,
       0.    , 0.    , 0.2286, 0.    , 0.    , 0.    , 0.    , 0.2364, 0.    , 0.    , 0.    ,
       0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    ],      dtype=float32), Array([ 0.5507,  0.7117,  0.1663, -0.386 ,  0.572 ,  0.0428,  0.151 , -0.341 , -0.1592,  0.1468,
        0.3125, -0.506 , -0.098 , -0.1812, -0.2472, -0.1027, -0.2657, -0.0098,  0.1604,  0.1802,
       -0.4265, -0.2269, -0.2294, -0.0505,  0.1361,  0.0374,  0.0641, -0.0512,  0.6791, -0.3706,
        0.4092, -0.1719, -0.2406, -0.022 ,  0.1271,  0.4203,  0.0284,  0.0943, -0.4435,  0.1295,
        0.4245,  0.0444, -0.1952,  0.121 , -0.08  , -0.3598, -0.0973, -0.0134,  0.0238, -0.2089,
       -0

## Combine Neural Net and Simulation into one Jax function

In [5]:
# Multiple steps
def step_fn(carry, _):
    data, model = carry
    new_data = mjx.step(model, data)
    new_carry = (new_data, model)
    return new_carry, _

def multiple_steps(model, data):
    init_carry = (data, model)
    y, _ = jax.lax.scan(step_fn, init_carry, None, length=10)
    new_data = y[0]
    return new_data

def nn_mjx_one_step(nn_params, model, data, key):
    states = jp.concatenate([data.qpos, data.qvel])
    act = jit_nn_apply(nn_params, states, key)[0]
    # Generate the next key
    new_key = jax.random.split(key,1)[0]
    mtu = data.biomtu
    mtu = mtu.replace(act = act)
    data = data.replace(biomtu = mtu)
    new_data = mjx.step(model, data)
    return new_data, new_key

@jax.jit
def nn_mjx_multi_steps(nn_params, model, data, key):
    def nn_step_fn(carry, _):
        data, key = carry
        new_data, new_key = nn_mjx_one_step(nn_params, model, data, key)
        new_carry = (new_data, new_key)
        # head_hight = new_data.xpos[2,2]
        # jax.debug.print("Head Height {0}",head_hight)
        return new_carry, _
    init_carry = (data, key)
    y, _ = jax.lax.scan(nn_step_fn, init_carry, None, length=10)
    new_data = y[0]
    new_key = y[1]
    return new_data, new_key

## Testing control model with neural networks

In [6]:
import mujoco.viewer
import time

previous_frame_time = time.time()
i = 0
key = jax.random.key(387)
with mujoco.viewer.launch_passive(mj_model, mj_data) as viewer:
    while viewer.is_running():
        # Update mjx_data from mj_data. The mj_data was modified by the viewer
        # mjx_data = mjx_data.replace(ctrl=mj_data.ctrl, xfrc_applied=mj_data.xfrc_applied)
        # Use the nerual network to generate ctrl signal
        # Generate key
        
        mjx_data = mjx_data.replace(xfrc_applied=mj_data.xfrc_applied)
        mjx_data = mjx_data.replace(qpos=mj_data.qpos, qvel=mj_data.qvel, time = mj_data.time)
        
        # Update mjx_model from mj_model
        mjx_model = mjx_model.tree_replace({
            'opt.gravity': mj_model.opt.gravity,
            'opt.tolerance': mj_model.opt.tolerance,
            'opt.ls_tolerance': mj_model.opt.ls_tolerance,
            'opt.timestep': mj_model.opt.timestep,
        })
        
        # mjx_data = mjx_step(mjx_model, mjx_data)
        mjx_data, key = nn_mjx_multi_steps(params, mjx_model, mjx_data, key)
        mjx.get_data_into(mj_data, mj_model, mjx_data)
        
        # Record the current time at the start of this frame
        current_frame_time = time.time()
    
        # Calculate the difference in time from the last frame
        time_between_frames = current_frame_time - previous_frame_time
    
        # Print the time between frames
        print(f"Time between frames: {time_between_frames} seconds")
        previous_frame_time = current_frame_time
        
        # print("ACT:", mjx_data.biomtu.act)
        # print(mjx_data.qpos)
        print(mjx_data.xpos)
        # print(len(mjx_data.qvel))
        viewer.sync()

Time between frames: 15.955093145370483 seconds
[[ 0.      0.      0.    ]
 [ 0.      0.      0.    ]
 [ 0.0007  0.0005  0.9483]
 [-0.0702 -0.0831  0.8825]
 [-0.0786 -0.0848  0.4871]
 [-0.0714 -0.0874  0.0572]
 [-0.1238 -0.0934  0.0195]
 [ 0.0539 -0.0884  0.0004]
 [-0.0701  0.0839  0.8822]
 [-0.0722  0.0847  0.4866]
 [-0.0765  0.0855  0.0566]
 [-0.1252  0.0907  0.0141]
 [ 0.0533  0.0811  0.0112]
 [-0.0999  0.0007  1.03  ]]
Time between frames: 13.496637344360352 seconds
[[ 0.      0.      0.    ]
 [ 0.      0.      0.    ]
 [ 0.0038  0.0019  0.943 ]
 [-0.0678 -0.0787  0.8743]
 [-0.0881 -0.086   0.4798]
 [-0.0768 -0.0979  0.0502]
 [-0.1224 -0.1015  0.0042]
 [ 0.0551 -0.0828  0.0147]
 [-0.0667  0.0882  0.8804]
 [-0.0698  0.0857  0.485 ]
 [-0.077   0.0823  0.0551]
 [-0.1258  0.082   0.0125]
 [ 0.0528  0.0731  0.0105]
 [-0.0964 -0.0004  1.025 ]]
Time between frames: 0.0372319221496582 seconds
[[ 0.      0.      0.    ]
 [ 0.      0.      0.    ]
 [ 0.01    0.003   0.9396]
 [-0.0601 -0.076 

## Batched Random Init Model

In [7]:
def random_init(data, model, rng: jax.Array):
    nbiomtu = model.nq
    init_qpos = data.qpos
    init_qvel = data.qvel
    new_rng, rng1, rng2 = jax.random.split(rng, 3)
    # Qpos_1 is the vertical position
    random_qpos = init_qpos + jax.random.uniform(rng1, [nbiomtu], minval=-1, maxval=1)*0.0001
    random_qvel = init_qvel + jax.random.uniform(rng2, [nbiomtu], minval=-1, maxval=1)*0.0001
    newdata = data.replace(qpos=random_qpos)
    newdata = newdata.replace(qvel=random_qvel)
    newdata = mjx.forward(mjx_model, newdata)
    # print('data:',data.qpos, data.qvel)
    return newdata, new_rng

vrandom_init = jax.jit(jax.vmap(random_init, in_axes=(None, None, 0),out_axes=(0,0)))

## Reward Function

In [8]:
def reward_n_step(nn_params, data, model, key):
    repeat_n = 250
    def nn_step_fn(carry, _):
        data, key = carry
        new_data, new_key = nn_mjx_one_step(nn_params, model, data, key)
        new_carry = (new_data, new_key)
        head_hight = new_data.xpos[2,2]
        # jax.debug.print("Head Height {0}",head_hight)
        return new_carry, head_hight
    
    def decay_sum_scan(x, decay):
        def f(sxtm1, xt):
            b = xt + decay * sxtm1
            return b, b
        return jax.lax.scan(f, jp.zeros(x.shape[1:]), x)[1]
    
    _, head_heights = jax.lax.scan(nn_step_fn, (data, key), None, repeat_n)
    # new_data, new_key = carry
    reward = decay_sum_scan(head_heights, 0.95)[repeat_n -1]
    return reward
    
reward_grad = jax.jit(jax.grad(reward_n_step))


def batch_reward(nn_params, batched_data, model, keys):
    return jp.mean(jax.vmap(reward_n_step, (None, 0, None, 0))(nn_params, batched_data, model, keys))

batch_reward_grad = jax.jit(jax.grad((batch_reward)))
jit_batch_reward = jax.jit(batch_reward)

In [8]:
# Generate batched data
batch_size = 400
seed = 2024
key = jax.random.key(seed)
rngs = jax.random.split(key, batch_size) 
mjx_data_batch, r = vrandom_init(mjx_data, mjx_model, rngs)

print("Calculating reward grad")
a = batch_reward_grad(params, mjx_data_batch, mjx_model, rngs)

Calculating reward grad


2024-10-22 18:08:21.723023: W external/xla/xla/service/hlo_rematerialization.cc:3005] Can't reduce memory use below 6.02GiB (6462182217 bytes) by rematerialization; only reduced to 76.87GiB (82542725536 bytes), down from 76.87GiB (82542728160 bytes) originally
2024-10-22 18:09:22.612988: W external/xla/xla/tsl/framework/bfc_allocator.cc:497] Allocator (GPU_0_bfc) ran out of memory trying to allocate 76.87GiB (rounded to 82540875776)requested by op 
2024-10-22 18:09:22.613223: W external/xla/xla/tsl/framework/bfc_allocator.cc:508] *___________________________________________________________________________________________________
E1022 18:09:22.613279   10315 pjrt_stream_executor_client.cc:3084] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 82540875544 bytes.


XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 82540875544 bytes.

In [15]:
for i in range(100):
    r = jit_batch_reward(params, mjx_data_batch, mjx_model, rngs)

XlaRuntimeError: INTERNAL: cuSolver internal error

In [None]:
# print(r)
print(a)

## Train the NN

In [None]:
import jax
import jax.numpy as jp

seed = 123
key = jax.random.key(seed)
print(key)
new_key = jax.random.split(key,1)
print(new_key)
nn_key = jax.random.split(new_key[0],1)
print(nn_key)
print(jax.random.uniform(nn_key[0], [5], minval=-1, maxval=1)*0.001)

In [9]:
mjx_data.xipos[2,2]

Array(0.9484, dtype=float32)

In [None]:
mjx_data.qvel.shape

In [None]:
mjx_model.nq

## Test Gradient for activation

In [7]:
@jax.jit
def goal(act, data, model):
    # Multiple steps
    def one_step_fn(carry, _):
        in_data = carry[0]
        new_data = mjx_step(model, in_data)
        return [new_data], _
    
    in_data = data.replace(biomtu = data.biomtu.replace(act = act))
    y, _ = jax.lax.scan(one_step_fn, [in_data], None, length=2)
    out_data = y[0]
    
    # height = out_data.xipos[2,2]
    tendon_l = out_data.biomtu.tendon_l
    # act = jp.sum(out_data.biomtu.act)
    return jp.sum(tendon_l)

goal_grad = jax.grad(goal)
jit_goal_grad = jax.jit(goal_grad)
jit_goal = jax.jit(goal)


In [8]:
n = mjx_model.nbiomtu
act = jp.ones(n)*0.8

In [11]:
# jax.config.update("jax_debug_nans", True)
for i in range(20):
    g = jit_goal_grad(act, mjx_data, mjx_model)
    # g = goal_grad(act, mjx_data, mjx_model)
print(g)

[nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan]


In [9]:
reward = jit_goal(act, mjx_data, mjx_model)
print(reward)

8.002221


In [10]:
act = jp.ones(n)*0.0
mtu = mjx_data.biomtu
mtu = mtu.replace(act = act)
test_d = mjx_data
test_d = test_d.replace(biomtu = mtu)
for i in range(100):
    test_d = mjx_step(mjx_model, test_d)
print(test_d.xipos)

penn angle [0.173  0.     0.3802 0.     0.4443 0.     0.0884 0.0457 0.     0.0587 0.166  0.     0.1401 0.1142
 0.1384 0.     0.     0.2136 0.0704 0.0507 0.295  0.5362 0.1863 0.106  0.173  0.     0.3802 0.
 0.4443 0.     0.0884 0.0457 0.     0.0587 0.166  0.     0.1401 0.1142 0.1384 0.     0.     0.2136
 0.0704 0.0507 0.295  0.5362 0.1863 0.106  0.     0.     0.     0.     0.     0.    ]
penn angle [0.173  0.     0.3789 0.     0.4431 0.     0.0884 0.0457 0.     0.0587 0.1659 0.     0.1401 0.1142
 0.1384 0.     0.     0.2136 0.0704 0.0507 0.295  0.5353 0.1863 0.106  0.173  0.     0.3789 0.
 0.4431 0.     0.0884 0.0457 0.     0.0587 0.1659 0.     0.1401 0.1142 0.1384 0.     0.     0.2136
 0.0704 0.0507 0.295  0.5353 0.1863 0.106  0.     0.     0.     0.     0.     0.    ]
penn angle [0.1729 0.     0.3774 0.     0.4419 0.     0.0884 0.0457 0.     0.0587 0.1659 0.     0.1401 0.1142
 0.1384 0.     0.     0.2136 0.0704 0.0507 0.2949 0.534  0.1863 0.106  0.1729 0.     0.3774 0.
 0.4419 0.     

In [17]:
print(test_d.biomtu.f_se)
print(test_d.biomtu.f_ce)
print(test_d.biomtu.l)
print(test_d.biomtu.tendon_l - mjx_model.biomtu_tendon_slack_l)
print(test_d.biomtu.fiber_l - mjx_model.biomtu_ofl)
print(test_d.biomtu.fiber_l)
print(test_d.biomtu.act)

[  0.       0.       0.       0.       0.       0.       0.       9.3215   0.4286   0.       0.
   0.       0.      49.6185  26.9067   1.7476   1.875    0.     145.9819   0.       7.6348   0.
  10.7178   0.       0.       0.       0.       0.       0.       0.       0.       9.5163   0.3567
   0.       0.       0.       0.      50.1384  27.6669   2.487    2.175    0.     145.9139   0.
   9.9323   0.      13.1721   0.       6.8466   7.1436   0.       0.       0.       0.    ]
[  0.       0.      -0.0007   0.0042  -1.0216   0.4952  -0.1357   9.9354   1.9685   0.357    0.1604
   0.1219   0.0315  56.4589  27.4374   5.1075   2.1128   0.     148.3752  -0.2723   0.9375   0.0013
   9.1853   0.0125   0.       0.      -0.0007   0.0045  -1.0216   0.4982  -0.1176   9.9645   1.9453
   0.2691   0.1604   0.1219   0.0315  56.2305  27.5229   6.4554   3.1281   0.     148.2817  -0.2365
   2.7676   0.0011  10.0197   0.0125  33.4216  33.7595   0.0009   0.001   -0.6498  -0.6465]
[0.1139 0.1214 0.1001 0.4047

In [13]:
print(jax.devices())

[CudaDevice(id=0)]
