# Brax: a differentiable physics engine

[Brax](https://github.com/google/brax) simulates physical systems made up of rigid bodies, joints, and actutators.  Brax provides the function:

$$
\text{state}_{t+1} = \text{step}(\text{system}, \text{state}_t, \text{act})
$$

where:
* $\text{system}$ is the static description of the physical system: each body in the world, its weight and size, and so on
* $\text{state}_t$ is the dynamic state of the system at time $t$: each body's position, rotation, velocity, and angular velocity
* $\text{act}$ is dynamic input to the system in the form of motor actuation

Brax simulations are differentiable: the gradient $\Delta \text{step}$ can be used for efficient trajectory optimization.  But Brax is also well-suited to derivative-free optimization methods such as evolutionary strategy or reinforcement learning.

Let's review how $\text{system}$, $\text{state}_t$, and $\text{act}$ are used:

In [1]:
import os
import pickle
import mediapy as media
from etils import epath

import mujoco
import mujoco.mjx as mjx

import jax
from jax import numpy as jnp

from brax import envs
from brax.envs.base import Env, PipelineEnv, State
from brax.io import html, mjcf, model

print(f"jax.devices(): {jax.devices()}")

MJCF_ROOT_PATH = epath.Path("../../mujoco")


def zscore(x, xmean, xstd, default=0):
    valid = jnp.greater(xstd, 0)
    return jnp.where(valid, (x - xmean) / xstd, default)


def l1_norm(x):
    return jnp.sum(jnp.abs(x))


def l2_norm(x):
    return jnp.sqrt(jnp.sum(x**2))


class SequentialReacher(PipelineEnv):

    def __init__(
        self,
        target_duration=3,
        num_targets=10,
        **kwargs,
    ):
        self.mj_model = mujoco.MjModel.from_xml_path(
            (MJCF_ROOT_PATH / "arm.xml").as_posix()
        )

        sys = mjcf.load_model(
            self.mj_model
        )  # system defining the kinematic tree and other properties
        kwargs["backend"] = "mjx"  # string specifying the physics pipeline
        kwargs["n_frames"] = (
            1  # the number of times to step the physics pipeline for each environment step
        )

        super().__init__(sys, **kwargs)

        # Get the site ID using the name of your end effector
        self.hand_id = self.mj_model.geom("hand").id
        self.target_id = self.mj_model.body_mocapid[
            mujoco.mj_name2id(self.mj_model, mujoco.mjtObj.mjOBJ_BODY, b"target")
        ]

        # Load sensor stats
        sensor_stats_path = os.path.join(MJCF_ROOT_PATH, "sensor_stats.pkl")
        with open(sensor_stats_path, "rb") as f:
            self.sensor_stats = pickle.load(f)

        # Load hand stats
        hand_position_stats_path = os.path.join(
            MJCF_ROOT_PATH, "hand_position_stats.pkl"
        )
        with open(hand_position_stats_path, "rb") as f:
            self.hand_position_stats = pickle.load(f)

        # Load candidate target positions
        candidate_targets_path = os.path.join(MJCF_ROOT_PATH, "candidate_targets.pkl")
        with open(candidate_targets_path, "rb") as f:
            self.candidate_targets = pickle.load(f)

        # Load candidate nail positions
        grid_positions_path = os.path.join(MJCF_ROOT_PATH, "grid_positions.pkl")
        with open(grid_positions_path, "rb") as f:
            self.grid_positions = pickle.load(f)

        # Convert stats to JAX arrays
        self.target_means = jnp.array(self.hand_position_stats["mean"].values)
        self.target_stds = jnp.array(self.hand_position_stats["std"].values)
        self.sensor_means = jnp.array(self.sensor_stats["mean"].values)
        self.sensor_stds = jnp.array(self.sensor_stats["std"].values)

        # Convert candidate_targets to JAX array
        self.candidate_target_positions = jnp.array(self.candidate_targets.values)

        self.target_duration = target_duration
        self.num_targets = num_targets

    def reset(self, rng: jnp.ndarray) -> State:
        """Resets the environment to an initial state."""
        qpos = jnp.zeros(self.sys.nq)
        qvel = jnp.zeros(self.sys.nv)
        data = self.pipeline_init(qpos, qvel)

        target_positions = self._sample_target_positions(rng)
        data = self._update_target(data, target_positions)

        obs = self._get_obs(data)
        reward, done = jnp.zeros(2)

        return State(
            data, obs, reward, done, info={"target_positions": target_positions}
        )

    def step(self, state: State, action: jnp.ndarray) -> State:
        """Runs one timestep of the environment's dynamics."""
        data = self.pipeline_step(state.pipeline_state, action)
        data = self._update_target(data, state.info["target_positions"])

        hand_position = self._get_hand_pos(data)
        target_position = self._get_target_pos(data)
        euclidean_distance = l2_norm(target_position - hand_position)

        obs = self._get_obs(data)
        reward = -euclidean_distance

        done = jnp.where(data.time > self.target_duration * self.num_targets, 1.0, 0.0)

        return state.replace(pipeline_state=data, obs=obs, reward=reward, done=done)

    def _get_obs(self, data: mjx.Data) -> jnp.ndarray:
        target_position = self._get_target_pos(data)
        sensor_data = data.sensordata.copy()
        norm_target_position = zscore(
            target_position,
            self.target_means,
            self.target_stds,
        )
        norm_sensor_data = zscore(
            sensor_data,
            self.sensor_means,
            self.sensor_stds,
        )
        obs = jnp.concatenate(
            [
                norm_target_position,
                norm_sensor_data,
            ]
        )
        # obs = jnp.reshape(obs, (1, -1))
        return obs

    def _sample_target_positions(self, key: jnp.ndarray):
        """Sample target positions (w/o replacement) from the candidate targets"""
        sample_idcs = jax.random.choice(
            key,
            self.candidate_target_positions.shape[0],
            shape=(self.num_targets,),
            replace=False,
        )
        return self.candidate_target_positions[sample_idcs]

    def _update_target(self, data: mjx.Data, target_positions) -> jnp.ndarray:
        """Update the target position"""
        target_idx = jnp.floor_divide(data.time, self.target_duration).astype(jnp.int32)
        mocap_position = data.mocap_pos.at[self.target_id].set(
            target_positions[target_idx]
        )
        return data.replace(mocap_pos=mocap_position)

    def _get_hand_pos(self, data: mjx.Data) -> jnp.ndarray:
        """Get the position of the end effector (hand)"""
        hand_position = data.geom_xpos[self.hand_id].copy()
        return hand_position

    def _get_target_pos(self, data: mjx.Data) -> jnp.ndarray:
        """Get the position of the target"""
        target_position = data.mocap_pos[self.target_id].copy()
        return target_position


envs.register_environment("sequential_reacher", SequentialReacher)

jax.devices(): [CudaDevice(id=0)]


In [None]:
from typing import Tuple
from typing import Callable
from functools import partial

from flax import linen as nn
from flax.linen import initializers
from flax.struct import dataclass

from evojax.policy.base import PolicyState
from evojax.policy.base import PolicyNetwork
from evojax.task.base import TaskState
from evojax.util import create_logger
from evojax.util import get_params_format_fn

logger = create_logger("RNNPolicy")


@dataclass
class RNNState(PolicyState):
    h: jnp.ndarray


class RNN(nn.Module):
    """A recurrent neural network."""

    input_dim: int
    hidden_dim: int
    output_dim: int
    hidden_act_fn: Callable = nn.tanh
    output_act_fn: Callable = nn.sigmoid
    use_bias: bool = False

    @nn.compact
    def __call__(self, inputs, carry):
        input_layer = partial(
            nn.Dense,
            features=self.hidden_dim,
            use_bias=False,
            kernel_init=initializers.xavier_normal(),
        )
        recurrent_layer = partial(
            nn.Dense,
            features=self.hidden_dim,
            use_bias=self.use_bias,
            kernel_init=initializers.xavier_normal(),
        )
        output_layer = partial(
            nn.Dense,
            features=self.output_dim,
            use_bias=False,
            kernel_init=initializers.xavier_normal(),
        )
        carry = self.hidden_act_fn(
            input_layer(name="i")(inputs) + recurrent_layer(name="h")(carry)
        )
        output = self.output_act_fn(output_layer(name="o")(carry))
        return output, carry


class RNNPolicy(PolicyNetwork):
    """The policy wraps the model and does some data formatting works."""

    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        output_dim: int,
        hidden_act_fn: Callable = nn.tanh,
        output_act_fn: Callable = nn.sigmoid,
        use_bias: bool = False,
    ):
        model = RNN(
            input_dim=input_dim,
            hidden_dim=hidden_dim,
            output_dim=output_dim,
            hidden_act_fn=hidden_act_fn,
            output_act_fn=output_act_fn,
            use_bias=use_bias,
        )
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        params = model.init(
            jax.random.PRNGKey(0),
            jnp.ones([1, input_dim]),
            (jnp.zeros([1, hidden_dim]), jnp.zeros([1, hidden_dim])),
        )
        self.num_params, format_params_fn = get_params_format_fn(params)
        print("RNNPolicy.num_params = {}".format(self.num_params))

        self._format_params_fn = jax.vmap(format_params_fn)
        self._forward_fn = jax.vmap(model.apply)

    def reset(self, states: TaskState) -> PolicyState:
        keys = jax.random.split(jax.random.PRNGKey(0), states.obs.shape[0])
        batch_size = states.obs.shape[0]
        h = jnp.zeros((batch_size, self.hidden_dim))
        return RNNState(keys=keys, h=h)

    def get_actions(
        self, t_states: TaskState, params: jnp.ndarray, p_states: RNNState
    ) -> Tuple[jnp.ndarray, RNNState]:
        # Calling `self._format_params_fn` unflattens the parameters so that
        # our Flax model can take that as an input.
        params = self._format_params_fn(params)

        # Now we return the actions and the updated `p_states`.
        actions, h = self._forward_fn(params, t_states.obs, p_states.h)
        return actions, RNNState(keys=p_states.keys, h=h)

In [None]:
import time
import numpy as np
from IPython.display import HTML

from brax import envs
from brax.io import html

import jax

from evojax import SimManager
from evojax.algo import PGPE
from evojax.algo import CMA
from evojax.policy import MLPPolicy
from evojax.task.brax_task import BraxTask

print(f"jax.devices(): {jax.devices()}")
print(f"jax.device_count(): {jax.device_count()}")
print(f"jax.local_device_count(): {jax.local_device_count()}")

env_name = "sequential_reacher"
env = envs.get_environment(env_name=env_name)
state = env.reset(rng=jax.random.PRNGKey(seed=0))

HTML(html.render(env.sys, [state.pipeline_state]))

num_devices = jax.local_device_count()

pop_size = 32
num_tests = 128
assert pop_size % num_devices == 0
assert num_tests % num_devices == 0

max_iters = 10000
center_lr = 0.01
init_std = 0.04
std_lr = 0.07

seed = 42

train_task = BraxTask(env_name=env_name, test=False)
test_task = BraxTask(env_name=env_name, test=True)

# policy = MLPPolicy(
#     input_dim=train_task.obs_shape[0],
#     output_dim=train_task.act_shape[0],
#     hidden_dims=[32, 32, 32, 32],
# )
policy = RNNPolicy(
    input_dim=train_task.obs_shape[0],
    hidden_dim=32,
    output_dim=train_task.act_shape[0],
    hidden_act_fn=jax.nn.tanh,
    output_act_fn=jax.nn.sigmoid,
    use_bias=True,
)
print("#params={}".format(policy.num_params))

# solver = PGPE(
#     pop_size=pop_size,
#     param_size=policy.num_params,
#     optimizer="adam",
#     center_learning_rate=center_lr,
#     stdev_learning_rate=std_lr,
#     init_stdev=init_std,
#     seed=seed,
# )
solver = CMA(
    pop_size=pop_size,
    param_size=policy.num_params,
    init_stdev=1.0,
    seed=seed,
)

sim_mgr = SimManager(
    n_repeats=1,
    test_n_repeats=1,
    pop_size=pop_size,
    n_evaluations=num_tests,
    policy_net=policy,
    train_vec_task=train_task,
    valid_vec_task=test_task,
    seed=seed,
)

print("Start training Brax ({}) for {} iterations.".format(env_name, max_iters))
start_time = time.perf_counter()

for train_iters in range(max_iters):

    # Training
    params = solver.ask()
    scores = sim_mgr.eval_params(params=params, test=False)
    solver.tell(fitness=scores)

    # Test periodically
    if train_iters > 0 and train_iters % 10 == 0:
        best_params = solver.best_params
        scores = np.array(sim_mgr.eval_params(params=best_params, test=True)[0])
        score_avg = np.mean(scores)
        score_std = np.std(scores)
        print(
            "Iter={0}/{1}, #tests={2}, score.avg={3:.2f}, score.std={4:.2f}".format(
                train_iters, max_iters, num_tests, score_avg, score_std
            )
        )

stop_time = time.perf_counter()

jax.devices(): [CudaDevice(id=0)]
jax.device_count(): 1
jax.local_device_count(): 1
RNNPolicy.num_params = 1632
#params=1632
(16_w,32)-aCMA-ES (mu_w=9.2,w_1=19%) in dimension 1632 (seed=42, Fri May  2 21:41:41 2025)
Start training Brax (sequential_reacher) for 10000 iterations.


TypeError: object of type 'numpy.float32' has no len()

In [11]:
# Final test.
best_params = solver.best_params
file = f"../../models/evojax_brax_mjx_cma_rnn.pkl"
with open(file, "wb") as f:
    pickle.dump(best_params, f)
scores = np.array(sim_mgr.eval_params(params=best_params, test=True)[0])
score_avg = np.mean(scores)
score_std = np.std(scores)
print(
    "Iter={0}, #tests={1}, score.avg={2:.2f}, score.std={3:.2f}".format(
        train_iters, num_tests, score_avg, score_std
    )
)
print("Training time: {}s".format(stop_time - start_time))

Iter=999, #tests=128, score.avg=-174.42, score.std=83.29
Training time: 4147.315433334999s


In [14]:
from evojax.policy import MLPPolicy
from evojax.task.brax_task import BraxTask

# @title Visualize the trained policy
env_name="sequential_reacher"
env = envs.get_environment(env_name=env_name)

train_task = BraxTask(env_name=env_name, test=False)
test_task = BraxTask(env_name=env_name, test=True)

# policy = MLPPolicy(
#     input_dim=train_task.obs_shape[0],
#     output_dim=train_task.act_shape[0],
#     hidden_dims=[32, 32, 32, 32],
# )
policy = RNNPolicy(
    input_dim=train_task.obs_shape[0],
    hidden_dim=32,
    output_dim=train_task.act_shape[0],
)

task_reset_fn = jax.jit(env.reset)
policy_reset_fn = jax.jit(policy.reset)
step_fn = jax.jit(env.step)
act_fn = jax.jit(policy.get_actions)

# Check if solver exists and get best_params directly if it does
try:
    best_params = solver.best_params
    print("Using best parameters from solver")
except NameError:
    print("Solver not found, loading parameters from file")

    # Load best params from file
    file = "../../models/evojax_brax_mjx_pgpe_rnn.pkl"
    with open(file, "rb") as f:
        best_params = pickle.load(f)

total_reward = 0
rng = jax.random.PRNGKey(seed=42)
task_state = task_reset_fn(rng=rng)
policy_state = policy_reset_fn(task_state)

rollout = [task_state.pipeline_state]

while not task_state.done:
    task_state = task_state.replace(obs=task_state.obs)
    act, policy_state = act_fn(task_state, best_params[None, :], policy_state)
    task_state = step_fn(task_state, act[0])
    total_reward = total_reward + task_state.reward
    rollout.append(task_state.pipeline_state)

print("rollout reward = {}".format(total_reward))

media.show_video(env.render(rollout), fps=1.0 / env.dt)

RNNPolicy.num_params = 1632
Using best parameters from solver
rollout reward = -2803.974609375


0
This browser does not support the video tag.


In [None]:
# import functools
# import matplotlib.pyplot as plt

# from datetime import datetime
# from brax.training.agents.ppo import train as ppo
# from brax.training.agents.es import train as es

# train_fn = functools.partial(
#     ppo.train,
#     num_timesteps=20_000_000,
#     num_evals=10,
#     reward_scaling=0.1,
#     episode_length=1000,
#     normalize_observations=True,
#     action_repeat=1,
#     unroll_length=10,
#     num_minibatches=24,
#     num_updates_per_batch=8,
#     discounting=0.97,
#     learning_rate=3e-4,
#     entropy_cost=1e-3,
#     num_envs=3072,
#     batch_size=512,
#     seed=0,
# )

# # train_fn = functools.partial(
# #     es.train,
# #     population_size=128,
# #     num_eval_envs=128,
# #     num_timesteps=15_000,
# #     episode_length=15_000,
# #     max_devices_per_host=8,
# #     seed=1,
# # )


# x_data = []
# y_data = []
# ydataerr = []
# times = [datetime.now()]

# max_y, min_y = 0, -1000


# def progress(num_steps, metrics):
#     times.append(datetime.now())
#     x_data.append(num_steps)
#     y_data.append(metrics["eval/episode_reward"])
#     ydataerr.append(metrics["eval/episode_reward_std"])

#     plt.xlim([-1, train_fn.keywords["num_timesteps"] * 1.25])
#     plt.ylim([min_y, max_y])

#     plt.xlabel("# environment steps")
#     plt.ylabel("reward per episode")
#     plt.title(f"y={y_data[-1]:.3f}")

#     plt.errorbar(x_data, y_data, yerr=ydataerr)
#     plt.show()


# make_inference_fn, params, _ = train_fn(environment=env, progress_fn=progress)

# print(f"time to jit: {times[1] - times[0]}")
# print(f"time to train: {times[-1] - times[1]}")

# model_path = 'mjx_brax_policy'
# model.save_params(model_path, params)

# params = model.load_params(model_path)

# inference_fn = make_inference_fn(params)
# jit_inference_fn = jax.jit(inference_fn)

# eval_env = envs.get_environment(env_name)

# jit_reset = jax.jit(eval_env.reset)
# jit_step = jax.jit(eval_env.step)

# # initialize the state
# rng = jax.random.PRNGKey(0)
# state = jit_reset(rng)
# rollout = [state.pipeline_state]

# # grab a trajectory
# n_steps = 500
# render_every = 1

# for i in range(n_steps):
#   act_rng, rng = jax.random.split(rng)
#   ctrl, _ = jit_inference_fn(state.obs, act_rng)
#   state = jit_step(state, ctrl)
#   rollout.append(state.pipeline_state)

#   if state.done:
#     break

# media.show_video(env.render(rollout[::render_every]), fps=1.0 / env.dt / render_every)

absl: 2025-04-30 16:29:26,229 [INFO] Device count: 1, process count: 1 (id 0), local device count: 1, devices to be used count: 1


KeyboardInterrupt: 