Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Add collective_broadcast to the StableHLO specification #1809

Merged

Conversation

chaserileyroberts
Copy link
Contributor

This RFC proposes adding collective_broadcast as one of the collective communication primitives.
Please provide any feedback you feel is valuable.

@google-cla
Copy link

google-cla bot commented Oct 17, 2023

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@burmako burmako added the RFC label Oct 17, 2023
@GleasonK
Copy link
Member

View this failed invocation of the CLA check for more information.

Also - could you follow this link and sign the CLA when you get a chance.

@chaserileyroberts
Copy link
Contributor Author

Also - could you follow this link and sign the CLA when you get a chance.

I work at Nvidia so I need to go through the separate channel. I'll take care of it don't worry.

rfcs/20231017-collective-broadcast.md Outdated Show resolved Hide resolved
rfcs/20231017-collective-broadcast.md Outdated Show resolved Hide resolved
rfcs/20231017-collective-broadcast.md Outdated Show resolved Hide resolved
rfcs/20231017-collective-broadcast.md Outdated Show resolved Hide resolved
rfcs/20231017-collective-broadcast.md Outdated Show resolved Hide resolved
rfcs/20231017-collective-broadcast.md Outdated Show resolved Hide resolved
rfcs/20231017-collective-broadcast.md Outdated Show resolved Hide resolved
rfcs/20231017-collective-broadcast.md Outdated Show resolved Hide resolved
rfcs/20231017-collective-broadcast.md Show resolved Hide resolved
Copy link
Contributor Author

@chaserileyroberts chaserileyroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made updates based on the new semantics.

rfcs/20231017-collective-broadcast.md Outdated Show resolved Hide resolved
rfcs/20231017-collective-broadcast.md Outdated Show resolved Hide resolved
rfcs/20231017-collective-broadcast.md Outdated Show resolved Hide resolved
rfcs/20231017-collective-broadcast.md Outdated Show resolved Hide resolved
rfcs/20231017-collective-broadcast.md Outdated Show resolved Hide resolved
rfcs/20231017-collective-broadcast.md Outdated Show resolved Hide resolved
rfcs/20231017-collective-broadcast.md Outdated Show resolved Hide resolved
rfcs/20231017-collective-broadcast.md Outdated Show resolved Hide resolved
rfcs/20231017-collective-broadcast.md Outdated Show resolved Hide resolved
rfcs/20231017-collective-broadcast.md Outdated Show resolved Hide resolved
rfcs/20231017-collective-broadcast.md Outdated Show resolved Hide resolved
rfcs/20231017-collective-broadcast.md Outdated Show resolved Hide resolved
rfcs/20231017-collective-broadcast.md Show resolved Hide resolved
rfcs/20231017-collective-broadcast.md Show resolved Hide resolved
@andydavis1
Copy link

Thank you for the RFC. The current set of collective operations are currently used in the SPMD context, where each device is running the same program. With the proposed collective_broadcast, do you have an MPMD use case for this, or were you expecting other devices which were not broadcasting to produce an empty buffer as a result of the broadcast?

@chaserileyroberts
Copy link
Contributor Author

With the proposed collective_broadcast, do you have an MPMD use case for this, or were you expecting other devices which were not broadcasting to produce an empty buffer as a result of the broadcast

No, the idea is for this to still be targeting exclusively SPMD. On the cuda side, the goal is to have this op lower to the exact same nccl.Broadcast operation on all devices. Devices that are not in any replica_group of the SPMD op return zeros, as per the specification and latest example.

@chaserileyroberts
Copy link
Contributor Author

chaserileyroberts commented Nov 8, 2023

To give a very concrete example of how this operation could be used in a realistic SPMD setting, let me show you some code.

Here, we have an example of how we could implement the 2D pgemm algorithm summa in JAX once this spec is in place.

@partial(shard_map, mesh=Mesh(...), 
    in_specs=(P('x', 'y'), P('x', 'y')), out_specs=P('x', 'y'))
def summa_matrix_multiply(a, b):
  for i in range(N):
     abcast = pcollective_broadcast(a, 'y', root=i)
     bbcast = pcollective_broadcast(b, 'x', root=i)
     if i == 0:
        c = abcast @ bbcast
     c += abcast @ bbcast
  return  c

Copy link
Member

@GleasonK GleasonK left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RFC Approved. I'll send a follow-up with markdownlint fixes.

@GleasonK GleasonK merged commit 76e25a5 into openxla:main Nov 20, 2023
6 of 7 checks passed
@chaserileyroberts
Copy link
Contributor Author

🥳 🎉

Thanks again Kevin for getting this over the finish line! I'll get to work on getting this implemented in JAX and cuda xla backend.

GleasonK pushed a commit that referenced this pull request Nov 30, 2023
This is the first PR for RFC #1809. I did not add an interpreter
implementation as @GleasonK specifically asked me to leave that for new
staff joining his team.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants