-
Notifications
You must be signed in to change notification settings - Fork 57
/
ppg.py
276 lines (231 loc) · 8.22 KB
/
ppg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
from copy import deepcopy
from . import ppo
from . import logger
import torch as th
import itertools
from . import torch_util as tu
from torch import distributions as td
from .distr_builder import distr_builder
from mpi4py import MPI
from .tree_util import tree_map, tree_reduce
import operator
def sum_nonbatch(logprob_tree):
"""
sums over nonbatch dimensions and over all leaves of the tree
use with nested action spaces, which require Product distributions
"""
return tree_reduce(operator.add, tree_map(tu.sum_nonbatch, logprob_tree))
class PpoModel(th.nn.Module):
def forward(self, ob, first, state_in) -> "pd, vpred, aux, state_out":
raise NotImplementedError
@tu.no_grad
def act(self, ob, first, state_in):
pd, vpred, _, state_out = self(
ob=tree_map(lambda x: x[:, None], ob),
first=first[:, None],
state_in=state_in,
)
ac = pd.sample()
logp = sum_nonbatch(pd.log_prob(ac))
return (
tree_map(lambda x: x[:, 0], ac),
state_out,
dict(vpred=vpred[:, 0], logp=logp[:, 0]),
)
@tu.no_grad
def v(self, ob, first, state_in):
_pd, vpred, _, _state_out = self(
ob=tree_map(lambda x: x[:, None], ob),
first=first[:, None],
state_in=state_in,
)
return vpred[:, 0]
class PhasicModel(PpoModel):
def forward(self, ob, first, state_in) -> "pd, vpred, aux, state_out":
raise NotImplementedError
def compute_aux_loss(self, aux, mb):
raise NotImplementedError
def initial_state(self, batchsize):
raise NotImplementedError
def aux_keys(self) -> "list of keys needed in mb dict for compute_aux_loss":
raise NotImplementedError
def set_aux_phase(self, is_aux_phase: bool):
"sometimes you want to modify the model, e.g. add a stop gradient"
class PhasicValueModel(PhasicModel):
def __init__(
self,
obtype,
actype,
enc_fn,
arch="dual", # shared, detach, dual
):
super().__init__()
detach_value_head = False
vf_keys = None
pi_key = "pi"
if arch == "shared":
true_vf_key = "pi"
elif arch == "detach":
true_vf_key = "pi"
detach_value_head = True
elif arch == "dual":
true_vf_key = "vf"
else:
assert False
vf_keys = vf_keys or [true_vf_key]
self.pi_enc = enc_fn(obtype)
self.pi_key = pi_key
self.true_vf_key = true_vf_key
self.vf_keys = vf_keys
self.enc_keys = list(set([pi_key] + vf_keys))
self.detach_value_head = detach_value_head
pi_outsize, self.make_distr = distr_builder(actype)
for k in self.enc_keys:
self.set_encoder(k, enc_fn(obtype))
for k in self.vf_keys:
lastsize = self.get_encoder(k).codetype.size
self.set_vhead(k, tu.NormedLinear(lastsize, 1, scale=0.1))
lastsize = self.get_encoder(self.pi_key).codetype.size
self.pi_head = tu.NormedLinear(lastsize, pi_outsize, scale=0.1)
self.aux_vf_head = tu.NormedLinear(lastsize, 1, scale=0.1)
def compute_aux_loss(self, aux, seg):
vtarg = seg["vtarg"]
return {
"vf_aux": 0.5 * ((aux["vpredaux"] - vtarg) ** 2).mean(),
"vf_true": 0.5 * ((aux["vpredtrue"] - vtarg) ** 2).mean(),
}
def reshape_x(self, x):
b, t = x.shape[:2]
x = x.reshape(b, t, -1)
return x
def get_encoder(self, key):
return getattr(self, key + "_enc")
def set_encoder(self, key, enc):
setattr(self, key + "_enc", enc)
def get_vhead(self, key):
return getattr(self, key + "_vhead")
def set_vhead(self, key, layer):
setattr(self, key + "_vhead", layer)
def forward(self, ob, first, state_in):
state_out = {}
x_out = {}
for k in self.enc_keys:
x_out[k], state_out[k] = self.get_encoder(k)(ob, first, state_in[k])
x_out[k] = self.reshape_x(x_out[k])
pi_x = x_out[self.pi_key]
pivec = self.pi_head(pi_x)
pd = self.make_distr(pivec)
aux = {}
for k in self.vf_keys:
if self.detach_value_head:
x_out[k] = x_out[k].detach()
aux[k] = self.get_vhead(k)(x_out[k])[..., 0]
vfvec = aux[self.true_vf_key]
aux.update({"vpredaux": self.aux_vf_head(pi_x)[..., 0], "vpredtrue": vfvec})
return pd, vfvec, aux, state_out
def initial_state(self, batchsize):
return {k: self.get_encoder(k).initial_state(batchsize) for k in self.enc_keys}
def aux_keys(self):
return ["vtarg"]
def make_minibatches(segs, mbsize):
"""
Yield one epoch of minibatch over the dataset described by segs
Each minibatch mixes data between different segs
"""
nenv = tu.batch_len(segs[0])
nseg = len(segs)
envs_segs = th.tensor(list(itertools.product(range(nenv), range(nseg))))
for perminds in th.randperm(len(envs_segs)).split(mbsize):
esinds = envs_segs[perminds]
yield tu.tree_stack(
[tu.tree_slice(segs[segind], envind) for (envind, segind) in esinds]
)
def aux_train(*, model, segs, opt, mbsize, name2coef):
"""
Train on auxiliary loss + policy KL + vf distance
"""
needed_keys = {"ob", "first", "state_in", "oldpd"}.union(model.aux_keys())
segs = [{k: seg[k] for k in needed_keys} for seg in segs]
for mb in make_minibatches(segs, mbsize):
mb = tree_map(lambda x: x.to(tu.dev()), mb)
pd, _, aux, _state_out = model(mb["ob"], mb["first"], mb["state_in"])
name2loss = {}
name2loss["pol_distance"] = td.kl_divergence(mb["oldpd"], pd).mean()
name2loss.update(model.compute_aux_loss(aux, mb))
assert set(name2coef.keys()).issubset(name2loss.keys())
loss = 0
for name in name2loss.keys():
unscaled_loss = name2loss[name]
scaled_loss = unscaled_loss * name2coef.get(name, 1)
logger.logkv_mean("unscaled/" + name, unscaled_loss)
logger.logkv_mean("scaled/" + name, scaled_loss)
loss += scaled_loss
opt.zero_grad()
loss.backward()
tu.sync_grads(model.parameters())
opt.step()
def compute_presleep_outputs(
*, model, segs, mbsize, pdkey="oldpd", vpredkey="oldvpred"
):
def forward(ob, first, state_in):
pd, vpred, _aux, _state_out = model.forward(ob.to(tu.dev()), first, state_in)
return pd, vpred
for seg in segs:
seg[pdkey], seg[vpredkey] = tu.minibatched_call(
forward, mbsize, ob=seg["ob"], first=seg["first"], state_in=seg["state_in"]
)
def learn(
*,
model,
venv,
ppo_hps,
aux_lr,
aux_mbsize,
n_aux_epochs=6,
n_pi=32,
kl_ewma_decay=None,
interacts_total=float("inf"),
name2coef=None,
comm=None,
):
"""
Run PPO for X iterations
Then minimize aux loss + KL + value distance for X passes over data
"""
if comm is None:
comm = MPI.COMM_WORLD
ppo_state = None
aux_state = th.optim.Adam(model.parameters(), lr=aux_lr)
name2coef = name2coef or {}
while True:
store_segs = n_pi != 0 and n_aux_epochs != 0
# Policy phase
ppo_state = ppo.learn(
venv=venv,
model=model,
learn_state=ppo_state,
callbacks=[
lambda _l: n_pi > 0 and _l["curr_iteration"] >= n_pi,
],
interacts_total=interacts_total,
store_segs=store_segs,
comm=comm,
**ppo_hps,
)
if ppo_state["curr_interact_count"] >= interacts_total:
break
if n_aux_epochs > 0:
segs = ppo_state["seg_buf"]
compute_presleep_outputs(model=model, segs=segs, mbsize=aux_mbsize)
# Auxiliary phase
for i in range(n_aux_epochs):
logger.log(f"Aux epoch {i}")
aux_train(
model=model,
segs=segs,
opt=aux_state,
mbsize=aux_mbsize,
name2coef=name2coef,
)
logger.dumpkvs()
segs.clear()