diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td index b342b6a057a04..8c9a1e3ad1d80 100644 --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -544,6 +544,8 @@ def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [ "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides, CArg<"ArrayRef", "{}">:$attrs)> ]; + + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp index 7e743109b6e09..a160ba028928c 100644 --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -1229,6 +1229,45 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result, build(b, result, source, dest, offsetValues, sizeValues, strideValues); } +namespace { +/// Pattern to rewrite a parallel_insert_slice op with constant arguments. +class ParallelInsertSliceOpConstantArgumentFolder final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ParallelInsertSliceOp insertSliceOp, + PatternRewriter &rewriter) const override { + // No constant operand, just return. + if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) { + return matchPattern(operand, matchConstantIndex()); + })) + return failure(); + + // At least one of offsets/sizes/strides is a new constant. + // Form the new list of operands and constant attributes from the + // existing. + SmallVector mixedOffsets(insertSliceOp.getMixedOffsets()); + SmallVector mixedSizes(insertSliceOp.getMixedSizes()); + SmallVector mixedStrides(insertSliceOp.getMixedStrides()); + canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset); + canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic); + canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset); + + // Create the new op in canonical form. + rewriter.replaceOpWithNewOp( + insertSliceOp, insertSliceOp.getSource(), insertSliceOp.getDest(), + mixedOffsets, mixedSizes, mixedStrides); + return success(); + } +}; +} // namespace + +void ParallelInsertSliceOp::getCanonicalizationPatterns( + RewritePatternSet &results, MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // PerformConcurrentlyOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index 8e087fc0f38a4..ad5afa9c36015 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -1458,3 +1458,28 @@ func.func @func_execute_region_elim_multi_yield() { // CHECK: ^[[bb3]](%[[z:.+]]: i64): // CHECK: "test.bar"(%[[z]]) // CHECK: return + +// ----- + +// CHECK-LABEL: func.func @canonicalize_parallel_insert_slice_indices( +// CHECK-SAME: %[[arg0:[0-9a-z]*]]: tensor, +// CHECK-SAME: %[[arg1:[0-9a-z]*]]: tensor, +// CHECK-SAME: %[[num_threads:[0-9a-z]*]]: index +func.func @canonicalize_parallel_insert_slice_indices( + %arg0 : tensor, %arg1: tensor, + %num_threads : index) -> tensor +{ + %cst = arith.constant 4.200000e+01 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + // CHECK: scf.foreach_thread (%[[tidx:[0-9a-z]*]]) in (%[[num_threads]]) -> (tensor) { + // CHECK-NEXT: scf.foreach_thread.perform_concurrently { + // CHECK-NEXT: scf.foreach_thread.parallel_insert_slice %[[arg0]] into %[[arg1]][%[[tidx]], 0] [1, 5] [1, 1] + %2 = scf.foreach_thread (%tidx) in (%num_threads) -> (tensor) { + scf.foreach_thread.perform_concurrently { + scf.foreach_thread.parallel_insert_slice %arg0 into %arg1[%tidx, %c0] [%c1, 5] [%c1, %c1] : tensor into tensor + } + } + return %2 : tensor +}