Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPIR-V] Do not use OpenCL metadata for ptr element type resolution #82678

Conversation

michalpaszkowski
Copy link
Member

This pull request aims to remove any dependency on OpenCL/SPIR-V type information in LLVM IR metadata. While, using metadata might simplify and prettify the resulting SPIR-V output (and restore some of the information missed in the transformation to opaque pointers), the overall methodology for resolving kernel parameter types is highly inefficient.

This pull request is work in progress, but the high-level strategy is to assign kernel parameter types in this order:

  1. Resolving the types using builtin function calls as mangled names must contain type information or by looking up builtin definition in SPIRVBuiltins.td. Then:
  • Assigning the type temporarily using an intrinsic and later setting the right SPIR-V type in SPIRVGlobalRegistry after IRTranslation
  • Inserting a bitcast
  1. Defaulting to LLVM IR types (in case of pointers the generic i8* type)

In case of type incompatibility (e.g. parameter defined initially as sampler_t and later used as image_t) the error will be found early on before IRTranslation (most likely in the SPIRVEmitIntrinsics pass).

The code repetition in parseBuiltinCallArgumentBaseType(...) will be removed in an amended commit.

Copy link

github-actions bot commented Feb 22, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@michalpaszkowski michalpaszkowski force-pushed the feature_do_not_rely_on_type_metadata branch from a4a0d28 to 8b57539 Compare February 28, 2024 12:47
@michalpaszkowski michalpaszkowski marked this pull request as ready for review February 28, 2024 12:49
@llvmbot
Copy link
Collaborator

llvmbot commented Feb 28, 2024

@llvm/pr-subscribers-backend-spir-v

Author: Michal Paszkowski (michalpaszkowski)

Changes

This pull request aims to remove any dependency on OpenCL/SPIR-V type information in LLVM IR metadata. While, using metadata might simplify and prettify the resulting SPIR-V output (and restore some of the information missed in the transformation to opaque pointers), the overall methodology for resolving kernel parameter types is highly inefficient.

This pull request is work in progress, but the high-level strategy is to assign kernel parameter types in this order:

  1. Resolving the types using builtin function calls as mangled names must contain type information or by looking up builtin definition in SPIRVBuiltins.td. Then:
  • Assigning the type temporarily using an intrinsic and later setting the right SPIR-V type in SPIRVGlobalRegistry after IRTranslation
  • Inserting a bitcast
  1. Defaulting to LLVM IR types (in case of pointers the generic i8* type)

In case of type incompatibility (e.g. parameter defined initially as sampler_t and later used as image_t) the error will be found early on before IRTranslation (most likely in the SPIRVEmitIntrinsics pass).

The code repetition in parseBuiltinCallArgumentBaseType(...) will be removed in an amended commit.


Patch is 62.00 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/82678.diff

35 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp (+65-10)
  • (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.h (+13-3)
  • (modified) llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp (+35-19)
  • (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+108-71)
  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+5-39)
  • (modified) llvm/lib/Target/SPIRV/SPIRVMetadata.cpp (-7)
  • (modified) llvm/lib/Target/SPIRV/SPIRVMetadata.h (-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+4-2)
  • (modified) llvm/lib/Target/SPIRV/SPIRVUtils.cpp (+29-13)
  • (modified) llvm/lib/Target/SPIRV/SPIRVUtils.h (+3-4)
  • (modified) llvm/test/CodeGen/SPIRV/function/alloca-load-store.ll (+4-7)
  • (modified) llvm/test/CodeGen/SPIRV/half_no_extension.ll (-3)
  • (modified) llvm/test/CodeGen/SPIRV/instructions/undef-nested-composite-store.ll (+4-6)
  • (modified) llvm/test/CodeGen/SPIRV/instructions/undef-simple-composite-store.ll (+4-6)
  • (modified) llvm/test/CodeGen/SPIRV/opaque_pointers.ll (+5-8)
  • (modified) llvm/test/CodeGen/SPIRV/opencl/basic/get_global_offset.ll (+10-14)
  • (removed) llvm/test/CodeGen/SPIRV/opencl/metadata/kernel_arg_type_function_metadata.ll (-12)
  • (removed) llvm/test/CodeGen/SPIRV/opencl/metadata/kernel_arg_type_module_metadata.ll (-16)
  • (modified) llvm/test/CodeGen/SPIRV/opencl/vload2.ll (+14-6)
  • (added) llvm/test/CodeGen/SPIRV/opencl/vstore2.ll (+23)
  • (added) llvm/test/CodeGen/SPIRV/passes/SPIRVEmitIntrinsics-TargetExtType-arg-no-spv_assign_type.ll (+12)
  • (added) llvm/test/CodeGen/SPIRV/passes/SPIRVEmitIntrinsics-no-divergent-spv_assign_ptr_type.ll (+12)
  • (added) llvm/test/CodeGen/SPIRV/passes/SPIRVEmitIntrinsics-no-duplicate-spv_assign_type.ll (+14)
  • (modified) llvm/test/CodeGen/SPIRV/pointers/getelementptr-kernel-arg-char.ll (+6-14)
  • (added) llvm/test/CodeGen/SPIRV/pointers/kernel-argument-builtin-vload-type-discrapency.ll (+35)
  • (added) llvm/test/CodeGen/SPIRV/pointers/kernel-argument-pointer-type-deduction-mismatch.ll (+12)
  • (added) llvm/test/CodeGen/SPIRV/pointers/kernel-argument-pointer-type-deduction-no-metadata.ll (+13)
  • (added) llvm/test/CodeGen/SPIRV/pointers/store-operand-ptr-to-struct.ll (+19)
  • (renamed) llvm/test/CodeGen/SPIRV/pointers/two-bitcast-or-param-users.ll (+3-7)
  • (modified) llvm/test/CodeGen/SPIRV/pointers/two-subsequent-bitcasts.ll (+3-4)
  • (modified) llvm/test/CodeGen/SPIRV/sitofp-with-bool.ll (+2-3)
  • (modified) llvm/test/CodeGen/SPIRV/transcoding/OpenCL/atomic_cmpxchg.ll (+2-3)
  • (modified) llvm/test/CodeGen/SPIRV/transcoding/OpenCL/atomic_legacy.ll (+2-3)
  • (modified) llvm/test/CodeGen/SPIRV/transcoding/spirv-private-array-initialization.ll (-1)
  • (modified) llvm/test/CodeGen/SPIRV/uitofp-with-bool.ll (+2-3)
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index c1bb27322443ff..119a15b5f1bfb9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -1775,7 +1775,7 @@ static const Type *getMachineInstrType(MachineInstr *MI) {
     return nullptr;
   Type *Ty = getMDOperandAsType(NextMI->getOperand(2).getMetadata(), 0);
   assert(Ty && "Type is expected");
-  return getTypedPtrEltType(Ty);
+  return Ty;
 }
 
 static const Type *getBlockStructType(Register ParamReg,
@@ -1787,7 +1787,7 @@ static const Type *getBlockStructType(Register ParamReg,
   // section 6.12.5 should guarantee that we can do this.
   MachineInstr *MI = getBlockStructInstr(ParamReg, MRI);
   if (MI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE)
-    return getTypedPtrEltType(MI->getOperand(1).getGlobal()->getType());
+    return MI->getOperand(1).getGlobal()->getType();
   assert(isSpvIntrinsic(*MI, Intrinsic::spv_alloca) &&
          "Blocks in OpenCL C must be traceable to allocation site");
   return getMachineInstrType(MI);
@@ -2043,7 +2043,8 @@ static bool generateVectorLoadStoreInst(const SPIRV::IncomingCall *Call,
           .addImm(Builtin->Number);
   for (auto Argument : Call->Arguments)
     MIB.addUse(Argument);
-  MIB.addImm(Builtin->ElementCount);
+  if (Builtin->Name.contains("load") && Builtin->ElementCount > 1)
+    MIB.addImm(Builtin->ElementCount);
 
   // Rounding mode should be passed as a last argument in the MI for builtins
   // like "vstorea_halfn_r".
@@ -2179,6 +2180,61 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
   return false;
 }
 
+Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
+                                       unsigned ArgIdx, LLVMContext &Ctx) {
+  SmallVector<StringRef, 10> BuiltinArgsTypeStrs;
+  StringRef BuiltinArgs =
+      DemangledCall.slice(DemangledCall.find('(') + 1, DemangledCall.find(')'));
+  BuiltinArgs.split(BuiltinArgsTypeStrs, ',', -1, false);
+  if (ArgIdx >= BuiltinArgsTypeStrs.size())
+    return nullptr;
+  StringRef TypeStr = BuiltinArgsTypeStrs[ArgIdx].trim();
+
+  // Parse strings representing OpenCL builtin types.
+  if (hasBuiltinTypePrefix(TypeStr)) {
+    // OpenCL builtin types in demangled call strings have the following format:
+    // e.g. ocl_image2d_ro
+    bool IsOCLBuiltinType = TypeStr.consume_front("ocl_");
+    assert(IsOCLBuiltinType && "Invalid OpenCL builtin prefix");
+
+    // Check if this is pointer to a builtin type and not just pointer
+    // representing a builtin type. In case it is a pointer to builtin type,
+    // this will require additional handling in the method calling
+    // parseBuiltinCallArgumentBaseType(...) as this function only retrieves the
+    // base types.
+    if (TypeStr.ends_with("*"))
+      TypeStr = TypeStr.slice(0, TypeStr.find_first_of(" "));
+
+    return parseBuiltinTypeNameToTargetExtType("opencl." + TypeStr.str() + "_t",
+                                               Ctx);
+  }
+
+  // Parse type name in either "typeN" or "type vector[N]" format, where
+  // N is the number of elements of the vector.
+  Type *BaseType;
+  unsigned VecElts = 0;
+
+  BaseType = parseBasicTypeName(TypeStr, Ctx);
+  if (!BaseType)
+    // Unable to recognize SPIRV type name.
+    return nullptr;
+
+  if (BaseType->isVoidTy())
+    BaseType = Type::getInt8Ty(Ctx);
+
+  // Handle "typeN*" or "type vector[N]*".
+  TypeStr.consume_back("*");
+
+  if (TypeStr.consume_front(" vector["))
+    TypeStr = TypeStr.substr(0, TypeStr.find(']'));
+
+  TypeStr.getAsInteger(10, VecElts);
+  if (VecElts > 0)
+    BaseType = VectorType::get(BaseType, VecElts, false);
+
+  return BaseType;
+}
+
 struct BuiltinType {
   StringRef Name;
   uint32_t Opcode;
@@ -2277,9 +2333,8 @@ static SPIRVType *getSampledImageType(const TargetExtType *OpaqueType,
 }
 
 namespace SPIRV {
-const TargetExtType *
-parseBuiltinTypeNameToTargetExtType(std::string TypeName,
-                                    MachineIRBuilder &MIRBuilder) {
+TargetExtType *parseBuiltinTypeNameToTargetExtType(std::string TypeName,
+                                                   LLVMContext &Context) {
   StringRef NameWithParameters = TypeName;
 
   // Pointers-to-opaque-structs representing OpenCL types are first translated
@@ -2303,7 +2358,7 @@ parseBuiltinTypeNameToTargetExtType(std::string TypeName,
   // Parameterized SPIR-V builtins names follow this format:
   // e.g. %spirv.Image._void_1_0_0_0_0_0_0, %spirv.Pipe._0
   if (!NameWithParameters.contains('_'))
-    return TargetExtType::get(MIRBuilder.getContext(), NameWithParameters);
+    return TargetExtType::get(Context, NameWithParameters);
 
   SmallVector<StringRef> Parameters;
   unsigned BaseNameLength = NameWithParameters.find('_') - 1;
@@ -2313,7 +2368,7 @@ parseBuiltinTypeNameToTargetExtType(std::string TypeName,
   bool HasTypeParameter = !isDigit(Parameters[0][0]);
   if (HasTypeParameter)
     TypeParameters.push_back(parseTypeString(
-        Parameters[0], MIRBuilder.getMF().getFunction().getContext()));
+        Parameters[0], Context));
   SmallVector<unsigned> IntParameters;
   for (unsigned i = HasTypeParameter ? 1 : 0; i < Parameters.size(); i++) {
     unsigned IntParameter = 0;
@@ -2323,7 +2378,7 @@ parseBuiltinTypeNameToTargetExtType(std::string TypeName,
            "Invalid format of SPIR-V builtin parameter literal!");
     IntParameters.push_back(IntParameter);
   }
-  return TargetExtType::get(MIRBuilder.getContext(),
+  return TargetExtType::get(Context,
                             NameWithParameters.substr(0, BaseNameLength),
                             TypeParameters, IntParameters);
 }
@@ -2343,7 +2398,7 @@ SPIRVType *lowerBuiltinType(const Type *OpaqueType,
   const TargetExtType *BuiltinType = dyn_cast<TargetExtType>(OpaqueType);
   if (!BuiltinType)
     BuiltinType = parseBuiltinTypeNameToTargetExtType(
-        OpaqueType->getStructName().str(), MIRBuilder);
+        OpaqueType->getStructName().str(), MIRBuilder.getContext());
 
   unsigned NumStartingVRegs = MIRBuilder.getMRI()->getNumVirtRegs();
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
index 6f957295464812..649f5bfd1d7c26 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
@@ -38,6 +38,17 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
                                  const SmallVectorImpl<Register> &Args,
                                  SPIRVGlobalRegistry *GR);
 
+/// Parses the provided \p ArgIdx argument base type in the \p DemangledCall
+/// skeleton. A base type is either a basic type (e.g. i32 for int), pointer
+/// element type (e.g. i8 for char*), or builtin type (TargetExtType).
+///
+/// \return LLVM Type or nullptr if unrecognized
+///
+/// \p DemangledCall is the skeleton of the lowered builtin function call.
+/// \p ArgIdx is the index of the argument to parse.
+Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
+                                       unsigned ArgIdx, LLVMContext &Ctx);
+
 /// Translates a string representing a SPIR-V or OpenCL builtin type to a
 /// TargetExtType that can be further lowered with lowerBuiltinType().
 ///
@@ -45,9 +56,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
 ///
 /// \p TypeName is the full string representation of the SPIR-V or OpenCL
 /// builtin type.
-const TargetExtType *
-parseBuiltinTypeNameToTargetExtType(std::string TypeName,
-                                    MachineIRBuilder &MIRBuilder);
+TargetExtType *parseBuiltinTypeNameToTargetExtType(std::string TypeName,
+                                                   LLVMContext &Context);
 
 /// Handles the translation of the provided special opaque/builtin type \p Type
 /// to SPIR-V type. Generates the corresponding machine instructions for the
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index cc438b2bb8d4d7..f9197b805f0637 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -22,6 +22,8 @@
 #include "SPIRVSubtarget.h"
 #include "SPIRVUtils.h"
 #include "llvm/CodeGen/FunctionLoweringInfo.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
 #include "llvm/Support/ModRef.h"
 
 using namespace llvm;
@@ -157,28 +159,42 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
 
   Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(ArgIdx);
 
-  // In case of non-kernel SPIR-V function or already TargetExtType, use the
-  // original IR type.
-  if (F.getCallingConv() != CallingConv::SPIR_KERNEL ||
-      isSpecialOpaqueType(OriginalArgType))
+  // If OriginalArgType is non-pointer, use the OriginalArgType (the type cannot
+  // be legally reassigned later).
+  if (!OriginalArgType->isPointerTy())
     return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
 
-  SPIRVType *ResArgType = nullptr;
-  if (MDString *MDKernelArgType = getOCLKernelArgType(F, ArgIdx)) {
-    StringRef MDTypeStr = MDKernelArgType->getString();
-    if (MDTypeStr.ends_with("*"))
-      ResArgType = GR->getOrCreateSPIRVTypeByName(
-          MDTypeStr, MIRBuilder,
-          addressSpaceToStorageClass(
-              OriginalArgType->getPointerAddressSpace()));
-    else if (MDTypeStr.ends_with("_t"))
-      ResArgType = GR->getOrCreateSPIRVTypeByName(
-          "opencl." + MDTypeStr.str(), MIRBuilder,
-          SPIRV::StorageClass::Function, ArgAccessQual);
+  // In case OriginalArgType is of pointer type, there are two possibilities:
+  // 1) This is an OpenCL/SPIR-V builtin type if there is spv_assign_type
+  // intrinsic assigning a TargetExtType.
+  // 2) This is a pointer, try to retrieve pointer element type from a
+  // spv_assign_ptr_type intrinsic or otherwise use default pointer element
+  // type.
+  for (auto User : F.getArg(ArgIdx)->users()) {
+    auto *II = dyn_cast<IntrinsicInst>(User);
+    // Check if this is spv_assign_type assigning OpenCL/SPIR-V builtin type.
+    if (II && II->getIntrinsicID() == Intrinsic::spv_assign_type) {
+      MetadataAsValue *VMD = cast<MetadataAsValue>(II->getOperand(1));
+      Type *BuiltinType =
+          cast<ConstantAsMetadata>(VMD->getMetadata())->getType();
+      assert(BuiltinType->isTargetExtTy() && "Expected TargetExtType");
+      return GR->getOrCreateSPIRVType(BuiltinType, MIRBuilder, ArgAccessQual);
+    }
+
+    // Check if this is spv_assign_ptr_type assigning pointer element type.
+    if (!II || II->getIntrinsicID() != Intrinsic::spv_assign_ptr_type)
+      continue;
+
+    MetadataAsValue *VMD = cast<MetadataAsValue>(II->getOperand(1));
+    SPIRVType *ElementType = GR->getOrCreateSPIRVType(
+        cast<ConstantAsMetadata>(VMD->getMetadata())->getType(), MIRBuilder);
+    return GR->getOrCreateSPIRVPointerType(
+        ElementType, MIRBuilder,
+        addressSpaceToStorageClass(
+            cast<ConstantInt>(II->getOperand(2))->getZExtValue()));
   }
-  return ResArgType ? ResArgType
-                    : GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder,
-                                               ArgAccessQual);
+
+  return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
 }
 
 static SPIRV::ExecutionModel::ExecutionModel
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index e32cd50be56e38..c627427bd9c7a0 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "SPIRV.h"
+#include "SPIRVBuiltins.h"
 #include "SPIRVMetadata.h"
 #include "SPIRVTargetMachine.h"
 #include "SPIRVUtils.h"
@@ -75,7 +76,11 @@ class SPIRVEmitIntrinsics
   void processInstrAfterVisit(Instruction *I);
   void insertAssignPtrTypeIntrs(Instruction *I);
   void insertAssignTypeIntrs(Instruction *I);
-  void insertPtrCastInstr(Instruction *I);
+  void insertAssignTypeInstrForTargetExtTypes(TargetExtType* AssignedType, Value *V);
+  void replacePointerOperandWithPtrCast(Instruction *I, Value *Pointer,
+                                        Type *ExpectedElementType,
+                                        unsigned OperandToReplace);
+  void insertPtrCastOrAssignTypeInstr(Instruction *I);
   void processGlobalValue(GlobalVariable &GV);
 
 public:
@@ -130,13 +135,6 @@ static void setInsertPointSkippingPhis(IRBuilder<> &B, Instruction *I) {
     B.SetInsertPoint(I);
 }
 
-static bool requireAssignPtrType(Instruction *I) {
-  if (isa<AllocaInst>(I) || isa<GetElementPtrInst>(I))
-    return true;
-
-  return false;
-}
-
 static bool requireAssignType(Instruction *I) {
   IntrinsicInst *Intr = dyn_cast<IntrinsicInst>(I);
   if (Intr) {
@@ -269,7 +267,7 @@ Instruction *SPIRVEmitIntrinsics::visitBitCastInst(BitCastInst &I) {
   // SPIR-V, contrary to LLVM 17+ IR, supports bitcasts between pointers of
   // varying element types. In case of IR coming from older versions of LLVM
   // such bitcasts do not provide sufficient information, should be just skipped
-  // here, and handled in insertPtrCastInstr.
+  // here, and handled in insertPtrCastOrAssignTypeInstr.
   if (I.getType()->isPointerTy()) {
     I.replaceAllUsesWith(Source);
     I.eraseFromParent();
@@ -286,34 +284,37 @@ Instruction *SPIRVEmitIntrinsics::visitBitCastInst(BitCastInst &I) {
   return NewI;
 }
 
-void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {
-  Value *Pointer;
-  Type *ExpectedElementType;
-  unsigned OperandToReplace;
+void SPIRVEmitIntrinsics::insertAssignTypeInstrForTargetExtTypes(
+    TargetExtType *AssignedType, Value *V) {
+  // Do not emit spv_assign_type if the V is of the AssignedType already.
+  if (V->getType() == AssignedType)
+    return;
 
-  StoreInst *SI = dyn_cast<StoreInst>(I);
-  if (SI && F->getCallingConv() == CallingConv::SPIR_KERNEL &&
-      SI->getValueOperand()->getType()->isPointerTy() &&
-      isa<Argument>(SI->getValueOperand())) {
-    Pointer = SI->getValueOperand();
-    ExpectedElementType = IntegerType::getInt8Ty(F->getContext());
-    OperandToReplace = 0;
-  } else if (SI) {
-    Pointer = SI->getPointerOperand();
-    ExpectedElementType = SI->getValueOperand()->getType();
-    OperandToReplace = 1;
-  } else if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
-    Pointer = LI->getPointerOperand();
-    ExpectedElementType = LI->getType();
-    OperandToReplace = 0;
-  } else if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(I)) {
-    Pointer = GEPI->getPointerOperand();
-    ExpectedElementType = GEPI->getSourceElementType();
-    OperandToReplace = 0;
-  } else {
+  // Do not emit spv_assign_type if there is one already targetting V. If the
+  // found spv_assign_type assigns a type different than AssignedType, report an
+  // error. Builtin types cannot be redeclared or casted.
+  for (auto User : V->users()) {
+    auto *II = dyn_cast<IntrinsicInst>(User);
+    if (!II || II->getIntrinsicID() != Intrinsic::spv_assign_type)
+      continue;
+
+    MetadataAsValue *VMD = cast<MetadataAsValue>(II->getOperand(1));
+    Type *BuiltinType = dyn_cast<ConstantAsMetadata>(VMD->getMetadata())->getType();
+    if (BuiltinType != AssignedType)
+      report_fatal_error("Type mismatch " + BuiltinType->getTargetExtName() +
+                             "/" + AssignedType->getTargetExtName() +
+                             " for value " + V->getName(),
+                         false);
     return;
   }
 
+  Constant *Const = UndefValue::get(AssignedType);
+  buildIntrWithMD(Intrinsic::spv_assign_type, {V->getType()}, Const, V, {});
+}
+
+void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
+    Instruction *I, Value *Pointer, Type *ExpectedElementType,
+    unsigned OperandToReplace) {
   // If Pointer is the result of nop BitCastInst (ptr -> ptr), use the source
   // pointer instead. The BitCastInst should be later removed when visited.
   while (BitCastInst *BC = dyn_cast<BitCastInst>(Pointer))
@@ -378,38 +379,76 @@ void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {
     return;
   }
 
-  // Do not emit spv_ptrcast if it would cast to the default pointer element
-  // type (i8) of the same address space. In case of OpenCL kernels, make sure
-  // i8 is the pointer element type defined for the given kernel argument.
-  if (ExpectedElementType->isIntegerTy(8) &&
-      F->getCallingConv() != CallingConv::SPIR_KERNEL)
-    return;
+  // // Do not emit spv_ptrcast if it would cast to the default pointer element
+  // // type (i8) of the same address space.
+  // if (ExpectedElementType->isIntegerTy(8))
+  //   return;
 
-  Argument *Arg = dyn_cast<Argument>(Pointer);
-  if (ExpectedElementType->isIntegerTy(8) &&
-      F->getCallingConv() == CallingConv::SPIR_KERNEL && Arg) {
-    MDString *ArgType = getOCLKernelArgType(*Arg->getParent(), Arg->getArgNo());
-    if (ArgType && ArgType->getString().starts_with("uchar*"))
-      return;
-  }
-
-  // If this would be the first spv_ptrcast, the pointer's defining instruction
-  // requires spv_assign_ptr_type and does not already have one, do not emit
-  // spv_ptrcast and emit spv_assign_ptr_type instead.
-  Instruction *PointerDefInst = dyn_cast<Instruction>(Pointer);
-  if (FirstPtrCastOrAssignPtrType && PointerDefInst &&
-      requireAssignPtrType(PointerDefInst)) {
+  // If this would be the first spv_ptrcast, do not emit spv_ptrcast and emit
+  // spv_assign_ptr_type instead.
+  if (FirstPtrCastOrAssignPtrType &&
+      (isa<Instruction>(Pointer) || isa<Argument>(Pointer))) {
     buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {Pointer->getType()},
                     ExpectedElementTypeConst, Pointer,
                     {IRB->getInt32(AddressSpace)});
     return;
-  } else {
-    SmallVector<Type *, 2> Types = {Pointer->getType(), Pointer->getType()};
-    SmallVector<Value *, 2> Args = {Pointer, VMD, IRB->getInt32(AddressSpace)};
-    auto *PtrCastI =
-        IRB->CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
-    I->setOperand(OperandToReplace, PtrCastI);
+  }
+
+  // Emit spv_ptrcast
+  SmallVector<Type *, 2> Types = {Pointer->getType(), Pointer->getType()};
+  SmallVector<Value *, 2> Args = {Pointer, VMD, IRB->getInt32(AddressSpace)};
+  auto *PtrCastI = IRB->CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
+  I->setOperand(OperandToReplace, PtrCastI);
+}
+
+void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I) {
+  // Handle basic instructions:
+  StoreInst *SI = dyn_cast<StoreInst>(I);
+  if (SI && F->getCallingConv() == CallingConv::SPIR_KERNEL &&
+      SI->getValueOperand()->getType()->isPointerTy() &&
+      isa<Argument>(SI->getValueOperand())) {
+    return replacePointerOperandWithPtrCast(
+        I, SI->getValueOperand(), IntegerType::getInt8Ty(F->getContext()), 0);
+  } else if (SI) {
+    return replacePointerOperandWithPtrCast(
+        I, SI->getPointerOperand(), SI->getValueOperand()->getType(), 1);
+  } else if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
+    return replacePointerOperandWithPtrCast(I, LI->getPointerOperand(),
+                                            LI->getType(), 0);
+  } else if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(I)) {
+    return replacePointerOperandWithPtrCast(I, GEPI->getPointerOperand(),
+                                            GEPI->getSourceElementType(), 0);
+  }
+
+  // Handle calls to builtins (non-intrinsics):
+  CallInst *CI = dyn_cast<CallInst>(I);
+  if (!CI || CI->isIndirectCall() || CI->getCalledFunction()->isIntrinsic())
+    return;
+
+  std::string DemangledName =
+      getOclOrSpirvBuiltinDemangledName(CI->getCalledFunction()->getName());
+  if (DemangledName.empty())
     return;
+
+  for (unsigned OpIdx = 0; OpIdx < CI->arg_size(); OpIdx++) {
+    Value *ArgOperand = CI->getArgOperand(OpIdx);
+    if (!isa<PointerType>(ArgOperand->getType()))
+      continue;
+
+    // Constants (nulls/undefs) are handled in insertAssignPtrTypeIntrs()
+    if (!isa<Instruction>(ArgOperand) && !isa<Argument>(ArgOperand))
+      continue;
+
+    Type *ExpectedType = SPIRV::parseBuiltinCallArgumentBaseType(
+        DemangledName, OpIdx, I->getContext());
+    if (!ExpectedType)
+      cont...
[truncated]

@michalpaszkowski michalpaszkowski force-pushed the feature_do_not_rely_on_type_metadata branch from 8b57539 to 275e0c8 Compare February 28, 2024 12:52
@michalpaszkowski michalpaszkowski changed the title [SPIR-V] Do not rely on type metadata for ptr element type resolution [SPIR-V] Do not use OpenCL metadata for ptr element type resolution Feb 28, 2024
@michalpaszkowski
Copy link
Member Author

The patch is mostly finished and ready for review. I might be able to remove some more dead code (working on this), need to rebase, and will add a couple more tests by the end of this week.

In the next patch, I am planning to remove the rest of SPIRVMetadata.

@@ -0,0 +1,19 @@
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this test, the generated SPIR-V appears to be working fine with the NEO driver but differs when compared to the output of SPIR-V Translator. OpFunctionParameter has i32* type instead of %nested_struct*.

@@ -597,7 +636,7 @@ void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I) {
if (isa<UndefValue>(Op) && Op->getType()->isAggregateType())
buildIntrWithMD(Intrinsic::spv_assign_type, {IRB->getInt32Ty()}, Op,
UndefValue::get(IRB->getInt32Ty()), {});
else
else if (!isa<Instruction>(Op)) // TODO: This case could be removed
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will remove/refactor this in the next patch, so that it does not obscure the changes here

@@ -0,0 +1,12 @@
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -print-after-all -o - 2>&1 | FileCheck %s

; CHECK: *** IR Dump After SPIRV emit intrinsics (emit-intrinsics) ***
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests check the output LLVM IR after running the SPIRVEmitIntrinsics pass. This might be a good way to document the behavior of some of the more complex passes we have.

@michalpaszkowski michalpaszkowski force-pushed the feature_do_not_rely_on_type_metadata branch from 275e0c8 to bf3df62 Compare March 4, 2024 05:19
@michalpaszkowski
Copy link
Member Author

I will merge the patch in a moment as discussed in the SPIR-V BE meeting

@michalpaszkowski michalpaszkowski merged commit 43222bd into llvm:main Mar 4, 2024
4 of 5 checks passed
llvm_unreachable("Unexpected instruction!");
}
else if (I->getType()->isPointerTy())
EltTyConst = UndefValue::get(IntegerType::getInt8Ty(I->getContext()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dropping the else/llvm_unreachable case here results in a -Wsometimes-uninitialized warning. We probably want to add it back to get a reliable crash in debug builds rather than UB.

llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp:623:12: error: variable 'EltTyConst' is used uninitialized whenever 'if' condition is false [-Werror,-Wsometimes-uninitialized]
  else if (I->getType()->isPointerTy())
           ^~~~~~~~~~~~~~~~~~~~~~~~~~~
llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp:626:67: note: uninitialized use occurs here
  buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {I->getType()}, EltTyConst, I,
                                                                  ^~~~~~~~~~
llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp:623:8: note: remove the 'if' if its condition is always true
  else if (I->getType()->isPointerTy())
       ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp:617:23: note: initialize the variable 'EltTyConst' to silence this warning
  Constant *EltTyConst;
                      ^
                       = nullptr

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bogner Thank you for catching this! Yes, this should be in an else clause. Here is a pull request fixing this: #83901

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants