/
ppo_lib.py
304 lines (272 loc) · 11.5 KB
/
ppo_lib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
# Copyright 2020 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library file which executes the PPO training."""
import functools
from typing import Tuple, List
import jax
import jax.random
import jax.numpy as jnp
import numpy as onp
import flax
from flax.metrics import tensorboard
from flax.training import checkpoints
import ml_collections
import agent
import models
import test_episodes
@jax.jit
@functools.partial(jax.vmap, in_axes=(1, 1, 1, None, None), out_axes=1)
def gae_advantages(
rewards: onp.ndarray,
terminal_masks: onp.ndarray,
values: onp.ndarray,
discount: float,
gae_param: float):
"""Use Generalized Advantage Estimation (GAE) to compute advantages.
As defined by eqs. (11-12) in PPO paper arXiv: 1707.06347. Implementation uses
key observation that A_{t} = delta_t + gamma*lambda*A_{t+1}.
Args:
rewards: array shaped (actor_steps, num_agents), rewards from the game
terminal_masks: array shaped (actor_steps, num_agents), zeros for terminal
and ones for non-terminal states
values: array shaped (actor_steps, num_agents), values estimated by critic
discount: RL discount usually denoted with gamma
gae_param: GAE parameter usually denoted with lambda
Returns:
advantages: calculated advantages shaped (actor_steps, num_agents)
"""
assert rewards.shape[0] + 1 == values.shape[0], ('One more value needed; Eq. '
'(12) in PPO paper requires '
'V(s_{t+1}) for delta_t')
advantages = []
gae = 0.
for t in reversed(range(len(rewards))):
# Masks used to set next state value to 0 for terminal states.
value_diff = discount * values[t + 1] * terminal_masks[t] - values[t]
delta = rewards[t] + value_diff
# Masks[t] used to ensure that values before and after a terminal state
# are independent of each other.
gae = delta + discount * gae_param * terminal_masks[t] * gae
advantages.append(gae)
advantages = advantages[::-1]
return jnp.array(advantages)
@functools.partial(jax.jit, static_argnums=1)
def loss_fn(
params: flax.core.frozen_dict.FrozenDict,
module: models.ActorCritic,
minibatch: Tuple,
clip_param: float,
vf_coeff: float,
entropy_coeff: float):
"""Evaluate the loss function.
Compute loss as a sum of three components: the negative of the PPO clipped
surrogate objective, the value function loss and the negative of the entropy
bonus.
Args:
params: the parameters of the actor-critic model
module: the actor-critic model
minibatch: Tuple of five elements forming one experience batch:
states: shape (batch_size, 84, 84, 4)
actions: shape (batch_size, 84, 84, 4)
old_log_probs: shape (batch_size,)
returns: shape (batch_size,)
advantages: shape (batch_size,)
clip_param: the PPO clipping parameter used to clamp ratios in loss function
vf_coeff: weighs value function loss in total loss
entropy_coeff: weighs entropy bonus in the total loss
Returns:
loss: the PPO loss, scalar quantity
"""
states, actions, old_log_probs, returns, advantages = minibatch
log_probs, values = agent.policy_action(params, module, states)
values = values[:, 0] # Convert shapes: (batch, 1) to (batch, ).
probs = jnp.exp(log_probs)
value_loss = jnp.mean(jnp.square(returns - values), axis=0)
entropy = jnp.sum(-probs*log_probs, axis=1).mean()
log_probs_act_taken = jax.vmap(lambda lp, a: lp[a])(log_probs, actions)
ratios = jnp.exp(log_probs_act_taken - old_log_probs)
# Advantage normalization (following the OpenAI baselines).
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
PG_loss = ratios * advantages
clipped_loss = advantages * jax.lax.clamp(1. - clip_param, ratios,
1. + clip_param)
PPO_loss = -jnp.mean(jnp.minimum(PG_loss, clipped_loss), axis=0)
return PPO_loss + vf_coeff*value_loss - entropy_coeff*entropy
@functools.partial(jax.jit, static_argnums=(0,7))
def train_step(
module: models.ActorCritic,
optimizer: flax.optim.base.Optimizer,
trajectories: Tuple,
clip_param: float,
vf_coeff: float,
entropy_coeff: float,
lr: float,
batch_size: int):
"""Compilable train step.
Runs an entire epoch of training (i.e. the loop over minibatches within
an epoch is included here for performance reasons).
Args:
module: the actor-critic model
optimizer: optimizer for the actor-critic model
trajectories: Tuple of the following five elements forming the experience:
states: shape (steps_per_agent*num_agents, 84, 84, 4)
actions: shape (steps_per_agent*num_agents, 84, 84, 4)
old_log_probs: shape (steps_per_agent*num_agents, )
returns: shape (steps_per_agent*num_agents, )
advantages: (steps_per_agent*num_agents, )
clip_param: the PPO clipping parameter used to clamp ratios in loss function
vf_coeff: weighs value function loss in total loss
entropy_coeff: weighs entropy bonus in the total loss
lr: learning rate, varies between optimization steps
if decaying_lr_and_clip_param is set to true
batch_size: the minibatch size, static argument
Returns:
optimizer: new optimizer after the parameters update
loss: loss summed over training steps
"""
iterations = trajectories[0].shape[0] // batch_size
trajectories = jax.tree_map(
lambda x: x.reshape((iterations, batch_size) + x.shape[1:]), trajectories)
loss = 0.
for batch in zip(*trajectories):
grad_fn = jax.value_and_grad(loss_fn)
l, grad = grad_fn(optimizer.target, module, batch, clip_param, vf_coeff,
entropy_coeff)
loss += l
optimizer = optimizer.apply_gradient(grad, learning_rate=lr)
return optimizer, loss
def get_experience(
params: flax.core.frozen_dict.FrozenDict,
module: models.ActorCritic,
simulators: List[agent.RemoteSimulator],
steps_per_actor: int):
"""Collect experience from agents.
Runs `steps_per_actor` time steps of the game for each of the `simulators`.
"""
all_experience = []
# Range up to steps_per_actor + 1 to get one more value needed for GAE.
for _ in range(steps_per_actor + 1):
states = []
for sim in simulators:
state = sim.conn.recv()
states.append(state)
states = onp.concatenate(states, axis=0)
log_probs, values = agent.policy_action(params, module, states)
log_probs, values = jax.device_get((log_probs, values))
probs = onp.exp(onp.array(log_probs))
for i, sim in enumerate(simulators):
probabilities = probs[i]
action = onp.random.choice(probs.shape[1], p=probabilities)
sim.conn.send(action)
experiences = []
for i, sim in enumerate(simulators):
state, action, reward, done = sim.conn.recv()
value = values[i, 0]
log_prob = log_probs[i][action]
sample = agent.ExpTuple(state, action, reward, value, log_prob, done)
experiences.append(sample)
all_experience.append(experiences)
return all_experience
def process_experience(
experience: List[List[agent.ExpTuple]],
actor_steps: int,
num_agents: int,
gamma: float,
lambda_: float):
"""Process experience for training, including advantage estimation.
Args:
experience: collected from agents in the form of nested lists/namedtuple
actor_steps: number of steps each agent has completed
num_agents: number of agents that collected experience
gamma: dicount parameter
lambda_: GAE parameter
Returns:
trajectories: trajectories readily accessible for `train_step()` function
"""
obs_shape = (84, 84, 4)
exp_dims = (actor_steps, num_agents)
values_dims = (actor_steps + 1, num_agents)
states = onp.zeros(exp_dims + obs_shape, dtype=onp.float32)
actions = onp.zeros(exp_dims, dtype=onp.int32)
rewards = onp.zeros(exp_dims, dtype=onp.float32)
values = onp.zeros(values_dims, dtype=onp.float32)
log_probs = onp.zeros(exp_dims, dtype=onp.float32)
dones = onp.zeros(exp_dims, dtype=onp.float32)
for t in range(len(experience) - 1): # experience[-1] only for next_values
for agent_id, exp_agent in enumerate(experience[t]):
states[t, agent_id, ...] = exp_agent.state
actions[t, agent_id] = exp_agent.action
rewards[t, agent_id] = exp_agent.reward
values[t, agent_id] = exp_agent.value
log_probs[t, agent_id] = exp_agent.log_prob
# Dones need to be 0 for terminal states.
dones[t, agent_id] = float(not exp_agent.done)
for a in range(num_agents):
values[-1, a] = experience[-1][a].value
advantages = gae_advantages(rewards, dones, values, gamma, lambda_)
returns = advantages + values[:-1, :]
# After preprocessing, concatenate data from all agents.
trajectories = (states, actions, log_probs, returns, advantages)
trajectory_len = num_agents * actor_steps
trajectories = tuple(map(
lambda x: onp.reshape(x, (trajectory_len,) + x.shape[2:]), trajectories))
return trajectories
def train(
module: models.ActorCritic,
optimizer: flax.optim.base.Optimizer,
config: ml_collections.ConfigDict,
model_dir: str):
"""Main training loop.
Args:
module: the actor-critic model
optimizer: optimizer for the actor-critic model
config: object holding hyperparameters and the training information
model_dir: path to dictionary where checkpoints and logging info are stored
Returns:
optimizer: the trained optimizer
"""
game = config.game + 'NoFrameskip-v4'
simulators = [agent.RemoteSimulator(game)
for _ in range(config.num_agents)]
summary_writer = tensorboard.SummaryWriter(model_dir)
summary_writer.hparams(dict(config))
loop_steps = config.total_frames // (config.num_agents * config.actor_steps)
log_frequency = 40
checkpoint_frequency = 500
for s in range(loop_steps):
# Bookkeeping and testing.
if s % log_frequency == 0:
score = test_episodes.policy_test(1, module, optimizer.target, game)
frames = s * config.num_agents * config.actor_steps
summary_writer.scalar('game_score', score, frames)
print(f'Step {s}:\nframes seen {frames}\nscore {score}\n\n')
if s % checkpoint_frequency == 0:
checkpoints.save_checkpoint(model_dir, optimizer, s)
# Core training code.
alpha = 1. - s/loop_steps if config.decaying_lr_and_clip_param else 1.
all_experiences = get_experience(
optimizer.target, module, simulators, config.actor_steps)
trajectories = process_experience(
all_experiences, config.actor_steps, config.num_agents, config.gamma,
config.lambda_)
lr = config.learning_rate * alpha
clip_param = config.clip_param * alpha
for e in range(config.num_epochs):
permutation = onp.random.permutation(
config.num_agents * config.actor_steps)
trajectories = tuple(map(lambda x: x[permutation], trajectories))
optimizer, loss = train_step(
module, optimizer, trajectories, clip_param, config.vf_coeff,
config.entropy_coeff, lr, config.batch_size)
return optimizer