Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPIRV] Improve builtins matching and type inference in SPIR-V Backend, fix target ext type constants #89948

Merged
merged 2 commits into from
Apr 26, 2024

Conversation

VyacheslavLevytskyy
Copy link
Contributor

@VyacheslavLevytskyy VyacheslavLevytskyy commented Apr 24, 2024

This PR is to improve builtins matching and type inference in SPIR-V Backend. The model test case is printf call from OpenCL.std that has several features allowing for a wider look at builtins support/type inference:
(1) call in a "spirv-friendly" style (prefixed by _spirv_ocl)
(2) restricted type of the 1st argument

Attached test cases checks several possible inputs. Support of the extension SPV_EXT_relaxed_printf_string_address_space is to do (see: https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/EXT/SPV_EXT_relaxed_printf_string_address_space.asciidoc).

This PR also fixes target ext type constants and OpGroupAsyncCopy/OpGroupWaitEvents generation. A new test case is attached.

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 24, 2024

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

Changes

This PR is to improve builtins matching and type inference in SPIR-V Backend. The model test case is printf call from OpenCL.std that has several features allowing for a wider look at builtins support/type inference:
(1) call in a "spirv-friendly" style (prefixed by _spirv_ocl)
(2) restricted type of the 1st argument

Attached test cases checks several possible inputs. Support of the extension SPV_EXT_relaxed_printf_string_address_space is to do (see: https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/EXT/SPV_EXT_relaxed_printf_string_address_space.asciidoc).


Full diff: https://github.com/llvm/llvm-project/pull/89948.diff

3 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp (+5-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+64-23)
  • (added) llvm/test/CodeGen/SPIRV/printf.ll (+40)
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 4b07d7e61fa113..64227ae062dfde 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -189,6 +189,10 @@ lookupBuiltin(StringRef DemangledCall,
   std::string BuiltinName =
       DemangledCall.substr(0, DemangledCall.find('(')).str();
 
+  // Account for possible "__spirv_ocl_" prefix in SPIR-V friendly LLVM IR
+  if (BuiltinName.rfind("__spirv_ocl_", 0) == 0)
+    BuiltinName = BuiltinName.substr(12);
+
   // Check if the extracted name contains type information between angle
   // brackets. If so, the builtin is an instantiated template - needs to have
   // the information after angle brackets and return type removed.
@@ -2306,7 +2310,7 @@ Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
     // parseBuiltinCallArgumentBaseType(...) as this function only retrieves the
     // base types.
     if (TypeStr.ends_with("*"))
-      TypeStr = TypeStr.slice(0, TypeStr.find_first_of(" "));
+      TypeStr = TypeStr.slice(0, TypeStr.find_first_of(" *"));
 
     return parseBuiltinTypeNameToTargetExtType("opencl." + TypeStr.str() + "_t",
                                                Ctx);
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 472bc8638c9af1..204a662240e54f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -98,6 +98,8 @@ class SPIRVEmitIntrinsics
     return B.CreateIntrinsic(IntrID, {Types}, Args);
   }
 
+  void buildAssignPtr(IRBuilder<> &B, Type *ElemTy, Value *Arg);
+
   void replaceMemInstrUses(Instruction *Old, Instruction *New, IRBuilder<> &B);
   void processInstrAfterVisit(Instruction *I, IRBuilder<> &B);
   void insertAssignPtrTypeIntrs(Instruction *I, IRBuilder<> &B);
@@ -111,6 +113,7 @@ class SPIRVEmitIntrinsics
   void insertPtrCastOrAssignTypeInstr(Instruction *I, IRBuilder<> &B);
   void processGlobalValue(GlobalVariable &GV, IRBuilder<> &B);
   void processParamTypes(Function *F, IRBuilder<> &B);
+  void processParamTypesByFunHeader(Function *F, IRBuilder<> &B);
   Type *deduceFunParamElementType(Function *F, unsigned OpIdx);
   Type *deduceFunParamElementType(Function *F, unsigned OpIdx,
                                   std::unordered_set<Function *> &FVisited);
@@ -194,6 +197,17 @@ static inline void reportFatalOnTokenType(const Instruction *I) {
                        false);
 }
 
