Skip to content

Commit

Permalink
[CodeGen][ExpandMemcmp] Add an option for allowing overlapping loads.
Browse files Browse the repository at this point in the history
Summary:
This allows expanding {7,11,13,14,15,21,22,23,25,26,27,28,29,30,31}-byte memcmp
in just two loads on X86. These were previously calling memcmp.

Reviewers: spatel, gchatelet

Subscribers: llvm-commits

Differential Revision: https://reviews.llvm.org/D55263

llvm-svn: 349731
  • Loading branch information
legrosbuffle committed Dec 20, 2018
1 parent d3bd614 commit 1bb6e1b
Show file tree
Hide file tree
Showing 6 changed files with 639 additions and 265 deletions.
8 changes: 6 additions & 2 deletions llvm/include/llvm/Analysis/TargetTransformInfo.h
Expand Up @@ -581,13 +581,17 @@ class TargetTransformInfo {
struct MemCmpExpansionOptions {
// The list of available load sizes (in bytes), sorted in decreasing order.
SmallVector<unsigned, 8> LoadSizes;
// Set to true to allow overlapping loads. For example, 7-byte compares can
// be done with two 4-byte compares instead of 4+2+1-byte compares. This
// requires all loads in LoadSizes to be doable in an unaligned way.
bool AllowOverlappingLoads = false;
};
const MemCmpExpansionOptions *enableMemCmpExpansion(bool IsZeroCmp) const;

/// Enable matching of interleaved access groups.
bool enableInterleavedAccessVectorization() const;

/// Enable matching of interleaved access groups that contain predicated
/// Enable matching of interleaved access groups that contain predicated
/// accesses or gaps and therefore vectorized using masked
/// vector loads/stores.
bool enableMaskedInterleavedAccessVectorization() const;
Expand Down Expand Up @@ -772,7 +776,7 @@ class TargetTransformInfo {
/// \return The cost of a shuffle instruction of kind Kind and of type Tp.
/// The index and subtype parameters are used by the subvector insertion and
/// extraction shuffle kinds to show the insert/extract point and the type of
/// the subvector being inserted/extracted.
/// the subvector being inserted/extracted.
/// NOTE: For subvector extractions Tp represents the source type.
int getShuffleCost(ShuffleKind Kind, Type *Tp, int Index = 0,
Type *SubTp = nullptr) const;
Expand Down
233 changes: 137 additions & 96 deletions llvm/lib/CodeGen/ExpandMemCmp.cpp
Expand Up @@ -66,23 +66,18 @@ class MemCmpExpansion {
// Represents the decomposition in blocks of the expansion. For example,
// comparing 33 bytes on X86+sse can be done with 2x16-byte loads and
// 1x1-byte load, which would be represented as [{16, 0}, {16, 16}, {32, 1}.
// TODO(courbet): Involve the target more in this computation. On X86, 7
// bytes can be done more efficiently with two overlaping 4-byte loads than
// covering the interval with [{4, 0},{2, 4},{1, 6}}.
struct LoadEntry {
LoadEntry(unsigned LoadSize, uint64_t Offset)
: LoadSize(LoadSize), Offset(Offset) {
assert(Offset % LoadSize == 0 && "invalid load entry");
}

uint64_t getGEPIndex() const { return Offset / LoadSize; }

// The size of the load for this block, in bytes.
const unsigned LoadSize;
// The offset of this load WRT the base pointer, in bytes.
const uint64_t Offset;
unsigned LoadSize;
// The offset of this load from the base pointer, in bytes.
uint64_t Offset;
};
SmallVector<LoadEntry, 8> LoadSequence;
using LoadEntryVector = SmallVector<LoadEntry, 8>;
LoadEntryVector LoadSequence;

void createLoadCmpBlocks();
void createResultBlock();
Expand All @@ -92,13 +87,23 @@ class MemCmpExpansion {
void emitLoadCompareBlock(unsigned BlockIndex);
void emitLoadCompareBlockMultipleLoads(unsigned BlockIndex,
unsigned &LoadIndex);
void emitLoadCompareByteBlock(unsigned BlockIndex, unsigned GEPIndex);
void emitLoadCompareByteBlock(unsigned BlockIndex, unsigned OffsetBytes);
void emitMemCmpResultBlock();
Value *getMemCmpExpansionZeroCase();
Value *getMemCmpEqZeroOneBlock();
Value *getMemCmpOneBlock();
Value *getPtrToElementAtOffset(Value *Source, Type *LoadSizeType,
uint64_t OffsetBytes);

static LoadEntryVector
computeGreedyLoadSequence(uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes,
unsigned MaxNumLoads, unsigned &NumLoadsNonOneByte);
static LoadEntryVector
computeOverlappingLoadSequence(uint64_t Size, unsigned MaxLoadSize,
unsigned MaxNumLoads,
unsigned &NumLoadsNonOneByte);

public:
public:
MemCmpExpansion(CallInst *CI, uint64_t Size,
const TargetTransformInfo::MemCmpExpansionOptions &Options,
unsigned MaxNumLoads, const bool IsUsedForZeroCmp,
Expand All @@ -110,6 +115,76 @@ class MemCmpExpansion {
Value *getMemCmpExpansion();
};

MemCmpExpansion::LoadEntryVector MemCmpExpansion::computeGreedyLoadSequence(
uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes,
const unsigned MaxNumLoads, unsigned &NumLoadsNonOneByte) {
NumLoadsNonOneByte = 0;
LoadEntryVector LoadSequence;
uint64_t Offset = 0;
while (Size && !LoadSizes.empty()) {
const unsigned LoadSize = LoadSizes.front();
const uint64_t NumLoadsForThisSize = Size / LoadSize;
if (LoadSequence.size() + NumLoadsForThisSize > MaxNumLoads) {
// Do not expand if the total number of loads is larger than what the
// target allows. Note that it's important that we exit before completing
// the expansion to avoid using a ton of memory to store the expansion for
// large sizes.
return {};
}
if (NumLoadsForThisSize > 0) {
for (uint64_t I = 0; I < NumLoadsForThisSize; ++I) {
LoadSequence.push_back({LoadSize, Offset});
Offset += LoadSize;
}
if (LoadSize > 1)
++NumLoadsNonOneByte;
Size = Size % LoadSize;
}
LoadSizes = LoadSizes.drop_front();
}
return LoadSequence;
}

MemCmpExpansion::LoadEntryVector
MemCmpExpansion::computeOverlappingLoadSequence(uint64_t Size,
const unsigned MaxLoadSize,
const unsigned MaxNumLoads,
unsigned &NumLoadsNonOneByte) {
// These are already handled by the greedy approach.
if (Size < 2 || MaxLoadSize < 2)
return {};

// We try to do as many non-overlapping loads as possible starting from the
// beginning.
const uint64_t NumNonOverlappingLoads = Size / MaxLoadSize;
assert(NumNonOverlappingLoads && "there must be at least one load");
// There remain 0 to (MaxLoadSize - 1) bytes to load, this will be done with
// an overlapping load.
Size = Size - NumNonOverlappingLoads * MaxLoadSize;
// Bail if we do not need an overloapping store, this is already handled by
// the greedy approach.
if (Size == 0)
return {};
// Bail if the number of loads (non-overlapping + potential overlapping one)
// is larger than the max allowed.
if ((NumNonOverlappingLoads + 1) > MaxNumLoads)
return {};

// Add non-overlapping loads.
LoadEntryVector LoadSequence;
uint64_t Offset = 0;
for (uint64_t I = 0; I < NumNonOverlappingLoads; ++I) {
LoadSequence.push_back({MaxLoadSize, Offset});
Offset += MaxLoadSize;
}

// Add the last overlapping load.
assert(Size > 0 && Size < MaxLoadSize && "broken invariant");
LoadSequence.push_back({MaxLoadSize, Offset - (MaxLoadSize - Size)});
NumLoadsNonOneByte = 1;
return LoadSequence;
}

// Initialize the basic block structure required for expansion of memcmp call
// with given maximum load size and memcmp size parameter.
// This structure includes:
Expand All @@ -133,38 +208,31 @@ MemCmpExpansion::MemCmpExpansion(
Builder(CI) {
assert(Size > 0 && "zero blocks");
// Scale the max size down if the target can load more bytes than we need.
size_t LoadSizeIndex = 0;
while (LoadSizeIndex < Options.LoadSizes.size() &&
Options.LoadSizes[LoadSizeIndex] > Size) {
++LoadSizeIndex;
llvm::ArrayRef<unsigned> LoadSizes(Options.LoadSizes);
while (!LoadSizes.empty() && LoadSizes.front() > Size) {
LoadSizes = LoadSizes.drop_front();
}
this->MaxLoadSize = Options.LoadSizes[LoadSizeIndex];
assert(!LoadSizes.empty() && "cannot load Size bytes");
MaxLoadSize = LoadSizes.front();
// Compute the decomposition.
uint64_t CurSize = Size;
uint64_t Offset = 0;
while (CurSize && LoadSizeIndex < Options.LoadSizes.size()) {
const unsigned LoadSize = Options.LoadSizes[LoadSizeIndex];
assert(LoadSize > 0 && "zero load size");
const uint64_t NumLoadsForThisSize = CurSize / LoadSize;
if (LoadSequence.size() + NumLoadsForThisSize > MaxNumLoads) {
// Do not expand if the total number of loads is larger than what the
// target allows. Note that it's important that we exit before completing
// the expansion to avoid using a ton of memory to store the expansion for
// large sizes.
LoadSequence.clear();
return;
}
if (NumLoadsForThisSize > 0) {
for (uint64_t I = 0; I < NumLoadsForThisSize; ++I) {
LoadSequence.push_back({LoadSize, Offset});
Offset += LoadSize;
}
if (LoadSize > 1) {
++NumLoadsNonOneByte;
}
CurSize = CurSize % LoadSize;
unsigned GreedyNumLoadsNonOneByte = 0;
LoadSequence = computeGreedyLoadSequence(Size, LoadSizes, MaxNumLoads,
GreedyNumLoadsNonOneByte);
NumLoadsNonOneByte = GreedyNumLoadsNonOneByte;
assert(LoadSequence.size() <= MaxNumLoads && "broken invariant");
// If we allow overlapping loads and the load sequence is not already optimal,
// use overlapping loads.
if (Options.AllowOverlappingLoads &&
(LoadSequence.empty() || LoadSequence.size() > 2)) {
unsigned OverlappingNumLoadsNonOneByte = 0;
auto OverlappingLoads = computeOverlappingLoadSequence(
Size, MaxLoadSize, MaxNumLoads, OverlappingNumLoadsNonOneByte);
if (!OverlappingLoads.empty() &&
(LoadSequence.empty() ||
OverlappingLoads.size() < LoadSequence.size())) {
LoadSequence = OverlappingLoads;
NumLoadsNonOneByte = OverlappingNumLoadsNonOneByte;
}
++LoadSizeIndex;
}
assert(LoadSequence.size() <= MaxNumLoads && "broken invariant");
}
Expand All @@ -189,30 +257,32 @@ void MemCmpExpansion::createResultBlock() {
EndBlock->getParent(), EndBlock);
}

/// Return a pointer to an element of type `LoadSizeType` at offset
/// `OffsetBytes`.
Value *MemCmpExpansion::getPtrToElementAtOffset(Value *Source,
Type *LoadSizeType,
uint64_t OffsetBytes) {
if (OffsetBytes > 0) {
auto *ByteType = Type::getInt8Ty(CI->getContext());
Source = Builder.CreateGEP(
ByteType, Builder.CreateBitCast(Source, ByteType->getPointerTo()),
ConstantInt::get(ByteType, OffsetBytes));
}
return Builder.CreateBitCast(Source, LoadSizeType->getPointerTo());
}

// This function creates the IR instructions for loading and comparing 1 byte.
// It loads 1 byte from each source of the memcmp parameters with the given
// GEPIndex. It then subtracts the two loaded values and adds this result to the
// final phi node for selecting the memcmp result.
void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex,
unsigned GEPIndex) {
Value *Source1 = CI->getArgOperand(0);
Value *Source2 = CI->getArgOperand(1);

unsigned OffsetBytes) {
Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
Type *LoadSizeType = Type::getInt8Ty(CI->getContext());
// Cast source to LoadSizeType*.
if (Source1->getType() != LoadSizeType)
Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
if (Source2->getType() != LoadSizeType)
Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());

// Get the base address using the GEPIndex.
if (GEPIndex != 0) {
Source1 = Builder.CreateGEP(LoadSizeType, Source1,
ConstantInt::get(LoadSizeType, GEPIndex));
Source2 = Builder.CreateGEP(LoadSizeType, Source2,
ConstantInt::get(LoadSizeType, GEPIndex));
}
Value *Source1 =
getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType, OffsetBytes);
Value *Source2 =
getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType, OffsetBytes);

Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
Expand Down Expand Up @@ -270,24 +340,10 @@ Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex,
IntegerType *LoadSizeType =
IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8);

Value *Source1 = CI->getArgOperand(0);
Value *Source2 = CI->getArgOperand(1);

// Cast source to LoadSizeType*.
if (Source1->getType() != LoadSizeType)
Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
if (Source2->getType() != LoadSizeType)
Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());

// Get the base address using a GEP.
if (CurLoadEntry.Offset != 0) {
Source1 = Builder.CreateGEP(
LoadSizeType, Source1,
ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex()));
Source2 = Builder.CreateGEP(
LoadSizeType, Source2,
ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex()));
}
Value *Source1 = getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType,
CurLoadEntry.Offset);
Value *Source2 = getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType,
CurLoadEntry.Offset);

// Get a constant or load a value for each source address.
Value *LoadSrc1 = nullptr;
Expand Down Expand Up @@ -378,8 +434,7 @@ void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) {
const LoadEntry &CurLoadEntry = LoadSequence[BlockIndex];

if (CurLoadEntry.LoadSize == 1) {
MemCmpExpansion::emitLoadCompareByteBlock(BlockIndex,
CurLoadEntry.getGEPIndex());
MemCmpExpansion::emitLoadCompareByteBlock(BlockIndex, CurLoadEntry.Offset);
return;
}

Expand All @@ -388,25 +443,12 @@ void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) {
Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8);
assert(CurLoadEntry.LoadSize <= MaxLoadSize && "Unexpected load type");

Value *Source1 = CI->getArgOperand(0);
Value *Source2 = CI->getArgOperand(1);

Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
// Cast source to LoadSizeType*.
if (Source1->getType() != LoadSizeType)
Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
if (Source2->getType() != LoadSizeType)
Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());

// Get the base address using a GEP.
if (CurLoadEntry.Offset != 0) {
Source1 = Builder.CreateGEP(
LoadSizeType, Source1,
ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex()));
Source2 = Builder.CreateGEP(
LoadSizeType, Source2,
ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex()));
}
Value *Source1 = getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType,
CurLoadEntry.Offset);
Value *Source2 = getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType,
CurLoadEntry.Offset);

