In [16]:
import os
os.environ['MUJOCO_GL'] = 'egl' 
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import model
from mujoco_playground.config import locomotion_params
from mujoco_playground import registry
import jax
from mujoco_playground._src.mjx_env import get_sensor_data
import mujoco
import time
import logging
from mujoco import mjx
from typing import List
from mujoco_playground._src.locomotion.go2 import go2_constants as consts
import mediapy as media


print("Setting up JAX for GPU usage...")
os.environ['JAX_PLATFORMS'] = 'gpu'
jax.config.update('jax_platform_name', 'gpu')
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
jax.config.update('jax_disable_jit', False)
# Remove jax debug logging
logging.getLogger('jax').setLevel(logging.WARNING)
print("Defining functions")

def get_obs_sensors(model: mujoco.MjModel, data: mjx.Data) -> List:
	"""
	Gets the observation sensors for the Go2 robot

	Args:
		model (mujoco.MjModel): The Mujoco model
		data (mujoco.MjData): The Mujoco data
	Returns:
		List: A list of observation sensors
	"""
	linvel = get_sensor_data(model, data, consts.LOCAL_LINVEL_SENSOR)
	gyro = get_sensor_data(model, data, consts.GYRO_SENSOR)
	resized_matrix = jax.numpy.array(data.site_xmat[model.site("imu").id]).reshape((3, 3))
	gravity = resized_matrix @ jax.numpy.array([0, 0, -1])
	joint_angles = data.qpos[7:] - model.keyframe("home").qpos[7:]
	joint_velocities = data.qvel[6:]

	return [linvel, gyro, gravity, joint_angles, joint_velocities]


def import_model(env_name='Go2JoystickFlatTerrain', env_cfg=None):
	"""
	This function imports the PPO trained model for the task 'Go2JoystickFlatTerrain'.
	"""
	
	
	ppo_params = locomotion_params.brax_ppo_config(env_name)
	model_params = model.load_params('ppo_go2joystick_flatterrain_params_v1_ctrl_002_sim_0004_impratio_100')
	ppo = ppo_networks.make_ppo_networks(action_size=env_cfg.action_size, observation_size=env_cfg.observation_size, **ppo_params.network_factory)
	make_inference = ppo_networks.make_inference_fn(ppo)
	inference_fnTEST = make_inference(model_params, deterministic=True)

	return inference_fnTEST

Setting up JAX for GPU usage...
Defining functions


In [17]:
print("Initializing environment and model...")
rng = jax.random.PRNGKey(0)
env_name = 'Go2JoystickFlatTerrain'
env_cfg = registry.get_default_config(env_name)
m = mujoco.MjModel.from_xml_path('./xmls/scene_mjx_feetonly_flat_terrain.xml')
d = mujoco.MjData(m)
print("Setting up renderer and importing model...")
renderer = mujoco.Renderer(m)
print("Importing trained model...")
inference = import_model(env_name, env_cfg)
print("Model imported successfully.")
last_action = jax.numpy.zeros(env_cfg.action_size)
print("JIT compiling inference function...")
inference = jax.jit(inference)
print("Starting main simulation loop...")
command = jax.numpy.array([0.0, 0.0, 3.14]) 
counter_control = 0
counter_init = 0
mujoco.mj_resetData(m, d)
# act_rng, rng = jax.random.split(rng)
# obs_sensors = get_obs_sensors(m, d)
# obs = jax.numpy.concatenate(obs_sensors + [last_action, command_1])
# # Inference from command -> angle off set for each of the joints
# ctrl, _ = inference(obs, act_rng)
timer_control = time.time()
motors_targets = m.keyframe("home").qpos[7:]
render_every = 2
n_steps = 500
images = []
print("Starting simulation...")

Initializing environment and model...
Setting up renderer and importing model...
Importing trained model...
Model imported successfully.
JIT compiling inference function...
Starting main simulation loop...
Starting simulation...


In [18]:
for i in range(n_steps):
	# print(f"Control dt: {env_cfg.ctrl_dt}s, Sim dt: {env_cfg.sim_dt}s")
	step_start = time.time()
	act_rng, rng = jax.random.split(rng)
	# print(f"Time elapsed for random split: {time.time() - step_start:.4f}s")
	obs_sensors = get_obs_sensors(m, d)
	# print(f"Time elapsed for getting sensors: {time.time() - step_start:.4f}s")
	obs = jax.numpy.concatenate(obs_sensors + [last_action, command])
	# print(f"Time elapsed for concatenating obs: {time.time() - step_start:.4f}s")
	# Inference from command -> angle off set for each of the joints
	ctrl, _ = inference(obs, act_rng)
	# print(f"Time elapsed for inference: {time.time() - step_start:.4f}s")
	motors_targets = m.keyframe("home").qpos[7:] + ctrl * env_cfg.action_scale
	# print(f"Time elapsed for getting motor targets: {time.time() - step_start:.4f}s")
	timer_control = time.time()
	last_action = ctrl
	print("Time used in control loop: {:.4f}s".format(time.time() - step_start))
	d.ctrl = motors_targets	
	for _ in range(int(env_cfg.ctrl_dt / env_cfg.sim_dt)):
		mujoco.mj_step(m, d)

	# time_until_next_step = env_cfg.sim_dt - (time.time() - step_start)
	# if time_until_next_step > 0:
	# 	time.sleep(time_until_next_step)
	# else:
	# 	print(f"Step took longer ({env_cfg.sim_dt - time_until_next_step:.4f}s) than sim_dt of {env_cfg.sim_dt}s")
	# 	pass
	print(f"Timestep {i+1}/{n_steps} completed.")
	if i % render_every == 0:
		renderer.update_scene(d, camera='track')
		images.append(renderer.render())
		
media.show_video(images, fps=1.0 / env_cfg.ctrl_dt / render_every)

Time used in control loop: 0.1451s
Timestep 1/500 completed.
Time used in control loop: 0.0016s
Timestep 2/500 completed.
Time used in control loop: 0.0012s
Timestep 3/500 completed.
Time used in control loop: 0.0020s
Timestep 4/500 completed.
Time used in control loop: 0.0015s
Timestep 5/500 completed.
Time used in control loop: 0.0019s
Timestep 6/500 completed.
Time used in control loop: 0.0012s
Timestep 7/500 completed.
Time used in control loop: 0.0024s
Timestep 8/500 completed.
Time used in control loop: 0.0014s
Timestep 9/500 completed.
Time used in control loop: 0.0019s
Timestep 10/500 completed.
Time used in control loop: 0.0012s
Timestep 11/500 completed.
Time used in control loop: 0.0017s
Timestep 12/500 completed.
Time used in control loop: 0.0011s
Timestep 13/500 completed.
Time used in control loop: 0.0016s
Timestep 14/500 completed.
Time used in control loop: 0.0011s
Timestep 15/500 completed.
Time used in control loop: 0.0021s
Timestep 16/500 completed.
Time used in cont

0
This browser does not support the video tag.
