In [1]:
from brax import envs
from prefacc.training.agents.prefppo import train as prefppo

In [70]:
env = envs.get_environment('hopper', backend='positional')

In [126]:
make_inference_fn, params, metrics = prefppo.train(
  env,
  num_timesteps=100_000_000, 
  num_evals=20, 
  reward_scaling=30, 
  episode_length=1000, 
  normalize_observations=True, 
  action_repeat=1, 
  unroll_length=5, 
  num_minibatches=32, 
  num_updates_per_batch=4, 
  discounting=0.97, 
  learning_rate=3e-4, 
  entropy_cost=1e-2, 
  num_envs=4096, 
  batch_size=2048, 
  num_prefill_iterations=10,
  seed=123,
  num_prefs=50000)

print("Training complete")

Training complete


In [116]:
print(metrics)

{'eval/walltime': 19.561452865600586, 'training/sps': np.float64(658266.4356909072), 'training/walltime': 172.35624265670776, 'training/entropy_loss': Array(0.05913162, dtype=float32), 'training/policy_loss': Array(-0.00039917, dtype=float32), 'training/total_loss': Array(23.174294, dtype=float32), 'training/v_loss': Array(23.11556, dtype=float32), 'eval/episode_reward': Array(474.9343, dtype=float32), 'eval/episode_reward_ctrl': Array(-0.39988792, dtype=float32), 'eval/episode_reward_forward': Array(179.74045, dtype=float32), 'eval/episode_reward_healthy': Array(295.59375, dtype=float32), 'eval/episode_x_position': Array(232.03638, dtype=float32), 'eval/episode_x_velocity': Array(179.74045, dtype=float32), 'eval/episode_reward_std': Array(36.36413, dtype=float32), 'eval/episode_reward_ctrl_std': Array(0.0288454, dtype=float32), 'eval/episode_reward_forward_std': Array(13.214758, dtype=float32), 'eval/episode_reward_healthy_std': Array(23.309212, dtype=float32), 'eval/episode_x_positio

In [117]:
from brax.io import model
from brax.io import json
from brax.io import html

In [118]:
model.save_params('/tmp/params', params)

In [119]:
params = model.load_params('/tmp/params')
inference_fn = make_inference_fn(params)

In [120]:
env = envs.create('hopper', backend='positional')

In [121]:
import jax

In [122]:
jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step)
jit_inference_fn = jax.jit(inference_fn)

In [123]:
rollout = []
rng = jax.random.PRNGKey(seed=42)
state = jit_env_reset(rng=rng)
for _ in range(200):
  rollout.append(state.pipeline_state)
  act_rng, rng = jax.random.split(rng)
  act, _ = jit_inference_fn(state.obs, act_rng)
  state = jit_env_step(state, act)

In [124]:
from IPython.display import HTML, clear_output

In [125]:
HTML(html.render(env.sys.tree_replace({'opt.timestep': env.dt}), rollout))