Skip to content

[NFC][mlir][shard] Unify MoveLastSplitAxisPattern/MoveLastSplitAxisPattern#192295

Merged
fschlimb merged 1 commit into
llvm:mainfrom
fschlimb:movelast
Apr 17, 2026
Merged

[NFC][mlir][shard] Unify MoveLastSplitAxisPattern/MoveLastSplitAxisPattern#192295
fschlimb merged 1 commit into
llvm:mainfrom
fschlimb:movelast

Conversation

@fschlimb
Copy link
Copy Markdown
Contributor

Made MoveLastSplitAxisPattern more general to also cover MoveLastSplitAxisPattern.
Less code, same functionality.
Assisted by claude.

@fschlimb fschlimb requested review from joker-eph and sogartar April 15, 2026 17:24
@llvmbot llvmbot added the mlir label Apr 15, 2026
@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Apr 15, 2026

@llvm/pr-subscribers-mlir

Author: Frank Schlimbach (fschlimb)

Changes

Made MoveLastSplitAxisPattern more general to also cover MoveLastSplitAxisPattern.
Less code, same functionality.
Assisted by claude.


Full diff: https://github.com/llvm/llvm-project/pull/192295.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Shard/Transforms/Partition.cpp (+7-99)
diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
index 57502fbf9e276..05b864dc5a29d 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
@@ -271,106 +271,15 @@ static ShapedType allToAllResultShape(ShapedType srcShape, int64_t splitCount,
   return srcShape.cloneWith(tgtShape, srcShape.getElementType());
 }
 
-/// Move a split axis between tensor dimensions:
-/// e.g. [[0], []] -> [[], [0]].
-class MoveSplitAxisPattern : public ReshardingPattern {
-  // Detect if the resharding moves a single split axis from one tensor
-  // dimension to another tensor dimension. If detected, returns the
-  // corresponding (tgt_tensor_dim, grid_axis) pair.
-  static std::optional<std::tuple<int64_t, GridAxis>>
-  detect(const Sharding &srcSharding, const Sharding &tgtSharding,
-         int64_t srcTensorDim) {
-    if (static_cast<size_t>(srcTensorDim) >= srcSharding.getSplitAxes().size())
-      return std::nullopt;
-    auto srcAxes = srcSharding.getSplitAxes()[srcTensorDim].asArrayRef();
-    if (srcAxes.size() != 1)
-      return std::nullopt;
-    for (size_t tgtTensorDim = 0;
-         tgtTensorDim < tgtSharding.getSplitAxes().size(); ++tgtTensorDim) {
-      if (static_cast<int64_t>(tgtTensorDim) == srcTensorDim)
-        continue;
-      auto tgtAxes = tgtSharding.getSplitAxes()[tgtTensorDim].asArrayRef();
-      if (tgtAxes.size() != 1 || srcAxes.front() != tgtAxes.front())
-        continue;
-      return std::make_tuple(static_cast<int64_t>(tgtTensorDim),
-                             srcAxes.front());
-    }
-    return std::nullopt;
-  }
-
-  static Sharding tgtSharding(MLIRContext *ctx, const Sharding &srcSharding,
-                              int64_t srcTensorDim, int64_t tgtTensorDim) {
-    SmallVector<GridAxesAttr> tgtShardingSplitAxes =
-        llvm::to_vector(srcSharding.getSplitAxes());
-    while (static_cast<int64_t>(tgtShardingSplitAxes.size()) <= tgtTensorDim) {
-      tgtShardingSplitAxes.push_back(GridAxesAttr::get(ctx, {}));
-    }
-
-    auto srcSplitAxes =
-        llvm::to_vector(tgtShardingSplitAxes[srcTensorDim].asArrayRef());
-    assert(srcSplitAxes.size() == 1);
-    auto gridAxis = srcSplitAxes.back();
-    srcSplitAxes.pop_back();
-    tgtShardingSplitAxes[srcTensorDim] = GridAxesAttr::get(ctx, srcSplitAxes);
-
-    auto tgtSplitAxes =
-        llvm::to_vector(tgtShardingSplitAxes[tgtTensorDim].asArrayRef());
-    tgtSplitAxes.push_back(gridAxis);
-    tgtShardingSplitAxes[tgtTensorDim] = GridAxesAttr::get(ctx, tgtSplitAxes);
-
-    return Sharding::get(srcSharding.getGridAttr(), tgtShardingSplitAxes);
-  }
-
-  static std::tuple<TypedValue<ShapedType>, Sharding>
-  apply(ImplicitLocOpBuilder &builder, GridOp grid, Sharding srcSharding,
-        ShapedType srcUnshardedType, TypedValue<ShapedType> srcShard,
-        int64_t srcTensorDim, int64_t tgtTensorDim, GridAxis gridAxis) {
-    MLIRContext *ctx = builder.getContext();
-    builder.setInsertionPointAfterValue(srcShard);
-
-    Sharding resultSharding =
-        tgtSharding(ctx, std::move(srcSharding), srcTensorDim, tgtTensorDim);
-    ShapedType a2aResultShape =
-        allToAllResultShape(srcShard.getType(), grid.getShape()[gridAxis],
-                            srcTensorDim, tgtTensorDim);
-    Value allToAllResult = AllToAllOp::create(
-        builder,
-        RankedTensorType::get(a2aResultShape.getShape(),
-                              a2aResultShape.getElementType()),
-        grid.getSymName(), SmallVector<GridAxis>({gridAxis}), srcShard,
-        APInt(64, tgtTensorDim), APInt(64, srcTensorDim));
-    ShapedType tgtShape =
-        shardShapedType(srcUnshardedType, grid, resultSharding);
-    TypedValue<ShapedType> tgtShard =
-        tensor::CastOp::create(builder, tgtShape, allToAllResult).getResult();
-    return {tgtShard, resultSharding};
-  }
-
-public:
-  std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
-  tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim,
-           const Sharding &srcSharding, const Sharding &tgtSharding,
-           ShapedType srcUnshardedType,
-           TypedValue<ShapedType> srcShard) override {
-    if (hasStaticOffsetsOrHalos(srcSharding, tgtSharding))
-      return std::nullopt;
-    if (auto detectRes = detect(srcSharding, tgtSharding, tensorDim)) {
-      auto [tgtTensorDim, gridAxis] = detectRes.value();
-      return apply(builder, grid, srcSharding, srcUnshardedType, srcShard,
-                   tensorDim, tgtTensorDim, gridAxis);
-    }
-    return std::nullopt;
-  }
-};
-
 /// Move the last split axis of one tensor dimension to the front of another
