## Demo for Grasp: Pick-and-Place with a robotic hand

In [None]:
#@title Install Brax and some helper modules
#@markdown ## ⚠️ PLEASE NOTE:
#@markdown This colab runs best using a TPU runtime.  From the Colab menu, choose Runtime > Change Runtime Type, then select **'TPU'** in the dropdown.

from datetime import datetime
import functools
import os
from os import getcwd
from os.path import join
from IPython.display import HTML, Image, clear_output

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

try:
  import brax
except ImportError:
  !pip install git+https://github.com/google/brax.git@main
  clear_output()
  import brax

from brax import envs
from brax import jumpy as jp
from brax.io import html
from brax.io import image
from brax.io import model
from brax.training.agents.ppo import train as ppo
from brax.training.agents.es import train as es
from brax.training.agents.ars import train as ars
import torch

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

if 'COLAB_TPU_ADDR' in os.environ:
  from jax.tools import colab_tpu
  colab_tpu.setup_tpu()

#### Setting up the environment

In [None]:
# Set up some constants
SEED = 0
ENV_NAME = "grasp"
VIS_ALGO = "ppo" # "ppo", "es", "ars"
VIS_STEPS = int(1e3) # how often to render the environment
DEVICE = "cuda" if torch.cuda.is_available() else 

# Create a Brax environment
env_name = ENV_NAME
env = envs.get_environment(env_name=env_name)
state = env.reset(rng=jp.random_prngkey(seed=SEED))

### Details about the environment

