Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix support for multi-dim observations #22

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

danielgafni
Copy link
Contributor

@danielgafni danielgafni commented Mar 29, 2022

Hey! I found a bug in the observations normalization code. The bug occurs when the observations are not a flat array, but a multi-dim array. This happens because the obs_normalizer params are stored as a flat array. The code fails in this case. Here is the fix for this bug.

@lerrytang lerrytang self-assigned this Mar 31, 2022
@lerrytang
Copy link
Contributor

Thanks for pointing out the bug!
I did a test on train_mnist.py with your PR (set obs_normalization=True in trainer), however, I still got an error.
Did you not have problems in your tests?

@danielgafni
Copy link
Contributor Author

Hmm, no, the code is working correctly in my project. I didn't test it with the examples tho. Will do it and fix the errors.

@danielgafni
Copy link
Contributor Author

danielgafni commented Apr 4, 2022

Ok, so 2 more dimensions are present in the obs_buffer in the MNIST example:

>>> obs_buffer.shape
(1, 64, 1024, 28, 28, 1)

Their meaning is:
[?, pop_size, batch_size, height, width, channels]

While

>>> running_mean.shape
(784,)

My code handles the last dimensions, they are expected. But the first 2 are causing the error.

  1. I don't know where does the very first (1) dimension come from. Is this the number of agents? ...
  2. Looks like the third dimensions, the batch_size, is causing another error. The MNIST task has obs_size of (28, 28, 1). However, the actual obs_buffer also has the batch_size dim introduced by the sample_batch function. The ObsNormalizer doesn't know anything about the batch size. Seems like it needs to have another argument, something like reduce_dims, where we would specify our custom batch_size dimension. Maybe you can suggest another fix? I hope I explained the problem clean enough.

@danielgafni
Copy link
Contributor Author

hey @lerrytang! how any update on this issue?

@danielgafni
Copy link
Contributor Author

for example, take a look at the brax implementation:

https://github.com/google/brax/blob/main/brax/training/normalization.py

They have a num_leading_batch_dims parameter for the normalizer. Seems like evojax can do the same?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants