diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h index 4aad2c950f7448..e00709a54465ff 100644 --- a/llvm/include/llvm/ADT/APInt.h +++ b/llvm/include/llvm/ADT/APInt.h @@ -2206,6 +2206,15 @@ Optional SolveQuadraticEquationWrap(APInt A, APInt B, APInt C, Optional GetMostSignificantDifferentBit(const APInt &A, const APInt &B); +/// Splat/Merge neighboring bits to widen/narrow the bitmask represented +/// by \param A to \param NewBitWidth bits. +/// +/// e.g. ScaleBitMask(0b0101, 8) -> 0b00110011 +/// e.g. ScaleBitMask(0b00011011, 4) -> 0b0111 +/// A.getBitwidth() or NewBitWidth must be a whole multiples of the other. +/// +/// TODO: Do we need a mode where all bits must be set when merging down? +APInt ScaleBitMask(const APInt &A, unsigned NewBitWidth); } // namespace APIntOps // See friend declaration above. This additional declaration is required in diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 08ff77b5ef33ef..6808b212bb2b23 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -2990,11 +2990,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts, // bits from the overlapping larger input elements and extracting the // sub sections we actually care about. unsigned SubScale = SubBitWidth / BitWidth; - APInt SubDemandedElts(NumElts / SubScale, 0); - for (unsigned i = 0; i != NumElts; ++i) - if (DemandedElts[i]) - SubDemandedElts.setBit(i / SubScale); - + APInt SubDemandedElts = + APIntOps::ScaleBitMask(DemandedElts, NumElts / SubScale); Known2 = computeKnownBits(N0, SubDemandedElts, Depth + 1); Known.Zero.setAllBits(); Known.One.setAllBits(); @@ -3802,10 +3799,8 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts, assert(VT.isVector() && "Expected bitcast to vector"); unsigned Scale = SrcBits / VTBits; - APInt SrcDemandedElts(NumElts / Scale, 0); - for (unsigned i = 0; i != NumElts; ++i) - if (DemandedElts[i]) - SrcDemandedElts.setBit(i / Scale); + APInt SrcDemandedElts = + APIntOps::ScaleBitMask(DemandedElts, NumElts / Scale); // Fast case - sign splat can be simply split across the small elements. Tmp = ComputeNumSignBits(N0, SrcDemandedElts, Depth + 1); diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp index 6a3eba1cb6231c..1b84aff7ac2fcd 100644 --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -2469,17 +2469,13 @@ bool TargetLowering::SimplifyDemandedVectorElts( return SimplifyDemandedVectorElts(Src, DemandedElts, KnownUndef, KnownZero, TLO, Depth + 1); - APInt SrcZero, SrcUndef; - APInt SrcDemandedElts = APInt::getZero(NumSrcElts); + APInt SrcDemandedElts, SrcZero, SrcUndef; // Bitcast from 'large element' src vector to 'small element' vector, we // must demand a source element if any DemandedElt maps to it. if ((NumElts % NumSrcElts) == 0) { unsigned Scale = NumElts / NumSrcElts; - for (unsigned i = 0; i != NumElts; ++i) - if (DemandedElts[i]) - SrcDemandedElts.setBit(i / Scale); - + SrcDemandedElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts); if (SimplifyDemandedVectorElts(Src, SrcDemandedElts, SrcUndef, SrcZero, TLO, Depth + 1)) return true; @@ -2519,10 +2515,7 @@ bool TargetLowering::SimplifyDemandedVectorElts( // of this vector. if ((NumSrcElts % NumElts) == 0) { unsigned Scale = NumSrcElts / NumElts; - for (unsigned i = 0; i != NumElts; ++i) - if (DemandedElts[i]) - SrcDemandedElts.setBits(i * Scale, (i + 1) * Scale); - + SrcDemandedElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts); if (SimplifyDemandedVectorElts(Src, SrcDemandedElts, SrcUndef, SrcZero, TLO, Depth + 1)) return true; diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp index 39824905434ccd..a630050c0157a2 100644 --- a/llvm/lib/Support/APInt.cpp +++ b/llvm/lib/Support/APInt.cpp @@ -2948,6 +2948,40 @@ llvm::APIntOps::GetMostSignificantDifferentBit(const APInt &A, const APInt &B) { return A.getBitWidth() - ((A ^ B).countLeadingZeros() + 1); } +APInt llvm::APIntOps::ScaleBitMask(const APInt &A, unsigned NewBitWidth) { + unsigned OldBitWidth = A.getBitWidth(); + assert((((OldBitWidth % NewBitWidth) == 0) || + ((NewBitWidth % OldBitWidth) == 0)) && + "One size should be a multiple of the other one. " + "Can't do fractional scaling."); + + // Check for matching bitwidths. + if (OldBitWidth == NewBitWidth) + return A; + + APInt NewA = APInt::getNullValue(NewBitWidth); + + // Check for null input. + if (A.isNullValue()) + return NewA; + + if (NewBitWidth > OldBitWidth) { + // Repeat bits. + unsigned Scale = NewBitWidth / OldBitWidth; + for (unsigned i = 0; i != OldBitWidth; ++i) + if (A[i]) + NewA.setBits(i * Scale, (i + 1) * Scale); + } else { + // Merge bits - if any old bit is set, then set scale equivalent new bit. + unsigned Scale = OldBitWidth / NewBitWidth; + for (unsigned i = 0; i != NewBitWidth; ++i) + if (!A.extractBits(Scale, i * Scale).isNullValue()) + NewA.setBit(i); + } + + return NewA; +} + /// StoreIntToMemory - Fills the StoreBytes bytes of memory starting from Dst /// with the integer held in IntVal. void llvm::StoreIntToMemory(const APInt &IntVal, uint8_t *Dst, diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp index 0f8f4626692b4c..6da9d2ddeb993d 100644 --- a/llvm/unittests/ADT/APIntTest.cpp +++ b/llvm/unittests/ADT/APIntTest.cpp @@ -2982,4 +2982,24 @@ TEST(APIntTest, ZeroWidth) { EXPECT_EQ(0U, MZW1.getBitWidth()); } +TEST(APIntTest, ScaleBitMask) { + EXPECT_EQ(APIntOps::ScaleBitMask(APInt(2, 0x00), 8), APInt(8, 0x00)); + EXPECT_EQ(APIntOps::ScaleBitMask(APInt(2, 0x01), 8), APInt(8, 0x0F)); + EXPECT_EQ(APIntOps::ScaleBitMask(APInt(2, 0x02), 8), APInt(8, 0xF0)); + EXPECT_EQ(APIntOps::ScaleBitMask(APInt(2, 0x03), 8), APInt(8, 0xFF)); + + EXPECT_EQ(APIntOps::ScaleBitMask(APInt(8, 0x00), 4), APInt(4, 0x00)); + EXPECT_EQ(APIntOps::ScaleBitMask(APInt(8, 0xFF), 4), APInt(4, 0x0F)); + EXPECT_EQ(APIntOps::ScaleBitMask(APInt(8, 0xE4), 4), APInt(4, 0x0E)); + + EXPECT_EQ(APIntOps::ScaleBitMask(APInt(8, 0x00), 8), APInt(8, 0x00)); + + EXPECT_EQ(APIntOps::ScaleBitMask(APInt::getNullValue(1024), 4096), + APInt::getNullValue(4096)); + EXPECT_EQ(APIntOps::ScaleBitMask(APInt::getAllOnes(4096), 256), + APInt::getAllOnes(256)); + EXPECT_EQ(APIntOps::ScaleBitMask(APInt::getOneBitSet(4096, 32), 256), + APInt::getOneBitSet(256, 2)); +} + } // end anonymous namespace