-
Notifications
You must be signed in to change notification settings - Fork 787
Proximal Policy Optimization example #470
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
53 commits
Select commit
Hold shift + click to select a range
d38a671
Initial PPO commit
wrzadkow c0ff3ef
Use jax.nn.one_hot instead of list comprehension for speed
wrzadkow f576a76
Clarity: calculate only advantages in gae_advantages()
wrzadkow 11bc593
jit-compile training step
wrzadkow 5feeec7
Clarity: get rid of most [:-1] indexing
wrzadkow 8be6677
jit & vmap Generalized Advantage Estimation
wrzadkow 670978f
Add advantage normalization
wrzadkow 3414100
Small code cleanup
wrzadkow f40b049
Add some asserts & debug info logging
wrzadkow 2bd52d8
Add unit tests
wrzadkow b943afc
Add more debugging info
wrzadkow b0543a9
Add forward pass tests
wrzadkow 6eedf84
Explicitly mention values shape being (batch,1), not (batch, ) (no in…
wrzadkow 04763aa
Add more asserts, test more frequently
wrzadkow be01451
Use log_probs from the start
wrzadkow a99baac
Thread sync: wait for experience before starting the training
wrzadkow c06e8d7
Reduce amount of information printed when testing
wrzadkow 21a3540
Clarity: use namedtuple instead of tuple
wrzadkow c18dd9d
Add README
wrzadkow d9ad5be
Enhance docstrings
wrzadkow d0ff2ae
Allow more flexible game choice (don't hardcode game-pecific features)
wrzadkow 1af5bbb
Correctly specify the number of frames
wrzadkow f88e45b
Add device_get() for speed as suggested by @jheek
wrzadkow 690a9c8
Add requirements.txt
wrzadkow 58c4ca0
Use absl.flags for better hyperparameter handling
wrzadkow f53c1df
Style improvement (comments by @lespeholt and @8bitmp3 & beyond)
wrzadkow 2b10c33
Don't bin rewards during testing
wrzadkow da0ec77
Update testing requirements
wrzadkow 9c72f00
Implement the decay of the clip parameter and learning rate
wrzadkow f398660
Models: jnp.maximum->nn.relu and use dtype everywhere
wrzadkow 19dbbc2
Append and then reverse instead of pushing in front in GAE estimation
wrzadkow 518a7f6
Unit & policy test improvements
wrzadkow 8ef4493
Fix conflict in setup.py
wrzadkow e846aef
Add required packages to test requirements
wrzadkow 399e9b2
Merge branch 'master' into rl-example-ppo
wrzadkow 7b02ec0
Cleanup of main.py incl. variable rename
wrzadkow 50b2b79
Streamline training: use one thread, divide code into smaller chunks
wrzadkow df3daa1
Avoid using global variables
wrzadkow 7e036ae
Adhere to file naming standard
wrzadkow 9ff33b9
Merge remote.py with agent.py due to similar function
wrzadkow 08bd344
Use tensorboard for logging and add checkpointing
wrzadkow 65faed8
Simplify and format code
wrzadkow 68b8713
Save checkpoints less frequently
wrzadkow 57dd0a3
Update the README
wrzadkow d7a8fa4
Don't send values and log probs to remote process and back
wrzadkow f9e37fe
Add tensorboard.dev trace
wrzadkow 70d21f7
Remove unneeded function get_state()
wrzadkow 342786b
Small type hints & docstrings enhancement
wrzadkow a4dade8
Use ml_collections for hyperparameter handling
wrzadkow 315902b
Refactor a long statement
wrzadkow d2eae5c
Test: use assertEqual and clip rewards when testing them
wrzadkow d444075
Compile vectorized code instead of vectorizing compiled code
wrzadkow f3a9d03
Specify static_argnums with proper int
wrzadkow File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| # Proximal Policy Optimization | ||
|
|
||
| Uses the Proximal Policy Optimization algorithm ([Schulman et al., 2017](https://arxiv.org/abs/1707.06347)) | ||
| to learn playing Atari games. | ||
|
|
||
| ## Requirements | ||
|
|
||
| This example depends on the `gym`, `opencv-python` and `atari-py` packages | ||
| in addition to `jax` and `flax`. | ||
|
|
||
| ## Supported setups | ||
|
|
||
| The example should run with other configurations and hardware, but was explicitly | ||
| tested on the following: | ||
|
|
||
| | Hardware | Game | Training time | Total frames seen | TensorBoard.dev | | ||
| | --- | --- | --- | --- | --- | | ||
| | 1x V100 GPU | Qbert | 9h 27m 8s | 40M | [2020-09-30](https://tensorboard.dev/experiment/1pacpbxxRz2di3NIOFkHoA/#scalars) | | ||
|
|
||
| ## How to run | ||
|
|
||
| Running `python ppo_main.py` will run the example with default | ||
| (hyper)parameters, i.e. for 40M frames on the Pong game. | ||
|
|
||
| By default logging info and checkpoints will be stored in `/tmp/ppo_training` | ||
| directory. This can be overriden as follows: | ||
|
|
||
| ```python ppo_main.py --logdir=/my_fav_directory``` | ||
|
|
||
| You can also override the default (hyper)parameters, for example | ||
|
|
||
| ```python ppo_main.py --config.game=Seaquest --config.total_frames=20000000 --config.decaying_lr_and_clip_param=False --logdir=/tmp/seaquest``` | ||
|
|
||
| will train the model on 20M Seaquest frames with constant (i.e. not linearly | ||
| decaying) learning rate and PPO clipping parameter. Checkpoints and tensorboard | ||
| files will be saved in `/tmp/seaquest`. | ||
|
|
||
| Unit tests can be run using `python ppo_lib_test.py`. | ||
|
|
||
| ## How to run on Google Cloud TPU | ||
|
|
||
| It is also possible to run this code on Google Cloud TPU. For detailed | ||
| instructions on the required setup, please refer to the [WMT example readme](https://github.com/google/flax/tree/master/examples/wmt). | ||
|
|
||
| ## Owners | ||
|
|
||
| Jonathan Heek @jheek, Wojciech Rzadkowski @wrzadkow | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| """Agent utilities, incl. choosing the move and running in separate process.""" | ||
|
|
||
| import multiprocessing | ||
| import collections | ||
| import jax | ||
wrzadkow marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| import numpy as onp | ||
wrzadkow marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| import env_utils | ||
|
|
||
| @jax.jit | ||
| def policy_action(model, state): | ||
| """Forward pass of the network.""" | ||
| out = model(state) | ||
| return out | ||
|
|
||
|
|
||
| ExpTuple = collections.namedtuple( | ||
| 'ExpTuple', ['state', 'action', 'reward', 'value', 'log_prob', 'done']) | ||
|
|
||
|
|
||
| class RemoteSimulator: | ||
| """Wrap functionality for an agent emulating Atari in a separate process. | ||
|
|
||
| An object of this class is created for every agent. | ||
| """ | ||
|
|
||
| def __init__(self, game: str): | ||
| """Start the remote process and create Pipe() to communicate with it.""" | ||
| parent_conn, child_conn = multiprocessing.Pipe() | ||
| self.proc = multiprocessing.Process( | ||
| target=rcv_action_send_exp, args=(child_conn, game)) | ||
| self.conn = parent_conn | ||
| self.proc.start() | ||
|
|
||
|
|
||
| def rcv_action_send_exp(conn, game: str): | ||
| """Run the remote agents. | ||
|
|
||
| Receive action from the main learner, perform one step of simulation and | ||
| send back collected experience. | ||
| """ | ||
| env = env_utils.create_env(game, clip_rewards=True) | ||
| while True: | ||
| obs = env.reset() | ||
| done = False | ||
| # Observations fetched from Atari env need additional batch dimension. | ||
| state = obs[None, ...] | ||
| while not done: | ||
| conn.send(state) | ||
| action = conn.recv() | ||
| obs, reward, done, _ = env.step(action) | ||
| next_state = obs[None, ...] if not done else None | ||
| experience = (state, action, reward, done) | ||
| conn.send(experience) | ||
| if done: | ||
| break | ||
| state = next_state | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| """Definitions of default hyperparameters.""" | ||
|
|
||
| import ml_collections | ||
|
|
||
| def get_config(): | ||
| """Get the default configuration. | ||
|
|
||
| The default hyperparameters originate from PPO paper arXiv:1707.06347 | ||
| and openAI baselines 2:: | ||
| https://github.com/openai/baselines/blob/master/baselines/ppo2/defaults.py | ||
| """ | ||
| config = ml_collections.ConfigDict() | ||
| # The Atari game used. | ||
| config.game = 'Pong' | ||
| # Total number of frames seen during training. | ||
| config.total_frames = 40000000 | ||
| # The learning rate for the Adam optimizer. | ||
| config.learning_rate = 2.5e-4 | ||
| # Batch size used in training. | ||
| config.batch_size = 256 | ||
| # Number of agents playing in parallel. | ||
| config.num_agents = 8 | ||
| # Number of steps each agent performs in one policy unroll. | ||
| config.actor_steps = 128 | ||
| # Number of training epochs per each unroll of the policy. | ||
| config.num_epochs = 3 | ||
| # RL discount parameter. | ||
| config.gamma = 0.99 | ||
| # Generalized Advantage Estimation parameter. | ||
| config.lambda_ = 0.95 | ||
| # The PPO clipping parameter used to clamp ratios in loss function. | ||
| config.clip_param = 0.1 | ||
| # Weight of value function loss in the total loss. | ||
| config.vf_coeff = 0.5 | ||
| # Weight of entropy bonus in the total loss. | ||
| config.entropy_coeff = 0.01 | ||
| # Linearly decay learning rate and clipping parameter to zero during | ||
| # the training. | ||
| config.decaying_lr_and_clip_param = True | ||
| return config |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| """Utilities for handling the Atari environment.""" | ||
|
|
||
| import collections | ||
| import gym | ||
| import numpy as onp | ||
|
|
||
| import seed_rl_atari_preprocessing | ||
|
|
||
| class ClipRewardEnv(gym.RewardWrapper): | ||
| """Adapted from OpenAI baselines. | ||
|
|
||
| github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py | ||
| """ | ||
|
|
||
| def __init__(self, env): | ||
| gym.RewardWrapper.__init__(self, env) | ||
|
|
||
| def reward(self, reward): | ||
| """Bin reward to {+1, 0, -1} by its sign.""" | ||
| return onp.sign(reward) | ||
|
|
||
| class FrameStack: | ||
| """Implements stacking of `num_frames` last frames of the game. | ||
|
|
||
| Wraps an AtariPreprocessing object. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| preproc: seed_rl_atari_preprocessing.AtariPreprocessing, | ||
| num_frames: int): | ||
| self.preproc = preproc | ||
| self.num_frames = num_frames | ||
| self.frames = collections.deque(maxlen=num_frames) | ||
|
|
||
| def reset(self): | ||
| ob = self.preproc.reset() | ||
| for _ in range(self.num_frames): | ||
| self.frames.append(ob) | ||
| return self._get_array() | ||
|
|
||
| def step(self, action: int): | ||
| ob, reward, done, info = self.preproc.step(action) | ||
| self.frames.append(ob) | ||
| return self._get_array(), reward, done, info | ||
|
|
||
| def _get_array(self): | ||
| assert len(self.frames) == self.num_frames | ||
| return onp.concatenate(self.frames, axis=-1) | ||
|
|
||
| def create_env(game: str, clip_rewards: bool): | ||
| """Create a FrameStack object that serves as environment for the `game`.""" | ||
| env = gym.make(game) | ||
| if clip_rewards: | ||
| env = ClipRewardEnv(env) # bin rewards to {-1., 0., 1.} | ||
| preproc = seed_rl_atari_preprocessing.AtariPreprocessing(env) | ||
| stack = FrameStack(preproc, num_frames=4) | ||
| return stack | ||
|
|
||
| def get_num_actions(game: str): | ||
| """Get the number of possible actions of a given Atari game. | ||
|
|
||
| This determines the number of outputs in the actor part of the | ||
| actor-critic model. | ||
| """ | ||
| env = gym.make(game) | ||
| return env.action_space.n |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| """Class and functions to define and initialize the actor-critic model.""" | ||
|
|
||
| import numpy as onp | ||
| import flax | ||
| from flax import nn | ||
| import jax.numpy as jnp | ||
|
|
||
| class ActorCritic(flax.nn.Module): | ||
| """Class defining the actor-critic model.""" | ||
|
|
||
| def apply(self, x, num_outputs): | ||
| """Define the convolutional network architecture. | ||
|
|
||
| Architecture originates from "Human-level control through deep reinforcement | ||
| learning.", Nature 518, no. 7540 (2015): 529-533. | ||
| Note that this is different than the one from "Playing atari with deep | ||
| reinforcement learning." arxiv.org/abs/1312.5602 (2013) | ||
| """ | ||
| dtype = jnp.float32 | ||
| x = x.astype(dtype) / 255. | ||
| x = nn.Conv(x, features=32, kernel_size=(8, 8), | ||
| strides=(4, 4), name='conv1', | ||
| dtype=dtype) | ||
| x = nn.relu(x) | ||
| x = nn.Conv(x, features=64, kernel_size=(4, 4), | ||
| strides=(2, 2), name='conv2', | ||
| dtype=dtype) | ||
| x = nn.relu(x) | ||
| x = nn.Conv(x, features=64, kernel_size=(3, 3), | ||
| strides=(1, 1), name='conv3', | ||
| dtype=dtype) | ||
| x = nn.relu(x) | ||
| x = x.reshape((x.shape[0], -1)) # flatten | ||
| x = nn.Dense(x, features=512, name='hidden', dtype=dtype) | ||
| x = nn.relu(x) | ||
| # Network used to both estimate policy (logits) and expected state value. | ||
| # See github.com/openai/baselines/blob/master/baselines/ppo1/cnn_policy.py | ||
| logits = nn.Dense(x, features=num_outputs, name='logits', dtype=dtype) | ||
| policy_log_probabilities = nn.log_softmax(logits) | ||
| value = nn.Dense(x, features=1, name='value', dtype=dtype) | ||
| return policy_log_probabilities, value | ||
|
|
||
| def create_model(key: onp.ndarray, num_outputs: int): | ||
| input_dims = (1, 84, 84, 4) # (minibatch, height, width, stacked frames) | ||
| module = ActorCritic.partial(num_outputs=num_outputs) | ||
| _, initial_par = module.init_by_shape(key, [(input_dims, jnp.float32)]) | ||
| model = flax.nn.Model(module, initial_par) | ||
| return model | ||
|
|
||
| def create_optimizer(model: nn.base.Model, learning_rate: float): | ||
| optimizer_def = flax.optim.Adam(learning_rate) | ||
| optimizer = optimizer_def.create(model) | ||
| return optimizer |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.