diff --git a/llvm/include/llvm/IR/ConstantRange.h b/llvm/include/llvm/IR/ConstantRange.h index 44b8c395c89e2..e464d29afdaca 100644 --- a/llvm/include/llvm/IR/ConstantRange.h +++ b/llvm/include/llvm/IR/ConstantRange.h @@ -383,6 +383,11 @@ class LLVM_NODISCARD ConstantRange { /// treating both this and \p Other as unsigned ranges. ConstantRange multiply(const ConstantRange &Other) const; + /// Return range of possible values for a signed multiplication of this and + /// \p Other. However, if overflow is possible always return a full range + /// rather than trying to determine a more precise result. + ConstantRange smul_fast(const ConstantRange &Other) const; + /// Return a new range representing the possible values resulting /// from a signed maximum of a value in this range and a value in \p Other. ConstantRange smax(const ConstantRange &Other) const; diff --git a/llvm/lib/Analysis/BasicAliasAnalysis.cpp b/llvm/lib/Analysis/BasicAliasAnalysis.cpp index 3129da27053f2..865db9f798326 100644 --- a/llvm/lib/Analysis/BasicAliasAnalysis.cpp +++ b/llvm/lib/Analysis/BasicAliasAnalysis.cpp @@ -1302,7 +1302,7 @@ AliasResult BasicAAResult::aliasGEP( computeConstantRange(Var.Val.V, true, &AC, Var.CxtI)); if (!R.isFullSet() && !R.isEmptySet()) VarIndexRange = R.sextOrTrunc(Var.Scale.getBitWidth()) - .multiply(ConstantRange(Var.Scale)); + .smul_fast(ConstantRange(Var.Scale)); } else if (DecompGEP1.VarIndices.size() == 2) { // VarIndex = Scale*V0 + (-Scale)*V1. // If V0 != V1 then abs(VarIndex) >= abs(Scale). diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp index d8b4262a81142..6877a5d278ac5 100644 --- a/llvm/lib/IR/ConstantRange.cpp +++ b/llvm/lib/IR/ConstantRange.cpp @@ -1054,6 +1054,25 @@ ConstantRange::multiply(const ConstantRange &Other) const { return UR.isSizeStrictlySmallerThan(SR) ? UR : SR; } +ConstantRange ConstantRange::smul_fast(const ConstantRange &Other) const { + if (isEmptySet() || Other.isEmptySet()) + return getEmpty(); + + APInt Min = getSignedMin(); + APInt Max = getSignedMax(); + APInt OtherMin = Other.getSignedMin(); + APInt OtherMax = Other.getSignedMax(); + + bool O1, O2, O3, O4; + auto Muls = {Min.smul_ov(OtherMin, O1), Min.smul_ov(OtherMax, O2), + Max.smul_ov(OtherMin, O3), Max.smul_ov(OtherMax, O4)}; + if (O1 || O2 || O3 || O4) + return getFull(); + + auto Compare = [](const APInt &A, const APInt &B) { return A.slt(B); }; + return getNonEmpty(std::min(Muls, Compare), std::max(Muls, Compare) + 1); +} + ConstantRange ConstantRange::smax(const ConstantRange &Other) const { // X smax Y is: range(smax(X_smin, Y_smin), diff --git a/llvm/unittests/IR/ConstantRangeTest.cpp b/llvm/unittests/IR/ConstantRangeTest.cpp index bc78869f9c54d..21533652b11c2 100644 --- a/llvm/unittests/IR/ConstantRangeTest.cpp +++ b/llvm/unittests/IR/ConstantRangeTest.cpp @@ -1081,6 +1081,20 @@ TEST_F(ConstantRangeTest, Multiply) { ConstantRange(APInt(8, -2), APInt(8, 1))); } +TEST_F(ConstantRangeTest, smul_fast) { + TestBinaryOpExhaustive( + [](const ConstantRange &CR1, const ConstantRange &CR2) { + return CR1.smul_fast(CR2); + }, + [](const APInt &N1, const APInt &N2) { + return N1 * N2; + }, + PreferSmallest, + [](const ConstantRange &, const ConstantRange &) { + return false; // Check correctness only. + }); +} + TEST_F(ConstantRangeTest, UMax) { EXPECT_EQ(Full.umax(Full), Full); EXPECT_EQ(Full.umax(Empty), Empty);