Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move CrossReplicaSum from StableHLO to CHLO. #118

Closed
wants to merge 1 commit into from

Conversation

subhankarshah
Copy link
Member

Part 1 of Many

  • Create a copy of stablehlo.cross-replica-sum in CHLO, including changes in .td and .cc files.

CrossReplicaSum is a special case of AllReduce where the reduction is a summation operation.
It is therefore decomposable to AllReduce. We intend to move the op to CHLO and lower it to AllReduce.

stablehlo/dialect/ChloOps.td Outdated Show resolved Hide resolved
@burmako burmako added Spec RFC and removed Spec labels Sep 15, 2022
Copy link
Contributor

@burmako burmako left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your contribution! Since this PR proposes to modify the opset, it'll need to go through the RFC process before it can get merged. I'm finishing setting up the process and will follow up here later this week.

Also, I wanted to document our in-person discussion about what happens next if/when this pull request is approved. This is how I propose this should look like.

(M0) Current state:

  • There's stablehlo.cross-replica-sum.
  • There's mhlo.cross-replica.sum.
  • stablehlo.cross-replica-sum => [stablehlo-legalize-to-hlo] => mhlo.cross-replica.sum.
  • mhlo.cross-replica-sum => [hlo-legalize-to-stablehlo] => stablehlo.cross-replica.sum.

(M1) After this pull request (and a pending follow-on in MLIR-HLO):

  • There's chlo.cross-replica-sum.
  • chlo.cross-replica-sum => [chlo-legalize-to-hlo] => mhlo.all_reduce.
  • There's stablehlo.cross-replica-sum.
  • There's mhlo.cross-replica.sum.
  • stablehlo.cross-replica-sum => [stablehlo-legalize-to-hlo] => mhlo.cross-replica.sum.
  • mhlo.cross-replica-sum => [hlo-legalize-to-stablehlo] => stablehlo.cross-replica.sum.

(M2) Clean up MHLO (should be done shortly afterwards, because we don't provide stability guarantees for MHLO):

  • There's chlo.cross-replica-sum.
  • chlo.cross-replica-sum => [chlo-legalize-to-hlo] => mhlo.all_reduce.
  • There's stablehlo.cross-replica-sum.
  • stablehlo.cross-replica-sum => [stablehlo-legalize-to-hlo] => mhlo.all_reduce.

(M3) Clean up StableHLO (need to wait for the backward compatibility window)

  • There's chlo.cross-replica-sum.
  • chlo.cross-replica-sum => [chlo-legalize-to-hlo] => mhlo.all_reduce.

@GleasonK
Copy link
Member

From a compatibility perspective, Migration Change. We will need to keep both Ops around, unless we figure out a way to use a generic Operation * for upgrade pass, before verifiers are run. I'll be investigating this in the near future.

(M0)

  • Current state, no impact.

(M1)

  • Verifier (pre-downgrade) blocks serialization of stablehlo.cross-replica-sum
  • chlo.cross_replica_sum => [downgrade] => stablehlo.cross-replica-sum (remove after forward compatibility window)
  • stablehlo.cross-replica-sum => [upgrade] => chlo.cross_replica_sum

(M2)

  • No impact.

(M3)

  • Remove upgrade pass, delete stablehlo.cross-replica-sum

Part 1 of Many
 - Create a copy of stablehlo.cross-replica-sum in CHLO, including changes in .td and .cc files.
@burmako
Copy link
Contributor

burmako commented Sep 30, 2022

Closing this pull request for now as discussed in #3 (comment). Will revisit once we pick up the associated ticket again.

@burmako burmako closed this Sep 30, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants