diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h index 881125bd0da2f..c85f3b02c4a44 100644 --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -251,8 +251,7 @@ FailureOr parallelLoopUnrollByFactors( /// Get constant trip counts for each of the induction variables of the given /// loop operation. If any of the loop's trip counts is not constant, return an /// empty vector. -/// TODO(#178506): Should return SmallVector for correct signedness. -llvm::SmallVector +llvm::SmallVector getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp); } // namespace mlir diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index 3892904646da7..f8a4f057c9f0d 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -391,18 +391,14 @@ FailureOr mlir::loopUnrollByFactor( int64_t ubCst = getConstantIntValue(forOp.getUpperBound()).value(); int64_t stepCst = getConstantIntValue(forOp.getStep()).value(); if (unrollFactor == 1) { - if (*constTripCount == 1 && + if (constTripCount->isOne() && failed(forOp.promoteIfSingleIteration(rewriter))) return failure(); return UnrolledLoopInfo{forOp, std::nullopt}; } - // TODO(#178506): This may overflow for large trip counts. Should use - // uint64_t. - int64_t tripCountEvenMultiple = - constTripCount->getZExtValue() - - (constTripCount->getZExtValue() % unrollFactor); - // TODO(#178506): This may overflow when computing upperBoundUnrolledCst. + uint64_t tripCount = constTripCount->getZExtValue(); + uint64_t tripCountEvenMultiple = tripCount - tripCount % unrollFactor; int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst; int64_t stepUnrolledCst = stepCst * unrollFactor; @@ -504,7 +500,7 @@ LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) { const APInt &tripCount = *mayBeConstantTripCount; if (tripCount.isZero()) return success(); - if (tripCount.getZExtValue() == 1) + if (tripCount.isOne()) return forOp.promoteIfSingleIteration(rewriter); return loopUnrollByFactor(forOp, tripCount.getZExtValue()); } @@ -552,12 +548,13 @@ LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp, LDBG() << "failed to unroll and jam: trip count could not be determined"; return failure(); } - if (unrollJamFactor > tripCount->getZExtValue()) { + uint64_t tripCountValue = tripCount->getZExtValue(); + if (unrollJamFactor > tripCountValue) { LDBG() << "unroll and jam factor is greater than trip count, set factor to " "trip " "count"; - unrollJamFactor = tripCount->getZExtValue(); - } else if (tripCount->getZExtValue() % unrollJamFactor != 0) { + unrollJamFactor = tripCountValue; + } else if (tripCountValue % unrollJamFactor != 0) { LDBG() << "failed to unroll and jam: unsupported trip count that is not a " "multiple of unroll jam factor"; return failure(); @@ -1563,23 +1560,21 @@ bool mlir::isPerfectlyNestedForLoops( return true; } -llvm::SmallVector +llvm::SmallVector mlir::getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp) { std::optional> loBnds = loopOp.getLoopLowerBounds(); std::optional> upBnds = loopOp.getLoopUpperBounds(); std::optional> steps = loopOp.getLoopSteps(); if (!loBnds || !upBnds || !steps) return {}; - // TODO(#178506): The result should be SmallVector and use uint64_t - // for trip counts. - llvm::SmallVector tripCounts; + llvm::SmallVector tripCounts; for (auto [lb, ub, step] : llvm::zip(*loBnds, *upBnds, *steps)) { // TODO(#178506): Signedness is not handled correctly here. std::optional numIter = constantTripCount( lb, ub, step, /*isSigned=*/true, scf::computeUbMinusLb); if (!numIter) return {}; - tripCounts.push_back(numIter->getZExtValue()); + tripCounts.push_back(*numIter); } return tripCounts; } @@ -1610,7 +1605,7 @@ FailureOr mlir::parallelLoopUnrollByFactors( // Make sure that the unroll factors divide the iteration space evenly // TODO: Support unrolling loops with dynamic iteration spaces. - const llvm::SmallVector tripCounts = getConstLoopTripCounts(op); + const llvm::SmallVector tripCounts = getConstLoopTripCounts(op); if (tripCounts.empty()) return rewriter.notifyMatchFailure( op, "Failed to compute constant trip counts for the loop. Note that " @@ -1618,7 +1613,7 @@ FailureOr mlir::parallelLoopUnrollByFactors( for (unsigned dimIdx = firstLoopDimIdx; dimIdx < numLoops; dimIdx++) { const uint64_t unrollFactor = unrollFactors[dimIdx - firstLoopDimIdx]; - if (tripCounts[dimIdx] % unrollFactor) + if (tripCounts[dimIdx].urem(unrollFactor) != 0) return rewriter.notifyMatchFailure( op, "Unroll factors don't divide the iteration space evenly"); }