diff --git a/llvm/include/llvm/IR/Module.h b/llvm/include/llvm/IR/Module.h index 3895e99c10d232..3052651a372267 100644 --- a/llvm/include/llvm/IR/Module.h +++ b/llvm/include/llvm/IR/Module.h @@ -156,6 +156,11 @@ class Module { /// converted result in MFB. static bool isValidModFlagBehavior(Metadata *MD, ModFlagBehavior &MFB); + /// Check if the given module flag metadata represents a valid module flag, + /// and store the flag behavior, the key string and the value metadata. + static bool isValidModuleFlag(const MDNode &ModFlag, ModFlagBehavior &MFB, + MDString *&Key, Metadata *&Val); + struct ModuleFlagEntry { ModFlagBehavior Behavior; MDString *Key; @@ -493,10 +498,12 @@ class Module { void addModuleFlag(ModFlagBehavior Behavior, StringRef Key, Constant *Val); void addModuleFlag(ModFlagBehavior Behavior, StringRef Key, uint32_t Val); void addModuleFlag(MDNode *Node); + /// Like addModuleFlag but replaces the old module flag if it already exists. + void setModuleFlag(ModFlagBehavior Behavior, StringRef Key, Metadata *Val); -/// @} -/// @name Materialization -/// @{ + /// @} + /// @name Materialization + /// @{ /// Sets the GVMaterializer to GVM. This module must not yet have a /// Materializer. To reset the materializer for a module that already has one, diff --git a/llvm/lib/IR/Module.cpp b/llvm/lib/IR/Module.cpp index f1acf4653de66e..9ac1edb2519d37 100644 --- a/llvm/lib/IR/Module.cpp +++ b/llvm/lib/IR/Module.cpp @@ -283,6 +283,20 @@ bool Module::isValidModFlagBehavior(Metadata *MD, ModFlagBehavior &MFB) { return false; } +bool Module::isValidModuleFlag(const MDNode &ModFlag, ModFlagBehavior &MFB, + MDString *&Key, Metadata *&Val) { + if (ModFlag.getNumOperands() < 3) + return false; + if (!isValidModFlagBehavior(ModFlag.getOperand(0), MFB)) + return false; + MDString *K = dyn_cast_or_null(ModFlag.getOperand(1)); + if (!K) + return false; + Key = K; + Val = ModFlag.getOperand(2); + return true; +} + /// getModuleFlagsMetadata - Returns the module flags in the provided vector. void Module:: getModuleFlagsMetadata(SmallVectorImpl &Flags) const { @@ -291,13 +305,11 @@ getModuleFlagsMetadata(SmallVectorImpl &Flags) const { for (const MDNode *Flag : ModFlags->operands()) { ModFlagBehavior MFB; - if (Flag->getNumOperands() >= 3 && - isValidModFlagBehavior(Flag->getOperand(0), MFB) && - dyn_cast_or_null(Flag->getOperand(1))) { + MDString *Key = nullptr; + Metadata *Val = nullptr; + if (isValidModuleFlag(*Flag, MFB, Key, Val)) { // Check the operands of the MDNode before accessing the operands. // The verifier will actually catch these failures. - MDString *Key = cast(Flag->getOperand(1)); - Metadata *Val = Flag->getOperand(2); Flags.push_back(ModuleFlagEntry(MFB, Key, Val)); } } @@ -358,6 +370,23 @@ void Module::addModuleFlag(MDNode *Node) { getOrInsertModuleFlagsMetadata()->addOperand(Node); } +void Module::setModuleFlag(ModFlagBehavior Behavior, StringRef Key, + Metadata *Val) { + NamedMDNode *ModFlags = getOrInsertModuleFlagsMetadata(); + // Replace the flag if it already exists. + for (unsigned I = 0, E = ModFlags->getNumOperands(); I != E; ++I) { + MDNode *Flag = ModFlags->getOperand(I); + ModFlagBehavior MFB; + MDString *K = nullptr; + Metadata *V = nullptr; + if (isValidModuleFlag(*Flag, MFB, K, V) && K->getString() == Key) { + Flag->replaceOperandWith(2, Val); + return; + } + } + addModuleFlag(Behavior, Key, Val); +} + void Module::setDataLayout(StringRef Desc) { DL.reset(Desc); } @@ -547,9 +576,9 @@ void Module::setCodeModel(CodeModel::Model CL) { void Module::setProfileSummary(Metadata *M, ProfileSummary::Kind Kind) { if (Kind == ProfileSummary::PSK_CSInstr) - addModuleFlag(ModFlagBehavior::Error, "CSProfileSummary", M); + setModuleFlag(ModFlagBehavior::Error, "CSProfileSummary", M); else - addModuleFlag(ModFlagBehavior::Error, "ProfileSummary", M); + setModuleFlag(ModFlagBehavior::Error, "ProfileSummary", M); } Metadata *Module::getProfileSummary(bool IsCS) { diff --git a/llvm/unittests/IR/ModuleTest.cpp b/llvm/unittests/IR/ModuleTest.cpp index f642b002a5eb1e..7b34d5d0ee5546 100644 --- a/llvm/unittests/IR/ModuleTest.cpp +++ b/llvm/unittests/IR/ModuleTest.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "llvm/IR/Module.h" +#include "llvm/AsmParser/Parser.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/Pass.h" #include "llvm/Support/RandomNumberGenerator.h" @@ -72,4 +73,52 @@ TEST(ModuleTest, randomNumberGenerator) { RandomStreams[1].begin())); } +TEST(ModuleTest, setModuleFlag) { + LLVMContext Context; + Module M("M", Context); + StringRef Key = "Key"; + Metadata *Val1 = MDString::get(Context, "Val1"); + Metadata *Val2 = MDString::get(Context, "Val2"); + EXPECT_EQ(nullptr, M.getModuleFlag(Key)); + M.setModuleFlag(Module::ModFlagBehavior::Error, Key, Val1); + EXPECT_EQ(Val1, M.getModuleFlag(Key)); + M.setModuleFlag(Module::ModFlagBehavior::Error, Key, Val2); + EXPECT_EQ(Val2, M.getModuleFlag(Key)); +} + +const char *IRString = R"IR( + !llvm.module.flags = !{!0} + + !0 = !{i32 1, !"ProfileSummary", !1} + !1 = !{!2, !3, !4, !5, !6, !7, !8, !9} + !2 = !{!"ProfileFormat", !"SampleProfile"} + !3 = !{!"TotalCount", i64 10000} + !4 = !{!"MaxCount", i64 10} + !5 = !{!"MaxInternalCount", i64 1} + !6 = !{!"MaxFunctionCount", i64 1000} + !7 = !{!"NumCounts", i64 200} + !8 = !{!"NumFunctions", i64 3} + !9 = !{!"DetailedSummary", !10} + !10 = !{!11, !12, !13} + !11 = !{i32 10000, i64 1000, i32 1} + !12 = !{i32 990000, i64 300, i32 10} + !13 = !{i32 999999, i64 5, i32 100} +)IR"; + +TEST(ModuleTest, setProfileSummary) { + SMDiagnostic Err; + LLVMContext Context; + std::unique_ptr M = parseAssemblyString(IRString, Err, Context); + auto *PS = ProfileSummary::getFromMD(M->getProfileSummary(/*IsCS*/ false)); + EXPECT_NE(nullptr, PS); + EXPECT_EQ(false, PS->isPartialProfile()); + PS->setPartialProfile(true); + M->setProfileSummary(PS->getMD(Context), ProfileSummary::PSK_Sample); + delete PS; + PS = ProfileSummary::getFromMD(M->getProfileSummary(/*IsCS*/ false)); + EXPECT_NE(nullptr, PS); + EXPECT_EQ(true, PS->isPartialProfile()); + delete PS; +} + } // end namespace