Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPIR-V] Cast ptr kernel args to i8* when used as Store's value operand #78603

Conversation

michalpaszkowski
Copy link
Member

@michalpaszkowski michalpaszkowski commented Jan 18, 2024

Handle a special case when StoreInst's value operand is a kernel argument of a pointer type. Since these arguments could have either a basic element type (e.g. float*) or OpenCL builtin type (sampler_t), bitcast the StoreInst's value operand to default pointer element type (i8).

This pull request addresses the issue #72864

@llvmbot
Copy link
Collaborator

llvmbot commented Jan 18, 2024

@llvm/pr-subscribers-backend-spir-v

Author: Michal Paszkowski (michalpaszkowski)

Changes

Handle a special case when StoreInst's value operand is a kernel argument of a pointer type. Since these arguments could have either a basic element type (e.g. float*) or OpenCL builtin type (sampler_t), bitcast the StoreInst's value operand to default pointer element type (i8).


Full diff: https://github.com/llvm/llvm-project/pull/78603.diff

6 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/CMakeLists.txt (+1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp (+5-58)
  • (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+28-6)
  • (added) llvm/lib/Target/SPIRV/SPIRVMetadata.cpp (+92)
  • (added) llvm/lib/Target/SPIRV/SPIRVMetadata.h (+31)
  • (added) llvm/test/CodeGen/SPIRV/pointers/store-kernel-arg-ptr-as-value-operand.ll (+19)
diff --git a/llvm/lib/Target/SPIRV/CMakeLists.txt b/llvm/lib/Target/SPIRV/CMakeLists.txt
index 7d17c307db13a04..d9e24375dcb243f 100644
--- a/llvm/lib/Target/SPIRV/CMakeLists.txt
+++ b/llvm/lib/Target/SPIRV/CMakeLists.txt
@@ -26,6 +26,7 @@ add_llvm_target(SPIRVCodeGen
   SPIRVISelLowering.cpp
   SPIRVLegalizerInfo.cpp
   SPIRVMCInstLower.cpp
+  SPIRVMetadata.cpp
   SPIRVModuleAnalysis.cpp
   SPIRVPreLegalizer.cpp
   SPIRVPrepareFunctions.cpp
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 0a8b5499a1fc2ac..62c08bab46eee27 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -17,6 +17,7 @@
 #include "SPIRVBuiltins.h"
 #include "SPIRVGlobalRegistry.h"
 #include "SPIRVISelLowering.h"
+#include "SPIRVMetadata.h"
 #include "SPIRVRegisterInfo.h"
 #include "SPIRVSubtarget.h"
 #include "SPIRVUtils.h"
@@ -117,64 +118,12 @@ static FunctionType *getOriginalFunctionType(const Function &F) {
   return FunctionType::get(RetTy, ArgTypes, F.isVarArg());
 }
 
-static MDString *getKernelArgAttribute(const Function &KernelFunction,
-                                       unsigned ArgIdx,
-                                       const StringRef AttributeName) {
-  assert(KernelFunction.getCallingConv() == CallingConv::SPIR_KERNEL &&
-         "Kernel attributes are attached/belong only to kernel functions");
-
-  // Lookup the argument attribute in metadata attached to the kernel function.
-  MDNode *Node = KernelFunction.getMetadata(AttributeName);
-  if (Node && ArgIdx < Node->getNumOperands())
-    return cast<MDString>(Node->getOperand(ArgIdx));
-
-  // Sometimes metadata containing kernel attributes is not attached to the
-  // function, but can be found in the named module-level metadata instead.
-  // For example:
-  //   !opencl.kernels = !{!0}
-  //   !0 = !{void ()* @someKernelFunction, !1, ...}
-  //   !1 = !{!"kernel_arg_addr_space", ...}
-  // In this case the actual index of searched argument attribute is ArgIdx + 1,
-  // since the first metadata node operand is occupied by attribute name
-  // ("kernel_arg_addr_space" in the example above).
-  unsigned MDArgIdx = ArgIdx + 1;
-  NamedMDNode *OpenCLKernelsMD =
-      KernelFunction.getParent()->getNamedMetadata("opencl.kernels");
-  if (!OpenCLKernelsMD || OpenCLKernelsMD->getNumOperands() == 0)
-    return nullptr;
-
-  // KernelToMDNodeList contains kernel function declarations followed by
-  // corresponding MDNodes for each attribute. Search only MDNodes "belonging"
-  // to the currently lowered kernel function.
-  MDNode *KernelToMDNodeList = OpenCLKernelsMD->getOperand(0);
-  bool FoundLoweredKernelFunction = false;
-  for (const MDOperand &Operand : KernelToMDNodeList->operands()) {
-    ValueAsMetadata *MaybeValue = dyn_cast<ValueAsMetadata>(Operand);
-    if (MaybeValue && dyn_cast<Function>(MaybeValue->getValue())->getName() ==
-                          KernelFunction.getName()) {
-      FoundLoweredKernelFunction = true;
-      continue;
-    }
-    if (MaybeValue && FoundLoweredKernelFunction)
-      return nullptr;
-
-    MDNode *MaybeNode = dyn_cast<MDNode>(Operand);
-    if (FoundLoweredKernelFunction && MaybeNode &&
-        cast<MDString>(MaybeNode->getOperand(0))->getString() ==
-            AttributeName &&
-        MDArgIdx < MaybeNode->getNumOperands())
-      return cast<MDString>(MaybeNode->getOperand(MDArgIdx));
-  }
-  return nullptr;
-}
-
 static SPIRV::AccessQualifier::AccessQualifier
 getArgAccessQual(const Function &F, unsigned ArgIdx) {
   if (F.getCallingConv() != CallingConv::SPIR_KERNEL)
     return SPIRV::AccessQualifier::ReadWrite;
 
-  MDString *ArgAttribute =
-      getKernelArgAttribute(F, ArgIdx, "kernel_arg_access_qual");
+  MDString *ArgAttribute = getOCLKernelArgAccessQual(F, ArgIdx);
   if (!ArgAttribute)
     return SPIRV::AccessQualifier::ReadWrite;
 
@@ -186,9 +135,8 @@ getArgAccessQual(const Function &F, unsigned ArgIdx) {
 }
 
 static std::vector<SPIRV::Decoration::Decoration>
-getKernelArgTypeQual(const Function &KernelFunction, unsigned ArgIdx) {
-  MDString *ArgAttribute =
-      getKernelArgAttribute(KernelFunction, ArgIdx, "kernel_arg_type_qual");
+getKernelArgTypeQual(const Function &F, unsigned ArgIdx) {
+  MDString *ArgAttribute = getOCLKernelArgTypeQual(F, ArgIdx);
   if (ArgAttribute && ArgAttribute->getString().compare("volatile") == 0)
     return {SPIRV::Decoration::Volatile};
   return {};
@@ -209,8 +157,7 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
       isSpecialOpaqueType(OriginalArgType))
     return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
 
-  MDString *MDKernelArgType =
-      getKernelArgAttribute(F, ArgIdx, "kernel_arg_type");
+  MDString *MDKernelArgType = getOCLKernelArgType(F, ArgIdx);
   if (!MDKernelArgType || (!MDKernelArgType->getString().ends_with("*") &&
                            !MDKernelArgType->getString().ends_with("_t")))
     return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 90ec98bb361d3c0..56384ba1c006478 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "SPIRV.h"
+#include "SPIRVMetadata.h"
 #include "SPIRVTargetMachine.h"
 #include "SPIRVUtils.h"
 #include "llvm/IR/IRBuilder.h"
@@ -282,7 +283,26 @@ void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {
   Value *Pointer;
   Type *ExpectedElementType;
   unsigned OperandToReplace;
-  if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
+  bool AllowCastingToChar = false;
+
+  StoreInst *SI = dyn_cast<StoreInst>(I);
+  if (SI && F->getCallingConv() == CallingConv::SPIR_KERNEL &&
+      SI->getValueOperand()->getType()->isPointerTy() &&
+      isa<Argument>(SI->getValueOperand())) {
+    Argument *Arg = dyn_cast<Argument>(SI->getValueOperand());
+    MDString *ArgType = getOCLKernelArgType(*Arg->getParent(), Arg->getArgNo());
+    if (!ArgType || ArgType->getString().starts_with("uchar*"))
+      return;
+
+    // Handle special case when StoreInst's value operand is a kernel argument
+    // of a pointer type. Since these arguments could have either a basic
+    // element type (e.g. float*) or OpenCL builtin type (sampler_t), bitcast
+    // the StoreInst's value operand to default pointer element type (i8).
+    Pointer = Arg;
+    ExpectedElementType = IntegerType::getInt8Ty(F->getContext());
+    OperandToReplace = 0;
+    AllowCastingToChar = true;
+  } else if (SI) {
     Pointer = SI->getPointerOperand();
     ExpectedElementType = SI->getValueOperand()->getType();
     OperandToReplace = 1;
@@ -364,13 +384,15 @@ void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {
 
   // Do not emit spv_ptrcast if it would cast to the default pointer element
   // type (i8) of the same address space.
-  if (ExpectedElementType->isIntegerTy(8))
+  if (ExpectedElementType->isIntegerTy(8) && !AllowCastingToChar)
     return;
 
-  // If this would be the first spv_ptrcast and there is no spv_assign_ptr_type
-  // for this pointer before, do not emit spv_ptrcast but emit
-  // spv_assign_ptr_type instead.
-  if (FirstPtrCastOrAssignPtrType && isa<Instruction>(Pointer)) {
+  // If this would be the first spv_ptrcast, the pointer's defining instruction
+  // requires spv_assign_ptr_type and does not already have one, do not emit
+  // spv_ptrcast and emit spv_assign_ptr_type instead.
+  Instruction *PointerDefInst = dyn_cast<Instruction>(Pointer);
+  if (FirstPtrCastOrAssignPtrType && PointerDefInst &&
+      requireAssignPtrType(PointerDefInst)) {
     buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {Pointer->getType()},
                     ExpectedElementTypeConst, Pointer,
                     {IRB->getInt32(AddressSpace)});
diff --git a/llvm/lib/Target/SPIRV/SPIRVMetadata.cpp b/llvm/lib/Target/SPIRV/SPIRVMetadata.cpp
new file mode 100644
index 000000000000000..e8c707742f24437
--- /dev/null
+++ b/llvm/lib/Target/SPIRV/SPIRVMetadata.cpp
@@ -0,0 +1,92 @@
+//===--- SPIRVMetadata.cpp ---- IR Metadata Parsing Funcs -------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains functions needed for parsing LLVM IR metadata relevant
+// to the SPIR-V target.
+//
+//===----------------------------------------------------------------------===//
+
+#include "SPIRVMetadata.h"
+
+using namespace llvm;
+
+static MDString *getOCLKernelArgAttribute(const Function &F, unsigned ArgIdx,
+                                          const StringRef AttributeName) {
+  assert(
+      F.getCallingConv() == CallingConv::SPIR_KERNEL &&
+      "Kernel attributes are attached/belong only to OpenCL kernel functions");
+
+  // Lookup the argument attribute in metadata attached to the kernel function.
+  MDNode *Node = F.getMetadata(AttributeName);
+  if (Node && ArgIdx < Node->getNumOperands())
+    return cast<MDString>(Node->getOperand(ArgIdx));
+
+  // Sometimes metadata containing kernel attributes is not attached to the
+  // function, but can be found in the named module-level metadata instead.
+  // For example:
+  //   !opencl.kernels = !{!0}
+  //   !0 = !{void ()* @someKernelFunction, !1, ...}
+  //   !1 = !{!"kernel_arg_addr_space", ...}
+  // In this case the actual index of searched argument attribute is ArgIdx + 1,
+  // since the first metadata node operand is occupied by attribute name
+  // ("kernel_arg_addr_space" in the example above).
+  unsigned MDArgIdx = ArgIdx + 1;
+  NamedMDNode *OpenCLKernelsMD =
+      F.getParent()->getNamedMetadata("opencl.kernels");
+  if (!OpenCLKernelsMD || OpenCLKernelsMD->getNumOperands() == 0)
+    return nullptr;
+
+  // KernelToMDNodeList contains kernel function declarations followed by
+  // corresponding MDNodes for each attribute. Search only MDNodes "belonging"
+  // to the currently lowered kernel function.
+  MDNode *KernelToMDNodeList = OpenCLKernelsMD->getOperand(0);
+  bool FoundLoweredKernelFunction = false;
+  for (const MDOperand &Operand : KernelToMDNodeList->operands()) {
+    ValueAsMetadata *MaybeValue = dyn_cast<ValueAsMetadata>(Operand);
+    if (MaybeValue &&
+        dyn_cast<Function>(MaybeValue->getValue())->getName() == F.getName()) {
+      FoundLoweredKernelFunction = true;
+      continue;
+    }
+    if (MaybeValue && FoundLoweredKernelFunction)
+      return nullptr;
+
+    MDNode *MaybeNode = dyn_cast<MDNode>(Operand);
+    if (FoundLoweredKernelFunction && MaybeNode &&
+        cast<MDString>(MaybeNode->getOperand(0))->getString() ==
+            AttributeName &&
+        MDArgIdx < MaybeNode->getNumOperands())
+      return cast<MDString>(MaybeNode->getOperand(MDArgIdx));
+  }
+  return nullptr;
+}
+
+namespace llvm {
+
+MDString *getOCLKernelArgAccessQual(const Function &F, unsigned ArgIdx) {
+  assert(
+      F.getCallingConv() == CallingConv::SPIR_KERNEL &&
+      "Kernel attributes are attached/belong only to OpenCL kernel functions");
+  return getOCLKernelArgAttribute(F, ArgIdx, "kernel_arg_access_qual");
+}
+
+MDString *getOCLKernelArgTypeQual(const Function &F, unsigned ArgIdx) {
+  assert(
+      F.getCallingConv() == CallingConv::SPIR_KERNEL &&
+      "Kernel attributes are attached/belong only to OpenCL kernel functions");
+  return getOCLKernelArgAttribute(F, ArgIdx, "kernel_arg_type_qual");
+}
+
+MDString *getOCLKernelArgType(const Function &F, unsigned ArgIdx) {
+  assert(
+      F.getCallingConv() == CallingConv::SPIR_KERNEL &&
+      "Kernel attributes are attached/belong only to OpenCL kernel functions");
+  return getOCLKernelArgAttribute(F, ArgIdx, "kernel_arg_type");
+}
+
+} // namespace llvm
diff --git a/llvm/lib/Target/SPIRV/SPIRVMetadata.h b/llvm/lib/Target/SPIRV/SPIRVMetadata.h
new file mode 100644
index 000000000000000..50aee7234395927
--- /dev/null
+++ b/llvm/lib/Target/SPIRV/SPIRVMetadata.h
@@ -0,0 +1,31 @@
+//===--- SPIRVMetadata.h ---- IR Metadata Parsing Funcs ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains functions needed for parsing LLVM IR metadata relevant
+// to the SPIR-V target.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVMETADATA_H
+#define LLVM_LIB_TARGET_SPIRV_SPIRVMETADATA_H
+
+#include "llvm/IR/Metadata.h"
+#include "llvm/IR/Module.h"
+
+namespace llvm {
+
+//===----------------------------------------------------------------------===//
+// OpenCL Metadata
+//
+
+MDString *getOCLKernelArgAccessQual(const Function &F, unsigned ArgIdx);
+MDString *getOCLKernelArgTypeQual(const Function &F, unsigned ArgIdx);
+MDString *getOCLKernelArgType(const Function &F, unsigned ArgIdx);
+
+} // namespace llvm
+#endif // LLVM_LIB_TARGET_SPIRV_METADATA_H
diff --git a/llvm/test/CodeGen/SPIRV/pointers/store-kernel-arg-ptr-as-value-operand.ll b/llvm/test/CodeGen/SPIRV/pointers/store-kernel-arg-ptr-as-value-operand.ll
new file mode 100644
index 000000000000000..e7ce3ef621e83a0
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/store-kernel-arg-ptr-as-value-operand.ll
@@ -0,0 +1,19 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+define spir_kernel void @foo(ptr addrspace(1) %arg) !kernel_arg_addr_space !1 !kernel_arg_access_qual !2 !kernel_arg_type !3 !kernel_arg_base_type !3 !kernel_arg_type_qual !4 {
+  %var = alloca ptr addrspace(1), align 8
+; CHECK: %[[#VAR:]] = OpVariable %[[#]] Function
+  store ptr addrspace(1) %arg, ptr %var, align 8
+; The test itends to verify that OpStore uses OpVariable result directly (without a bitcast).
+; Other type checking is done by spirv-val.
+; CHECK: OpStore %[[#VAR]] %[[#]] Aligned 8
+  %lod = load ptr addrspace(1), ptr %var, align 8
+  %idx = getelementptr inbounds i64, ptr addrspace(1) %lod, i64 0
+  ret void
+}
+
+!1 = !{i32 1}
+!2 = !{!"none"}
+!3 = !{!"ulong*"}
+!4 = !{!""}

@michalpaszkowski michalpaszkowski linked an issue Jan 18, 2024 that may be closed by this pull request
SI->getValueOperand()->getType()->isPointerTy() &&
isa<Argument>(SI->getValueOperand())) {
Argument *Arg = dyn_cast<Argument>(SI->getValueOperand());
MDString *ArgType = getOCLKernelArgType(*Arg->getParent(), Arg->getArgNo());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can Arg be nullptr?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to safely use cast<> here instead of dyn_cast since you have the isa check.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @sudonatalie! Fixed!

; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}

define spir_kernel void @foo(ptr addrspace(1) %arg) !kernel_arg_addr_space !1 !kernel_arg_access_qual !2 !kernel_arg_type !3 !kernel_arg_base_type !3 !kernel_arg_type_qual !4 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a test that covers the case you mentioned where the arguments can be found in the module-level metadata?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we do have a test in llvm/test/CodeGen/SPIRV/opencl/metadata/kernel_arg_type_module_metadata.ll

SI->getValueOperand()->getType()->isPointerTy() &&
isa<Argument>(SI->getValueOperand())) {
Argument *Arg = dyn_cast<Argument>(SI->getValueOperand());
MDString *ArgType = getOCLKernelArgType(*Arg->getParent(), Arg->getArgNo());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to safely use cast<> here instead of dyn_cast since you have the isa check.

Handle a special case when StoreInst's value operand is a kernel
argument of a pointer type. Since these arguments could have either a
basic element type (e.g. float*) or OpenCL builtin type (sampler_t),
bitcast the StoreInst's value operand to default pointer element type
(i8).
@michalpaszkowski michalpaszkowski force-pushed the fix_store_value_operand_kernel_ptr_arg branch from da3ccdd to 497f7ba Compare January 25, 2024 12:28
@michalpaszkowski michalpaszkowski merged commit 0fbaf03 into llvm:main Jan 29, 2024
4 of 5 checks passed
@michalpaszkowski michalpaszkowski deleted the fix_store_value_operand_kernel_ptr_arg branch January 29, 2024 03:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[SPIR-V][OpenCL] Type mismatch on OpStore instruction
5 participants