diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h index b9b39f3b9dfbc..bd1716219ee5f 100644 --- a/llvm/include/llvm/ADT/APInt.h +++ b/llvm/include/llvm/ADT/APInt.h @@ -1743,6 +1743,9 @@ class [[nodiscard]] APInt { /// \returns the multiplicative inverse for a given modulo. APInt multiplicativeInverse(const APInt &modulo) const; + /// \returns the multiplicative inverse of an odd APInt modulo 2^BitWidth. + APInt multiplicativeInverse() const; + /// @} /// \name Building-block Operations for APInt and APFloat /// @{ diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 515b9d0744f6e..e030b9fc7dac4 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -944,10 +944,7 @@ static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K, // Calculate the multiplicative inverse of K! / 2^T; // this multiplication factor will perform the exact division by // K! / 2^T. - APInt Mod = APInt::getSignedMinValue(W+1); - APInt MultiplyFactor = OddFactorial.zext(W+1); - MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod); - MultiplyFactor = MultiplyFactor.trunc(W); + APInt MultiplyFactor = OddFactorial.multiplicativeInverse(); // Calculate the product, at width T+W IntegerType *CalculationTy = IntegerType::get(SE.getContext(), @@ -10086,10 +10083,8 @@ static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent // (N / D) in general. The inverse itself always fits into BW bits, though, // so we immediately truncate it. - APInt AD = A.lshr(Mult2).zext(BW + 1); // AD = A / D - APInt Mod(BW + 1, 0); - Mod.setBit(BW - Mult2); // Mod = N / D - APInt I = AD.multiplicativeInverse(Mod).trunc(BW); + APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D + APInt I = AD.multiplicativeInverse().zext(BW); // 4. Compute the minimum unsigned root of the equation: // I * (B / D) mod (N / D) diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp index 062132c8304b0..719209e0edd5f 100644 --- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp +++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp @@ -5201,10 +5201,7 @@ MachineInstr *CombinerHelper::buildSDivUsingMul(MachineInstr &MI) { // Calculate the multiplicative inverse modulo BW. // 2^W requires W + 1 bits, so we have to extend and then truncate. - unsigned W = Divisor.getBitWidth(); - APInt Factor = Divisor.zext(W + 1) - .multiplicativeInverse(APInt::getSignedMinValue(W + 1)) - .trunc(W); + APInt Factor = Divisor.multiplicativeInverse(); Shifts.push_back(MIB.buildConstant(ScalarShiftAmtTy, Shift).getReg(0)); Factors.push_back(MIB.buildConstant(ScalarTy, Factor).getReg(0)); return true; diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp index 5e053f97675d7..409d66adfd67d 100644 --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -6071,11 +6071,7 @@ static SDValue BuildExactSDIV(const TargetLowering &TLI, SDNode *N, Divisor.ashrInPlace(Shift); UseSRA = true; } - // Calculate the multiplicative inverse, using Newton's method. - APInt t; - APInt Factor = Divisor; - while ((t = Divisor * Factor) != 1) - Factor *= APInt(Divisor.getBitWidth(), 2) - t; + APInt Factor = Divisor.multiplicativeInverse(); Shifts.push_back(DAG.getConstant(Shift, dl, ShSVT)); Factors.push_back(DAG.getConstant(Factor, dl, SVT)); return true; @@ -6664,10 +6660,7 @@ TargetLowering::prepareUREMEqFold(EVT SETCCVT, SDValue REMNode, // P = inv(D0, 2^W) // 2^W requires W + 1 bits, so we have to extend and then truncate. unsigned W = D.getBitWidth(); - APInt P = D0.zext(W + 1) - .multiplicativeInverse(APInt::getSignedMinValue(W + 1)) - .trunc(W); - assert(!P.isZero() && "No multiplicative inverse!"); // unreachable + APInt P = D0.multiplicativeInverse(); assert((D0 * P).isOne() && "Multiplicative inverse basic check failed."); // Q = floor((2^W - 1) u/ D) @@ -6922,10 +6915,7 @@ TargetLowering::prepareSREMEqFold(EVT SETCCVT, SDValue REMNode, // P = inv(D0, 2^W) // 2^W requires W + 1 bits, so we have to extend and then truncate. unsigned W = D.getBitWidth(); - APInt P = D0.zext(W + 1) - .multiplicativeInverse(APInt::getSignedMinValue(W + 1)) - .trunc(W); - assert(!P.isZero() && "No multiplicative inverse!"); // unreachable + APInt P = D0.multiplicativeInverse(); assert((D0 * P).isOne() && "Multiplicative inverse basic check failed."); // A = floor((2^(W - 1) - 1) / D0) & -2^K @@ -7651,7 +7641,7 @@ bool TargetLowering::expandMUL(SDNode *N, SDValue &Lo, SDValue &Hi, EVT HiLoVT, // // For division, we can compute the remainder using the algorithm described // above, subtract it from the dividend to get an exact multiple of Constant. -// Then multiply that extact multiply by the multiplicative inverse modulo +// Then multiply that exact multiply by the multiplicative inverse modulo // (1 << (BitWidth / 2)) to get the quotient. // If Constant is even, we can shift right the dividend and the divisor by the @@ -7786,10 +7776,7 @@ bool TargetLowering::expandDIVREMByConstant(SDNode *N, // Multiply by the multiplicative inverse of the divisor modulo // (1 << BitWidth). - APInt Mod = APInt::getSignedMinValue(BitWidth + 1); - APInt MulFactor = Divisor.zext(BitWidth + 1); - MulFactor = MulFactor.multiplicativeInverse(Mod); - MulFactor = MulFactor.trunc(BitWidth); + APInt MulFactor = Divisor.multiplicativeInverse(); SDValue Quotient = DAG.getNode(ISD::MUL, dl, VT, Dividend, DAG.getConstant(MulFactor, dl, VT)); diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp index c20609748dc97..f8f699f8f6ccd 100644 --- a/llvm/lib/Support/APInt.cpp +++ b/llvm/lib/Support/APInt.cpp @@ -1289,6 +1289,19 @@ APInt APInt::multiplicativeInverse(const APInt& modulo) const { return std::move(t[i]); } +/// \returns the multiplicative inverse of an odd APInt modulo 2^BitWidth. +APInt APInt::multiplicativeInverse() const { + assert((*this)[0] && + "multiplicative inverse is only defined for odd numbers!"); + + // Use Newton's method. + APInt Factor = *this; + APInt T; + while (!(T = *this * Factor).isOne()) + Factor *= 2 - T; + return Factor; +} + /// Implementation of Knuth's Algorithm D (Division of nonnegative integers) /// from "Art of Computer Programming, Volume 2", section 4.3.1, p. 272. The /// variables here have the same names as in the algorithm. Comments explain diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp index d5ef63e38e279..23f9ee2d39c44 100644 --- a/llvm/unittests/ADT/APIntTest.cpp +++ b/llvm/unittests/ADT/APIntTest.cpp @@ -3257,9 +3257,10 @@ TEST(APIntTest, MultiplicativeInverseExaustive) { .multiplicativeInverse(APInt::getSignedMinValue(BitWidth + 1)) .trunc(BitWidth); APInt One = V * MulInv; - if (!V.isZero() && V.countr_zero() == 0) { + if (V[0]) { // Multiplicative inverse exists for all odd numbers. EXPECT_TRUE(One.isOne()); + EXPECT_TRUE((V * V.multiplicativeInverse()).isOne()); } else { // Multiplicative inverse does not exist for even numbers (and 0). EXPECT_TRUE(MulInv.isZero());