Source code for the env can be found [here](https://github.com/google/brax/blob/198dee3ac4/brax/envs/grasp.py).

In [None]:
HTML(html.render(env.sys, [state.qp]))

#### Initial state

In [None]:
env.sys.num_actuators
env.sys.num_bodies
env.sys.num_joints
env.sys.num_joint_dof

In [None]:
rollout = []
rng = jax.random.PRNGKey(seed=SEED)
state = env.reset(rng=rng)
for _ in range(VIS_STEPS):
  action = torch.rand((env.action_size,), device='cuda') * 2 - 1
  state = env.step(state, action)
  rollout.append(state)
HTML(html.render(env.sys, [s.qp for s in rollout]))

### Objectives

1. Grasp trains an agent to pick up an object. Grasp observes three bodies: `Hand`, `Object`, and `Target`. When `Object` reaches `Target`, the agent is rewarded.
2. The `reward function` is determined by the following factors: moving towards the object, being close to the object, touching the object, hitting the target and moving towards the target.

$$
\begin{align}
reward &= moving\ to\ object + close\ to\ object + touching\ object + 5 * target\ hit + moving\ to\ target \nonumber \\
\text{}  &\nonumber \\

reward &: \text{final reward achieved.}  \nonumber \\
moving\ to\ object &: \text{small reward for moving towards the object.} \nonumber \\
close\ to\ object &: \text{small reward for being close to the object.} \nonumber \\
touching\ object &: \text{small reward for touching the object.} \nonumber \\
target\ hit &: \text{high reward for hitting the target (max. reward).} \nonumber \\
moving\ to\ target &: \text{high reward for moving towards the target.} \nonumber
\end{align}
$$


### Training the agent

We will used the following training algorithms to train our RL agent:

1. `Proximal policy optimization (PPO)`
2. `Evolution Strategy (ES)`
3. `Augmented Random Search (ARS)`

#


#### 1. Proximal policy optimization (PPO)

In [None]:
ppo_train_fn = functools.partial(ppo.train, num_timesteps=600_000_000, num_evals=10, reward_scaling=10, episode_length=1000, normalize_observations=True, action_repeat=1, unroll_length=20, num_minibatches=32, num_updates_per_batch=2, discounting=0.99, learning_rate=3e-4, entropy_cost=0.001, num_envs=2048, batch_size=256)

max_y = 100
min_y = 0

ppo_xdata, ppo_ydata = [], []
ppo_times = [datetime.now()]

def ppo_progress(num_steps, metrics):
  ppo_times.append(datetime.now())
  ppo_xdata.append(num_steps)
  ppo_ydata.append(metrics['eval/episode_reward'])
  clear_output(wait=True)
  plt.xlim([0, ppo_train_fn.keywords['num_timesteps']])
  plt.ylim([min_y, max_y])
  plt.xlabel('# environment steps')
  plt.ylabel('reward per episode')
  plt.plot(ppo_xdata, ppo_ydata)
  plt.show()

ppo_make_inference_fn, ppo_params, _ = ppo_train_fn(environment=env, progress_fn=ppo_progress)

print(f'for "PPO": time to jit: {ppo_times[1] - ppo_times[0]}')
print(f'for "PPO": time to train: {ppo_times[-1] - ppo_times[1]}')

#### 2. Evolution Strategy (ES)

In [None]:
es_train_fn = functools.partial(es.train, num_timesteps=600_000_000, num_evals=10, episode_length=1000, normalize_observations=True, action_repeat=1, learning_rate=3e-4, population_size=1024)

max_y = 100
min_y = 0

es_xdata, es_ydata = [], []
es_times = [datetime.now()]

def es_progress(num_steps, metrics):
  es_times.append(datetime.now())
  es_xdata.append(num_steps)
  es_ydata.append(metrics['eval/episode_reward'])
  clear_output(wait=True)
  plt.xlim([0, es_train_fn.keywords['num_timesteps']])
  plt.ylim([min_y, max_y])
  plt.xlabel('# environment steps')
  plt.ylabel('reward per episode')
  plt.plot(es_xdata, es_ydata)
  plt.show()

es_make_inference_fn, es_params, _ = es_train_fn(environment=env, progress_fn=es_progress)

print(f'for "ES": time to jit: {es_times[1] - es_times[0]}')
print(f'for "ES": time to train: {es_times[-1] - es_times[1]}')

#### 3. Augmented Random Search (ARS)

In [None]:
ars_train_fn = functools.partial(ars.train, num_timesteps=600_000_000, num_evals=10, episode_length=1000, normalize_observations=True, action_repeat=1, number_of_directions=1024)

max_y = 100
min_y = 0

ars_xdata, ars_ydata = [], []
ars_times = [datetime.now()]

def ars_progress(num_steps, metrics):
  ars_times.append(datetime.now())
  ars_xdata.append(num_steps)
  ars_ydata.append(metrics['eval/episode_reward'])
  clear_output(wait=True)
  plt.xlim([0, ars_train_fn.keywords['num_timesteps']])
  plt.ylim([min_y, max_y])
  plt.xlabel('# environment steps')
  plt.ylabel('reward per episode')
  plt.plot(ars_xdata, ars_ydata)
  plt.show()

ars_make_inference_fn, ars_params, _ = ars_train_fn(environment=env, progress_fn=ars_progress)

print(f'for "ARS": time to jit: {ars_times[1] - ars_times[0]}')
print(f'for "ARS": time to train: {ars_times[-1] - ars_times[1]}')

### Visualize the trained agent

In [None]:
params = ppo_params if VIS_ALGO == "ppo" else es_params if VIS_ALGO == "es" else ars_params
make_inference_fn = ppo_make_inference_fn if VIS_ALGO == "ppo" else es_make_inference_fn if VIS_ALGO == "es" else ars_make_inference_fn
inference_fn = make_inference_fn(params)

env = envs.create(env_name=env_name)
jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step)
jit_inference_fn = jax.jit(inference_fn)

rollout = []
rng = jax.random.PRNGKey(seed=0)
state = jit_env_reset(rng=rng)
for _ in range(1000):
  rollout.append(state)
  act_rng, rng = jax.random.split(rng)
  act, _ = jit_inference_fn(state.obs, act_rng)
  state = jit_env_step(state, act)

HTML(html.render(env.sys, [s.qp for s in rollout]))