Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Running different environments in parallel #143

Closed
jbgaya opened this issue Jan 3, 2022 · 2 comments
Closed

Running different environments in parallel #143

jbgaya opened this issue Jan 3, 2022 · 2 comments

Comments

@jbgaya
Copy link

jbgaya commented Jan 3, 2022

Hi and thanks for this great simulator
I would like to know if it is possible to run multiple environments in parallel (for example Halfcheetah with different gravity coefficients). From what I understand, it is quite difficult to do so because the brax.System() of an environment is fixed and hard to change.

Ideally the step() function of an environment would take not only the state and action, but also a system as an input. But I don't think it is feasible to vectorize that object in jax, right ?

Any ideas ?

@erikfrey
Copy link
Collaborator

erikfrey commented Jan 6, 2022

Hi @jbgaya . System is a pytree:

https://github.com/google/brax/blob/main/brax/physics/system.py#L33

So in theory you could vmap over multiple systems. For example this code runs:

sys1 = envs.create('ant').sys
sys2 = copy.copy(sys1)
sys2.integrator.dt = 0.002
sys = jax.tree_map(lambda x, y: jnp.stack([x,y]), *[sys1, sys2])

But it won't work out of the box as default_qp won't produce the right results - we'd have to change the way System allocates internal fields.

That said, jax executes asynchronously - so have you tried just running multiple step() functions from multiple different brax.System() instances serially and then stacking the results? I would expect that to work OK, although it may be slow to JIT if you're doing hundreds or thousands of Systems.

@jbgaya
Copy link
Author

jbgaya commented Jan 10, 2022

That said, jax executes asynchronously - so have you tried just running multiple step() functions from multiple different brax.System() instances serially and then stacking the results? I would expect that to work OK, although it may be slow to JIT if you're doing hundreds or thousands of Systems.

Indeed this is the easiest option since acquisition is fast. Will try. Thanks !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants