Skip to content

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reacher Task and AutoResetWrapper #443

Closed
esraaelelimy opened this issue Jan 18, 2024 · 1 comment
Closed

Reacher Task and AutoResetWrapper #443

esraaelelimy opened this issue Jan 18, 2024 · 1 comment

Comments

@esraaelelimy
Copy link

I am looking to use Brax Reacher task as an alternative to Mujoco Reacher for some RL tasks, but I have some concerns:
In Mujoco Reacher task , if the fingertip reaches the target, a new random target appears. Also, at the beginning of each new episode, the target position changes. In Brax, I see that the target position is only generated when the environment is rested. Moreover, when using the Autoreset wrapper, at the reset, it fetches the 'first state,' which means that the random target is generated once at the very beginning, and it never changes. Does this make the Brax version of Reacher easy to solve compared to Mujoco's Reacher? and how can we allow the Autoresetwrapper to actually change the target every reset without sacrificing the speed?

@btaba
Copy link
Collaborator

btaba commented Feb 8, 2024

Hi @esraaelelimy , indeed AutoResetWrapper will cache the first_state, but the first_state is sampled with a different rng for each environment. As the number of parallel environments goes up, the diversity of sampled first_states increases. You're right to point out that this is done for performance reasons. To do a reset, you'll have to call reset with a different rng. See the example here:

for _ in range(max(num_resets_per_eval, 1)):
# optimization
epoch_key, local_key = jax.random.split(local_key)
epoch_keys = jax.random.split(epoch_key, local_devices_to_use)
(training_state, env_state, training_metrics) = (
training_epoch_with_timing(training_state, env_state, epoch_keys)
)
current_step = int(_unpmap(training_state.env_steps))
key_envs = jax.vmap(
lambda x, s: jax.random.split(x[0], s),
in_axes=(0, None))(key_envs, key_envs.shape[1])
# TODO: move extra reset logic to the AutoResetWrapper.
env_state = reset_fn(key_envs) if num_resets_per_eval > 0 else env_state

But we have not done in-depth analysis on some of these hyperparameters (i.e. num_resets_per_eval).

@google google locked and limited conversation to collaborators Feb 8, 2024
@btaba btaba converted this issue into discussion #452 Feb 8, 2024

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants