Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] Are the jits around pmaps intended? #8

Closed
MaximilienLC opened this issue Feb 17, 2022 · 1 comment
Closed

[Question] Are the jits around pmaps intended? #8

MaximilienLC opened this issue Feb 17, 2022 · 1 comment
Assignees

Comments

@MaximilienLC
Copy link

self._train_rollout_fn = jax.jit(jax.pmap(

The docs seem to say that jits around pmaps are unnecessary: https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html#pmap-and-jit

While running experiments, I also often get this warning which seems to say that it might be problematic:
UserWarning: The jitted function <unnamed function> includes a pmap. Using jit-of-pmap can lead to inefficient data movement, as the outer jit does not preserve sharded data representations and instead collects input and output arrays onto a single device. Consider removing the outer jit unless you know what you're doing. See [https://github.com/google/jax/issues/2926].

If the behaviour is intended please discard.

@lerrytang
Copy link
Contributor

Hi, thanks for raising this.
I'm aware of this warning and it is indeed intended.
If you run experiments on multiple devices (e.g. this, which by defaults uses colab TPUs), you will notice the difference with and without this outer jit.

@lerrytang lerrytang self-assigned this Feb 22, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants