Skip to content

Commit

Permalink
[InstCombine] Make MatchBSwap also match bit reversals
Browse files Browse the repository at this point in the history
MatchBSwap has most of the functionality to match bit reversals already. If we switch it from looking at bytes to individual bits and remove a few early exits, we can extend the main recursive function to match any sequence of ORs, ANDs and shifts that assemble a value from different parts of another, base value. Once we have this bit->bit mapping, we can very simply detect if it is appropriate for a bswap or bitreverse.

llvm-svn: 255334
  • Loading branch information
James Molloy committed Dec 11, 2015
1 parent 4865148 commit 37b82e7
Show file tree
Hide file tree
Showing 3 changed files with 250 additions and 103 deletions.
237 changes: 135 additions & 102 deletions llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1566,157 +1566,190 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) {
return Changed ? &I : nullptr;
}


/// Analyze the specified subexpression and see if it is capable of providing
/// pieces of a bswap. The subexpression provides pieces of a bswap if it is
/// proven that each of the non-zero bytes in the output of the expression came
/// from the corresponding "byte swapped" byte in some other value.
/// For example, if the current subexpression is "(shl i32 %X, 24)" then
/// we know that the expression deposits the low byte of %X into the high byte
/// of the bswap result and that all other bytes are zero. This expression is
/// accepted, the high byte of ByteValues is set to X to indicate a correct
/// match.
/// pieces of a bswap or bitreverse. The subexpression provides a potential
/// piece of a bswap or bitreverse if it can be proven that each non-zero bit in
/// the output of the expression came from a corresponding bit in some other
/// value. This function is recursive, and the end result is a mapping of
/// (value, bitnumber) to bitnumber. It is the caller's responsibility to
/// validate that all `value`s are identical and that the bitnumber to bitnumber
/// mapping is correct for a bswap or bitreverse.
///
/// For example, if the current subexpression if "(shl i32 %X, 24)" then we know
/// that the expression deposits the low byte of %X into the high byte of the
/// result and that all other bits are zero. This expression is accepted,
/// BitValues[24-31] are set to %X and BitProvenance[24-31] are set to [0-7].
///
/// This function returns true if the match was unsuccessful and false if so.
/// On entry to the function the "OverallLeftShift" is a signed integer value
/// indicating the number of bytes that the subexpression is later shifted. For
/// indicating the number of bits that the subexpression is later shifted. For
/// example, if the expression is later right shifted by 16 bits, the
/// OverallLeftShift value would be -2 on entry. This is used to specify which
/// byte of ByteValues is actually being set.
/// OverallLeftShift value would be -16 on entry. This is used to specify which
/// bits of BitValues are actually being set.
///
/// Similarly, ByteMask is a bitmask where a bit is clear if its corresponding
/// byte is masked to zero by a user. For example, in (X & 255), X will be
/// processed with a bytemask of 1. Because bytemask is 32-bits, this limits
/// this function to working on up to 32-byte (256 bit) values. ByteMask is
/// always in the local (OverallLeftShift) coordinate space.
/// Similarly, BitMask is a bitmask where a bit is clear if its corresponding
/// bit is masked to zero by a user. For example, in (X & 255), X will be
/// processed with a bytemask of 255. BitMask is always in the local
/// (OverallLeftShift) coordinate space.
///
static bool CollectBSwapParts(Value *V, int OverallLeftShift, uint32_t ByteMask,
SmallVectorImpl<Value *> &ByteValues) {
static bool CollectBitParts(Value *V, int OverallLeftShift, APInt BitMask,
SmallVectorImpl<Value *> &BitValues,
SmallVectorImpl<int> &BitProvenance) {
if (Instruction *I = dyn_cast<Instruction>(V)) {
// If this is an or instruction, it may be an inner node of the bswap.
if (I->getOpcode() == Instruction::Or) {
return CollectBSwapParts(I->getOperand(0), OverallLeftShift, ByteMask,
ByteValues) ||
CollectBSwapParts(I->getOperand(1), OverallLeftShift, ByteMask,
ByteValues);
}

// If this is a logical shift by a constant multiple of 8, recurse with
// OverallLeftShift and ByteMask adjusted.
if (I->getOpcode() == Instruction::Or)
return CollectBitParts(I->getOperand(0), OverallLeftShift, BitMask,
BitValues, BitProvenance) ||
CollectBitParts(I->getOperand(1), OverallLeftShift, BitMask,
BitValues, BitProvenance);

// If this is a logical shift by a constant, recurse with OverallLeftShift
// and BitMask adjusted.
if (I->isLogicalShift() && isa<ConstantInt>(I->getOperand(1))) {
unsigned ShAmt =
cast<ConstantInt>(I->getOperand(1))->getLimitedValue(~0U);
// Ensure the shift amount is defined and of a byte value.
if ((ShAmt & 7) || (ShAmt > 8*ByteValues.size()))
cast<ConstantInt>(I->getOperand(1))->getLimitedValue(~0U);
// Ensure the shift amount is defined.
if (ShAmt > BitValues.size())
return true;

unsigned ByteShift = ShAmt >> 3;
unsigned BitShift = ShAmt;
if (I->getOpcode() == Instruction::Shl) {
// X << 2 -> collect(X, +2)
OverallLeftShift += ByteShift;
ByteMask >>= ByteShift;
// X << C -> collect(X, +C)
OverallLeftShift += BitShift;
BitMask = BitMask.lshr(BitShift);
} else {
// X >>u 2 -> collect(X, -2)
OverallLeftShift -= ByteShift;
ByteMask <<= ByteShift;
ByteMask &= (~0U >> (32-ByteValues.size()));
// X >>u C -> collect(X, -C)
OverallLeftShift -= BitShift;
BitMask = BitMask.shl(BitShift);
}

if (OverallLeftShift >= (int)ByteValues.size()) return true;
if (OverallLeftShift <= -(int)ByteValues.size()) return true;
if (OverallLeftShift >= (int)BitValues.size())
return true;
if (OverallLeftShift <= -(int)BitValues.size())
return true;

return CollectBSwapParts(I->getOperand(0), OverallLeftShift, ByteMask,
ByteValues);
return CollectBitParts(I->getOperand(0), OverallLeftShift, BitMask,
BitValues, BitProvenance);
}

// If this is a logical 'and' with a mask that clears bytes, clear the
// corresponding bytes in ByteMask.
// If this is a logical 'and' with a mask that clears bits, clear the
// corresponding bits in BitMask.
if (I->getOpcode() == Instruction::And &&
isa<ConstantInt>(I->getOperand(1))) {
// Scan every byte of the and mask, seeing if the byte is either 0 or 255.
unsigned NumBytes = ByteValues.size();
APInt Byte(I->getType()->getPrimitiveSizeInBits(), 255);
unsigned NumBits = BitValues.size();
APInt Bit(I->getType()->getPrimitiveSizeInBits(), 1);
const APInt &AndMask = cast<ConstantInt>(I->getOperand(1))->getValue();

for (unsigned i = 0; i != NumBytes; ++i, Byte <<= 8) {
// If this byte is masked out by a later operation, we don't care what
for (unsigned i = 0; i != NumBits; ++i, Bit <<= 1) {
// If this bit is masked out by a later operation, we don't care what
// the and mask is.
if ((ByteMask & (1 << i)) == 0)
if (BitMask[i] == 0)
continue;

// If the AndMask is all zeros for this byte, clear the bit.
APInt MaskB = AndMask & Byte;
// If the AndMask is zero for this bit, clear the bit.
APInt MaskB = AndMask & Bit;
if (MaskB == 0) {
ByteMask &= ~(1U << i);
BitMask.clearBit(i);
continue;
}

// If the AndMask is not all ones for this byte, it's not a bytezap.
if (MaskB != Byte)
return true;

// Otherwise, this byte is kept.
// Otherwise, this bit is kept.
}

return CollectBSwapParts(I->getOperand(0), OverallLeftShift, ByteMask,
ByteValues);
return CollectBitParts(I->getOperand(0), OverallLeftShift, BitMask,
BitValues, BitProvenance);
}
}

// Okay, we got to something that isn't a shift, 'or' or 'and'. This must be
// the input value to the bswap. Some observations: 1) if more than one byte
// is demanded from this input, then it could not be successfully assembled
// into a byteswap. At least one of the two bytes would not be aligned with
// their ultimate destination.
if (!isPowerOf2_32(ByteMask)) return true;
unsigned InputByteNo = countTrailingZeros(ByteMask);

// 2) The input and ultimate destinations must line up: if byte 3 of an i32
// is demanded, it needs to go into byte 0 of the result. This means that the
// byte needs to be shifted until it lands in the right byte bucket. The
// shift amount depends on the position: if the byte is coming from the high
// part of the value (e.g. byte 3) then it must be shifted right. If from the
// low part, it must be shifted left.
unsigned DestByteNo = InputByteNo + OverallLeftShift;
if (ByteValues.size()-1-DestByteNo != InputByteNo)
// the input value to the bswap/bitreverse. To be part of a bswap or
// bitreverse we must be demanding a contiguous range of bits from it.
unsigned InputBitLen = BitMask.countPopulation();
unsigned InputBitNo = BitMask.countTrailingZeros();
if (BitMask.getBitWidth() - BitMask.countLeadingZeros() - InputBitNo !=
InputBitLen)
// Not a contiguous set range of bits!
return true;

// If the destination byte value is already defined, the values are or'd
// together, which isn't a bswap (unless it's an or of the same bits).
if (ByteValues[DestByteNo] && ByteValues[DestByteNo] != V)
// We know we're moving a contiguous range of bits from the input to the
// output. Record which bits in the output came from which bits in the input.
unsigned DestBitNo = InputBitNo + OverallLeftShift;
for (unsigned I = 0; I < InputBitLen; ++I)
BitProvenance[DestBitNo + I] = InputBitNo + I;

// If the destination bit value is already defined, the values are or'd
// together, which isn't a bswap/bitreverse (unless it's an or of the same
// bits).
if (BitValues[DestBitNo] && BitValues[DestBitNo] != V)
return true;
ByteValues[DestByteNo] = V;
for (unsigned I = 0; I < InputBitLen; ++I)
BitValues[DestBitNo + I] = V;

return false;
}

/// Given an OR instruction, check to see if this is a bswap idiom.
/// If so, insert the new bswap intrinsic and return it.
Instruction *InstCombiner::MatchBSwap(BinaryOperator &I) {
IntegerType *ITy = dyn_cast<IntegerType>(I.getType());
if (!ITy || ITy->getBitWidth() % 16 ||
// ByteMask only allows up to 32-byte values.
ITy->getBitWidth() > 32*8)
return nullptr; // Can only bswap pairs of bytes. Can't do vectors.
static bool bitTransformIsCorrectForBSwap(unsigned From, unsigned To,
unsigned BitWidth) {
if (From % 8 != To % 8)
return false;
// Convert from bit indices to byte indices and check for a byte reversal.
From >>= 3;
To >>= 3;
BitWidth >>= 3;
return From == BitWidth - To - 1;
}

/// ByteValues - For each byte of the result, we keep track of which value
/// defines each byte.
SmallVector<Value*, 8> ByteValues;
ByteValues.resize(ITy->getBitWidth()/8);
static bool bitTransformIsCorrectForBitReverse(unsigned From, unsigned To,
unsigned BitWidth) {
return From == BitWidth - To - 1;
}

/// Given an OR instruction, check to see if this is a bswap or bitreverse
/// idiom. If so, insert the new intrinsic and return it.
Instruction *InstCombiner::MatchBSwapOrBitReverse(BinaryOperator &I) {
IntegerType *ITy = dyn_cast<IntegerType>(I.getType());
if (!ITy)
return nullptr; // Can't do vectors.
unsigned BW = ITy->getBitWidth();

/// We keep track of which bit (BitProvenance) inside which value (BitValues)
/// defines each bit in the result.
SmallVector<Value *, 8> BitValues(BW, nullptr);
SmallVector<int, 8> BitProvenance(BW, -1);

// Try to find all the pieces corresponding to the bswap.
uint32_t ByteMask = ~0U >> (32-ByteValues.size());
if (CollectBSwapParts(&I, 0, ByteMask, ByteValues))
APInt BitMask = APInt::getAllOnesValue(BitValues.size());
if (CollectBitParts(&I, 0, BitMask, BitValues, BitProvenance))
return nullptr;

// Check to see if all of the bytes come from the same value.
Value *V = ByteValues[0];
if (!V) return nullptr; // Didn't find a byte? Must be zero.
// Check to see if all of the bits come from the same value.
Value *V = BitValues[0];
if (!V) return nullptr; // Didn't find a bit? Must be zero.

// Check to make sure that all of the bytes come from the same value.
for (unsigned i = 1, e = ByteValues.size(); i != e; ++i)
if (ByteValues[i] != V)
return nullptr;
if (!std::all_of(BitValues.begin(), BitValues.end(),
[&](const Value *X) { return X == V; }))
return nullptr;

// Now, is the bit permutation correct for a bswap or a bitreverse? We can
// only byteswap values with an even number of bytes.
bool OKForBSwap = BW % 16 == 0, OKForBitReverse = true;;
for (unsigned i = 0, e = BitValues.size(); i != e; ++i) {
OKForBSwap &= bitTransformIsCorrectForBSwap(BitProvenance[i], i, BW);
OKForBitReverse &=
bitTransformIsCorrectForBitReverse(BitProvenance[i], i, BW);
}

Intrinsic::ID Intrin;
if (OKForBSwap)
Intrin = Intrinsic::bswap;
else if (OKForBitReverse)
Intrin = Intrinsic::bitreverse;
else
return nullptr;

Module *M = I.getParent()->getParent()->getParent();
Function *F = Intrinsic::getDeclaration(M, Intrinsic::bswap, ITy);
Function *F = Intrinsic::getDeclaration(M, Intrin, ITy);
return CallInst::Create(F, V);
}

Expand Down Expand Up @@ -2265,7 +2298,7 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) {
match(Op1, m_And(m_Value(), m_Value()));

if (OrOfOrs || OrOfShifts || OrOfAnds)
if (Instruction *BSwap = MatchBSwap(I))
if (Instruction *BSwap = MatchBSwapOrBitReverse(I))
return BSwap;

// (X^C)|Y -> (X|Y)^C iff Y&C == 0
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Transforms/InstCombine/InstCombineInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner
Value *InsertRangeTest(Value *V, Constant *Lo, Constant *Hi, bool isSigned,
bool Inside);
Instruction *PromoteCastOfAllocation(BitCastInst &CI, AllocaInst &AI);
Instruction *MatchBSwap(BinaryOperator &I);
Instruction *MatchBSwapOrBitReverse(BinaryOperator &I);
bool SimplifyStoreAtEndOfBlock(StoreInst &SI);
Instruction *SimplifyMemTransfer(MemIntrinsic *MI);
Instruction *SimplifyMemSet(MemSetInst *MI);
Expand Down
Loading

0 comments on commit 37b82e7

Please sign in to comment.