-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable opt-in autodetection of distributed configuration for mpi4py #20174
Conversation
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the late review.
Thanks a lot for the contribution. I left some comments.
jax/_src/clusters/cluster.py
Outdated
|
||
# First, we check the spec detection method because it will ignore submitted values | ||
# If if succeeds. | ||
if spec_detection_method == "mpi4py": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add a "name" attribute under ClusterEnv and set it in the derived classes: aka ompi, slurm, mpi4py, gketpu, multislicegcetpu, singleslicegcetpu
and also adding a new bool "opt_in_only_method" and have here:
if spec_detection_method is not None:
env = next((env for env in cls._cluster_types if env.name == spec_detection_method), None)
if env is None:
# print fatal error that method {spec_detection_method} is not supported.
if not env.is_env_present():
# print fatal error requested env is not present.
else:
env = next((env for env in cls._cluster_types if env.opt_in_only_method == False and env.is_env_present()), None)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, good idea, this makes it more flexible to add other opt-in methods in the future. I've implemented this with the names you suggested, every type except mpi4py
has opt_in_only_method = False
as a class attribute.
jax/_src/clusters/mpi4py_cluster.py
Outdated
from mpi4py import MPI | ||
COMM_WORLD = MPI.COMM_WORLD | ||
|
||
hostname = socket.gethostname() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you use COMM_WORLD.Split_type(MPI.COMM_TYPE_SHARED)
instead of doing the hostname query?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In principle, that seems possible. I am trying to imagine scenarios on some clusters where processors that can access the same shared memory space might be over or under restrictive compared to matching against hostname...but at worst it seems like it might lead to NUMA-induced slowdowns but not anything wrong.
For example, on one system I am thinking of, the hosts will have two CPUs with two sockets each, attached to 6 GPUs and each CPU has its own set of shared memory. Using shared memory accesses here might be ambiguous, depending on the job launcher, and it's not a COMM-split method I see used much.
If you prefer COMM_TYPE_SHARED I won't object, though I'm not sure what the downside of a host query is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even in multi-NUMA systems, it works fine since the processes can still communicate via shared-memory.
I am pretty confident that MPI.COMM_TYPE_SHARED works well for this use case.
We have been using it in Horovod (distributed layer that can be used on top of PyTorch/TF/MXNet) for years with a wide variety of users:
https://github.com/horovod/horovod/blob/master/horovod/common/mpi/mpi_context.cc#L126
I would prefer the split over implementing it ourself by broadcasting the hostname since MPI vendors may be able to optimize this specific comm split under the cover.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Corey for following up.
I left a few comments.
Do you mind re-testing the final code on your cluster and adding an unit test similar to the ompi one https://github.com/google/jax/blob/main/tests/multiprocess_gpu_test.py#L184 ?
jax/_src/clusters/cluster.py
Outdated
@@ -43,7 +44,11 @@ def auto_detect_unset_distributed_params(cls, | |||
num_processes: int | None, | |||
process_id: int | None, | |||
local_device_ids: Sequence[int] | None, | |||
<<<<<<< HEAD |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some merge conflict left over here?
jax/_src/distributed.py
Outdated
@@ -215,7 +225,11 @@ def initialize(coordinator_address: str | None = None, | |||
raise RuntimeError("jax.distributed.initialize() must be called before " | |||
"any JAX computations are executed.") | |||
global_state.initialize(coordinator_address, num_processes, process_id, | |||
<<<<<<< HEAD |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
jax/_src/clusters/mpi4py_cluster.py
Outdated
from mpi4py import MPI | ||
COMM_WORLD = MPI.COMM_WORLD | ||
|
||
hostname = socket.gethostname() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even in multi-NUMA systems, it works fine since the processes can still communicate via shared-memory.
I am pretty confident that MPI.COMM_TYPE_SHARED works well for this use case.
We have been using it in Horovod (distributed layer that can be used on top of PyTorch/TF/MXNet) for years with a wide variety of users:
https://github.com/horovod/horovod/blob/master/horovod/common/mpi/mpi_context.cc#L126
I would prefer the split over implementing it ourself by broadcasting the hostname since MPI vendors may be able to optimize this specific comm split under the cover.
Hey, I'm back, sorry for the delay. It's been a busy time here - I'm about to transition jobs and I'm wrapping up some open things. I've made the changes requested, and added the test. I'll point out that on my system, mpirun does not allow the following arguments: '--oversubscribe',
'--allow-run-as-root', So I ran the test without these and it passes. I left them in the final code, however - I'm not sure what test system you have that allows launching mpirun jobs as root but it sure isn't a DOE super computer :). Let me know if you want any more changes. I'll be able to test things for about 1 month more. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Requested changes to comments and documentation is done - thanks for the feedback :) |
@coreyjadams You would need to sign the CLA too: https://cla.developers.google.com/ |
@coreyjadams could you update the year at the top of your new file and also add it to https://github.com/google/jax/blob/98b87540a76bc3b4f9f414bb763304f1cc474544/jax/BUILD#L953 so that it gets properly packaged. |
I did the |
OK, I see the green check mark now, just had to link my work email to github. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't look at the code, but this looks useful.
Can you update the title to help understand more what it does? SOmething line:
Enable opt-in autodetection of distributed configuration for mpi4py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This mostly looks good to me: one api naming suggestion.
Please fix the lint errors and squash your commits. |
The lint errors are fixed locally - just waiting on the CI to confirm then I'll deal with squashing the commits. Thanks for the feedback. |
I made an inadvertent force push here, so I created a branch on the commit just before that: #22235 . Suggestion is if you can force the squash, please do so? Then we just close this. Appreciate working with you on this and sorry my git skills aren't up to par on this one :). Update: @felker below was able to salvage the git snafu I created, I think. Should be ok to merge? |
commit 79b8cbf Author: Corey Adams <corey.adams@anl.gov> Date: Mon Jul 1 14:14:15 2024 -0500 Fix mypy issues; change variable name to more universally known name commit 10edc86 Author: Corey Adams <corey.adams@anl.gov> Date: Thu Jun 27 13:25:32 2024 -0500 Change copyright year to the year this was authored commit f7086cb Author: Corey Adams <corey.adams@anl.gov> Date: Thu Jun 27 13:15:32 2024 -0500 Update build file to include mpi4py cluster. commit 6235eb3 Author: Corey adams <coreyjadams@gmail.com> Date: Thu Jun 27 12:11:48 2024 -0500 Update distributed.py Clean up documentation slightly. commit ef3a2e2 Author: Corey adams <coreyjadams@gmail.com> Date: Thu Jun 27 12:09:37 2024 -0500 Update mpi4py_cluster.py Further clean up unneeded comments. commit 6cc07a9 Author: Corey adams <coreyjadams@gmail.com> Date: Thu Jun 27 12:08:38 2024 -0500 Update mpi4py_cluster.py Remove unneeded commented code. commit 6701bd1 Merge: 5a91ac3 98b8754 Author: Corey adams <coreyjadams@gmail.com> Date: Thu Jun 27 12:07:25 2024 -0500 Merge branch 'google:main' into main commit 5a91ac3 Merge: 301bbc6 6c51234 Author: Corey adams <coreyjadams@gmail.com> Date: Tue May 28 22:14:08 2024 -0500 Merge branch 'google:main' into main commit 301bbc6 Author: Corey Adams <corey.adams@anl.gov> Date: Tue May 28 11:34:51 2024 -0500 Add test to verify mpi4py based distributed initialization commit 19e6694 Author: Corey Adams <corey.adams@anl.gov> Date: Tue May 28 11:14:40 2024 -0500 Unify variable naming and fix function argument ordering commit 72fe093 Author: Corey Adams <corey.adams@anl.gov> Date: Tue May 28 10:56:25 2024 -0500 Remove unmerged code commit 3a96e73 Merge: e4fd97e ff3db9b Author: Corey adams <coreyjadams@gmail.com> Date: Tue May 28 10:51:41 2024 -0500 Merge branch 'google:main' into main commit e4fd97e Merge: a697299 72a81e5 Author: Corey adams <coreyjadams@gmail.com> Date: Mon May 13 16:01:35 2024 -0500 Merge branch 'google:main' into main commit a697299 Merge: 85bcf42 1e48adc Author: Corey adams <coreyjadams@gmail.com> Date: Mon May 13 14:21:32 2024 -0500 Merge branch 'google:main' into main commit 85bcf42 Merge: af1a4f0 06cd05d Author: Corey Adams <corey.adams@anl.gov> Date: Tue Apr 16 09:09:31 2024 -0500 Merge branch 'main' of https://github.com/google/jax commit af1a4f0 Author: Corey Adams <corey.adams@anl.gov> Date: Tue Apr 16 08:58:33 2024 -0500 update documentation and elaborate on spec_detect_method variable commit 01f4709 Author: Corey Adams <corey.adams@anl.gov> Date: Tue Apr 16 08:45:38 2024 -0500 Address feedback and comments on PR 20174; fix typo in documentation. commit 4f22d86 Merge: 900a037 71ec6e3 Author: Corey adams <coreyjadams@gmail.com> Date: Mon Mar 11 11:51:30 2024 -0500 Merge branch 'google:main' into main commit 900a037 Author: Corey Adams <corey.adams@anl.gov> Date: Mon Mar 11 11:50:48 2024 -0500 Auto-detect of mpi4py-based configuration is now strictly opt-in. commit 1992969 Author: Corey Adams <corey.adams@anl.gov> Date: Thu Mar 7 12:27:43 2024 -0600 Enable automatic detection of distrbuted variables with any configuration of MPI, as long as mpi4py is available
This PR is in response to the discussion on #19409.
It does the following:
jax.distributed.initialize
can be entirely inferred frommpi4py
provided the job is launched in a way compatible withMPI
.MPI
, this autodetect method is exclusively opt-in. Users must passspec_detection_method="mpi4py"
in a call tojax.distributed.initialize
.jax.distributed.initialize
for coordinator_address, etc., will override and auto-detected settings frommpi4py
.jax.distributed.initialize
, I have updated the documentation accordingly.jax.distributed.initialize
on HPC systems, there is sometimes a hang (See Usage of `jax.distributed.initialize` on HPC clusters #9582). I have included a warning if any of the suspect variables are detected, and updated the documentation to point out that the user may want to unset these variables if they are on an HPC cluster. I suspect the warning will be viewed as to noisy, since I don't know the default log level in JAX off the top of my head. In this case, perhaps it can only be emitted if the TimeOut occurs, or at the very least maintained in the documentation.I hope this is helpful, and it would simplify our lives using JAX on supercomputers :).
Corey