-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPIRV] Add type inference of function parameters by call instances #85077
[SPIRV] Add type inference of function parameters by call instances #85077
Conversation
@llvm/pr-subscribers-backend-spir-v Author: Vyacheslav Levytskyy (VyacheslavLevytskyy) ChangesThis PR adds type inference of function parameters by call instances. Two use cases that demonstrate the problem are added. Full diff: https://github.com/llvm/llvm-project/pull/85077.diff 5 Files Affected:
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index f1fbe2ba1bc416..77319f58ff4d97 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -209,7 +209,7 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
// spv_assign_ptr_type intrinsic or otherwise use default pointer element
// type.
Argument *Arg = F.getArg(ArgIdx);
- if (Arg->hasByValAttr() || Arg->hasByRefAttr()) {
+ if (HasPointeeTypeAttr(Arg)) {
Type *ByValRefType = Arg->hasByValAttr() ? Arg->getParamByValType()
: Arg->getParamByRefType();
SPIRVType *ElementType = GR->getOrCreateSPIRVType(ByValRefType, MIRBuilder);
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index c5b901235402c1..e9099fac1c1a3c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -91,6 +91,7 @@ class SPIRVEmitIntrinsics
IRBuilder<> &B);
void insertPtrCastOrAssignTypeInstr(Instruction *I, IRBuilder<> &B);
void processGlobalValue(GlobalVariable &GV, IRBuilder<> &B);
+ void processParamTypes(Function *F, IRBuilder<> &B);
public:
static char ID;
@@ -794,6 +795,62 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
}
}
+void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
+ DenseMap<unsigned, Argument *> Args;
+ unsigned i = 0;
+ for (Argument &Arg : F->args()) {
+ if (isUntypedPointerTy(Arg.getType()) &&
+ DeducedElTys.find(&Arg) == DeducedElTys.end() &&
+ !HasPointeeTypeAttr(&Arg))
+ Args[i++] = &Arg;
+ }
+ if (i == 0)
+ return;
+
+ // Args contains opaque pointers without element type definition
+ B.SetInsertPointPastAllocas(F);
+ std::unordered_set<Value *> Visited;
+ for (User *U : F->users()) {
+ CallInst *CI = dyn_cast<CallInst>(U);
+ if (!CI)
+ continue;
+ for (unsigned OpIdx = 0; OpIdx < CI->arg_size() && Args.size() > 0;
+ OpIdx++) {
+ auto It = Args.find(OpIdx);
+ Argument *Arg = It == Args.end() ? nullptr : It->second;
+ if (!Arg)
+ continue;
+ Value *OpArg = CI->getArgOperand(OpIdx);
+ if (!isPointerTy(OpArg->getType()))
+ continue;
+ // maybe we already know the operand's element type
+ auto DeducedIt = DeducedElTys.find(OpArg);
+ Type *ElemTy = DeducedIt == DeducedElTys.end() ? nullptr : DeducedIt->second;
+ if (!ElemTy) {
+ for (User *OpU : OpArg->users()) {
+ if (Instruction *Inst = dyn_cast<Instruction>(OpU)) {
+ Visited.clear();
+ ElemTy = deduceElementTypeHelper(Inst, Visited, DeducedElTys);
+ if (ElemTy)
+ break;
+ }
+ }
+ }
+ if (ElemTy) {
+ unsigned AddressSpace = getPointerAddressSpace(Arg->getType());
+ CallInst *AssignPtrTyCI = buildIntrWithMD(
+ Intrinsic::spv_assign_ptr_type, {Arg->getType()},
+ Constant::getNullValue(ElemTy), Arg, {B.getInt32(AddressSpace)}, B);
+ DeducedElTys[AssignPtrTyCI] = ElemTy;
+ DeducedElTys[Arg] = ElemTy;
+ Args.erase(It);
+ }
+ }
+ if (Args.size() == 0)
+ break;
+ }
+}
+
bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
if (Func.isDeclaration())
return false;
@@ -839,6 +896,11 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
continue;
processInstrAfterVisit(I, B);
}
+
+ // check if function parameter types are set
+ if (!F->isIntrinsic())
+ processParamTypes(F, B);
+
return true;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index d5ed501def9986..eb87349f0941c5 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -126,5 +126,10 @@ inline unsigned getPointerAddressSpace(const Type *T) {
: cast<TypedPointerType>(SubT)->getAddressSpace();
}
+// Return true if the Argument is decorated with a pointee type
+inline bool HasPointeeTypeAttr(Argument *Arg) {
+ return Arg->hasByValAttr() || Arg->hasByRefAttr();
+}
+
} // namespace llvm
#endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H
diff --git a/llvm/test/CodeGen/SPIRV/pointers/type-deduce-args-rev.ll b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-args-rev.ll
new file mode 100644
index 00000000000000..3f8edfef78b03c
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-args-rev.ll
@@ -0,0 +1,28 @@
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV-DAG: OpName %[[FooArg:.*]] "known_type_ptr"
+; CHECK-SPIRV-DAG: OpName %[[Foo:.*]] "foo"
+; CHECK-SPIRV-DAG: OpName %[[ArgToDeduce:.*]] "unknown_type_ptr"
+; CHECK-SPIRV-DAG: OpName %[[Bar:.*]] "bar"
+; CHECK-SPIRV-DAG: %[[Long:.*]] = OpTypeInt 32 0
+; CHECK-SPIRV-DAG: %[[Void:.*]] = OpTypeVoid
+; CHECK-SPIRV-DAG: %[[LongPtr:.*]] = OpTypePointer CrossWorkgroup %[[Long]]
+; CHECK-SPIRV-DAG: %[[Fun:.*]] = OpTypeFunction %[[Void]] %[[LongPtr]]
+; CHECK-SPIRV: %[[Bar]] = OpFunction %[[Void]] None %[[Fun]]
+; CHECK-SPIRV: %[[ArgToDeduce]] = OpFunctionParameter %[[LongPtr]]
+; CHECK-SPIRV: OpFunctionCall %[[Void]] %[[Foo]] %[[ArgToDeduce]]
+; CHECK-SPIRV: %[[Foo]] = OpFunction %[[Void]] None %[[Fun]]
+; CHECK-SPIRV: %[[FooArg]] = OpFunctionParameter %[[LongPtr]]
+
+define spir_kernel void @bar(ptr addrspace(1) %unknown_type_ptr) {
+entry:
+ %elem = getelementptr inbounds i32, ptr addrspace(1) %unknown_type_ptr, i64 0
+ call spir_func void @foo(ptr addrspace(1) %unknown_type_ptr)
+ ret void
+}
+
+define void @foo(ptr addrspace(1) %known_type_ptr) {
+entry:
+ ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/pointers/type-deduce-args.ll b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-args.ll
new file mode 100644
index 00000000000000..be8582f9226d5c
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-args.ll
@@ -0,0 +1,28 @@
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV-DAG: OpName %[[FooArg:.*]] "known_type_ptr"
+; CHECK-SPIRV-DAG: OpName %[[Foo:.*]] "foo"
+; CHECK-SPIRV-DAG: OpName %[[ArgToDeduce:.*]] "unknown_type_ptr"
+; CHECK-SPIRV-DAG: OpName %[[Bar:.*]] "bar"
+; CHECK-SPIRV-DAG: %[[Long:.*]] = OpTypeInt 32 0
+; CHECK-SPIRV-DAG: %[[Void:.*]] = OpTypeVoid
+; CHECK-SPIRV-DAG: %[[LongPtr:.*]] = OpTypePointer CrossWorkgroup %[[Long]]
+; CHECK-SPIRV-DAG: %[[Fun:.*]] = OpTypeFunction %[[Void]] %[[LongPtr]]
+; CHECK-SPIRV: %[[Foo]] = OpFunction %[[Void]] None %[[Fun]]
+; CHECK-SPIRV: %[[FooArg]] = OpFunctionParameter %[[LongPtr]]
+; CHECK-SPIRV: %[[Bar]] = OpFunction %[[Void]] None %[[Fun]]
+; CHECK-SPIRV: %[[ArgToDeduce]] = OpFunctionParameter %[[LongPtr]]
+; CHECK-SPIRV: OpFunctionCall %[[Void]] %[[Foo]] %[[ArgToDeduce]]
+
+define void @foo(ptr addrspace(1) %known_type_ptr) {
+entry:
+ ret void
+}
+
+define spir_kernel void @bar(ptr addrspace(1) %unknown_type_ptr) {
+entry:
+ %elem = getelementptr inbounds i32, ptr addrspace(1) %unknown_type_ptr, i64 0
+ call spir_func void @foo(ptr addrspace(1) %unknown_type_ptr)
+ ret void
+}
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the patch! LGTM! No regressions in OpenCL CTS. Two new tests pass SPIR-V validation: basic/progvar_prog_scope_misc, basic/vstore_private.
This PR adds type inference of function parameters by call instances. Two use cases that demonstrate the problem are added.