Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU] Use NCCL user buffers for collective permute and all-to-all #8874

Closed
wants to merge 4 commits into from

Conversation

trevor-m
Copy link
Contributor

@trevor-m trevor-m commented Jan 26, 2024

This PR enables XLA to take advantage of NCCL user buffers for ncclSend/ncclRecv when --xla_gpu_enable_nccl_user_buffers=true is used. Requires NCCL 2.20

@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Jan 26, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Jan 26, 2024
@trevor-m trevor-m changed the title WIP: [GPU] Use NCCL user buffers for ncclSend/ncclRecv ops (Requires NCCL 2.20) WIP: [GPU] Use NCCL user buffers for ncclSend/ncclRecv ops Jan 29, 2024
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Feb 6, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Feb 6, 2024
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Feb 9, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Feb 9, 2024
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Feb 9, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Feb 9, 2024
@trevor-m trevor-m changed the title WIP: [GPU] Use NCCL user buffers for ncclSend/ncclRecv ops WIP: [GPU] Use NCCL user buffers for collective permute and all-to-all Feb 9, 2024
@kamaljeeti
Copy link
Contributor

Hi @cheshire , can you look into this once? Thanks.

@trevor-m trevor-m changed the title WIP: [GPU] Use NCCL user buffers for collective permute and all-to-all [GPU] Use NCCL user buffers for collective permute and all-to-all Feb 15, 2024
@kamaljeeti
Copy link
Contributor

Hi @cheshire , there is an internal CI build failing can you look into this once? Thanks.

@cheshire cheshire added the kokoro:force-run Forces CI to rerun label Mar 12, 2024
@trevor-m trevor-m force-pushed the p2p-user-buffers branch 2 times, most recently from 4cb441a to 6471175 Compare March 12, 2024 17:32
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Mar 12, 2024
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Mar 12, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Mar 12, 2024
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 12, 2024
…to-all

Imported from GitHub PR openxla/xla#8874

This PR enables XLA to take advantage of NCCL user buffers for ncclSend/ncclRecv when `--xla_gpu_enable_nccl_user_buffers=true` is used. Requires NCCL 2.20

Copybara import of the project:

--
8de2786d3242c76bed385235b5655156ee187e5f by Trevor Morris <tmorris@nvidia.com>:

Use NCCL user buffers for ncclSend/ncclRecv ops

--
56ceecb1b7fc1606dd00b514bbdb7d039e787b8c by Trevor Morris <tmorris@nvidia.com>:

Include memory space in buffers for collective permute and send/recv

--
64711757e48b619b9e2d322fc49714a94194d8f1 by Trevor Morris <tmorris@nvidia.com>:

Don't offload send, recv

Merging this change closes #8874

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#8874 from trevor-m:p2p-user-buffers 64711757e48b619b9e2d322fc49714a94194d8f1
PiperOrigin-RevId: 615104094
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Mar 12, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Mar 12, 2024
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 12, 2024
…to-all

Imported from GitHub PR openxla/xla#8874

This PR enables XLA to take advantage of NCCL user buffers for ncclSend/ncclRecv when `--xla_gpu_enable_nccl_user_buffers=true` is used. Requires NCCL 2.20

Copybara import of the project:

--
8de2786d3242c76bed385235b5655156ee187e5f by Trevor Morris <tmorris@nvidia.com>:

Use NCCL user buffers for ncclSend/ncclRecv ops

--
56ceecb1b7fc1606dd00b514bbdb7d039e787b8c by Trevor Morris <tmorris@nvidia.com>:

Include memory space in buffers for collective permute and send/recv

--
64711757e48b619b9e2d322fc49714a94194d8f1 by Trevor Morris <tmorris@nvidia.com>:

Don't offload send, recv

Merging this change closes #8874

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#8874 from trevor-m:p2p-user-buffers 64711757e48b619b9e2d322fc49714a94194d8f1
PiperOrigin-RevId: 615104094
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 12, 2024
…to-all

Imported from GitHub PR openxla/xla#8874

This PR enables XLA to take advantage of NCCL user buffers for ncclSend/ncclRecv when `--xla_gpu_enable_nccl_user_buffers=true` is used. Requires NCCL 2.20

Copybara import of the project:

--
8de2786d3242c76bed385235b5655156ee187e5f by Trevor Morris <tmorris@nvidia.com>:

Use NCCL user buffers for ncclSend/ncclRecv ops

--
56ceecb1b7fc1606dd00b514bbdb7d039e787b8c by Trevor Morris <tmorris@nvidia.com>:

Include memory space in buffers for collective permute and send/recv

--
64711757e48b619b9e2d322fc49714a94194d8f1 by Trevor Morris <tmorris@nvidia.com>:

Don't offload send, recv

Merging this change closes #8874

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#8874 from trevor-m:p2p-user-buffers 64711757e48b619b9e2d322fc49714a94194d8f1
PiperOrigin-RevId: 615104094
// opcode or async wrapped opcode is in kSupportedOpcodes.
if (kSupportedOpcodes->contains(alias->instruction()->opcode()) ||
(alias->instruction()->opcode() == HloOpcode::kAsyncStart ||
alias->instruction()->opcode() == HloOpcode::kAsyncDone) &&
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This causes a warning which we treat as error:

error: '&&' within '||' [-Werror,-Wlogical-op-parentheses]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CC @ddunl for reconciling warnings (given that we use Clang in both places now, why can't we have an identical set of warnings?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for letting me know, I fixed the conditional.

copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 15, 2024
…to-all

Imported from GitHub PR openxla/xla#8874

This PR enables XLA to take advantage of NCCL user buffers for ncclSend/ncclRecv when `--xla_gpu_enable_nccl_user_buffers=true` is used. Requires NCCL 2.20

Copybara import of the project:

--
8de2786d3242c76bed385235b5655156ee187e5f by Trevor Morris <tmorris@nvidia.com>:

Use NCCL user buffers for ncclSend/ncclRecv ops

--
56ceecb1b7fc1606dd00b514bbdb7d039e787b8c by Trevor Morris <tmorris@nvidia.com>:

Include memory space in buffers for collective permute and send/recv

--
b3e776cb8486f2952dcb60a753dcea3c11da4d87 by Trevor Morris <tmorris@nvidia.com>:

Don't offload send, recv

Merging this change closes #8874

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#8874 from trevor-m:p2p-user-buffers b3e776cb8486f2952dcb60a753dcea3c11da4d87
PiperOrigin-RevId: 615104094
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 15, 2024
…to-all

Imported from GitHub PR openxla/xla#8874

This PR enables XLA to take advantage of NCCL user buffers for ncclSend/ncclRecv when `--xla_gpu_enable_nccl_user_buffers=true` is used. Requires NCCL 2.20

Copybara import of the project:

--
8de2786d3242c76bed385235b5655156ee187e5f by Trevor Morris <tmorris@nvidia.com>:

Use NCCL user buffers for ncclSend/ncclRecv ops

--
56ceecb1b7fc1606dd00b514bbdb7d039e787b8c by Trevor Morris <tmorris@nvidia.com>:

Include memory space in buffers for collective permute and send/recv

--
b3e776cb8486f2952dcb60a753dcea3c11da4d87 by Trevor Morris <tmorris@nvidia.com>:

Don't offload send, recv

Merging this change closes #8874

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#8874 from trevor-m:p2p-user-buffers b3e776cb8486f2952dcb60a753dcea3c11da4d87
PiperOrigin-RevId: 615104094
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Mar 18, 2024
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 18, 2024
…to-all

Imported from GitHub PR openxla/xla#8874

This PR enables XLA to take advantage of NCCL user buffers for ncclSend/ncclRecv when `--xla_gpu_enable_nccl_user_buffers=true` is used. Requires NCCL 2.20

Copybara import of the project:

--
98acdf27d4eba6b19652a76d3f7dcd6630349fc5 by Trevor Morris <tmorris@nvidia.com>:

Use NCCL user buffers for ncclSend/ncclRecv ops

--
bcc289b49bcf2086b50a86a2381ea1b80acd3dd2 by Trevor Morris <tmorris@nvidia.com>:

Include memory space in buffers for collective permute and send/recv

--
4a83d8906b6b5e305dad23fc1d8b9a5069637279 by Trevor Morris <tmorris@nvidia.com>:

Don't offload send, recv

--
0083a418c4ab119ed5a0eb061113104980476943 by Trevor Morris <tmorris@nvidia.com>:

Fix conditional

Merging this change closes #8874

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#8874 from trevor-m:p2p-user-buffers 0083a418c4ab119ed5a0eb061113104980476943
PiperOrigin-RevId: 615104094
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 18, 2024
…to-all

Imported from GitHub PR openxla/xla#8874

This PR enables XLA to take advantage of NCCL user buffers for ncclSend/ncclRecv when `--xla_gpu_enable_nccl_user_buffers=true` is used. Requires NCCL 2.20

Copybara import of the project:

--
98acdf27d4eba6b19652a76d3f7dcd6630349fc5 by Trevor Morris <tmorris@nvidia.com>:

Use NCCL user buffers for ncclSend/ncclRecv ops

--
bcc289b49bcf2086b50a86a2381ea1b80acd3dd2 by Trevor Morris <tmorris@nvidia.com>:

Include memory space in buffers for collective permute and send/recv

--
4a83d8906b6b5e305dad23fc1d8b9a5069637279 by Trevor Morris <tmorris@nvidia.com>:

Don't offload send, recv

--
0083a418c4ab119ed5a0eb061113104980476943 by Trevor Morris <tmorris@nvidia.com>:

Fix conditional

Merging this change closes #8874

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#8874 from trevor-m:p2p-user-buffers 0083a418c4ab119ed5a0eb061113104980476943
PiperOrigin-RevId: 615104094
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 18, 2024
…to-all

Imported from GitHub PR openxla/xla#8874

This PR enables XLA to take advantage of NCCL user buffers for ncclSend/ncclRecv when `--xla_gpu_enable_nccl_user_buffers=true` is used. Requires NCCL 2.20

Copybara import of the project:

--
98acdf27d4eba6b19652a76d3f7dcd6630349fc5 by Trevor Morris <tmorris@nvidia.com>:

Use NCCL user buffers for ncclSend/ncclRecv ops

--
bcc289b49bcf2086b50a86a2381ea1b80acd3dd2 by Trevor Morris <tmorris@nvidia.com>:

Include memory space in buffers for collective permute and send/recv

--
4a83d8906b6b5e305dad23fc1d8b9a5069637279 by Trevor Morris <tmorris@nvidia.com>:

Don't offload send, recv

--
0083a418c4ab119ed5a0eb061113104980476943 by Trevor Morris <tmorris@nvidia.com>:

Fix conditional

Merging this change closes #8874

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#8874 from trevor-m:p2p-user-buffers 0083a418c4ab119ed5a0eb061113104980476943
PiperOrigin-RevId: 615104094
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 19, 2024
…to-all

Imported from GitHub PR openxla/xla#8874

This PR enables XLA to take advantage of NCCL user buffers for ncclSend/ncclRecv when `--xla_gpu_enable_nccl_user_buffers=true` is used. Requires NCCL 2.20

Copybara import of the project:

--
98acdf27d4eba6b19652a76d3f7dcd6630349fc5 by Trevor Morris <tmorris@nvidia.com>:

Use NCCL user buffers for ncclSend/ncclRecv ops

--
bcc289b49bcf2086b50a86a2381ea1b80acd3dd2 by Trevor Morris <tmorris@nvidia.com>:

Include memory space in buffers for collective permute and send/recv

--
4a83d8906b6b5e305dad23fc1d8b9a5069637279 by Trevor Morris <tmorris@nvidia.com>:

Don't offload send, recv

--
0083a418c4ab119ed5a0eb061113104980476943 by Trevor Morris <tmorris@nvidia.com>:

Fix conditional

Merging this change closes #8874

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#8874 from trevor-m:p2p-user-buffers 0083a418c4ab119ed5a0eb061113104980476943
PiperOrigin-RevId: 615104094
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 19, 2024
…to-all

Imported from GitHub PR openxla/xla#8874

This PR enables XLA to take advantage of NCCL user buffers for ncclSend/ncclRecv when `--xla_gpu_enable_nccl_user_buffers=true` is used. Requires NCCL 2.20

Copybara import of the project:

--
98acdf27d4eba6b19652a76d3f7dcd6630349fc5 by Trevor Morris <tmorris@nvidia.com>:

Use NCCL user buffers for ncclSend/ncclRecv ops

--
bcc289b49bcf2086b50a86a2381ea1b80acd3dd2 by Trevor Morris <tmorris@nvidia.com>:

Include memory space in buffers for collective permute and send/recv

--
4a83d8906b6b5e305dad23fc1d8b9a5069637279 by Trevor Morris <tmorris@nvidia.com>:

Don't offload send, recv

--
0083a418c4ab119ed5a0eb061113104980476943 by Trevor Morris <tmorris@nvidia.com>:

Fix conditional

Merging this change closes #8874

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#8874 from trevor-m:p2p-user-buffers 0083a418c4ab119ed5a0eb061113104980476943
PiperOrigin-RevId: 615104094
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 19, 2024
…to-all

Imported from GitHub PR openxla/xla#8874

This PR enables XLA to take advantage of NCCL user buffers for ncclSend/ncclRecv when `--xla_gpu_enable_nccl_user_buffers=true` is used. Requires NCCL 2.20

Copybara import of the project:

--
98acdf27d4eba6b19652a76d3f7dcd6630349fc5 by Trevor Morris <tmorris@nvidia.com>:

Use NCCL user buffers for ncclSend/ncclRecv ops

--
bcc289b49bcf2086b50a86a2381ea1b80acd3dd2 by Trevor Morris <tmorris@nvidia.com>:

Include memory space in buffers for collective permute and send/recv

--
4a83d8906b6b5e305dad23fc1d8b9a5069637279 by Trevor Morris <tmorris@nvidia.com>:

Don't offload send, recv

--
0083a418c4ab119ed5a0eb061113104980476943 by Trevor Morris <tmorris@nvidia.com>:

Fix conditional

Merging this change closes #8874

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#8874 from trevor-m:p2p-user-buffers 0083a418c4ab119ed5a0eb061113104980476943
PiperOrigin-RevId: 615104094
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 19, 2024
…to-all

Imported from GitHub PR openxla/xla#8874

This PR enables XLA to take advantage of NCCL user buffers for ncclSend/ncclRecv when `--xla_gpu_enable_nccl_user_buffers=true` is used. Requires NCCL 2.20

Copybara import of the project:

--
98acdf27d4eba6b19652a76d3f7dcd6630349fc5 by Trevor Morris <tmorris@nvidia.com>:

Use NCCL user buffers for ncclSend/ncclRecv ops

--
bcc289b49bcf2086b50a86a2381ea1b80acd3dd2 by Trevor Morris <tmorris@nvidia.com>:

Include memory space in buffers for collective permute and send/recv

--
4a83d8906b6b5e305dad23fc1d8b9a5069637279 by Trevor Morris <tmorris@nvidia.com>:

Don't offload send, recv

--
0083a418c4ab119ed5a0eb061113104980476943 by Trevor Morris <tmorris@nvidia.com>:

Fix conditional

Merging this change closes #8874

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#8874 from trevor-m:p2p-user-buffers 0083a418c4ab119ed5a0eb061113104980476943
PiperOrigin-RevId: 615104094
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 19, 2024
…to-all

Imported from GitHub PR openxla/xla#8874

This PR enables XLA to take advantage of NCCL user buffers for ncclSend/ncclRecv when `--xla_gpu_enable_nccl_user_buffers=true` is used. Requires NCCL 2.20

Copybara import of the project:

--
98acdf27d4eba6b19652a76d3f7dcd6630349fc5 by Trevor Morris <tmorris@nvidia.com>:

Use NCCL user buffers for ncclSend/ncclRecv ops

--
bcc289b49bcf2086b50a86a2381ea1b80acd3dd2 by Trevor Morris <tmorris@nvidia.com>:

Include memory space in buffers for collective permute and send/recv

--
4a83d8906b6b5e305dad23fc1d8b9a5069637279 by Trevor Morris <tmorris@nvidia.com>:

Don't offload send, recv

--
0083a418c4ab119ed5a0eb061113104980476943 by Trevor Morris <tmorris@nvidia.com>:

Fix conditional

Merging this change closes #8874

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#8874 from trevor-m:p2p-user-buffers 0083a418c4ab119ed5a0eb061113104980476943
PiperOrigin-RevId: 615104094
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 19, 2024
…to-all

Imported from GitHub PR openxla/xla#8874

This PR enables XLA to take advantage of NCCL user buffers for ncclSend/ncclRecv when `--xla_gpu_enable_nccl_user_buffers=true` is used. Requires NCCL 2.20

Copybara import of the project:

--
98acdf27d4eba6b19652a76d3f7dcd6630349fc5 by Trevor Morris <tmorris@nvidia.com>:

Use NCCL user buffers for ncclSend/ncclRecv ops

--
bcc289b49bcf2086b50a86a2381ea1b80acd3dd2 by Trevor Morris <tmorris@nvidia.com>:

Include memory space in buffers for collective permute and send/recv

--
4a83d8906b6b5e305dad23fc1d8b9a5069637279 by Trevor Morris <tmorris@nvidia.com>:

Don't offload send, recv

--
0083a418c4ab119ed5a0eb061113104980476943 by Trevor Morris <tmorris@nvidia.com>:

Fix conditional

Merging this change closes #8874

PiperOrigin-RevId: 617140675
@cheshire
Copy link
Contributor

I'm actually seeing crashes from this: it checks layout on recv, but recv shape is tuple, which doesn't have layout.

steeve pushed a commit to zml/xla that referenced this pull request Aug 30, 2024
…nd all-to-all

Imported from GitHub PR openxla#8874

This PR enables XLA to take advantage of NCCL user buffers for ncclSend/ncclRecv when `--xla_gpu_enable_nccl_user_buffers=true` is used. Requires NCCL 2.20

Copybara import of the project:

--
98acdf2 by Trevor Morris <tmorris@nvidia.com>:

Use NCCL user buffers for ncclSend/ncclRecv ops

--
bcc289b by Trevor Morris <tmorris@nvidia.com>:

Include memory space in buffers for collective permute and send/recv

--
4a83d89 by Trevor Morris <tmorris@nvidia.com>:

Don't offload send, recv

--
0083a41 by Trevor Morris <tmorris@nvidia.com>:

Fix conditional

Merging this change closes openxla#8874

COPYBARA_INTEGRATE_REVIEW=openxla#8874 from trevor-m:p2p-user-buffers 0083a41
PiperOrigin-RevId: 617140675
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants