Skip to content

Commit

Permalink
[LoopIdiom] let the pass deal with runtime memset size
Browse files Browse the repository at this point in the history
The current LIR does not deal with runtime-determined memset-size. This patch
utilizes SCEV and check if the PointerStrideSCEV and the MemsetSizeSCEV are equal.
Before comparison the pass would try to fold the expression that is already
protected by the loop guard.

Testcase file `memset-runtime.ll`, `memset-runtime-debug.ll` added.

This patch deals with proper loop-idiom. Proceeding patch wants to deal with SCEV-s
that are inequal after folding with the loop guards.

Reviewed By: lebedev.ri, Whitney

Differential Revision: https://reviews.llvm.org/D107353
  • Loading branch information
eopXD authored and Yueh-Ting Chen committed Aug 14, 2021
1 parent fe86632 commit 0121736
Show file tree
Hide file tree
Showing 3 changed files with 445 additions and 31 deletions.
96 changes: 65 additions & 31 deletions llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
Expand Up @@ -896,8 +896,8 @@ bool LoopIdiomRecognize::processLoopMemCpy(MemCpyInst *MCI,
/// processLoopMemSet - See if this memset can be promoted to a large memset.
bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI,
const SCEV *BECount) {
// We can only handle non-volatile memsets with a constant size.
if (MSI->isVolatile() || !isa<ConstantInt>(MSI->getLength()))
// We can only handle non-volatile memsets.
if (MSI->isVolatile())
return false;

// If we're not allowed to hack on memset, we fail.
Expand All @@ -910,23 +910,72 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI,
// loop, which indicates a strided store. If we have something else, it's a
// random store we can't handle.
const SCEVAddRecExpr *Ev = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(Pointer));
if (!Ev || Ev->getLoop() != CurLoop || !Ev->isAffine())
if (!Ev || Ev->getLoop() != CurLoop)
return false;

// Reject memsets that are so large that they overflow an unsigned.
uint64_t SizeInBytes = cast<ConstantInt>(MSI->getLength())->getZExtValue();
if ((SizeInBytes >> 32) != 0)
if (!Ev->isAffine()) {
LLVM_DEBUG(dbgs() << " Pointer is not affine, abort\n");
return false;
}

// Check to see if the stride matches the size of the memset. If so, then we
// know that every byte is touched in the loop.
const SCEVConstant *ConstStride = dyn_cast<SCEVConstant>(Ev->getOperand(1));
if (!ConstStride)
const SCEV *PointerStrideSCEV = Ev->getOperand(1);
const SCEV *MemsetSizeSCEV = SE->getSCEV(MSI->getLength());
if (!PointerStrideSCEV || !MemsetSizeSCEV)
return false;

APInt Stride = ConstStride->getAPInt();
if (SizeInBytes != Stride && SizeInBytes != -Stride)
return false;
bool IsNegStride = false;
const bool IsConstantSize = isa<ConstantInt>(MSI->getLength());

if (IsConstantSize) {
// Memset size is constant.
// Check if the pointer stride matches the memset size. If so, then
// we know that every byte is touched in the loop.
LLVM_DEBUG(dbgs() << " memset size is constant\n");
uint64_t SizeInBytes = cast<ConstantInt>(MSI->getLength())->getZExtValue();
const SCEVConstant *ConstStride = dyn_cast<SCEVConstant>(Ev->getOperand(1));
if (!ConstStride)
return false;

APInt Stride = ConstStride->getAPInt();
if (SizeInBytes != Stride && SizeInBytes != -Stride)
return false;

IsNegStride = SizeInBytes == -Stride;
} else {
// Memset size is non-constant.
// Check if the pointer stride matches the memset size.
// To be conservative, the pass would not promote pointers that aren't in
// address space zero. Also, the pass only handles memset length and stride
// that are invariant for the top level loop.
LLVM_DEBUG(dbgs() << " memset size is non-constant\n");
if (Pointer->getType()->getPointerAddressSpace() != 0) {
LLVM_DEBUG(dbgs() << " pointer is not in address space zero, "
<< "abort\n");
return false;
}
if (!SE->isLoopInvariant(MemsetSizeSCEV, CurLoop)) {
LLVM_DEBUG(dbgs() << " memset size is not a loop-invariant, "
<< "abort\n");
return false;
}

// Compare positive direction PointerStrideSCEV with MemsetSizeSCEV
IsNegStride = PointerStrideSCEV->isNonConstantNegative();
const SCEV *PositiveStrideSCEV =
IsNegStride ? SE->getNegativeSCEV(PointerStrideSCEV)
: PointerStrideSCEV;
LLVM_DEBUG(dbgs() << " MemsetSizeSCEV: " << *MemsetSizeSCEV << "\n"
<< " PositiveStrideSCEV: " << *PositiveStrideSCEV
<< "\n");

if (PositiveStrideSCEV != MemsetSizeSCEV) {
// TODO: folding can be done to the SCEVs
// The folding is to fold expressions that is covered by the loop guard
// at loop entry. After the folding, compare again and proceed
// optimization if equal.
LLVM_DEBUG(dbgs() << " SCEV don't match, abort\n");
return false;
}
}

// Verify that the memset value is loop invariant. If not, we can't promote
// the memset.
Expand All @@ -936,7 +985,6 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI,

SmallPtrSet<Instruction *, 1> MSIs;
MSIs.insert(MSI);
bool IsNegStride = SizeInBytes == -Stride;
return processLoopStridedStore(Pointer, SE->getSCEV(MSI->getLength()),
MaybeAlign(MSI->getDestAlignment()),
SplatValue, MSI, MSIs, Ev, BECount,
Expand Down Expand Up @@ -1028,20 +1076,6 @@ static const SCEV *getTripCount(const SCEV *BECount, Type *IntPtr,
///
/// This also maps the SCEV into the provided type and tries to handle the
/// computation in a way that will fold cleanly.
static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr,
unsigned StoreSize, Loop *CurLoop,
const DataLayout *DL, ScalarEvolution *SE) {
const SCEV *TripCountSCEV = getTripCount(BECount, IntPtr, CurLoop, DL, SE);

// And scale it based on the store size.
if (StoreSize != 1) {
return SE->getMulExpr(TripCountSCEV, SE->getConstant(IntPtr, StoreSize),
SCEV::FlagNUW);
}
return TripCountSCEV;
}

/// getNumBytes that takes StoreSize as a SCEV
static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr,
const SCEV *StoreSizeSCEV, Loop *CurLoop,
const DataLayout *DL, ScalarEvolution *SE) {
Expand Down Expand Up @@ -1342,8 +1376,8 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(

// Okay, everything is safe, we can transform this!

const SCEV *NumBytesS =
getNumBytes(BECount, IntIdxTy, StoreSize, CurLoop, DL, SE);
const SCEV *NumBytesS = getNumBytes(
BECount, IntIdxTy, SE->getConstant(IntIdxTy, StoreSize), CurLoop, DL, SE);

Value *NumBytes =
Expander.expandCodeFor(NumBytesS, IntIdxTy, Preheader->getTerminator());
Expand Down

0 comments on commit 0121736

Please sign in to comment.