-
Notifications
You must be signed in to change notification settings - Fork 2.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pytorch PPO Implementation, dimension difference #384
Comments
Hi, Kevin. I'm regularly using that same algorithm, pytorch PPO, and never
had that problem. What error are you getting exactly? And on which
environment(s)?
Best.
Alberto
El mié, 22 feb 2023, 21:25, kevin ***@***.***> escribió:
… Hello, apologies if I do this wrong I don't contribute to open source
often. I was attempting to run the Pytorch PPO implementation and kept
getting several errors regarding the dimension of the observation 'o'. I
believe this is because the return values from the environment are being
mishandled. See my proposed changes below:
Original:
o, ep_ret, ep_len = env.reset(), 0, 0
# Main loop: collect experience in env and update/log each epoch
for epoch in range(epochs):
for t in range(local_steps_per_epoch):
a, v, logp = ac.step(torch.as_tensor(o, dtype=torch.float32))
next_o, r, d, _ = env.step(a)
ep_ret += r
ep_len += 1
# save and log
buf.store(o, a, r, v, logp)
logger.store(VVals=v)
# Update obs (critical!)
o = next_o
timeout = ep_len == max_ep_len
terminal = d or timeout
epoch_ended = t==local_steps_per_epoch-1
if terminal or epoch_ended:
if epoch_ended and not(terminal):
print('Warning: trajectory cut off by epoch at %d steps.'%ep_len, flush=True)
# if trajectory didn't reach terminal state, bootstrap value target
if timeout or epoch_ended:
_, v, _ = ac.step(torch.as_tensor(o, dtype=torch.float32))
else:
v = 0
buf.finish_path(v)
if terminal:
# only save EpRet / EpLen if trajectory finished
logger.store(EpRet=ep_ret, EpLen=ep_len)
o, ep_ret, ep_len = env.reset(), 0, 0
To fix this I just get rid of the 'info' return value from env.reset, but
I imagine this might not work for all environments (if they have extra
returns for the env.step() for example).
Adjusted:
(o,_), ep_ret, ep_len = env.reset(), 0, 0
# Main loop: collect experience in env and update/log each epoch
for epoch in range(epochs):
for t in range(local_steps_per_epoch):
a, v, logp = ac.step(torch.as_tensor(o, dtype=torch.float32))
next_o, r, d, _ = env.step(a)
print(f"next_o: {next_o}")
print(f"r: {r}")
print(f"d: {d}")
print(f"_: {_}")
ep_ret += r
ep_len += 1
# save and log
buf.store(o, a, r, v, logp)
logger.store(VVals=v)
# Update obs (critical!)
o = next_o
timeout = ep_len == max_ep_len
terminal = d or timeout
epoch_ended = t==local_steps_per_epoch-1
if terminal or epoch_ended:
if epoch_ended and not(terminal):
print('Warning: trajectory cut off by epoch at %d steps.'%ep_len, flush=True)
# if trajectory didn't reach terminal state, bootstrap value target
if timeout or epoch_ended:
_, v, _ = ac.step(torch.as_tensor(o, dtype=torch.float32))
else:
v = 0
buf.finish_path(v)
if terminal:
# only save EpRet / EpLen if trajectory finished
logger.store(EpRet=ep_ret, EpLen=ep_len)
(o,_), ep_ret, ep_len = env.reset(), 0, 0
Let me know if anyone else has run into this problem or if this fix works.
—
Reply to this email directly, view it on GitHub
<#384>, or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AFRLIPR3LFAGJCMUDAKQDJDWYZY4VANCNFSM6AAAAAAVEZHQVM>
.
You are receiving this because you are subscribed to this thread.Message
ID: ***@***.***>
|
Hello,
My reset and step functions are defined as follows:
I am a little confused how this could work for other people when it seems like all the environment's should return the same tuple (https://www.gymlibrary.dev/api/core/#gym.Env.reset) |
Hello, apologies if I do this wrong I don't contribute to open source often. I was attempting to run the Pytorch PPO implementation and kept getting several errors regarding the dimension of the observation 'o'. I believe this is because the return values from the environment are being mishandled. See my proposed changes below:
Original:
To fix this I just get rid of the 'info' return value from env.reset, but I imagine this might not work for all environments (if they have extra returns for the env.step() for example).
Adjusted:
Let me know if anyone else has run into this problem or if this fix works.
The text was updated successfully, but these errors were encountered: