diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp index 9c5880e0c3b64..57502fbf9e276 100644 --- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp +++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp @@ -260,6 +260,17 @@ class UnsplitLastAxesPattern : public ReshardingPattern { } }; +// Compute the result shape of an all-to-all that gathers along srcTensorDim +// and scatters along tgtTensorDim with the given split count. +static ShapedType allToAllResultShape(ShapedType srcShape, int64_t splitCount, + int64_t srcTensorDim, + int64_t tgtTensorDim) { + SmallVector tgtShape = llvm::to_vector(srcShape.getShape()); + tgtShape[srcTensorDim] = gatherDimension(tgtShape[srcTensorDim], splitCount); + tgtShape[tgtTensorDim] = shardDimension(tgtShape[tgtTensorDim], splitCount); + return srcShape.cloneWith(tgtShape, srcShape.getElementType()); +} + /// Move a split axis between tensor dimensions: /// e.g. [[0], []] -> [[], [0]]. class MoveSplitAxisPattern : public ReshardingPattern { @@ -310,16 +321,6 @@ class MoveSplitAxisPattern : public ReshardingPattern { return Sharding::get(srcSharding.getGridAttr(), tgtShardingSplitAxes); } - static ShapedType allToAllResultShape(ShapedType srcShape, int64_t splitCount, - int64_t srcTensorDim, - int64_t tgtTensorDim) { - SmallVector tgtShape = llvm::to_vector(srcShape.getShape()); - tgtShape[srcTensorDim] = - gatherDimension(tgtShape[srcTensorDim], splitCount); - tgtShape[tgtTensorDim] = shardDimension(tgtShape[tgtTensorDim], splitCount); - return srcShape.cloneWith(tgtShape, srcShape.getElementType()); - } - static std::tuple, Sharding> apply(ImplicitLocOpBuilder &builder, GridOp grid, Sharding srcSharding, ShapedType srcUnshardedType, TypedValue srcShard, @@ -362,6 +363,130 @@ class MoveSplitAxisPattern : public ReshardingPattern { } }; +/// Move the last split axis of one tensor dimension to the front of another +/// tensor dimension's split axes, e.g. [[0, 1], [2]] -> [[0], [1, 2]]. +class MoveLastSplitAxisPattern : public ReshardingPattern { + // Detect if the resharding moves the last grid axis of srcTensorDim to the + // front of another tensor dimension's split axes. If detected, returns + // (tgtTensorDim, movedGridAxis). + // + // Pattern: src[srcTensorDim] = [a1,...,a(n-1),an] (n >= 2) + // tgt[srcTensorDim] = [a1,...,a(n-1)] + // src[tgtTensorDim] = [b1,...,bm] (m >= 0) + // tgt[tgtTensorDim] = [an, b1,...,bm] + static std::optional> + detect(const Sharding &srcSharding, const Sharding &tgtSharding, + int64_t srcTensorDim) { + if (static_cast(srcTensorDim) >= srcSharding.getSplitAxes().size()) + return std::nullopt; + auto srcAxes = srcSharding.getSplitAxes()[srcTensorDim].asArrayRef(); + // Need at least 2 axes to move the last one. + if (srcAxes.size() < 2) + return std::nullopt; + + // After the move the source tensor dim should lose its last axis. + if (static_cast(srcTensorDim) >= tgtSharding.getSplitAxes().size()) + return std::nullopt; + auto tgtSrcAxes = tgtSharding.getSplitAxes()[srcTensorDim].asArrayRef(); + if (tgtSrcAxes.size() + 1 != srcAxes.size()) + return std::nullopt; + // The remaining axes at srcTensorDim must be the same (prefix of source). + if (!llvm::equal(tgtSrcAxes, + llvm::make_range(srcAxes.begin(), srcAxes.end() - 1))) + return std::nullopt; + + GridAxis movedAxis = srcAxes.back(); + + // Find a target tensor dimension whose split axes start with movedAxis + // and whose remaining axes match the source sharding at that dimension. + for (size_t tgtTensorDim = 0; + tgtTensorDim < tgtSharding.getSplitAxes().size(); ++tgtTensorDim) { + if (static_cast(tgtTensorDim) == srcTensorDim) + continue; + auto tgtAxes = tgtSharding.getSplitAxes()[tgtTensorDim].asArrayRef(); + // The target dimension must start with the moved axis. + if (tgtAxes.empty() || tgtAxes.front() != movedAxis) + continue; + // The remainder of tgtAxes must equal the source sharding at + // tgtTensorDim. + ArrayRef srcTgtAxes = + static_cast(tgtTensorDim) < srcSharding.getSplitAxes().size() + ? srcSharding.getSplitAxes()[tgtTensorDim].asArrayRef() + : ArrayRef{}; + if (!llvm::equal(srcTgtAxes, + llvm::make_range(tgtAxes.begin() + 1, tgtAxes.end()))) + continue; + return std::make_tuple(static_cast(tgtTensorDim), movedAxis); + } + return std::nullopt; + } + + // Compute the result sharding after moving movedAxis from srcTensorDim + // to the front of tgtTensorDim. + static Sharding tgtSharding(MLIRContext *ctx, const Sharding &srcSharding, + int64_t srcTensorDim, int64_t tgtTensorDim, + GridAxis movedAxis) { + SmallVector splitAxes = + llvm::to_vector(srcSharding.getSplitAxes()); + while (static_cast(splitAxes.size()) <= tgtTensorDim) + splitAxes.push_back(GridAxesAttr::get(ctx, {})); + + // Remove last axis from srcTensorDim. + auto srcSplitAxes = llvm::to_vector(splitAxes[srcTensorDim].asArrayRef()); + assert(!srcSplitAxes.empty() && srcSplitAxes.back() == movedAxis); + srcSplitAxes.pop_back(); + splitAxes[srcTensorDim] = GridAxesAttr::get(ctx, srcSplitAxes); + + // Prepend movedAxis to tgtTensorDim. + auto tgtSplitAxes = llvm::to_vector(splitAxes[tgtTensorDim].asArrayRef()); + tgtSplitAxes.insert(tgtSplitAxes.begin(), movedAxis); + splitAxes[tgtTensorDim] = GridAxesAttr::get(ctx, tgtSplitAxes); + + return Sharding::get(srcSharding.getGridAttr(), splitAxes); + } + + static std::tuple, Sharding> + apply(ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &srcSharding, + ShapedType srcUnshardedType, TypedValue srcShard, + int64_t srcTensorDim, int64_t tgtTensorDim, GridAxis movedAxis) { + MLIRContext *ctx = builder.getContext(); + builder.setInsertionPointAfterValue(srcShard); + + Sharding resultSharding = + tgtSharding(ctx, srcSharding, srcTensorDim, tgtTensorDim, movedAxis); + ShapedType a2aResultShape = + allToAllResultShape(srcShard.getType(), grid.getShape()[movedAxis], + srcTensorDim, tgtTensorDim); + Value allToAllResult = AllToAllOp::create( + builder, + RankedTensorType::get(a2aResultShape.getShape(), + a2aResultShape.getElementType()), + grid.getSymName(), SmallVector({movedAxis}), srcShard, + APInt(64, tgtTensorDim), APInt(64, srcTensorDim)); + ShapedType tgtShape = + shardShapedType(srcUnshardedType, grid, resultSharding); + TypedValue tgtShard = + tensor::CastOp::create(builder, tgtShape, allToAllResult).getResult(); + return {tgtShard, resultSharding}; + } + +public: + std::optional, Sharding>> + tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim, + const Sharding &srcSharding, const Sharding &tgtSharding, + ShapedType srcUnshardedType, + TypedValue srcShard) override { + if (hasStaticOffsetsOrHalos(srcSharding, tgtSharding)) + return std::nullopt; + if (auto detectRes = detect(srcSharding, tgtSharding, tensorDim)) { + auto [tgtTensorDim, movedAxis] = detectRes.value(); + return apply(builder, grid, srcSharding, srcUnshardedType, srcShard, + tensorDim, tgtTensorDim, movedAxis); + } + return std::nullopt; + } +}; + /// Update halo sizes: handles cases where only the halo sizes differ between /// source and target sharding. Requires copying the "core" of the source tensor /// into the "core" of the destination tensor followed by an update halo op. @@ -460,12 +585,13 @@ static TypedValue reshard(ImplicitLocOpBuilder &builder, // Each pattern's tryApply checks its own applicability preconditions. static UpdateHaloPattern updateHaloPattern; + static MoveLastSplitAxisPattern moveLastSplitAxisPattern; static MoveSplitAxisPattern moveSplitAxisPattern; static SplitLastAxisPattern splitLastAxisPattern; static UnsplitLastAxesPattern unsplitLastAxesPattern; static ReshardingPattern *patterns[] = { - &updateHaloPattern, &moveSplitAxisPattern, &splitLastAxisPattern, - &unsplitLastAxesPattern}; + &updateHaloPattern, &moveLastSplitAxisPattern, &moveSplitAxisPattern, + &splitLastAxisPattern, &unsplitLastAxesPattern}; TypedValue currentShard = shardedSrc; Sharding currentSharding = srcSharding; for (int64_t dim = 0; diff --git a/mlir/test/Dialect/Shard/resharding-partition.mlir b/mlir/test/Dialect/Shard/resharding-partition.mlir index ff9e8408aa7fd..01c4733485678 100644 --- a/mlir/test/Dialect/Shard/resharding-partition.mlir +++ b/mlir/test/Dialect/Shard/resharding-partition.mlir @@ -2,6 +2,7 @@ shard.grid @grid_1d(shape = 2) shard.grid @grid_1d_dynamic(shape = ?) +shard.grid @grid_3d(shape = 2x2x2) // CHECK-LABEL: func @same_source_and_target_sharding func.func @same_source_and_target_sharding( @@ -153,7 +154,7 @@ func.func @unshard_dynamic_axis( // CHECK-LABEL: func @unshard_static_axis_on_dynamic_grid_axis func.func @unshard_static_axis_on_dynamic_grid_axis( -// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32> +// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32> %arg0: tensor<10x14xf32> ) -> tensor<10x14xf32> { // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor @@ -166,3 +167,59 @@ func.func @unshard_static_axis_on_dynamic_grid_axis( // CHECK: return %[[RES]] : tensor<10x14xf32> return %1 : tensor<10x14xf32> } + +// MoveLastSplitAxisPattern: [[0, 1], [2]] -> [[0], [1, 2]] +// Source shard: 8/(2*2) x 16/2 = 2x8; after all_to_all(axis=1): 4x4 +// CHECK-LABEL: func @move_last_split_axis_to_front_of_target_dim +func.func @move_last_split_axis_to_front_of_target_dim( + // CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32> + %arg0: tensor<8x16xf32> +) -> tensor<8x16xf32> { + // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<8x16xf32> to tensor<2x8xf32> + // CHECK: %[[ALL_TO_ALL:.*]] = shard.all_to_all %[[SOURCE_SHARD]] on @grid_3d grid_axes = [1] split_axis = 1 concat_axis = 0 : tensor<2x8xf32> -> tensor<4x4xf32> + // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[ALL_TO_ALL]] : tensor<4x4xf32> to tensor<8x16xf32> + %s0 = shard.sharding @grid_3d split_axes = [[0, 1], [2]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor<8x16xf32> + %s1 = shard.sharding @grid_3d split_axes = [[0], [1, 2]] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor<8x16xf32> + // CHECK: return %[[RES]] : tensor<8x16xf32> + return %1 : tensor<8x16xf32> +} + +// MoveLastSplitAxisPattern with tgtTensorDim < srcTensorDim: +// [[0], [1, 2]] -> [[2, 0], [1]] (axis 2 moved from dim 1 to front of dim 0) +// Source shard: 8/2 x 16/(2*2) = 4x4; after all_to_all(axis=2): 2x8 +// CHECK-LABEL: func @move_last_split_axis_to_lower_dim +func.func @move_last_split_axis_to_lower_dim( + // CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32> + %arg0: tensor<8x16xf32> +) -> tensor<8x16xf32> { + // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<8x16xf32> to tensor<4x4xf32> + // CHECK: %[[ALL_TO_ALL:.*]] = shard.all_to_all %[[SOURCE_SHARD]] on @grid_3d grid_axes = [2] split_axis = 0 concat_axis = 1 : tensor<4x4xf32> -> tensor<2x8xf32> + // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[ALL_TO_ALL]] : tensor<2x8xf32> to tensor<8x16xf32> + %s0 = shard.sharding @grid_3d split_axes = [[0], [1, 2]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor<8x16xf32> + %s1 = shard.sharding @grid_3d split_axes = [[2, 0], [1]] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor<8x16xf32> + // CHECK: return %[[RES]] : tensor<8x16xf32> + return %1 : tensor<8x16xf32> +} + +// MoveLastSplitAxisPattern where source has no axes at tgtTensorDim: +// [[0, 1]] -> [[0], [1]] (tgtTensorDim has empty source) +// Source shard: 8/(2*2) x 16 = 2x16; after all_to_all(axis=1): 4x8 +// CHECK-LABEL: func @move_last_split_axis_empty_source_at_target_dim +func.func @move_last_split_axis_empty_source_at_target_dim( + // CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32> + %arg0: tensor<8x16xf32> +) -> tensor<8x16xf32> { + // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<8x16xf32> to tensor<2x16xf32> + // CHECK: %[[ALL_TO_ALL:.*]] = shard.all_to_all %[[SOURCE_SHARD]] on @grid_3d grid_axes = [1] split_axis = 1 concat_axis = 0 : tensor<2x16xf32> -> tensor<4x8xf32> + // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[ALL_TO_ALL]] : tensor<4x8xf32> to tensor<8x16xf32> + %s0 = shard.sharding @grid_3d split_axes = [[0, 1]] : !shard.sharding + %0 = shard.shard %arg0 to %s0 : tensor<8x16xf32> + %s1 = shard.sharding @grid_3d split_axes = [[0], [1]] : !shard.sharding + %1 = shard.shard %0 to %s1 annotate_for_users : tensor<8x16xf32> + // CHECK: return %[[RES]] : tensor<8x16xf32> + return %1 : tensor<8x16xf32> +}