# Learn AOT to speed up compilation

In [5]:
import jax
import jax.numpy as jp

def f(x,y):
    return x*2 +y
x, y = 2,3


i32_scalar = jax.ShapeDtypeStruct((), jp.dtype('int32'))
lowered=jax.jit(f).lower(i32_scalar, i32_scalar)

print(lowered.as_text())

compiled = lowered.compile()
print(compiled.cost_analysis())

compiled(x,y)

TRACING CACHE MISS at /tmp/ipykernel_591948/3584921380.py:10:8 (<module>) because:
  never seen function:
    f id=134866206508192 defined at /tmp/ipykernel_591948/3584921380.py:4
PERSISTENT COMPILATION CACHE MISS for 'jit_f' with key 'jit_f-a66f5fa0bcd7ec596c3d0e1de326e04658f2b01211742476aab810456c448992'
Not writing persistent cache entry for 'jit_f' because it took < 2.00 seconds to compile (0.02s)


module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32> {jax.result_info = ""}) {
    %c = stablehlo.constant dense<2> : tensor<i32>
    %0 = stablehlo.multiply %arg0, %c : tensor<i32>
    %1 = stablehlo.add %0, %arg1 : tensor<i32>
    return %1 : tensor<i32>
  }
}

[{'bytes accessed1{}': 4.0, 'flops': 2.0, 'bytes accessed0{}': 4.0, 'bytes accessed': 12.0, 'utilization1{}': 1.0, 'bytes accessedout{}': 4.0, 'utilization0{}': 1.0}]


Array(7, dtype=int32)

In [11]:
import numpy as np
import mediapy as media
import matplotlib.pyplot as plt


import os
# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".60"

# Optionally, force JAX to preallocate memory.
# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "true"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

# Setup environment variable for Nvidia GPU acceleration
os.environ['XLA_FLAGS'] = (
    # '--xla_gpu_enable_triton_softmax_fusion=true '
    '--xla_gpu_triton_gemm_any=True '
    # '--xla_gpu_enable_async_collectives=true '
    # '--xla_gpu_enable_latency_hiding_scheduler=true '
    '--xla_gpu_enable_highest_priority_async_stream=true '
    # '--xla_cpu_multi_thread_eigen=true intra_op_parallelism_threads=32'
)

backend = 'gpu'
# backend = 'METAL'
# backend = 'cpu'

import jax
jax.config.update('jax_platform_name', 'gpu')
# os.environ["TF_XLA_FLAGS"] = "--tf_xla_enable_lazy_compilation=false"
# Enable compliation catch
os.environ["JAX_COMPILATION_CACHE_DIR"] = "./jax_cache"
jax.config.update("jax_compilation_cache_dir", "./jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", 0)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 2)
jax.config.update("jax_explain_cache_misses", True)

from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir("./jax_cache")
# Debug Nan
# jax.config.update("jax_debug_nans", True)

from jax import numpy as jp
# More legible printing from numpy.
jp.set_printoptions(precision=4, suppress=True, linewidth=100)

import mujoco
import mujoco.mjx as mjx
from mujoco.mjx._src import scan
from mujoco.mjx._src import types

# More legible printing from numpy.
np.set_printoptions(precision=4, suppress=True, linewidth=100)

from IPython.display import clear_output
clear_output()

device = jax.devices(backend=backend)[0]

model_path = '/home/bugman/Currentwork/biomujoco_converter/converted/mjc/Gait2354/gait2354_cvt1.xml'


mj_model = mujoco.MjModel.from_xml_path(model_path)
mjx_model = mjx.put_model(mj_model,device=device)

# Disable tendon
opt = mjx_model.opt.replace(disableflags = mjx_model.opt.disableflags |mujoco.mjtDisableBit.mjDSBL_PASSIVE)
mjx_model = mjx_model.replace(opt=opt)

mjx_data = mjx.make_data(mjx_model)
mj_data = mujoco.MjData(mj_model)

# Single step
mjx_step_lower = jax.jit(mjx.step, backend=backend).lower(mjx_model, mjx_data)
mjx_step = mjx_step_lower.compile()



In [7]:
print(mjx_step.cost_analysis()[0])
print(mjx_step.cost_analysis()[0]["flops"])

{'utilization92{}': 4.0, 'bytes accessed49{}': 56.0, 'utilization79{}': 5.0, 'optimal_seconds': -8.0, 'utilization78{}': 5.0, 'utilization31{}': 12.0, 'utilization101{}': 3.0, 'utilization146{}': 1.0, 'utilization107{}': 2.0, 'utilization44{}': 7.0, 'utilization136{}': 2.0, 'bytes accessed37{}': 180.0, 'utilization155{}': 1.0, 'utilization41{}': 7.0, 'utilization11{}': 67.0, 'utilization157{}': 1.0, 'utilization0{}': 1917.0714111328125, 'bytes accessed0{}': 565069.0, 'bytes accessed11{}': 7544.0, 'bytes accessed42{}': 196.0, 'bytes accessed36{}': 304.0, 'bytes accessed70{}': 268.0, 'utilization111{}': 2.0, 'utilization109{}': 2.0, 'bytes accessed26{}': 1212.0, 'bytes accessed20{}': 2524.0, 'utilization74{}': 5.0, 'utilization64{}': 5.0, 'bytes accessed52{}': 44.0, 'bytes accessedout{9}': 592.0, 'utilization19{}': 29.0, 'bytes accessed39{}': 228.0, 'utilization81{}': 4.0, 'utilization125{}': 2.0, 'bytes accessed48{}': 296.0, 'bytes accessed3{}': 47997.0, 'bytes accessed79{}': 292.0, 'by

In [12]:
from mujoco.mjx._src.biomtu import acceleration_mtu



# Load the Keyframe
mjx_data = mjx_data.replace(qpos = mj_model.key_qpos[0])
mj_data.qpos = mj_model.key_qpos[0]

# Calculate equilibrum
mjx_data = acceleration_mtu.calc_equilibrium(mjx_model, mjx_data)

for i in range(10):
    mjx_data = mjx_step(mjx_model, mjx_data)

def print_all():
    print(mjx_model.biomtu_adr)
    print(mjx_model.mtu_wrap_objid)
    print(mjx_model.mtu_wrap_type)
    print(mjx_model.biomtu_fiso)
    print(mjx_model.biomtu_vmax)
    print(mjx_model.biomtu_ofl)
    print(mjx_model.biomtu_opa)
    print(mjx_model.biomtu_mass)
    print("-------Data--------")
    print("qpos:", mjx_data.qpos)
    print("mtu l:", mjx_data.biomtu.l)
    print("tendon l:", mjx_data.biomtu.tendon_l)
    print("fiber l :", mjx_data.biomtu.fiber_l)
    print("Muscle Bce:", mjx_data.biomtu.B_ce)
    print("Muscle vm:", mjx_data.biomtu.m)
    print("Fiber acc:", mjx_data.biomtu.fiber_acc)
    print("Fiber v:", mjx_data.biomtu.fiber_v)
    print("Biomtu h:", mjx_data.biomtu.h)
    print(mjx_data.biomtu.v)
    print(mjx_data.biomtu.h)  # The constant high of the muscle.
    print(mjx_data.biomtu.pennation_angle)
    print(mjx_data.biomtu.origin_body_id)
    print(mjx_data.biomtu.insertion_body_id)
    print("mtu act:", mjx_data.biomtu.act)
    # print(mjx_data.biomtu.j)
    print(mjx_data.qfrc_biomtu)
    print(mj_model.key_time)
    print(mj_model.key_qpos)
    print(mj_model.key_qvel)

print_all()

TRACING CACHE MISS at /tmp/ipykernel_591948/1040697251.py:10:11 (<module>) because:
  never seen function:
    calc_equilibrium.<locals>.callback_optimize_seperate id=134864463809248 defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/mujoco/mjx/_src/biomtu/acceleration_mtu.py:241
  but seen another function defined on the same line; maybe the function is
  being re-defined repeatedly, preventing caching?
PERSISTENT COMPILATION CACHE MISS for 'jit_callback_optimize_seperate' with key 'jit_callback_optimize_seperate-ad54009f7380485982440e7e3e3ed6e53a03248877483855fb60d98ecc81f2c2'
Not writing persistent cache entry for 'jit_callback_optimize_seperate' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
TRACING CACHE MISS at /tmp/ipykernel_591948/1040697251.py:10:11 (<module>) because:
  never seen function:
    calc_equilibrium.<locals>.goal id=134864994582912 defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/muj

[  0   2   4   6   9  12  17  19  23  25  29  33  37  41  46  51  53  55  58  61  65  68  70  74
  77  79  81  83  86  89  94  96 100 102 106 110 114 118 123 128 130 132 135 138 142 145 147 151
 154 156 158 160 162 164]
[  0  58   1  59   2  60   3  86  87  61  88  89   4  62  90  91  92   5  63   6  64  65  93   7
  66   8  94  95  96   9  10  67  68  11  12  69  70  13  14  71  72  15  16  17  73  74  18  19
  20  75  76  21  77  22  78  23  24  79  25  80  97  81  82  83  98  84  85 104  99 105 100 101
 106 107 102 103 108  26 109  27 110  28 111  29 137 138 112 139 140  30 113 141 142 143  31 114
  32 115 116 144  33 117  34 145 146 147  35  36 118 119  37  38 120 121  39  40 122 123  41  42
  43 124 125  44  45  46 126 127  47 128  48 129  49  50 130  51 131 148 132 133 134 149 135 136
 155 150 156 151 152 157 158 153 154 159  52 161  53 162  54 163  55 164  56 165  57 166]
[3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 

In [8]:
import jax
import jax.numpy as jp
import flax
import flax.linen as nn
import optax

class Controller_NN(nn.Module):
    # It will output the mean and variance for each muscle's activation
    
    def setup(self):
        # Features means the output dimension# Single step
        self.nbiomtu = 54
        self.linear1 = nn.Dense(features=400)
        self.linear2 = nn.Dense(features=400)
        self.linear3 = nn.Dense(features=400)
        self.linear4 = nn.Dense(features=400)
        # The last layer will output the mean and logstd
        self.linear5 = nn.Dense(features=self.nbiomtu*2)
        
    
    def __call__(self, x, key):
        x = self.linear1(x)
        x = nn.relu(x)
        x = self.linear2(x)
        x = nn.relu(x)
        x = self.linear3(x)
        x = nn.relu(x)
        x = self.linear4(x)
        x = nn.relu(x)
        x = self.linear5(x)
        # The last layer of the neural requires samping
        mean = x[0:self.nbiomtu]
        logstd = x[self.nbiomtu:self.nbiomtu*2]
        std = jp.exp(logstd)
        samples = jp.clip(jax.random.normal(key)*std*0.3 + mean, 0 ,1)
        
        return samples, mean, logstd


# Test the neural network
control_model = Controller_NN()

print(control_model)
# Init the model
key = jax.random.key(66)
sub_keys = jax.random.split(key,1)
# The second parameter is the dommy input
params = control_model.init(key,jp.empty([1,mjx_model.nq*2]),sub_keys[0])
# print(params)
print(control_model.apply( params, jp.ones(mjx_model.nq*2), sub_keys[0]))
jit_nn_apply = jax.jit(lambda params,states,key : control_model.apply(params,states,key))

PERSISTENT COMPILATION CACHE MISS for 'jit__threefry_seed' with key 'jit__threefry_seed-b1f5e5394dcb69b2eb30d0f1c98bce4d599427c981b33c8c9c9352ff3f6f47d4'
Not writing persistent cache entry for 'jit__threefry_seed' because it took < 2.00 seconds to compile (0.02s)
TRACING CACHE MISS at /tmp/ipykernel_591948/3932189320.py:46:11 (<module>) because:
  never seen function:
    _threefry_split id=134873191543520 defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/prng.py:1089
PERSISTENT COMPILATION CACHE MISS for 'jit__threefry_split' with key 'jit__threefry_split-733f3ae6f83f0386b1a477d21308ee9624bafb78870970a5645ac4ebb03f86c0'
Not writing persistent cache entry for 'jit__threefry_split' because it took < 2.00 seconds to compile (0.05s)
TRACING CACHE MISS at /tmp/ipykernel_591948/3932189320.py:48:32 (<module>) because:
  never seen function:
    broadcast_in_dim id=134865535368416 defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-package

Controller_NN()


Not writing persistent cache entry for 'jit__threefry_fold_in' because it took < 2.00 seconds to compile (0.06s)
PERSISTENT COMPILATION CACHE MISS for 'jit_sqrt' with key 'jit_sqrt-5ba6e2d335bcce92235d48c3b0d6c00aba9acb98cffd58639e1c97a268c8660f'
Not writing persistent cache entry for 'jit_sqrt' because it took < 2.00 seconds to compile (0.02s)
PERSISTENT COMPILATION CACHE MISS for 'jit_true_divide' with key 'jit_true_divide-5f0cb230b9d483a302d65f5b48ff117b6cfdeb14a65d970935acd807407da935'
Not writing persistent cache entry for 'jit_true_divide' because it took < 2.00 seconds to compile (0.02s)
TRACING CACHE MISS at /tmp/ipykernel_591948/3932189320.py:22:12 (Controller_NN.__call__) because:
  never seen function:
    _uniform id=134873191572768 defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/random.py:405
  but seen another function defined on the same line; maybe the function is
  being re-defined repeatedly, preventing caching?
TRACING CACHE MIS

(Array([0.3537, 0.4582, 0.    , 0.    , 0.3568, 0.    , 0.    , 0.    , 0.    , 0.0299, 0.1213,
       0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.0104, 0.    , 0.    , 0.    ,
       0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.4784, 0.    , 0.2003, 0.    , 0.    ,
       0.    , 0.    , 0.2286, 0.    , 0.    , 0.    , 0.    , 0.2364, 0.    , 0.    , 0.    ,
       0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    , 0.    ],      dtype=float32), Array([ 0.5507,  0.7117,  0.1663, -0.386 ,  0.572 ,  0.0428,  0.151 , -0.341 , -0.1592,  0.1468,
        0.3125, -0.506 , -0.098 , -0.1812, -0.2472, -0.1027, -0.2657, -0.0098,  0.1604,  0.1802,
       -0.4265, -0.2269, -0.2294, -0.0505,  0.1361,  0.0374,  0.0641, -0.0512,  0.6791, -0.3706,
        0.4092, -0.1719, -0.2406, -0.022 ,  0.1271,  0.4203,  0.0284,  0.0943, -0.4435,  0.1295,
        0.4245,  0.0444, -0.1952,  0.121 , -0.08  , -0.3598, -0.0973, -0.0134,  0.0238, -0.2089,
       -0.5425, -0.3392, 

In [9]:
# Multiple steps
def step_fn(carry, _):
    data, model = carry
    new_data = mjx.step(model, data)
    new_carry = (new_data, model)
    return new_carry, _

def multiple_steps(model, data):
    init_carry = (data, model)
    y, _ = jax.lax.scan(step_fn, init_carry, None, length=10)
    new_data = y[0]
    return new_data

def nn_mjx_one_step(nn_params, model, data, key):
    states = jp.concatenate([data.qpos, data.qvel])
    act = jit_nn_apply(nn_params, states, key)[0]
    # Generate the next key
    new_key = jax.random.split(key,1)[0]
    # mtu = data.biomtu
    # mtu = mtu.replace(act = act)
    # data = data.replace(biomtu = mtu)
    new_data = mjx.step(model, data)
    return new_data, new_key

@jax.jit
def nn_mjx_multi_steps(nn_params, model, data, key):
    def nn_step_fn(carry, _):
        data, key = carry
        new_data, new_key = nn_mjx_one_step(nn_params, model, data, key)
        new_carry = (new_data, new_key)
        # head_hight = new_data.xpos[2,2]
        # jax.debug.print("Head Height {0}",head_hight)
        return new_carry, _
    init_carry = (data, key)
    y, _ = jax.lax.scan(nn_step_fn, init_carry, None, length=10)
    new_data = y[0]
    new_key = y[1]
    return new_data, new_key

In [10]:
import mujoco.viewer
import time

previous_frame_time = time.time()
i = 0
key = jax.random.key(334)
with mujoco.viewer.launch_passive(mj_model, mj_data) as viewer:
    while viewer.is_running():
        # Update mjx_data from mj_data. The mj_data was modified by the viewer
        # mjx_data = mjx_data.replace(ctrl=mj_data.ctrl, xfrc_applied=mj_data.xfrc_applied)
        # Use the nerual network to generate ctrl signal
        # Generate key
        
        mjx_data = mjx_data.replace(xfrc_applied=jp.array(mj_data.xfrc_applied, dtype=jp.float32))
        mjx_data = mjx_data.replace(
            qpos= jp.array(mj_data.qpos, dtype=jp.float32),
            qvel= jp.array(mj_data.qvel, dtype=jp.float32),
            time = jp.array(mj_data.time, dtype=jp.float32))
        
        # Update mjx_model from mj_model
        mjx_model = mjx_model.tree_replace({
            'opt.gravity': jp.array(mj_model.opt.gravity, dtype=jp.float32),
            'opt.tolerance': jp.array(mj_model.opt.tolerance, dtype=jp.float32),
            'opt.ls_tolerance': jp.array(mj_model.opt.ls_tolerance, dtype=jp.float32),
            'opt.timestep': jp.array(mj_model.opt.timestep, dtype=jp.float32),
        })
        
        mjx_data = mjx_step(mjx_model, mjx_data)
        # mjx_data, key = nn_mjx_multi_steps(params, mjx_model, mjx_data, key)
        mjx.get_data_into(mj_data, mj_model, mjx_data)
        
        # Record the current time at the start of this frame
        current_frame_time = time.time()
    
        # Calculate the difference in time from the last frame
        time_between_frames = current_frame_time - previous_frame_time
    
        # Print the time between frames
        print(f"Time between frames: {time_between_frames} seconds")
        previous_frame_time = current_frame_time
        
        # print("ACT:", mjx_data.biomtu.act)
        # print(mjx_data.qpos)
        print(mjx_data.sensordata)
        # print(len(mjx_data.qvel))
        viewer.sync()

TRACING CACHE MISS at /tmp/ipykernel_591948/1919104746.py:14:49 (<module>) because:
  for convert_element_type defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/dispatch.py:98
  never seen input type signature:
    args[0]: f32[14,6]
  closest seen input type signature has 1 mismatches, including:
    * at args[0], seen f32[4,3], but now given f32[14,6]
PERSISTENT COMPILATION CACHE MISS for 'jit_convert_element_type' with key 'jit_convert_element_type-7609aadf16ca451a7b54974e989f7c00ddd7e21e36e84058a5238d3a9b80b11c'
Not writing persistent cache entry for 'jit_convert_element_type' because it took < 2.00 seconds to compile (0.01s)
TRACING CACHE MISS at /tmp/ipykernel_591948/1919104746.py:16:18 (<module>) because:
  for convert_element_type defined at /home/bugman/anaconda3/envs/biomujoco/lib/python3.11/site-packages/jax/_src/dispatch.py:98
  never seen input type signature:
    args[0]: f32[27]
  closest seen input type signature has 1 mismatches, in

Time between frames: 0.45810651779174805 seconds
[-0.1007 -0.0003  1.6315]
Time between frames: 0.013562440872192383 seconds
[-0.1007 -0.0003  1.6315]
Time between frames: 0.014281511306762695 seconds
[-0.1007 -0.0003  1.6314]
Time between frames: 0.013303518295288086 seconds
[-0.1007 -0.0003  1.6313]
Time between frames: 0.012742042541503906 seconds
[-0.1007 -0.0003  1.6312]
Time between frames: 0.013068199157714844 seconds
[-0.1007 -0.0003  1.631 ]
Time between frames: 0.014198064804077148 seconds
[-0.1008 -0.0003  1.6308]
Time between frames: 0.013084173202514648 seconds
[-0.1008 -0.0003  1.6305]
Time between frames: 0.013964414596557617 seconds
[-0.1008 -0.0003  1.6302]
Time between frames: 0.013312339782714844 seconds
[-0.1008 -0.0003  1.6299]
Time between frames: 0.012134790420532227 seconds
[-0.1008 -0.0003  1.6296]
Time between frames: 0.009987592697143555 seconds
[-0.1008 -0.0003  1.6292]
Time between frames: 0.012313604354858398 seconds
[-0.1009 -0.0003  1.6287]
Time between 