diff --git a/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp b/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp index 525ff007bc2f6..bc11d5d3fcf93 100644 --- a/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp +++ b/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp @@ -131,6 +131,66 @@ struct AllReduceAllSliceSimplification : OpRewritePattern { } }; +// Simplify AllSliceOp(AllGatherOp) -> input when both ops share the same grid, +// grid_axes and gather/slice axis. +// +// AllGather concatenates in-group slices along gather_axis and replicates the +// concatenated result. AllSlice on the same axis then takes each device-local +// in-group slice from that replicated tensor, i.e. exactly the original input. +struct AllGatherAllSliceSimplification : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AllSliceOp sliceOp, + PatternRewriter &rewriter) const override { + auto gatherOp = sliceOp.getInput().getDefiningOp(); + if (!gatherOp) + return failure(); + + if (gatherOp.getGrid() != sliceOp.getGrid() || + gatherOp.getGridAxes() != sliceOp.getGridAxes()) + return failure(); + + if (gatherOp.getGatherAxis() != sliceOp.getSliceAxis()) + return failure(); + + if (gatherOp.getInput().getType() != sliceOp.getResult().getType()) + return failure(); + + rewriter.replaceOp(sliceOp, gatherOp.getInput()); + return success(); + } +}; + +// Simplify AllGatherOp(ReduceScatterOp) -> AllReduceOp when both ops share the +// same grid, grid_axes and gather/scatter axis. +// +// ReduceScatter computes an element-wise reduction and scatters along a tensor +// axis. AllGather along the same axis reassembles that full reduced tensor and +// replicates it to all participants, which is exactly AllReduce. +struct ReduceScatterAllGatherSimplification : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AllGatherOp gatherOp, + PatternRewriter &rewriter) const override { + auto reduceScatterOp = gatherOp.getInput().getDefiningOp(); + if (!reduceScatterOp) + return failure(); + + if (reduceScatterOp.getGrid() != gatherOp.getGrid() || + reduceScatterOp.getGridAxes() != gatherOp.getGridAxes()) + return failure(); + + if (reduceScatterOp.getScatterDim() != gatherOp.getGatherAxis()) + return failure(); + + rewriter.replaceOpWithNewOp( + gatherOp, gatherOp.getResult().getType(), gatherOp.getGridAttr(), + gatherOp.getGridAxesAttr(), reduceScatterOp.getInput(), + reduceScatterOp.getReductionAttr()); + return success(); + } +}; + } // namespace void populateSimplifyPatterns(RewritePatternSet &patterns, @@ -154,7 +214,8 @@ void populateSimplifyPatterns(RewritePatternSet &patterns, populateAllReduceEndomorphismSimplifyPatterns( patterns, ReductionKind::Max); - patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); // TODO: add simplify patterns for all-gather and other collectives. diff --git a/mlir/test/Dialect/Shard/simplify.mlir b/mlir/test/Dialect/Shard/simplify.mlir index e5693a288fda6..5f8b1b0fac83a 100644 --- a/mlir/test/Dialect/Shard/simplify.mlir +++ b/mlir/test/Dialect/Shard/simplify.mlir @@ -260,3 +260,150 @@ func.func @all_reduce_all_slice_type_promotion( // CHECK: return %[[RS]] return %1 : tensor<1x8xf64> } + +// ----- +// AllGatherOp + AllSliceOp -> input tests +// ----- + +// Basic inverse case: all_slice(all_gather(x)) with matching grid/axes/axis. +// CHECK-LABEL: func.func @all_gather_all_slice_to_input +func.func @all_gather_all_slice_to_input( + // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<1x8xf32> + %arg0: tensor<1x8xf32>) -> tensor<1x8xf32> { + %0 = shard.all_gather %arg0 on @grid0 grid_axes = [0] gather_axis = 0 + : tensor<1x8xf32> -> tensor<4x8xf32> + %1 = shard.all_slice %0 on @grid0 grid_axes = [0] slice_axis = 0 + : tensor<4x8xf32> -> tensor<1x8xf32> + // CHECK-NOT: shard.all_gather + // CHECK-NOT: shard.all_slice + // CHECK: return %[[ARG0]] + return %1 : tensor<1x8xf32> +} + +// Do not fold if gather/slice grid axes differ. +// CHECK-LABEL: func.func @all_gather_all_slice_different_grid_axes +func.func @all_gather_all_slice_different_grid_axes( + // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<1x8xf32> + %arg0: tensor<1x8xf32>) -> tensor<2x8xf32> { + // CHECK: %[[AG:.*]] = shard.all_gather %[[ARG0]] on @grid0 grid_axes = [0] gather_axis = 0 + %0 = shard.all_gather %arg0 on @grid0 grid_axes = [0] gather_axis = 0 + : tensor<1x8xf32> -> tensor<4x8xf32> + // CHECK: %[[AS:.*]] = shard.all_slice %[[AG]] on @grid0 grid_axes = [1] slice_axis = 0 + %1 = shard.all_slice %0 on @grid0 grid_axes = [1] slice_axis = 0 + : tensor<4x8xf32> -> tensor<2x8xf32> + // CHECK: return %[[AS]] + return %1 : tensor<2x8xf32> +} + +// Do not fold if gather/slice grids differ. +// CHECK-LABEL: func.func @all_gather_all_slice_different_grid +func.func @all_gather_all_slice_different_grid( + // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<1x8xf32> + %arg0: tensor<1x8xf32>) -> tensor<1x8xf32> { + // CHECK: %[[AG:.*]] = shard.all_gather %[[ARG0]] on @grid0 grid_axes = [0] gather_axis = 0 + %0 = shard.all_gather %arg0 on @grid0 grid_axes = [0] gather_axis = 0 + : tensor<1x8xf32> -> tensor<4x8xf32> + // CHECK: %[[AS:.*]] = shard.all_slice %[[AG]] on @grid1 grid_axes = [0] slice_axis = 0 + %1 = shard.all_slice %0 on @grid1 grid_axes = [0] slice_axis = 0 + : tensor<4x8xf32> -> tensor<1x8xf32> + // CHECK: return %[[AS]] + return %1 : tensor<1x8xf32> +} + +// Do not fold if gather/slice tensor axes differ. +// CHECK-LABEL: func.func @all_gather_all_slice_different_tensor_axes +func.func @all_gather_all_slice_different_tensor_axes( + // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<2x2xf32> + %arg0: tensor<2x2xf32>) -> tensor<4x1xf32> { + // CHECK: %[[AG:.*]] = shard.all_gather %[[ARG0]] on @grid0 grid_axes = [1] gather_axis = 0 + %0 = shard.all_gather %arg0 on @grid0 grid_axes = [1] gather_axis = 0 + : tensor<2x2xf32> -> tensor<4x2xf32> + // CHECK: %[[AS:.*]] = shard.all_slice %[[AG]] on @grid0 grid_axes = [1] slice_axis = 1 + %1 = shard.all_slice %0 on @grid0 grid_axes = [1] slice_axis = 1 + : tensor<4x2xf32> -> tensor<4x1xf32> + // CHECK: return %[[AS]] + return %1 : tensor<4x1xf32> +} + +// ----- +// ReduceScatterOp + AllGatherOp -> AllReduceOp tests +// ----- + +// Basic case: all_gather(reduce_scatter(x)) with matching grid/axes/axis folds +// into all_reduce. +// CHECK-LABEL: func.func @reduce_scatter_all_gather_to_all_reduce +func.func @reduce_scatter_all_gather_to_all_reduce( + // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32> + %arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { + %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_dim = 0 + : tensor<4x8xf32> -> tensor<1x8xf32> + %1 = shard.all_gather %0 on @grid0 grid_axes = [0] gather_axis = 0 + : tensor<1x8xf32> -> tensor<4x8xf32> + // CHECK: %[[AR:.*]] = shard.all_reduce %[[ARG0]] on @grid0 grid_axes = [0] + // CHECK-SAME: : tensor<4x8xf32> -> tensor<4x8xf32> + // CHECK: return %[[AR]] + return %1 : tensor<4x8xf32> +} + +// Verify reduction kind is preserved through the rewrite. +// CHECK-LABEL: func.func @reduce_scatter_all_gather_preserve_reduction +func.func @reduce_scatter_all_gather_preserve_reduction( + // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32> + %arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { + %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] reduction = max scatter_dim = 0 + : tensor<4x8xf32> -> tensor<1x8xf32> + %1 = shard.all_gather %0 on @grid0 grid_axes = [0] gather_axis = 0 + : tensor<1x8xf32> -> tensor<4x8xf32> + // CHECK: %[[AR:.*]] = shard.all_reduce %[[ARG0]] on @grid0 grid_axes = [0] reduction = max + // CHECK-SAME: : tensor<4x8xf32> -> tensor<4x8xf32> + // CHECK: return %[[AR]] + return %1 : tensor<4x8xf32> +} + +// Do not fold if reduce-scatter/all-gather grid axes differ. +// CHECK-LABEL: func.func @reduce_scatter_all_gather_different_grid_axes +func.func @reduce_scatter_all_gather_different_grid_axes( + // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32> + %arg0: tensor<4x8xf32>) -> tensor<2x8xf32> { + // CHECK: %[[RS:.*]] = shard.reduce_scatter %[[ARG0]] on @grid0 grid_axes = [0] scatter_dim = 0 + %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_dim = 0 + : tensor<4x8xf32> -> tensor<1x8xf32> + // CHECK: %[[AG:.*]] = shard.all_gather %[[RS]] on @grid0 grid_axes = [1] gather_axis = 0 + %1 = shard.all_gather %0 on @grid0 grid_axes = [1] gather_axis = 0 + : tensor<1x8xf32> -> tensor<2x8xf32> + // CHECK: return %[[AG]] + return %1 : tensor<2x8xf32> +} + +// Do not fold if reduce-scatter/all-gather grids differ. +// CHECK-LABEL: func.func @reduce_scatter_all_gather_different_grid +func.func @reduce_scatter_all_gather_different_grid( + // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32> + %arg0: tensor<4x8xf32>) -> tensor<2x8xf32> { + // CHECK: %[[RS:.*]] = shard.reduce_scatter %[[ARG0]] on @grid1 grid_axes = [0] scatter_dim = 0 + %0 = shard.reduce_scatter %arg0 on @grid1 grid_axes = [0] scatter_dim = 0 + : tensor<4x8xf32> -> tensor<1x8xf32> + // CHECK: %[[AG:.*]] = shard.all_gather %[[RS]] on @grid0 grid_axes = [1] gather_axis = 0 + %1 = shard.all_gather %0 on @grid0 grid_axes = [1] gather_axis = 0 + : tensor<1x8xf32> -> tensor<2x8xf32> + // CHECK: return %[[AG]] + return %1 : tensor<2x8xf32> +} + +// Do not fold if scatter/gather tensor axes differ. +// CHECK-LABEL: func.func @reduce_scatter_all_gather_different_tensor_axes +func.func @reduce_scatter_all_gather_different_tensor_axes( + // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32> + %arg0: tensor<4x8xf32>) -> tensor<2x8xf32> { + // CHECK: %[[RS:.*]] = shard.reduce_scatter %[[ARG0]] on @grid0 grid_axes = [1] scatter_dim = 0 + %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [1] scatter_dim = 0 + : tensor<4x8xf32> -> tensor<2x8xf32> + // CHECK: %[[AG:.*]] = shard.all_gather %[[RS]] on @grid0 grid_axes = [1] gather_axis = 1 + %1 = shard.all_gather %0 on @grid0 grid_axes = [1] gather_axis = 1 + : tensor<2x8xf32> -> tensor<2x16xf32> + // Keep function result type simple by slicing back. + %2 = tensor.extract_slice %1[0, 0] [2, 8] [1, 1] + : tensor<2x16xf32> to tensor<2x8xf32> + // CHECK: return %{{.*}} + return %2 : tensor<2x8xf32> +}