<a href="https://colab.research.google.com/github/ellipticalcurves/aibjj/blob/main/Aibjj.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install mujoco
!pip install mujoco_mjx
!pip install brax

Collecting mujoco
  Downloading mujoco-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (44 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/44.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.3/44.3 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
Collecting glfw (from mujoco)
  Downloading glfw-2.7.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38-none-manylinux2014_x86_64.whl.metadata (5.4 kB)
Downloading mujoco-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.0/6.0 MB[0m [31m22.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading glfw-2.7.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38-none-manylinux2014_x86_64.whl (211 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.8/211.8 kB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected p

In [2]:
#@title Check if MuJoCo installation was successful

from google.colab import files

import distutils.util
import os
import subprocess
if subprocess.run('nvidia-smi').returncode:
  raise RuntimeError(
      'Cannot communicate with GPU. '
      'Make sure you are using a GPU Colab runtime. '
      'Go to the Runtime menu and select Choose runtime type.')

# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but the Colab
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
    f.write("""{
    "file_format_version" : "1.0.0",
    "ICD" : {
        "library_path" : "libEGL_nvidia.so.0"
    }
}
""")

# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl

try:
  print('Checking that the installation succeeded:')
  import mujoco
  mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
  raise e from RuntimeError(
      'Something went wrong during installation. Check the shell output above '
      'for more information.\n'
      'If using a hosted Colab runtime, make sure you enable GPU acceleration '
      'by going to the Runtime menu and selecting "Choose runtime type".')

print('Installation successful.')

# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags


Setting environment variable to use GPU rendering:
env: MUJOCO_GL=egl
Checking that the installation succeeded:
Installation successful.


NOTE: you may have to connect to a GPU instance of collab by changing run time type and selecting T4 GPU


In [3]:
#@title Import packages for plotting and creating graphics
import time
import itertools
import numpy as np
from typing import Callable, NamedTuple, Optional, Union, List

# Graphics and plotting.
print('Installing mediapy:')
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
!pip install -q mediapy
import mediapy as media
import matplotlib.pyplot as plt

# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)

Installing mediapy:
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m29.6 MB/s[0m eta [36m0:00:00[0m
[?25h

In [4]:
#@title Import MuJoCo, MJX, and Brax
from datetime import datetime
from etils import epath
import functools
from IPython.display import HTML
from typing import Any, Dict, Sequence, Tuple, Union
import os
from ml_collections import config_dict


import jax
from jax import numpy as jp
import numpy as np
from flax.training import orbax_utils
from flax import struct
from matplotlib import pyplot as plt
import mediapy as media
from orbax import checkpoint as ocp

import mujoco
from mujoco import mjx

from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, mjcf, model


In [64]:
#@title 2_humanoids xml string

xml ="""<mujoco model="22 Humanoids">
  <option timestep="0.005"/>

  <size memory="100M"/>

  <asset>
    <texture type="skybox" builtin="gradient" rgb1=".3 .5 .7" rgb2="0 0 0" width="512" height="512"/>
    <texture name="body" type="cube" builtin="flat" mark="cross" width="128" height="128"
             rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" markrgb="1 1 1" random="0.01"/>
    <material name="body" texture="body" texuniform="true" rgba="0.8 0.6 .4 1"/>
    <texture name="grid" type="2d" builtin="checker" width="512" height="512" rgb1=".1 .2 .3" rgb2=".2 .3 .4"/>
    <material name="grid" texture="grid" texrepeat="1 1" texuniform="true" reflectance=".2"/>
  </asset>

  <default>
    <motor ctrlrange="-1 1" ctrllimited="true"/>
    <default class="body">

      <!-- geoms -->
      <geom type="capsule" condim="1" friction=".7" solimp=".9 .99 .003" solref=".015 1" material="body" group="1"/>
      <default class="thigh">
        <geom size=".06"/>
      </default>
      <default class="shin">
        <geom fromto="0 0 0 0 0 -.3"  size=".049"/>
      </default>
      <default class="foot">
        <geom size=".027"/>
        <default class="foot1">
          <geom fromto="-.07 -.01 0 .14 -.03 0"/>
        </default>
        <default class="foot2">
          <geom fromto="-.07 .01 0 .14  .03 0"/>
        </default>
      </default>
      <default class="arm_upper">
        <geom size=".04"/>
      </default>
      <default class="arm_lower">
        <geom size=".031"/>
      </default>
      <default class="hand">
        <geom type="sphere" size=".04"/>
      </default>

      <!-- joints -->
      <joint type="hinge" damping=".2" stiffness="1" armature=".01" limited="true" solimplimit="0 .99 .01"/>
      <default class="joint_big">
        <joint damping="5" stiffness="10"/>
        <default class="hip_x">
          <joint range="-30 10"/>
        </default>
        <default class="hip_z">
          <joint range="-60 35"/>
        </default>
        <default class="hip_y">
          <joint axis="0 1 0" range="-150 20"/>
        </default>
        <default class="joint_big_stiff">
          <joint stiffness="20"/>
        </default>
      </default>
      <default class="knee">
        <joint pos="0 0 .02" axis="0 -1 0" range="-160 2"/>
      </default>
      <default class="ankle">
        <joint range="-50 50"/>
        <default class="ankle_y">
          <joint pos="0 0 .08" axis="0 1 0" stiffness="6"/>
        </default>
        <default class="ankle_x">
          <joint pos="0 0 .04" stiffness="3"/>
        </default>
      </default>
      <default class="shoulder">
        <joint range="-85 60"/>
      </default>
      <default class="elbow">
        <joint range="-100 50" stiffness="0"/>
      </default>
    </default>
  </default>

  <visual>
    <map force="0.1" zfar="30"/>
    <rgba haze="0.15 0.25 0.35 1"/>
    <quality shadowsize="4096"/>
    <global offwidth="800" offheight="800"/>
  </visual>

  <worldbody>
    <geom name="floor" size="10 10 .05" type="plane" material="grid" condim="3"/>
    <light directional="true" diffuse=".4 .4 .4" specular="0.1 0.1 0.1" pos="0 0 5" dir="0 0 -1" castshadow="false"/>
    <light name="spotlight" mode="targetbodycom" target="world" diffuse="1 1 1" specular="0.3 0.3 0.3" pos="-6 -6 4" cutoff="60"/>
    <replicate count="2" euler="0 0 16.36" sep="-">
      <body name="torso" pos="-5 0 1.282" childclass="body">
        <camera name="back" pos="-3 0 1" xyaxes="0 -1 0 1 0 2" mode="trackcom"/>
        <camera name="side" pos="0 -3 1" xyaxes="1 0 0 0 1 2" mode="trackcom"/>
        <freejoint name="root"/>
        <geom name="torso" fromto="0 -.07 0 0 .07 0" size=".07"/>
        <geom name="waist_upper" fromto="-.01 -.06 -.12 -.01 .06 -.12" size=".06"/>
        <body name="head" pos="0 0 .19">
          <geom name="head" type="sphere" size=".09"/>
          <camera name="egocentric" pos=".09 0 0" xyaxes="0 -1 0 .1 0 1" fovy="80"/>
        </body>
        <body name="waist_lower" pos="-.01 0 -.26">
          <geom name="waist_lower" fromto="0 -.06 0 0 .06 0" size=".06"/>
          <joint name="abdomen_z" pos="0 0 .065" axis="0 0 1" range="-45 45" class="joint_big_stiff"/>
          <joint name="abdomen_y" pos="0 0 .065" axis="0 1 0" range="-75 30" class="joint_big"/>
          <body name="pelvis" pos="0 0 -.165">
            <joint name="abdomen_x" pos="0 0 .1" axis="1 0 0" range="-35 35" class="joint_big"/>
            <geom name="butt" fromto="-.02 -.07 0 -.02 .07 0" size=".09"/>
            <body name="thigh_right" pos="0 -.1 -.04">
              <joint name="hip_x_right" axis="1 0 0" class="hip_x"/>
              <joint name="hip_z_right" axis="0 0 1" class="hip_z"/>
              <joint name="hip_y_right" class="hip_y"/>
              <geom name="thigh_right" fromto="0 0 0 0 .01 -.34" class="thigh"/>
              <body name="shin_right" pos="0 .01 -.4">
                <joint name="knee_right" class="knee"/>
                <geom name="shin_right" class="shin"/>
                <body name="foot_right" pos="0 0 -.39">
                  <joint name="ankle_y_right" class="ankle_y"/>
                  <joint name="ankle_x_right" class="ankle_x" axis="1 0 .5"/>
                  <geom name="foot1_right" class="foot1"/>
                  <geom name="foot2_right" class="foot2"/>
                </body>
              </body>
            </body>
            <body name="thigh_left" pos="0 .1 -.04">
              <joint name="hip_x_left" axis="-1 0 0" class="hip_x"/>
              <joint name="hip_z_left" axis="0 0 -1" class="hip_z"/>
              <joint name="hip_y_left" class="hip_y"/>
              <geom name="thigh_left" fromto="0 0 0 0 -.01 -.34" class="thigh"/>
              <body name="shin_left" pos="0 -.01 -.4">
                <joint name="knee_left" class="knee"/>
                <geom name="shin_left" fromto="0 0 0 0 0 -.3" class="shin"/>
                <body name="foot_left" pos="0 0 -.39">
                  <joint name="ankle_y_left" class="ankle_y"/>
                  <joint name="ankle_x_left" class="ankle_x" axis="-1 0 -.5"/>
                  <geom name="foot1_left" class="foot1"/>
                  <geom name="foot2_left" class="foot2"/>
                </body>
              </body>
            </body>
          </body>
        </body>
        <body name="upper_arm_right" pos="0 -.17 .06">
          <joint name="shoulder1_right" axis="2 1 1"  class="shoulder"/>
          <joint name="shoulder2_right" axis="0 -1 1" class="shoulder"/>
          <geom name="upper_arm_right" fromto="0 0 0 .16 -.16 -.16" class="arm_upper"/>
          <body name="lower_arm_right" pos=".18 -.18 -.18">
            <joint name="elbow_right" axis="0 -1 1" class="elbow"/>
            <geom name="lower_arm_right" fromto=".01 .01 .01 .17 .17 .17" class="arm_lower"/>
            <body name="hand_right" pos=".18 .18 .18">
              <geom name="hand_right" zaxis="1 1 1" class="hand"/>
            </body>
          </body>
        </body>
        <body name="upper_arm_left" pos="0 .17 .06">
          <joint name="shoulder1_left" axis="-2 1 -1" class="shoulder"/>
          <joint name="shoulder2_left" axis="0 -1 -1"  class="shoulder"/>
          <geom name="upper_arm_left" fromto="0 0 0 .16 .16 -.16" class="arm_upper"/>
          <body name="lower_arm_left" pos=".18 .18 -.18">
            <joint name="elbow_left" axis="0 -1 -1" class="elbow"/>
            <geom name="lower_arm_left" fromto=".01 -.01 .01 .17 -.17 .17" class="arm_lower"/>
            <body name="hand_left" pos=".18 -.18 .18">
              <geom name="hand_left" zaxis="1 -1 1" class="hand"/>
            </body>
          </body>
        </body>
      </body>
    </replicate>
  </worldbody>

  <contact>
    <exclude body1="waist_lower" body2="thigh_right"/>
    <exclude body1="waist_lower" body2="thigh_left"/>
  </contact>

  <!--<tendon>
    <fixed name="hamstring_right" limited="true" range="-0.3 2">
      <joint joint="hip_y_right" coef=".5"/>
      <joint joint="knee_right" coef="-.5"/>
    </fixed>
    <fixed name="hamstring_left" limited="true" range="-0.3 2">
      <joint joint="hip_y_left" coef=".5"/>
      <joint joint="knee_left" coef="-.5"/>
    </fixed>
  </tendon> -->

  <actuator>
    <motor name="abdomen_z"       gear="40"  joint="abdomen_z"/>
    <motor name="abdomen_y"       gear="40"  joint="abdomen_y"/>
    <motor name="abdomen_x"       gear="40"  joint="abdomen_x"/>
    <motor name="hip_x_right"     gear="40"  joint="hip_x_right"/>
    <motor name="hip_z_right"     gear="40"  joint="hip_z_right"/>
    <motor name="hip_y_right"     gear="120" joint="hip_y_right"/>
    <motor name="knee_right"      gear="80"  joint="knee_right"/>
    <motor name="ankle_y_right"   gear="20"  joint="ankle_y_right"/>
    <motor name="ankle_x_right"   gear="20"  joint="ankle_x_right"/>
    <motor name="hip_x_left"      gear="40"  joint="hip_x_left"/>
    <motor name="hip_z_left"      gear="40"  joint="hip_z_left"/>
    <motor name="hip_y_left"      gear="120" joint="hip_y_left"/>
    <motor name="knee_left"       gear="80"  joint="knee_left"/>
    <motor name="ankle_y_left"    gear="20"  joint="ankle_y_left"/>
    <motor name="ankle_x_left"    gear="20"  joint="ankle_x_left"/>
    <motor name="shoulder1_right" gear="20"  joint="shoulder1_right"/>
    <motor name="shoulder2_right" gear="20"  joint="shoulder2_right"/>
    <motor name="elbow_right"     gear="40"  joint="elbow_right"/>
    <motor name="shoulder1_left"  gear="20"  joint="shoulder1_left"/>
    <motor name="shoulder2_left"  gear="20"  joint="shoulder2_left"/>
    <motor name="elbow_left"      gear="40"  joint="elbow_left"/>
  </actuator>
  <keyframe>
  <key qpos='-5 0 1.282 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 -4.1204 0.0439149 1.2821 0.0357475 0.0169831 0.00416343 0.999208 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0'/>
  <key qpos='-4.9369 0.0104954 0.877604 0.999617 0.0267022 0.00647118 -0.00319783 -0.0104185 -0.48433 -0.0602882 0.05204 -0.00947736 0.0303459 -2.28398 -0.202255 -0.095124 -0.0517436 0.0101542 0.045474 -2.28398 -0.196754 0.112707 0.10916 -0.0657979 -0.0187167 0.117478 -0.10947 0.00278909 -4.13095 0.0451804 1.27917 0.0337869 -0.0293599 0.00431884 0.998988 0.000878712 -0.0797496 -0.0055062 -0.0454015 0.00122102 -0.0133842 0.0420628 0.0748343 -0.0163969 -0.0123584 -0.000779432 -0.0139995 0.0429348 0.0763261 0.0147094 0.407914 -0.28512 -0.224151 0.410719 -0.277123 -0.221295'/>
  <key qpos='-4.9369 0.0104954 0.877604 0.999617 0.0267022 0.00647118 -0.00319783 -0.0104185 -0.420189 -0.0602882 -0.062854 -0.00947736 0.022719 -2.28398 -0.202255 -0.095124 -0.0517436 0.0101542 0.045474 -2.28398 -0.196754 0.112707 0.10916 -0.0657979 -0.0187167 0.117478 -0.10947 0.00278909 -4.39301 -0.0489197 0.970054 0.0329077 -0.0294438 0.0348755 0.998416 0.000878712 -1.309 -0.0055062 -0.279265 -1.047 -0.288826 -1.33663 0.0748343 -0.0163969 -0.5236 -1.047 0.185909 -1.11039 0.0763261 0.0147094 0.407914 -0.28512 -0.224151 0.410719 -0.277123 -0.221295'/>
  <key name="closed guard" qpos='-4.93912 0.00574695 0.781246 0.983692 0.00932319 0.174587 0.042218 -0.0116104 -0.827671 -0.028291 -0.0956173 -0.0029346 -0.0989541 -2.42802 0.0654537 0.0787425 -0.013588 0.0691285 -0.10043 -2.41905 0.0792197 0.0596308 0.173107 -0.229405 -0.508331 0.529554 -0.214771 -0.367333 -4.3503 -0.0448199 0.684852 0.0230275 0.0811853 0.00085943 0.996433 0.00598573 -1.12713 -0.0227058 -0.236974 -0.874702 -0.380878 -1.18933 0.206359 -0.142285 -0.429261 -0.885197 0.0282188 -0.832516 0.206467 -0.0937541 0.757062 -0.298457 -0.0624962 0.0231864 -0.458937 -0.568735'/>
  </keyframe>
</mujoco>"""

NOTE: This model is tendonless (I had to comment out the part where it says tendon because they still haven't implemented tendons...)

In [76]:
model = mujoco.MjModel.from_xml_string(xml)
data = mujoco.MjData(model)
renderer = mujoco.Renderer(model)
mujoco.mj_forward(model, data)
renderer.update_scene(data)
media.show_image(renderer.render())
mujoco.mj_resetDataKeyframe(model, data, 3)
print(len(data.qpos))

56


In [84]:
po = '-4.93912 0.00574695 0.781246 0.983692 0.00932319 0.174587 0.042218 -0.0116104 -0.827671 -0.028291 -0.0956173 -0.0029346 -0.0989541 -2.42802 0.0654537 0.0787425 -0.013588 0.0691285 -0.10043 -2.41905 0.0792197 0.0596308 0.173107 -0.229405 -0.508331 0.529554 -0.214771 -0.367333 -4.3503 -0.0448199 0.684852 0.0230275 0.0811853 0.00085943 0.996433 0.00598573 -1.12713 -0.0227058 -0.236974 -0.874702 -0.380878 -1.18933 0.206359 -0.142285 -0.429261 -0.885197 0.0282188 -0.832516 0.206467 -0.0937541 0.757062 -0.298457 -0.0624962 0.0231864 -0.458937 -0.568735'
print(po.split())
qpos0 = [float(i) for i in po.split()]
print(qpos0)

['-5', '0', '1.282', '1', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '-4.1204', '0.0439149', '1.2821', '0.0357475', '0.0169831', '0.00416343', '0.999208', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0']
[-5.0, 0.0, 1.282, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -4.1204, 0.0439149, 1.2821, 0.0357475, 0.0169831, 0.00416343, 0.999208, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]


  and should_run_async(code)


In [80]:
print(model.nkey)

0


  and should_run_async(code)


In [74]:
for key in range(model.nkey):
  mujoco.mj_resetDataKeyframe(model, data, key)
  mujoco.mj_forward(model, data)
  renderer.update_scene(data)
  media.show_image(renderer.render())

Main code starts from here


In [61]:
mj_model = mujoco.MjModel.from_xml_string(xml)
mj_data = mujoco.MjData(mj_model)
renderer = mujoco.Renderer(mj_model)

In [62]:
mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)
# print(mj_data.qpos, type(mj_data.qpos))
# print(mjx_data.qpos, type(mjx_data.qpos), mjx_data.qpos.devices())

NotImplementedError: tendons are not supported

In [43]:
print(mj_model.nkey)
for key in range(mj_model.nkey):
  mujoco.mj_resetDataKeyframe(mj_model, mj_data, key)
  mujoco.mj_forward(mj_model, mj_data)
  renderer.update_scene(mj_data)
  media.show_image(renderer.render())

0


In [86]:
# enable joint visualization option:
scene_option = mujoco.MjvOption()
scene_option.flags[mujoco.mjtVisFlag.mjVIS_JOINT] = True

duration = 3.8  # (seconds)
framerate = 60  # (Hz)

frames = []
mujoco.mj_resetData(mj_model, mj_data)
mj_model. = qpos0
mujoco.mj_resetDataKeyframe(mj_model, mj_data, 3)
while mj_data.time < duration:
  mujoco.mj_step(mj_model, mj_data)
  if len(frames) < mj_data.time * framerate:
    renderer.update_scene(mj_data, scene_option=scene_option)
    pixels = renderer.render()
    frames.append(pixels)

# Simulate and display video.
media.show_video(frames, fps=framerate)

  and should_run_async(code)


ValueError: could not broadcast input array from shape (56,) into shape (0,3)

In [26]:

jit_step = jax.jit(mjx.step)

frames = []
mujoco.mj_resetData(mj_model, mj_data)

mjx_data = mjx.put_data(mj_model, mj_data)
while mjx_data.time < duration:
  mjx_data = jit_step(mjx_model, mjx_data)
  if len(frames) < mjx_data.time * framerate:
    mj_data = mjx.get_data(mj_model, mjx_data)
    renderer.update_scene(mj_data, scene_option=scene_option)
    pixels = renderer.render()
    frames.append(pixels)

media.show_video(frames, fps=framerate)

0
This browser does not support the video tag.
