<a href="https://colab.research.google.com/github/mohammadzainabbas/Reinforcement-Learning-CS/blob/dev/notebooks/demo_ppo_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Demo for step-by-step training with PPO

In [1]:
from datetime import datetime
import functools
import os
from os import getcwd
from os.path import join
from IPython.display import HTML, clear_output

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

try:
  import brax
except ImportError:
  !pip install git+https://github.com/google/brax.git@main
  clear_output()
  import brax

from brax import envs
from brax import jumpy as jp
from brax.io import html
from brax.io import model
from brax.training.agents.ppo import train as ppo

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

if 'COLAB_TPU_ADDR' in os.environ:
  from jax.tools import colab_tpu
  colab_tpu.setup_tpu()

#### Environment

In [2]:
SEED = 0
env_name = "grasp"
env = envs.get_environment(env_name=env_name)
state = env.reset(rng=jp.random_prngkey(seed=SEED))

HTML(html.render(env.sys, [state.qp]))

#### Helper functions

In [3]:
def train_ppo(num_timesteps, env_name):
	print(f"Training PPO for '{num_timesteps}' timesteps'")

	env = envs.get_environment(env_name=env_name)
	state = env.reset(rng=jp.random_prngkey(seed=SEED))

	train_fn = functools.partial(ppo.train, num_timesteps=num_timesteps, num_evals=10, reward_scaling=10, episode_length=1000, normalize_observations=True, action_repeat=1, unroll_length=20, num_minibatches=32, num_updates_per_batch=2, discounting=0.99, learning_rate=3e-4, entropy_cost=0.001, num_envs=2048, batch_size=256)

	max_y = 100
	min_y = 0

	xdata, ydata = [], []
	times = [datetime.now()]

	def progress(num_steps, metrics):
		times.append(datetime.now())
		xdata.append(num_steps)
		ydata.append(metrics['eval/episode_reward'])
		clear_output(wait=True)
		# plt.xlim([0, train_fn.keywords['num_timesteps']])
		# plt.ylim([min_y, max_y])
		# plt.xlabel('# environment steps')
		# plt.ylabel('reward per episode')
		# plt.plot(xdata, ydata)
		# plt.show()

	make_inference_fn, params, _ = train_fn(environment=env, progress_fn=progress)
	print(f'time to jit: {times[1] - times[0]}')
	print(f'time to train: {times[-1] - times[1]}')

	return make_inference_fn, params, times, xdata, ydata

def visual_rollout(inference_fn, env_name, steps=100, seed=0):
	env = envs.create(env_name=env_name)
	jit_env_reset = jax.jit(env.reset)
	jit_env_step = jax.jit(env.step)
	jit_inference_fn = jax.jit(inference_fn)

	rollout = []
	rng = jax.random.PRNGKey(seed=seed)
	state = jit_env_reset(rng=rng)
	for _ in range(steps):
		rollout.append(state)
		act_rng, rng = jax.random.split(rng)
		act, _ = jit_inference_fn(state.obs, act_rng)
		state = jit_env_step(state, act)

	return env.sys, [s.qp for s in rollout]

#### Training (step-by-step)

In [4]:
training_num_timesteps = [1_000, 5_000_000, 400_000_000]

inference_fns = []

for idx, num_timesteps in enumerate(training_num_timesteps):
	make_inference_fn, params, times, xdata, ydata = train_ppo(num_timesteps, env_name)
	inference_fns.append(make_inference_fn(params))

time to jit: 0:01:25.988931
time to train: 0:08:41.862544


#### Visualise learning

In [7]:
vis_steps = [300, 500, 750]

env_sys = []
rollouts = []

for idx, inference_fn in enumerate(inference_fns):
	sys, rollout = visual_rollout(inference_fn, env_name, steps=vis_steps[idx], seed=SEED)
	env_sys.append(sys)
	rollouts.append(rollout)

In [8]:
for i, sys in enumerate(env_sys):
	HTML(html.render(sys, rollouts[i]))