// Load LoadSizeType from the base address.
Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
Expand Down Expand Up @@ -694,7 +736,6 @@ static bool expandMemCmp(CallInst *CI, const TargetTransformInfo *TTI,
if (SizeVal == 0) {
return false;
}

// TTI call to check if target would like to expand memcmp. Also, get the
// available load sizes.
const bool IsUsedForZeroCmp = isOnlyUsedInZeroEqualityComparison(CI);
Expand Down
5 changes: 4 additions & 1 deletion llvm/lib/Target/X86/X86TargetTransformInfo.cpp
Expand Up @@ -1886,7 +1886,7 @@ int X86TTIImpl::getIntrinsicInstrCost(Intrinsic::ID IID, Type *RetTy,
{ ISD::FSQRT, MVT::v4f32, 56 }, // Pentium III from http://www.agner.org/
};
static const CostTblEntry X64CostTbl[] = { // 64-bit targets
{ ISD::BITREVERSE, MVT::i64, 14 }
{ ISD::BITREVERSE, MVT::i64, 14 }
};
static const CostTblEntry X86CostTbl[] = { // 32 or 64-bit targets
{ ISD::BITREVERSE, MVT::i32, 14 },
Expand Down Expand Up @@ -2899,6 +2899,9 @@ X86TTIImpl::enableMemCmpExpansion(bool IsZeroCmp) const {
Options.LoadSizes.push_back(4);
Options.LoadSizes.push_back(2);
Options.LoadSizes.push_back(1);
// All GPR and vector loads can be unaligned. SIMD compare requires integer
// vectors (SSE2/AVX2).
Options.AllowOverlappingLoads = true;
return Options;
}();
return IsZeroCmp ? &EqZeroOptions : &ThreeWayOptions;
Expand Down

0 comments on commit 1bb6e1b

Please sign in to comment.