Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial support for Jax env #2955

Closed
wants to merge 29 commits into from

Conversation

pseudo-rnd-thoughts
Copy link
Contributor

@pseudo-rnd-thoughts pseudo-rnd-thoughts commented Jul 7, 2022

With the future addition of Brax then this PR adds a new type of Env, JaxEnv along with JaxState that follows the functional API suggested in #2954

This PR introductions three new classes

  • JaxState - This is the dataclass object that contains all of the information required for the environment to take a step. The state is created during the reset function.
  • JaxEnv - This is the subclass of Env that requires functions for the reset and step functions. Using these functions, they can be JIT hardware accelerated by default and allows easier vectorisation
  • VectoriseJaxEnv - Allows JaxEnv to be passed which will contain the stateless reset and step functions. Using these functions, we can use jax.vmap or jax.pmap for parallelisation of the reset and step. This implementation does not follow use the current VectorEnv due to the number of unnecessary components however this can be changed

… to fix bug if environment doesn't use np_random in reset
…n the opposite case than was intended to (openai#2871)"

This reverts commit 519dfd9.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant