In [1]:
# @markdown ## ⚠️ PLEASE NOTE:
# @markdown This colab runs best using a GPU runtime.  From the Colab menu, choose Runtime > Change Runtime Type, then select **'GPU'** in the dropdown.

%pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
%pip install matplotlib
%pip install IPython
%pip install flax

import functools
import os
from datetime import datetime

import jax
import matplotlib.pyplot as plt
from IPython.display import HTML, clear_output
from jax import numpy as jp

try:
    import brax
except ImportError:
    %pip install git+https://github.com/google/brax.git@main
    import brax

import flax
from brax import envs
from brax import geometry as braxgeo
from brax.io import html, json, model
from brax.training.agents.ppo import train as ppo
from brax.training.agents.sac import train as sac

Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.


In [12]:
import sys
# print current working dir
print(sys.path)

/root


In [11]:
from rewards import reward_flavors

ModuleNotFoundError: No module named 'Rewards'

In [3]:
env = envs.create(env_name="ant", backend="positional")

In [None]:
state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0))

In [None]:
max_y = 8000
min_y = 0

xdata, ydata = [], []
times = [datetime.now()]


def progress(num_steps, metrics):
    times.append(datetime.now())
    xdata.append(num_steps)
    ydata.append(metrics["eval/episode_reward"])
    clear_output(wait=True)
    plt.xlim([0, 100_000_000])
    plt.ylim([min_y, max_y])
    plt.xlabel("# environment steps")
    plt.ylabel("reward per episode")
    plt.plot(xdata, ydata)
    plt.show()


make_inference_fn, params, _ = ppo.train(
    environment=env,
    progress_fn=progress,
    num_timesteps=100_000_000,
    num_evals=10,
    reward_scaling=10,
    episode_length=1000,
    normalize_observations=True,
    action_repeat=1,
    unroll_length=5,
    num_minibatches=32,
    num_updates_per_batch=4,
    discounting=0.97,
    learning_rate=3e-4,
    entropy_cost=1e-2,
    num_envs=4096,
    batch_size=2048,
    seed=1,
)


print(f"time to jit: {times[1] - times[0]}")
print(f"time to train: {times[-1] - times[1]}")

In [None]:
model.save_params("/tmp/params", params)
params = model.load_params("/tmp/params")
inference_fn = make_inference_fn(params)

In [None]:
# @title Visualizing a trajectory of the learned inference function

# create an env with auto-reset
env = envs.create(env_name="ant", backend="positional")

jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step)
jit_inference_fn = jax.jit(inference_fn)

rollout = []
rng = jax.random.PRNGKey(seed=1)
state = jit_env_reset(rng=rng)
for _ in range(1000):
    rollout.append(state.pipeline_state)
    act_rng, rng = jax.random.split(rng)
    act, _ = jit_inference_fn(state.obs, act_rng)
    state = jit_env_step(state, act)

HTML(html.render(env.sys.replace(dt=env.dt), rollout))

# Collecting Rollouts to Rank

In [None]:
env = envs.create(env_name="ant", backend="positional")

jit_vec_env_reset = jax.jit(vec_env.reset)
jit_vec_env_step = jax.jit(vec_env.step)

rollouts = []
rng = jax.random.PRNGKey(seed=1)
states = jit_env_reset(rng=rng)

num_envs = 4096
num_rollouts = 16384
trajectory_length = 1000
num_vec_rollout_passes = num_rollouts // num_envs

In [None]:
vmapped_reset = jax.vmap(jit_env_reset)
vmapped_step = jax.vmap(jit_env_step)
vmapped_inference_fn = jax.vmap(jit_inference_fn)

In [None]:
# Initialize empty arrays to concatenate to
rollout_states = jp.zeros((0 , trajectory_length + 1, env.))
rollout_actions = jp.zeros((0 , trajectory_length, 8))

In [None]:
def collect_new_rollouts(
    parallel_envs: int,
    rng: jax.random.KeyArray,
    trajectory_length: int,
):
    rngs = jax.random.split(rng, parallel_envs)
    states = vmapped_reset(rng=rngs)
    rollout_states = jp.zeros((parallel_envs, trajectory_length, states.pipeline_state.shape[-1]))
    rollout_acts = jp.zeros((parallel_envs, trajectory_length, states.act.shape[-1]))
    for _ in range(trajectory_length):
        rollouts.append(states.pipeline_state)
        act_rngs, rngs = jax.random.split(rngs)
        acts, _ = vmapped_inference_fn(states.obs, act_rngs)
        states = vmapped_step(states, acts)
    return rollout_states, rollout_acts

In [None]:
for _ in range(num_vec_rollout_passes):
    for _ in range(trajectory_length):
        rollouts.append(states.pipeline_state)
        act_rng, rng = jax.random.split(rng)
        acts, _ = vmapped_inference_fn(states.obs, act_rng)
        states = vmapped_step(states, acts)