Skip to content
96 changes: 95 additions & 1 deletion mlir/lib/Dialect/SCF/IR/SCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/DebugLog.h"
Expand Down Expand Up @@ -291,9 +292,102 @@ struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
}
};

// Pattern to eliminate ExecuteRegionOp results which forward external
// values from the region. In case there are multiple yield operations,
// all of them must have the same operands in order for the pattern to be
// applicable.
struct ExecuteRegionForwardingEliminator
: public OpRewritePattern<ExecuteRegionOp> {
using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ExecuteRegionOp op,
PatternRewriter &rewriter) const override {
if (op.getNumResults() == 0)
return failure();

SmallVector<Operation *> yieldOps;
for (Block &block : op.getRegion()) {
if (auto yield = dyn_cast<scf::YieldOp>(block.getTerminator()))
yieldOps.push_back(yield.getOperation());
}

if (yieldOps.empty())
return failure();

// Check if all yield operations have the same operands.
auto yieldOpsOperands = yieldOps[0]->getOperands();
for (auto *yieldOp : yieldOps) {
if (yieldOp->getOperands() != yieldOpsOperands)
return failure();
}

SmallVector<Value> externalValues;
SmallVector<Value> internalValues;
SmallVector<Value> opResultsToReplaceWithExternalValues;
SmallVector<Value> opResultsToKeep;
for (auto [index, yieldedValue] : llvm::enumerate(yieldOpsOperands)) {
if (isValueFromInsideRegion(yieldedValue, op)) {
internalValues.push_back(yieldedValue);
opResultsToKeep.push_back(op.getResult(index));
} else {
externalValues.push_back(yieldedValue);
opResultsToReplaceWithExternalValues.push_back(op.getResult(index));
}
}
// No yielded external values - nothing to do.
if (externalValues.empty())
return failure();

// There are yielded external values - create a new execute_region returning
// just the internal values.
SmallVector<Type> resultTypes;
for (Value value : internalValues)
resultTypes.push_back(value.getType());
auto newOp =
ExecuteRegionOp::create(rewriter, op.getLoc(), TypeRange(resultTypes));
newOp->setAttrs(op->getAttrs());

// Move old op's region to the new operation.
rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
newOp.getRegion().end());

// Replace all yield operations with a new yield operation with updated
// results. scf.execute_region must have at least one yield operation.
for (auto *yieldOp : yieldOps) {
rewriter.setInsertionPoint(yieldOp);
rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp,
ValueRange(internalValues));
}

// Replace the old operation with the external values directly.
rewriter.replaceAllUsesWith(opResultsToReplaceWithExternalValues,
externalValues);
// Replace the old operation's remaining results with the new operation's
// results.
rewriter.replaceAllUsesWith(opResultsToKeep, newOp.getResults());
rewriter.eraseOp(op);
return success();
}

private:
bool isValueFromInsideRegion(Value value,
ExecuteRegionOp executeRegionOp) const {
// Check if the value is defined within the execute_region
if (Operation *defOp = value.getDefiningOp())
return &executeRegionOp.getRegion() == defOp->getParentRegion();

// If it's a block argument, check if it's from within the region
if (BlockArgument blockArg = dyn_cast<BlockArgument>(value))
return &executeRegionOp.getRegion() == blockArg.getParentRegion();

return false; // Value is from outside the region
}
};

void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context);
results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner,
ExecuteRegionForwardingEliminator>(context);
}

void ExecuteRegionOp::getSuccessorRegions(
Expand Down
142 changes: 142 additions & 0 deletions mlir/test/Dialect/SCF/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1604,6 +1604,148 @@ func.func @func_execute_region_inline_multi_yield() {

// -----

// Test case with single scf.yield op inside execute_region and its operand is defined outside the execute_region op.
// Make scf.execute_region not to return anything.

// CHECK: scf.execute_region no_inline {
// CHECK: func.call @foo() : () -> ()
// CHECK: scf.yield
// CHECK: }

module {
func.func private @foo()->()
func.func private @execute_region_yeilding_external_value() -> memref<1x60xui8> {
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
%1 = scf.execute_region -> memref<1x60xui8> no_inline {
func.call @foo():()->()
scf.yield %alloc: memref<1x60xui8>
}
return %1 : memref<1x60xui8>
}
}

// -----

// Test case with scf.yield op inside execute_region with multiple operands.
// One of operands is defined outside the execute_region op.
// Remove just this operand from the op results.

// CHECK: %[[VAL_1:.*]] = scf.execute_region -> memref<1x120xui8> no_inline {
// CHECK: %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
// CHECK: func.call @foo() : () -> ()
// CHECK: scf.yield %[[VAL_2]] : memref<1x120xui8>
// CHECK: }
module {
func.func private @foo()->()
func.func private @execute_region_yeilding_external_and_local_values() -> (memref<1x60xui8>, memref<1x120xui8>) {
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
%1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
func.call @foo():()->()
scf.yield %alloc, %alloc_1: memref<1x60xui8>, memref<1x120xui8>
}
return %1, %2 : memref<1x60xui8>, memref<1x120xui8>
}
}

// -----

// Test case with multiple scf.yield ops inside execute_region with same operands and those operands are defined outside the execute_region op..
// Make scf.execute_region not to return anything.
// scf.yield must remain, cause scf.execute_region can't be empty.

// CHECK: scf.execute_region no_inline {
// CHECK: %[[VAL_3:.*]] = "test.cmp"() : () -> i1
// CHECK: cf.cond_br %[[VAL_3]], ^bb1, ^bb2
// CHECK: ^bb1:
// CHECK: scf.yield
// CHECK: ^bb2:
// CHECK: scf.yield
// CHECK: }

module {
func.func private @foo()->()
func.func private @execute_region_multiple_yields_same_operands() -> (memref<1x60xui8>, memref<1x120xui8>) {
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
%1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
%c = "test.cmp"() : () -> i1
cf.cond_br %c, ^bb2, ^bb3
^bb2:
func.call @foo():()->()
scf.yield %alloc, %alloc_1 : memref<1x60xui8>, memref<1x120xui8>
^bb3:
func.call @foo():()->()
scf.yield %alloc, %alloc_1 : memref<1x60xui8>, memref<1x120xui8>
}
return %1, %2 : memref<1x60xui8>, memref<1x120xui8>
}
}

// -----

// Test case with multiple scf.yield ops with at least one different operand, then no change.

// CHECK: %[[VAL_3:.*]]:2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
// CHECK: ^bb1:
// CHECK: scf.yield %{{.*}}, %{{.*}} : memref<1x60xui8>, memref<1x120xui8>
// CHECK: ^bb2:
// CHECK: scf.yield %{{.*}}, %{{.*}} : memref<1x60xui8>, memref<1x120xui8>
// CHECK: }

module {
func.func private @foo()->()
func.func private @execute_region_multiple_yields_different_operands() -> (memref<1x60xui8>, memref<1x120xui8>) {
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
%alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
%1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
%c = "test.cmp"() : () -> i1
cf.cond_br %c, ^bb2, ^bb3
^bb2:
func.call @foo():()->()
scf.yield %alloc, %alloc_1 : memref<1x60xui8>, memref<1x120xui8>
^bb3:
func.call @foo():()->()
scf.yield %alloc, %alloc_2 : memref<1x60xui8>, memref<1x120xui8>
}
return %1, %2 : memref<1x60xui8>, memref<1x120xui8>
}
}

// -----

// Test case with multiple scf.yield ops each has different operand.
// In this case scf.execute_region isn't changed.

// CHECK: %[[VAL_2:.*]] = scf.execute_region -> memref<1x60xui8> no_inline {
// CHECK: ^bb1:
// CHECK: scf.yield %{{.*}} : memref<1x60xui8>
// CHECK: ^bb2:
// CHECK: scf.yield %{{.*}} : memref<1x60xui8>
// CHECK: }

module {
func.func private @foo()->()
func.func private @execute_region_multiple_yields_different_operands() -> (memref<1x60xui8>) {
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
%1 = scf.execute_region -> (memref<1x60xui8>) no_inline {
%c = "test.cmp"() : () -> i1
cf.cond_br %c, ^bb2, ^bb3
^bb2:
func.call @foo():()->()
scf.yield %alloc : memref<1x60xui8>
^bb3:
func.call @foo():()->()
scf.yield %alloc_1 : memref<1x60xui8>
}
return %1 : memref<1x60xui8>
}
}

// -----

// CHECK-LABEL: func @canonicalize_parallel_insert_slice_indices(
// CHECK-SAME: %[[arg0:.*]]: tensor<1x5xf32>, %[[arg1:.*]]: tensor<?x?xf32>
func.func @canonicalize_parallel_insert_slice_indices(
Expand Down