# Brax: a differentiable physics engine

[Brax](https://github.com/google/brax) simulates physical systems made up of rigid bodies, joints, and actutators.  Brax provides the function:

$$
\text{state}_{t+1} = \text{step}(\text{system}, \text{state}_t, \text{act})
$$

where:
* $\text{system}$ is the static description of the physical system: each body in the world, its weight and size, and so on
* $\text{state}_t$ is the dynamic state of the system at time $t$: each body's position, rotation, velocity, and angular velocity
* $\text{act}$ is dynamic input to the system in the form of motor actuation

Brax simulations are differentiable: the gradient $\Delta \text{step}$ can be used for efficient trajectory optimization.  But Brax is also well-suited to derivative-free optimization methods such as evolutionary strategy or reinforcement learning.

Let's review how $\text{system}$, $\text{state}_t$, and $\text{act}$ are used:

In [None]:
import os
import pickle
import mediapy as media
from etils import epath

import mujoco
import mujoco.mjx as mjx

import jax
from jax import numpy as jnp

from brax import envs
from brax.envs.base import Env, PipelineEnv, State
from brax.io import html, mjcf, model

print(f"jax.devices(): {jax.devices()}")

MJCF_ROOT_PATH = epath.Path("../../mujoco")


def zscore(x, xmean, xstd, default=0):
    valid = jnp.greater(xstd, 0)
    return jnp.where(valid, (x - xmean) / xstd, default)


def l1_norm(x):
    return jnp.sum(jnp.abs(x))


def l2_norm(x):
    return jnp.sqrt(jnp.sum(x**2))


class SequentialReacher(PipelineEnv):

    def __init__(
        self,
        target_duration=3,
        num_targets=10,
        **kwargs,
    ):
        self.mj_model = mujoco.MjModel.from_xml_path(
            (MJCF_ROOT_PATH / "arm.xml").as_posix()
        )

        sys = mjcf.load_model(self.mj_model)
        kwargs["backend"] = "mjx"

        super().__init__(sys, **kwargs)

        # Get the site ID using the name of your end effector
        self.hand_id = self.mj_model.geom("hand").id
        self.target_id = self.mj_model.body_mocapid[
            mujoco.mj_name2id(self.mj_model, mujoco.mjtObj.mjOBJ_BODY, b"target")
        ]

        # Load sensor stats
        sensor_stats_path = os.path.join(MJCF_ROOT_PATH, "sensor_stats.pkl")
        with open(sensor_stats_path, "rb") as f:
            self.sensor_stats = pickle.load(f)

        # Load hand stats
        hand_position_stats_path = os.path.join(
            MJCF_ROOT_PATH, "hand_position_stats.pkl"
        )
        with open(hand_position_stats_path, "rb") as f:
            self.hand_position_stats = pickle.load(f)

        # Load candidate target positions
        candidate_targets_path = os.path.join(MJCF_ROOT_PATH, "candidate_targets.pkl")
        with open(candidate_targets_path, "rb") as f:
            self.candidate_targets = pickle.load(f)

        # Load candidate nail positions
        grid_positions_path = os.path.join(MJCF_ROOT_PATH, "grid_positions.pkl")
        with open(grid_positions_path, "rb") as f:
            self.grid_positions = pickle.load(f)

        # Convert stats to JAX arrays
        self.target_means = jnp.array(self.hand_position_stats["mean"].values)
        self.target_stds = jnp.array(self.hand_position_stats["std"].values)
        self.sensor_means = jnp.array(self.sensor_stats["mean"].values)
        self.sensor_stds = jnp.array(self.sensor_stats["std"].values)

        # Convert candidate_targets to JAX array
        self.candidate_target_positions = jnp.array(self.candidate_targets.values)

        self.target_duration = target_duration
        self.num_targets = num_targets

    def reset(self, key: jnp.ndarray) -> State:
        """Resets the environment to an initial state."""
        qpos = jnp.zeros(self.sys.nq)
        qvel = jnp.zeros(self.sys.nv)
        data = self.pipeline_init(qpos, qvel)

        target_positions = self._sample_target_positions(key)
        data = self._update_target(data, target_positions)

        obs = self._get_obs(data)
        reward, done = jnp.zeros(2)

        return State(
            data, obs, reward, done, info={"target_positions": target_positions}
        )

    def step(self, state: State, action: jnp.ndarray) -> State:
        """Runs one timestep of the environment's dynamics."""
        data = self.pipeline_step(state.pipeline_state, action)
        data = self._update_target(data, state.info["target_positions"])

        hand_position = self._get_hand_pos(data)
        target_position = self._get_target_pos(data)
        euclidean_distance = l2_norm(target_position - hand_position)

        obs = self._get_obs(data)
        reward = -euclidean_distance

        done = jnp.where(data.time > self.target_duration * self.num_targets, 1.0, 0.0)

        return state.replace(pipeline_state=data, obs=obs, reward=reward, done=done)

    def _get_obs(self, data: mjx.Data) -> jnp.ndarray:
        target_position = self._get_target_pos(data)
        sensor_data = data.sensordata.copy()
        norm_target_position = zscore(
            target_position,
            self.target_means,
            self.target_stds,
        )
        norm_sensor_data = zscore(
            sensor_data,
            self.sensor_means,
            self.sensor_stds,
        )
        obs = jnp.concatenate(
            [
                norm_target_position,
                norm_sensor_data,
            ]
        )
        return jnp.reshape(obs, (1, -1))

    def _sample_target_positions(self, key: jnp.ndarray):
        """Sample target positions (w/o replacement) from the candidate targets"""
        sample_idcs = jax.random.choice(
            key,
            self.candidate_target_positions.shape[0],
            shape=(self.num_targets,),
            replace=False,
        )
        return self.candidate_target_positions[sample_idcs]

    def _update_target(self, data: mjx.Data, target_positions) -> jnp.ndarray:
        """Update the target position"""
        target_idx = jnp.floor_divide(data.time, self.target_duration).astype(jnp.int32)
        mocap_position = data.mocap_pos.at[self.target_id].set(
            target_positions[target_idx]
        )
        return data.replace(mocap_pos=mocap_position)

    def _get_hand_pos(self, data: mjx.Data) -> jnp.ndarray:
        """Get the position of the end effector (hand)"""
        hand_position = data.geom_xpos[self.hand_id].copy()
        return hand_position

    def _get_target_pos(self, data: mjx.Data) -> jnp.ndarray:
        """Get the position of the target"""
        target_position = data.mocap_pos[self.target_id].copy()
        return target_position


envs.register_environment("sequential_reacher", SequentialReacher)

In [None]:
# instantiate the environment
env_name = "sequential_reacher"
env = envs.get_environment(env_name, target_duration=3, num_targets=10)

# define the jit reset/step functions
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

# initialize the state
seed = 42
state = jit_reset(jax.random.PRNGKey(seed))
rollout = [state.pipeline_state]
num_frames = jnp.int32(jnp.round(env.target_duration * env.num_targets / env.dt))

print(f"Number of frames: {num_frames}")
print(f"Number of targets: {env.num_targets}")
print(f"Target duration: {env.target_duration}")
print(f"Time step: {env.dt}")

# grab a trajectory
for i in range(num_frames):
    prev_state = state.pipeline_state
    ctrl = jnp.array([0.0, 0.01, 0.0, 0.1])
    state = jit_step(state, ctrl)
    rollout.append(state.pipeline_state)

media.show_video(env.render(rollout), fps=1.0 / env.dt)

Number of frames: 15000
Number of targets: 10
Target duration: 3
Time step: 0.0020000000949949026


In [None]:
import functools
import matplotlib.pyplot as plt

from datetime import datetime
from brax.training.agents.ppo import train as ppo
from brax.training.agents.es import train as es

train_fn = functools.partial(
    ppo.train,
    num_timesteps=20_000_000,
    num_evals=10,
    reward_scaling=0.1,
    episode_length=1000,
    normalize_observations=True,
    action_repeat=1,
    unroll_length=10,
    num_minibatches=24,
    num_updates_per_batch=8,
    discounting=0.97,
    learning_rate=3e-4,
    entropy_cost=1e-3,
    num_envs=3072,
    batch_size=512,
    seed=0,
)

# train_fn = functools.partial(
#     es.train,
#     population_size=128,
#     num_eval_envs=128,
#     num_timesteps=15_000,
#     episode_length=15_000,
#     max_devices_per_host=8,
#     seed=1,
# )


x_data = []
y_data = []
ydataerr = []
times = [datetime.now()]

max_y, min_y = 0, -1000


def progress(num_steps, metrics):
    times.append(datetime.now())
    x_data.append(num_steps)
    y_data.append(metrics["eval/episode_reward"])
    ydataerr.append(metrics["eval/episode_reward_std"])

    plt.xlim([-1, train_fn.keywords["num_timesteps"] * 1.25])
    plt.ylim([min_y, max_y])

    plt.xlabel("# environment steps")
    plt.ylabel("reward per episode")
    plt.title(f"y={y_data[-1]:.3f}")

    plt.errorbar(x_data, y_data, yerr=ydataerr)
    plt.show()


make_inference_fn, params, _ = train_fn(environment=env, progress_fn=progress)

print(f"time to jit: {times[1] - times[0]}")
print(f"time to train: {times[-1] - times[1]}")

model_path = 'mjx_brax_policy'
model.save_params(model_path, params)

params = model.load_params(model_path)

inference_fn = make_inference_fn(params)
jit_inference_fn = jax.jit(inference_fn)

eval_env = envs.get_environment(env_name)

jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)

# initialize the state
key = jax.random.PRNGKey(0)
state = jit_reset(key)
rollout = [state.pipeline_state]

# grab a trajectory
n_steps = 500
render_every = 1

for i in range(n_steps):
  act_rng, key = jax.random.split(key)
  ctrl, _ = jit_inference_fn(state.obs, act_rng)
  state = jit_step(state, ctrl)
  rollout.append(state.pipeline_state)

  if state.done:
    break

media.show_video(env.render(rollout[::render_every]), fps=1.0 / env.dt / render_every)

In [None]:
from typing import Sequence, Tuple
from flax import linen as nn
from flax.struct import dataclass
from typing import Callable, Any, Optional

from evojax.policy.base import PolicyState
from evojax.policy.base import PolicyNetwork
from evojax.task.base import TaskState
from evojax.util import get_params_format_fn


@dataclass
class State(PolicyState):
    hx: jnp.ndarray


class RNN(nn.Module):
    """A recurrent neural network."""

    input_dim: int
    hidden_dim: int
    output_dim: int
    hidden_act_fn: Callable = nn.tanh
    output_act_fn: Callable = nn.sigmoid

    @nn.compact
    def __call__(self, x, hx):
        new_hx = nn.SimpleCell()(hx, x)
        y = nn.tanh(new_hx)
        return y, new_hx  # In addition to the output, we return hx too.


