diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp index 774941d1f17ea9..18193bf2a9ad5e 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp @@ -115,6 +115,102 @@ static FunctionType *getOriginalFunctionType(const Function &F) { return FunctionType::get(RetTy, ArgTypes, F.isVarArg()); } +static MDString *getKernelArgAttribute(const Function &KernelFunction, + unsigned ArgIdx, + const StringRef AttributeName) { + assert(KernelFunction.getCallingConv() == CallingConv::SPIR_KERNEL && + "Kernel attributes are attached/belong only to kernel functions"); + + // Lookup the argument attribute in metadata attached to the kernel function. + MDNode *Node = KernelFunction.getMetadata(AttributeName); + if (Node && ArgIdx < Node->getNumOperands()) + return cast(Node->getOperand(ArgIdx)); + + // Sometimes metadata containing kernel attributes is not attached to the + // function, but can be found in the named module-level metadata instead. + // For example: + // !opencl.kernels = !{!0} + // !0 = !{void ()* @someKernelFunction, !1, ...} + // !1 = !{!"kernel_arg_addr_space", ...} + // In this case the actual index of searched argument attribute is ArgIdx + 1, + // since the first metadata node operand is occupied by attribute name + // ("kernel_arg_addr_space" in the example above). + unsigned MDArgIdx = ArgIdx + 1; + NamedMDNode *OpenCLKernelsMD = + KernelFunction.getParent()->getNamedMetadata("opencl.kernels"); + if (!OpenCLKernelsMD || OpenCLKernelsMD->getNumOperands() == 0) + return nullptr; + + // KernelToMDNodeList contains kernel function declarations followed by + // corresponding MDNodes for each attribute. Search only MDNodes "belonging" + // to the currently lowered kernel function. + MDNode *KernelToMDNodeList = OpenCLKernelsMD->getOperand(0); + bool FoundLoweredKernelFunction = false; + for (const MDOperand &Operand : KernelToMDNodeList->operands()) { + ValueAsMetadata *MaybeValue = dyn_cast(Operand); + if (MaybeValue && dyn_cast(MaybeValue->getValue())->getName() == + KernelFunction.getName()) { + FoundLoweredKernelFunction = true; + continue; + } + if (MaybeValue && FoundLoweredKernelFunction) + return nullptr; + + MDNode *MaybeNode = dyn_cast(Operand); + if (FoundLoweredKernelFunction && MaybeNode && + cast(MaybeNode->getOperand(0))->getString() == + AttributeName && + MDArgIdx < MaybeNode->getNumOperands()) + return cast(MaybeNode->getOperand(MDArgIdx)); + } + return nullptr; +} + +static SPIRV::AccessQualifier::AccessQualifier +getArgAccessQual(const Function &F, unsigned ArgIdx) { + if (F.getCallingConv() != CallingConv::SPIR_KERNEL) + return SPIRV::AccessQualifier::ReadWrite; + + MDString *ArgAttribute = + getKernelArgAttribute(F, ArgIdx, "kernel_arg_access_qual"); + if (!ArgAttribute) + return SPIRV::AccessQualifier::ReadWrite; + + if (ArgAttribute->getString().compare("read_only") == 0) + return SPIRV::AccessQualifier::ReadOnly; + if (ArgAttribute->getString().compare("write_only") == 0) + return SPIRV::AccessQualifier::WriteOnly; + return SPIRV::AccessQualifier::ReadWrite; +} + +static std::vector +getKernelArgTypeQual(const Function &KernelFunction, unsigned ArgIdx) { + MDString *ArgAttribute = + getKernelArgAttribute(KernelFunction, ArgIdx, "kernel_arg_type_qual"); + if (ArgAttribute && ArgAttribute->getString().compare("volatile") == 0) + return {SPIRV::Decoration::Volatile}; + return {}; +} + +static Type *getArgType(const Function &F, unsigned ArgIdx) { + Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(ArgIdx); + if (F.getCallingConv() != CallingConv::SPIR_KERNEL || + isSpecialOpaqueType(OriginalArgType)) + return OriginalArgType; + + MDString *MDKernelArgType = + getKernelArgAttribute(F, ArgIdx, "kernel_arg_type"); + if (!MDKernelArgType || !MDKernelArgType->getString().endswith("_t")) + return OriginalArgType; + + std::string KernelArgTypeStr = "opencl." + MDKernelArgType->getString().str(); + Type *ExistingOpaqueType = + StructType::getTypeByName(F.getContext(), KernelArgTypeStr); + return ExistingOpaqueType + ? ExistingOpaqueType + : StructType::create(F.getContext(), KernelArgTypeStr); +} + bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, const Function &F, ArrayRef> VRegs, @@ -132,18 +228,10 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, // TODO: handle the case of multiple registers. if (VRegs[i].size() > 1) return false; - Type *ArgTy = FTy->getParamType(i); - SPIRV::AccessQualifier::AccessQualifier AQ = - SPIRV::AccessQualifier::ReadWrite; - MDNode *Node = F.getMetadata("kernel_arg_access_qual"); - if (Node && i < Node->getNumOperands()) { - StringRef AQString = cast(Node->getOperand(i))->getString(); - if (AQString.compare("read_only") == 0) - AQ = SPIRV::AccessQualifier::ReadOnly; - else if (AQString.compare("write_only") == 0) - AQ = SPIRV::AccessQualifier::WriteOnly; - } - auto *SpirvTy = GR->assignTypeToVReg(ArgTy, VRegs[i][0], MIRBuilder, AQ); + SPIRV::AccessQualifier::AccessQualifier ArgAccessQual = + getArgAccessQual(F, i); + auto *SpirvTy = GR->assignTypeToVReg(getArgType(F, i), VRegs[i][0], + MIRBuilder, ArgAccessQual); ArgTypeVRegs.push_back(SpirvTy); if (Arg.hasName()) @@ -178,14 +266,15 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::FuncParamAttr, {Attr}); } - Node = F.getMetadata("kernel_arg_type_qual"); - if (Node && i < Node->getNumOperands()) { - StringRef TypeQual = cast(Node->getOperand(i))->getString(); - if (TypeQual.compare("volatile") == 0) - buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Volatile, - {}); + + if (F.getCallingConv() == CallingConv::SPIR_KERNEL) { + std::vector ArgTypeQualDecs = + getKernelArgTypeQual(F, i); + for (SPIRV::Decoration::Decoration Decoration : ArgTypeQualDecs) + buildOpDecorate(VRegs[i][0], MIRBuilder, Decoration, {}); } - Node = F.getMetadata("spirv.ParameterDecorations"); + + MDNode *Node = F.getMetadata("spirv.ParameterDecorations"); if (Node && i < Node->getNumOperands() && isa(Node->getOperand(i))) { MDNode *MD = cast(Node->getOperand(i)); diff --git a/llvm/test/CodeGen/SPIRV/opencl/kernel_arg_type_function_metadata.ll b/llvm/test/CodeGen/SPIRV/opencl/kernel_arg_type_function_metadata.ll new file mode 100644 index 00000000000000..ce5910efc6ccd4 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/opencl/kernel_arg_type_function_metadata.ll @@ -0,0 +1,12 @@ +; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s + +; CHECK: %[[#TypeSampler:]] = OpTypeSampler +define spir_kernel void @foo(i64 %sampler) !kernel_arg_addr_space !7 !kernel_arg_access_qual !8 !kernel_arg_type !9 !kernel_arg_type_qual !10 !kernel_arg_base_type !9 { +entry: + ret void +} + +!7 = !{i32 0} +!8 = !{!"none"} +!9 = !{!"sampler_t"} +!10 = !{!""} diff --git a/llvm/test/CodeGen/SPIRV/opencl/kernel_arg_type_module_metadata.ll b/llvm/test/CodeGen/SPIRV/opencl/kernel_arg_type_module_metadata.ll new file mode 100644 index 00000000000000..b5bb8433321daf --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/opencl/kernel_arg_type_module_metadata.ll @@ -0,0 +1,16 @@ +; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s + +; CHECK: %[[#TypeSampler:]] = OpTypeSampler +define spir_kernel void @foo(i64 %sampler) { +entry: + ret void +} +!opencl.kernels = !{!0} + +!0 = !{void (i64)* @foo, !1, !2, !3, !4, !5, !6} +!1 = !{!"kernel_arg_addr_space", i32 0} +!2 = !{!"kernel_arg_access_qual", !"none"} +!3 = !{!"kernel_arg_type", !"sampler_t"} +!4 = !{!"kernel_arg_type_qual", !""} +!5 = !{!"kernel_arg_base_type", !"sampler_t"} +!6 = !{!"kernel_arg_name", !"sampler"}