# The Biomujoco model viewer

In [1]:
import numpy as np
import mediapy as media
import matplotlib.pyplot as plt


import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".50"

# Optionally, force JAX to preallocate memory.
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

# Setup environment variable for Nvidia GPU acceleration

os.environ['XLA_FLAGS'] = (
    '--xla_gpu_triton_gemm_any=True '
)

backend = 'gpu'
# backend = 'METAL'
# backend = 'cpu'

import jax
# Enable compliation catch
jax.config.update("jax_compilation_cache_dir", "./jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

from jax import numpy as jp
# More legible printing from numpy.
jp.set_printoptions(precision=4, suppress=True, linewidth=100)

import mujoco
import mujoco.mjx as mjx
from mujoco.mjx._src import scan
from mujoco.mjx._src import types

# More legible printing from numpy.
np.set_printoptions(precision=4, suppress=True, linewidth=100)

from IPython.display import clear_output
clear_output()

device = jax.devices(backend=backend)[0]

model_path = '/home/bugman/Currentwork/biomujoco_converter/converted/mjc/Gait2354/gait2354_cvt1.xml'

# Single step
mjx_step = jax.jit(mjx.step, backend=backend)

# Multiple steps
def step_fn(carry, _):
    data, model = carry
    new_data = mjx.step(model, data)
    new_carry = new_data, model
    return new_carry, _

def multiple_steps(model, data):
    init_carry = (data, model)
    y, _ = jax.lax.scan(step_fn, init_carry, None, length=10)
    new_data = y[0]
    return new_data

mjx_multiple_steps = jax.jit(multiple_steps, backend=backend, )


## Make Data, Load key frame, and calculate mtu equilibrium 

In [3]:

from mujoco.mjx._src.biomtu.acceleration_mtu import acceleration_mtu

mj_model = mujoco.MjModel.from_xml_path(model_path)
mjx_model = mjx.put_model(mj_model,device=device)

# Disable tendon
opt = mjx_model.opt.replace(disableflags = mjx_model.opt.disableflags |mujoco.mjtDisableBit.mjDSBL_PASSIVE)
mjx_model = mjx_model.replace(opt=opt)

mjx_data = mjx.make_data(mjx_model)
mj_data = mujoco.MjData(mj_model)

# Load the Keyframe
mjx_data = mjx_data.replace(qpos = mj_model.key_qpos[0])
mj_data.qpos = mj_model.key_qpos[0]

# Calculate equilibrum
mjx_data = acceleration_mtu.calc_equilibrium(mjx_model, mjx_data)

def print_all():
    print(mjx_model.biomtu_adr)
    print(mjx_model.mtu_wrap_objid)
    print(mjx_model.mtu_wrap_type)
    print(mjx_model.biomtu_fiso)
    print(mjx_model.biomtu_vmax)
    print(mjx_model.biomtu_ofl)
    print(mjx_model.biomtu_opa)
    print(mjx_model.biomtu_mass)
    print("-------Data--------")
    print("qpos:", mjx_data.qpos)
    print("mtu l:", mjx_data.biomtu.l)
    print("tendon l:", mjx_data.biomtu.tendon_l)
    print("fiber l :", mjx_data.biomtu.fiber_l)
    print("Muscle Bce:", mjx_data.biomtu.B_ce)
    print("Muscle vm:", mjx_data.biomtu.m)
    print("Fiber acc:", mjx_data.biomtu.fiber_acc)
    print("Fiber v:", mjx_data.biomtu.fiber_v)
    print("Biomtu h:", mjx_data.biomtu.h)
    print(mjx_data.biomtu.v)
    print(mjx_data.biomtu.h)  # The constant high of the muscle.
    print(mjx_data.biomtu.pennation_angle)
    print(mjx_data.biomtu.origin_body_id)
    print(mjx_data.biomtu.insertion_body_id)
    print("mtu act:", mjx_data.biomtu.act)
    # print(mjx_data.biomtu.j)
    print(mjx_data.qfrc_biomtu)
    print(mj_model.key_time)
    print(mj_model.key_qpos)
    print(mj_model.key_qvel)

print_all()

--------------------
BUGMAN: muscle name-glut_med1_r
--------------------
BUGMAN: muscle name-glut_med2_r
--------------------
BUGMAN: muscle name-glut_med3_r
--------------------
BUGMAN: muscle name-bifemlh_r
--------------------
BUGMAN: muscle name-bifemsh_r
--------------------
BUGMAN: muscle name-sar_r
--------------------
BUGMAN: muscle name-add_mag2_r
--------------------
BUGMAN: muscle name-tfl_r
--------------------
BUGMAN: muscle name-pect_r
--------------------
BUGMAN: muscle name-grac_r
--------------------
BUGMAN: muscle name-glut_max1_r
--------------------
BUGMAN: muscle name-glut_max2_r
--------------------
BUGMAN: muscle name-glut_max3_r
--------------------
BUGMAN: muscle name-iliacus_r
--------------------
BUGMAN: muscle name-psoas_r
--------------------
BUGMAN: muscle name-quad_fem_r
--------------------
BUGMAN: muscle name-gem_r
--------------------
BUGMAN: muscle name-peri_r
--------------------
BUGMAN: muscle name-rect_fem_r
--------------------
BUGMAN: muscle nam

## View and Play with the Model

In [5]:
import mujoco.viewer
import time

previous_frame_time = time.time()
i = 0
with mujoco.viewer.launch_passive(mj_model, mj_data) as viewer:
    while viewer.is_running():
        mjx_data = mjx_data.replace(xfrc_applied=mj_data.xfrc_applied)
        mjx_data = mjx_data.replace(qpos=mj_data.qpos, qvel=mj_data.qvel, time = mj_data.time)
        
        # Update mjx_model from mj_model
        mjx_model = mjx_model.tree_replace({
            'opt.gravity': mj_model.opt.gravity,
            'opt.tolerance': mj_model.opt.tolerance,
            'opt.ls_tolerance': mj_model.opt.ls_tolerance,
            'opt.timestep': mj_model.opt.timestep,
        })
        
        # mjx_data = mjx_step(mjx_model, mjx_data)
        mjx_data = multiple_steps(mjx_model, mjx_data)
        mjx.get_data_into(mj_data, mj_model, mjx_data)
        
        # Record the current time at the start of this frame
        current_frame_time = time.time()
    
        # Calculate the difference in time from the last frame
        time_between_frames = current_frame_time - previous_frame_time
    
        # Print the time between frames
        print(f"Time between frames: {time_between_frames} seconds")
        previous_frame_time = current_frame_time
        
        # print(mjx_data.biomtu.f_se)
        # print(mjx_data.qpos)
        # print(len(mjx_data.qvel))
        viewer.sync()

Time between frames: 0.31905221939086914 seconds
Time between frames: 0.05259990692138672 seconds
Time between frames: 0.04004669189453125 seconds
Time between frames: 0.056181907653808594 seconds
Time between frames: 0.03765535354614258 seconds
Time between frames: 0.0370631217956543 seconds
Time between frames: 0.0364227294921875 seconds
Time between frames: 0.035539865493774414 seconds
Time between frames: 0.033997297286987305 seconds
Time between frames: 0.03330230712890625 seconds
Time between frames: 0.035349369049072266 seconds
Time between frames: 0.041129112243652344 seconds
Time between frames: 0.04765796661376953 seconds
Time between frames: 0.04915571212768555 seconds
Time between frames: 0.03883504867553711 seconds
Time between frames: 0.036749839782714844 seconds
Time between frames: 0.04660964012145996 seconds
Time between frames: 0.03838801383972168 seconds
Time between frames: 0.03890490531921387 seconds
Time between frames: 0.03710579872131348 seconds
Time between fra