In [140]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [141]:
import functools
import os
import time

from IPython.display import HTML, clear_output

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

import brax
from brax import envs
from brax import jumpy as jp
from brax.training import ppo, sac
from brax.io import html
from brax.io import model

In [154]:
env_name = "walker3d"
env_fn = envs.create_fn(env_name=env_name)
env = env_fn()

## Test Environment

In [132]:
state = env.reset(rng=None)
# print(state.qp.pos)

act = jnp.zeros(env.action_size)
state = env.step(state, act)
# print(state.qp.pos)

HTML(html.render(env.sys, [state.qp]))

In [153]:
# jit_env_reset = jax.jit(env.reset)
# jit_env_step = jax.jit(env.step)
jit_env_reset = env.reset
jit_env_step = env.step

rollout = []
rng = jax.random.PRNGKey(seed=0)
state = jit_env_reset(rng=rng)

while not state.done:
    rollout.append(state)
    key, _ = jax.random.split(rng)
    act = jax.random.uniform(key, (env.action_size,))
    state = jit_env_step(state, act)

HTML(html.render(env.sys, [s.qp for s in rollout]))

## Train Controller

In [158]:
train_fn = functools.partial(
  ppo.train, num_timesteps = 50000000, log_frequency = 20,
  reward_scaling = 0.1, episode_length = 1000, normalize_observations = True,
  action_repeat = 1, unroll_length = 10, num_minibatches = 32,
  num_update_epochs = 8, discounting = 0.97, learning_rate = 3e-4,
  entropy_cost = 1e-3, num_envs = 2048, batch_size = 1024, seed=1
)

times = [time.time()]

def progress(num_steps, metrics):
    times.append(time.time())
    elapsed_time = max(times[-1] - times[1], 1)
    fps = num_steps / elapsed_time
    print(f"Steps: {num_steps:d} | FPS: {fps:.0f} | Reward: {metrics['eval/episode_reward']:.0f}")

inference_fn, params, _ = train_fn(environment_fn=env_fn, progress_fn=progress)

print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')

Steps: 0 | FPS: 0 | Reward: 364
Steps: 2293760 | FPS: 74789 | Reward: 590
Steps: 4587520 | FPS: 94633 | Reward: 775
Steps: 6881280 | FPS: 103843 | Reward: 865
Steps: 9175040 | FPS: 109147 | Reward: 1125
Steps: 11468800 | FPS: 112568 | Reward: 1307
Steps: 13762560 | FPS: 114970 | Reward: 1496
Steps: 16056320 | FPS: 116757 | Reward: 1660
Steps: 18350080 | FPS: 118081 | Reward: 1715
Steps: 20643840 | FPS: 119142 | Reward: 1711
Steps: 22937600 | FPS: 120029 | Reward: 2211
Steps: 25231360 | FPS: 120733 | Reward: 2393
Steps: 27525120 | FPS: 121338 | Reward: 2652
Steps: 29818880 | FPS: 121863 | Reward: 2882
Steps: 32112640 | FPS: 122296 | Reward: 3180
Steps: 34406400 | FPS: 122676 | Reward: 3314
Steps: 36700160 | FPS: 123030 | Reward: 3557
Steps: 38993920 | FPS: 123319 | Reward: 3848
Steps: 41287680 | FPS: 123565 | Reward: 3983
Steps: 43581440 | FPS: 123782 | Reward: 4251
Steps: 45875200 | FPS: 123977 | Reward: 4483
time to jit: 24.854645013809204
time to train: 370.03103280067444


In [159]:
env = env_fn()
model.save_params('/tmp/params', params)

empty_params, inference_fn = ppo.make_params_and_inference_fn(
    env.observation_size, env.action_size, True)

params = model.load_params('/tmp/params', empty_params)

## Render Trained Controller

In [160]:
jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step)
jit_inference_fn = jax.jit(inference_fn)

rollout = []
rng = jax.random.PRNGKey(seed=0)
state = jit_env_reset(rng=rng)

while not state.done:
    rollout.append(state)
    act_rng, rng = jax.random.split(rng)
    act = jit_inference_fn(params, state.obs, act_rng)
    state = jit_env_step(state, act)

HTML(html.render(env.sys, [s.qp for s in rollout]))