Skip to content

Commit

Permalink
[SPIRV] read kernel arg attributes from fuction/module metadata
Browse files Browse the repository at this point in the history
The patch introduces reading the attributes of kernel arguments both from
function-attached and module-level metadata, during kernel arguments lowering.
Two tests are added to show the improvement.

Differential Revision: https://reviews.llvm.org/D135106

Co-authored-by: Aleksandr Bezzubikov <zuban32s@gmail.com>
Co-authored-by: Michal Paszkowski <michal.paszkowski@outlook.com>
Co-authored-by: Andrey Tretyakov <andrey.tretyakov@mail.com>
Co-authored-by: Konrad Trifunovic <konrad.trifunovic@intel.com>
  • Loading branch information
5 people committed Oct 6, 2022
1 parent 933fefb commit 25ee36c
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 19 deletions.
127 changes: 108 additions & 19 deletions llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
Expand Up @@ -115,6 +115,102 @@ static FunctionType *getOriginalFunctionType(const Function &F) {
return FunctionType::get(RetTy, ArgTypes, F.isVarArg());
}

static MDString *getKernelArgAttribute(const Function &KernelFunction,
unsigned ArgIdx,
const StringRef AttributeName) {
assert(KernelFunction.getCallingConv() == CallingConv::SPIR_KERNEL &&
"Kernel attributes are attached/belong only to kernel functions");

// Lookup the argument attribute in metadata attached to the kernel function.
MDNode *Node = KernelFunction.getMetadata(AttributeName);
if (Node && ArgIdx < Node->getNumOperands())
return cast<MDString>(Node->getOperand(ArgIdx));

// Sometimes metadata containing kernel attributes is not attached to the
// function, but can be found in the named module-level metadata instead.
// For example:
// !opencl.kernels = !{!0}
// !0 = !{void ()* @someKernelFunction, !1, ...}
// !1 = !{!"kernel_arg_addr_space", ...}
// In this case the actual index of searched argument attribute is ArgIdx + 1,
// since the first metadata node operand is occupied by attribute name
// ("kernel_arg_addr_space" in the example above).
unsigned MDArgIdx = ArgIdx + 1;
NamedMDNode *OpenCLKernelsMD =
KernelFunction.getParent()->getNamedMetadata("opencl.kernels");
if (!OpenCLKernelsMD || OpenCLKernelsMD->getNumOperands() == 0)
return nullptr;

// KernelToMDNodeList contains kernel function declarations followed by
// corresponding MDNodes for each attribute. Search only MDNodes "belonging"
// to the currently lowered kernel function.
MDNode *KernelToMDNodeList = OpenCLKernelsMD->getOperand(0);
bool FoundLoweredKernelFunction = false;
for (const MDOperand &Operand : KernelToMDNodeList->operands()) {
ValueAsMetadata *MaybeValue = dyn_cast<ValueAsMetadata>(Operand);
if (MaybeValue && dyn_cast<Function>(MaybeValue->getValue())->getName() ==
KernelFunction.getName()) {
FoundLoweredKernelFunction = true;
continue;
}
if (MaybeValue && FoundLoweredKernelFunction)
return nullptr;

MDNode *MaybeNode = dyn_cast<MDNode>(Operand);
if (FoundLoweredKernelFunction && MaybeNode &&
cast<MDString>(MaybeNode->getOperand(0))->getString() ==
AttributeName &&
MDArgIdx < MaybeNode->getNumOperands())
return cast<MDString>(MaybeNode->getOperand(MDArgIdx));
}
return nullptr;
}

static SPIRV::AccessQualifier::AccessQualifier
getArgAccessQual(const Function &F, unsigned ArgIdx) {
if (F.getCallingConv() != CallingConv::SPIR_KERNEL)
return SPIRV::AccessQualifier::ReadWrite;

MDString *ArgAttribute =
getKernelArgAttribute(F, ArgIdx, "kernel_arg_access_qual");
if (!ArgAttribute)
return SPIRV::AccessQualifier::ReadWrite;

if (ArgAttribute->getString().compare("read_only") == 0)
return SPIRV::AccessQualifier::ReadOnly;
if (ArgAttribute->getString().compare("write_only") == 0)
return SPIRV::AccessQualifier::WriteOnly;
return SPIRV::AccessQualifier::ReadWrite;
}

static std::vector<SPIRV::Decoration::Decoration>
getKernelArgTypeQual(const Function &KernelFunction, unsigned ArgIdx) {
MDString *ArgAttribute =
getKernelArgAttribute(KernelFunction, ArgIdx, "kernel_arg_type_qual");
if (ArgAttribute && ArgAttribute->getString().compare("volatile") == 0)
return {SPIRV::Decoration::Volatile};
return {};
}

static Type *getArgType(const Function &F, unsigned ArgIdx) {
Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(ArgIdx);
if (F.getCallingConv() != CallingConv::SPIR_KERNEL ||
isSpecialOpaqueType(OriginalArgType))
return OriginalArgType;

MDString *MDKernelArgType =
getKernelArgAttribute(F, ArgIdx, "kernel_arg_type");
if (!MDKernelArgType || !MDKernelArgType->getString().endswith("_t"))
return OriginalArgType;

std::string KernelArgTypeStr = "opencl." + MDKernelArgType->getString().str();
Type *ExistingOpaqueType =
StructType::getTypeByName(F.getContext(), KernelArgTypeStr);
return ExistingOpaqueType
? ExistingOpaqueType
: StructType::create(F.getContext(), KernelArgTypeStr);
}

bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
const Function &F,
ArrayRef<ArrayRef<Register>> VRegs,
Expand All @@ -132,18 +228,10 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
// TODO: handle the case of multiple registers.
if (VRegs[i].size() > 1)
return false;
Type *ArgTy = FTy->getParamType(i);
SPIRV::AccessQualifier::AccessQualifier AQ =
SPIRV::AccessQualifier::ReadWrite;
MDNode *Node = F.getMetadata("kernel_arg_access_qual");
if (Node && i < Node->getNumOperands()) {
StringRef AQString = cast<MDString>(Node->getOperand(i))->getString();
if (AQString.compare("read_only") == 0)
AQ = SPIRV::AccessQualifier::ReadOnly;
else if (AQString.compare("write_only") == 0)
AQ = SPIRV::AccessQualifier::WriteOnly;
}
auto *SpirvTy = GR->assignTypeToVReg(ArgTy, VRegs[i][0], MIRBuilder, AQ);
SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
getArgAccessQual(F, i);
auto *SpirvTy = GR->assignTypeToVReg(getArgType(F, i), VRegs[i][0],
MIRBuilder, ArgAccessQual);
ArgTypeVRegs.push_back(SpirvTy);

if (Arg.hasName())
Expand Down Expand Up @@ -178,14 +266,15 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
buildOpDecorate(VRegs[i][0], MIRBuilder,
SPIRV::Decoration::FuncParamAttr, {Attr});
}
Node = F.getMetadata("kernel_arg_type_qual");
if (Node && i < Node->getNumOperands()) {
StringRef TypeQual = cast<MDString>(Node->getOperand(i))->getString();
if (TypeQual.compare("volatile") == 0)
buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Volatile,
{});

if (F.getCallingConv() == CallingConv::SPIR_KERNEL) {
std::vector<SPIRV::Decoration::Decoration> ArgTypeQualDecs =
getKernelArgTypeQual(F, i);
for (SPIRV::Decoration::Decoration Decoration : ArgTypeQualDecs)
buildOpDecorate(VRegs[i][0], MIRBuilder, Decoration, {});
}
Node = F.getMetadata("spirv.ParameterDecorations");

MDNode *Node = F.getMetadata("spirv.ParameterDecorations");
if (Node && i < Node->getNumOperands() &&
isa<MDNode>(Node->getOperand(i))) {
MDNode *MD = cast<MDNode>(Node->getOperand(i));
Expand Down
@@ -0,0 +1,12 @@
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s

; CHECK: %[[#TypeSampler:]] = OpTypeSampler
define spir_kernel void @foo(i64 %sampler) !kernel_arg_addr_space !7 !kernel_arg_access_qual !8 !kernel_arg_type !9 !kernel_arg_type_qual !10 !kernel_arg_base_type !9 {
entry:
ret void
}

!7 = !{i32 0}
!8 = !{!"none"}
!9 = !{!"sampler_t"}
!10 = !{!""}
16 changes: 16 additions & 0 deletions llvm/test/CodeGen/SPIRV/opencl/kernel_arg_type_module_metadata.ll
@@ -0,0 +1,16 @@
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s

; CHECK: %[[#TypeSampler:]] = OpTypeSampler
define spir_kernel void @foo(i64 %sampler) {
entry:
ret void
}
!opencl.kernels = !{!0}

!0 = !{void (i64)* @foo, !1, !2, !3, !4, !5, !6}
!1 = !{!"kernel_arg_addr_space", i32 0}
!2 = !{!"kernel_arg_access_qual", !"none"}
!3 = !{!"kernel_arg_type", !"sampler_t"}
!4 = !{!"kernel_arg_type_qual", !""}
!5 = !{!"kernel_arg_base_type", !"sampler_t"}
!6 = !{!"kernel_arg_name", !"sampler"}

0 comments on commit 25ee36c

Please sign in to comment.