-/// tensor dimension's split axes, e.g. [[0, 1], [2]] -> [[0], [1, 2]].
+/// tensor dimension's split axes, e.g. [[0], []] -> [[], [0]] or
+/// [[0, 1], [2]] -> [[0], [1, 2]].
 class MoveLastSplitAxisPattern : public ReshardingPattern {
   // Detect if the resharding moves the last grid axis of srcTensorDim to the
   // front of another tensor dimension's split axes. If detected, returns
   // (tgtTensorDim, movedGridAxis).
   //
-  // Pattern: src[srcTensorDim] = [a1,...,a(n-1),an]  (n >= 2)
+  // Pattern: src[srcTensorDim] = [a1,...,a(n-1),an]  (n >= 1)
   //          tgt[srcTensorDim] = [a1,...,a(n-1)]
   //          src[tgtTensorDim] = [b1,...,bm]          (m >= 0)
   //          tgt[tgtTensorDim] = [an, b1,...,bm]
@@ -380,8 +289,8 @@ class MoveLastSplitAxisPattern : public ReshardingPattern {
     if (static_cast<size_t>(srcTensorDim) >= srcSharding.getSplitAxes().size())
       return std::nullopt;
     auto srcAxes = srcSharding.getSplitAxes()[srcTensorDim].asArrayRef();
-    // Need at least 2 axes to move the last one.
-    if (srcAxes.size() < 2)
+    // Need at least 1 axis to move.
+    if (srcAxes.empty())
       return std::nullopt;
 
     // After the move the source tensor dim should lose its last axis.
@@ -586,12 +495,11 @@ static TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder,
   // Each pattern's tryApply checks its own applicability preconditions.
   static UpdateHaloPattern updateHaloPattern;
   static MoveLastSplitAxisPattern moveLastSplitAxisPattern;
-  static MoveSplitAxisPattern moveSplitAxisPattern;
   static SplitLastAxisPattern splitLastAxisPattern;
   static UnsplitLastAxesPattern unsplitLastAxesPattern;
   static ReshardingPattern *patterns[] = {
-      &updateHaloPattern, &moveLastSplitAxisPattern, &moveSplitAxisPattern,
-      &splitLastAxisPattern, &unsplitLastAxesPattern};
+      &updateHaloPattern, &moveLastSplitAxisPattern, &splitLastAxisPattern,
+      &unsplitLastAxesPattern};
   TypedValue<ShapedType> currentShard = shardedSrc;
   Sharding currentSharding = srcSharding;
   for (int64_t dim = 0;

Copy link
Copy Markdown
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

were there no users? No tests broke?

@joker-eph
Copy link
Copy Markdown
Contributor

joker-eph commented Apr 16, 2026

were there no users? No tests broke?

This is NFC: the pattern modified can now cover the pattern that got removed.

@fschlimb can you add NFC to the PR description to make it clear?

Copy link
Copy Markdown
Contributor

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Impressive that there was so much duplication here in the first place...

@fschlimb fschlimb changed the title [mlir][shard] Unify MoveLastSplitAxisPattern/MoveLastSplitAxisPattern [NFC][mlir][shard] Unify MoveLastSplitAxisPattern/MoveLastSplitAxisPattern Apr 17, 2026
@fschlimb
Copy link
Copy Markdown
Contributor Author

were there no users? No tests broke?

Some of the older shard stuff has no tests and unclear documentation/description. That's why I changed the orig pattern over time to cover only the patterns that where correct and tested. Recently @joker-eph added the probably originally intended functionality including the necessary tests (#189241).

@fschlimb fschlimb merged commit bcc606c into llvm:main Apr 17, 2026
12 checks passed
alexfh pushed a commit to alexfh/llvm-project that referenced this pull request Apr 18, 2026
…ttern (llvm#192295)

Made MoveLastSplitAxisPattern more general to also cover MoveLastSplitAxisPattern.
Less code, same functionality.
Assisted by claude.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants