Skip to content

Commit

Permalink
[AggressiveInstCombine] Combine consecutive loads which are being mer…
Browse files Browse the repository at this point in the history
…ged to form a wider load.

The patch simplifies some of the patterns as below

1. (ZExt(L1) << shift1) | (ZExt(L2) << shift2) -> ZExt(L3) << shift1
2. (ZExt(L1) << shift1) | ZExt(L2) -> ZExt(L3)

The pattern is indicative of the fact that the loads are being merged to a wider load and the only use of this pattern is with a wider load. In this case for a non-atomic/non-volatile loads reduce the pattern to a combined load which would improve the cost of inlining, unrolling, vectorization etc.

Fix the error reported on reverse load merge.

Differential Revision: https://reviews.llvm.org/D127392
  • Loading branch information
bipmis committed Sep 28, 2022
1 parent bc2a85f commit 3b49a9f
Show file tree
Hide file tree
Showing 3 changed files with 1,228 additions and 687 deletions.
215 changes: 210 additions & 5 deletions llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
Expand Up @@ -50,6 +50,10 @@ STATISTIC(NumGuardedFunnelShifts,
"Number of guarded funnel shifts transformed into funnel shifts");
STATISTIC(NumPopCountRecognized, "Number of popcount idioms recognized");

static cl::opt<unsigned> MaxInstrsToScan(
"aggressive-instcombine-max-scan-instrs", cl::init(64), cl::Hidden,
cl::desc("Max number of instructions to scan for aggressive instcombine."));

namespace {
/// Contains expression pattern combiner logic.
/// This class provides both the logic to combine expression patterns and
Expand Down Expand Up @@ -635,18 +639,214 @@ static bool tryToRecognizeTableBasedCttz(Instruction &I) {
return true;
}

/// This is used by foldLoadsRecursive() to capture a Root Load node which is
/// of type or(load, load) and recursively build the wide load. Also capture the
/// shift amount, zero extend type and loadSize.
struct LoadOps {
LoadInst *Root = nullptr;
bool FoundRoot = false;
uint64_t LoadSize = 0;
Value *Shift = nullptr;
Type *ZextType;
AAMDNodes AATags;
};

// Identify and Merge consecutive loads recursively which is of the form
// (ZExt(L1) << shift1) | (ZExt(L2) << shift2) -> ZExt(L3) << shift1
// (ZExt(L1) << shift1) | ZExt(L2) -> ZExt(L3)
static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
AliasAnalysis &AA) {
Value *ShAmt2 = nullptr;
Value *X;
Instruction *L1, *L2;

// Go to the last node with loads.
if (match(V, m_OneUse(m_c_Or(
m_Value(X),
m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))),
m_Value(ShAmt2)))))) ||
match(V, m_OneUse(m_Or(m_Value(X),
m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2))))))))
foldLoadsRecursive(X, LOps, DL, AA);
else
return false;

// Check if the pattern has loads
LoadInst *LI1 = LOps.Root;
Value *ShAmt1 = LOps.Shift;
if (LOps.FoundRoot == false &&
(match(X, m_OneUse(m_ZExt(m_Instruction(L1)))) ||
match(X, m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L1)))),
m_Value(ShAmt1)))))) {
LI1 = dyn_cast<LoadInst>(L1);
}
LoadInst *LI2 = dyn_cast<LoadInst>(L2);

// Check if loads are same, atomic, volatile and having same address space.
if (LI1 == LI2 || !LI1 || !LI2 || !LI1->isSimple() || !LI2->isSimple() ||
LI1->getPointerAddressSpace() != LI2->getPointerAddressSpace())
return false;

// Check if Loads come from same BB.
if (LI1->getParent() != LI2->getParent())
return false;

// Swap loads if LI1 comes later as we handle only forward loads.
// This is done as InstCombine folds lowest node forward loads to reverse.
// The implementation will be subsequently extended to handle all reverse
// loads.
if (!LI1->comesBefore(LI2)) {
if (LOps.FoundRoot == false) {
std::swap(LI1, LI2);
std::swap(ShAmt1, ShAmt2);
} else
return false;
}

// Find the data layout
bool IsBigEndian = DL.isBigEndian();

// Check if loads are consecutive and same size.
Value *Load1Ptr = LI1->getPointerOperand();
APInt Offset1(DL.getIndexTypeSizeInBits(Load1Ptr->getType()), 0);
Load1Ptr =
Load1Ptr->stripAndAccumulateConstantOffsets(DL, Offset1,
/* AllowNonInbounds */ true);

Value *Load2Ptr = LI2->getPointerOperand();
APInt Offset2(DL.getIndexTypeSizeInBits(Load2Ptr->getType()), 0);
Load2Ptr =
Load2Ptr->stripAndAccumulateConstantOffsets(DL, Offset2,
/* AllowNonInbounds */ true);

// Verify if both loads have same base pointers and load sizes are same.
uint64_t LoadSize1 = LI1->getType()->getPrimitiveSizeInBits();
uint64_t LoadSize2 = LI2->getType()->getPrimitiveSizeInBits();
if (Load1Ptr != Load2Ptr || LoadSize1 != LoadSize2)
return false;

// Support Loadsizes greater or equal to 8bits and only power of 2.
if (LoadSize1 < 8 || !isPowerOf2_64(LoadSize1))
return false;

// Alias Analysis to check for store b/w the loads.
MemoryLocation Loc = MemoryLocation::get(LI2);
unsigned NumScanned = 0;
for (Instruction &Inst : make_range(LI1->getIterator(), LI2->getIterator())) {
if (Inst.mayWriteToMemory() && isModSet(AA.getModRefInfo(&Inst, Loc)))
return false;
if (++NumScanned > MaxInstrsToScan)
return false;
}

// Big endian swap the shifts
if (IsBigEndian)
std::swap(ShAmt1, ShAmt2);

// Find Shifts values.
const APInt *Temp;
uint64_t Shift1 = 0, Shift2 = 0;
if (ShAmt1 && match(ShAmt1, m_APInt(Temp)))
Shift1 = Temp->getZExtValue();
if (ShAmt2 && match(ShAmt2, m_APInt(Temp)))
Shift2 = Temp->getZExtValue();

// First load is always LI1. This is where we put the new load.
// Use the merged load size available from LI1, if we already combined loads.
if (LOps.FoundRoot)
LoadSize1 = LOps.LoadSize;

// Verify if shift amount and load index aligns and verifies that loads
// are consecutive.
uint64_t ShiftDiff = IsBigEndian ? LoadSize2 : LoadSize1;
uint64_t PrevSize =
DL.getTypeStoreSize(IntegerType::get(LI1->getContext(), LoadSize1));
if ((Shift2 - Shift1) != ShiftDiff || (Offset2 - Offset1) != PrevSize)
return false;

// Update LOps
AAMDNodes AATags1 = LOps.AATags;
AAMDNodes AATags2 = LI2->getAAMetadata();
if (LOps.FoundRoot == false) {
LOps.FoundRoot = true;
LOps.LoadSize = LoadSize1 + LoadSize2;
AATags1 = LI1->getAAMetadata();
} else
LOps.LoadSize = LOps.LoadSize + LoadSize2;

// Concatenate the AATags of the Merged Loads.
LOps.AATags = AATags1.concat(AATags2);

LOps.Root = LI1;
LOps.Shift = ShAmt1;
LOps.ZextType = X->getType();
return true;
}

