Skip to content

[mlir][Interfaces] Add ExecutionProgressOpInterface + folding pattern#180348

Open
matthias-springer wants to merge 1 commit intomainfrom
users/matthias-springer/must_progress_2
Open

[mlir][Interfaces] Add ExecutionProgressOpInterface + folding pattern#180348
matthias-springer wants to merge 1 commit intomainfrom
users/matthias-springer/must_progress_2

Conversation

@matthias-springer
Copy link
Member

Add the ExecutionProgressOpInterface with an interface method to check if an operation "must progress". Add mustProgress attributes to scf.for and scf.while (default value is "true").

mustProgress corresponds to the llvm.loop.mustprogress metadata.

Also add a canonicalization pattern to erase RegionBranchOpInterface ops that must progress but loop infinitely (and are non-side-effecting). This canonicalization pattern is enabled for scf.for and scf.while.

Registered operations are assumed to "must progress" by default.

RFC: https://discourse.llvm.org/t/infinite-loops-and-dead-code/89530

This PR is a re-upload of #179039.

…rn (#179039)

Add the `ExecutionProgressOpInterface` with an interface method to check
if an operation "must progress". Add `mustProgress` attributes to
`scf.for` and `scf.while` (default value is "true").

`mustProgress` corresponds to the [`llvm.loop.mustprogress`
metadata](https://llvm.org/docs/LangRef.html#langref-llvm-loop-mustprogress).

Also add a canonicalization pattern to erase `RegionBranchOpInterface`
ops that must progress but loop infinitely (and are non-side-effecting).
This canonicalization pattern is enabled for `scf.for` and `scf.while`.

RFC: https://discourse.llvm.org/t/infinite-loops-and-dead-code/89530

[mlir] Fix build after #179039 (#179180)

Fix build after #179039.
@llvmbot
Copy link
Member

llvmbot commented Feb 7, 2026

@llvm/pr-subscribers-mlir-ub

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

Add the ExecutionProgressOpInterface with an interface method to check if an operation "must progress". Add mustProgress attributes to scf.for and scf.while (default value is "true").

mustProgress corresponds to the llvm.loop.mustprogress metadata.

Also add a canonicalization pattern to erase RegionBranchOpInterface ops that must progress but loop infinitely (and are non-side-effecting). This canonicalization pattern is enabled for scf.for and scf.while.

Registered operations are assumed to "must progress" by default.

RFC: https://discourse.llvm.org/t/infinite-loops-and-dead-code/89530

This PR is a re-upload of #179039.


Patch is 31.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/180348.diff

19 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/IR/SCF.h (+1)
  • (modified) mlir/include/mlir/Dialect/SCF/IR/SCFOps.td (+13-4)
  • (modified) mlir/include/mlir/Dialect/UB/IR/UBOps.h (+10)
  • (modified) mlir/include/mlir/Interfaces/CMakeLists.txt (+1)
  • (modified) mlir/include/mlir/Interfaces/ControlFlowInterfaces.h (+12)
  • (modified) mlir/include/mlir/Interfaces/ControlFlowInterfaces.td (+1-1)
  • (added) mlir/include/mlir/Interfaces/ExecutionProgressOpInterface.h (+29)
  • (added) mlir/include/mlir/Interfaces/ExecutionProgressOpInterface.td (+48)
  • (modified) mlir/lib/Dialect/SCF/IR/CMakeLists.txt (+2)
  • (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+45-3)
  • (modified) mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp (+3-2)
  • (modified) mlir/lib/Dialect/UB/IR/CMakeLists.txt (+4)
  • (modified) mlir/lib/Dialect/UB/IR/UBOps.cpp (+39)
  • (modified) mlir/lib/Interfaces/CMakeLists.txt (+2)
  • (modified) mlir/lib/Interfaces/ControlFlowInterfaces.cpp (+73-30)
  • (added) mlir/lib/Interfaces/ExecutionProgressOpInterface.cpp (+29)
  • (modified) mlir/test/CAPI/ir.c (+1-1)
  • (modified) mlir/test/Dialect/SCF/canonicalize.mlir (+51)
  • (modified) mlir/test/Dialect/SCF/ops.mlir (+4-3)
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index e754a04b0903a..de60ed99dd336 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -20,6 +20,7 @@
 #include "mlir/IR/RegionKindInterface.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/ExecutionProgressOpInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/ParallelCombiningOpInterface.h"
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index a08cf3c95e6ce..b259e33f1d75f 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -18,6 +18,7 @@ include "mlir/Interfaces/LoopLikeInterface.td"
 include "mlir/IR/RegionKindInterface.td"
 include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Interfaces/ExecutionProgressOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/ParallelCombiningOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -40,7 +41,7 @@ def SCF_Dialect : Dialect {
     and then lowered to some final target like LLVM or SPIR-V.
   }];
 
-  let dependentDialects = ["arith::ArithDialect"];
+  let dependentDialects = ["arith::ArithDialect", "ub::UBDialect"];
 }
 
 // Base class for SCF dialect ops.
@@ -161,6 +162,8 @@ def ForOp : SCF_Op<"for",
        ConditionallySpeculatable,
        DeclareOpInterfaceMethods<RegionBranchOpInterface,
         ["getEntrySuccessorOperands", "getSuccessorInputs"]>,
+       DeclareOpInterfaceMethods<ExecutionProgressOpInterface,
+        ["mustProgress"]>,
        SingleBlockImplicitTerminator<"scf::YieldOp">,
        RecursiveMemoryEffects]> {
   let summary = "for operation";
@@ -265,7 +268,8 @@ def ForOp : SCF_Op<"for",
                        AnySignlessIntegerOrIndex:$upperBound,
                        AnySignlessIntegerOrIndex:$step,
                        Variadic<AnyType>:$initArgs,
-                       UnitAttr:$unsignedCmp);
+                       UnitAttr:$unsignedCmp,
+                       DefaultValuedAttr<BoolAttr, "true">:$mustProgress);
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$region);
 
@@ -986,6 +990,7 @@ def WhileOp : SCF_Op<"while",
         ["getEntrySuccessorOperands", "getSuccessorInputs"]>,
      DeclareOpInterfaceMethods<LoopLikeOpInterface,
         ["getRegionIterArgs", "getYieldedValuesMutable"]>,
+     DeclareOpInterfaceMethods<ExecutionProgressOpInterface, ["mustProgress"]>,
      RecursiveMemoryEffects, SingleBlock]> {
   let summary = "a generic 'while' loop";
   let description = [{
@@ -1101,14 +1106,18 @@ def WhileOp : SCF_Op<"while",
     ```
   }];
 
-  let arguments = (ins Variadic<AnyType>:$inits);
+  let arguments = (ins Variadic<AnyType>:$inits,
+                       DefaultValuedAttr<BoolAttr, "true">:$mustProgress);
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$before, SizedRegion<1>:$after);
 
+  let skipDefaultBuilders = 1;
   let builders = [
     OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$inits,
       "function_ref<void(OpBuilder &, Location, ValueRange)>":$beforeBuilder,
-      "function_ref<void(OpBuilder &, Location, ValueRange)>":$afterBuilder)>
+      "function_ref<void(OpBuilder &, Location, ValueRange)>":$afterBuilder)>,
+    OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$inits,
+                   CArg<"bool", "true">:$mustProgress)>
   ];
 
   let extraClassDeclaration = [{
diff --git a/mlir/include/mlir/Dialect/UB/IR/UBOps.h b/mlir/include/mlir/Dialect/UB/IR/UBOps.h
index 21de5cb0c182a..281bd3ed4e805 100644
--- a/mlir/include/mlir/Dialect/UB/IR/UBOps.h
+++ b/mlir/include/mlir/Dialect/UB/IR/UBOps.h
@@ -12,6 +12,7 @@
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 #include "mlir/Dialect/UB/IR/UBOpsInterfaces.h.inc"
@@ -24,4 +25,13 @@
 
 #include "mlir/Dialect/UB/IR/UBOpsDialect.h.inc"
 
+namespace mlir::ub {
+/// Populate a canonicalization pattern that erases "must progress" region
+/// branch ops that loop infinitely and replaces their results with poison
+/// values.
+void populateEraseInfiniteRegionBranchLoopPattern(RewritePatternSet &patterns,
+                                                  StringRef opName,
+                                                  PatternBenefit benefit = 1);
+} // namespace mlir::ub
+
 #endif // MLIR_DIALECT_UB_IR_OPS_H
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index eb96a68861116..e0c75aee29c00 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_interface(CastInterfaces)
 add_mlir_interface(ControlFlowInterfaces)
 add_mlir_interface(DerivedAttributeOpInterface)
 add_mlir_interface(DestinationStyleOpInterface)
+add_mlir_interface(ExecutionProgressOpInterface)
 add_mlir_interface(FunctionInterfaces)
 add_mlir_interface(IndexingMapOpInterface)
 add_mlir_interface(InferIntRangeInterface)
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index a76dce6f2ffc5..33e139f6b0cea 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -314,6 +314,11 @@ Region *getEnclosingRepetitiveRegion(Operation *op);
 /// exists.
 Region *getEnclosingRepetitiveRegion(Value value);
 
+/// Return "true" if the given region branch op is guaranteed to loop
+/// infinitely. Every path starting from "parent" enters the region, but the
+/// "parent" is not reachable from there.
+bool isGuaranteedToLoopInfinitely(RegionBranchOpInterface op);
+
 /// Populate canonicalization patterns that simplify successor operands/inputs
 /// of region branch operations. Only operations with the given name are
 /// matched.
@@ -359,6 +364,13 @@ void populateRegionBranchOpInterfaceInliningPattern(
     PatternMatcherFn matcherFn = detail::defaultMatcherFn,
     PatternBenefit benefit = 1);
 
+/// Return all successor regions when branching from the given region branch
+/// point. This helper functions extracts all constant operand values and
+/// passes them to the `RegionBranchOpInterface`.
+SmallVector<RegionSuccessor>
+getSuccessorRegionsWithAttrs(RegionBranchOpInterface op,
+                             RegionBranchPoint point);
+
 //===----------------------------------------------------------------------===//
 // ControlFlow Traits
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 8975b1235a7e3..1dacde297efa2 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -350,7 +350,7 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
       "bool", "areTypesCompatible",
       (ins "::mlir::Type":$lhs, "::mlir::Type":$rhs), [{}],
       /*defaultImplementation=*/[{ return lhs == rhs; }]
-    >,
+    >
   ];
 
   let verify = [{
diff --git a/mlir/include/mlir/Interfaces/ExecutionProgressOpInterface.h b/mlir/include/mlir/Interfaces/ExecutionProgressOpInterface.h
new file mode 100644
index 0000000000000..e395f909de092
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/ExecutionProgressOpInterface.h
@@ -0,0 +1,29 @@
+//===- ExecutionProgressOpInterface.h ---------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_EXECUTIONPROGRESSOPINTERFACE_H_
+#define MLIR_INTERFACES_EXECUTIONPROGRESSOPINTERFACE_H_
+
+#include "mlir/IR/OpDefinition.h"
+
+#include "mlir/Interfaces/ExecutionProgressOpInterface.h.inc"
+
+namespace mlir {
+/// Return "true" if the operation must progress.
+///
+/// Unregistered operations are treated conservatively: they may not
+/// necessarily progress (i.e., return "false"). Registered operations are
+/// assumed to progress by default. This can be overridden by the
+/// ExecutionProgressOpInterface.
+bool mustProgress(Operation *op);
+
+/// Return "true" if the operation might not progress.
+inline bool mightNotProgress(Operation *op) { return !mustProgress(op); }
+} // namespace mlir
+
+#endif // MLIR_INTERFACES_EXECUTIONPROGRESSOPINTERFACE_H_
diff --git a/mlir/include/mlir/Interfaces/ExecutionProgressOpInterface.td b/mlir/include/mlir/Interfaces/ExecutionProgressOpInterface.td
new file mode 100644
index 0000000000000..4b7923ce3612e
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/ExecutionProgressOpInterface.td
@@ -0,0 +1,48 @@
+//===- ExecutionProgressOpInterface.td - Interface Decl. -*- tablegen -*---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the definition file for the ExecutionProgressOpInterface.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_EXECUTION_PROGRESS_OP_INTERFACE
+#define MLIR_INTERFACES_EXECUTION_PROGRESS_OP_INTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def ExecutionProgressOpInterface : OpInterface<"ExecutionProgressOpInterface"> {
+  let description = [{
+    This interface models execution progress properties of operations.
+  }];
+  let cppNamespace = "::mlir";
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Operations that "must progress" are required to return normally (control
+        flow reaches the next operation) or interact with the environment in an
+        observable way (e.g., volatile memory access, I/O, synchronization or
+        program termination). If a "must progress" op executes indefinitely
+        without any observable interaction, it may be erased.
+        
+        See LLVM "llvm.loop.mustprogress" / "mustprogress" function attribute
+        for more details.
+
+        Operations must progress by default.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"mustProgress",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return true;
+      }]
+    >
+  ];
+}
+
+#endif // MLIR_INTERFACES_EXECUTION_PROGRESS_OP_INTERFACE
diff --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
index b111117410ba3..8c3b93b3c580b 100644
--- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
@@ -13,11 +13,13 @@ add_mlir_dialect_library(MLIRSCFDialect
   MLIRArithDialect
   MLIRControlFlowDialect
   MLIRDialectUtils
+  MLIRExecutionProgressOpInterface
   MLIRFunctionInterfaces
   MLIRIR
   MLIRLoopLikeInterface
   MLIRSideEffectInterfaces
   MLIRTensorDialect
+  MLIRUBDialect
   MLIRValueBoundsOpInterface
   MLIRTransformUtils
   )
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index c46a0577c4b96..0116620bdd3a3 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Matchers.h"
@@ -509,8 +510,10 @@ void ForOp::print(OpAsmPrinter &p) {
   p.printRegion(getRegion(),
                 /*printEntryBlockArgs=*/false,
                 /*printBlockTerminators=*/!getInitArgs().empty());
-  p.printOptionalAttrDict((*this)->getAttrs(),
-                          /*elidedAttrs=*/getUnsignedCmpAttrName().strref());
+  SmallVector<StringRef> elidedAttrs = {getUnsignedCmpAttrName().strref()};
+  if (getMustProgress()) // "true" is the default, elide attribute.
+    elidedAttrs.push_back(getMustProgressAttrName().strref());
+  p.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
 }
 
 ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -691,6 +694,24 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point,
     }
   }
 
+  // Infinite loops (lb < ub and step size 0) enter the loop body and never
+  // leave it.
+  std::optional<std::pair<APInt, bool>> lbCst =
+      getConstantAPIntValue(getLowerBound());
+  std::optional<std::pair<APInt, bool>> ubCst =
+      getConstantAPIntValue(getUpperBound());
+  std::optional<std::pair<APInt, bool>> stepCst =
+      getConstantAPIntValue(getStep());
+  if (lbCst.has_value() && ubCst.has_value() && stepCst.has_value()) {
+    bool atLeastOneIteration =
+        (getUnsignedCmp() && lbCst->first.ult(ubCst->first)) ||
+        (!getUnsignedCmp() && lbCst->first.slt(ubCst->first));
+    if (atLeastOneIteration && stepCst->first.isZero()) {
+      regions.push_back(RegionSuccessor(&getRegion()));
+      return;
+    }
+  }
+
   // Both the operation itself and the region may be branching into the body or
   // back into the operation itself. It is possible for loop not to enter the
   // body.
@@ -703,6 +724,8 @@ ValueRange ForOp::getSuccessorInputs(RegionSuccessor successor) {
                               : ValueRange(getRegionIterArgs());
 }
 
+bool ForOp::mustProgress() { return getMustProgress(); }
+
 SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
 
 /// Promotes the loop body of a forallOp to its containing block if it can be
@@ -1004,6 +1027,8 @@ void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
         auto forOp = cast<ForOp>(blockArg.getOwner()->getParentOp());
         return forOp.getLowerBound();
       });
+  ub::populateEraseInfiniteRegionBranchLoopPattern(results,
+                                                   ForOp::getOperationName());
 }
 
 std::optional<APInt> ForOp::getConstantStep() {
@@ -3210,6 +3235,16 @@ void WhileOp::build(::mlir::OpBuilder &odsBuilder,
     afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
 }
 
+void WhileOp::build(::mlir::OpBuilder &odsBuilder,
+                    ::mlir::OperationState &odsState, TypeRange resultTypes,
+                    ValueRange inits, bool mustProgress) {
+  odsState.addOperands(inits);
+  for (unsigned i = 0; i < 2; ++i)
+    (void)odsState.addRegion();
+  odsState.addTypes(resultTypes);
+  odsState.addAttribute("mustProgress", odsBuilder.getBoolAttr(mustProgress));
+}
+
 ConditionOp WhileOp::getConditionOp() {
   return cast<ConditionOp>(getBeforeBody()->getTerminator());
 }
@@ -3273,6 +3308,8 @@ ValueRange WhileOp::getSuccessorInputs(RegionSuccessor successor) {
   llvm_unreachable("invalid region successor");
 }
 
+bool WhileOp::mustProgress() { return getMustProgress(); }
+
 SmallVector<Region *> WhileOp::getLoopRegions() {
   return {&getBefore(), &getAfter()};
 }
@@ -3332,7 +3369,10 @@ void scf::WhileOp::print(OpAsmPrinter &p) {
   p.printRegion(getBefore(), /*printEntryBlockArgs=*/false);
   p << " do ";
   p.printRegion(getAfter());
-  p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
+  SmallVector<StringRef> elidedAttrs;
+  if (getMustProgress()) // "true" is the default, elide attribute.
+    elidedAttrs.push_back(getMustProgressAttrName().strref());
+  p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
 }
 
 /// Verifies that two ranges of types match, i.e. have the same number of
@@ -3708,6 +3748,8 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
       results, WhileOp::getOperationName());
   populateRegionBranchOpInterfaceInliningPattern(results,
                                                  WhileOp::getOperationName());
+  ub::populateEraseInfiniteRegionBranchLoopPattern(results,
+                                                   WhileOp::getOperationName());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
index ddcbda86cf1f3..152fb226993e9 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
@@ -49,8 +49,9 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
     SmallVector<Value> initArgs;
     initArgs.push_back(forOp.getLowerBound());
     llvm::append_range(initArgs, forOp.getInitArgs());
-    auto whileOp = WhileOp::create(rewriter, forOp.getLoc(), lcvTypes, initArgs,
-                                   forOp->getAttrs());
+    auto whileOp =
+        WhileOp::create(rewriter, forOp.getLoc(), lcvTypes, initArgs);
+    whileOp->setAttrs(forOp->getAttrDictionary());
 
     // 'before' region contains the loop condition and forwarding of iteration
     // arguments to the 'after' region.
diff --git a/mlir/lib/Dialect/UB/IR/CMakeLists.txt b/mlir/lib/Dialect/UB/IR/CMakeLists.txt
index 84125ea0b5718..3baac5045b8db 100644
--- a/mlir/lib/Dialect/UB/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/UB/IR/CMakeLists.txt
@@ -5,9 +5,13 @@ add_mlir_dialect_library(MLIRUBDialect
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/UB
 
   DEPENDS
+  MLIRControlFlowInterfaces
   MLIRUBOpsIncGen
   MLIRUBOpsInterfacesIncGen
 
   LINK_LIBS PUBLIC
+  MLIRControlFlowInterfaces
   MLIRIR
+  MLIRExecutionProgressOpInterface
+  MLIRSideEffectInterfaces
   )
diff --git a/mlir/lib/Dialect/UB/IR/UBOps.cpp b/mlir/lib/Dialect/UB/IR/UBOps.cpp
index ee523f9522953..2310fc5af8cb8 100644
--- a/mlir/lib/Dialect/UB/IR/UBOps.cpp
+++ b/mlir/lib/Dialect/UB/IR/UBOps.cpp
@@ -8,6 +8,8 @@
 
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/ExecutionProgressOpInterface.h"
 #include "mlir/Transforms/InliningUtils.h"
 
 #include "mlir/IR/Builders.h"
@@ -66,3 +68,40 @@ OpFoldResult PoisonOp::fold(FoldAdaptor /*adaptor*/) { return getValue(); }
 
 #define GET_OP_CLASSES
 #include "mlir/Dialect/UB/IR/UBOps.cpp.inc"
+
+namespace {
+/// Canonicalization pattern for RegionBranchOpInterface ops that loop
+/// infinitely. Such ops are replaced with poison values if they "must
+/// progress".
+struct EraseInfiniteRegionBranchLoop : public RewritePattern {
+  EraseInfiniteRegionBranchLoop(MLIRContext *context, StringRef name,
+                                PatternBenefit benefit = 1)
+      : RewritePattern(name, benefit, context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    auto regionBranchOp = cast<RegionBranchOpInterface>(op);
+    if (mightNotProgress(op))
+      return rewriter.notifyMatchFailure(
+          op, "only loops that must progress are removed");
+    if (!wouldOpBeTriviallyDead(op))
+      return rewriter.notifyMatchFailure(op,
+                                         "only trivially dead ops are removed");
+    if (!isGuaranteedToLoopInfinitely(regionBranchOp))
+      return rewriter.notifyMatchFailure(
+          op, "only loops that loop infinitely are removed");
+    SmallVector<Value> replacements =
+        llvm::map_to_vector(op->getResultTypes(), [&](Type type) {
+          return PoisonOp::create(rewriter, op->getLoc(), type).getResult();
+        });
+    rewriter.replaceOp(op, replacements);
+    return success();
+  }
+};
+} // namespace
+
+void mlir::ub::populateEraseInfiniteRegionBranchLoopPattern(
+    RewritePatternSet &patterns, StringRef opName, PatternBenefit benefit) {
+  patterns.add<EraseInfiniteRegionBranchLoop>(patterns.getContext(), opName,
+                                              benefit);
+}
diff --git a/m...
[truncated]

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Member

@linuxlonelyeagle linuxlonelyeagle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

I am still on the side of mustProgress should default to false (even for registered ops), but the current state is an okay compromise to iterate on.

Comment on lines +26 to +30
Operations that "must progress" are required to return normally (control
flow reaches the next operation) or interact with the environment in an
observable way (e.g., volatile memory access, I/O, synchronization or
program termination). If a "must progress" op executes indefinitely
without any observable interaction, it may be erased.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this part?

or interact with the environment in an observable way (e.g., volatile memory access, I/O, synchronization or program termination).

Why instead such operation that interacts with the environment wouldn't just return "false" here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now this means that mustProgress(op) does not say much about an operation with side effets.

Another problem is that we still are missing "volatile" effects in MLIR (unless I missed the addition).

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a LangRef update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants