diff --git a/llvm/include/llvm/IR/Module.h b/llvm/include/llvm/IR/Module.h index a99937a90cbb7..ac6c20b81d68c 100644 --- a/llvm/include/llvm/IR/Module.h +++ b/llvm/include/llvm/IR/Module.h @@ -710,6 +710,17 @@ class LLVM_ABI Module { return make_range(begin(), end()); } + /// Get an iterator range over all function definitions (excluding + /// declarations). + auto getFunctionDefs() { + return make_filter_range(functions(), + [](Function &F) { return !F.isDeclaration(); }); + } + auto getFunctionDefs() const { + return make_filter_range( + functions(), [](const Function &F) { return !F.isDeclaration(); }); + } + /// @} /// @name Alias Iteration /// @{ diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp index 7402782bfd404..7b8d3f093a3d1 100644 --- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp +++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp @@ -395,9 +395,7 @@ class MIR2VecTool { /// FIXME: Use --target option to get target info directly, avoiding the need /// to parse machine functions for pre-training operations. bool initializeVocabularyForLayout(const Module &M) { - for (const Function &F : M) { - if (F.isDeclaration()) - continue; + for (const Function &F : M.getFunctionDefs()) { MachineFunction *MF = MMI.getMachineFunction(F); if (!MF) @@ -431,9 +429,7 @@ class MIR2VecTool { std::string Relationships; raw_string_ostream RelOS(Relationships); - for (const Function &F : M) { - if (F.isDeclaration()) - continue; + for (const Function &F : M.getFunctionDefs()) { MachineFunction *MF = MMI.getMachineFunction(F); if (!MF) { @@ -532,9 +528,7 @@ class MIR2VecTool { return; } - for (const Function &F : M) { - if (F.isDeclaration()) - continue; + for (const Function &F : M.getFunctionDefs()) { MachineFunction *MF = MMI.getMachineFunction(F); if (!MF) { diff --git a/llvm/unittests/IR/ModuleTest.cpp b/llvm/unittests/IR/ModuleTest.cpp index 30eda738020d0..1e4565b219386 100644 --- a/llvm/unittests/IR/ModuleTest.cpp +++ b/llvm/unittests/IR/ModuleTest.cpp @@ -433,4 +433,82 @@ define void @Foo2() { ASSERT_EQ(M2Str, M1Print); } +TEST(ModuleTest, FunctionDefinitions) { + // Test getFunctionDefs() method which returns only functions with bodies + LLVMContext Context; + SMDiagnostic Err; + std::unique_ptr M = parseAssemblyString(R"( +declare void @Decl1() +declare void @Decl2() + +define void @Def1() { + ret void +} + +define void @Def2() { + ret void +} + +declare void @Decl3() + +define void @Def3() { + ret void +} +)", + Err, Context); + ASSERT_TRUE(M); + + // Count total functions (should be 6: 3 declarations + 3 definitions) + size_t TotalFunctions = 0; + for (Function &F : *M) { + (void)F; + ++TotalFunctions; + } + EXPECT_EQ(TotalFunctions, 6u); + + // Count function definitions only (should be 3) + size_t DefinitionCount = 0; + for (Function &F : M->getFunctionDefs()) { + EXPECT_FALSE(F.isDeclaration()); + ++DefinitionCount; + } + EXPECT_EQ(DefinitionCount, 3u); + + // Verify the names of the definitions + auto DefRange = M->getFunctionDefs(); + auto It = DefRange.begin(); + EXPECT_EQ(It->getName(), "Def1"); + ++It; + EXPECT_EQ(It->getName(), "Def2"); + ++It; + EXPECT_EQ(It->getName(), "Def3"); + ++It; + EXPECT_EQ(It, DefRange.end()); +} + +TEST(ModuleTest, FunctionDefinitionsEmpty) { + // Test getFunctionDefs() with no definitions (only declarations) + LLVMContext Context; + SMDiagnostic Err; + std::unique_ptr M = parseAssemblyString(R"( +declare void @Decl1() +declare void @Decl2() +declare void @Decl3() +)", + Err, Context); + ASSERT_TRUE(M); + + // Should have functions + EXPECT_FALSE(M->empty()); + EXPECT_EQ(M->size(), 3u); + + // But no definitions + size_t DefinitionCount = 0; + for (Function &F : M->getFunctionDefs()) { + (void)F; + ++DefinitionCount; + } + EXPECT_EQ(DefinitionCount, 0u); +} + } // end namespace