Skip to content

Conversation

joker-eph
Copy link
Collaborator

Add post-merge review comments on #158679

Add post-merge review comments on llvm#158679
@llvmbot
Copy link
Member

llvmbot commented Sep 17, 2025

@llvm/pr-subscribers-mlir-scf

@llvm/pr-subscribers-mlir

Author: Mehdi Amini (joker-eph)

Changes

Add post-merge review comments on #158679


Full diff: https://github.com/llvm/llvm-project/pull/159307.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/SCF/Utils/Utils.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Utils/StaticValueUtils.cpp (+19-17)
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 18f139c1bd54a..e7bce98c607df 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -483,7 +483,7 @@ LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
   std::optional<APInt> mayBeConstantTripCount = forOp.getStaticTripCount();
   if (!mayBeConstantTripCount.has_value())
     return failure();
-  APInt &tripCount = *mayBeConstantTripCount;
+  const APInt &tripCount = *mayBeConstantTripCount;
   if (tripCount.isZero())
     return success();
   if (tripCount.getSExtValue() == 1)
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 5048b19b2891f..8d3944f883963 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/APSInt.h"
@@ -280,27 +281,28 @@ std::optional<APInt> constantTripCount(
         computeUbMinusLb) {
   // This is the bitwidth used to return 0 when loop does not execute.
   // We infer it from the type of the bound if it isn't an index type.
-  bool isIndex = true;
-  auto getBitwidth = [&](OpFoldResult ofr) -> int {
-    if (auto attr = dyn_cast<Attribute>(ofr)) {
-      if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
-        if (auto intType = dyn_cast<IntegerType>(intAttr.getType())) {
-          isIndex = intType.isIndex();
-          return intType.getWidth();
-        }
-      }
+  auto getBitwidth = [&](OpFoldResult ofr) -> std::tuple<int, bool> {
+    if (auto intAttr =
+            dyn_cast_or_null<IntegerAttr>(dyn_cast<Attribute>(ofr))) {
+      if (auto intType = dyn_cast<IntegerType>(intAttr.getType()))
+        return std::make_tuple(intType.getWidth(), intType.isIndex());
     } else {
       auto val = cast<Value>(ofr);
-      if (auto intType = dyn_cast<IntegerType>(val.getType())) {
-        isIndex = intType.isIndex();
-        return intType.getWidth();
-      }
+      if (auto intType = dyn_cast<IntegerType>(val.getType()))
+        return std::make_tuple(intType.getWidth(), intType.isIndex());
     }
-    return IndexType::kInternalStorageBitWidth;
+    return std::make_tuple(IndexType::kInternalStorageBitWidth, true);
   };
-  int bitwidth = getBitwidth(lb);
-  assert(bitwidth == getBitwidth(ub) &&
-         "lb and ub must have the same bitwidth");
+  auto [bitwidth, isIndex] = getBitwidth(lb);
+  // This would better be an assert, but unfortunately it breaks scf.for_all
+  // which is missing attributes and SSA value optionally for its bounds, and
+  // uses Index type for the dynamic bounds but i64 for the static bounds. This
+  // is broken...
+  if (std::tie(bitwidth, isIndex) != getBitwidth(ub)) {
+    LDBG() << "mismatch between lb and ub bitwidth/type: " << ub << " vs "
+           << lb;
+    return std::nullopt;
+  }
   if (lb == ub)
     return APInt(bitwidth, 0);
 

@joker-eph joker-eph merged commit 385c9f5 into llvm:main Sep 17, 2025
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants