Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable opt-in autodetection of distributed configuration for mpi4py #20174

Merged
merged 1 commit into from
Jul 8, 2024

Conversation

coreyjadams
Copy link

This PR is in response to the discussion on #19409.

It does the following:

  • First, it adds an additional cluster environment for jax.distributed that is based on autodetection of rank, size, local_rank/devices, and coordination address based on mpi4py. The motivation here is that for clusters with mpi4py, the parameters needed for jax.distributed.initialize can be entirely inferred from mpi4py provided the job is launched in a way compatible with MPI.
  • Because of the constraint above, compatibility with MPI, this autodetect method is exclusively opt-in. Users must pass spec_detection_method="mpi4py" in a call to jax.distributed.initialize.
  • Consistent with the behavior of all other initialization methods, options passed to jax.distributed.initialize for coordinator_address, etc., will override and auto-detected settings from mpi4py.
  • Because of the new option in the arguments to jax.distributed.initialize, I have updated the documentation accordingly.
  • Lastly, related to using jax.distributed.initialize on HPC systems, there is sometimes a hang (See Usage of `jax.distributed.initialize` on HPC clusters #9582). I have included a warning if any of the suspect variables are detected, and updated the documentation to point out that the user may want to unset these variables if they are on an HPC cluster. I suspect the warning will be viewed as to noisy, since I don't know the default log level in JAX off the top of my head. In this case, perhaps it can only be emitted if the TimeOut occurs, or at the very least maintained in the documentation.

I hope this is helpful, and it would simplify our lives using JAX on supercomputers :).

Corey

Copy link

google-cla bot commented Mar 11, 2024

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Copy link
Collaborator

@nvcastet nvcastet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late review.
Thanks a lot for the contribution. I left some comments.


# First, we check the spec detection method because it will ignore submitted values
# If if succeeds.
if spec_detection_method == "mpi4py":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add a "name" attribute under ClusterEnv and set it in the derived classes: aka ompi, slurm, mpi4py, gketpu, multislicegcetpu, singleslicegcetpu and also adding a new bool "opt_in_only_method" and have here:

if spec_detection_method is not None:
  env = next((env for env in cls._cluster_types if env.name == spec_detection_method), None)
  if env is None:
    # print fatal error that method {spec_detection_method} is not supported.
  if not env.is_env_present():
    # print fatal error requested env is not present.
else:
  env = next((env for env in cls._cluster_types if env.opt_in_only_method == False and env.is_env_present()), None)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, good idea, this makes it more flexible to add other opt-in methods in the future. I've implemented this with the names you suggested, every type except mpi4py has opt_in_only_method = False as a class attribute.

from mpi4py import MPI
COMM_WORLD = MPI.COMM_WORLD

hostname = socket.gethostname()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you use COMM_WORLD.Split_type(MPI.COMM_TYPE_SHARED) instead of doing the hostname query?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle, that seems possible. I am trying to imagine scenarios on some clusters where processors that can access the same shared memory space might be over or under restrictive compared to matching against hostname...but at worst it seems like it might lead to NUMA-induced slowdowns but not anything wrong.

For example, on one system I am thinking of, the hosts will have two CPUs with two sockets each, attached to 6 GPUs and each CPU has its own set of shared memory. Using shared memory accesses here might be ambiguous, depending on the job launcher, and it's not a COMM-split method I see used much.

If you prefer COMM_TYPE_SHARED I won't object, though I'm not sure what the downside of a host query is.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even in multi-NUMA systems, it works fine since the processes can still communicate via shared-memory.
I am pretty confident that MPI.COMM_TYPE_SHARED works well for this use case.
We have been using it in Horovod (distributed layer that can be used on top of PyTorch/TF/MXNet) for years with a wide variety of users:
https://github.com/horovod/horovod/blob/master/horovod/common/mpi/mpi_context.cc#L126

I would prefer the split over implementing it ourself by broadcasting the hostname since MPI vendors may be able to optimize this specific comm split under the cover.

jax/_src/distributed.py Outdated Show resolved Hide resolved
jax/_src/clusters/mpi4py_cluster.py Outdated Show resolved Hide resolved
jax/_src/distributed.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@nvcastet nvcastet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Corey for following up.
I left a few comments.
Do you mind re-testing the final code on your cluster and adding an unit test similar to the ompi one https://github.com/google/jax/blob/main/tests/multiprocess_gpu_test.py#L184 ?

