10 changes: 10 additions & 0 deletions llvm/lib/IR/IntrinsicInst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,11 @@ ConstantInt *InstrProfCntrInstBase::getIndex() const {
return cast<ConstantInt>(const_cast<Value *>(getArgOperand(3)));
}

void InstrProfCntrInstBase::setIndex(uint32_t Idx) {
assert(isa<InstrProfCntrInstBase>(this));
setArgOperand(3, ConstantInt::get(Type::getInt32Ty(getContext()), Idx));
}

Value *InstrProfIncrementInst::getStep() const {
if (InstrProfIncrementInstStep::classof(this)) {
return const_cast<Value *>(getArgOperand(4));
Expand All @@ -300,6 +305,11 @@ Value *InstrProfCallsite::getCallee() const {
return nullptr;
}

void InstrProfCallsite::setCallee(Value *Callee) {
assert(isa<InstrProfCallsite>(this));
setArgOperand(4, Callee);
}

std::optional<RoundingMode> ConstrainedFPIntrinsic::getRoundingMode() const {
unsigned NumOperands = arg_size();
Metadata *MD = nullptr;
Expand Down
82 changes: 82 additions & 0 deletions llvm/lib/ProfileData/PGOCtxProfWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

#include "llvm/ProfileData/PGOCtxProfWriter.h"
#include "llvm/Bitstream/BitCodeEnums.h"
#include "llvm/ProfileData/CtxInstrContextNode.h"
#include "llvm/Support/JSON.h"

using namespace llvm;
using namespace llvm::ctx_profile;
Expand Down Expand Up @@ -81,3 +83,83 @@ void PGOCtxProfileWriter::writeImpl(std::optional<uint32_t> CallerIndex,
void PGOCtxProfileWriter::write(const ContextNode &RootNode) {
writeImpl(std::nullopt, RootNode);
}

namespace {
// A structural representation of the JSON input.
struct DeserializableCtx {
ctx_profile::GUID Guid = 0;
std::vector<uint64_t> Counters;
std::vector<std::vector<DeserializableCtx>> Callsites;
};

ctx_profile::ContextNode *
createNode(std::vector<std::unique_ptr<char[]>> &Nodes,
const std::vector<DeserializableCtx> &DCList);

// Convert a DeserializableCtx into a ContextNode, potentially linking it to
// its sibling (e.g. callee at same callsite) "Next".
ctx_profile::ContextNode *
createNode(std::vector<std::unique_ptr<char[]>> &Nodes,
const DeserializableCtx &DC,
ctx_profile::ContextNode *Next = nullptr) {
auto AllocSize = ctx_profile::ContextNode::getAllocSize(DC.Counters.size(),
DC.Callsites.size());
auto *Mem = Nodes.emplace_back(std::make_unique<char[]>(AllocSize)).get();
std::memset(Mem, 0, AllocSize);
auto *Ret = new (Mem) ctx_profile::ContextNode(DC.Guid, DC.Counters.size(),
DC.Callsites.size(), Next);
std::memcpy(Ret->counters(), DC.Counters.data(),
sizeof(uint64_t) * DC.Counters.size());
for (const auto &[I, DCList] : llvm::enumerate(DC.Callsites))
Ret->subContexts()[I] = createNode(Nodes, DCList);
return Ret;
}

// Convert a list of DeserializableCtx into a linked list of ContextNodes.
ctx_profile::ContextNode *
createNode(std::vector<std::unique_ptr<char[]>> &Nodes,
const std::vector<DeserializableCtx> &DCList) {
ctx_profile::ContextNode *List = nullptr;
for (const auto &DC : DCList)
List = createNode(Nodes, DC, List);
return List;
}
} // namespace

namespace llvm {
namespace json {
bool fromJSON(const Value &E, DeserializableCtx &R, Path P) {
json::ObjectMapper Mapper(E, P);
return Mapper && Mapper.map("Guid", R.Guid) &&
Mapper.map("Counters", R.Counters) &&
Mapper.mapOptional("Callsites", R.Callsites);
}
} // namespace json
} // namespace llvm

Error llvm::createCtxProfFromJSON(StringRef Profile, raw_ostream &Out) {
auto P = json::parse(Profile);
if (!P)
return P.takeError();

json::Path::Root R("");
std::vector<DeserializableCtx> DCList;
if (!fromJSON(*P, DCList, R))
return R.getError();
// Nodes provides memory backing for the ContextualNodes.
std::vector<std::unique_ptr<char[]>> Nodes;
std::error_code EC;
if (EC)
return createStringError(EC, "failed to open output");
PGOCtxProfileWriter Writer(Out);
for (const auto &DC : DCList) {
auto *TopList = createNode(Nodes, DC);
if (!TopList)
return createStringError(
"Unexpected error converting internal structure to ctx profile");
Writer.write(*TopList);
}
if (EC)
return createStringError(EC, "failed to write output");
return Error::success();
}
2 changes: 1 addition & 1 deletion llvm/lib/Transforms/Coroutines/CoroEarly.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ void Lowerer::lowerEarlyIntrinsics(Function &F) {
if (auto *CII = cast<CoroIdInst>(&I)) {
if (CII->getInfo().isPreSplit()) {
assert(F.isPresplitCoroutine() &&
"The frontend uses Swtich-Resumed ABI should emit "
"The frontend uses Switch-Resumed ABI should emit "
"\"presplitcoroutine\" attribute for the coroutine.");
setCannotDuplicate(CII);
CII->setCoroutineSelf();
Expand Down
163 changes: 98 additions & 65 deletions llvm/lib/Transforms/Coroutines/CoroFrame.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
//
//===----------------------------------------------------------------------===//
// This file contains classes used to discover if for a particular value
// there from sue to definition that crosses a suspend block.
// its definition precedes and its uses follow a suspend block. This is
// referred to as a suspend crossing value.
//
// Using the information discovered we form a Coroutine Frame structure to
// contain those values. All uses of those values are replaced with appropriate
Expand Down Expand Up @@ -124,7 +125,8 @@ class SuspendCrossingInfo {
public:
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
void dump() const;
void dump(StringRef Label, BitVector const &BV) const;
void dump(StringRef Label, BitVector const &BV,
const ReversePostOrderTraversal<Function *> &RPOT) const;
#endif

SuspendCrossingInfo(Function &F, coro::Shape &Shape);
Expand Down Expand Up @@ -207,21 +209,41 @@ class SuspendCrossingInfo {
} // end anonymous namespace

#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
LLVM_DUMP_METHOD void SuspendCrossingInfo::dump(StringRef Label,
BitVector const &BV) const {
static std::string getBasicBlockLabel(const BasicBlock *BB) {
if (BB->hasName())
return BB->getName().str();

std::string S;
raw_string_ostream OS(S);
BB->printAsOperand(OS, false);
return OS.str().substr(1);
}

LLVM_DUMP_METHOD void SuspendCrossingInfo::dump(
StringRef Label, BitVector const &BV,
const ReversePostOrderTraversal<Function *> &RPOT) const {
dbgs() << Label << ":";
for (size_t I = 0, N = BV.size(); I < N; ++I)
if (BV[I])
dbgs() << " " << Mapping.indexToBlock(I)->getName();
for (const BasicBlock *BB : RPOT) {
auto BBNo = Mapping.blockToIndex(BB);
if (BV[BBNo])
dbgs() << " " << getBasicBlockLabel(BB);
}
dbgs() << "\n";
}

LLVM_DUMP_METHOD void SuspendCrossingInfo::dump() const {
for (size_t I = 0, N = Block.size(); I < N; ++I) {
BasicBlock *const B = Mapping.indexToBlock(I);
dbgs() << B->getName() << ":\n";
dump(" Consumes", Block[I].Consumes);
dump(" Kills", Block[I].Kills);
if (Block.empty())
return;

BasicBlock *const B = Mapping.indexToBlock(0);
Function *F = B->getParent();

ReversePostOrderTraversal<Function *> RPOT(F);
for (const BasicBlock *BB : RPOT) {
auto BBNo = Mapping.blockToIndex(BB);
dbgs() << getBasicBlockLabel(BB) << ":\n";
dump(" Consumes", Block[BBNo].Consumes, RPOT);
dump(" Kills", Block[BBNo].Kills, RPOT);
}
dbgs() << "\n";
}
Expand Down Expand Up @@ -418,10 +440,7 @@ struct RematGraph {
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
void dump() const {
dbgs() << "Entry (";
if (EntryNode->Node->getParent()->hasName())
dbgs() << EntryNode->Node->getParent()->getName();
else
EntryNode->Node->getParent()->printAsOperand(dbgs(), false);
dbgs() << getBasicBlockLabel(EntryNode->Node->getParent());
dbgs() << ") : " << *EntryNode->Node << "\n";
for (auto &E : Remats) {
dbgs() << *(E.first) << "\n";
Expand Down Expand Up @@ -551,7 +570,7 @@ struct FrameDataInfo {

#ifndef NDEBUG
static void dumpSpills(StringRef Title, const SpillInfo &Spills) {
dbgs() << "------------- " << Title << "--------------\n";
dbgs() << "------------- " << Title << " --------------\n";
for (const auto &E : Spills) {
E.first->dump();
dbgs() << " user: ";
Expand Down Expand Up @@ -813,7 +832,7 @@ void FrameTypeBuilder::addFieldForAllocas(const Function &F,
StackLifetime StackLifetimeAnalyzer(F, ExtractAllocas(),
StackLifetime::LivenessType::May);
StackLifetimeAnalyzer.run();
auto IsAllocaInferenre = [&](const AllocaInst *AI1, const AllocaInst *AI2) {
auto DoAllocasInterfere = [&](const AllocaInst *AI1, const AllocaInst *AI2) {
return StackLifetimeAnalyzer.getLiveRange(AI1).overlaps(
StackLifetimeAnalyzer.getLiveRange(AI2));
};
Expand All @@ -833,13 +852,13 @@ void FrameTypeBuilder::addFieldForAllocas(const Function &F,
for (const auto &A : FrameData.Allocas) {
AllocaInst *Alloca = A.Alloca;
bool Merged = false;
// Try to find if the Alloca is not inferenced with any existing
// Try to find if the Alloca does not interfere with any existing
// NonOverlappedAllocaSet. If it is true, insert the alloca to that
// NonOverlappedAllocaSet.
for (auto &AllocaSet : NonOverlapedAllocas) {
assert(!AllocaSet.empty() && "Processing Alloca Set is not empty.\n");
bool NoInference = none_of(AllocaSet, [&](auto Iter) {
return IsAllocaInferenre(Alloca, Iter);
bool NoInterference = none_of(AllocaSet, [&](auto Iter) {
return DoAllocasInterfere(Alloca, Iter);
});
// If the alignment of A is multiple of the alignment of B, the address
// of A should satisfy the requirement for aligning for B.
Expand All @@ -852,7 +871,7 @@ void FrameTypeBuilder::addFieldForAllocas(const Function &F,
return LargestAlloca->getAlign().value() % Alloca->getAlign().value() ==
0;
}();
bool CouldMerge = NoInference && Alignable;
bool CouldMerge = NoInterference && Alignable;
if (!CouldMerge)
continue;
AllocaSet.push_back(Alloca);
Expand Down Expand Up @@ -1714,6 +1733,51 @@ static Instruction *splitBeforeCatchSwitch(CatchSwitchInst *CatchSwitch) {
return CleanupRet;
}

static BasicBlock::iterator getSpillInsertionPt(const coro::Shape &Shape,
Value *Def,
const DominatorTree &DT) {
BasicBlock::iterator InsertPt;
if (auto *Arg = dyn_cast<Argument>(Def)) {
// For arguments, we will place the store instruction right after
// the coroutine frame pointer instruction, i.e. coro.begin.
InsertPt = Shape.getInsertPtAfterFramePtr();

// If we're spilling an Argument, make sure we clear 'nocapture'
// from the coroutine function.
Arg->getParent()->removeParamAttr(Arg->getArgNo(), Attribute::NoCapture);
} else if (auto *CSI = dyn_cast<AnyCoroSuspendInst>(Def)) {
// Don't spill immediately after a suspend; splitting assumes
// that the suspend will be followed by a branch.
InsertPt = CSI->getParent()->getSingleSuccessor()->getFirstNonPHIIt();
} else {
auto *I = cast<Instruction>(Def);
if (!DT.dominates(Shape.CoroBegin, I)) {
// If it is not dominated by CoroBegin, then spill should be
// inserted immediately after CoroFrame is computed.
InsertPt = Shape.getInsertPtAfterFramePtr();
} else if (auto *II = dyn_cast<InvokeInst>(I)) {
// If we are spilling the result of the invoke instruction, split
// the normal edge and insert the spill in the new block.
auto *NewBB = SplitEdge(II->getParent(), II->getNormalDest());
InsertPt = NewBB->getTerminator()->getIterator();
} else if (isa<PHINode>(I)) {
// Skip the PHINodes and EH pads instructions.
BasicBlock *DefBlock = I->getParent();
if (auto *CSI = dyn_cast<CatchSwitchInst>(DefBlock->getTerminator()))
InsertPt = splitBeforeCatchSwitch(CSI)->getIterator();
else
InsertPt = DefBlock->getFirstInsertionPt();
} else {
assert(!I->isTerminator() && "unexpected terminator");
// For all other values, the spill is placed immediately after
// the definition.
InsertPt = I->getNextNode()->getIterator();
}
}

return InsertPt;
}

// Replace all alloca and SSA values that are accessed across suspend points
// with GetElementPointer from coroutine frame + loads and stores. Create an
// AllocaSpillBB that will become the new entry block for the resume parts of
Expand All @@ -1736,9 +1800,8 @@ static Instruction *splitBeforeCatchSwitch(CatchSwitchInst *CatchSwitch) {
//
//
static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) {
auto *CB = Shape.CoroBegin;
LLVMContext &C = CB->getContext();
Function *F = CB->getFunction();
LLVMContext &C = Shape.CoroBegin->getContext();
Function *F = Shape.CoroBegin->getFunction();
IRBuilder<> Builder(C);
StructType *FrameTy = Shape.FrameTy;
Value *FramePtr = Shape.FramePtr;
Expand Down Expand Up @@ -1799,47 +1862,16 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) {
auto SpillAlignment = Align(FrameData.getAlign(Def));
// Create a store instruction storing the value into the
// coroutine frame.
BasicBlock::iterator InsertPt;
BasicBlock::iterator InsertPt = getSpillInsertionPt(Shape, Def, DT);

Type *ByValTy = nullptr;
if (auto *Arg = dyn_cast<Argument>(Def)) {
// For arguments, we will place the store instruction right after
// the coroutine frame pointer instruction, i.e. coro.begin.
InsertPt = Shape.getInsertPtAfterFramePtr();

// If we're spilling an Argument, make sure we clear 'nocapture'
// from the coroutine function.
Arg->getParent()->removeParamAttr(Arg->getArgNo(), Attribute::NoCapture);

if (Arg->hasByValAttr())
ByValTy = Arg->getParamByValType();
} else if (auto *CSI = dyn_cast<AnyCoroSuspendInst>(Def)) {
// Don't spill immediately after a suspend; splitting assumes
// that the suspend will be followed by a branch.
InsertPt = CSI->getParent()->getSingleSuccessor()->getFirstNonPHIIt();
} else {
auto *I = cast<Instruction>(Def);
if (!DT.dominates(CB, I)) {
// If it is not dominated by CoroBegin, then spill should be
// inserted immediately after CoroFrame is computed.
InsertPt = Shape.getInsertPtAfterFramePtr();
} else if (auto *II = dyn_cast<InvokeInst>(I)) {
// If we are spilling the result of the invoke instruction, split
// the normal edge and insert the spill in the new block.
auto *NewBB = SplitEdge(II->getParent(), II->getNormalDest());
InsertPt = NewBB->getTerminator()->getIterator();
} else if (isa<PHINode>(I)) {
// Skip the PHINodes and EH pads instructions.
BasicBlock *DefBlock = I->getParent();
if (auto *CSI = dyn_cast<CatchSwitchInst>(DefBlock->getTerminator()))
InsertPt = splitBeforeCatchSwitch(CSI)->getIterator();
else
InsertPt = DefBlock->getFirstInsertionPt();
} else {
assert(!I->isTerminator() && "unexpected terminator");
// For all other values, the spill is placed immediately after
// the definition.
InsertPt = I->getNextNode()->getIterator();
}
}

auto Index = FrameData.getFieldIndex(Def);
Expand Down Expand Up @@ -1982,7 +2014,7 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) {
UsersToUpdate.clear();
for (User *U : Alloca->users()) {
auto *I = cast<Instruction>(U);
if (DT.dominates(CB, I))
if (DT.dominates(Shape.CoroBegin, I))
UsersToUpdate.push_back(I);
}
if (UsersToUpdate.empty())
Expand Down Expand Up @@ -2024,7 +2056,7 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) {
Builder.CreateStore(Value, G);
}
// For each alias to Alloca created before CoroBegin but used after
// CoroBegin, we recreate them after CoroBegin by appplying the offset
// CoroBegin, we recreate them after CoroBegin by applying the offset
// to the pointer in the frame.
for (const auto &Alias : A.Aliases) {
auto *FramePtr = GetFramePointer(Alloca);
Expand All @@ -2033,7 +2065,7 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) {
auto *AliasPtr =
Builder.CreatePtrAdd(FramePtr, ConstantInt::get(ITy, Value));
Alias.first->replaceUsesWithIf(
AliasPtr, [&](Use &U) { return DT.dominates(CB, U); });
AliasPtr, [&](Use &U) { return DT.dominates(Shape.CoroBegin, U); });
}
}

Expand All @@ -2046,7 +2078,7 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) {
// If there is memory accessing to promise alloca before CoroBegin;
bool HasAccessingPromiseBeforeCB = llvm::any_of(PA->uses(), [&](Use &U) {
auto *Inst = dyn_cast<Instruction>(U.getUser());
if (!Inst || DT.dominates(CB, Inst))
if (!Inst || DT.dominates(Shape.CoroBegin, Inst))
return false;

if (auto *CI = dyn_cast<CallInst>(Inst)) {
Expand Down Expand Up @@ -2692,7 +2724,7 @@ static void eliminateSwiftError(Function &F, coro::Shape &Shape) {
}
}

/// retcon and retcon.once conventions assume that all spill uses can be sunk
/// Async and Retcon{Once} conventions assume that all spill uses can be sunk
/// after the coro.begin intrinsic.
static void sinkSpillUsesAfterCoroBegin(Function &F,
const FrameDataInfo &FrameData,
Expand Down Expand Up @@ -2728,7 +2760,7 @@ static void sinkSpillUsesAfterCoroBegin(Function &F,
// Sort by dominance.
SmallVector<Instruction *, 64> InsertionList(ToMove.begin(), ToMove.end());
llvm::sort(InsertionList, [&Dom](Instruction *A, Instruction *B) -> bool {
// If a dominates b it should preceed (<) b.
// If a dominates b it should precede (<) b.
return Dom.dominates(A, B);
});

Expand Down Expand Up @@ -3126,7 +3158,7 @@ void coro::buildCoroutineFrame(
cleanupSinglePredPHIs(F);

// Transforms multi-edge PHI Nodes, so that any value feeding into a PHI will
// never has its definition separated from the PHI by the suspend point.
// never have its definition separated from the PHI by the suspend point.
rewritePHIs(F);

// Build suspend crossing info.
Expand Down Expand Up @@ -3223,6 +3255,7 @@ void coro::buildCoroutineFrame(
Shape.FramePtr = Shape.CoroBegin;
// For now, this works for C++ programs only.
buildFrameDebugInfo(F, Shape, FrameData);
// Insert spills and reloads
insertSpills(FrameData, Shape);
lowerLocalAllocas(LocalAllocas, DeadInstructions);

Expand Down
9 changes: 4 additions & 5 deletions llvm/lib/Transforms/Coroutines/CoroSplit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1221,11 +1221,10 @@ static void postSplitCleanup(Function &F) {
// frame if possible.
static void handleNoSuspendCoroutine(coro::Shape &Shape) {
auto *CoroBegin = Shape.CoroBegin;
auto *CoroId = CoroBegin->getId();
auto *AllocInst = CoroId->getCoroAlloc();
switch (Shape.ABI) {
case coro::ABI::Switch: {
auto SwitchId = cast<CoroIdInst>(CoroId);
auto SwitchId = Shape.getSwitchCoroId();
auto *AllocInst = SwitchId->getCoroAlloc();
coro::replaceCoroFree(SwitchId, /*Elide=*/AllocInst != nullptr);
if (AllocInst) {
IRBuilder<> Builder(AllocInst);
Expand Down Expand Up @@ -1689,7 +1688,7 @@ static void splitAsyncCoroutine(Function &F, coro::Shape &Shape,
auto &Context = F.getContext();
auto *Int8PtrTy = PointerType::getUnqual(Context);

auto *Id = cast<CoroIdAsyncInst>(Shape.CoroBegin->getId());
auto *Id = Shape.getAsyncCoroId();
IRBuilder<> Builder(Id);

auto *FramePtr = Id->getStorage();
Expand Down Expand Up @@ -1783,7 +1782,7 @@ static void splitRetconCoroutine(Function &F, coro::Shape &Shape,
F.removeRetAttr(Attribute::NonNull);

// Allocate the frame.
auto *Id = cast<AnyCoroIdRetconInst>(Shape.CoroBegin->getId());
auto *Id = Shape.getRetconCoroId();
Value *RawFramePtr;
if (Shape.RetconLowering.IsFrameInlineInStorage) {
RawFramePtr = Id->getStorage();
Expand Down
86 changes: 85 additions & 1 deletion llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/Utils/CallPromotionUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Analysis/CtxProfAnalysis.h"
#include "llvm/Analysis/Loads.h"
#include "llvm/Analysis/TypeMetadataUtils.h"
#include "llvm/IR/AttributeMask.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Module.h"
#include "llvm/ProfileData/PGOCtxProfReader.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"

using namespace llvm;
Expand Down Expand Up @@ -572,6 +574,88 @@ CallBase &llvm::promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
return promoteCall(NewInst, Callee);
}

CallBase *llvm::promoteCallWithIfThenElse(CallBase &CB, Function &Callee,
PGOContextualProfile &CtxProf) {
assert(CB.isIndirectCall());
if (!CtxProf.isFunctionKnown(Callee))
return nullptr;
auto &Caller = *CB.getFunction();
auto *CSInstr = CtxProfAnalysis::getCallsiteInstrumentation(CB);
if (!CSInstr)
return nullptr;
const uint64_t CSIndex = CSInstr->getIndex()->getZExtValue();

CallBase &DirectCall = promoteCall(
versionCallSite(CB, &Callee, /*BranchWeights=*/nullptr), &Callee);
CSInstr->moveBefore(&CB);
const auto NewCSID = CtxProf.allocateNextCallsiteIndex(Caller);
auto *NewCSInstr = cast<InstrProfCallsite>(CSInstr->clone());
NewCSInstr->setIndex(NewCSID);
NewCSInstr->setCallee(&Callee);
NewCSInstr->insertBefore(&DirectCall);
auto &DirectBB = *DirectCall.getParent();
auto &IndirectBB = *CB.getParent();

assert((CtxProfAnalysis::getBBInstrumentation(IndirectBB) == nullptr) &&
"The ICP direct BB is new, it shouldn't have instrumentation");
assert((CtxProfAnalysis::getBBInstrumentation(DirectBB) == nullptr) &&
"The ICP indirect BB is new, it shouldn't have instrumentation");

// Allocate counters for the new basic blocks.
const uint32_t DirectID = CtxProf.allocateNextCounterIndex(Caller);
const uint32_t IndirectID = CtxProf.allocateNextCounterIndex(Caller);
auto *EntryBBIns =
CtxProfAnalysis::getBBInstrumentation(Caller.getEntryBlock());
auto *DirectBBIns = cast<InstrProfCntrInstBase>(EntryBBIns->clone());
DirectBBIns->setIndex(DirectID);
DirectBBIns->insertInto(&DirectBB, DirectBB.getFirstInsertionPt());

auto *IndirectBBIns = cast<InstrProfCntrInstBase>(EntryBBIns->clone());
IndirectBBIns->setIndex(IndirectID);
IndirectBBIns->insertInto(&IndirectBB, IndirectBB.getFirstInsertionPt());

const GlobalValue::GUID CalleeGUID = AssignGUIDPass::getGUID(Callee);
const uint32_t NewCountersSize = IndirectID + 1;

auto ProfileUpdater = [&](PGOCtxProfContext &Ctx) {
assert(Ctx.guid() == AssignGUIDPass::getGUID(Caller));
assert(NewCountersSize - 2 == Ctx.counters().size());
// All the ctx-es belonging to a function must have the same size counters.
Ctx.resizeCounters(NewCountersSize);

// Maybe in this context, the indirect callsite wasn't observed at all
if (!Ctx.hasCallsite(CSIndex))
return;
auto &CSData = Ctx.callsite(CSIndex);
auto It = CSData.find(CalleeGUID);

// Maybe we did notice the indirect callsite, but to other targets.
if (It == CSData.end())
return;

assert(CalleeGUID == It->second.guid());

uint32_t DirectCount = It->second.getEntrycount();
uint32_t TotalCount = 0;
for (const auto &[_, V] : CSData)
TotalCount += V.getEntrycount();
assert(TotalCount >= DirectCount);
uint32_t IndirectCount = TotalCount - DirectCount;
// The ICP's effect is as-if the direct BB would have been taken DirectCount
// times, and the indirect BB, IndirectCount times
Ctx.counters()[DirectID] = DirectCount;
Ctx.counters()[IndirectID] = IndirectCount;

// This particular indirect target needs to be moved to this caller under
// the newly-allocated callsite index.
assert(Ctx.callsites().count(NewCSID) == 0);
Ctx.ingestContext(NewCSID, std::move(It->second));
CSData.erase(CalleeGUID);
};
CtxProf.update(ProfileUpdater, &Caller);
return &DirectCall;
}

CallBase &llvm::promoteCallWithVTableCmp(CallBase &CB, Instruction *VPtr,
Function *Callee,
ArrayRef<Constant *> AddressPoints,
Expand Down
82 changes: 7 additions & 75 deletions llvm/tools/llvm-ctxprof-util/llvm-ctxprof-util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,90 +46,22 @@ static cl::opt<std::string> OutputFilename("output", cl::value_desc("output"),
cl::desc("Output file"),
cl::sub(FromJSON));

namespace {
// A structural representation of the JSON input.
struct DeserializableCtx {
GlobalValue::GUID Guid = 0;
std::vector<uint64_t> Counters;
std::vector<std::vector<DeserializableCtx>> Callsites;
};

ctx_profile::ContextNode *
createNode(std::vector<std::unique_ptr<char[]>> &Nodes,
const std::vector<DeserializableCtx> &DCList);

// Convert a DeserializableCtx into a ContextNode, potentially linking it to
// its sibling (e.g. callee at same callsite) "Next".
ctx_profile::ContextNode *
createNode(std::vector<std::unique_ptr<char[]>> &Nodes,
const DeserializableCtx &DC,
ctx_profile::ContextNode *Next = nullptr) {
auto AllocSize = ctx_profile::ContextNode::getAllocSize(DC.Counters.size(),
DC.Callsites.size());
auto *Mem = Nodes.emplace_back(std::make_unique<char[]>(AllocSize)).get();
std::memset(Mem, 0, AllocSize);
auto *Ret = new (Mem) ctx_profile::ContextNode(DC.Guid, DC.Counters.size(),
DC.Callsites.size(), Next);
std::memcpy(Ret->counters(), DC.Counters.data(),
sizeof(uint64_t) * DC.Counters.size());
for (const auto &[I, DCList] : llvm::enumerate(DC.Callsites))
Ret->subContexts()[I] = createNode(Nodes, DCList);
return Ret;
}

// Convert a list of DeserializableCtx into a linked list of ContextNodes.
ctx_profile::ContextNode *
createNode(std::vector<std::unique_ptr<char[]>> &Nodes,
const std::vector<DeserializableCtx> &DCList) {
ctx_profile::ContextNode *List = nullptr;
for (const auto &DC : DCList)
List = createNode(Nodes, DC, List);
return List;
}
} // namespace

namespace llvm {
namespace json {
// Hook into the JSON deserialization.
bool fromJSON(const Value &E, DeserializableCtx &R, Path P) {
json::ObjectMapper Mapper(E, P);
return Mapper && Mapper.map("Guid", R.Guid) &&
Mapper.map("Counters", R.Counters) &&
Mapper.mapOptional("Callsites", R.Callsites);
}
} // namespace json
} // namespace llvm

// Save the bitstream profile from the JSON representation.
Error convertFromJSON() {
auto BufOrError = MemoryBuffer::getFileOrSTDIN(InputFilename);
if (!BufOrError)
return createFileError(InputFilename, BufOrError.getError());
auto P = json::parse(BufOrError.get()->getBuffer());
if (!P)
return P.takeError();

std::vector<DeserializableCtx> DCList;
json::Path::Root R("");
if (!fromJSON(*P, DCList, R))
return R.getError();
// Nodes provides memory backing for the ContextualNodes.
std::vector<std::unique_ptr<char[]>> Nodes;
std::error_code EC;
raw_fd_stream Out(OutputFilename, EC);
// Using a fd_ostream instead of a fd_stream. The latter would be more
// efficient as the bitstream writer supports incremental flush to it, but the
// json scenario is for test, and file size scalability doesn't really concern
// us.
raw_fd_ostream Out(OutputFilename, EC);
if (EC)
return createStringError(EC, "failed to open output");
PGOCtxProfileWriter Writer(Out);
for (const auto &DC : DCList) {
auto *TopList = createNode(Nodes, DC);
if (!TopList)
return createStringError(
"Unexpected error converting internal structure to ctx profile");
Writer.write(*TopList);
}
if (EC)
return createStringError(EC, "failed to write output");
return Error::success();

return llvm::createCtxProfFromJSON(BufOrError.get()->getBuffer(), Out);
}

int main(int argc, const char **argv) {
Expand Down
1 change: 1 addition & 0 deletions llvm/unittests/Transforms/Utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ set(LLVM_LINK_COMPONENTS
AsmParser
BitWriter
Core
ProfileData
Support
TransformUtils
Passes
Expand Down
157 changes: 157 additions & 0 deletions llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,21 @@
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/Utils/CallPromotionUtils.h"
#include "llvm/Analysis/CtxProfAnalysis.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/NoFolder.h"
#include "llvm/IR/PassInstrumentation.h"
#include "llvm/ProfileData/PGOCtxProfReader.h"
#include "llvm/ProfileData/PGOCtxProfWriter.h"
#include "llvm/Support/JSON.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Testing/Support/SupportHelpers.h"
#include "gtest/gtest.h"

using namespace llvm;
Expand Down Expand Up @@ -456,3 +463,153 @@ declare void @_ZN5Base35func3Ev(ptr)
// 1 call instruction from the entry block.
EXPECT_EQ(F->front().size(), OrigEntryBBSize + 4);
}

TEST(CallPromotionUtilsTest, PromoteWithIcmpAndCtxProf) {
LLVMContext C;
std::unique_ptr<Module> M = parseIR(C,
R"IR(
define i32 @testfunc1(ptr %d) !guid !0 {
call void @llvm.instrprof.increment(ptr @testfunc1, i64 0, i32 1, i32 0)
call void @llvm.instrprof.callsite(ptr @testfunc1, i64 0, i32 1, i32 0, ptr %d)
%call = call i32 %d()
ret i32 %call
}
define i32 @f1() !guid !1 {
call void @llvm.instrprof.increment(ptr @f1, i64 0, i32 1, i32 0)
ret i32 2
}
define i32 @f2() !guid !2 {
call void @llvm.instrprof.increment(ptr @f2, i64 0, i32 1, i32 0)
call void @llvm.instrprof.callsite(ptr @f2, i64 0, i32 1, i32 0, ptr @f4)
%r = call i32 @f4()
ret i32 %r
}
define i32 @testfunc2(ptr %p) !guid !4 {
call void @llvm.instrprof.increment(ptr @testfunc2, i64 0, i32 1, i32 0)
call void @llvm.instrprof.callsite(ptr @testfunc2, i64 0, i32 1, i32 0, ptr @testfunc1)
%r = call i32 @testfunc1(ptr %p)
ret i32 %r
}
declare i32 @f3()
define i32 @f4() !guid !3 {
ret i32 3
}
!0 = !{i64 1000}
!1 = !{i64 1001}
!2 = !{i64 1002}
!3 = !{i64 1004}
!4 = !{i64 1005}
)IR");

const char *Profile = R"json(
[
{
"Guid": 1000,
"Counters": [1],
"Callsites": [
[{ "Guid": 1001,
"Counters": [10]},
{ "Guid": 1002,
"Counters": [11],
"Callsites": [[{"Guid": 1004, "Counters":[13]}]]
},
{ "Guid": 1003,
"Counters": [12]
}]]
},
{
"Guid": 1005,
"Counters": [2],
"Callsites": [
[{ "Guid": 1000,
"Counters": [1],
"Callsites": [
[{ "Guid": 1001,
"Counters": [101]},
{ "Guid": 1002,
"Counters": [102],
"Callsites": [[{"Guid": 1004, "Counters":[104]}]]
},
{ "Guid": 1003,
"Counters": [103]
}]]}]]}]
)json";

llvm::unittest::TempFile ProfileFile("ctx_profile", "", "", /*Unique=*/true);
{
std::error_code EC;
raw_fd_stream Out(ProfileFile.path(), EC);
ASSERT_FALSE(EC);
// "False" means no error.
ASSERT_FALSE(llvm::createCtxProfFromJSON(Profile, Out));
}

ModuleAnalysisManager MAM;
MAM.registerPass([&]() { return CtxProfAnalysis(ProfileFile.path()); });
MAM.registerPass([&]() { return PassInstrumentationAnalysis(); });
auto &CtxProf = MAM.getResult<CtxProfAnalysis>(*M);
auto *Caller = M->getFunction("testfunc1");
ASSERT_NE(Caller, nullptr);
auto *Callee = M->getFunction("f2");
ASSERT_NE(Callee, nullptr);
auto *IndirectCS = [&]() -> CallBase * {
for (auto &BB : *Caller)
for (auto &I : BB)
if (auto *CB = dyn_cast<CallBase>(&I); CB && CB->isIndirectCall())
return CB;
return nullptr;
}();
ASSERT_NE(IndirectCS, nullptr);
promoteCallWithIfThenElse(*IndirectCS, *Callee, CtxProf);

std::string Str;
raw_string_ostream OS(Str);
CtxProfAnalysisPrinterPass Printer(
OS, CtxProfAnalysisPrinterPass::PrintMode::JSON);
Printer.run(*M, MAM);
const char *Expected = R"json(
[
{
"Guid": 1000,
"Counters": [1, 11, 22],
"Callsites": [
[{ "Guid": 1001,
"Counters": [10]},
{ "Guid": 1003,
"Counters": [12]
}],
[{ "Guid": 1002,
"Counters": [11],
"Callsites": [
[{ "Guid": 1004,
"Counters": [13] }]]}]]
},
{
"Guid": 1005,
"Counters": [2],
"Callsites": [
[{ "Guid": 1000,
"Counters": [1, 102, 204],
"Callsites": [
[{ "Guid": 1001,
"Counters": [101]},
{ "Guid": 1003,
"Counters": [103]}],
[{ "Guid": 1002,
"Counters": [102],
"Callsites": [
[{ "Guid": 1004,
"Counters": [104]}]]}]]}]]}
])json";
auto ExpectedJSON = json::parse(Expected);
ASSERT_TRUE(!!ExpectedJSON);
auto ProducedJSON = json::parse(Str);
ASSERT_TRUE(!!ProducedJSON);
EXPECT_EQ(*ProducedJSON, *ExpectedJSON);
}
11 changes: 11 additions & 0 deletions mlir/include/mlir-c/Dialect/GPU.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ MLIR_CAPI_EXPORTED MlirAttribute
mlirGPUObjectAttrGet(MlirContext mlirCtx, MlirAttribute target, uint32_t format,
MlirStringRef objectStrRef, MlirAttribute mlirObjectProps);

MLIR_CAPI_EXPORTED MlirAttribute mlirGPUObjectAttrGetWithKernels(
MlirContext mlirCtx, MlirAttribute target, uint32_t format,
MlirStringRef objectStrRef, MlirAttribute mlirObjectProps,
MlirAttribute mlirKernelsAttr);

MLIR_CAPI_EXPORTED MlirAttribute
mlirGPUObjectAttrGetTarget(MlirAttribute mlirObjectAttr);

Expand All @@ -52,6 +57,12 @@ mlirGPUObjectAttrHasProperties(MlirAttribute mlirObjectAttr);
MLIR_CAPI_EXPORTED MlirAttribute
mlirGPUObjectAttrGetProperties(MlirAttribute mlirObjectAttr);

MLIR_CAPI_EXPORTED bool
mlirGPUObjectAttrHasKernels(MlirAttribute mlirObjectAttr);

MLIR_CAPI_EXPORTED MlirAttribute
mlirGPUObjectAttrGetKernels(MlirAttribute mlirObjectAttr);

#ifdef __cplusplus
}
#endif
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def GPUTargetAttrInterface : AttrInterface<"TargetAttrInterface"> {
InterfaceMethod<[{
Creates a GPU object attribute from a binary string.

The `module` parameter must be a `GPUModuleOp` and can be used to
retrieve additional information like the list of kernels in the binary.
The `object` parameter is a binary string. The `options` parameter is
meant to be used for passing additional options that are not in the
attribute.
Expand Down
175 changes: 171 additions & 4 deletions mlir/include/mlir/Dialect/GPU/IR/CompilationAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,155 @@
include "mlir/Dialect/GPU/IR/GPUBase.td"
include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"

//===----------------------------------------------------------------------===//
// GPU kernel metadata attribute
//===----------------------------------------------------------------------===//

def GPU_KernelMetadataAttr : GPU_Attr<"KernelMetadata", "kernel_metadata"> {
let description = [{
GPU attribute for storing metadata related to a compiled kernel. The
attribute contains the name and arguments type of the kernel.

The attribute also contains optional parameters for storing the arguments
attributes as well as a dictionary for additional metadata, like occupancy
information or other function attributes.

Note: The `arg_attrs` parameter is expected to follow all the constraints
imposed by the `mlir::FunctionOpInterface` interface.

Examples:
```mlir
#gpu.kernel_metadata<@kernel1, (i32) -> (), arg_attrs = [...], metadata = {reg_count = 255, ...}>
#gpu.kernel_metadata<@kernel2, (i32, f64) -> ()>
```
}];
let parameters = (ins
"StringAttr":$name,
"Type":$function_type,
OptionalParameter<"ArrayAttr", "arguments attributes">:$arg_attrs,
OptionalParameter<"DictionaryAttr", "metadata dictionary">:$metadata
);
let assemblyFormat = [{
`<` $name `,` $function_type (`,` struct($arg_attrs, $metadata)^)? `>`
}];
let builders = [
AttrBuilderWithInferredContext<(ins "StringAttr":$name,
"Type":$functionType,
CArg<"ArrayAttr", "nullptr">:$argAttrs,
CArg<"DictionaryAttr",
"nullptr">:$metadata), [{
assert(name && "invalid name");
return $_get(name.getContext(), name, functionType, argAttrs, metadata);
}]>,
AttrBuilderWithInferredContext<(ins "FunctionOpInterface":$kernel,
CArg<"DictionaryAttr",
"nullptr">:$metadata)>
];
let genVerifyDecl = 1;
let extraClassDeclaration = [{
/// Compare two kernels based on the name.
bool operator<(const KernelMetadataAttr& other) const {
return getName().getValue() < other.getName().getValue();
}

/// Returns the metadata attribute corresponding to `key` or `nullptr`
/// if missing.
Attribute getAttr(StringRef key) const {
DictionaryAttr attrs = getMetadata();
return attrs ? attrs.get(key) : nullptr;
}
template <typename ConcreteAttr>
ConcreteAttr getAttr(StringRef key) const {
return llvm::dyn_cast_or_null<ConcreteAttr>(getAttr(key));
}
Attribute getAttr(StringAttr key) const {
DictionaryAttr attrs = getMetadata();
return attrs ? attrs.get(key) : nullptr;
}
template <typename ConcreteAttr>
ConcreteAttr getAttr(StringAttr key) const {
return llvm::dyn_cast_or_null<ConcreteAttr>(getAttr(key));
}

/// Returns the attribute dictionary at position `index`.
DictionaryAttr getArgAttrDict(unsigned index) {
ArrayAttr argArray = getArgAttrs();
return argArray ? llvm::cast<DictionaryAttr>(argArray[index]) : nullptr;
}

/// Return the specified attribute, if present, for the argument at 'index',
/// null otherwise.
Attribute getArgAttr(unsigned index, StringAttr name) {
DictionaryAttr argDict = getArgAttrDict(index);
return argDict ? argDict.get(name) : nullptr;
}
Attribute getArgAttr(unsigned index, StringRef name) {
DictionaryAttr argDict = getArgAttrDict(index);
return argDict ? argDict.get(name) : nullptr;
}

/// Returns a new KernelMetadataAttr that contains `attrs` in the metadata dictionary.
KernelMetadataAttr appendMetadata(ArrayRef<NamedAttribute> attrs) const;
}];
}

//===----------------------------------------------------------------------===//
// GPU kernel table attribute
//===----------------------------------------------------------------------===//

def GPU_KernelTableAttr : GPU_Attr<"KernelTable", "kernel_table"> {
let description = [{
GPU attribute representing a list of `#gpu.kernel_metadata` attributes. This
attribute supports searching kernels by name. All kernels in the table must
have an unique name.

Examples:
```mlir
// Empty table.
#gpu.kernel_table<>

// Table with a single kernel.
#gpu.kernel_table<[#gpu.kernel_metadata<kernel0, () -> () >]>

// Table with multiple kernels.
#gpu.kernel_table<[
#gpu.kernel_metadata<"kernel0", (i32, f32) -> (), metadata = {sgpr_count = 255}>,
#gpu.kernel_metadata<"kernel1", (i32) -> ()>
]>
```
}];
let parameters = (ins
OptionalArrayRefParameter<"KernelMetadataAttr", "array of kernels">:$kernel_table
);
let assemblyFormat = [{
`<` (`[` qualified($kernel_table)^ `]`)? `>`
}];
let builders = [
AttrBuilder<(ins "ArrayRef<KernelMetadataAttr>":$kernels,
CArg<"bool", "false">:$isSorted)>
];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
let extraClassDeclaration = [{
llvm::ArrayRef<KernelMetadataAttr>::iterator begin() const {
return getKernelTable().begin();
}
llvm::ArrayRef<KernelMetadataAttr>::iterator end() const {
return getKernelTable().end();
}
size_t size() const {
return getKernelTable().size();
}
bool empty() const {
return getKernelTable().empty();
}

/// Returns the kernel with name `key` or `nullptr` if not present.
KernelMetadataAttr lookup(StringRef key) const;
KernelMetadataAttr lookup(StringAttr key) const;
}];
}

//===----------------------------------------------------------------------===//
// GPU object attribute.
//===----------------------------------------------------------------------===//
Expand All @@ -36,8 +185,9 @@ def GPU_CompilationTargetEnum : GPU_I32Enum<
def GPU_ObjectAttr : GPU_Attr<"Object", "object"> {
let description = [{
A GPU object attribute glues together a GPU target, the object kind, a
binary string with the object, and the object properties, encapsulating how
the object was generated and its properties with the object itself.
binary string with the object, the object properties, and kernel metadata,
encapsulating how the object was generated and its properties with the
object itself.

There are four object formats:
1. `Offload`: represents generic objects not described by the other three
Expand All @@ -55,6 +205,10 @@ def GPU_ObjectAttr : GPU_Attr<"Object", "object"> {

Object properties are specified through the `properties` dictionary
attribute and can be used to define additional information.

Kernel metadata is specified through the `kernels` parameter, and can be
used to specify additional information on a kernel by kernel basis.

The target attribute must implement or promise the `TargetAttrInterface`
interface.

Expand All @@ -63,16 +217,29 @@ def GPU_ObjectAttr : GPU_Attr<"Object", "object"> {
#gpu.object<#nvvm.target, properties = {O = 3 : i32}, assembly = "..."> // An assembly object with additional properties.
#gpu.object<#rocdl.target, bin = "..."> // A binary object.
#gpu.object<#nvvm.target, "..."> // A fatbin object.
#gpu.object<#nvvm.target, kernels = #gpu.kernel_table<...>, "..."> // An object with a kernel table.
```
}];
let parameters = (ins
"Attribute":$target,
DefaultValuedParameter<"CompilationTarget", "CompilationTarget::Fatbin">:$format,
"StringAttr":$object,
OptionalParameter<"DictionaryAttr">:$properties
OptionalParameter<"DictionaryAttr">:$properties,
OptionalParameter<"KernelTableAttr">:$kernels
);
let builders = [
AttrBuilderWithInferredContext<(ins "Attribute":$target,
"CompilationTarget":$format,
"StringAttr":$object,
CArg<"DictionaryAttr", "nullptr">:$properties,
CArg<"KernelTableAttr", "nullptr">:$kernels), [{
assert(target && "invalid target");
return $_get(target.getContext(), target, format, object, properties, kernels);
}]>
];
let assemblyFormat = [{ `<`
$target `,` (`properties` `=` $properties ^ `,`)?
$target `,` (`properties` `=` $properties^ `,`)?
(`kernels` `=` $kernels^ `,`)?
custom<Object>($format, $object)
`>`
}];
Expand Down
15 changes: 15 additions & 0 deletions mlir/include/mlir/Target/LLVM/ROCDL/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define MLIR_TARGET_LLVM_ROCDL_UTILS_H

#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Target/LLVM/ModuleToObject.h"
Expand Down Expand Up @@ -107,6 +108,20 @@ class SerializeGPUModuleBase : public LLVM::ModuleToObject {
/// AMD GCN libraries to use when linking, the default is using none.
AMDGCNLibraries deviceLibs = AMDGCNLibraries::None;
};

/// Returns a map containing the `amdhsa.kernels` ELF metadata for each of the
/// kernels in the binary, or `std::nullopt` if the metadata couldn't be
/// retrieved. The map associates the name of the kernel with the list of named
/// attributes found in `amdhsa.kernels`. For more information on the ELF
/// metadata see: https://llvm.org/docs/AMDGPUUsage.html#amdhsa
std::optional<DenseMap<StringAttr, NamedAttrList>>
getAMDHSAKernelsELFMetadata(Builder &builder, ArrayRef<char> elfData);

/// Returns a `#gpu.kernel_table` containing kernel metadata for each of the
/// kernels in `gpuModule`. If `elfData` is valid, then the `amdhsa.kernels` ELF
/// metadata will be added to the `#gpu.kernel_table`.
gpu::KernelTableAttr getKernelMetadata(Operation *gpuModule,
ArrayRef<char> elfData = {});
} // namespace ROCDL
} // namespace mlir

Expand Down
23 changes: 17 additions & 6 deletions mlir/lib/Bindings/Python/DialectGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,21 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) {
.def_classmethod(
"get",
[](py::object cls, MlirAttribute target, uint32_t format,
py::bytes object, std::optional<MlirAttribute> mlirObjectProps) {
py::bytes object, std::optional<MlirAttribute> mlirObjectProps,
std::optional<MlirAttribute> mlirKernelsAttr) {
py::buffer_info info(py::buffer(object).request());
MlirStringRef objectStrRef =
mlirStringRefCreate(static_cast<char *>(info.ptr), info.size);
return cls(mlirGPUObjectAttrGet(
return cls(mlirGPUObjectAttrGetWithKernels(
mlirAttributeGetContext(target), target, format, objectStrRef,
mlirObjectProps.has_value() ? *mlirObjectProps
: MlirAttribute{nullptr},
mlirKernelsAttr.has_value() ? *mlirKernelsAttr
: MlirAttribute{nullptr}));
},
"cls"_a, "target"_a, "format"_a, "object"_a,
"properties"_a = py::none(), "Gets a gpu.object from parameters.")
"properties"_a = py::none(), "kernels"_a = py::none(),
"Gets a gpu.object from parameters.")
.def_property_readonly(
"target",
[](MlirAttribute self) { return mlirGPUObjectAttrGetTarget(self); })
Expand All @@ -71,9 +75,16 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) {
MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self);
return py::bytes(stringRef.data, stringRef.length);
})
.def_property_readonly("properties", [](MlirAttribute self) {
if (mlirGPUObjectAttrHasProperties(self))
return py::cast(mlirGPUObjectAttrGetProperties(self));
.def_property_readonly("properties",
[](MlirAttribute self) {
if (mlirGPUObjectAttrHasProperties(self))
return py::cast(
mlirGPUObjectAttrGetProperties(self));
return py::none().cast<py::object>();
})
.def_property_readonly("kernels", [](MlirAttribute self) {
if (mlirGPUObjectAttrHasKernels(self))
return py::cast(mlirGPUObjectAttrGetKernels(self));
return py::none().cast<py::object>();
});
}
37 changes: 34 additions & 3 deletions mlir/lib/CAPI/Dialect/GPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,28 @@ MlirAttribute mlirGPUObjectAttrGet(MlirContext mlirCtx, MlirAttribute target,
DictionaryAttr objectProps;
if (mlirObjectProps.ptr != nullptr)
objectProps = llvm::cast<DictionaryAttr>(unwrap(mlirObjectProps));
return wrap(gpu::ObjectAttr::get(ctx, unwrap(target),
static_cast<gpu::CompilationTarget>(format),
StringAttr::get(ctx, object), objectProps));
return wrap(gpu::ObjectAttr::get(
ctx, unwrap(target), static_cast<gpu::CompilationTarget>(format),
StringAttr::get(ctx, object), objectProps, nullptr));
}

MlirAttribute mlirGPUObjectAttrGetWithKernels(MlirContext mlirCtx,
MlirAttribute target,
uint32_t format,
MlirStringRef objectStrRef,
MlirAttribute mlirObjectProps,
MlirAttribute mlirKernelsAttr) {
MLIRContext *ctx = unwrap(mlirCtx);
llvm::StringRef object = unwrap(objectStrRef);
DictionaryAttr objectProps;
if (mlirObjectProps.ptr != nullptr)
objectProps = llvm::cast<DictionaryAttr>(unwrap(mlirObjectProps));
gpu::KernelTableAttr kernels;
if (mlirKernelsAttr.ptr != nullptr)
kernels = llvm::cast<gpu::KernelTableAttr>(unwrap(mlirKernelsAttr));
return wrap(gpu::ObjectAttr::get(
ctx, unwrap(target), static_cast<gpu::CompilationTarget>(format),
StringAttr::get(ctx, object), objectProps, kernels));
}

MlirAttribute mlirGPUObjectAttrGetTarget(MlirAttribute mlirObjectAttr) {
Expand Down Expand Up @@ -78,3 +97,15 @@ MlirAttribute mlirGPUObjectAttrGetProperties(MlirAttribute mlirObjectAttr) {
llvm::cast<gpu::ObjectAttr>(unwrap(mlirObjectAttr));
return wrap(objectAttr.getProperties());
}

bool mlirGPUObjectAttrHasKernels(MlirAttribute mlirObjectAttr) {
gpu::ObjectAttr objectAttr =
llvm::cast<gpu::ObjectAttr>(unwrap(mlirObjectAttr));
return objectAttr.getKernels() != nullptr;
}

MlirAttribute mlirGPUObjectAttrGetKernels(MlirAttribute mlirObjectAttr) {
gpu::ObjectAttr objectAttr =
llvm::cast<gpu::ObjectAttr>(unwrap(mlirObjectAttr));
return wrap(objectAttr.getKernels());
}
110 changes: 109 additions & 1 deletion mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2091,7 +2091,8 @@ void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,

LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
Attribute target, CompilationTarget format,
StringAttr object, DictionaryAttr properties) {
StringAttr object, DictionaryAttr properties,
KernelTableAttr kernels) {
if (!target)
return emitError() << "the target attribute cannot be null";
if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
Expand Down Expand Up @@ -2177,6 +2178,113 @@ LogicalResult gpu::DynamicSharedMemoryOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// GPU KernelMetadataAttr
//===----------------------------------------------------------------------===//

KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel,
DictionaryAttr metadata) {
assert(kernel && "invalid kernel");
return get(kernel.getNameAttr(), kernel.getFunctionType(),
kernel.getAllArgAttrs(), metadata);
}

KernelMetadataAttr
KernelMetadataAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
FunctionOpInterface kernel,
DictionaryAttr metadata) {
assert(kernel && "invalid kernel");
return getChecked(emitError, kernel.getNameAttr(), kernel.getFunctionType(),
kernel.getAllArgAttrs(), metadata);
}

KernelMetadataAttr
KernelMetadataAttr::appendMetadata(ArrayRef<NamedAttribute> attrs) const {
if (attrs.empty())
return *this;
NamedAttrList attrList;
if (DictionaryAttr dict = getMetadata())
attrList.append(dict);
attrList.append(attrs);
return KernelMetadataAttr::get(getName(), getFunctionType(), getArgAttrs(),
attrList.getDictionary(getContext()));
}

LogicalResult
KernelMetadataAttr::verify(function_ref<InFlightDiagnostic()> emitError,
StringAttr name, Type functionType,
ArrayAttr argAttrs, DictionaryAttr metadata) {
if (name.empty())
return emitError() << "the kernel name can't be empty";
if (argAttrs) {
if (llvm::any_of(argAttrs, [](Attribute attr) {
return !llvm::isa<DictionaryAttr>(attr);
}))
return emitError()
<< "all attributes in the array must be a dictionary attribute";
}
return success();
}

//===----------------------------------------------------------------------===//
// GPU KernelTableAttr
//===----------------------------------------------------------------------===//

KernelTableAttr KernelTableAttr::get(MLIRContext *context,
ArrayRef<KernelMetadataAttr> kernels,
bool isSorted) {
// Note that `is_sorted` is always only invoked once even with assertions ON.
assert((!isSorted || llvm::is_sorted(kernels)) &&
"expected a sorted kernel array");
// Immediately return the attribute if the array is sorted.
if (isSorted || llvm::is_sorted(kernels))
return Base::get(context, kernels);
// Sort the array.
SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
return Base::get(context, kernelsTmp);
}

KernelTableAttr KernelTableAttr::getChecked(
function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
ArrayRef<KernelMetadataAttr> kernels, bool isSorted) {
// Note that `is_sorted` is always only invoked once even with assertions ON.
assert((!isSorted || llvm::is_sorted(kernels)) &&
"expected a sorted kernel array");
// Immediately return the attribute if the array is sorted.
if (isSorted || llvm::is_sorted(kernels))
return Base::getChecked(emitError, context, kernels);
// Sort the array.
SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
return Base::getChecked(emitError, context, kernelsTmp);
}

LogicalResult
KernelTableAttr::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<KernelMetadataAttr> kernels) {
if (kernels.size() < 2)
return success();
// Check that the kernels are uniquely named.
if (std::adjacent_find(kernels.begin(), kernels.end(),
[](KernelMetadataAttr l, KernelMetadataAttr r) {
return l.getName() == r.getName();
}) != kernels.end()) {
return emitError() << "expected all kernels to be uniquely named";
}
return success();
}

KernelMetadataAttr KernelTableAttr::lookup(StringRef key) const {
auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
return found ? *iterator : KernelMetadataAttr();
}

KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
return found ? *iterator : KernelMetadataAttr();
}

//===----------------------------------------------------------------------===//
// GPU target options
//===----------------------------------------------------------------------===//
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Target/LLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,12 @@ endif()

add_mlir_dialect_library(MLIRROCDLTarget
ROCDL/Target.cpp
ROCDL/Utils.cpp

OBJECT

LINK_COMPONENTS
FrontendOffloading
MCParser
${AMDGPU_LIBS}

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Target/LLVM/NVVM/Target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -604,5 +604,5 @@ NVVMTargetAttrImpl::createObject(Attribute attribute, Operation *module,
return builder.getAttr<gpu::ObjectAttr>(
attribute, format,
builder.getStringAttr(StringRef(object.data(), object.size())),
objectProps);
objectProps, /*kernels=*/nullptr);
}
14 changes: 8 additions & 6 deletions mlir/lib/Target/LLVM/ROCDL/Target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -506,13 +506,15 @@ ROCDLTargetAttrImpl::createObject(Attribute attribute, Operation *module,
gpu::CompilationTarget format = options.getCompilationTarget();
// If format is `fatbin` transform it to binary as `fatbin` is not yet
// supported.
if (format > gpu::CompilationTarget::Binary)
gpu::KernelTableAttr kernels;
if (format > gpu::CompilationTarget::Binary) {
format = gpu::CompilationTarget::Binary;

kernels = ROCDL::getKernelMetadata(module, object);
}
DictionaryAttr properties{};
Builder builder(attribute.getContext());
return builder.getAttr<gpu::ObjectAttr>(
attribute, format,
builder.getStringAttr(StringRef(object.data(), object.size())),
properties);
StringAttr objectStr =
builder.getStringAttr(StringRef(object.data(), object.size()));
return builder.getAttr<gpu::ObjectAttr>(attribute, format, objectStr,
properties, kernels);
}
87 changes: 87 additions & 0 deletions mlir/lib/Target/LLVM/ROCDL/Utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
//===- Utils.cpp - MLIR ROCDL target utils ----------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This files defines ROCDL target related utility classes and functions.
//
//===----------------------------------------------------------------------===//

#include "mlir/Target/LLVM/ROCDL/Utils.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"

#include "llvm/ADT/StringMap.h"
#include "llvm/Frontend/Offloading/Utility.h"

using namespace mlir;
using namespace mlir::ROCDL;

std::optional<DenseMap<StringAttr, NamedAttrList>>
mlir::ROCDL::getAMDHSAKernelsELFMetadata(Builder &builder,
ArrayRef<char> elfData) {
uint16_t elfABIVersion;
llvm::StringMap<llvm::offloading::amdgpu::AMDGPUKernelMetaData> kernels;
llvm::MemoryBufferRef buffer(StringRef(elfData.data(), elfData.size()),
"buffer");
// Get the metadata.
llvm::Error error = llvm::offloading::amdgpu::getAMDGPUMetaDataFromImage(
buffer, kernels, elfABIVersion);
// Return `nullopt` if the metadata couldn't be retrieved.
if (error) {
llvm::consumeError(std::move(error));
return std::nullopt;
}
// Helper lambda for converting values.
auto getI32Array = [&builder](const uint32_t *array) {
return builder.getDenseI32ArrayAttr({static_cast<int32_t>(array[0]),
static_cast<int32_t>(array[1]),
static_cast<int32_t>(array[2])});
};
DenseMap<StringAttr, NamedAttrList> kernelMD;
for (const auto &[name, kernel] : kernels) {
NamedAttrList attrs;
// Add kernel metadata.
attrs.append("agpr_count", builder.getI64IntegerAttr(kernel.AGPRCount));
attrs.append("sgpr_count", builder.getI64IntegerAttr(kernel.SGPRCount));
attrs.append("vgpr_count", builder.getI64IntegerAttr(kernel.VGPRCount));
attrs.append("sgpr_spill_count",
builder.getI64IntegerAttr(kernel.SGPRSpillCount));
attrs.append("vgpr_spill_count",
builder.getI64IntegerAttr(kernel.VGPRSpillCount));
attrs.append("wavefront_size",
builder.getI64IntegerAttr(kernel.WavefrontSize));
attrs.append("max_flat_workgroup_size",
builder.getI64IntegerAttr(kernel.MaxFlatWorkgroupSize));
attrs.append("group_segment_fixed_size",
builder.getI64IntegerAttr(kernel.GroupSegmentList));
attrs.append("private_segment_fixed_size",
builder.getI64IntegerAttr(kernel.PrivateSegmentSize));
attrs.append("reqd_workgroup_size",
getI32Array(kernel.RequestedWorkgroupSize));
attrs.append("workgroup_size_hint", getI32Array(kernel.WorkgroupSizeHint));
kernelMD[builder.getStringAttr(name)] = std::move(attrs);
}
return std::move(kernelMD);
}

gpu::KernelTableAttr mlir::ROCDL::getKernelMetadata(Operation *gpuModule,
ArrayRef<char> elfData) {
auto module = cast<gpu::GPUModuleOp>(gpuModule);
Builder builder(module.getContext());
SmallVector<gpu::KernelMetadataAttr> kernels;
std::optional<DenseMap<StringAttr, NamedAttrList>> mdMapOrNull =
getAMDHSAKernelsELFMetadata(builder, elfData);
for (auto funcOp : module.getBody()->getOps<LLVM::LLVMFuncOp>()) {
if (!funcOp->getDiscardableAttr("rocdl.kernel"))
continue;
kernels.push_back(gpu::KernelMetadataAttr::get(
funcOp, mdMapOrNull ? builder.getDictionaryAttr(
mdMapOrNull->lookup(funcOp.getNameAttr()))
: nullptr));
}
return gpu::KernelTableAttr::get(gpuModule->getContext(), kernels);
}
2 changes: 1 addition & 1 deletion mlir/lib/Target/SPIRV/Target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,5 +98,5 @@ SPIRVTargetAttrImpl::createObject(Attribute attribute, Operation *module,
return builder.getAttr<gpu::ObjectAttr>(
attribute, format,
builder.getStringAttr(StringRef(object.data(), object.size())),
objectProps);
objectProps, /*kernels=*/nullptr);
}
12 changes: 12 additions & 0 deletions mlir/test/Dialect/GPU/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -848,3 +848,15 @@ module attributes {gpu.container_module} {
gpu.module @kernel <> {
}
}

// -----

gpu.binary @binary [#gpu.object<#rocdl.target<chip = "gfx900">,
// expected-error@+1{{expected all kernels to be uniquely named}}
kernels = #gpu.kernel_table<[
#gpu.kernel_metadata<"kernel", (i32) -> ()>,
#gpu.kernel_metadata<"kernel", (i32, f32) -> (), metadata = {sgpr_count = 255}>
// expected-error@below{{failed to parse GPU_ObjectAttr parameter 'kernels' which is to be a `KernelTableAttr`}}
]>,
bin = "BLOB">
]
23 changes: 23 additions & 0 deletions mlir/test/Dialect/GPU/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -441,3 +441,26 @@ gpu.module @module_with_two_target [#nvvm.target, #rocdl.target<chip = "gfx90a">

gpu.module @module_with_offload_handler <#gpu.select_object<0>> [#nvvm.target] {
}

// Test kernel attributes
gpu.binary @kernel_attrs_1 [
#gpu.object<#rocdl.target<chip = "gfx900">,
kernels = #gpu.kernel_table<[
#gpu.kernel_metadata<"kernel0", (i32, f32) -> (), metadata = {sgpr_count = 255}>,
#gpu.kernel_metadata<"kernel1", (i32) -> (), arg_attrs = [{llvm.read_only}]>
]>,
bin = "BLOB">
]

// Verify the kernels are sorted
// CHECK-LABEL: gpu.binary @kernel_attrs_2
gpu.binary @kernel_attrs_2 [
// CHECK: [#gpu.kernel_metadata<"a_kernel", () -> ()>, #gpu.kernel_metadata<"m_kernel", () -> ()>, #gpu.kernel_metadata<"z_kernel", () -> ()>]
#gpu.object<#rocdl.target<chip = "gfx900">,
kernels = #gpu.kernel_table<[
#gpu.kernel_metadata<"z_kernel", () -> ()>,
#gpu.kernel_metadata<"m_kernel", () -> ()>,
#gpu.kernel_metadata<"a_kernel", () -> ()>
]>,
bin = "BLOB">
]
9 changes: 9 additions & 0 deletions mlir/test/python/dialects/gpu/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,12 @@ def testObjectAttr():
# CHECK: #gpu.object<#nvvm.target, "//\0A// Generated by LLVM NVPTX Back-End\0A//\0A\0A.version 6.0\0A.target sm_50">
print(o)
assert o.object == object

object = b"BC\xc0\xde5\x14\x00\x00\x05\x00\x00\x00b\x0c0$MY\xbef"
kernelTable = Attribute.parse(
'#gpu.kernel_table<[#gpu.kernel_metadata<"kernel", () -> ()>]>'
)
o = gpu.ObjectAttr.get(target, format, object, kernels=kernelTable)
# CHECK: #gpu.object<#nvvm.target, kernels = <[#gpu.kernel_metadata<"kernel", () -> ()>]>, "BC\C0\DE5\14\00\00\05\00\00\00b\0C0$MY\BEf">
print(o)
assert o.kernels == kernelTable
66 changes: 66 additions & 0 deletions mlir/unittests/Target/LLVM/SerializeROCDLTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,69 @@ TEST_F(MLIRTargetLLVMROCDL, SKIP_WITHOUT_AMDGPU(SerializeROCDLToBinary)) {
ASSERT_FALSE(object->empty());
}
}

// Test ROCDL metadata.
TEST_F(MLIRTargetLLVMROCDL, SKIP_WITHOUT_AMDGPU(GetELFMetadata)) {
if (!hasROCMTools())
GTEST_SKIP() << "ROCm installation not found, skipping test.";

MLIRContext context(registry);

// MLIR module used for the tests.
const std::string moduleStr = R"mlir(
gpu.module @rocdl_test {
llvm.func @rocdl_kernel_1(%arg0: f32) attributes {gpu.kernel, rocdl.kernel} {
llvm.return
}
llvm.func @rocdl_kernel_0(%arg0: f32) attributes {gpu.kernel, rocdl.kernel} {
llvm.return
}
llvm.func @rocdl_kernel_2(%arg0: f32) attributes {gpu.kernel, rocdl.kernel} {
llvm.return
}
llvm.func @a_kernel(%arg0: f32) attributes {gpu.kernel, rocdl.kernel} {
llvm.return
}
})mlir";

OwningOpRef<ModuleOp> module =
parseSourceString<ModuleOp>(moduleStr, &context);
ASSERT_TRUE(!!module);

// Create a ROCDL target.
ROCDL::ROCDLTargetAttr target = ROCDL::ROCDLTargetAttr::get(&context);

// Serialize the module.
auto serializer = dyn_cast<gpu::TargetAttrInterface>(target);
ASSERT_TRUE(!!serializer);
gpu::TargetOptions options("", {}, "", gpu::CompilationTarget::Binary);
for (auto gpuModule : (*module).getBody()->getOps<gpu::GPUModuleOp>()) {
std::optional<SmallVector<char, 0>> object =
serializer.serializeToObject(gpuModule, options);
// Check that the serializer was successful.
ASSERT_TRUE(object != std::nullopt);
ASSERT_FALSE(object->empty());
if (!object)
continue;
// Get the metadata.
gpu::KernelTableAttr metadata =
ROCDL::getKernelMetadata(gpuModule, *object);
ASSERT_TRUE(metadata != nullptr);
// There should be 4 kernels.
ASSERT_TRUE(metadata.size() == 4);
// Check that the lookup method returns finds the kernel.
ASSERT_TRUE(metadata.lookup("a_kernel") != nullptr);
ASSERT_TRUE(metadata.lookup("rocdl_kernel_0") != nullptr);
// Check that the kernel doesn't exist.
ASSERT_TRUE(metadata.lookup("not_existent_kernel") == nullptr);
// Test the `KernelMetadataAttr` iterators.
for (gpu::KernelMetadataAttr kernel : metadata) {
// Check that the ELF metadata is present.
ASSERT_TRUE(kernel.getMetadata() != nullptr);
// Verify that `sgpr_count` is present and it is an integer attribute.
ASSERT_TRUE(kernel.getAttr<IntegerAttr>("sgpr_count") != nullptr);
// Verify that `vgpr_count` is present and it is an integer attribute.
ASSERT_TRUE(kernel.getAttr<IntegerAttr>("vgpr_count") != nullptr);
}
}
}
8 changes: 7 additions & 1 deletion mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(SerializeToLLVMBitcode)) {
std::optional<SmallVector<char, 0>>
TargetAttrImpl::serializeToObject(Attribute attribute, Operation *module,
const gpu::TargetOptions &options) const {
// Set a dummy attr to be retrieved by `createObject`.
module->setAttr("serialize_attr", UnitAttr::get(module->getContext()));
std::string targetTriple = llvm::sys::getProcessTriple();
LLVM::ModuleToObject serializer(*module, targetTriple, "", "");
Expand All @@ -112,13 +113,18 @@ Attribute
TargetAttrImpl::createObject(Attribute attribute, Operation *module,
const SmallVector<char, 0> &object,
const gpu::TargetOptions &options) const {
// Create a GPU object with the GPU module dictionary as the object
// properties.
return gpu::ObjectAttr::get(
module->getContext(), attribute, gpu::CompilationTarget::Offload,
StringAttr::get(module->getContext(),
StringRef(object.data(), object.size())),
module->getAttrDictionary());
module->getAttrDictionary(), /*kernels=*/nullptr);
}

// This test checks the correct functioning of `TargetAttrInterface` as an API.
// In particular, it shows how `TargetAttrInterface::createObject` can leverage
// the `module` operation argument to retrieve information from the module.
TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(TargetAttrAPI)) {
MLIRContext context(registry);
context.loadAllAvailableDialects();
Expand Down