# Simple RL code to teach the model to stand

## Setup jax enviroment and Load Model 

In [1]:
import numpy as np
import mediapy as media
import matplotlib.pyplot as plt


import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".50"

# Optionally, force JAX to preallocate memory.
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

# Setup environment variable for Nvidia GPU acceleration

os.environ['XLA_FLAGS'] = (
    '--xla_gpu_triton_gemm_any=True '
)

backend = 'gpu'
# backend = 'METAL'
# backend = 'cpu'

import jax
# Enable compliation catch
jax.config.update("jax_compilation_cache_dir", "./jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

from jax import numpy as jp
# More legible printing from numpy.
jp.set_printoptions(precision=4, suppress=True, linewidth=100)

import mujoco
import mujoco.mjx as mjx
from mujoco.mjx._src import scan
from mujoco.mjx._src import types

# More legible printing from numpy.
np.set_printoptions(precision=4, suppress=True, linewidth=100)

from IPython.display import clear_output
clear_output()

device = jax.devices(backend=backend)[0]

model_path = '/home/bugman/Currentwork/biomujoco_converter/converted/mjc/Gait2354/gait2354_cvt1.xml'

# Single step
mjx_step = jax.jit(mjx.step, backend=backend)



# mjx_multiple_steps = jax.jit(multiple_steps, backend=backend, )

In [9]:
from mujoco.mjx._src.biomtu.acceleration_mtu import acceleration_mtu

mj_model = mujoco.MjModel.from_xml_path(model_path)
mjx_model = mjx.put_model(mj_model,device=device)

# Disable tendon
opt = mjx_model.opt.replace(disableflags = mjx_model.opt.disableflags |mujoco.mjtDisableBit.mjDSBL_PASSIVE)
mjx_model = mjx_model.replace(opt=opt)

mjx_data = mjx.make_data(mjx_model)
mj_data = mujoco.MjData(mj_model)

# Load the Keyframe
mjx_data = mjx_data.replace(qpos = mj_model.key_qpos[0])
mj_data.qpos = mj_model.key_qpos[0]

# Calculate equilibrum
mjx_data = acceleration_mtu.calc_equilibrium(mjx_model, mjx_data)

def print_all():
    print(mjx_model.biomtu_adr)
    print(mjx_model.mtu_wrap_objid)
    print(mjx_model.mtu_wrap_type)
    print(mjx_model.biomtu_fiso)
    print(mjx_model.biomtu_vmax)
    print(mjx_model.biomtu_ofl)
    print(mjx_model.biomtu_opa)
    print(mjx_model.biomtu_mass)
    print("-------Data--------")
    print("qpos:", mjx_data.qpos)
    print("mtu l:", mjx_data.biomtu.l)
    print("tendon l:", mjx_data.biomtu.tendon_l)
    print("fiber l :", mjx_data.biomtu.fiber_l)
    print("Muscle Bce:", mjx_data.biomtu.B_ce)
    print("Muscle vm:", mjx_data.biomtu.m)
    print("Fiber acc:", mjx_data.biomtu.fiber_acc)
    print("Fiber v:", mjx_data.biomtu.fiber_v)
    print("Biomtu h:", mjx_data.biomtu.h)
    print(mjx_data.biomtu.v)
    print(mjx_data.biomtu.h)  # The constant high of the muscle.
    print(mjx_data.biomtu.pennation_angle)
    print(mjx_data.biomtu.origin_body_id)
    print(mjx_data.biomtu.insertion_body_id)
    print("mtu act:", mjx_data.biomtu.act)
    # print(mjx_data.biomtu.j)
    print(mjx_data.qfrc_biomtu)
    print(mj_model.key_time)
    print(mj_model.key_qpos)
    print(mj_model.key_qvel)

print_all()

--------------------
BUGMAN: muscle name-glut_med1_r
--------------------
BUGMAN: muscle name-glut_med2_r
--------------------
BUGMAN: muscle name-glut_med3_r
--------------------
BUGMAN: muscle name-bifemlh_r
--------------------
BUGMAN: muscle name-bifemsh_r
--------------------
BUGMAN: muscle name-sar_r
--------------------
BUGMAN: muscle name-add_mag2_r
--------------------
BUGMAN: muscle name-tfl_r
--------------------
BUGMAN: muscle name-pect_r
--------------------
BUGMAN: muscle name-grac_r
--------------------
BUGMAN: muscle name-glut_max1_r
--------------------
BUGMAN: muscle name-glut_max2_r
--------------------
BUGMAN: muscle name-glut_max3_r
--------------------
BUGMAN: muscle name-iliacus_r
--------------------
BUGMAN: muscle name-psoas_r
--------------------
BUGMAN: muscle name-quad_fem_r
--------------------
BUGMAN: muscle name-gem_r
--------------------
BUGMAN: muscle name-peri_r
--------------------
BUGMAN: muscle name-rect_fem_r
--------------------
BUGMAN: muscle nam

In [3]:
print(mjx_model.nbiomtu)
print(mjx_model.nq)
print(mjx_data.qpos.shape)

54
27
(27,)


## Neural Network

In [10]:
import jax
import jax.numpy as jp
import flax
import flax.linen as nn
import optax

class Controller_NN(nn.Module):
    # It will output the mean and variance for each muscle's activation
    nbiomtu = mjx_model.nbiomtu
    def setup(self):
        # Features means the output dimension# Single step

        self.linear1 = nn.Dense(features=400)
        self.linear2 = nn.Dense(features=400)
        self.linear3 = nn.Dense(features=400)
        self.linear4 = nn.Dense(features=400)
        # The last layer will output the mean and logstd
        self.linear5 = nn.Dense(features=self.nbiomtu*2)
        
    
    def __call__(self, x, key):
        x = self.linear1(x)
        x = nn.relu(x)
        x = self.linear2(x)
        x = nn.relu(x)
        x = self.linear3(x)
        x = nn.relu(x)
        x = self.linear4(x)
        x = nn.relu(x)
        x = self.linear5(x)
        # The last layer of the neural requires samping
        mean = x[0:self.nbiomtu]
        logstd = x[self.nbiomtu:self.nbiomtu*2]
        std = jp.exp(logstd)
        samples = jp.clip(jax.random.normal(key)*std*0.3 + mean, 0 ,1)
        
        return samples, mean, logstd


# Test the neural network
control_model = Controller_NN()

print(control_model)
# Init the model
key = jax.random.PRNGKey(66)
sub_keys = jax.random.split(key,1)
# The second parameter is the dommy input
params = control_model.init(key,jp.empty([1,mjx_model.nq*2]),sub_keys[0])
# print(params)
print(control_model.apply( params,jp.ones(mjx_model.nq*2), sub_keys[0]))
jit_nn_apply = jax.jit(lambda params,states,key : control_model.apply(params,states,key))

Controller_NN()
(Array([0.3537, 0.4582, 0.    , 0.    , 0.3568, 0.    , 0.    , 0.    , 0.    , 0.0299, 0.1213,
       0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.0104, 0.    , 0.    , 0.    ,
       0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.4784, 0.    , 0.2003, 0.    , 0.    ,
       0.    , 0.    , 0.2286, 0.    , 0.    , 0.    , 0.    , 0.2364, 0.    , 0.    , 0.    ,
       0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    ],      dtype=float32), Array([ 0.5507,  0.7117,  0.1663, -0.386 ,  0.572 ,  0.0428,  0.151 , -0.341 , -0.1592,  0.1468,
        0.3125, -0.506 , -0.098 , -0.1812, -0.2472, -0.1027, -0.2657, -0.0098,  0.1604,  0.1802,
       -0.4265, -0.2269, -0.2294, -0.0505,  0.1361,  0.0374,  0.0641, -0.0512,  0.6791, -0.3706,
        0.4092, -0.1719, -0.2406, -0.022 ,  0.1271,  0.4203,  0.0284,  0.0943, -0.4435,  0.1295,
        0.4245,  0.0444, -0.1952,  0.121 , -0.08  , -0.3598, -0.0973, -0.0134,  0.0238, -0.2089,
       -0

## Combine Neural Net and Simulation into one Jax function

In [None]:
# Multiple steps
def step_fn(carry, _):
    data, model = carry
    new_data = mjx.step(model, data)
    new_carry = new_data, model
    return new_carry, _

def multiple_steps(model, data):
    init_carry = (data, model)
    y, _ = jax.lax.scan(step_fn, init_carry, None, length=10)
    new_data = y[0]
    return new_data

def nn_mjx_one_step(model, data, key):
    states = jp.concatenate([data.qpos, data.qvel])
    act = jit_nn_apply(params, states, key)[0]
    # Generate the next key
    
    mtu = data.biomtu
    mtu = mtu.replace(act = act)
    data = data.replace(biomtu = mtu)
    new_data = mjx.step(model, data)
    return new_data, key

def nn_step_fn(carry, _):
    data, model, key = carry
    new_data, new_key = nn_mjx_one_step(model, data, key)
    new_carry = new_data, model, new_key
    return new_carry, _

def nn_mjx_multi_steps(model, data, key):
    init_carry = (data, model, key)
    y, _ = jax.lax.scan(nn_step_fn, init_carry, None, length=10)
    new_data = y[0]
    new_key = y[2]
    return new_data, new_key

## Testing control model with neural networks

In [11]:
import mujoco.viewer
import time

previous_frame_time = time.time()
i = 0
with mujoco.viewer.launch_passive(mj_model, mj_data) as viewer:
    while viewer.is_running():
        # Update mjx_data from mj_data. The mj_data was modified by the viewer
        # mjx_data = mjx_data.replace(ctrl=mj_data.ctrl, xfrc_applied=mj_data.xfrc_applied)
        # Use the nerual network to generate ctrl signal
        # Generate key
        key = jax.random.key(250)
        states = jp.concatenate([mjx_data.qpos, mjx_data.qvel])
        act = jit_nn_apply(params, states, key)[0]
        print(act)
        biomtu = mjx_data.biomtu
        biomtu = biomtu.replace(act = act)
        mjx_data = mjx_data.replace(biomtu = biomtu)
        # ctrl = jit_nn_apply(params, states)
        mjx_data = mjx_data.replace(xfrc_applied=mj_data.xfrc_applied)
        mjx_data = mjx_data.replace(qpos=mj_data.qpos, qvel=mj_data.qvel, time = mj_data.time)
        
        # Update mjx_model from mj_model
        mjx_model = mjx_model.tree_replace({
            'opt.gravity': mj_model.opt.gravity,
            'opt.tolerance': mj_model.opt.tolerance,
            'opt.ls_tolerance': mj_model.opt.ls_tolerance,
            'opt.timestep': mj_model.opt.timestep,
        })
        
        mjx_data = mjx_step(mjx_model, mjx_data)
        mjx.get_data_into(mj_data, mj_model, mjx_data)
        
        # Record the current time at the start of this frame
        current_frame_time = time.time()
    
        # Calculate the difference in time from the last frame
        time_between_frames = current_frame_time - previous_frame_time
    
        # Print the time between frames
        print(f"Time between frames: {time_between_frames} seconds")
        previous_frame_time = current_frame_time
        
        print("ACT:", mjx_data.biomtu.act)
        print(mjx_data.qpos)
        # print(len(mjx_data.qvel))
        viewer.sync()

[0.0845 0.133  0.068  0.     0.0269 0.085  0.0284 0.0293 0.0279 0.0451 0.0429 0.     0.     0.0367
 0.0322 0.003  0.     0.     0.0365 0.023  0.     0.0237 0.     0.     0.     0.0251 0.0204 0.
 0.106  0.     0.0898 0.     0.     0.     0.0077 0.0878 0.0267 0.0008 0.     0.0197 0.0983 0.0527
 0.     0.0485 0.0145 0.     0.0182 0.0321 0.0158 0.     0.     0.0075 0.008  0.0344]
Time between frames: 0.22204113006591797 seconds
ACT: [0.0845 0.133  0.068  0.     0.0269 0.085  0.0284 0.0293 0.0279 0.0451 0.0429 0.     0.     0.0367
 0.0322 0.003  0.     0.     0.0365 0.023  0.     0.0237 0.     0.     0.     0.0251 0.0204 0.
 0.106  0.     0.0898 0.     0.     0.     0.0077 0.0878 0.0267 0.0008 0.     0.0197 0.0983 0.0527
 0.     0.0485 0.0145 0.     0.0182 0.0321 0.0158 0.     0.     0.0075 0.008  0.0344]
[ 0.      0.9501 -0.      0.0007  0.      0.0001 -0.0015 -0.0001 -0.0011 -0.0037 -0.3957  0.0024
 -0.0203 -0.0031  0.0233 -0.0014 -0.     -0.0009 -0.0037 -0.3957  0.0023 -0.0201 -0.0032  0

## Batched Random Init Model

In [None]:
def random_init(data:mujoco.mjx._src.types.Data, rng: jax.Array):
    init_qpos = data.qpos
    init_qvel = data.qvel
    new_rng, rng1, rng2, rng3, rng4 = jax.random.split(rng, 5)
    # Qpos_1 is the vertical position
    qpos_0 = init_qpos[0] + jax.random.uniform(rng1, [1], minval=-1, maxval=1)
    qpos_1 = init_qpos[1] + jax.random.uniform(rng2, [1], minval=-1.5, maxval=1.5)
    random_qpos = jp.array([*qpos_0,*qpos_1])
    qvel_0 = init_qvel[0] + jax.random.uniform(rng3, [1], minval=-0.4, maxval=0.4)
    qvel_1 = init_qvel[1] + jax.random.uniform(rng4, [1], minval=-0.4, maxval=0.4)
    random_qpos = jp.array([*qpos_0,*qpos_1])
    random_qvel = jp.array([*qvel_0,*qvel_1])
    newdata = data.replace(qpos=random_qpos)
    newdata = newdata.replace(qvel=random_qvel)
    newdata = mjx.forward(mjx_model, newdata)
    # print('data:',data.qpos, data.qvel)
    return newdata, new_rng

vrandom_init = jax.jit(jax.vmap(random_init, in_axes=(None,0),out_axes=(0,0)))