Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
d38a671
Initial PPO commit
wrzadkow Sep 11, 2020
c0ff3ef
Use jax.nn.one_hot instead of list comprehension for speed
wrzadkow Sep 11, 2020
f576a76
Clarity: calculate only advantages in gae_advantages()
wrzadkow Sep 11, 2020
11bc593
jit-compile training step
wrzadkow Sep 11, 2020
5feeec7
Clarity: get rid of most [:-1] indexing
wrzadkow Sep 14, 2020
8be6677
jit & vmap Generalized Advantage Estimation
wrzadkow Sep 14, 2020
670978f
Add advantage normalization
wrzadkow Sep 14, 2020
3414100
Small code cleanup
wrzadkow Sep 14, 2020
f40b049
Add some asserts & debug info logging
wrzadkow Sep 14, 2020
2bd52d8
Add unit tests
wrzadkow Sep 15, 2020
b943afc
Add more debugging info
wrzadkow Sep 15, 2020
b0543a9
Add forward pass tests
wrzadkow Sep 16, 2020
6eedf84
Explicitly mention values shape being (batch,1), not (batch, ) (no in…
wrzadkow Sep 16, 2020
04763aa
Add more asserts, test more frequently
wrzadkow Sep 16, 2020
be01451
Use log_probs from the start
wrzadkow Sep 16, 2020
a99baac
Thread sync: wait for experience before starting the training
wrzadkow Sep 17, 2020
c06e8d7
Reduce amount of information printed when testing
wrzadkow Sep 17, 2020
21a3540
Clarity: use namedtuple instead of tuple
wrzadkow Sep 17, 2020
c18dd9d
Add README
wrzadkow Sep 17, 2020
d9ad5be
Enhance docstrings
wrzadkow Sep 17, 2020
d0ff2ae
Allow more flexible game choice (don't hardcode game-pecific features)
wrzadkow Sep 18, 2020
1af5bbb
Correctly specify the number of frames
wrzadkow Sep 18, 2020
f88e45b
Add device_get() for speed as suggested by @jheek
wrzadkow Sep 18, 2020
690a9c8
Add requirements.txt
wrzadkow Sep 18, 2020
58c4ca0
Use absl.flags for better hyperparameter handling
wrzadkow Sep 18, 2020
f53c1df
Style improvement (comments by @lespeholt and @8bitmp3 & beyond)
wrzadkow Sep 21, 2020
2b10c33
Don't bin rewards during testing
wrzadkow Sep 21, 2020
da0ec77
Update testing requirements
wrzadkow Sep 21, 2020
9c72f00
Implement the decay of the clip parameter and learning rate
wrzadkow Sep 21, 2020
f398660
Models: jnp.maximum->nn.relu and use dtype everywhere
wrzadkow Sep 22, 2020
19dbbc2
Append and then reverse instead of pushing in front in GAE estimation
wrzadkow Sep 22, 2020
518a7f6
Unit & policy test improvements
wrzadkow Sep 23, 2020
8ef4493
Fix conflict in setup.py
wrzadkow Sep 23, 2020
e846aef
Add required packages to test requirements
wrzadkow Sep 23, 2020
399e9b2
Merge branch 'master' into rl-example-ppo
wrzadkow Sep 23, 2020
7b02ec0
Cleanup of main.py incl. variable rename
wrzadkow Sep 23, 2020
50b2b79
Streamline training: use one thread, divide code into smaller chunks
wrzadkow Sep 24, 2020
df3daa1
Avoid using global variables
wrzadkow Sep 24, 2020
7e036ae
Adhere to file naming standard
wrzadkow Sep 24, 2020
9ff33b9
Merge remote.py with agent.py due to similar function
wrzadkow Sep 25, 2020
08bd344
Use tensorboard for logging and add checkpointing
wrzadkow Sep 25, 2020
65faed8
Simplify and format code
wrzadkow Sep 28, 2020
68b8713
Save checkpoints less frequently
wrzadkow Sep 28, 2020
57dd0a3
Update the README
wrzadkow Sep 29, 2020
d7a8fa4
Don't send values and log probs to remote process and back
wrzadkow Sep 29, 2020
f9e37fe
Add tensorboard.dev trace
wrzadkow Sep 30, 2020
70d21f7
Remove unneeded function get_state()
wrzadkow Sep 30, 2020
342786b
Small type hints & docstrings enhancement
wrzadkow Sep 30, 2020
a4dade8
Use ml_collections for hyperparameter handling
wrzadkow Sep 30, 2020
315902b
Refactor a long statement
wrzadkow Oct 1, 2020
d2eae5c
Test: use assertEqual and clip rewards when testing them
wrzadkow Oct 1, 2020
d444075
Compile vectorized code instead of vectorizing compiled code
wrzadkow Oct 1, 2020
f3a9d03
Specify static_argnums with proper int
wrzadkow Oct 1, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions examples/ppo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Proximal Policy Optimization

Uses the Proximal Policy Optimization algorithm ([Schulman et al., 2017](https://arxiv.org/abs/1707.06347))
to learn playing Atari games.

## Requirements

This example depends on the `gym`, `opencv-python` and `atari-py` packages
in addition to `jax` and `flax`.

## Supported setups

The example should run with other configurations and hardware, but was explicitly
tested on the following:

| Hardware | Game | Training time | Total frames seen | TensorBoard.dev |
| --- | --- | --- | --- | --- |
| 1x V100 GPU | Qbert | 9h 27m 8s | 40M | [2020-09-30](https://tensorboard.dev/experiment/1pacpbxxRz2di3NIOFkHoA/#scalars) |

## How to run

Running `python ppo_main.py` will run the example with default
(hyper)parameters, i.e. for 40M frames on the Pong game.

By default logging info and checkpoints will be stored in `/tmp/ppo_training`
directory. This can be overriden as follows:

```python ppo_main.py --logdir=/my_fav_directory```

You can also override the default (hyper)parameters, for example

```python ppo_main.py --config.game=Seaquest --config.total_frames=20000000 --config.decaying_lr_and_clip_param=False --logdir=/tmp/seaquest```

will train the model on 20M Seaquest frames with constant (i.e. not linearly
decaying) learning rate and PPO clipping parameter. Checkpoints and tensorboard
files will be saved in `/tmp/seaquest`.

Unit tests can be run using `python ppo_lib_test.py`.

## How to run on Google Cloud TPU

It is also possible to run this code on Google Cloud TPU. For detailed
instructions on the required setup, please refer to the [WMT example readme](https://github.com/google/flax/tree/master/examples/wmt).

## Owners

Jonathan Heek @jheek, Wojciech Rzadkowski @wrzadkow
57 changes: 57 additions & 0 deletions examples/ppo/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""Agent utilities, incl. choosing the move and running in separate process."""

import multiprocessing
import collections
import jax
import numpy as onp

import env_utils

@jax.jit
def policy_action(model, state):
"""Forward pass of the network."""
out = model(state)
return out


ExpTuple = collections.namedtuple(
'ExpTuple', ['state', 'action', 'reward', 'value', 'log_prob', 'done'])


class RemoteSimulator:
"""Wrap functionality for an agent emulating Atari in a separate process.

An object of this class is created for every agent.
"""

def __init__(self, game: str):
"""Start the remote process and create Pipe() to communicate with it."""
parent_conn, child_conn = multiprocessing.Pipe()
self.proc = multiprocessing.Process(
target=rcv_action_send_exp, args=(child_conn, game))
self.conn = parent_conn
self.proc.start()


def rcv_action_send_exp(conn, game: str):
"""Run the remote agents.

Receive action from the main learner, perform one step of simulation and
send back collected experience.
"""
env = env_utils.create_env(game, clip_rewards=True)
while True:
obs = env.reset()
done = False
# Observations fetched from Atari env need additional batch dimension.
state = obs[None, ...]
while not done:
conn.send(state)
action = conn.recv()
obs, reward, done, _ = env.step(action)
next_state = obs[None, ...] if not done else None
experience = (state, action, reward, done)
conn.send(experience)
if done:
break
state = next_state
40 changes: 40 additions & 0 deletions examples/ppo/default_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Definitions of default hyperparameters."""

import ml_collections

def get_config():
"""Get the default configuration.

The default hyperparameters originate from PPO paper arXiv:1707.06347
and openAI baselines 2::
https://github.com/openai/baselines/blob/master/baselines/ppo2/defaults.py
"""
config = ml_collections.ConfigDict()
# The Atari game used.
config.game = 'Pong'
# Total number of frames seen during training.
config.total_frames = 40000000
# The learning rate for the Adam optimizer.
config.learning_rate = 2.5e-4
# Batch size used in training.
config.batch_size = 256
# Number of agents playing in parallel.
config.num_agents = 8
# Number of steps each agent performs in one policy unroll.
config.actor_steps = 128
# Number of training epochs per each unroll of the policy.
config.num_epochs = 3
# RL discount parameter.
config.gamma = 0.99
# Generalized Advantage Estimation parameter.
config.lambda_ = 0.95
# The PPO clipping parameter used to clamp ratios in loss function.
config.clip_param = 0.1
# Weight of value function loss in the total loss.
config.vf_coeff = 0.5
# Weight of entropy bonus in the total loss.
config.entropy_coeff = 0.01
# Linearly decay learning rate and clipping parameter to zero during
# the training.
config.decaying_lr_and_clip_param = True
return config
67 changes: 67 additions & 0 deletions examples/ppo/env_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""Utilities for handling the Atari environment."""

import collections
import gym
import numpy as onp

import seed_rl_atari_preprocessing

class ClipRewardEnv(gym.RewardWrapper):
"""Adapted from OpenAI baselines.

github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
"""

def __init__(self, env):
gym.RewardWrapper.__init__(self, env)

def reward(self, reward):
"""Bin reward to {+1, 0, -1} by its sign."""
return onp.sign(reward)

class FrameStack:
"""Implements stacking of `num_frames` last frames of the game.

Wraps an AtariPreprocessing object.
"""

def __init__(
self,
preproc: seed_rl_atari_preprocessing.AtariPreprocessing,
num_frames: int):
self.preproc = preproc
self.num_frames = num_frames
self.frames = collections.deque(maxlen=num_frames)

def reset(self):
ob = self.preproc.reset()
for _ in range(self.num_frames):
self.frames.append(ob)
return self._get_array()

def step(self, action: int):
ob, reward, done, info = self.preproc.step(action)
self.frames.append(ob)
return self._get_array(), reward, done, info

def _get_array(self):
assert len(self.frames) == self.num_frames
return onp.concatenate(self.frames, axis=-1)

def create_env(game: str, clip_rewards: bool):
"""Create a FrameStack object that serves as environment for the `game`."""
env = gym.make(game)
if clip_rewards:
env = ClipRewardEnv(env) # bin rewards to {-1., 0., 1.}
preproc = seed_rl_atari_preprocessing.AtariPreprocessing(env)
stack = FrameStack(preproc, num_frames=4)
return stack

def get_num_actions(game: str):
"""Get the number of possible actions of a given Atari game.

This determines the number of outputs in the actor part of the
actor-critic model.
"""
env = gym.make(game)
return env.action_space.n
53 changes: 53 additions & 0 deletions examples/ppo/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Class and functions to define and initialize the actor-critic model."""

import numpy as onp
import flax
from flax import nn
import jax.numpy as jnp

class ActorCritic(flax.nn.Module):
"""Class defining the actor-critic model."""

def apply(self, x, num_outputs):
"""Define the convolutional network architecture.

Architecture originates from "Human-level control through deep reinforcement
learning.", Nature 518, no. 7540 (2015): 529-533.
Note that this is different than the one from "Playing atari with deep
reinforcement learning." arxiv.org/abs/1312.5602 (2013)
"""
dtype = jnp.float32
x = x.astype(dtype) / 255.
x = nn.Conv(x, features=32, kernel_size=(8, 8),
strides=(4, 4), name='conv1',
dtype=dtype)
x = nn.relu(x)
x = nn.Conv(x, features=64, kernel_size=(4, 4),
strides=(2, 2), name='conv2',
dtype=dtype)
x = nn.relu(x)
x = nn.Conv(x, features=64, kernel_size=(3, 3),
strides=(1, 1), name='conv3',
dtype=dtype)
x = nn.relu(x)
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(x, features=512, name='hidden', dtype=dtype)
x = nn.relu(x)
# Network used to both estimate policy (logits) and expected state value.
# See github.com/openai/baselines/blob/master/baselines/ppo1/cnn_policy.py
logits = nn.Dense(x, features=num_outputs, name='logits', dtype=dtype)
policy_log_probabilities = nn.log_softmax(logits)
value = nn.Dense(x, features=1, name='value', dtype=dtype)
return policy_log_probabilities, value

def create_model(key: onp.ndarray, num_outputs: int):
input_dims = (1, 84, 84, 4) # (minibatch, height, width, stacked frames)
module = ActorCritic.partial(num_outputs=num_outputs)
_, initial_par = module.init_by_shape(key, [(input_dims, jnp.float32)])
model = flax.nn.Model(module, initial_par)
return model

def create_optimizer(model: nn.base.Model, learning_rate: float):
optimizer_def = flax.optim.Adam(learning_rate)
optimizer = optimizer_def.create(model)
return optimizer
Loading