<a href="https://colab.research.google.com/github/google/evojax/blob/main/examples/notebooks/BraxTasks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# @title Install Packages

from IPython.display import clear_output

!pip install evojax
!pip install git+https://github.com/google/brax.git@main

clear_output()

In [2]:
# @title Import Libraries

import time
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import HTML

from brax import envs
from brax import jumpy as jp
from brax.io import html

import jax
import jax.numpy as jnp
from jax import random

from evojax import SimManager
from evojax import ObsNormalizer
from evojax.algo import PGPE
from evojax.policy import MLPPolicy
from evojax.task.brax_task import BraxTask

import os
if 'COLAB_TPU_ADDR' in os.environ:
    from jax.tools import colab_tpu
    colab_tpu.setup_tpu()

print('jax.devices():')
jax.devices()

jax.devices():


[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [3]:
#@title Preview a Brax environment { run: "auto" }
#@markdown Select the environment to train:

env_name = "ant"  # @param ['ant', 'humanoid', 'halfcheetah', 'fetch']
env_fn = envs.create_fn(env_name=env_name)
env = env_fn()
state = env.reset(rng=jp.random_prngkey(seed=0))

HTML(html.render(env.sys, [state.qp]))

# Learing to solve the control task

In [4]:
# @title Set hyper-parameters
# @markdown PLEASE NOTE: `pop_size` and `num_tests` should be multiples of `jax.local_device_count()`.

n_devices = jax.local_device_count()

pop_size = 1024  # @param
num_tests = 128  # @param
assert pop_size % n_devices == 0
assert num_tests % n_devices == 0

max_iters = 300  # @param
center_lr = 0.01  # @param
init_std = 0.04  # @param
std_lr = 0.07  # @param

seed = 42  # @param

In [5]:
# @title Training

train_task = BraxTask(env_name=env_name, test=False)
test_task = BraxTask(env_name=env_name, test=True)
    
policy = MLPPolicy(
    input_dim=train_task.obs_shape[0],
    output_dim=train_task.act_shape[0],
    hidden_dims=[32, 32, 32, 32],
)
print('#params={}'.format(policy.num_params))

solver = PGPE(
    pop_size=pop_size,
    param_size=policy.num_params,
    optimizer='adam',
    center_learning_rate=center_lr,
    stdev_learning_rate=std_lr,
    init_stdev=init_std,
    seed=seed,
)
obs_normalizer = ObsNormalizer(obs_shape=train_task.obs_shape)
sim_mgr = SimManager(
    n_repeats=1,
    test_n_repeats=1,
    pop_size=pop_size,
    n_evaluations=num_tests,
    policy_net=policy,
    train_vec_task=train_task,
    valid_vec_task=test_task,
    seed=seed,
    obs_normalizer=obs_normalizer,
)

print('Start training Brax ({}) for {} iterations.'.format(env_name, max_iters))
start_time = time.perf_counter()
for train_iters in range(max_iters):
  
  # Training.
  params = solver.ask()
  scores = sim_mgr.eval_params(params=params, test=False)
  solver.tell(fitness=scores)

  # Test periodically.
  if train_iters > 0 and train_iters % 10 == 0:
      best_params = solver.best_params
      scores = np.array(sim_mgr.eval_params(params=best_params, test=True))
      score_avg = np.mean(scores)
      score_std = np.std(scores)
      print('Iter={0}, #tests={1}, score.avg={2:.2f}, score.std={3:.2f}'.format(
          train_iters, num_tests, score_avg, score_std))

# Final test.
best_params = solver.best_params
scores = np.array(sim_mgr.eval_params(params=best_params, test=True))
score_avg = np.mean(scores)
score_std = np.std(scores)
print('Iter={0}, #tests={1}, score.avg={2:.2f}, score.std={3:.2f}'.format(
    train_iters, num_tests, score_avg, score_std))
print('time cost: {}s'.format(time.perf_counter() - start_time))

#params=6248
Start training Brax (ant) for 300 iterations.


  warn(f"The jitted function {name} includes a pmap. Using "
  warn(f"The jitted function {name} includes a pmap. Using "


Iter=10, #tests=128, score.avg=998.41, score.std=0.75
Iter=20, #tests=128, score.avg=1000.74, score.std=1.00
Iter=30, #tests=128, score.avg=1183.64, score.std=103.71
Iter=40, #tests=128, score.avg=1547.04, score.std=74.12
Iter=50, #tests=128, score.avg=1849.80, score.std=124.95
Iter=60, #tests=128, score.avg=2476.96, score.std=236.69
Iter=70, #tests=128, score.avg=3206.10, score.std=296.84
Iter=80, #tests=128, score.avg=3581.43, score.std=55.88
Iter=90, #tests=128, score.avg=3735.85, score.std=48.15
Iter=100, #tests=128, score.avg=3873.76, score.std=56.37
Iter=110, #tests=128, score.avg=4100.58, score.std=32.77
Iter=120, #tests=128, score.avg=4282.35, score.std=27.74
Iter=130, #tests=128, score.avg=4390.00, score.std=26.75
Iter=140, #tests=128, score.avg=4547.78, score.std=30.85
Iter=150, #tests=128, score.avg=4645.33, score.std=33.28
Iter=160, #tests=128, score.avg=4763.96, score.std=34.70
Iter=170, #tests=128, score.avg=4807.52, score.std=29.54
Iter=180, #tests=128, score.avg=4993.32

In [6]:
# @title Visualize the trained policy

task_reset_fn = jax.jit(env.reset)
policy_reset_fn = jax.jit(policy.reset)
step_fn = jax.jit(env.step)
act_fn = jax.jit(policy.get_actions)
obs_norm_fn = jax.jit(obs_normalizer.normalize_obs)

best_params = solver.best_params
obs_params = sim_mgr.obs_params

total_reward = 0
rollout = []
rng = jax.random.PRNGKey(seed=42)
task_state = task_reset_fn(rng=rng)
policy_state = policy_reset_fn(task_state)

while not task_state.done:
  rollout.append(task_state)
  task_state = task_state.replace(
    obs=obs_norm_fn(task_state.obs[None, :], obs_params))
  act, policy_state = act_fn(task_state, best_params[None, :], policy_state)
  task_state = step_fn(task_state, act[0])
  total_reward = total_reward + task_state.reward

print('rollout reward = {}'.format(total_reward))
HTML(html.render(env.sys, [s.qp for s in rollout]))

rollout reward = 5616.42041015625