// For a given BB instruction, evaluate all loads in the chain that form a
// pattern which suggests that the loads can be combined. The one and only use
// of the loads is to form a wider load.
static bool foldConsecutiveLoads(Instruction &I, const DataLayout &DL,
TargetTransformInfo &TTI, AliasAnalysis &AA) {
LoadOps LOps;
if (!foldLoadsRecursive(&I, LOps, DL, AA) || !LOps.FoundRoot)
return false;

IRBuilder<> Builder(&I);
LoadInst *NewLoad = nullptr, *LI1 = LOps.Root;

// TTI based checks if we want to proceed with wider load
bool Allowed =
TTI.isTypeLegal(IntegerType::get(I.getContext(), LOps.LoadSize));
if (!Allowed)
return false;

unsigned AS = LI1->getPointerAddressSpace();
bool Fast = false;
Allowed = TTI.allowsMisalignedMemoryAccesses(I.getContext(), LOps.LoadSize,
AS, LI1->getAlign(), &Fast);
if (!Allowed || !Fast)
return false;

// New load can be generated
Value *Load1Ptr = LI1->getPointerOperand();
Builder.SetInsertPoint(LI1);
NewLoad = Builder.CreateAlignedLoad(
IntegerType::get(Load1Ptr->getContext(), LOps.LoadSize), Load1Ptr,
LI1->getAlign(), LI1->isVolatile(), "");
NewLoad->takeName(LI1);
// Set the New Load AATags Metadata.
if (LOps.AATags)
NewLoad->setAAMetadata(LOps.AATags);

Value *NewOp = NewLoad;
// Check if zero extend needed.
if (LOps.ZextType)
NewOp = Builder.CreateZExt(NewOp, LOps.ZextType);

// Check if shift needed. We need to shift with the amount of load1
// shift if not zero.
if (LOps.Shift)
NewOp = Builder.CreateShl(NewOp, LOps.Shift);
I.replaceAllUsesWith(NewOp);

return true;
}

/// This is the entry point for folds that could be implemented in regular
/// InstCombine, but they are separated because they are not expected to
/// occur frequently and/or have more than a constant-length pattern match.
static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
TargetTransformInfo &TTI,
TargetLibraryInfo &TLI) {
TargetLibraryInfo &TLI, AliasAnalysis &AA) {
bool MadeChange = false;
for (BasicBlock &BB : F) {
// Ignore unreachable basic blocks.
if (!DT.isReachableFromEntry(&BB))
continue;

const DataLayout &DL = F.getParent()->getDataLayout();

// Walk the block backwards for efficiency. We're matching a chain of
// use->defs, so we're more likely to succeed by starting from the bottom.
// Also, we want to avoid matching partial patterns.
Expand All @@ -658,6 +858,7 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
MadeChange |= tryToRecognizePopCount(I);
MadeChange |= tryToFPToSat(I, TTI);
MadeChange |= tryToRecognizeTableBasedCttz(I);
MadeChange |= foldConsecutiveLoads(I, DL, TTI, AA);
// NOTE: This function introduces erasing of the instruction `I`, so it
// needs to be called at the end of this sequence, otherwise we may make
// bugs.
Expand All @@ -676,12 +877,13 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
/// This is the entry point for all transforms. Pass manager differences are
/// handled in the callers of this function.
static bool runImpl(Function &F, AssumptionCache &AC, TargetTransformInfo &TTI,
TargetLibraryInfo &TLI, DominatorTree &DT) {
TargetLibraryInfo &TLI, DominatorTree &DT,
AliasAnalysis &AA) {
bool MadeChange = false;
const DataLayout &DL = F.getParent()->getDataLayout();
TruncInstCombine TIC(AC, TLI, DL, DT);
MadeChange |= TIC.run(F);
MadeChange |= foldUnusualPatterns(F, DT, TTI, TLI);
MadeChange |= foldUnusualPatterns(F, DT, TTI, TLI, AA);
return MadeChange;
}

Expand All @@ -696,14 +898,16 @@ void AggressiveInstCombinerLegacyPass::getAnalysisUsage(
AU.addPreserved<BasicAAWrapperPass>();
AU.addPreserved<DominatorTreeWrapperPass>();
AU.addPreserved<GlobalsAAWrapperPass>();
AU.addRequired<AAResultsWrapperPass>();
}

bool AggressiveInstCombinerLegacyPass::runOnFunction(Function &F) {
auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
return runImpl(F, AC, TTI, TLI, DT);
auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
return runImpl(F, AC, TTI, TLI, DT, AA);
}

PreservedAnalyses AggressiveInstCombinePass::run(Function &F,
Expand All @@ -712,7 +916,8 @@ PreservedAnalyses AggressiveInstCombinePass::run(Function &F,
auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
auto &TTI = AM.getResult<TargetIRAnalysis>(F);
if (!runImpl(F, AC, TTI, TLI, DT)) {
auto &AA = AM.getResult<AAManager>(F);
if (!runImpl(F, AC, TTI, TLI, DT, AA)) {
// No changes, all analyses are preserved.
return PreservedAnalyses::all();
}
Expand Down

0 comments on commit 3b49a9f

Please sign in to comment.