[mlir][Interfaces] Add ExecutionProgressOpInterface + folding pattern#180348
[mlir][Interfaces] Add ExecutionProgressOpInterface + folding pattern#180348matthias-springer wants to merge 1 commit intomainfrom
ExecutionProgressOpInterface + folding pattern#180348Conversation
…rn (#179039) Add the `ExecutionProgressOpInterface` with an interface method to check if an operation "must progress". Add `mustProgress` attributes to `scf.for` and `scf.while` (default value is "true"). `mustProgress` corresponds to the [`llvm.loop.mustprogress` metadata](https://llvm.org/docs/LangRef.html#langref-llvm-loop-mustprogress). Also add a canonicalization pattern to erase `RegionBranchOpInterface` ops that must progress but loop infinitely (and are non-side-effecting). This canonicalization pattern is enabled for `scf.for` and `scf.while`. RFC: https://discourse.llvm.org/t/infinite-loops-and-dead-code/89530 [mlir] Fix build after #179039 (#179180) Fix build after #179039.
|
@llvm/pr-subscribers-mlir-ub @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesAdd the
Also add a canonicalization pattern to erase Registered operations are assumed to "must progress" by default. RFC: https://discourse.llvm.org/t/infinite-loops-and-dead-code/89530 This PR is a re-upload of #179039. Patch is 31.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/180348.diff 19 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index e754a04b0903a..de60ed99dd336 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -20,6 +20,7 @@
#include "mlir/IR/RegionKindInterface.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/ExecutionProgressOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index a08cf3c95e6ce..b259e33f1d75f 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -18,6 +18,7 @@ include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/IR/RegionKindInterface.td"
include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Interfaces/ExecutionProgressOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/ParallelCombiningOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -40,7 +41,7 @@ def SCF_Dialect : Dialect {
and then lowered to some final target like LLVM or SPIR-V.
}];
- let dependentDialects = ["arith::ArithDialect"];
+ let dependentDialects = ["arith::ArithDialect", "ub::UBDialect"];
}
// Base class for SCF dialect ops.
@@ -161,6 +162,8 @@ def ForOp : SCF_Op<"for",
ConditionallySpeculatable,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getEntrySuccessorOperands", "getSuccessorInputs"]>,
+ DeclareOpInterfaceMethods<ExecutionProgressOpInterface,
+ ["mustProgress"]>,
SingleBlockImplicitTerminator<"scf::YieldOp">,
RecursiveMemoryEffects]> {
let summary = "for operation";
@@ -265,7 +268,8 @@ def ForOp : SCF_Op<"for",
AnySignlessIntegerOrIndex:$upperBound,
AnySignlessIntegerOrIndex:$step,
Variadic<AnyType>:$initArgs,
- UnitAttr:$unsignedCmp);
+ UnitAttr:$unsignedCmp,
+ DefaultValuedAttr<BoolAttr, "true">:$mustProgress);
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$region);
@@ -986,6 +990,7 @@ def WhileOp : SCF_Op<"while",
["getEntrySuccessorOperands", "getSuccessorInputs"]>,
DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getRegionIterArgs", "getYieldedValuesMutable"]>,
+ DeclareOpInterfaceMethods<ExecutionProgressOpInterface, ["mustProgress"]>,
RecursiveMemoryEffects, SingleBlock]> {
let summary = "a generic 'while' loop";
let description = [{
@@ -1101,14 +1106,18 @@ def WhileOp : SCF_Op<"while",
```
}];
- let arguments = (ins Variadic<AnyType>:$inits);
+ let arguments = (ins Variadic<AnyType>:$inits,
+ DefaultValuedAttr<BoolAttr, "true">:$mustProgress);
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$before, SizedRegion<1>:$after);
+ let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$inits,
"function_ref<void(OpBuilder &, Location, ValueRange)>":$beforeBuilder,
- "function_ref<void(OpBuilder &, Location, ValueRange)>":$afterBuilder)>
+ "function_ref<void(OpBuilder &, Location, ValueRange)>":$afterBuilder)>,
+ OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$inits,
+ CArg<"bool", "true">:$mustProgress)>
];
let extraClassDeclaration = [{
diff --git a/mlir/include/mlir/Dialect/UB/IR/UBOps.h b/mlir/include/mlir/Dialect/UB/IR/UBOps.h
index 21de5cb0c182a..281bd3ed4e805 100644
--- a/mlir/include/mlir/Dialect/UB/IR/UBOps.h
+++ b/mlir/include/mlir/Dialect/UB/IR/UBOps.h
@@ -12,6 +12,7 @@
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Dialect/UB/IR/UBOpsInterfaces.h.inc"
@@ -24,4 +25,13 @@
#include "mlir/Dialect/UB/IR/UBOpsDialect.h.inc"
+namespace mlir::ub {
+/// Populate a canonicalization pattern that erases "must progress" region
+/// branch ops that loop infinitely and replaces their results with poison
+/// values.
+void populateEraseInfiniteRegionBranchLoopPattern(RewritePatternSet &patterns,
+ StringRef opName,
+ PatternBenefit benefit = 1);
+} // namespace mlir::ub
+
#endif // MLIR_DIALECT_UB_IR_OPS_H
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index eb96a68861116..e0c75aee29c00 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_interface(CastInterfaces)
add_mlir_interface(ControlFlowInterfaces)
add_mlir_interface(DerivedAttributeOpInterface)
add_mlir_interface(DestinationStyleOpInterface)
+add_mlir_interface(ExecutionProgressOpInterface)
add_mlir_interface(FunctionInterfaces)
add_mlir_interface(IndexingMapOpInterface)
add_mlir_interface(InferIntRangeInterface)
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index a76dce6f2ffc5..33e139f6b0cea 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -314,6 +314,11 @@ Region *getEnclosingRepetitiveRegion(Operation *op);
/// exists.
Region *getEnclosingRepetitiveRegion(Value value);
+/// Return "true" if the given region branch op is guaranteed to loop
+/// infinitely. Every path starting from "parent" enters the region, but the
+/// "parent" is not reachable from there.
+bool isGuaranteedToLoopInfinitely(RegionBranchOpInterface op);
+
/// Populate canonicalization patterns that simplify successor operands/inputs
/// of region branch operations. Only operations with the given name are
/// matched.
@@ -359,6 +364,13 @@ void populateRegionBranchOpInterfaceInliningPattern(
PatternMatcherFn matcherFn = detail::defaultMatcherFn,
PatternBenefit benefit = 1);
+/// Return all successor regions when branching from the given region branch
+/// point. This helper functions extracts all constant operand values and
+/// passes them to the `RegionBranchOpInterface`.
+SmallVector<RegionSuccessor>
+getSuccessorRegionsWithAttrs(RegionBranchOpInterface op,
+ RegionBranchPoint point);
+
//===----------------------------------------------------------------------===//
// ControlFlow Traits
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 8975b1235a7e3..1dacde297efa2 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -350,7 +350,7 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
"bool", "areTypesCompatible",
(ins "::mlir::Type":$lhs, "::mlir::Type":$rhs), [{}],
/*defaultImplementation=*/[{ return lhs == rhs; }]
- >,
+ >
];
let verify = [{
diff --git a/mlir/include/mlir/Interfaces/ExecutionProgressOpInterface.h b/mlir/include/mlir/Interfaces/ExecutionProgressOpInterface.h
new file mode 100644
index 0000000000000..e395f909de092
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/ExecutionProgressOpInterface.h
@@ -0,0 +1,29 @@
+//===- ExecutionProgressOpInterface.h ---------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_EXECUTIONPROGRESSOPINTERFACE_H_
+#define MLIR_INTERFACES_EXECUTIONPROGRESSOPINTERFACE_H_
+
+#include "mlir/IR/OpDefinition.h"
+
+#include "mlir/Interfaces/ExecutionProgressOpInterface.h.inc"
+
+namespace mlir {
+/// Return "true" if the operation must progress.
+///
+/// Unregistered operations are treated conservatively: they may not
+/// necessarily progress (i.e., return "false"). Registered operations are
+/// assumed to progress by default. This can be overridden by the
+/// ExecutionProgressOpInterface.
+bool mustProgress(Operation *op);
+
+/// Return "true" if the operation might not progress.
+inline bool mightNotProgress(Operation *op) { return !mustProgress(op); }
+} // namespace mlir
+
+#endif // MLIR_INTERFACES_EXECUTIONPROGRESSOPINTERFACE_H_
diff --git a/mlir/include/mlir/Interfaces/ExecutionProgressOpInterface.td b/mlir/include/mlir/Interfaces/ExecutionProgressOpInterface.td
new file mode 100644
index 0000000000000..4b7923ce3612e
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/ExecutionProgressOpInterface.td
@@ -0,0 +1,48 @@
+//===- ExecutionProgressOpInterface.td - Interface Decl. -*- tablegen -*---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the definition file for the ExecutionProgressOpInterface.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_EXECUTION_PROGRESS_OP_INTERFACE
+#define MLIR_INTERFACES_EXECUTION_PROGRESS_OP_INTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def ExecutionProgressOpInterface : OpInterface<"ExecutionProgressOpInterface"> {
+ let description = [{
+ This interface models execution progress properties of operations.
+ }];
+ let cppNamespace = "::mlir";
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Operations that "must progress" are required to return normally (control
+ flow reaches the next operation) or interact with the environment in an
+ observable way (e.g., volatile memory access, I/O, synchronization or
+ program termination). If a "must progress" op executes indefinitely
+ without any observable interaction, it may be erased.
+
+ See LLVM "llvm.loop.mustprogress" / "mustprogress" function attribute
+ for more details.
+
+ Operations must progress by default.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"mustProgress",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return true;
+ }]
+ >
+ ];
+}
+
+#endif // MLIR_INTERFACES_EXECUTION_PROGRESS_OP_INTERFACE
diff --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
index b111117410ba3..8c3b93b3c580b 100644
--- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
@@ -13,11 +13,13 @@ add_mlir_dialect_library(MLIRSCFDialect
MLIRArithDialect
MLIRControlFlowDialect
MLIRDialectUtils
+ MLIRExecutionProgressOpInterface
MLIRFunctionInterfaces
MLIRIR
MLIRLoopLikeInterface
MLIRSideEffectInterfaces
MLIRTensorDialect
+ MLIRUBDialect
MLIRValueBoundsOpInterface
MLIRTransformUtils
)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index c46a0577c4b96..0116620bdd3a3 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
@@ -509,8 +510,10 @@ void ForOp::print(OpAsmPrinter &p) {
p.printRegion(getRegion(),
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/!getInitArgs().empty());
- p.printOptionalAttrDict((*this)->getAttrs(),
- /*elidedAttrs=*/getUnsignedCmpAttrName().strref());
+ SmallVector<StringRef> elidedAttrs = {getUnsignedCmpAttrName().strref()};
+ if (getMustProgress()) // "true" is the default, elide attribute.
+ elidedAttrs.push_back(getMustProgressAttrName().strref());
+ p.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
}
ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -691,6 +694,24 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point,
}
}
+ // Infinite loops (lb < ub and step size 0) enter the loop body and never
+ // leave it.
+ std::optional<std::pair<APInt, bool>> lbCst =
+ getConstantAPIntValue(getLowerBound());
+ std::optional<std::pair<APInt, bool>> ubCst =
+ getConstantAPIntValue(getUpperBound());
+ std::optional<std::pair<APInt, bool>> stepCst =
+ getConstantAPIntValue(getStep());
+ if (lbCst.has_value() && ubCst.has_value() && stepCst.has_value()) {
+ bool atLeastOneIteration =
+ (getUnsignedCmp() && lbCst->first.ult(ubCst->first)) ||
+ (!getUnsignedCmp() && lbCst->first.slt(ubCst->first));
+ if (atLeastOneIteration && stepCst->first.isZero()) {
+ regions.push_back(RegionSuccessor(&getRegion()));
+ return;
+ }
+ }
+
// Both the operation itself and the region may be branching into the body or
// back into the operation itself. It is possible for loop not to enter the
// body.
@@ -703,6 +724,8 @@ ValueRange ForOp::getSuccessorInputs(RegionSuccessor successor) {
: ValueRange(getRegionIterArgs());
}
+bool ForOp::mustProgress() { return getMustProgress(); }
+
SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
/// Promotes the loop body of a forallOp to its containing block if it can be
@@ -1004,6 +1027,8 @@ void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
auto forOp = cast<ForOp>(blockArg.getOwner()->getParentOp());
return forOp.getLowerBound();
});
+ ub::populateEraseInfiniteRegionBranchLoopPattern(results,
+ ForOp::getOperationName());
}
std::optional<APInt> ForOp::getConstantStep() {
@@ -3210,6 +3235,16 @@ void WhileOp::build(::mlir::OpBuilder &odsBuilder,
afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
}
+void WhileOp::build(::mlir::OpBuilder &odsBuilder,
+ ::mlir::OperationState &odsState, TypeRange resultTypes,
+ ValueRange inits, bool mustProgress) {
+ odsState.addOperands(inits);
+ for (unsigned i = 0; i < 2; ++i)
+ (void)odsState.addRegion();
+ odsState.addTypes(resultTypes);
+ odsState.addAttribute("mustProgress", odsBuilder.getBoolAttr(mustProgress));
+}
+
ConditionOp WhileOp::getConditionOp() {
return cast<ConditionOp>(getBeforeBody()->getTerminator());
}
@@ -3273,6 +3308,8 @@ ValueRange WhileOp::getSuccessorInputs(RegionSuccessor successor) {
llvm_unreachable("invalid region successor");
}
+bool WhileOp::mustProgress() { return getMustProgress(); }
+
SmallVector<Region *> WhileOp::getLoopRegions() {
return {&getBefore(), &getAfter()};
}
@@ -3332,7 +3369,10 @@ void scf::WhileOp::print(OpAsmPrinter &p) {
p.printRegion(getBefore(), /*printEntryBlockArgs=*/false);
p << " do ";
p.printRegion(getAfter());
- p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
+ SmallVector<StringRef> elidedAttrs;
+ if (getMustProgress()) // "true" is the default, elide attribute.
+ elidedAttrs.push_back(getMustProgressAttrName().strref());
+ p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
}
/// Verifies that two ranges of types match, i.e. have the same number of
@@ -3708,6 +3748,8 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
results, WhileOp::getOperationName());
populateRegionBranchOpInterfaceInliningPattern(results,
WhileOp::getOperationName());
+ ub::populateEraseInfiniteRegionBranchLoopPattern(results,
+ WhileOp::getOperationName());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
index ddcbda86cf1f3..152fb226993e9 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
@@ -49,8 +49,9 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
SmallVector<Value> initArgs;
initArgs.push_back(forOp.getLowerBound());
llvm::append_range(initArgs, forOp.getInitArgs());
- auto whileOp = WhileOp::create(rewriter, forOp.getLoc(), lcvTypes, initArgs,
- forOp->getAttrs());
+ auto whileOp =
+ WhileOp::create(rewriter, forOp.getLoc(), lcvTypes, initArgs);
+ whileOp->setAttrs(forOp->getAttrDictionary());
// 'before' region contains the loop condition and forwarding of iteration
// arguments to the 'after' region.
diff --git a/mlir/lib/Dialect/UB/IR/CMakeLists.txt b/mlir/lib/Dialect/UB/IR/CMakeLists.txt
index 84125ea0b5718..3baac5045b8db 100644
--- a/mlir/lib/Dialect/UB/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/UB/IR/CMakeLists.txt
@@ -5,9 +5,13 @@ add_mlir_dialect_library(MLIRUBDialect
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/UB
DEPENDS
+ MLIRControlFlowInterfaces
MLIRUBOpsIncGen
MLIRUBOpsInterfacesIncGen
LINK_LIBS PUBLIC
+ MLIRControlFlowInterfaces
MLIRIR
+ MLIRExecutionProgressOpInterface
+ MLIRSideEffectInterfaces
)
diff --git a/mlir/lib/Dialect/UB/IR/UBOps.cpp b/mlir/lib/Dialect/UB/IR/UBOps.cpp
index ee523f9522953..2310fc5af8cb8 100644
--- a/mlir/lib/Dialect/UB/IR/UBOps.cpp
+++ b/mlir/lib/Dialect/UB/IR/UBOps.cpp
@@ -8,6 +8,8 @@
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/ExecutionProgressOpInterface.h"
#include "mlir/Transforms/InliningUtils.h"
#include "mlir/IR/Builders.h"
@@ -66,3 +68,40 @@ OpFoldResult PoisonOp::fold(FoldAdaptor /*adaptor*/) { return getValue(); }
#define GET_OP_CLASSES
#include "mlir/Dialect/UB/IR/UBOps.cpp.inc"
+
+namespace {
+/// Canonicalization pattern for RegionBranchOpInterface ops that loop
+/// infinitely. Such ops are replaced with poison values if they "must
+/// progress".
+struct EraseInfiniteRegionBranchLoop : public RewritePattern {
+ EraseInfiniteRegionBranchLoop(MLIRContext *context, StringRef name,
+ PatternBenefit benefit = 1)
+ : RewritePattern(name, benefit, context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ auto regionBranchOp = cast<RegionBranchOpInterface>(op);
+ if (mightNotProgress(op))
+ return rewriter.notifyMatchFailure(
+ op, "only loops that must progress are removed");
+ if (!wouldOpBeTriviallyDead(op))
+ return rewriter.notifyMatchFailure(op,
+ "only trivially dead ops are removed");
+ if (!isGuaranteedToLoopInfinitely(regionBranchOp))
+ return rewriter.notifyMatchFailure(
+ op, "only loops that loop infinitely are removed");
+ SmallVector<Value> replacements =
+ llvm::map_to_vector(op->getResultTypes(), [&](Type type) {
+ return PoisonOp::create(rewriter, op->getLoc(), type).getResult();
+ });
+ rewriter.replaceOp(op, replacements);
+ return success();
+ }
+};
+} // namespace
+
+void mlir::ub::populateEraseInfiniteRegionBranchLoopPattern(
+ RewritePatternSet &patterns, StringRef opName, PatternBenefit benefit) {
+ patterns.add<EraseInfiniteRegionBranchLoop>(patterns.getContext(), opName,
+ benefit);
+}
diff --git a/m...
[truncated]
|
zero9178
left a comment
There was a problem hiding this comment.
LGTM!
I am still on the side of mustProgress should default to false (even for registered ops), but the current state is an okay compromise to iterate on.
| Operations that "must progress" are required to return normally (control | ||
| flow reaches the next operation) or interact with the environment in an | ||
| observable way (e.g., volatile memory access, I/O, synchronization or | ||
| program termination). If a "must progress" op executes indefinitely | ||
| without any observable interaction, it may be erased. |
There was a problem hiding this comment.
Why do we need this part?
or interact with the environment in an observable way (e.g., volatile memory access, I/O, synchronization or program termination).
Why instead such operation that interacts with the environment wouldn't just return "false" here?
There was a problem hiding this comment.
Right now this means that mustProgress(op) does not say much about an operation with side effets.
Another problem is that we still are missing "volatile" effects in MLIR (unless I missed the addition).
joker-eph
left a comment
There was a problem hiding this comment.
We need a LangRef update.
Add the
ExecutionProgressOpInterfacewith an interface method to check if an operation "must progress". AddmustProgressattributes toscf.forandscf.while(default value is "true").mustProgresscorresponds to thellvm.loop.mustprogressmetadata.Also add a canonicalization pattern to erase
RegionBranchOpInterfaceops that must progress but loop infinitely (and are non-side-effecting). This canonicalization pattern is enabled forscf.forandscf.while.Registered operations are assumed to "must progress" by default.
RFC: https://discourse.llvm.org/t/infinite-loops-and-dead-code/89530
This PR is a re-upload of #179039.