From a620697340671aea2b0c65449fcddf3c2e4d1917 Mon Sep 17 00:00:00 2001 From: Arthur Eubanks Date: Wed, 8 May 2024 10:14:51 -0700 Subject: [PATCH] [IR] Check callee param attributes as well in CallBase::getParamAttr() (#91394) These methods aren't used yet, but may be in the future. This keeps them in line with other methods like getFnAttr(). --- llvm/include/llvm/IR/InstrTypes.h | 12 ++++++-- llvm/lib/IR/Instructions.cpp | 16 ++++++++++ llvm/unittests/IR/AttributesTest.cpp | 46 ++++++++++++++++++++++++++++ 3 files changed, 72 insertions(+), 2 deletions(-) diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h index eaade9ce4755fc..9dd1bb455a718e 100644 --- a/llvm/include/llvm/IR/InstrTypes.h +++ b/llvm/include/llvm/IR/InstrTypes.h @@ -1997,13 +1997,19 @@ class CallBase : public Instruction { /// Get the attribute of a given kind from a given arg Attribute getParamAttr(unsigned ArgNo, Attribute::AttrKind Kind) const { assert(ArgNo < arg_size() && "Out of bounds"); - return getAttributes().getParamAttr(ArgNo, Kind); + Attribute A = getAttributes().getParamAttr(ArgNo, Kind); + if (A.isValid()) + return A; + return getParamAttrOnCalledFunction(ArgNo, Kind); } /// Get the attribute of a given kind from a given arg Attribute getParamAttr(unsigned ArgNo, StringRef Kind) const { assert(ArgNo < arg_size() && "Out of bounds"); - return getAttributes().getParamAttr(ArgNo, Kind); + Attribute A = getAttributes().getParamAttr(ArgNo, Kind); + if (A.isValid()) + return A; + return getParamAttrOnCalledFunction(ArgNo, Kind); } /// Return true if the data operand at index \p i has the attribute \p @@ -2652,6 +2658,8 @@ class CallBase : public Instruction { return hasFnAttrOnCalledFunction(Kind); } template Attribute getFnAttrOnCalledFunction(AK Kind) const; + template + Attribute getParamAttrOnCalledFunction(unsigned ArgNo, AK Kind) const; /// Determine whether the return value has the given attribute. Supports /// Attribute::AttrKind and StringRef as \p AttrKind types. diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp index 7ad1ad4cddb703..32af58a43b68e4 100644 --- a/llvm/lib/IR/Instructions.cpp +++ b/llvm/lib/IR/Instructions.cpp @@ -500,6 +500,22 @@ template Attribute CallBase::getFnAttrOnCalledFunction(Attribute::AttrKind Kind) const; template Attribute CallBase::getFnAttrOnCalledFunction(StringRef Kind) const; +template +Attribute CallBase::getParamAttrOnCalledFunction(unsigned ArgNo, + AK Kind) const { + Value *V = getCalledOperand(); + + if (auto *F = dyn_cast(V)) + return F->getAttributes().getParamAttr(ArgNo, Kind); + + return Attribute(); +} +template Attribute +CallBase::getParamAttrOnCalledFunction(unsigned ArgNo, + Attribute::AttrKind Kind) const; +template Attribute CallBase::getParamAttrOnCalledFunction(unsigned ArgNo, + StringRef Kind) const; + void CallBase::getOperandBundlesAsDefs( SmallVectorImpl &Defs) const { for (unsigned i = 0, e = getNumOperandBundles(); i != e; ++i) diff --git a/llvm/unittests/IR/AttributesTest.cpp b/llvm/unittests/IR/AttributesTest.cpp index a7967593c2f960..da72fa14510cbe 100644 --- a/llvm/unittests/IR/AttributesTest.cpp +++ b/llvm/unittests/IR/AttributesTest.cpp @@ -340,4 +340,50 @@ TEST(Attributes, ConstantRangeAttributeCAPI) { } } +TEST(Attributes, CalleeAttributes) { + const char *IRString = R"IR( + declare void @f1(i32 %i) + declare void @f2(i32 range(i32 1, 2) %i) + + define void @g1(i32 %i) { + call void @f1(i32 %i) + ret void + } + define void @g2(i32 %i) { + call void @f2(i32 %i) + ret void + } + define void @g3(i32 %i) { + call void @f1(i32 range(i32 3, 4) %i) + ret void + } + define void @g4(i32 %i) { + call void @f2(i32 range(i32 3, 4) %i) + ret void + } + )IR"; + + SMDiagnostic Err; + LLVMContext Context; + std::unique_ptr M = parseAssemblyString(IRString, Err, Context); + ASSERT_TRUE(M); + + { + auto *I = cast(&M->getFunction("g1")->getEntryBlock().front()); + ASSERT_FALSE(I->getParamAttr(0, Attribute::Range).isValid()); + } + { + auto *I = cast(&M->getFunction("g2")->getEntryBlock().front()); + ASSERT_TRUE(I->getParamAttr(0, Attribute::Range).isValid()); + } + { + auto *I = cast(&M->getFunction("g3")->getEntryBlock().front()); + ASSERT_TRUE(I->getParamAttr(0, Attribute::Range).isValid()); + } + { + auto *I = cast(&M->getFunction("g4")->getEntryBlock().front()); + ASSERT_TRUE(I->getParamAttr(0, Attribute::Range).isValid()); + } +} + } // end anonymous namespace