diff --git a/llvm/include/llvm/CodeGen/BasicBlockSectionsProfileReader.h b/llvm/include/llvm/CodeGen/BasicBlockSectionsProfileReader.h index ee1f28377f7e4..891f50c9696ad 100644 --- a/llvm/include/llvm/CodeGen/BasicBlockSectionsProfileReader.h +++ b/llvm/include/llvm/CodeGen/BasicBlockSectionsProfileReader.h @@ -42,6 +42,22 @@ struct BBClusterInfo { unsigned PositionInCluster; }; +// The prefetch symbol is emitted immediately after the call of the given index, +// in block `BBID` (First call has an index of 1). Zero callsite index means the +// start of the block. +struct CallsiteID { + UniqueBBID BBID; + unsigned CallsiteIndex; +}; + +// This represents a prefetch hint to be injected at site `SiteID`, targetting +// `TargetID` in function `TargetFunction`. +struct PrefetchHint { + CallsiteID SiteID; + StringRef TargetFunction; + CallsiteID TargetID; +}; + // This represents the raw input profile for one function. struct FunctionPathAndClusterInfo { // BB Cluster information specified by `UniqueBBID`s. @@ -50,9 +66,13 @@ struct FunctionPathAndClusterInfo { // the edge a -> b (a is not cloned). The index of the path in this vector // determines the `UniqueBBID::CloneID` of the cloned blocks in that path. SmallVector> ClonePaths; + // Code prefetch targets, specified by the callsite ID immediately after + // which beginning must be targetted for prefetching. + SmallVector PrefetchTargets; + SmallVector PrefetchHints; // Node counts for each basic block. DenseMap NodeCounts; - // Edge counts for each edge, stored as a nested map. + // Edge counts for each edge. DenseMap> EdgeCounts; // Hash for each basic block. The Hashes are stored for every original block // (not cloned blocks), hence the map key being unsigned instead of @@ -86,6 +106,15 @@ class BasicBlockSectionsProfileReader { uint64_t getEdgeCount(StringRef FuncName, const UniqueBBID &SrcBBID, const UniqueBBID &SinkBBID) const; + // Returns the prefetch targets (identified by their containing callsite IDs) + // for function `FuncName`. + SmallVector + getPrefetchTargetsForFunction(StringRef FuncName) const; + + // Returns the prefetch hints to be injected in function `FuncName`. + SmallVector + getPrefetchHintsForFunction(StringRef FuncName) const; + private: StringRef getAliasName(StringRef FuncName) const { auto R = FuncAliasMap.find(FuncName); @@ -195,6 +224,12 @@ class BasicBlockSectionsProfileReaderWrapperPass : public ImmutablePass { uint64_t getEdgeCount(StringRef FuncName, const UniqueBBID &SrcBBID, const UniqueBBID &DestBBID) const; + SmallVector + getPrefetchHintsForFunction(StringRef FuncName) const; + + SmallVector + getPrefetchTargetsForFunction(StringRef FuncName) const; + // Initializes the FunctionNameToDIFilename map for the current module and // then reads the profile for the matching functions. bool doInitialization(Module &M) override; diff --git a/llvm/include/llvm/CodeGen/InsertCodePrefetch.h b/llvm/include/llvm/CodeGen/InsertCodePrefetch.h new file mode 100644 index 0000000000000..99241248862d3 --- /dev/null +++ b/llvm/include/llvm/CodeGen/InsertCodePrefetch.h @@ -0,0 +1,25 @@ +//===- BasicBlockSectionUtils.h - Utilities for basic block sections --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CODEGEN_INSERTCODEPREFETCH_H +#define LLVM_CODEGEN_INSERTCODEPREFETCH_H + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/UniqueBBID.h" + +namespace llvm { + +SmallString<128> getPrefetchTargetSymbolName(StringRef FunctionName, + const UniqueBBID &BBID, + unsigned SubblockIndex); + +} // end namespace llvm + +#endif // LLVM_CODEGEN_INSERTCODEPREFETCH_H diff --git a/llvm/include/llvm/CodeGen/MachineBasicBlock.h b/llvm/include/llvm/CodeGen/MachineBasicBlock.h index fcf7bab09fcff..be2fe2b3ef80b 100644 --- a/llvm/include/llvm/CodeGen/MachineBasicBlock.h +++ b/llvm/include/llvm/CodeGen/MachineBasicBlock.h @@ -213,6 +213,8 @@ class MachineBasicBlock /// basic block sections and basic block labels. std::optional BBID; + SmallVector PrefetchTargets; + /// With basic block sections, this stores the Section ID of the basic block. MBBSectionID SectionID{0}; @@ -229,6 +231,12 @@ class MachineBasicBlock /// is only computed once and is cached. mutable MCSymbol *CachedMCSymbol = nullptr; + /// Contains the callsite indices in this block that are targets of code + /// prefetching. The index `i` specifies the `i`th call, with zero + /// representing the beginning of the block and ` representing the first call. + /// Must be in ascending order and without duplicates. + SmallVector PrefetchTargetCallsiteIndexes; + /// Cached MCSymbol for this block (used if IsEHContTarget). mutable MCSymbol *CachedEHContMCSymbol = nullptr; @@ -710,6 +718,14 @@ class MachineBasicBlock std::optional getBBID() const { return BBID; } + const SmallVector &getPrefetchTargetCallsiteIndexes() const { + return PrefetchTargetCallsiteIndexes; + } + + void setPrefetchTargetCallsiteIndexes(const SmallVector &V) { + PrefetchTargetCallsiteIndexes = V; + } + /// Returns the section ID of this basic block. MBBSectionID getSectionID() const { return SectionID; } diff --git a/llvm/include/llvm/CodeGen/Passes.h b/llvm/include/llvm/CodeGen/Passes.h index a8525554b142e..f148d050a5772 100644 --- a/llvm/include/llvm/CodeGen/Passes.h +++ b/llvm/include/llvm/CodeGen/Passes.h @@ -69,6 +69,8 @@ LLVM_ABI MachineFunctionPass *createBasicBlockSectionsPass(); LLVM_ABI MachineFunctionPass *createBasicBlockPathCloningPass(); +LLVM_ABI MachineFunctionPass *createInsertCodePrefetchPass(); + /// createMachineBlockHashInfoPass - This pass computes basic block hashes. LLVM_ABI MachineFunctionPass *createMachineBlockHashInfoPass(); diff --git a/llvm/include/llvm/CodeGen/TargetInstrInfo.h b/llvm/include/llvm/CodeGen/TargetInstrInfo.h index 18142c2c0adf3..118b0b8ec7f82 100644 --- a/llvm/include/llvm/CodeGen/TargetInstrInfo.h +++ b/llvm/include/llvm/CodeGen/TargetInstrInfo.h @@ -2381,6 +2381,14 @@ class LLVM_ABI TargetInstrInfo : public MCInstrInfo { llvm_unreachable("unknown number of operands necessary"); } + /// Inserts a code prefetch instruction before `InsertBefore` in block `MBB` + /// targetting `GV`. + virtual bool insertCodePrefetchInstr(MachineBasicBlock &MBB, + MachineBasicBlock::iterator InsertBefore, + const GlobalValue *GV) const { + return false; + } + private: mutable std::unique_ptr Formatter; unsigned CallFrameSetupOpcode, CallFrameDestroyOpcode; diff --git a/llvm/include/llvm/InitializePasses.h b/llvm/include/llvm/InitializePasses.h index 10a4d8525a9e8..35d5ab14dc226 100644 --- a/llvm/include/llvm/InitializePasses.h +++ b/llvm/include/llvm/InitializePasses.h @@ -56,6 +56,7 @@ LLVM_ABI void initializeAssignmentTrackingAnalysisPass(PassRegistry &); LLVM_ABI void initializeAssumptionCacheTrackerPass(PassRegistry &); LLVM_ABI void initializeAtomicExpandLegacyPass(PassRegistry &); LLVM_ABI void initializeBasicBlockPathCloningPass(PassRegistry &); +LLVM_ABI void initializeInsertCodePrefetchPass(PassRegistry &); LLVM_ABI void initializeBasicBlockSectionsProfileReaderWrapperPassPass(PassRegistry &); LLVM_ABI void initializeBasicBlockSectionsPass(PassRegistry &); diff --git a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp index 3aa245b7f3f1e..79d4ff8fef27b 100644 --- a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp +++ b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp @@ -39,6 +39,7 @@ #include "llvm/BinaryFormat/ELF.h" #include "llvm/CodeGen/GCMetadata.h" #include "llvm/CodeGen/GCMetadataPrinter.h" +#include "llvm/CodeGen/InsertCodePrefetch.h" #include "llvm/CodeGen/LazyMachineBlockFrequencyInfo.h" #include "llvm/CodeGen/MachineBasicBlock.h" #include "llvm/CodeGen/MachineBlockHashInfo.h" @@ -1985,7 +1986,33 @@ void AsmPrinter::emitFunctionBody() { // Print a label for the basic block. emitBasicBlockStart(MBB); DenseMap MnemonicCounts; + + SmallVector PrefetchTargets = + MBB.getPrefetchTargetCallsiteIndexes(); + auto PrefetchTargetIt = PrefetchTargets.begin(); + unsigned LastCallsiteIndex = 0; + // Helper to emit a symbol for the prefetch target and proceed to the next + // one. + auto EmitPrefetchTargetSymbolIfNeeded = [&]() { + if (PrefetchTargetIt != PrefetchTargets.end() && + *PrefetchTargetIt == LastCallsiteIndex) { + MCSymbol *PrefetchTargetSymbol = OutContext.getOrCreateSymbol( + Twine("__llvm_prefetch_target_") + MF->getName() + Twine("_") + + utostr(MBB.getBBID()->BaseID) + Twine("_") + + utostr(static_cast(*PrefetchTargetIt))); + // If the function is weak-linkage it may be replaced by a strong + // version, in which case the prefetch targets should also be replaced. + OutStreamer->emitSymbolAttribute( + PrefetchTargetSymbol, + MF->getFunction().isWeakForLinker() ? MCSA_Weak : MCSA_Global); + OutStreamer->emitLabel(PrefetchTargetSymbol); + ++PrefetchTargetIt; + } + }; + for (auto &MI : MBB) { + EmitPrefetchTargetSymbolIfNeeded(); + // Print the assembly for the instruction. if (!MI.isPosition() && !MI.isImplicitDef() && !MI.isKill() && !MI.isDebugInstr()) { @@ -2123,8 +2150,11 @@ void AsmPrinter::emitFunctionBody() { break; } - if (MI.isCall() && MF->getTarget().Options.BBAddrMap) - OutStreamer->emitLabel(createCallsiteEndSymbol(MBB)); + if (MI.isCall()) { + if (MF->getTarget().Options.BBAddrMap) + OutStreamer->emitLabel(createCallsiteEndSymbol(MBB)); + LastCallsiteIndex++; + } if (TM.Options.EmitCallGraphSection && MI.isCall()) handleCallsiteForCallgraph(FuncCGInfo, CallSitesInfoMap, MI); @@ -2136,6 +2166,8 @@ void AsmPrinter::emitFunctionBody() { for (auto &Handler : Handlers) Handler->endInstruction(); } + // Emit the last prefetch target in case the last instruction was a call. + EmitPrefetchTargetSymbolIfNeeded(); // We must emit temporary symbol for the end of this basic block, if either // we have BBLabels enabled or if this basic blocks marks the end of a diff --git a/llvm/lib/CodeGen/BasicBlockSections.cpp b/llvm/lib/CodeGen/BasicBlockSections.cpp index 52e2909bec072..755abdbceaf4a 100644 --- a/llvm/lib/CodeGen/BasicBlockSections.cpp +++ b/llvm/lib/CodeGen/BasicBlockSections.cpp @@ -106,7 +106,8 @@ class BasicBlockSections : public MachineFunctionPass { public: static char ID; - BasicBlockSectionsProfileReaderWrapperPass *BBSectionsProfileReader = nullptr; + // BasicBlockSectionsProfileReaderWrapperPass *BBSectionsProfileReader = + // nullptr; BasicBlockSections() : MachineFunctionPass(ID) { initializeBasicBlockSectionsPass(*PassRegistry::getPassRegistry()); diff --git a/llvm/lib/CodeGen/BasicBlockSectionsProfileReader.cpp b/llvm/lib/CodeGen/BasicBlockSectionsProfileReader.cpp index c234c0f1b0b34..223831bb94805 100644 --- a/llvm/lib/CodeGen/BasicBlockSectionsProfileReader.cpp +++ b/llvm/lib/CodeGen/BasicBlockSectionsProfileReader.cpp @@ -93,6 +93,19 @@ uint64_t BasicBlockSectionsProfileReader::getEdgeCount( return EdgeIt->second; } +SmallVector +BasicBlockSectionsProfileReader::getPrefetchTargetsForFunction( + StringRef FuncName) const { + return ProgramPathAndClusterInfo.lookup(getAliasName(FuncName)) + .PrefetchTargets; +} + +SmallVector +BasicBlockSectionsProfileReader::getPrefetchHintsForFunction( + StringRef FuncName) const { + return ProgramPathAndClusterInfo.lookup(getAliasName(FuncName)).PrefetchHints; +} + // Reads the version 1 basic block sections profile. Profile for each function // is encoded as follows: // m @@ -148,6 +161,54 @@ uint64_t BasicBlockSectionsProfileReader::getEdgeCount( // +-->: 5 : // .... // **************************************************************************** +// This profile can also specify prefetch targets (starting with 't') which +// instruct the compiler to emit a prefetch symbol for the given target and +// prefetch hints (start with 'i') which instruct the compiler to insert a +// prefetch hint instruction at the given site for the given target. +// +// A prefetch target is specified by a pair "," where +// bbid specifies the target basic block and subblock_index is a zero-based +// index. Callsite 0 refers to the region at the beginning of the block up to +// the first callsite. Callsite `i > 0` refers to the region immediately after +// the `i`-th callsite up to the `i+1`-th callsite (or the end of the block). +// The prefetch target is always emitted at the beginning of the subblock. +// This is the beginning of the basic block for `i = 0` and immediately after +// the `i`-th call for every `i > 0`. +// +// A prefetch int is specified by a pair "site target", where site is +// specified as a pair "," similar to prefetch +// targets, and target is specified as a triple +// ",,". +// +// Example: A basic block in function "foo" with BBID 10 and two call +// instructions (call_A, call_B). This block is conceptually split into +// subblocks, with the prefetch target symbol emitted at the beginning of +// each subblock. +// +// +----------------------------------+ +// | __llvm_prefetch_target_foo_10_0: | <- Callsite 0 (before call_A) +// | Instruction 1 | +// | Instruction 2 | +// | call_A (Callsite 0) | +// | __llvm_prefetch_target_foo_10_1: | <--- Callsite 1 (after call_A, +// | | before call_B) +// | Instruction 3 | +// | call_B (Callsite 1) | +// | __llvm_prefetch_target_foo_10_2: | <--- Callsite 2 (after call_B, +// | | before call_C) +// | Instruction 4 | +// +----------------------------------+ +// +// A prefetch hint specified in function "bar" as "120,1 foo,10,2" results +// in a a hint inserted after the first call in block #120 of bar: +// B +// +----------------------------------------------------+ +// | Instruction 1 | +// | call_C (Callsite 1) | +// | code_prefetch __llvm_prfetch_target_foo_10 | +// | Instruction 2 | +// +----------------------------------------------------+ +// Error BasicBlockSectionsProfileReader::ReadV1Profile() { auto FI = ProgramPathAndClusterInfo.end(); @@ -308,6 +369,67 @@ Error BasicBlockSectionsProfileReader::ReadV1Profile() { } continue; } + case 'i': { // Prefetch hint specifier. + // Skip the profile when we the profile iterator (FI) refers to the + // past-the-end element. + if (FI == ProgramPathAndClusterInfo.end()) + continue; + if (Values.size() != 2) + return createProfileParseError(Twine("Prefetch hint expected: " + S)); + SmallVector PrefetchSiteStr; + Values[0].split(PrefetchSiteStr, ','); + if (PrefetchSiteStr.size() != 2) + return createProfileParseError(Twine("Prefetch site expected: ") + + Values[0]); + auto SiteBBID = parseUniqueBBID(PrefetchSiteStr[0]); + if (!SiteBBID) + return SiteBBID.takeError(); + unsigned long long SiteCallsiteIndex; + if (getAsUnsignedInteger(PrefetchSiteStr[1], 10, SiteCallsiteIndex)) + return createProfileParseError(Twine("unsigned integer expected: '") + + PrefetchSiteStr[1]); + + SmallVector PrefetchTargetStr; + Values[1].split(PrefetchTargetStr, ','); + if (PrefetchTargetStr.size() != 3) + return createProfileParseError( + Twine("Prefetch target target expected: ") + Values[1]); + auto TargetBBID = parseUniqueBBID(PrefetchTargetStr[1]); + if (!TargetBBID) + return TargetBBID.takeError(); + unsigned long long TargetCallsiteIndex; + if (getAsUnsignedInteger(PrefetchTargetStr[2], 10, TargetCallsiteIndex)) + return createProfileParseError(Twine("unsigned integer expected: '") + + PrefetchTargetStr[2]); + FI->second.PrefetchHints.push_back(PrefetchHint{ + CallsiteID{*SiteBBID, static_cast(SiteCallsiteIndex)}, + PrefetchTargetStr[0], + CallsiteID{*TargetBBID, static_cast(TargetCallsiteIndex)}}); + continue; + } + case 't': { // Prefetch target specifier. + // Skip the profile when we the profile iterator (FI) refers to the + // past-the-end element. + if (FI == ProgramPathAndClusterInfo.end()) + continue; + SmallVector PrefetchTargetStr; + if (Values.size() != 1) + return createProfileParseError(Twine("Prefetch target expected: ") + S); + Values[0].split(PrefetchTargetStr, ','); + if (PrefetchTargetStr.size() != 2) + return createProfileParseError(Twine("Prefetch target expected: ") + + Values[0]); + auto TargetBBID = parseUniqueBBID(PrefetchTargetStr[0]); + if (!TargetBBID) + return TargetBBID.takeError(); + unsigned long long CallsiteIndex; + if (getAsUnsignedInteger(PrefetchTargetStr[1], 10, CallsiteIndex)) + return createProfileParseError(Twine("signed integer expected: '") + + PrefetchTargetStr[1]); + FI->second.PrefetchTargets.push_back( + CallsiteID{*TargetBBID, static_cast(CallsiteIndex)}); + continue; + } default: return createProfileParseError(Twine("invalid specifier: '") + Twine(Specifier) + "'"); @@ -514,6 +636,18 @@ uint64_t BasicBlockSectionsProfileReaderWrapperPass::getEdgeCount( return BBSPR.getEdgeCount(FuncName, SrcBBID, SinkBBID); } +SmallVector +BasicBlockSectionsProfileReaderWrapperPass::getPrefetchTargetsForFunction( + StringRef FuncName) const { + return BBSPR.getPrefetchTargetsForFunction(FuncName); +} + +SmallVector +BasicBlockSectionsProfileReaderWrapperPass::getPrefetchHintsForFunction( + StringRef FuncName) const { + return BBSPR.getPrefetchHintsForFunction(FuncName); +} + BasicBlockSectionsProfileReader & BasicBlockSectionsProfileReaderWrapperPass::getBBSPR() { return BBSPR; diff --git a/llvm/lib/CodeGen/CMakeLists.txt b/llvm/lib/CodeGen/CMakeLists.txt index 1cf0b4964760b..fcf28247179ca 100644 --- a/llvm/lib/CodeGen/CMakeLists.txt +++ b/llvm/lib/CodeGen/CMakeLists.txt @@ -79,6 +79,7 @@ add_llvm_component_library(LLVMCodeGen IndirectBrExpandPass.cpp InitUndef.cpp InlineSpiller.cpp + InsertCodePrefetch.cpp InterferenceCache.cpp InterleavedAccessPass.cpp InterleavedLoadCombinePass.cpp diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp index 587c1372b19cb..47c7bbea739ae 100644 --- a/llvm/lib/CodeGen/CodeGenPrepare.cpp +++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp @@ -22,6 +22,7 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/FloatingPointPredicateUtils.h" diff --git a/llvm/lib/CodeGen/InsertCodePrefetch.cpp b/llvm/lib/CodeGen/InsertCodePrefetch.cpp new file mode 100644 index 0000000000000..68a500c545651 --- /dev/null +++ b/llvm/lib/CodeGen/InsertCodePrefetch.cpp @@ -0,0 +1,164 @@ +//===-- InsertCodePrefetch.cpp ---=========--------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +/// \file +/// Code Prefetch Insertion Pass. +//===----------------------------------------------------------------------===// +/// This pass inserts code prefetch instructions according to the prefetch +/// directives in the basic block section profile. The target of a prefetch can +/// be the beginning of any dynamic basic block, that is the beginning of a +/// machine basic block, or immediately after a callsite. A global symbol is +/// emitted at the position of the target so it can be addressed from the +/// prefetch instruction from any module. In order to insert prefetch hints, +/// `TargetInstrInfo::insertCodePrefetchInstr` must be implemented by the +/// target. +//===----------------------------------------------------------------------===// + +#include "llvm/CodeGen/InsertCodePrefetch.h" + +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/CodeGen/BasicBlockSectionUtils.h" +#include "llvm/CodeGen/BasicBlockSectionsProfileReader.h" +#include "llvm/CodeGen/MachineBasicBlock.h" +#include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineFunctionPass.h" +#include "llvm/CodeGen/Passes.h" +#include "llvm/CodeGen/TargetInstrInfo.h" +#include "llvm/InitializePasses.h" + +using namespace llvm; +#define DEBUG_TYPE "insert-code-prefetch" + +namespace llvm { +SmallString<128> getPrefetchTargetSymbolName(StringRef FunctionName, + const UniqueBBID &BBID, + unsigned CallsiteIndex) { + SmallString<128> R("__llvm_prefetch_target_"); + R += FunctionName; + R += "_"; + R += utostr(BBID.BaseID); + R += "_"; + R += utostr(CallsiteIndex); + return R; +} +} // namespace llvm + +namespace { +class InsertCodePrefetch : public MachineFunctionPass { +public: + static char ID; + + InsertCodePrefetch() : MachineFunctionPass(ID) { + initializeInsertCodePrefetchPass(*PassRegistry::getPassRegistry()); + } + + StringRef getPassName() const override { + return "Code Prefetch Inserter Pass"; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override; + + // Sets prefetch targets based on the bb section profile. + bool runOnMachineFunction(MachineFunction &MF) override; +}; + +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// Implementation +//===----------------------------------------------------------------------===// + +char InsertCodePrefetch::ID = 0; +INITIALIZE_PASS_BEGIN(InsertCodePrefetch, DEBUG_TYPE, "Code prefetch insertion", + true, false) +INITIALIZE_PASS_DEPENDENCY(BasicBlockSectionsProfileReaderWrapperPass) +INITIALIZE_PASS_END(InsertCodePrefetch, DEBUG_TYPE, "Code prefetch insertion", + true, false) + +bool InsertCodePrefetch::runOnMachineFunction(MachineFunction &MF) { + assert(MF.getTarget().getBBSectionsType() == BasicBlockSection::List && + "BB Sections list not enabled!"); + if (hasInstrProfHashMismatch(MF)) + return false; + // Set each block's prefetch targets so AsmPrinter can emit a special symbol + // there. + SmallVector PrefetchTargets = + getAnalysis() + .getPrefetchTargetsForFunction(MF.getName()); + DenseMap> PrefetchTargetsByBBID; + for (const auto &Target : PrefetchTargets) + PrefetchTargetsByBBID[Target.BBID].push_back(Target.CallsiteIndex); + // Sort and uniquify the callsite indices for every block. + for (auto &[K, V] : PrefetchTargetsByBBID) { + llvm::sort(V); + V.erase(llvm::unique(V), V.end()); + } + for (auto &MBB : MF) { + auto R = PrefetchTargetsByBBID.find(*MBB.getBBID()); + if (R == PrefetchTargetsByBBID.end()) + continue; + MBB.setPrefetchTargetCallsiteIndexes(R->second); + } + SmallVector PrefetchHints = + getAnalysis() + .getPrefetchHintsForFunction(MF.getName()); + DenseMap> PrefetchHintsBySiteBBID; + for (const auto &H : PrefetchHints) + PrefetchHintsBySiteBBID[H.SiteID.BBID].push_back(H); + // Sort prefetch hints by their callsite index so we can insert them by one + // pass over the block's instructions. + for (auto &[SiteBBID, Hints] : PrefetchHintsBySiteBBID) { + llvm::sort(Hints, [](const PrefetchHint &H1, const PrefetchHint &H2) { + return H1.SiteID.CallsiteIndex < H2.SiteID.CallsiteIndex; + }); + } + auto PtrTy = + PointerType::getUnqual(MF.getFunction().getParent()->getContext()); + const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo(); + for (auto &BB : MF) { + auto It = PrefetchHintsBySiteBBID.find(*BB.getBBID()); + if (It == PrefetchHintsBySiteBBID.end()) + continue; + const auto &PrefetchHints = It->second; + unsigned NumCallsInBB = 0; + auto InstrIt = BB.begin(); + for (auto HintIt = PrefetchHints.begin(); HintIt != PrefetchHints.end();) { + auto NextInstrIt = InstrIt == BB.end() ? BB.end() : std::next(InstrIt); + // Insert all the prefetch hints which must be placed after this call (or + // at the beginning of the block if `NumCallsInBB` is zero. + while (HintIt != PrefetchHints.end() && + NumCallsInBB >= HintIt->SiteID.CallsiteIndex) { + auto *GV = MF.getFunction().getParent()->getOrInsertGlobal( + getPrefetchTargetSymbolName(HintIt->TargetFunction, + HintIt->TargetID.BBID, + HintIt->TargetID.CallsiteIndex), + PtrTy); + TII->insertCodePrefetchInstr(BB, InstrIt, GV); + ++HintIt; + } + if (InstrIt == BB.end()) + break; + if (InstrIt->isCall()) + ++NumCallsInBB; + InstrIt = NextInstrIt; + } + } + return true; +} + +void InsertCodePrefetch::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequired(); + MachineFunctionPass::getAnalysisUsage(AU); +} + +MachineFunctionPass *llvm::createInsertCodePrefetchPass() { + return new InsertCodePrefetch(); +} diff --git a/llvm/lib/CodeGen/TargetPassConfig.cpp b/llvm/lib/CodeGen/TargetPassConfig.cpp index ceae0d29eea90..5334c5596d018 100644 --- a/llvm/lib/CodeGen/TargetPassConfig.cpp +++ b/llvm/lib/CodeGen/TargetPassConfig.cpp @@ -1291,6 +1291,7 @@ void TargetPassConfig::addMachinePasses() { addPass(llvm::createBasicBlockSectionsProfileReaderWrapperPass( TM->getBBSectionsFuncListBuf())); addPass(llvm::createBasicBlockPathCloningPass()); + addPass(llvm::createInsertCodePrefetchPass()); } addPass(llvm::createBasicBlockSectionsPass()); } diff --git a/llvm/lib/Target/X86/X86InstrInfo.cpp b/llvm/lib/Target/X86/X86InstrInfo.cpp index cb0208a4a5f32..6556e16241557 100644 --- a/llvm/lib/Target/X86/X86InstrInfo.cpp +++ b/llvm/lib/Target/X86/X86InstrInfo.cpp @@ -10978,5 +10978,25 @@ void X86InstrInfo::getFrameIndexOperands(SmallVectorImpl &Ops, M.getFullAddress(Ops); } +bool X86InstrInfo::insertCodePrefetchInstr( + MachineBasicBlock &MBB, MachineBasicBlock::iterator InsertBefore, + const GlobalValue *GV) const { + MachineFunction &MF = *MBB.getParent(); + MachineInstr *PrefetchInstr = MF.CreateMachineInstr( + get(X86::PREFETCHIT1), + InsertBefore == MBB.instr_end() ? MBB.findPrevDebugLoc(InsertBefore) + : InsertBefore->getDebugLoc(), + true); + MachineInstrBuilder MIB(MF, PrefetchInstr); + MIB.addMemOperand(MF.getMachineMemOperand(MachinePointerInfo(GV), + MachineMemOperand::MOLoad, /*s=*/8, + /*base_alignment=*/llvm::Align(1))); + MIB.addReg(X86::RIP).addImm(1).addReg(X86::NoRegister); + MIB.addGlobalAddress(GV); + MIB.addReg(X86::NoRegister); + MBB.insert(InsertBefore, PrefetchInstr); + return true; +} + #define GET_INSTRINFO_HELPERS #include "X86GenInstrInfo.inc" diff --git a/llvm/lib/Target/X86/X86InstrInfo.h b/llvm/lib/Target/X86/X86InstrInfo.h index a547fcd421411..2fe67c56e1bcd 100644 --- a/llvm/lib/Target/X86/X86InstrInfo.h +++ b/llvm/lib/Target/X86/X86InstrInfo.h @@ -767,6 +767,10 @@ class X86InstrInfo final : public X86GenInstrInfo { /// \returns the index of operand that is commuted with \p Idx1. If the method /// fails to commute the operands, it will return \p Idx1. unsigned commuteOperandsForFold(MachineInstr &MI, unsigned Idx1) const; + + bool insertCodePrefetchInstr(MachineBasicBlock &MBB, + MachineBasicBlock::iterator InsertBefore, + const GlobalValue *GV) const override; }; } // namespace llvm diff --git a/llvm/test/CodeGen/X86/basic-block-sections-code-prefetch.ll b/llvm/test/CodeGen/X86/basic-block-sections-code-prefetch.ll new file mode 100644 index 0000000000000..81fdccbbf73af --- /dev/null +++ b/llvm/test/CodeGen/X86/basic-block-sections-code-prefetch.ll @@ -0,0 +1,75 @@ +;; Check prefetch directives in basic block section profiles. +;; +;; Specify the bb sections profile: +; RUN: echo 'v1' > %t +; RUN: echo 'f _Z3foob' >> %t +; RUN: echo 't 0,0' >> %t +; RUN: echo 't 1,0' >> %t +; RUN: echo 't 1,1' >> %t +; RUN: echo 't 2,1' >> %t +; RUN: echo 't 3,0' >> %t +; RUN: echo 'i 3,0 _Z3barv,0,0' >> %t +; RUN: echo 'i 2,1 _Z3foob,1,0' >> %t +; RUN: echo 'f _Z3barv' >> %t +; RUN: echo 't 0,0' >> %t +; RUN: echo 't 21,1' >> %t +; RUN: echo 'i 0,1 _Z3foob,0,0' >> %t +;; +; RUN: llc < %s -O0 -mtriple=x86_64-pc-linux -asm-verbose=false -function-sections -basic-block-sections=%t | FileCheck %s + +define i32 @_Z3foob(i1 zeroext %0) nounwind { + %2 = alloca i32, align 4 + %3 = alloca i8, align 1 + %4 = zext i1 %0 to i8 + store i8 %4, ptr %3, align 1 + %5 = load i8, ptr %3, align 1 + %6 = trunc i8 %5 to i1 + %7 = zext i1 %6 to i32 + %8 = icmp sgt i32 %7, 0 + br i1 %8, label %9, label %11 +; CHECK: _Z3foob: +; CHECK-NEXT: .globl __llvm_prefetch_target__Z3foob_0_0 +; CHECK-NEXT: __llvm_prefetch_target__Z3foob_0_0: + +9: ; preds = %1 + %10 = call i32 @_Z3barv() + store i32 %10, ptr %2, align 4 + br label %13 +; CHECK: .globl __llvm_prefetch_target__Z3foob_1_0 +; CHECK-NEXT: __llvm_prefetch_target__Z3foob_1_0: +; CHECK-NEXT: callq _Z3barv@PLT +; CHECK-NEXT: .globl __llvm_prefetch_target__Z3foob_1_1 +; CHECK-NEXT: __llvm_prefetch_target__Z3foob_1_1: + +11: ; preds = %1 + %12 = call i32 @_Z3bazv() + store i32 %12, ptr %2, align 4 + br label %13 +; CHECK: callq _Z3bazv@PLT +; CHECK-NEXT: .globl __llvm_prefetch_target__Z3foob_2_1 +; CHECK-NEXT: __llvm_prefetch_target__Z3foob_2_1: +; CHECK-NEXT: prefetchit1 __llvm_prefetch_target__Z3foob_1_0(%rip) + +13: ; preds = %11, %9 + %14 = load i32, ptr %2, align 4 + ret i32 %14 +; CHECK: .LBB0_3: +; CHECK-NEXT: .globl __llvm_prefetch_target__Z3foob_3_0 +; CHECK-NEXT: __llvm_prefetch_target__Z3foob_3_0: +; CHECK-NEXT: prefetchit1 __llvm_prefetch_target__Z3barv_0_0(%rip) +; CHECK: retq +} + +define weak i32 @_Z3barv() nounwind { + %1 = call i32 @_Z3bazv() + br label %2 +; CHECK: _Z3barv: +; CHECK-NEXT: .weak __llvm_prefetch_target__Z3barv_0_0 +; CHECK-NEXT: __llvm_prefetch_target__Z3barv_0_0: +; CHECK: callq _Z3bazv@PLT +; CHECK-NEXT: prefetchit1 __llvm_prefetch_target__Z3foob_0_0(%rip) +2: + ret i32 %1 +} + +declare i32 @_Z3bazv() #1