diff --git a/llvm/docs/MLGO.rst b/llvm/docs/MLGO.rst index 965a21b8c84b8..bf3de11a2640e 100644 --- a/llvm/docs/MLGO.rst +++ b/llvm/docs/MLGO.rst @@ -508,7 +508,7 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance. .. code-block:: c++ - const ir2vec::Embedding &FuncVector = Emb->getFunctionVector(); + ir2vec::Embedding FuncVector = Emb->getFunctionVector(); Currently, ``Embedder`` can generate embeddings at three levels: Instructions, Basic Blocks, and Functions. Appropriate getters are provided to access the diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h index 81409df7337c5..6bc51feb580d9 100644 --- a/llvm/include/llvm/Analysis/IR2Vec.h +++ b/llvm/include/llvm/Analysis/IR2Vec.h @@ -533,21 +533,20 @@ class Embedder { /// in the IR instructions to generate the vector representation. const float OpcWeight, TypeWeight, ArgWeight; - // Utility maps - these are used to store the vector representations of - // instructions, basic blocks and functions. - mutable Embedding FuncVector; - mutable BBEmbeddingsMap BBVecMap; - mutable InstEmbeddingsMap InstVecMap; - - LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab); + LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab) + : F(F), Vocab(Vocab), Dimension(Vocab.getDimension()), + OpcWeight(ir2vec::OpcWeight), TypeWeight(ir2vec::TypeWeight), + ArgWeight(ir2vec::ArgWeight) {} - /// Function to compute embeddings. It generates embeddings for all - /// the instructions and basic blocks in the function F. - void computeEmbeddings() const; + /// Function to compute embeddings. + Embedding computeEmbeddings() const; /// Function to compute the embedding for a given basic block. + Embedding computeEmbeddings(const BasicBlock &BB) const; + + /// Function to compute the embedding for a given instruction. /// Specific to the kind of embeddings being computed. - virtual void computeEmbeddings(const BasicBlock &BB) const = 0; + virtual Embedding computeEmbeddings(const Instruction &I) const = 0; public: virtual ~Embedder() = default; @@ -556,23 +555,27 @@ class Embedder { LLVM_ABI static std::unique_ptr create(IR2VecKind Mode, const Function &F, const Vocabulary &Vocab); - /// Returns a map containing instructions and the corresponding embeddings for - /// the function F if it has been computed. If not, it computes the embeddings - /// for the function and returns the map. - LLVM_ABI const InstEmbeddingsMap &getInstVecMap() const; - - /// Returns a map containing basic block and the corresponding embeddings for - /// the function F if it has been computed. If not, it computes the embeddings - /// for the function and returns the map. - LLVM_ABI const BBEmbeddingsMap &getBBVecMap() const; + /// Computes and returns the embedding for a given instruction in the function + /// F + LLVM_ABI Embedding getInstVector(const Instruction &I) const { + return computeEmbeddings(I); + } - /// Returns the embedding for a given basic block in the function F if it has - /// been computed. If not, it computes the embedding for the basic block and - /// returns it. - LLVM_ABI const Embedding &getBBVector(const BasicBlock &BB) const; + /// Computes and returns the embedding for a given basic block in the function + /// F + LLVM_ABI Embedding getBBVector(const BasicBlock &BB) const { + return computeEmbeddings(BB); + } /// Computes and returns the embedding for the current function. - LLVM_ABI const Embedding &getFunctionVector() const; + LLVM_ABI Embedding getFunctionVector() const { return computeEmbeddings(); } + + /// Invalidate embeddings if cached. The embeddings may not be relevant + /// anymore when the IR changes due to transformations. In such cases, the + /// cached embeddings should be invalidated to ensure + /// correctness/recomputation. This is a no-op for SymbolicEmbedder but + /// removes all the cached entries in FlowAwareEmbedder. + virtual void invalidateEmbeddings() { return; } }; /// Class for computing the Symbolic embeddings of IR2Vec. @@ -580,7 +583,7 @@ class Embedder { /// representations obtained from the Vocabulary. class LLVM_ABI SymbolicEmbedder : public Embedder { private: - void computeEmbeddings(const BasicBlock &BB) const override; + Embedding computeEmbeddings(const Instruction &I) const override; public: SymbolicEmbedder(const Function &F, const Vocabulary &Vocab) @@ -592,11 +595,15 @@ class LLVM_ABI SymbolicEmbedder : public Embedder { /// embeddings, and additionally capture the flow information in the IR. class LLVM_ABI FlowAwareEmbedder : public Embedder { private: - void computeEmbeddings(const BasicBlock &BB) const override; + // FlowAware embeddings would benefit from caching instruction embeddings as + // they are reused while computing the embeddings of other instructions. + mutable InstEmbeddingsMap InstVecMap; + Embedding computeEmbeddings(const Instruction &I) const override; public: FlowAwareEmbedder(const Function &F, const Vocabulary &Vocab) : Embedder(F, Vocab) {} + void invalidateEmbeddings() override { InstVecMap.clear(); } }; } // namespace ir2vec diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp index 1794a604b991d..85b5372c961c1 100644 --- a/llvm/lib/Analysis/IR2Vec.cpp +++ b/llvm/lib/Analysis/IR2Vec.cpp @@ -153,11 +153,6 @@ void Embedding::print(raw_ostream &OS) const { // Embedder and its subclasses //===----------------------------------------------------------------------===// -Embedder::Embedder(const Function &F, const Vocabulary &Vocab) - : F(F), Vocab(Vocab), Dimension(Vocab.getDimension()), - OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight), - FuncVector(Embedding(Dimension)) {} - std::unique_ptr Embedder::create(IR2VecKind Mode, const Function &F, const Vocabulary &Vocab) { switch (Mode) { @@ -169,110 +164,85 @@ std::unique_ptr Embedder::create(IR2VecKind Mode, const Function &F, return nullptr; } -const InstEmbeddingsMap &Embedder::getInstVecMap() const { - if (InstVecMap.empty()) - computeEmbeddings(); - return InstVecMap; -} - -const BBEmbeddingsMap &Embedder::getBBVecMap() const { - if (BBVecMap.empty()) - computeEmbeddings(); - return BBVecMap; -} - -const Embedding &Embedder::getBBVector(const BasicBlock &BB) const { - auto It = BBVecMap.find(&BB); - if (It != BBVecMap.end()) - return It->second; - computeEmbeddings(BB); - return BBVecMap[&BB]; -} +Embedding Embedder::computeEmbeddings() const { + Embedding FuncVector(Dimension, 0.0); -const Embedding &Embedder::getFunctionVector() const { - // Currently, we always (re)compute the embeddings for the function. - // This is cheaper than caching the vector. - computeEmbeddings(); - return FuncVector; -} - -void Embedder::computeEmbeddings() const { if (F.isDeclaration()) - return; - - FuncVector = Embedding(Dimension, 0.0); + return FuncVector; // Consider only the basic blocks that are reachable from entry - for (const BasicBlock *BB : depth_first(&F)) { - computeEmbeddings(*BB); - FuncVector += BBVecMap[BB]; - } + for (const BasicBlock *BB : depth_first(&F)) + FuncVector += computeEmbeddings(*BB); + return FuncVector; } -void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const { +Embedding Embedder::computeEmbeddings(const BasicBlock &BB) const { Embedding BBVector(Dimension, 0); // We consider only the non-debug and non-pseudo instructions - for (const auto &I : BB.instructionsWithoutDebug()) { - Embedding ArgEmb(Dimension, 0); - for (const auto &Op : I.operands()) - ArgEmb += Vocab[*Op]; - auto InstVector = - Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb; - if (const auto *IC = dyn_cast(&I)) - InstVector += Vocab[IC->getPredicate()]; - InstVecMap[&I] = InstVector; - BBVector += InstVector; - } - BBVecMap[&BB] = BBVector; -} - -void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const { - Embedding BBVector(Dimension, 0); + for (const auto &I : BB.instructionsWithoutDebug()) + BBVector += computeEmbeddings(I); + return BBVector; +} + +Embedding SymbolicEmbedder::computeEmbeddings(const Instruction &I) const { + // Currently, we always (re)compute the embeddings for symbolic embedder. + // This is cheaper than caching the vectors. + Embedding ArgEmb(Dimension, 0); + for (const auto &Op : I.operands()) + ArgEmb += Vocab[*Op]; + auto InstVector = + Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb; + if (const auto *IC = dyn_cast(&I)) + InstVector += Vocab[IC->getPredicate()]; + return InstVector; +} + +Embedding FlowAwareEmbedder::computeEmbeddings(const Instruction &I) const { + // If we have already computed the embedding for this instruction, return it + auto It = InstVecMap.find(&I); + if (It != InstVecMap.end()) + return It->second; - // We consider only the non-debug and non-pseudo instructions - for (const auto &I : BB.instructionsWithoutDebug()) { - // TODO: Handle call instructions differently. - // For now, we treat them like other instructions - Embedding ArgEmb(Dimension, 0); - for (const auto &Op : I.operands()) { - // If the operand is defined elsewhere, we use its embedding - if (const auto *DefInst = dyn_cast(Op)) { - auto DefIt = InstVecMap.find(DefInst); - // Fixme (#159171): Ideally we should never miss an instruction - // embedding here. - // But when we have cyclic dependencies (e.g., phi - // nodes), we might miss the embedding. In such cases, we fall back to - // using the vocabulary embedding. This can be fixed by iterating to a - // fixed-point, or by using a simple solver for the set of simultaneous - // equations. - // Another case when we might miss an instruction embedding is when - // the operand instruction is in a different basic block that has not - // been processed yet. This can be fixed by processing the basic blocks - // in a topological order. - if (DefIt != InstVecMap.end()) - ArgEmb += DefIt->second; - else - ArgEmb += Vocab[*Op]; - } - // If the operand is not defined by an instruction, we use the vocabulary - else { - LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: " - << *Op << "=" << Vocab[*Op][0] << "\n"); + // TODO: Handle call instructions differently. + // For now, we treat them like other instructions + Embedding ArgEmb(Dimension, 0); + for (const auto &Op : I.operands()) { + // If the operand is defined elsewhere, we use its embedding + if (const auto *DefInst = dyn_cast(Op)) { + auto DefIt = InstVecMap.find(DefInst); + // Fixme (#159171): Ideally we should never miss an instruction + // embedding here. + // But when we have cyclic dependencies (e.g., phi + // nodes), we might miss the embedding. In such cases, we fall back to + // using the vocabulary embedding. This can be fixed by iterating to a + // fixed-point, or by using a simple solver for the set of simultaneous + // equations. + // Another case when we might miss an instruction embedding is when + // the operand instruction is in a different basic block that has not + // been processed yet. This can be fixed by processing the basic blocks + // in a topological order. + if (DefIt != InstVecMap.end()) + ArgEmb += DefIt->second; + else ArgEmb += Vocab[*Op]; - } } - // Create the instruction vector by combining opcode, type, and arguments - // embeddings - auto InstVector = - Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb; - // Add compare predicate embedding as an additional operand if applicable - if (const auto *IC = dyn_cast(&I)) - InstVector += Vocab[IC->getPredicate()]; - InstVecMap[&I] = InstVector; - BBVector += InstVector; + // If the operand is not defined by an instruction, we use the + // vocabulary + else { + LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: " + << *Op << "=" << Vocab[*Op][0] << "\n"); + ArgEmb += Vocab[*Op]; + } } - BBVecMap[&BB] = BBVector; + // Create the instruction vector by combining opcode, type, and arguments + // embeddings + auto InstVector = + Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb; + if (const auto *IC = dyn_cast(&I)) + InstVector += Vocab[IC->getPredicate()]; + InstVecMap[&I] = InstVector; + return InstVector; } // ==----------------------------------------------------------------------===// @@ -695,25 +665,17 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M, Emb->getFunctionVector().print(OS); OS << "Basic block vectors:\n"; - const auto &BBMap = Emb->getBBVecMap(); for (const BasicBlock &BB : F) { - auto It = BBMap.find(&BB); - if (It != BBMap.end()) { - OS << "Basic block: " << BB.getName() << ":\n"; - It->second.print(OS); - } + OS << "Basic block: " << BB.getName() << ":\n"; + Emb->getBBVector(BB).print(OS); } OS << "Instruction vectors:\n"; - const auto &InstMap = Emb->getInstVecMap(); for (const BasicBlock &BB : F) { for (const Instruction &I : BB) { - auto It = InstMap.find(&I); - if (It != InstMap.end()) { - OS << "Instruction: "; - I.print(OS); - It->second.print(OS); - } + OS << "Instruction: "; + I.print(OS); + Emb->getInstVector(I).print(OS); } } } diff --git a/llvm/test/Analysis/IR2Vec/unreachable.ll b/llvm/test/Analysis/IR2Vec/unreachable.ll index 9be0ee1c2de7a..627e2c9ac6b2d 100644 --- a/llvm/test/Analysis/IR2Vec/unreachable.ll +++ b/llvm/test/Analysis/IR2Vec/unreachable.ll @@ -30,13 +30,17 @@ return: ; preds = %if.else, %if.then %4 = load i32, ptr %retval, align 4 ret i32 %4 } - -; CHECK: Basic block vectors: +; We'll get individual basic block embeddings for all blocks in the function. +; But unreachable blocks are not counted for computing the function embedding. +; CHECK: Function vector: [ 1301.20 1318.20 1335.20 ] +; CHECK-NEXT: Basic block vectors: ; CHECK-NEXT: Basic block: entry: ; CHECK-NEXT: [ 816.20 825.20 834.20 ] ; CHECK-NEXT: Basic block: if.then: ; CHECK-NEXT: [ 195.00 198.00 201.00 ] ; CHECK-NEXT: Basic block: if.else: ; CHECK-NEXT: [ 195.00 198.00 201.00 ] +; CHECK-NEXT: Basic block: unreachable: +; CHECK-NEXT: [ 101.00 103.00 105.00 ] ; CHECK-NEXT: Basic block: return: ; CHECK-NEXT: [ 95.00 97.00 99.00 ] diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp index 434449c7c5117..1031932116c1e 100644 --- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp +++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp @@ -253,25 +253,17 @@ class IR2VecTool { break; } case BasicBlockLevel: { - const auto &BBVecMap = Emb->getBBVecMap(); for (const BasicBlock &BB : F) { - auto It = BBVecMap.find(&BB); - if (It != BBVecMap.end()) { - OS << BB.getName() << ":"; - It->second.print(OS); - } + OS << BB.getName() << ":"; + Emb->getBBVector(BB).print(OS); } break; } case InstructionLevel: { - const auto &InstMap = Emb->getInstVecMap(); for (const BasicBlock &BB : F) { for (const Instruction &I : BB) { - auto It = InstMap.find(&I); - if (It != InstMap.end()) { - I.print(OS); - It->second.print(OS); - } + I.print(OS); + Emb->getInstVector(I).print(OS); } } break; diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp index 40b4aa21f2b46..8ffc5f61d5e55 100644 --- a/llvm/unittests/Analysis/IR2VecTest.cpp +++ b/llvm/unittests/Analysis/IR2VecTest.cpp @@ -30,7 +30,9 @@ namespace { class TestableEmbedder : public Embedder { public: TestableEmbedder(const Function &F, const Vocabulary &V) : Embedder(F, V) {} - void computeEmbeddings(const BasicBlock &BB) const override {} + Embedding computeEmbeddings(const Instruction &I) const override { + return Embedding(); + } }; TEST(EmbeddingTest, ConstructorsAndAccessors) { @@ -321,18 +323,12 @@ class IR2VecTestFixture : public ::testing::Test { } }; -TEST_F(IR2VecTestFixture, GetInstVecMap_Symbolic) { +TEST_F(IR2VecTestFixture, GetInstVec_Symbolic) { auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, *V); ASSERT_TRUE(static_cast(Emb)); - const auto &InstMap = Emb->getInstVecMap(); - - EXPECT_EQ(InstMap.size(), 2u); - EXPECT_TRUE(InstMap.count(AddInst)); - EXPECT_TRUE(InstMap.count(RetInst)); - - const auto &AddEmb = InstMap.at(AddInst); - const auto &RetEmb = InstMap.at(RetInst); + const auto &AddEmb = Emb->getInstVector(*AddInst); + const auto &RetEmb = Emb->getInstVector(*RetInst); EXPECT_EQ(AddEmb.size(), 2u); EXPECT_EQ(RetEmb.size(), 2u); @@ -340,51 +336,17 @@ TEST_F(IR2VecTestFixture, GetInstVecMap_Symbolic) { EXPECT_TRUE(RetEmb.approximatelyEquals(Embedding(2, 15.5))); } -TEST_F(IR2VecTestFixture, GetInstVecMap_FlowAware) { - auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, *V); - ASSERT_TRUE(static_cast(Emb)); - - const auto &InstMap = Emb->getInstVecMap(); - - EXPECT_EQ(InstMap.size(), 2u); - EXPECT_TRUE(InstMap.count(AddInst)); - EXPECT_TRUE(InstMap.count(RetInst)); - - EXPECT_EQ(InstMap.at(AddInst).size(), 2u); - EXPECT_EQ(InstMap.at(RetInst).size(), 2u); - - EXPECT_TRUE(InstMap.at(AddInst).approximatelyEquals(Embedding(2, 25.5))); - EXPECT_TRUE(InstMap.at(RetInst).approximatelyEquals(Embedding(2, 32.6))); -} - -TEST_F(IR2VecTestFixture, GetBBVecMap_Symbolic) { - auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, *V); - ASSERT_TRUE(static_cast(Emb)); - - const auto &BBMap = Emb->getBBVecMap(); - - EXPECT_EQ(BBMap.size(), 1u); - EXPECT_TRUE(BBMap.count(BB)); - EXPECT_EQ(BBMap.at(BB).size(), 2u); - - // BB vector should be sum of add and ret: {25.5, 25.5} + {15.5, 15.5} = - // {41.0, 41.0} - EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 41.0))); -} - -TEST_F(IR2VecTestFixture, GetBBVecMap_FlowAware) { +TEST_F(IR2VecTestFixture, GetInstVec_FlowAware) { auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, *V); ASSERT_TRUE(static_cast(Emb)); - const auto &BBMap = Emb->getBBVecMap(); - - EXPECT_EQ(BBMap.size(), 1u); - EXPECT_TRUE(BBMap.count(BB)); - EXPECT_EQ(BBMap.at(BB).size(), 2u); + const auto &AddEmb = Emb->getInstVector(*AddInst); + const auto &RetEmb = Emb->getInstVector(*RetInst); + EXPECT_EQ(AddEmb.size(), 2u); + EXPECT_EQ(RetEmb.size(), 2u); - // BB vector should be sum of add and ret: {25.5, 25.5} + {32.6, 32.6} = - // {58.1, 58.1} - EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 58.1))); + EXPECT_TRUE(AddEmb.approximatelyEquals(Embedding(2, 25.5))); + EXPECT_TRUE(RetEmb.approximatelyEquals(Embedding(2, 32.6))); } TEST_F(IR2VecTestFixture, GetBBVector_Symbolic) { @@ -394,6 +356,8 @@ TEST_F(IR2VecTestFixture, GetBBVector_Symbolic) { const auto &BBVec = Emb->getBBVector(*BB); EXPECT_EQ(BBVec.size(), 2u); + // BB vector should be sum of add and ret: {25.5, 25.5} + {15.5, 15.5} = + // {41.0, 41.0} EXPECT_TRUE(BBVec.approximatelyEquals(Embedding(2, 41.0))); } @@ -404,6 +368,8 @@ TEST_F(IR2VecTestFixture, GetBBVector_FlowAware) { const auto &BBVec = Emb->getBBVector(*BB); EXPECT_EQ(BBVec.size(), 2u); + // BB vector should be sum of add and ret: {25.5, 25.5} + {32.6, 32.6} = + // {58.1, 58.1} EXPECT_TRUE(BBVec.approximatelyEquals(Embedding(2, 58.1))); } @@ -446,15 +412,9 @@ TEST_F(IR2VecTestFixture, MultipleComputeEmbeddingsConsistency_Symbolic) { EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec3)); EXPECT_TRUE(FuncVec2.approximatelyEquals(FuncVec3)); - // Also check that instruction vectors remain consistent - const auto &InstMap1 = Emb->getInstVecMap(); - const auto &InstMap2 = Emb->getInstVecMap(); - - EXPECT_EQ(InstMap1.size(), InstMap2.size()); - for (const auto &[Inst, Vec1] : InstMap1) { - ASSERT_TRUE(InstMap2.count(Inst)); - EXPECT_TRUE(Vec1.approximatelyEquals(InstMap2.at(Inst))); - } + Emb->invalidateEmbeddings(); + const auto &FuncVec4 = Emb->getFunctionVector(); + EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec4)); } TEST_F(IR2VecTestFixture, MultipleComputeEmbeddingsConsistency_FlowAware) { @@ -473,15 +433,9 @@ TEST_F(IR2VecTestFixture, MultipleComputeEmbeddingsConsistency_FlowAware) { EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec3)); EXPECT_TRUE(FuncVec2.approximatelyEquals(FuncVec3)); - // Also check that instruction vectors remain consistent - const auto &InstMap1 = Emb->getInstVecMap(); - const auto &InstMap2 = Emb->getInstVecMap(); - - EXPECT_EQ(InstMap1.size(), InstMap2.size()); - for (const auto &[Inst, Vec1] : InstMap1) { - ASSERT_TRUE(InstMap2.count(Inst)); - EXPECT_TRUE(Vec1.approximatelyEquals(InstMap2.at(Inst))); - } + Emb->invalidateEmbeddings(); + const auto &FuncVec4 = Emb->getFunctionVector(); + EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec4)); } static constexpr unsigned MaxOpcodes = Vocabulary::MaxOpcodes;