# Imports

In [1]:
import jax
from jax import numpy as jp
from matplotlib.lines import Line2D
from matplotlib.patches import Circle
import matplotlib.pyplot as plt

try:
  import brax
except (ImportError, ModuleNotFoundError):
  from IPython.display import clear_output 
  !pip install git+https://github.com/google/brax.git@main
  clear_output()
  import brax

# Load model

In [3]:
from brax.io import mjcf
import mujoco

MODEL_PATH = '../models/lelamp/scene.xml'

mj_model = mujoco.MjModel.from_xml_path(MODEL_PATH)
mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
mj_model.opt.iterations = 6
mj_model.opt.ls_iterations = 6

lelamp = mjcf.load_model(mj_model)



In [4]:
print(lelamp.link.inertia.mass)

[0.140913  0.0304303 0.0439213 0.040387  0.0271577 0.0121476]


In [None]:
#@title { run: "auto"}
from brax.generalized import pipeline as generalized_pipeline 
from brax.positional import pipeline as positional_pipeline 
from brax.spring import pipeline as spring_pipeline 

pipeline = 'generalized'  #@param ["generalized", "positional", "spring"]
step_size = "1 ms" #@param [".4 ms", "1 ms", "5 ms"]

substeps = {'.4 ms': 25, '1 ms': 10, '5 ms': 2}[step_size]
pipeline = {'generalized': generalized_pipeline,
            'positional': positional_pipeline,
            'spring': spring_pipeline}[pipeline]

pendulum = lelamp.tree_replace({'opt.timestep': 0.01 / substeps})

init_q = jp.zeros(lelamp.q_size())
state = jax.jit(pipeline.init)(pendulum, init_q, jp.zeros(pendulum.qd_size()))

_, ax = plt.subplots()
plt.xlim([-3, 3])
plt.ylim([-4, 0])

for i in range(100 * substeps):
  if i % substeps == 0:
    visualize(ax, state.x.pos, i / (100 * substeps))
  state = jax.jit(pipeline.step)(pendulum, state, None)

plt.title('pendulum in motion')
plt.show()

In [None]:
state = pipeline.init(pendulum, pendulum.init_q, jp.zeros(pendulum.qd_size()))