In [1]:
import functools
import os
import time

from IPython.display import HTML, clear_output

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

import brax
from brax import envs
from brax import jumpy as jp
from brax.training import ppo, sac
from brax.io import html
from brax.io import model

In [2]:
env_name = "humanoid"
env_fn = envs.create_fn(env_name=env_name)

train_fn = functools.partial(
  ppo.train, num_timesteps = 50000000, log_frequency = 20,
  reward_scaling = 0.1, episode_length = 1000, normalize_observations = True,
  action_repeat = 1, unroll_length = 10, num_minibatches = 32,
  num_update_epochs = 8, discounting = 0.97, learning_rate = 3e-4,
  entropy_cost = 1e-3, num_envs = 2048, batch_size = 1024, seed=1
)

times = [time.time()]

def progress(num_steps, metrics):
    times.append(time.time())
    elapsed_time = max(times[-1] - times[1], 1)
    fps = num_steps / elapsed_time
    print(f"FPS: {fps:.0f} | Reward: {metrics['eval/episode_reward']:.0f}")

inference_fn, params, _ = train_fn(environment_fn=env_fn, progress_fn=progress)

print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')

FPS: 0 | Reward: 285
FPS: 78543 | Reward: 478
FPS: 99977 | Reward: 663
FPS: 109194 | Reward: 870
FPS: 115328 | Reward: 1144
FPS: 118876 | Reward: 1521
FPS: 121782 | Reward: 1886
FPS: 123850 | Reward: 2326
FPS: 125609 | Reward: 2715
FPS: 126553 | Reward: 2806
FPS: 127560 | Reward: 3337
FPS: 128295 | Reward: 3506
FPS: 129007 | Reward: 3788
FPS: 129688 | Reward: 4560
FPS: 130317 | Reward: 7792
FPS: 130879 | Reward: 7740
FPS: 131288 | Reward: 7026
FPS: 131754 | Reward: 7697
FPS: 132077 | Reward: 7393
FPS: 132282 | Reward: 7431
FPS: 132497 | Reward: 6638
time to jit: 32.39626383781433
time to train: 346.23615741729736


In [3]:
env = env_fn()
model.save_params('/tmp/params', params)

empty_params, inference_fn = ppo.make_params_and_inference_fn(
    env.observation_size, env.action_size, True)

params = model.load_params('/tmp/params', empty_params)

In [7]:
jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step)
jit_inference_fn = jax.jit(inference_fn)

rollout = []
rng = jax.random.PRNGKey(seed=0)
state = jit_env_reset(rng=rng)
while not state.done:
    rollout.append(state)
    act_rng, rng = jax.random.split(rng)
    act = jit_inference_fn(params, state.obs, act_rng)
    state = jit_env_step(state, act)

HTML(html.render(env.sys, [s.qp for s in rollout]))

In [10]:
HTML(html.render(env.sys, [s.qp for s in rollout]))