@@ -43,7 +44,11 @@ def auto_detect_unset_distributed_params(cls,
num_processes: int | None,
process_id: int | None,
local_device_ids: Sequence[int] | None,
<<<<<<< HEAD
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some merge conflict left over here?

@@ -215,7 +225,11 @@ def initialize(coordinator_address: str | None = None,
raise RuntimeError("jax.distributed.initialize() must be called before "
"any JAX computations are executed.")
global_state.initialize(coordinator_address, num_processes, process_id,
<<<<<<< HEAD
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

from mpi4py import MPI
COMM_WORLD = MPI.COMM_WORLD

hostname = socket.gethostname()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even in multi-NUMA systems, it works fine since the processes can still communicate via shared-memory.
I am pretty confident that MPI.COMM_TYPE_SHARED works well for this use case.
We have been using it in Horovod (distributed layer that can be used on top of PyTorch/TF/MXNet) for years with a wide variety of users:
https://github.com/horovod/horovod/blob/master/horovod/common/mpi/mpi_context.cc#L126

I would prefer the split over implementing it ourself by broadcasting the hostname since MPI vendors may be able to optimize this specific comm split under the cover.

@coreyjadams
Copy link
Author

Hey, I'm back, sorry for the delay. It's been a busy time here - I'm about to transition jobs and I'm wrapping up some open things.

I've made the changes requested, and added the test. I'll point out that on my system, mpirun does not allow the following arguments:

          '--oversubscribe',
          '--allow-run-as-root',

So I ran the test without these and it passes. I left them in the final code, however - I'm not sure what test system you have that allows launching mpirun jobs as root but it sure isn't a DOE super computer :).

Let me know if you want any more changes. I'll be able to test things for about 1 month more.

Copy link
Collaborator

@nvcastet nvcastet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@skye @hawkinsp Do you mind giving another look at this PR?

jax/_src/clusters/mpi4py_cluster.py Outdated Show resolved Hide resolved
jax/_src/distributed.py Outdated Show resolved Hide resolved
jax/_src/clusters/mpi4py_cluster.py Outdated Show resolved Hide resolved
@coreyjadams
Copy link
Author

Requested changes to comments and documentation is done - thanks for the feedback :)

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Jun 27, 2024
@nvcastet
Copy link
Collaborator

@coreyjadams You would need to sign the CLA too: https://cla.developers.google.com/

@nvcastet
Copy link
Collaborator

@coreyjadams could you update the year at the top of your new file and also add it to https://github.com/google/jax/blob/98b87540a76bc3b4f9f414bb763304f1cc474544/jax/BUILD#L953 so that it gets properly packaged.

@coreyjadams
Copy link
Author

I did the BUILD file and I'm working on the cla. It's signed (Since April!) but looks like my commits were coming in on a different email address. Working on it...

@coreyjadams
Copy link
Author

I did the BUILD file and I'm working on the cla. It's signed (Since April!) but looks like my commits were coming in on a different email address. Working on it...

OK, I see the green check mark now, just had to link my work email to github. Thanks!

@nvcastet nvcastet added pull ready Ready for copybara import and testing and removed pull ready Ready for copybara import and testing labels Jun 27, 2024
Copy link
Collaborator

@nouiz nouiz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't look at the code, but this looks useful.
Can you update the title to help understand more what it does? SOmething line:
Enable opt-in autodetection of distributed configuration for mpi4py

Copy link
Collaborator

@hawkinsp hawkinsp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mostly looks good to me: one api naming suggestion.

jax/_src/distributed.py Outdated Show resolved Hide resolved
@hawkinsp
Copy link
Collaborator

Please fix the lint errors and squash your commits.

@coreyjadams
Copy link
Author

The lint errors are fixed locally - just waiting on the CI to confirm then I'll deal with squashing the commits. Thanks for the feedback.

@coreyjadams coreyjadams changed the title Enable opt-in autodetection of distributed configuration Enable opt-in autodetection of distributed configuration for mpi4py Jul 1, 2024
@coreyjadams
Copy link
Author

coreyjadams commented Jul 2, 2024

I made an inadvertent force push here, so I created a branch on the commit just before that: #22235 . Suggestion is if you can force the squash, please do so? Then we just close this.

Appreciate working with you on this and sorry my git skills aren't up to par on this one :).

Update: @felker below was able to salvage the git snafu I created, I think. Should be ok to merge?

commit 79b8cbf
Author: Corey Adams <corey.adams@anl.gov>
Date:   Mon Jul 1 14:14:15 2024 -0500

    Fix mypy issues; change variable name to more universally known name

commit 10edc86
Author: Corey Adams <corey.adams@anl.gov>
Date:   Thu Jun 27 13:25:32 2024 -0500

    Change copyright year to the year this was authored

commit f7086cb
Author: Corey Adams <corey.adams@anl.gov>
Date:   Thu Jun 27 13:15:32 2024 -0500

    Update build file to include mpi4py cluster.

commit 6235eb3
Author: Corey adams <coreyjadams@gmail.com>
Date:   Thu Jun 27 12:11:48 2024 -0500

    Update distributed.py

    Clean up documentation slightly.

commit ef3a2e2
Author: Corey adams <coreyjadams@gmail.com>
Date:   Thu Jun 27 12:09:37 2024 -0500

    Update mpi4py_cluster.py

    Further clean up unneeded comments.

commit 6cc07a9
Author: Corey adams <coreyjadams@gmail.com>
Date:   Thu Jun 27 12:08:38 2024 -0500

    Update mpi4py_cluster.py

    Remove unneeded commented code.

commit 6701bd1
Merge: 5a91ac3 98b8754
Author: Corey adams <coreyjadams@gmail.com>
Date:   Thu Jun 27 12:07:25 2024 -0500

    Merge branch 'google:main' into main

commit 5a91ac3
Merge: 301bbc6 6c51234
Author: Corey adams <coreyjadams@gmail.com>
Date:   Tue May 28 22:14:08 2024 -0500

    Merge branch 'google:main' into main

commit 301bbc6
Author: Corey Adams <corey.adams@anl.gov>
Date:   Tue May 28 11:34:51 2024 -0500

    Add test to verify mpi4py based distributed initialization

commit 19e6694
Author: Corey Adams <corey.adams@anl.gov>
Date:   Tue May 28 11:14:40 2024 -0500

    Unify variable naming and fix function argument ordering

commit 72fe093
Author: Corey Adams <corey.adams@anl.gov>
Date:   Tue May 28 10:56:25 2024 -0500

    Remove unmerged code

commit 3a96e73
Merge: e4fd97e ff3db9b
Author: Corey adams <coreyjadams@gmail.com>
Date:   Tue May 28 10:51:41 2024 -0500

    Merge branch 'google:main' into main

commit e4fd97e
Merge: a697299 72a81e5
Author: Corey adams <coreyjadams@gmail.com>
Date:   Mon May 13 16:01:35 2024 -0500

    Merge branch 'google:main' into main

commit a697299
Merge: 85bcf42 1e48adc
Author: Corey adams <coreyjadams@gmail.com>
Date:   Mon May 13 14:21:32 2024 -0500

    Merge branch 'google:main' into main

commit 85bcf42
Merge: af1a4f0 06cd05d
Author: Corey Adams <corey.adams@anl.gov>
Date:   Tue Apr 16 09:09:31 2024 -0500

    Merge branch 'main' of https://github.com/google/jax

commit af1a4f0
Author: Corey Adams <corey.adams@anl.gov>
Date:   Tue Apr 16 08:58:33 2024 -0500

    update documentation and elaborate on spec_detect_method variable

commit 01f4709
Author: Corey Adams <corey.adams@anl.gov>
Date:   Tue Apr 16 08:45:38 2024 -0500

    Address feedback and comments on PR 20174; fix typo in documentation.

commit 4f22d86
Merge: 900a037 71ec6e3
Author: Corey adams <coreyjadams@gmail.com>
Date:   Mon Mar 11 11:51:30 2024 -0500

    Merge branch 'google:main' into main

commit 900a037
Author: Corey Adams <corey.adams@anl.gov>
Date:   Mon Mar 11 11:50:48 2024 -0500

    Auto-detect of mpi4py-based configuration is now strictly opt-in.

commit 1992969
Author: Corey Adams <corey.adams@anl.gov>
Date:   Thu Mar 7 12:27:43 2024 -0600

    Enable automatic detection of distrbuted variables with any configuration of MPI, as long as mpi4py is available
@copybara-service copybara-service bot merged commit 0d57c72 into jax-ml:main Jul 8, 2024
15 of 16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants