diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 9bd13f3236cfc..744a5951330a3 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -27,6 +27,7 @@ #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Casting.h" #include "llvm/Support/DebugLog.h" @@ -291,9 +292,102 @@ struct MultiBlockExecuteInliner : public OpRewritePattern { } }; +// Pattern to eliminate ExecuteRegionOp results which forward external +// values from the region. In case there are multiple yield operations, +// all of them must have the same operands in order for the pattern to be +// applicable. +struct ExecuteRegionForwardingEliminator + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExecuteRegionOp op, + PatternRewriter &rewriter) const override { + if (op.getNumResults() == 0) + return failure(); + + SmallVector yieldOps; + for (Block &block : op.getRegion()) { + if (auto yield = dyn_cast(block.getTerminator())) + yieldOps.push_back(yield.getOperation()); + } + + if (yieldOps.empty()) + return failure(); + + // Check if all yield operations have the same operands. + auto yieldOpsOperands = yieldOps[0]->getOperands(); + for (auto *yieldOp : yieldOps) { + if (yieldOp->getOperands() != yieldOpsOperands) + return failure(); + } + + SmallVector externalValues; + SmallVector internalValues; + SmallVector opResultsToReplaceWithExternalValues; + SmallVector opResultsToKeep; + for (auto [index, yieldedValue] : llvm::enumerate(yieldOpsOperands)) { + if (isValueFromInsideRegion(yieldedValue, op)) { + internalValues.push_back(yieldedValue); + opResultsToKeep.push_back(op.getResult(index)); + } else { + externalValues.push_back(yieldedValue); + opResultsToReplaceWithExternalValues.push_back(op.getResult(index)); + } + } + // No yielded external values - nothing to do. + if (externalValues.empty()) + return failure(); + + // There are yielded external values - create a new execute_region returning + // just the internal values. + SmallVector resultTypes; + for (Value value : internalValues) + resultTypes.push_back(value.getType()); + auto newOp = + ExecuteRegionOp::create(rewriter, op.getLoc(), TypeRange(resultTypes)); + newOp->setAttrs(op->getAttrs()); + + // Move old op's region to the new operation. + rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), + newOp.getRegion().end()); + + // Replace all yield operations with a new yield operation with updated + // results. scf.execute_region must have at least one yield operation. + for (auto *yieldOp : yieldOps) { + rewriter.setInsertionPoint(yieldOp); + rewriter.replaceOpWithNewOp(yieldOp, + ValueRange(internalValues)); + } + + // Replace the old operation with the external values directly. + rewriter.replaceAllUsesWith(opResultsToReplaceWithExternalValues, + externalValues); + // Replace the old operation's remaining results with the new operation's + // results. + rewriter.replaceAllUsesWith(opResultsToKeep, newOp.getResults()); + rewriter.eraseOp(op); + return success(); + } + +private: + bool isValueFromInsideRegion(Value value, + ExecuteRegionOp executeRegionOp) const { + // Check if the value is defined within the execute_region + if (Operation *defOp = value.getDefiningOp()) + return &executeRegionOp.getRegion() == defOp->getParentRegion(); + + // If it's a block argument, check if it's from within the region + if (BlockArgument blockArg = dyn_cast(value)) + return &executeRegionOp.getRegion() == blockArg.getParentRegion(); + + return false; // Value is from outside the region + } +}; + void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } void ExecuteRegionOp::getSuccessorRegions( diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index 2bec63672e783..084c3fc065de3 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -1604,6 +1604,148 @@ func.func @func_execute_region_inline_multi_yield() { // ----- +// Test case with single scf.yield op inside execute_region and its operand is defined outside the execute_region op. +// Make scf.execute_region not to return anything. + +// CHECK: scf.execute_region no_inline { +// CHECK: func.call @foo() : () -> () +// CHECK: scf.yield +// CHECK: } + +module { +func.func private @foo()->() +func.func private @execute_region_yeilding_external_value() -> memref<1x60xui8> { + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8> + %1 = scf.execute_region -> memref<1x60xui8> no_inline { + func.call @foo():()->() + scf.yield %alloc: memref<1x60xui8> + } + return %1 : memref<1x60xui8> +} +} + +// ----- + +// Test case with scf.yield op inside execute_region with multiple operands. +// One of operands is defined outside the execute_region op. +// Remove just this operand from the op results. + +// CHECK: %[[VAL_1:.*]] = scf.execute_region -> memref<1x120xui8> no_inline { +// CHECK: %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8> +// CHECK: func.call @foo() : () -> () +// CHECK: scf.yield %[[VAL_2]] : memref<1x120xui8> +// CHECK: } +module { +func.func private @foo()->() +func.func private @execute_region_yeilding_external_and_local_values() -> (memref<1x60xui8>, memref<1x120xui8>) { + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8> + %1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline { + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8> + func.call @foo():()->() + scf.yield %alloc, %alloc_1: memref<1x60xui8>, memref<1x120xui8> + } + return %1, %2 : memref<1x60xui8>, memref<1x120xui8> +} +} + +// ----- + +// Test case with multiple scf.yield ops inside execute_region with same operands and those operands are defined outside the execute_region op.. +// Make scf.execute_region not to return anything. +// scf.yield must remain, cause scf.execute_region can't be empty. + +// CHECK: scf.execute_region no_inline { +// CHECK: %[[VAL_3:.*]] = "test.cmp"() : () -> i1 +// CHECK: cf.cond_br %[[VAL_3]], ^bb1, ^bb2 +// CHECK: ^bb1: +// CHECK: scf.yield +// CHECK: ^bb2: +// CHECK: scf.yield +// CHECK: } + +module { + func.func private @foo()->() + func.func private @execute_region_multiple_yields_same_operands() -> (memref<1x60xui8>, memref<1x120xui8>) { + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8> + %1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline { + %c = "test.cmp"() : () -> i1 + cf.cond_br %c, ^bb2, ^bb3 + ^bb2: + func.call @foo():()->() + scf.yield %alloc, %alloc_1 : memref<1x60xui8>, memref<1x120xui8> + ^bb3: + func.call @foo():()->() + scf.yield %alloc, %alloc_1 : memref<1x60xui8>, memref<1x120xui8> + } + return %1, %2 : memref<1x60xui8>, memref<1x120xui8> + } +} + +// ----- + +// Test case with multiple scf.yield ops with at least one different operand, then no change. + +// CHECK: %[[VAL_3:.*]]:2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline { +// CHECK: ^bb1: +// CHECK: scf.yield %{{.*}}, %{{.*}} : memref<1x60xui8>, memref<1x120xui8> +// CHECK: ^bb2: +// CHECK: scf.yield %{{.*}}, %{{.*}} : memref<1x60xui8>, memref<1x120xui8> +// CHECK: } + +module { + func.func private @foo()->() + func.func private @execute_region_multiple_yields_different_operands() -> (memref<1x60xui8>, memref<1x120xui8>) { + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8> + %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8> + %1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline { + %c = "test.cmp"() : () -> i1 + cf.cond_br %c, ^bb2, ^bb3 + ^bb2: + func.call @foo():()->() + scf.yield %alloc, %alloc_1 : memref<1x60xui8>, memref<1x120xui8> + ^bb3: + func.call @foo():()->() + scf.yield %alloc, %alloc_2 : memref<1x60xui8>, memref<1x120xui8> + } + return %1, %2 : memref<1x60xui8>, memref<1x120xui8> + } +} + +// ----- + +// Test case with multiple scf.yield ops each has different operand. +// In this case scf.execute_region isn't changed. + +// CHECK: %[[VAL_2:.*]] = scf.execute_region -> memref<1x60xui8> no_inline { +// CHECK: ^bb1: +// CHECK: scf.yield %{{.*}} : memref<1x60xui8> +// CHECK: ^bb2: +// CHECK: scf.yield %{{.*}} : memref<1x60xui8> +// CHECK: } + +module { +func.func private @foo()->() +func.func private @execute_region_multiple_yields_different_operands() -> (memref<1x60xui8>) { + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8> + %1 = scf.execute_region -> (memref<1x60xui8>) no_inline { + %c = "test.cmp"() : () -> i1 + cf.cond_br %c, ^bb2, ^bb3 + ^bb2: + func.call @foo():()->() + scf.yield %alloc : memref<1x60xui8> + ^bb3: + func.call @foo():()->() + scf.yield %alloc_1 : memref<1x60xui8> + } + return %1 : memref<1x60xui8> +} +} + +// ----- + // CHECK-LABEL: func @canonicalize_parallel_insert_slice_indices( // CHECK-SAME: %[[arg0:.*]]: tensor<1x5xf32>, %[[arg1:.*]]: tensor func.func @canonicalize_parallel_insert_slice_indices(