diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h index 1fc3c7b2236a1..cbe6c2e91f2d9 100644 --- a/llvm/include/llvm/ADT/APInt.h +++ b/llvm/include/llvm/ADT/APInt.h @@ -2193,6 +2193,16 @@ inline const APInt absdiff(const APInt &A, const APInt &B) { return A.uge(B) ? (A - B) : (B - A); } +/// Compute the higher order bits of unsigned multiplication of two APInts. +/// Mathematically, this computes the value: `(C1 * C2) >> C2.getBitWidth()` +/// where `(C1 * C2)` has double the bit width of the original values. +APInt mulhu(const APInt &C1, const APInt &C2); + +/// Compute the higher order bits of signed multiplication of two APInts. +/// Mathematically, this is `(C1 * C2) >> C2.getBitWidth()` while preserving +/// the signed bit. Example: `mulhs(-2097152, 524288) == -256` +APInt mulhs(const APInt &C1, const APInt &C2); + /// Compute GCD of two unsigned APInt values. /// /// This function returns the greatest common divisor of the two APInt values diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 50f53bbb04b62..e1fcb6f84ede2 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -6015,18 +6015,10 @@ static std::optional FoldValue(unsigned Opcode, const APInt &C1, if (!C2.getBoolValue()) break; return C1.srem(C2); - case ISD::MULHS: { - unsigned FullWidth = C1.getBitWidth() * 2; - APInt C1Ext = C1.sext(FullWidth); - APInt C2Ext = C2.sext(FullWidth); - return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth()); - } - case ISD::MULHU: { - unsigned FullWidth = C1.getBitWidth() * 2; - APInt C1Ext = C1.zext(FullWidth); - APInt C2Ext = C2.zext(FullWidth); - return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth()); - } + case ISD::MULHU: + return APIntOps::mulhu(C1, C2); + case ISD::MULHS: + return APIntOps::mulhs(C1, C2); case ISD::AVGFLOORS: { unsigned FullWidth = C1.getBitWidth() + 1; APInt C1Ext = C1.sext(FullWidth); diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp index e686b97652330..9ce10ada67e9e 100644 --- a/llvm/lib/Support/APInt.cpp +++ b/llvm/lib/Support/APInt.cpp @@ -3067,6 +3067,22 @@ void llvm::StoreIntToMemory(const APInt &IntVal, uint8_t *Dst, } } +APInt APIntOps::mulhu(const APInt &C1, const APInt &C2) { + // Return higher order bits for unsigned (C1 * C2) + unsigned FullWidth = C1.getBitWidth() * 2; + APInt C1Ext = C1.zext(FullWidth); + APInt C2Ext = C2.zext(FullWidth); + return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth()); +} + +APInt APIntOps::mulhs(const APInt &C1, const APInt &C2) { + // Return higher order bits for signed (C1 * C2) + unsigned FullWidth = C1.getBitWidth() * 2; + APInt C1Ext = C1.sext(FullWidth); + APInt C2Ext = C2.sext(FullWidth); + return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth()); +} + /// LoadIntFromMemory - Loads the integer stored in the LoadBytes bytes starting /// from Src into IntVal, which is assumed to be wide enough and to hold zero. void llvm::LoadIntFromMemory(APInt &IntVal, const uint8_t *Src, diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp index 24324822356bf..9995aeaff9287 100644 --- a/llvm/unittests/ADT/APIntTest.cpp +++ b/llvm/unittests/ADT/APIntTest.cpp @@ -2805,6 +2805,26 @@ TEST(APIntTest, multiply) { EXPECT_EQ(64U, i96.countr_zero()); } +TEST(APIntTest, Hmultiply) { + APInt i1048576(32, 1048576); + + EXPECT_EQ(APInt(32, 256), APIntOps::mulhu(i1048576, i1048576)); + + APInt i16777216(32, 16777216); + APInt i32768(32, 32768); + + EXPECT_EQ(APInt(32, 128), APIntOps::mulhu(i16777216, i32768)); + EXPECT_EQ(APInt(32, 128), APIntOps::mulhu(i32768, i16777216)); + + APInt i2097152(32, -2097152); + APInt i524288(32, 524288); + + EXPECT_EQ(APInt(32, 1024), APIntOps::mulhs(i2097152, i2097152)); + + EXPECT_EQ(APInt(32, -256), APIntOps::mulhs(i2097152, i524288)); + EXPECT_EQ(APInt(32, -256), APIntOps::mulhs(i524288, i2097152)); +} + TEST(APIntTest, RoundingUDiv) { for (uint64_t Ai = 1; Ai <= 255; Ai++) { APInt A(8, Ai); diff --git a/llvm/unittests/Support/DivisionByConstantTest.cpp b/llvm/unittests/Support/DivisionByConstantTest.cpp index 2b17f98bb75b2..8e0c78fe85654 100644 --- a/llvm/unittests/Support/DivisionByConstantTest.cpp +++ b/llvm/unittests/Support/DivisionByConstantTest.cpp @@ -21,12 +21,6 @@ template static void EnumerateAPInts(unsigned Bits, Fn TestFn) { } while (++N != 0); } -APInt MULHS(APInt X, APInt Y) { - unsigned Bits = X.getBitWidth(); - unsigned WideBits = 2 * Bits; - return (X.sext(WideBits) * Y.sext(WideBits)).lshr(Bits).trunc(Bits); -} - APInt SignedDivideUsingMagic(APInt Numerator, APInt Divisor, SignedDivisionByConstantInfo Magics) { unsigned Bits = Numerator.getBitWidth(); @@ -48,7 +42,7 @@ APInt SignedDivideUsingMagic(APInt Numerator, APInt Divisor, } // Multiply the numerator by the magic value. - APInt Q = MULHS(Numerator, Magics.Magic); + APInt Q = APIntOps::mulhs(Numerator, Magics.Magic); // (Optionally) Add/subtract the numerator using Factor. Factor = Numerator * Factor; @@ -89,12 +83,6 @@ TEST(SignedDivisionByConstantTest, Test) { } } -APInt MULHU(APInt X, APInt Y) { - unsigned Bits = X.getBitWidth(); - unsigned WideBits = 2 * Bits; - return (X.zext(WideBits) * Y.zext(WideBits)).lshr(Bits).trunc(Bits); -} - APInt UnsignedDivideUsingMagic(const APInt &Numerator, const APInt &Divisor, bool LZOptimization, bool AllowEvenDivisorOptimization, bool ForceNPQ, @@ -129,16 +117,16 @@ APInt UnsignedDivideUsingMagic(const APInt &Numerator, const APInt &Divisor, APInt Q = Numerator.lshr(PreShift); // Multiply the numerator by the magic value. - Q = MULHU(Q, Magics.Magic); + Q = APIntOps::mulhu(Q, Magics.Magic); if (UseNPQ || ForceNPQ) { APInt NPQ = Numerator - Q; // For vectors we might have a mix of non-NPQ/NPQ paths, so use - // MULHU to act as a SRL-by-1 for NPQ, else multiply by zero. + // mulhu to act as a SRL-by-1 for NPQ, else multiply by zero. APInt NPQ_Scalar = NPQ.lshr(1); (void)NPQ_Scalar; - NPQ = MULHU(NPQ, NPQFactor); + NPQ = APIntOps::mulhu(NPQ, NPQFactor); assert(!UseNPQ || NPQ == NPQ_Scalar); Q = NPQ + Q; diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp index 658f3796721c4..65bb228cbc73c 100644 --- a/llvm/unittests/Support/KnownBitsTest.cpp +++ b/llvm/unittests/Support/KnownBitsTest.cpp @@ -537,19 +537,13 @@ TEST(KnownBitsTest, BinaryExhaustive) { [](const KnownBits &Known1, const KnownBits &Known2) { return KnownBits::mulhs(Known1, Known2); }, - [](const APInt &N1, const APInt &N2) { - unsigned Bits = N1.getBitWidth(); - return (N1.sext(2 * Bits) * N2.sext(2 * Bits)).extractBits(Bits, Bits); - }, + [](const APInt &N1, const APInt &N2) { return APIntOps::mulhs(N1, N2); }, checkCorrectnessOnlyBinary); testBinaryOpExhaustive( [](const KnownBits &Known1, const KnownBits &Known2) { return KnownBits::mulhu(Known1, Known2); }, - [](const APInt &N1, const APInt &N2) { - unsigned Bits = N1.getBitWidth(); - return (N1.zext(2 * Bits) * N2.zext(2 * Bits)).extractBits(Bits, Bits); - }, + [](const APInt &N1, const APInt &N2) { return APIntOps::mulhu(N1, N2); }, checkCorrectnessOnlyBinary); } diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 0f71c19c23b65..c705051f0f440 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -434,9 +434,7 @@ arith::MulSIExtendedOp::fold(FoldAdaptor adaptor, // Invoke the constant fold helper again to calculate the 'high' result. Attribute highAttr = constFoldBinaryOp( adaptor.getOperands(), [](const APInt &a, const APInt &b) { - unsigned bitWidth = a.getBitWidth(); - APInt fullProduct = a.sext(bitWidth * 2) * b.sext(bitWidth * 2); - return fullProduct.extractBits(bitWidth, bitWidth); + return llvm::APIntOps::mulhs(a, b); }); assert(highAttr && "Unexpected constant-folding failure"); @@ -491,9 +489,7 @@ arith::MulUIExtendedOp::fold(FoldAdaptor adaptor, // Invoke the constant fold helper again to calculate the 'high' result. Attribute highAttr = constFoldBinaryOp( adaptor.getOperands(), [](const APInt &a, const APInt &b) { - unsigned bitWidth = a.getBitWidth(); - APInt fullProduct = a.zext(bitWidth * 2) * b.zext(bitWidth * 2); - return fullProduct.extractBits(bitWidth, bitWidth); + return llvm::APIntOps::mulhu(a, b); }); assert(highAttr && "Unexpected constant-folding failure"); diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp index 4c62289a1e945..eb1e97e7ecc90 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -250,14 +250,11 @@ struct MulExtendedFold final : OpRewritePattern { auto highBits = constFoldBinaryOp( {lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) { - unsigned bitWidth = a.getBitWidth(); - APInt c; if (IsSigned) { - c = a.sext(bitWidth * 2) * b.sext(bitWidth * 2); + return llvm::APIntOps::mulhs(a, b); } else { - c = a.zext(bitWidth * 2) * b.zext(bitWidth * 2); + return llvm::APIntOps::mulhu(a, b); } - return c.extractBits(bitWidth, bitWidth); // Extract high result }); if (!highBits)