+void SPIRVEmitIntrinsics::buildAssignPtr(IRBuilder<> &B, Type *ElemTy,
+                                         Value *Arg) {
+  CallInst *AssignPtrTyCI =
+      buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {Arg->getType()},
+                      Constant::getNullValue(ElemTy), Arg,
+                      {B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
+  GR->addDeducedElementType(AssignPtrTyCI, ElemTy);
+  GR->addDeducedElementType(Arg, ElemTy);
+  AssignPtrTypeInstr[Arg] = AssignPtrTyCI;
+}
+
 // Set element pointer type to the given value of ValueTy and tries to
 // specify this type further (recursively) by Operand value, if needed.
 Type *SPIRVEmitIntrinsics::deduceElementTypeByValueDeep(
@@ -232,6 +246,19 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeByUsersDeep(
   return nullptr;
 }
 
+// Implements what we know in advance about intrinsics and builtin calls
+// TODO: consider feasibility of this particular case to be generalized by
+// encoding knowledge about intrinsics and builtin calls by corresponding
+// specification rules
+static Type *getPointeeTypeByCallInst(StringRef DemangledName,
+                                      Function *CalledF, unsigned OpIdx) {
+  if ((DemangledName.starts_with("__spirv_ocl_printf(") ||
+       DemangledName.starts_with("printf(")) &&
+      OpIdx == 0)
+    return IntegerType::getInt8Ty(CalledF->getContext());
+  return nullptr;
+}
+
 // Deduce and return a successfully deduced Type of the Instruction,
 // or nullptr otherwise.
 Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(Value *I) {
@@ -795,6 +822,8 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
     return;
 
   // collect information about formal parameter types
+  std::string DemangledName =
+      getOclOrSpirvBuiltinDemangledName(CI->getCalledFunction()->getName());
   Function *CalledF = CI->getCalledFunction();
   SmallVector<Type *, 4> CalledArgTys;
   bool HaveTypes = false;
@@ -811,10 +840,15 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
       if (!ElemTy && hasPointeeTypeAttr(CalledArg))
         ElemTy = getPointeeTypeByAttr(CalledArg);
       if (!ElemTy) {
-        for (User *U : CalledArg->users()) {
-          if (Instruction *Inst = dyn_cast<Instruction>(U)) {
-            if ((ElemTy = deduceElementTypeHelper(Inst)) != nullptr)
-              break;
+        ElemTy = getPointeeTypeByCallInst(DemangledName, CalledF, OpIdx);
+        if (ElemTy) {
+          GR->addDeducedElementType(CalledArg, ElemTy);
+        } else {
+          for (User *U : CalledArg->users()) {
+            if (Instruction *Inst = dyn_cast<Instruction>(U)) {
+              if ((ElemTy = deduceElementTypeHelper(Inst)) != nullptr)
+                break;
+            }
           }
         }
       }
@@ -823,8 +857,6 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
     }
   }
 
-  std::string DemangledName =
-      getOclOrSpirvBuiltinDemangledName(CI->getCalledFunction()->getName());
   if (DemangledName.empty() && !HaveTypes)
     return;
 
@@ -835,8 +867,14 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
       continue;
 
     // Constants (nulls/undefs) are handled in insertAssignPtrTypeIntrs()
-    if (!isa<Instruction>(ArgOperand) && !isa<Argument>(ArgOperand))
-      continue;
+    if (!isa<Instruction>(ArgOperand) && !isa<Argument>(ArgOperand)) {
+      // However, we may have assumptions about the formal argument's type and
+      // may have a need to insert a ptr cast for the actual parameter of this
+      // call.
+      Argument *CalledArg = CalledF->getArg(OpIdx);
+      if (!GR->findDeducedElementType(CalledArg))
+        continue;
+    }
 
     Type *ExpectedType =
         OpIdx < CalledArgTys.size() ? CalledArgTys[OpIdx] : nullptr;
@@ -1179,28 +1217,29 @@ Type *SPIRVEmitIntrinsics::deduceFunParamElementType(
   return nullptr;
 }
 
-void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
+void SPIRVEmitIntrinsics::processParamTypesByFunHeader(Function *F,
+                                                       IRBuilder<> &B) {
   B.SetInsertPointPastAllocas(F);
   for (unsigned OpIdx = 0; OpIdx < F->arg_size(); ++OpIdx) {
     Argument *Arg = F->getArg(OpIdx);
     if (!isUntypedPointerTy(Arg->getType()))
       continue;
+    Type *ElemTy = GR->findDeducedElementType(Arg);
+    if (!ElemTy && hasPointeeTypeAttr(Arg) &&
+        (ElemTy = getPointeeTypeByAttr(Arg)) != nullptr)
+      buildAssignPtr(B, ElemTy, Arg);
+  }
+}
 
+void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
+  B.SetInsertPointPastAllocas(F);
+  for (unsigned OpIdx = 0; OpIdx < F->arg_size(); ++OpIdx) {
+    Argument *Arg = F->getArg(OpIdx);
+    if (!isUntypedPointerTy(Arg->getType()))
+      continue;
     Type *ElemTy = GR->findDeducedElementType(Arg);
-    if (!ElemTy) {
-      if (hasPointeeTypeAttr(Arg) &&
-          (ElemTy = getPointeeTypeByAttr(Arg)) != nullptr) {
-        GR->addDeducedElementType(Arg, ElemTy);
-      } else if ((ElemTy = deduceFunParamElementType(F, OpIdx)) != nullptr) {
-        CallInst *AssignPtrTyCI = buildIntrWithMD(
-            Intrinsic::spv_assign_ptr_type, {Arg->getType()},
-            Constant::getNullValue(ElemTy), Arg,
-            {B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
-        GR->addDeducedElementType(AssignPtrTyCI, ElemTy);
-        GR->addDeducedElementType(Arg, ElemTy);
-        AssignPtrTypeInstr[Arg] = AssignPtrTyCI;
-      }
-    }
+    if (!ElemTy && (ElemTy = deduceFunParamElementType(F, OpIdx)) != nullptr)
+      buildAssignPtr(B, ElemTy, Arg);
   }
 }
 
@@ -1217,6 +1256,8 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
   AggrConstTypes.clear();
   AggrStores.clear();
 
+  processParamTypesByFunHeader(F, B);
+
   // StoreInst's operand type can be changed during the next transformations,
   // so we need to store it in the set. Also store already transformed types.
   for (auto &I : instructions(Func)) {
diff --git a/llvm/test/CodeGen/SPIRV/printf.ll b/llvm/test/CodeGen/SPIRV/printf.ll
new file mode 100644
index 00000000000000..483fc1f244e57c
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/printf.ll
@@ -0,0 +1,40 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK: %[[#ExtImport:]] = OpExtInstImport "OpenCL.std"
+; CHECK: %[[#Char:]] = OpTypeInt 8 0
+; CHECK: %[[#CharPtr:]] = OpTypePointer UniformConstant %[[#Char]]
+; CHECK: %[[#GV:]] = OpVariable %[[#]] UniformConstant %[[#]]
+; CHECK: OpFunction
+; CHECK: %[[#Arg1:]] = OpFunctionParameter
+; CHECK: %[[#Arg2:]] = OpFunctionParameter
+; CHECK: %[[#CastedGV:]] = OpBitcast %[[#CharPtr]] %[[#GV]]
+; CHECK-NEXT: OpExtInst %[[#]] %[[#ExtImport]] printf %[[#CastedGV]] %[[#ArgConst:]]
+; CHECK-NEXT: OpExtInst %[[#]] %[[#ExtImport]] printf %[[#CastedGV]] %[[#ArgConst]]
+; CHECK-NEXT: OpExtInst %[[#]] %[[#ExtImport]] printf %[[#Arg1]] %[[#ArgConst:]]
+; CHECK-NEXT: OpExtInst %[[#]] %[[#ExtImport]] printf %[[#Arg1]] %[[#ArgConst]]
+; CHECK-NEXT: %[[#CastedArg2:]] = OpBitcast %[[#CharPtr]] %[[#Arg2]]
+; CHECK-NEXT: OpExtInst %[[#]] %[[#ExtImport]] printf %[[#CastedArg2]] %[[#ArgConst]]
+; CHECK-NEXT: OpExtInst %[[#]] %[[#ExtImport]] printf %[[#CastedArg2]] %[[#ArgConst]]
+; CHECK: OpFunctionEnd
+
+%struct = type { [6 x i8] }
+
+@FmtStr = internal addrspace(2) constant [6 x i8] c"c=%c\0A\00", align 1
+
+define spir_kernel void @foo(ptr addrspace(2) %_arg_fmt1, ptr addrspace(2) byval(%struct) %_arg_fmt2) {
+entry:
+  %r1 = tail call spir_func i32 (ptr addrspace(2), ...) @_Z6printfPU3AS2Kcz(ptr addrspace(2) @FmtStr, i8 signext 97)
+  %r2 = tail call spir_func i32 (ptr addrspace(2), ...) @_Z18__spirv_ocl_printfPU3AS2Kcz(ptr addrspace(2) @FmtStr, i8 signext 97)
+  %r3 = tail call spir_func i32 (ptr addrspace(2), ...) @_Z6printfPU3AS2Kcz(ptr addrspace(2) %_arg_fmt1, i8 signext 97)
+  %r4 = tail call spir_func i32 (ptr addrspace(2), ...) @_Z18__spirv_ocl_printfPU3AS2Kcz(ptr addrspace(2) %_arg_fmt1, i8 signext 97)
+  %r5 = tail call spir_func i32 (ptr addrspace(2), ...) @_Z6printfPU3AS2Kcz(ptr addrspace(2) %_arg_fmt2, i8 signext 97)
+  %r6 = tail call spir_func i32 (ptr addrspace(2), ...) @_Z18__spirv_ocl_printfPU3AS2Kcz(ptr addrspace(2) %_arg_fmt2, i8 signext 97)
+  ret void
+}
+
+declare dso_local spir_func i32 @_Z6printfPU3AS2Kcz(ptr addrspace(2), ...)
+declare dso_local spir_func i32 @_Z18__spirv_ocl_printfPU3AS2Kcz(ptr addrspace(2), ...)

@VyacheslavLevytskyy VyacheslavLevytskyy changed the title [SPIRV] Improve builtins matching and type inference in SPIR-V Backend [SPIRV] Improve builtins matching and type inference in SPIR-V Backend, fix target ext type constants Apr 25, 2024
@VyacheslavLevytskyy VyacheslavLevytskyy merged commit 1ed1ec9 into llvm:main Apr 26, 2024
4 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants