diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index 18f139c1bd54a..e7bce98c607df 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -483,7 +483,7 @@ LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) { std::optional mayBeConstantTripCount = forOp.getStaticTripCount(); if (!mayBeConstantTripCount.has_value()) return failure(); - APInt &tripCount = *mayBeConstantTripCount; + const APInt &tripCount = *mayBeConstantTripCount; if (tripCount.isZero()) return success(); if (tripCount.getSExtValue() == 1) diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp index 5048b19b2891f..8d3944f883963 100644 --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Matchers.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/APSInt.h" @@ -280,27 +281,28 @@ std::optional constantTripCount( computeUbMinusLb) { // This is the bitwidth used to return 0 when loop does not execute. // We infer it from the type of the bound if it isn't an index type. - bool isIndex = true; - auto getBitwidth = [&](OpFoldResult ofr) -> int { - if (auto attr = dyn_cast(ofr)) { - if (auto intAttr = dyn_cast(attr)) { - if (auto intType = dyn_cast(intAttr.getType())) { - isIndex = intType.isIndex(); - return intType.getWidth(); - } - } + auto getBitwidth = [&](OpFoldResult ofr) -> std::tuple { + if (auto intAttr = + dyn_cast_or_null(dyn_cast(ofr))) { + if (auto intType = dyn_cast(intAttr.getType())) + return std::make_tuple(intType.getWidth(), intType.isIndex()); } else { auto val = cast(ofr); - if (auto intType = dyn_cast(val.getType())) { - isIndex = intType.isIndex(); - return intType.getWidth(); - } + if (auto intType = dyn_cast(val.getType())) + return std::make_tuple(intType.getWidth(), intType.isIndex()); } - return IndexType::kInternalStorageBitWidth; + return std::make_tuple(IndexType::kInternalStorageBitWidth, true); }; - int bitwidth = getBitwidth(lb); - assert(bitwidth == getBitwidth(ub) && - "lb and ub must have the same bitwidth"); + auto [bitwidth, isIndex] = getBitwidth(lb); + // This would better be an assert, but unfortunately it breaks scf.for_all + // which is missing attributes and SSA value optionally for its bounds, and + // uses Index type for the dynamic bounds but i64 for the static bounds. This + // is broken... + if (std::tie(bitwidth, isIndex) != getBitwidth(ub)) { + LDBG() << "mismatch between lb and ub bitwidth/type: " << ub << " vs " + << lb; + return std::nullopt; + } if (lb == ub) return APInt(bitwidth, 0);