Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPIRV] Add type inference of function parameters by call instances #85077

Merged
merged 4 commits into from
Mar 14, 2024

Conversation

VyacheslavLevytskyy
Copy link
Contributor

This PR adds type inference of function parameters by call instances. Two use cases that demonstrate the problem are added.

@VyacheslavLevytskyy VyacheslavLevytskyy changed the title Add type inference of function parameters by call instances [SPIRV] Add type inference of function parameters by call instances Mar 13, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Mar 13, 2024

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

Changes

This PR adds type inference of function parameters by call instances. Two use cases that demonstrate the problem are added.


Full diff: https://github.com/llvm/llvm-project/pull/85077.diff

5 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp (+1-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+62)
  • (modified) llvm/lib/Target/SPIRV/SPIRVUtils.h (+5)
  • (added) llvm/test/CodeGen/SPIRV/pointers/type-deduce-args-rev.ll (+28)
  • (added) llvm/test/CodeGen/SPIRV/pointers/type-deduce-args.ll (+28)
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index f1fbe2ba1bc416..77319f58ff4d97 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -209,7 +209,7 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
   // spv_assign_ptr_type intrinsic or otherwise use default pointer element
   // type.
   Argument *Arg = F.getArg(ArgIdx);
-  if (Arg->hasByValAttr() || Arg->hasByRefAttr()) {
+  if (HasPointeeTypeAttr(Arg)) {
     Type *ByValRefType = Arg->hasByValAttr() ? Arg->getParamByValType()
                                              : Arg->getParamByRefType();
     SPIRVType *ElementType = GR->getOrCreateSPIRVType(ByValRefType, MIRBuilder);
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index c5b901235402c1..e9099fac1c1a3c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -91,6 +91,7 @@ class SPIRVEmitIntrinsics
                                         IRBuilder<> &B);
   void insertPtrCastOrAssignTypeInstr(Instruction *I, IRBuilder<> &B);
   void processGlobalValue(GlobalVariable &GV, IRBuilder<> &B);
+  void processParamTypes(Function *F, IRBuilder<> &B);
 
 public:
   static char ID;
@@ -794,6 +795,62 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
   }
 }
 
+void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
+  DenseMap<unsigned, Argument *> Args;
+  unsigned i = 0;
+  for (Argument &Arg : F->args()) {
+    if (isUntypedPointerTy(Arg.getType()) &&
+        DeducedElTys.find(&Arg) == DeducedElTys.end() &&
+        !HasPointeeTypeAttr(&Arg))
+      Args[i++] = &Arg;
+  }
+  if (i == 0)
+    return;
+
+  // Args contains opaque pointers without element type definition
+  B.SetInsertPointPastAllocas(F);
+  std::unordered_set<Value *> Visited;
+  for (User *U : F->users()) {
+    CallInst *CI = dyn_cast<CallInst>(U);
+    if (!CI)
+      continue;
+    for (unsigned OpIdx = 0; OpIdx < CI->arg_size() && Args.size() > 0;
+         OpIdx++) {
+      auto It = Args.find(OpIdx);
+      Argument *Arg = It == Args.end() ? nullptr : It->second;
+      if (!Arg)
+        continue;
+      Value *OpArg = CI->getArgOperand(OpIdx);
+      if (!isPointerTy(OpArg->getType()))
+        continue;
+      // maybe we already know the operand's element type
+      auto DeducedIt = DeducedElTys.find(OpArg);
+      Type *ElemTy = DeducedIt == DeducedElTys.end() ? nullptr : DeducedIt->second;
+      if (!ElemTy) {
+        for (User *OpU : OpArg->users()) {
+          if (Instruction *Inst = dyn_cast<Instruction>(OpU)) {
+            Visited.clear();
+            ElemTy = deduceElementTypeHelper(Inst, Visited, DeducedElTys);
+            if (ElemTy)
+              break;
+          }
+        }
+      }
+      if (ElemTy) {
+        unsigned AddressSpace = getPointerAddressSpace(Arg->getType());
+        CallInst *AssignPtrTyCI = buildIntrWithMD(
+            Intrinsic::spv_assign_ptr_type, {Arg->getType()},
+            Constant::getNullValue(ElemTy), Arg, {B.getInt32(AddressSpace)}, B);
+        DeducedElTys[AssignPtrTyCI] = ElemTy;
+        DeducedElTys[Arg] = ElemTy;
+        Args.erase(It);
+      }
+    }
+    if (Args.size() == 0)
+      break;
+  }
+}
+
 bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
   if (Func.isDeclaration())
     return false;
@@ -839,6 +896,11 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
       continue;
     processInstrAfterVisit(I, B);
   }
+
+  // check if function parameter types are set
+  if (!F->isIntrinsic())
+    processParamTypes(F, B);
+
   return true;
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index d5ed501def9986..eb87349f0941c5 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -126,5 +126,10 @@ inline unsigned getPointerAddressSpace(const Type *T) {
              : cast<TypedPointerType>(SubT)->getAddressSpace();
 }
 
+// Return true if the Argument is decorated with a pointee type
+inline bool HasPointeeTypeAttr(Argument *Arg) {
+  return Arg->hasByValAttr() || Arg->hasByRefAttr();
+}
+
 } // namespace llvm
 #endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H
diff --git a/llvm/test/CodeGen/SPIRV/pointers/type-deduce-args-rev.ll b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-args-rev.ll
new file mode 100644
index 00000000000000..3f8edfef78b03c
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-args-rev.ll
@@ -0,0 +1,28 @@
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV-DAG: OpName %[[FooArg:.*]] "known_type_ptr"
+; CHECK-SPIRV-DAG: OpName %[[Foo:.*]] "foo"
+; CHECK-SPIRV-DAG: OpName %[[ArgToDeduce:.*]] "unknown_type_ptr"
+; CHECK-SPIRV-DAG: OpName %[[Bar:.*]] "bar"
+; CHECK-SPIRV-DAG: %[[Long:.*]] = OpTypeInt 32 0
+; CHECK-SPIRV-DAG: %[[Void:.*]] = OpTypeVoid
+; CHECK-SPIRV-DAG: %[[LongPtr:.*]] = OpTypePointer CrossWorkgroup %[[Long]]
+; CHECK-SPIRV-DAG: %[[Fun:.*]] = OpTypeFunction %[[Void]] %[[LongPtr]]
+; CHECK-SPIRV: %[[Bar]] = OpFunction %[[Void]] None %[[Fun]]
+; CHECK-SPIRV: %[[ArgToDeduce]] = OpFunctionParameter %[[LongPtr]]
+; CHECK-SPIRV: OpFunctionCall %[[Void]] %[[Foo]] %[[ArgToDeduce]]
+; CHECK-SPIRV: %[[Foo]] = OpFunction %[[Void]] None %[[Fun]]
+; CHECK-SPIRV: %[[FooArg]] = OpFunctionParameter %[[LongPtr]]
+
+define spir_kernel void @bar(ptr addrspace(1) %unknown_type_ptr) {
+entry:
+  %elem = getelementptr inbounds i32, ptr addrspace(1) %unknown_type_ptr, i64 0
+  call spir_func void @foo(ptr addrspace(1) %unknown_type_ptr)
+  ret void
+}
+
+define void @foo(ptr addrspace(1) %known_type_ptr) {
+entry:
+  ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/pointers/type-deduce-args.ll b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-args.ll
new file mode 100644
index 00000000000000..be8582f9226d5c
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-args.ll
@@ -0,0 +1,28 @@
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV-DAG: OpName %[[FooArg:.*]] "known_type_ptr"
+; CHECK-SPIRV-DAG: OpName %[[Foo:.*]] "foo"
+; CHECK-SPIRV-DAG: OpName %[[ArgToDeduce:.*]] "unknown_type_ptr"
+; CHECK-SPIRV-DAG: OpName %[[Bar:.*]] "bar"
+; CHECK-SPIRV-DAG: %[[Long:.*]] = OpTypeInt 32 0
+; CHECK-SPIRV-DAG: %[[Void:.*]] = OpTypeVoid
+; CHECK-SPIRV-DAG: %[[LongPtr:.*]] = OpTypePointer CrossWorkgroup %[[Long]]
+; CHECK-SPIRV-DAG: %[[Fun:.*]] = OpTypeFunction %[[Void]] %[[LongPtr]]
+; CHECK-SPIRV: %[[Foo]] = OpFunction %[[Void]] None %[[Fun]]
+; CHECK-SPIRV: %[[FooArg]] = OpFunctionParameter %[[LongPtr]]
+; CHECK-SPIRV: %[[Bar]] = OpFunction %[[Void]] None %[[Fun]]
+; CHECK-SPIRV: %[[ArgToDeduce]] = OpFunctionParameter %[[LongPtr]]
+; CHECK-SPIRV: OpFunctionCall %[[Void]] %[[Foo]] %[[ArgToDeduce]]
+
+define void @foo(ptr addrspace(1) %known_type_ptr) {
+entry:
+  ret void
+}
+
+define spir_kernel void @bar(ptr addrspace(1) %unknown_type_ptr) {
+entry:
+  %elem = getelementptr inbounds i32, ptr addrspace(1) %unknown_type_ptr, i64 0
+  call spir_func void @foo(ptr addrspace(1) %unknown_type_ptr)
+  ret void
+}

Copy link

github-actions bot commented Mar 13, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Member

@michalpaszkowski michalpaszkowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the patch! LGTM! No regressions in OpenCL CTS. Two new tests pass SPIR-V validation: basic/progvar_prog_scope_misc, basic/vstore_private.

@VyacheslavLevytskyy VyacheslavLevytskyy merged commit afec257 into llvm:main Mar 14, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants