From 7ec937bd1e8a8ae56d481e43688efb6010688913 Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Tue, 10 Feb 2026 07:25:14 -0800 Subject: [PATCH 1/3] allow more than one last axis to be "unsplit" --- .../Dialect/Shard/Transforms/Partition.cpp | 135 ++++++++++-------- .../ShardToMPI/convert-shard-to-mpi.mlir | 13 +- mlir/test/Dialect/Shard/partition.mlir | 24 ++++ 3 files changed, 113 insertions(+), 59 deletions(-) diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp index e619c7073a8c4..8652d665e46bf 100644 --- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp +++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp @@ -132,98 +132,117 @@ trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, return std::nullopt; } -// Detect if the resharding is of type e.g. -// [[0, 1, 2]] -> [[0, 1]]. -// If detected, returns the corresponding tensor axis grid axis pair. -static std::optional> -detectUnsplitLastAxisInResharding(const Sharding &sourceSharding, +// Detect if the resharding removes trailing split Axes along a tensor +// dimension, e.g. +// [[0, 1, 2]] -> [[0, 1]], [[0, 1, 2]] -> [0] or [[0, 1, 2]] -> []. +// If detected, returns the corresponding (tensor dim, grid axes) pair, where +// the "grid axes" are the removed trailing split axes. +static std::optional>> +detectUnsplitLastAxesInResharding(const Sharding &sourceSharding, const Sharding &targetSharding) { - for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size(); - ++tensorAxis) { - if (targetSharding.getSplitAxes().size() > tensorAxis) { - if (sourceSharding.getSplitAxes()[tensorAxis].size() != - targetSharding.getSplitAxes()[tensorAxis].size() + 1) + for (size_t tensorDim = 0; tensorDim < sourceSharding.getSplitAxes().size(); + ++tensorDim) { + if (targetSharding.getSplitAxes().size() > tensorDim) { + // No match if the target sharding does not have less split axes than the + // source sharding along the current tensor dimension. + if (sourceSharding.getSplitAxes()[tensorDim].size() <= + targetSharding.getSplitAxes()[tensorDim].size()) continue; - if (!llvm::equal( - llvm::make_range( - sourceSharding.getSplitAxes()[tensorAxis] - .asArrayRef() - .begin(), - sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() - - 1), - targetSharding.getSplitAxes()[tensorAxis].asArrayRef())) + // No match if the split axes of the target sharding are different from + // the first split axes of the source sharding. + if (!std::equal( + targetSharding.getSplitAxes()[tensorDim].asArrayRef().begin(), + targetSharding.getSplitAxes()[tensorDim].asArrayRef().end(), + sourceSharding.getSplitAxes()[tensorDim].asArrayRef().begin())) continue; } else { - if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1) + // Here the target dimension is replicated; there is nothing to do if the + // source dimension is also replicated. + if (sourceSharding.getSplitAxes()[tensorDim].size() == 0) continue; } - return std::make_tuple( - tensorAxis, - sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back()); + // This is a match. Return the current tensor dimension and the trailing + // grid axis of the source sharding along this dimension. + SmallVector unsplitAxes; + size_t dimOff = tensorDim >= targetSharding.getSplitAxes().size() + ? 0 + : targetSharding.getSplitAxes()[tensorDim].size(); + for (auto a = + sourceSharding.getSplitAxes()[tensorDim].asArrayRef().begin() + + dimOff; + a != sourceSharding.getSplitAxes()[tensorDim].asArrayRef().end(); ++a) + unsplitAxes.push_back(*a); + return std::make_tuple(tensorDim, unsplitAxes); } return std::nullopt; } -static Sharding targetShardingInUnsplitLastAxis(MLIRContext *ctx, +// Return the resulting Sharding if the unsplit last axes resharding is applied. +static Sharding targetShardingInUnsplitLastAxes(MLIRContext *ctx, const Sharding &sourceSharding, - int64_t splitTensorAxis) { - SmallVector targetShardingSplitAxes = + int64_t splitTensorDim, + size_t numUnsplitAxes) { + SmallVector resSplitAxes = llvm::to_vector(sourceSharding.getSplitAxes()); - assert(static_cast(targetShardingSplitAxes.size()) > - splitTensorAxis); - auto targetSplitAxes = - llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef()); - - targetSplitAxes.pop_back(); - targetShardingSplitAxes[splitTensorAxis] = - GridAxesAttr::get(ctx, targetSplitAxes); - return Sharding::get(sourceSharding.getGridAttr(), targetShardingSplitAxes); + assert(static_cast(resSplitAxes.size()) > splitTensorDim); + ArrayRef srcSplitAxes = resSplitAxes[splitTensorDim].asArrayRef(); + assert(srcSplitAxes.size() >= numUnsplitAxes); + size_t numSplitAxes = srcSplitAxes.size() - numUnsplitAxes; + SmallVector newSplitAxes(srcSplitAxes.begin(), + srcSplitAxes.begin() + numSplitAxes); + resSplitAxes[splitTensorDim] = GridAxesAttr::get(ctx, newSplitAxes); + return Sharding::get(sourceSharding.getGridAttr(), resSplitAxes); } -static ShapedType allGatherResultShapeInUnsplitLastAxis( - ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) { - SmallVector targetShape = llvm::to_vector(sourceShape.getShape()); - targetShape[splitTensorAxis] = - gatherDimension(targetShape[splitTensorAxis], splitCount); - return sourceShape.cloneWith(targetShape, sourceShape.getElementType()); +// Return the resulting Tensor type after applying the unsplit last axes +// resharding. +static ShapedType allGatherResultTypeInUnsplitLastAxes( + ShapedType sourceType, int64_t splitTensorDim, ArrayRef gridShape, + ArrayRef unsplitAxes) { + SmallVector targetShape = llvm::to_vector(sourceType.getShape()); + for (GridAxis gridAxis : unsplitAxes) + targetShape[splitTensorDim] = + gatherDimension(targetShape[splitTensorDim], gridShape[gridAxis]); + return sourceType.cloneWith(targetShape, sourceType.getElementType()); } -static std::tuple, Sharding> unsplitLastAxisInResharding( +// Perform the resharding for the unsplit last axes case. +// This basically performs an all-gather along the unsplit grid axes. +static std::tuple, Sharding> unsplitLastAxesInResharding( ImplicitLocOpBuilder &builder, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue sourceShard, - GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis) { + GridOp grid, int64_t splitTensorDim, ArrayRef unsplitAxes) { MLIRContext *ctx = builder.getContext(); builder.setInsertionPointAfterValue(sourceShard); - Sharding targetSharding = targetShardingInUnsplitLastAxis( - ctx, std::move(sourceSharding), splitTensorAxis); - ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis( - sourceShard.getType(), grid.getShape()[splitGridAxis], splitTensorAxis); + Sharding targetSharding = targetShardingInUnsplitLastAxes( + ctx, std::move(sourceSharding), splitTensorDim, unsplitAxes.size()); + ShapedType allGatherResultType = allGatherResultTypeInUnsplitLastAxes( + sourceShard.getType(), splitTensorDim, grid.getShape(), unsplitAxes); Value allGatherResult = AllGatherOp::create( builder, - RankedTensorType::get(allGatherResultShape.getShape(), - allGatherResultShape.getElementType()), - grid.getSymName(), SmallVector({splitGridAxis}), sourceShard, - APInt(64, splitTensorAxis)); - ShapedType targetShape = + RankedTensorType::get(allGatherResultType.getShape(), + allGatherResultType.getElementType()), + grid.getSymName(), unsplitAxes, sourceShard, APInt(64, splitTensorDim)); + ShapedType targetType = shardShapedType(sourceUnshardedShape, grid, targetSharding); TypedValue targetShard = - tensor::CastOp::create(builder, targetShape, allGatherResult).getResult(); + tensor::CastOp::create(builder, targetType, allGatherResult).getResult(); return {targetShard, targetSharding}; } static std::optional, Sharding>> -tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, +tryUnsplitLastAxesInResharding(ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue sourceShard) { - if (auto detectRes = detectUnsplitLastAxisInResharding( + if (auto detectRes = detectUnsplitLastAxesInResharding( sourceSharding, std::move(targetSharding))) { - auto [tensorAxis, gridAxis] = detectRes.value(); - return unsplitLastAxisInResharding(builder, sourceSharding, + auto [tensorDim, gridAxes] = detectRes.value(); + return unsplitLastAxesInResharding(builder, sourceSharding, sourceUnshardedShape, sourceShard, grid, - tensorAxis, gridAxis); + tensorDim, gridAxes); } return std::nullopt; @@ -477,7 +496,7 @@ reshard(ImplicitLocOpBuilder &builder, GridOp grid, trySplitLastAxisInResharding(builder, grid, sourceSharding, targetSharding, sourceShard)) { std::tie(targetShard, actualTargetSharding) = tryRes.value(); - } else if (auto tryRes = tryUnsplitLastAxisInResharding( + } else if (auto tryRes = tryUnsplitLastAxesInResharding( builder, grid, sourceSharding, targetSharding, sourceUnshardedValue.getType(), sourceShard)) { std::tie(targetShard, actualTargetSharding) = tryRes.value(); diff --git a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir index 6161c131c8f50..f3da09d05e3b8 100644 --- a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir +++ b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir @@ -5,12 +5,23 @@ shard.grid @grid0(shape = 3x4x5) func.func @process_multi_index() -> (index, index, index) { // CHECK: mpi.comm_rank - // CHECK: [[res:%.*]]:3 = affine.delinearize_index %1 into (3, 4, 5) : index, index, index + // CHECK: [[v1:%.*]] = arith.index_cast + // CHECK: [[res:%.*]]:3 = affine.delinearize_index [[v1]] into (3, 4, 5) : index, index, index %0:3 = shard.process_multi_index on @grid0 axes = [] : index, index, index // CHECK: return [[res]]#0, [[res]]#1, [[res]]#2 : index, index, index return %0#0, %0#1, %0#2 : index, index, index } +// CHECK-LABEL: func @process_multi_index_reorder +func.func @process_multi_index_reorder() -> (index, index) { + // CHECK: mpi.comm_rank + // CHECK: [[v1:%.*]] = arith.index_cast + // CHECK: [[v2:%.*]]:3 = affine.delinearize_index [[v1]] into (3, 4, 5) : index, index, index + %0:2 = shard.process_multi_index on @grid0 axes = [2, 0] : index, index + // CHECK: return [[v2]]#2, [[v2]]#0 : index, index + return %0#0, %0#1 : index, index +} + // CHECK-LABEL: func @process_linear_index func.func @process_linear_index() -> index { // CHECK: %[[RES:.*]], %[[rank:.*]] = mpi.comm_rank diff --git a/mlir/test/Dialect/Shard/partition.mlir b/mlir/test/Dialect/Shard/partition.mlir index 4c8271aefcafc..d5db8073fcf2e 100644 --- a/mlir/test/Dialect/Shard/partition.mlir +++ b/mlir/test/Dialect/Shard/partition.mlir @@ -5,6 +5,7 @@ shard.grid @grid_1d(shape = 2) shard.grid @grid_1d_4(shape = 4) shard.grid @grid_2d_16(shape = 4x4) +shard.grid @grid_4d(shape = 2x3x4x5) // CHECK-LABEL: func @return_sharding func.func @return_sharding( @@ -52,6 +53,29 @@ func.func @sharding_triplet( return %sharded_1 : tensor<2xf32> } +// CHECK-LABEL: func.func @unsplit_last_axes_some( +// CHECK-SAME: [[varg0:%.*]]: tensor<6x2xi8>) -> tensor<6x24xi8> { +func.func @unsplit_last_axes_some( %in2: tensor<6x48xi8>) -> tensor<6x48xi8> { + %sharding1 = shard.sharding @grid_4d split_axes = [[], [0,1,2]] : !shard.sharding + %in2_replicated = shard.shard %in2 to %sharding1 : tensor<6x48xi8> + %sharding2 = shard.sharding @grid_4d split_axes = [[], [0]] : !shard.sharding + %in2_sharded = shard.shard %in2_replicated to %sharding2 annotate_for_users : tensor<6x48xi8> + // CHECK: [[vall_gather:%.*]] = shard.all_gather [[varg0]] on @grid_4d grid_axes = [1, 2] gather_axis = 1 : tensor<6x2xi8> -> tensor<6x24xi8> + // CHECK: return [[vall_gather]] : tensor<6x24xi8> + return %in2_sharded : tensor<6x48xi8> +} + +// CHECK-LABEL: func.func @unsplit_last_axes_all( +// CHECK-SAME: [[varg0:%.*]]: tensor<2x48xi8>) -> tensor<48x48xi8> { +func.func @unsplit_last_axes_all(%in2: tensor<48x48xi8>) -> tensor<48x48xi8> { + %sharding1 = shard.sharding @grid_4d split_axes = [[0,1,2]] : !shard.sharding + %in2_replicated = shard.shard %in2 to %sharding1 : tensor<48x48xi8> + %sharding2 = shard.sharding @grid_4d split_axes = [[]] : !shard.sharding + %in2_sharded = shard.shard %in2_replicated to %sharding2 annotate_for_users : tensor<48x48xi8> + // CHECK: [[vall_gather:%.*]] = shard.all_gather [[varg0]] on @grid_4d grid_axes = [0, 1, 2] gather_axis = 0 : tensor<2x48xi8> -> tensor<48x48xi8> + // CHECK: return [[vall_gather]] : tensor<48x48xi8> + return %in2_sharded : tensor<48x48xi8> +} // CHECK-LABEL: func @move_split_axis func.func @move_split_axis( From 3cc6a5394abdb3970e5f1f957c683ef0c03c95c5 Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Tue, 10 Feb 2026 07:54:18 -0800 Subject: [PATCH 2/3] easier to read; idomatic use of ArrayRef --- .../Dialect/Shard/Transforms/Partition.cpp | 37 ++++++++----------- 1 file changed, 16 insertions(+), 21 deletions(-) diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp index 8652d665e46bf..a2b3c86cac28d 100644 --- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp +++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp @@ -138,40 +138,35 @@ trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, // If detected, returns the corresponding (tensor dim, grid axes) pair, where // the "grid axes" are the removed trailing split axes. static std::optional>> -detectUnsplitLastAxesInResharding(const Sharding &sourceSharding, - const Sharding &targetSharding) { - for (size_t tensorDim = 0; tensorDim < sourceSharding.getSplitAxes().size(); - ++tensorDim) { - if (targetSharding.getSplitAxes().size() > tensorDim) { +detectUnsplitLastAxesInResharding(const Sharding &srcSharding, + const Sharding &tgtSharding) { + size_t dimOff = 0; + size_t srcSize = srcSharding.getSplitAxes().size(); + for (size_t tensorDim = 0; tensorDim < srcSize; ++tensorDim) { + auto srcSplitAxes = srcSharding.getSplitAxes()[tensorDim].asArrayRef(); + if (tgtSharding.getSplitAxes().size() > tensorDim) { + auto tgtSplitAxes = tgtSharding.getSplitAxes()[tensorDim].asArrayRef(); // No match if the target sharding does not have less split axes than the // source sharding along the current tensor dimension. - if (sourceSharding.getSplitAxes()[tensorDim].size() <= - targetSharding.getSplitAxes()[tensorDim].size()) + if (srcSplitAxes.size() <= tgtSplitAxes.size()) continue; // No match if the split axes of the target sharding are different from // the first split axes of the source sharding. - if (!std::equal( - targetSharding.getSplitAxes()[tensorDim].asArrayRef().begin(), - targetSharding.getSplitAxes()[tensorDim].asArrayRef().end(), - sourceSharding.getSplitAxes()[tensorDim].asArrayRef().begin())) + if (!std::equal(tgtSplitAxes.begin(), tgtSplitAxes.end(), + srcSplitAxes.begin())) continue; + dimOff = tgtSplitAxes.size(); } else { // Here the target dimension is replicated; there is nothing to do if the // source dimension is also replicated. - if (sourceSharding.getSplitAxes()[tensorDim].size() == 0) + if (srcSplitAxes.size() == 0) continue; + dimOff = 0; } // This is a match. Return the current tensor dimension and the trailing // grid axis of the source sharding along this dimension. - SmallVector unsplitAxes; - size_t dimOff = tensorDim >= targetSharding.getSplitAxes().size() - ? 0 - : targetSharding.getSplitAxes()[tensorDim].size(); - for (auto a = - sourceSharding.getSplitAxes()[tensorDim].asArrayRef().begin() + - dimOff; - a != sourceSharding.getSplitAxes()[tensorDim].asArrayRef().end(); ++a) - unsplitAxes.push_back(*a); + ArrayRef trailingAxes = srcSplitAxes.drop_front(dimOff); + SmallVector unsplitAxes(trailingAxes.begin(), trailingAxes.end()); return std::make_tuple(tensorDim, unsplitAxes); } return std::nullopt; From abd14a753f835797afa70635c8bd0fd105ad5b86 Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Wed, 11 Feb 2026 07:10:49 -0800 Subject: [PATCH 3/3] fixing names --- mlir/test/Dialect/Shard/partition.mlir | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mlir/test/Dialect/Shard/partition.mlir b/mlir/test/Dialect/Shard/partition.mlir index d5db8073fcf2e..c289c43cc3172 100644 --- a/mlir/test/Dialect/Shard/partition.mlir +++ b/mlir/test/Dialect/Shard/partition.mlir @@ -56,25 +56,25 @@ func.func @sharding_triplet( // CHECK-LABEL: func.func @unsplit_last_axes_some( // CHECK-SAME: [[varg0:%.*]]: tensor<6x2xi8>) -> tensor<6x24xi8> { func.func @unsplit_last_axes_some( %in2: tensor<6x48xi8>) -> tensor<6x48xi8> { - %sharding1 = shard.sharding @grid_4d split_axes = [[], [0,1,2]] : !shard.sharding - %in2_replicated = shard.shard %in2 to %sharding1 : tensor<6x48xi8> + %sharding0 = shard.sharding @grid_4d split_axes = [[], [0,1,2]] : !shard.sharding + %sharding1 = shard.shard %in2 to %sharding0 : tensor<6x48xi8> %sharding2 = shard.sharding @grid_4d split_axes = [[], [0]] : !shard.sharding - %in2_sharded = shard.shard %in2_replicated to %sharding2 annotate_for_users : tensor<6x48xi8> + %sharding3 = shard.shard %sharding1 to %sharding2 annotate_for_users : tensor<6x48xi8> // CHECK: [[vall_gather:%.*]] = shard.all_gather [[varg0]] on @grid_4d grid_axes = [1, 2] gather_axis = 1 : tensor<6x2xi8> -> tensor<6x24xi8> // CHECK: return [[vall_gather]] : tensor<6x24xi8> - return %in2_sharded : tensor<6x48xi8> + return %sharding3 : tensor<6x48xi8> } // CHECK-LABEL: func.func @unsplit_last_axes_all( // CHECK-SAME: [[varg0:%.*]]: tensor<2x48xi8>) -> tensor<48x48xi8> { func.func @unsplit_last_axes_all(%in2: tensor<48x48xi8>) -> tensor<48x48xi8> { - %sharding1 = shard.sharding @grid_4d split_axes = [[0,1,2]] : !shard.sharding - %in2_replicated = shard.shard %in2 to %sharding1 : tensor<48x48xi8> + %sharding0 = shard.sharding @grid_4d split_axes = [[0,1,2]] : !shard.sharding + %sharding1 = shard.shard %in2 to %sharding0 : tensor<48x48xi8> %sharding2 = shard.sharding @grid_4d split_axes = [[]] : !shard.sharding - %in2_sharded = shard.shard %in2_replicated to %sharding2 annotate_for_users : tensor<48x48xi8> + %sharding3 = shard.shard %sharding1 to %sharding2 annotate_for_users : tensor<48x48xi8> // CHECK: [[vall_gather:%.*]] = shard.all_gather [[varg0]] on @grid_4d grid_axes = [0, 1, 2] gather_axis = 0 : tensor<2x48xi8> -> tensor<48x48xi8> // CHECK: return [[vall_gather]] : tensor<48x48xi8> - return %in2_sharded : tensor<48x48xi8> + return %sharding3 : tensor<48x48xi8> } // CHECK-LABEL: func @move_split_axis