class RNNPolicy(PolicyNetwork):
    """The policy wraps the model and does some data formatting works."""

    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):

        model = RNN(
            input_dim=input_dim,
            hidden_dim=hidden_dim,
            output_dim=output_dim,
            hidden_act_fn=nn.tanh,
            output_act_fn=nn.sigmoid,
        )
        params = model.init(
            jax.random.PRNGKey(0),
            jnp.ones([1, input_dim]),
            jnp.zeros([1, self.hidden_dim]),
        )
        self.num_params, format_params_fn = get_params_format_fn(params)
        print("RNNPolicy.num_params = {}".format(self.num_params))

        self._format_params_fn = jax.vmap(format_params_fn)
        self._forward_fn = jax.vmap(model.apply)

    def reset(self, states: TaskState) -> PolicyState:
        keys = jax.random.split(jax.random.PRNGKey(0), states.obs.shape[0])
        b_size, obs_dim = states.obs.shape
        lstm_h = (
            jnp.zeros([b_size, self.hidden_dim]),
            jnp.zeros([b_size, self.hidden_dim]),
        )
        return State(keys=keys, hx=lstm_h)

    def get_actions(
        self, t_states: TaskState, params: jnp.ndarray, p_states: State
    ) -> Tuple[jnp.ndarray, State]:
        # Calling `self._format_params_fn` unflattens the parameters so that
        # our Flax model can take that as an input.
        params = self._format_params_fn(params)

        # Now we return the actions and the updated `p_states`.
        actions, hx = self._forward_fn(params, t_states.obs, p_states.hx)
        return actions, State(keys=p_states.keys, hx=hx)

In [None]:
# @title Import Libraries
import time
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import HTML

from brax import envs
from brax.io import html

import jax
import jax.numpy as jnp
from jax import random

from evojax import SimManager
from evojax import ObsNormalizer
from evojax.algo import PGPE
from evojax.algo import CMA_ES
from evojax.policy import MLPPolicy
from evojax.task.brax_task import BraxTask

import os

# @title Preview a Brax environment { run: "auto" }
# @markdown Select the environment to train:
env_name = "sequential_reacher"  # @param ['ant', 'humanoid', 'halfcheetah', 'fetch']
env = envs.create(env_name=env_name)
# env = env_fn()
state = env.reset(rng=jax.random.PRNGKey(seed=0))

HTML(html.render(env.sys, [state.pipeline_state]))

# @title Set hyper-parameters
# @markdown PLEASE NOTE: `pop_size` and `num_tests` should be multiples of `jax.local_device_count()`.

n_devices = jax.local_device_count()

pop_size = 1024  # @param
num_tests = 128  # @param
assert pop_size % n_devices == 0
assert num_tests % n_devices == 0

max_iters = 300  # @param
center_lr = 0.01  # @param
init_std = 0.04  # @param
std_lr = 0.07  # @param

seed = 42  # @param

# @title Training
train_task = BraxTask(env_name=env_name, test=False)
test_task = BraxTask(env_name=env_name, test=True)

policy = MLPPolicy(
    input_dim=train_task.obs_shape[0],
    output_dim=train_task.act_shape[0],
    hidden_dims=[32, 32, 32, 32],
)
print("#params={}".format(policy.num_params))

solver = PGPE(
    pop_size=pop_size,
    param_size=policy.num_params,
    optimizer="adam",
    center_learning_rate=center_lr,
    stdev_learning_rate=std_lr,
    init_stdev=init_std,
    seed=seed,
)

# solver = CMA_ES(
#     pop_size=pop_size,
#     param_size=policy.num_params,
#     init_stdev=1.0,
#     seed=seed,
# )

obs_normalizer = ObsNormalizer(obs_shape=train_task.obs_shape)

sim_mgr = SimManager(
    n_repeats=1,
    test_n_repeats=1,
    pop_size=pop_size,
    n_evaluations=num_tests,
    policy_net=policy,
    train_vec_task=train_task,
    valid_vec_task=test_task,
    seed=seed,
    obs_normalizer=obs_normalizer,
)

print("Start training Brax ({}) for {} iterations.".format(env_name, max_iters))
start_time = time.perf_counter()
for train_iters in range(max_iters):

    # Training
    params = solver.ask()
    scores, _ = sim_mgr.eval_params(params=params, test=False)
    solver.tell(fitness=scores)

    # Test periodically.
    if train_iters > 0 and train_iters % 10 == 0:
        best_params = solver.best_params
        scores = np.array(sim_mgr.eval_params(params=best_params, test=True)[0])
        score_avg = np.mean(scores)
        score_std = np.std(scores)
        print(
            "Iter={0}, #tests={1}, score.avg={2:.2f}, score.std={3:.2f}".format(
                train_iters, num_tests, score_avg, score_std
            )
        )

# Final test.
best_params = solver.best_params
scores = np.array(sim_mgr.eval_params(params=best_params, test=True)[0])
score_avg = np.mean(scores)
score_std = np.std(scores)
print(
    "Iter={0}, #tests={1}, score.avg={2:.2f}, score.std={3:.2f}".format(
        train_iters, num_tests, score_avg, score_std
    )
)
print("time cost: {}s".format(time.perf_counter() - start_time))

In [None]:
# @title Visualize the trained policy
env = envs.create(env_name="sequential_reacher", num_targets=10, target_duration=3)

task_reset_fn = jax.jit(env.reset)
policy_reset_fn = jax.jit(policy.reset)
step_fn = jax.jit(env.step)
act_fn = jax.jit(policy.get_actions)
obs_norm_fn = jax.jit(obs_normalizer.normalize_obs)

best_params = solver.best_params
obs_params = sim_mgr.obs_params

total_reward = 0
rng = jax.random.PRNGKey(seed=42)
task_state = task_reset_fn(rng=rng)
policy_state = policy_reset_fn(task_state)

rollout = [task_state.pipeline_state]

while not task_state.done:
    task_state = task_state.replace(obs=task_state.obs)
    act, policy_state = act_fn(task_state, best_params[None, :], policy_state)
    task_state = step_fn(task_state, act[0])
    total_reward = total_reward + task_state.reward
    rollout.append(task_state.pipeline_state)

print("rollout reward = {}".format(total_reward))

media.show_video(env.render(rollout), fps=1.0 / env.dt)