Skip to content

[mlir][shard, mpi] Allow more than one last axis to be "unsplit"#180754

Merged
fschlimb merged 3 commits into
llvm:mainfrom
fschlimb:shard-unsplit-mult
Feb 11, 2026
Merged

[mlir][shard, mpi] Allow more than one last axis to be "unsplit"#180754
fschlimb merged 3 commits into
llvm:mainfrom
fschlimb:shard-unsplit-mult

Conversation

@fschlimb
Copy link
Copy Markdown
Contributor

A resharding pattern allowed only a single trailing axis to be "unsplit".
This PR allows multiple trailing axes to be "unsplit".

@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Feb 10, 2026

@llvm/pr-subscribers-mlir

Author: Frank Schlimbach (fschlimb)

Changes

A resharding pattern allowed only a single trailing axis to be "unsplit".
This PR allows multiple trailing axes to be "unsplit".


Full diff: https://github.com/llvm/llvm-project/pull/180754.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Shard/Transforms/Partition.cpp (+77-58)
  • (modified) mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir (+12-1)
  • (modified) mlir/test/Dialect/Shard/partition.mlir (+24)
diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
index e619c7073a8c4..8652d665e46bf 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
@@ -132,98 +132,117 @@ trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
   return std::nullopt;
 }
 
-// Detect if the resharding is of type e.g.
-// [[0, 1, 2]] -> [[0, 1]].
-// If detected, returns the corresponding tensor axis grid axis pair.
-static std::optional<std::tuple<int64_t, GridAxis>>
-detectUnsplitLastAxisInResharding(const Sharding &sourceSharding,
+// Detect if the resharding removes trailing split Axes along a tensor
+// dimension, e.g.
+// [[0, 1, 2]] -> [[0, 1]], [[0, 1, 2]] -> [0] or [[0, 1, 2]] -> [].
+// If detected, returns the corresponding (tensor dim, grid axes) pair, where
+// the "grid axes" are the removed trailing split axes.
+static std::optional<std::tuple<int64_t, SmallVector<GridAxis>>>
+detectUnsplitLastAxesInResharding(const Sharding &sourceSharding,
                                   const Sharding &targetSharding) {
-  for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
-       ++tensorAxis) {
-    if (targetSharding.getSplitAxes().size() > tensorAxis) {
-      if (sourceSharding.getSplitAxes()[tensorAxis].size() !=
-          targetSharding.getSplitAxes()[tensorAxis].size() + 1)
+  for (size_t tensorDim = 0; tensorDim < sourceSharding.getSplitAxes().size();
+       ++tensorDim) {
+    if (targetSharding.getSplitAxes().size() > tensorDim) {
+      // No match if the target sharding does not have less split axes than the
+      // source sharding along the current tensor dimension.
+      if (sourceSharding.getSplitAxes()[tensorDim].size() <=
+          targetSharding.getSplitAxes()[tensorDim].size())
         continue;
-      if (!llvm::equal(
-              llvm::make_range(
-                  sourceSharding.getSplitAxes()[tensorAxis]
-                      .asArrayRef()
-                      .begin(),
-                  sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
-                      1),
-              targetSharding.getSplitAxes()[tensorAxis].asArrayRef()))
+      // No match if the split axes of the target sharding are different from
+      // the first split axes of the source sharding.
+      if (!std::equal(
+              targetSharding.getSplitAxes()[tensorDim].asArrayRef().begin(),
+              targetSharding.getSplitAxes()[tensorDim].asArrayRef().end(),
+              sourceSharding.getSplitAxes()[tensorDim].asArrayRef().begin()))
         continue;
     } else {
-      if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1)
+      // Here the target dimension is replicated; there is nothing to do if the
+      // source dimension is also replicated.
+      if (sourceSharding.getSplitAxes()[tensorDim].size() == 0)
         continue;
     }
-    return std::make_tuple(
-        tensorAxis,
-        sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
+    // This is a match. Return the current tensor dimension and the trailing
+    // grid axis of the source sharding along this dimension.
+    SmallVector<GridAxis> unsplitAxes;
+    size_t dimOff = tensorDim >= targetSharding.getSplitAxes().size()
+                        ? 0
+                        : targetSharding.getSplitAxes()[tensorDim].size();
+    for (auto a =
+             sourceSharding.getSplitAxes()[tensorDim].asArrayRef().begin() +
+             dimOff;
+         a != sourceSharding.getSplitAxes()[tensorDim].asArrayRef().end(); ++a)
+      unsplitAxes.push_back(*a);
+    return std::make_tuple(tensorDim, unsplitAxes);
   }
   return std::nullopt;
 }
 
-static Sharding targetShardingInUnsplitLastAxis(MLIRContext *ctx,
+// Return the resulting Sharding if the unsplit last axes resharding is applied.
+static Sharding targetShardingInUnsplitLastAxes(MLIRContext *ctx,
                                                 const Sharding &sourceSharding,
-                                                int64_t splitTensorAxis) {
-  SmallVector<GridAxesAttr> targetShardingSplitAxes =
+                                                int64_t splitTensorDim,
+                                                size_t numUnsplitAxes) {
+  SmallVector<GridAxesAttr> resSplitAxes =
       llvm::to_vector(sourceSharding.getSplitAxes());
-  assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >
-         splitTensorAxis);
-  auto targetSplitAxes =
-      llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
-
-  targetSplitAxes.pop_back();
-  targetShardingSplitAxes[splitTensorAxis] =
-      GridAxesAttr::get(ctx, targetSplitAxes);
-  return Sharding::get(sourceSharding.getGridAttr(), targetShardingSplitAxes);
+  assert(static_cast<int64_t>(resSplitAxes.size()) > splitTensorDim);
+  ArrayRef<GridAxis> srcSplitAxes = resSplitAxes[splitTensorDim].asArrayRef();
+  assert(srcSplitAxes.size() >= numUnsplitAxes);
+  size_t numSplitAxes = srcSplitAxes.size() - numUnsplitAxes;
+  SmallVector<GridAxis> newSplitAxes(srcSplitAxes.begin(),
+                                     srcSplitAxes.begin() + numSplitAxes);
+  resSplitAxes[splitTensorDim] = GridAxesAttr::get(ctx, newSplitAxes);
+  return Sharding::get(sourceSharding.getGridAttr(), resSplitAxes);
 }
 
-static ShapedType allGatherResultShapeInUnsplitLastAxis(
-    ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) {
-  SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
-  targetShape[splitTensorAxis] =
-      gatherDimension(targetShape[splitTensorAxis], splitCount);
-  return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
+// Return the resulting Tensor type after applying the unsplit last axes
+// resharding.
+static ShapedType allGatherResultTypeInUnsplitLastAxes(
+    ShapedType sourceType, int64_t splitTensorDim, ArrayRef<int64_t> gridShape,
+    ArrayRef<GridAxis> unsplitAxes) {
+  SmallVector<int64_t> targetShape = llvm::to_vector(sourceType.getShape());
+  for (GridAxis gridAxis : unsplitAxes)
+    targetShape[splitTensorDim] =
+        gatherDimension(targetShape[splitTensorDim], gridShape[gridAxis]);
+  return sourceType.cloneWith(targetShape, sourceType.getElementType());
 }
 
-static std::tuple<TypedValue<ShapedType>, Sharding> unsplitLastAxisInResharding(
+// Perform the resharding for the unsplit last axes case.
+// This basically performs an all-gather along the unsplit grid axes.
+static std::tuple<TypedValue<ShapedType>, Sharding> unsplitLastAxesInResharding(
     ImplicitLocOpBuilder &builder, Sharding sourceSharding,
     ShapedType sourceUnshardedShape, TypedValue<ShapedType> sourceShard,
-    GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis) {
+    GridOp grid, int64_t splitTensorDim, ArrayRef<GridAxis> unsplitAxes) {
   MLIRContext *ctx = builder.getContext();
   builder.setInsertionPointAfterValue(sourceShard);
 
-  Sharding targetSharding = targetShardingInUnsplitLastAxis(
-      ctx, std::move(sourceSharding), splitTensorAxis);
-  ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
-      sourceShard.getType(), grid.getShape()[splitGridAxis], splitTensorAxis);
+  Sharding targetSharding = targetShardingInUnsplitLastAxes(
+      ctx, std::move(sourceSharding), splitTensorDim, unsplitAxes.size());
+  ShapedType allGatherResultType = allGatherResultTypeInUnsplitLastAxes(
+      sourceShard.getType(), splitTensorDim, grid.getShape(), unsplitAxes);
   Value allGatherResult = AllGatherOp::create(
       builder,
-      RankedTensorType::get(allGatherResultShape.getShape(),
-                            allGatherResultShape.getElementType()),
-      grid.getSymName(), SmallVector<GridAxis>({splitGridAxis}), sourceShard,
-      APInt(64, splitTensorAxis));
-  ShapedType targetShape =
+      RankedTensorType::get(allGatherResultType.getShape(),
+                            allGatherResultType.getElementType()),
+      grid.getSymName(), unsplitAxes, sourceShard, APInt(64, splitTensorDim));
+  ShapedType targetType =
       shardShapedType(sourceUnshardedShape, grid, targetSharding);
   TypedValue<ShapedType> targetShard =
-      tensor::CastOp::create(builder, targetShape, allGatherResult).getResult();
+      tensor::CastOp::create(builder, targetType, allGatherResult).getResult();
   return {targetShard, targetSharding};
 }
 
 static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
-tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
+tryUnsplitLastAxesInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
                                const Sharding &sourceSharding,
                                Sharding targetSharding,
                                ShapedType sourceUnshardedShape,
                                TypedValue<ShapedType> sourceShard) {
-  if (auto detectRes = detectUnsplitLastAxisInResharding(
+  if (auto detectRes = detectUnsplitLastAxesInResharding(
           sourceSharding, std::move(targetSharding))) {
-    auto [tensorAxis, gridAxis] = detectRes.value();
-    return unsplitLastAxisInResharding(builder, sourceSharding,
+    auto [tensorDim, gridAxes] = detectRes.value();
+    return unsplitLastAxesInResharding(builder, sourceSharding,
                                        sourceUnshardedShape, sourceShard, grid,
-                                       tensorAxis, gridAxis);
+                                       tensorDim, gridAxes);
   }
 
   return std::nullopt;
@@ -477,7 +496,7 @@ reshard(ImplicitLocOpBuilder &builder, GridOp grid,
                    trySplitLastAxisInResharding(builder, grid, sourceSharding,
                                                 targetSharding, sourceShard)) {
       std::tie(targetShard, actualTargetSharding) = tryRes.value();
-    } else if (auto tryRes = tryUnsplitLastAxisInResharding(
+    } else if (auto tryRes = tryUnsplitLastAxesInResharding(
                    builder, grid, sourceSharding, targetSharding,
                    sourceUnshardedValue.getType(), sourceShard)) {
       std::tie(targetShard, actualTargetSharding) = tryRes.value();
diff --git a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
index 6161c131c8f50..f3da09d05e3b8 100644
--- a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
+++ b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
@@ -5,12 +5,23 @@
 shard.grid @grid0(shape = 3x4x5)
 func.func @process_multi_index() -> (index, index, index) {
   // CHECK: mpi.comm_rank
-  // CHECK: [[res:%.*]]:3 = affine.delinearize_index %1 into (3, 4, 5) : index, index, index 
+  // CHECK: [[v1:%.*]] = arith.index_cast
+  // CHECK: [[res:%.*]]:3 = affine.delinearize_index [[v1]] into (3, 4, 5) : index, index, index 
   %0:3 = shard.process_multi_index on @grid0 axes = [] : index, index, index
   // CHECK: return [[res]]#0, [[res]]#1, [[res]]#2 : index, index, index
   return %0#0, %0#1, %0#2 : index, index, index
 }
 
+// CHECK-LABEL: func @process_multi_index_reorder
+func.func @process_multi_index_reorder() -> (index, index) {
+  // CHECK: mpi.comm_rank
+  // CHECK: [[v1:%.*]] = arith.index_cast
+  // CHECK: [[v2:%.*]]:3 = affine.delinearize_index [[v1]] into (3, 4, 5) : index, index, index
+  %0:2 = shard.process_multi_index on @grid0 axes = [2, 0] : index, index
+  // CHECK: return [[v2]]#2, [[v2]]#0 : index, index
+  return %0#0, %0#1 : index, index
+}
+
 // CHECK-LABEL: func @process_linear_index
 func.func @process_linear_index() -> index {
   // CHECK: %[[RES:.*]], %[[rank:.*]] = mpi.comm_rank
diff --git a/mlir/test/Dialect/Shard/partition.mlir b/mlir/test/Dialect/Shard/partition.mlir
index 4c8271aefcafc..d5db8073fcf2e 100644
--- a/mlir/test/Dialect/Shard/partition.mlir
+++ b/mlir/test/Dialect/Shard/partition.mlir
@@ -5,6 +5,7 @@
 shard.grid @grid_1d(shape = 2)
 shard.grid @grid_1d_4(shape = 4)
 shard.grid @grid_2d_16(shape = 4x4)
+shard.grid @grid_4d(shape = 2x3x4x5)
 
 // CHECK-LABEL: func @return_sharding
 func.func @return_sharding(
@@ -52,6 +53,29 @@ func.func @sharding_triplet(
   return %sharded_1 : tensor<2xf32>
 }
 
+// CHECK-LABEL: func.func @unsplit_last_axes_some(
+// CHECK-SAME: [[varg0:%.*]]: tensor<6x2xi8>) -> tensor<6x24xi8> {
+func.func @unsplit_last_axes_some( %in2: tensor<6x48xi8>) -> tensor<6x48xi8> {
+  %sharding1 = shard.sharding @grid_4d split_axes = [[], [0,1,2]] : !shard.sharding
+  %in2_replicated = shard.shard %in2 to %sharding1 : tensor<6x48xi8>
+  %sharding2 = shard.sharding @grid_4d split_axes = [[], [0]] : !shard.sharding
+  %in2_sharded = shard.shard %in2_replicated to %sharding2 annotate_for_users : tensor<6x48xi8>
+  // CHECK: [[vall_gather:%.*]] = shard.all_gather [[varg0]] on @grid_4d grid_axes = [1, 2] gather_axis = 1 : tensor<6x2xi8> -> tensor<6x24xi8>
+  // CHECK: return [[vall_gather]] : tensor<6x24xi8>
+  return %in2_sharded : tensor<6x48xi8>
+}
+
+// CHECK-LABEL: func.func @unsplit_last_axes_all(
+// CHECK-SAME: [[varg0:%.*]]: tensor<2x48xi8>) -> tensor<48x48xi8> {
+func.func @unsplit_last_axes_all(%in2: tensor<48x48xi8>) -> tensor<48x48xi8> {
+  %sharding1 = shard.sharding @grid_4d split_axes = [[0,1,2]] : !shard.sharding
+  %in2_replicated = shard.shard %in2 to %sharding1 : tensor<48x48xi8>
+  %sharding2 = shard.sharding @grid_4d split_axes = [[]] : !shard.sharding
+  %in2_sharded = shard.shard %in2_replicated to %sharding2 annotate_for_users : tensor<48x48xi8>
+  // CHECK: [[vall_gather:%.*]] = shard.all_gather [[varg0]] on @grid_4d grid_axes = [0, 1, 2] gather_axis = 0 : tensor<2x48xi8> -> tensor<48x48xi8>
+  // CHECK: return [[vall_gather]] : tensor<48x48xi8>
+  return %in2_sharded : tensor<48x48xi8>
+}
 
 // CHECK-LABEL: func @move_split_axis
 func.func @move_split_axis(

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends the resharding pattern in the MLIR Shard dialect to support removing multiple trailing split axes along a tensor dimension, rather than just a single trailing axis. The change enables more flexible resharding operations by allowing patterns like [[0, 1, 2]] -> [[0]] or [[0, 1, 2]] -> [[]] in addition to the previously supported [[0, 1, 2]] -> [[0, 1]].

Changes:

  • Generalized the unsplit last axis resharding pattern to handle multiple trailing axes
  • Updated function and variable names to reflect the plural "axes" instead of singular "axis"
  • Added test cases for the new multi-axis unsplit functionality

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.

File Description
mlir/lib/Dialect/Shard/Transforms/Partition.cpp Refactored detection and handling logic to support unsplitting multiple trailing axes instead of just one
mlir/test/Dialect/Shard/partition.mlir Added test cases demonstrating unsplitting some and all trailing axes
mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir Added test for process multi-index reordering and updated existing test expectations

Comment thread mlir/lib/Dialect/Shard/Transforms/Partition.cpp Outdated
Comment thread mlir/lib/Dialect/Shard/Transforms/Partition.cpp Outdated
Copy link
Copy Markdown
Contributor

@tkarna tkarna left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me.

As I understand an immediate use case for this is to gather sharded computation results on rank 0 for verification purposes.

Comment thread mlir/test/Dialect/Shard/partition.mlir Outdated
@fschlimb fschlimb merged commit f5e5745 into llvm:main Feb 11, 2026
10 checks passed
kevinwkt pushed a commit to kevinwkt/llvm-project that referenced this pull request Feb 16, 2026
…m#180754)

A resharding pattern allowed only a single trailing axis to be
"unsplit".
This PR allows multiple trailing axes to be "unsplit".
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants