In [None]:
from brax import envs
from brax.training.agents.ppo import train as ppo
from prefacc.training.agents.prefppo import train as prefppo

In [None]:
env = envs.get_environment('ant', backend='positional')

In [None]:
make_inference_fn, params, metrics = prefppo.train(
  env,
  num_timesteps=50_000_000, 
  num_evals=20, 
  reward_scaling=10, 
  episode_length=1000, 
  normalize_observations=True, 
  action_repeat=1, 
  unroll_length=5, 
  num_minibatches=32, 
  num_updates_per_batch=4, 
  discounting=0.97, 
  learning_rate=3e-4, 
  entropy_cost=1e-2, 
  num_envs=4096, 
  batch_size=2048, 
  num_prefill_iterations=10,
  seed=1,
  num_prefs=2000)

print("Training complete")

In [None]:
print(metrics)

In [None]:
from brax.io import model
from brax.io import json
from brax.io import html

In [None]:
model.save_params('/tmp/params', params)

In [None]:
params = model.load_params('/tmp/params')
inference_fn = make_inference_fn(params)

In [None]:
env = envs.create('ant', backend='positional')

In [None]:
import jax

In [None]:
jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step)
jit_inference_fn = jax.jit(inference_fn)

In [None]:
rollout = []
rng = jax.random.PRNGKey(seed=42)
state = jit_env_reset(rng=rng)
for _ in range(200):
  rollout.append(state.pipeline_state)
  act_rng, rng = jax.random.split(rng)
  act, _ = jit_inference_fn(state.obs, act_rng)
  state = jit_env_step(state, act)

In [None]:
from IPython.display import HTML, clear_output

In [None]:
HTML(html.render(env.sys.tree_replace({'opt.timestep': env.dt}), rollout))