diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp index 9e11c3a281a1b..dd57b74d79a5e 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp @@ -149,23 +149,23 @@ static FunctionType *getOriginalFunctionType(const Function &F) { return isa(N->getOperand(0)) && cast(N->getOperand(0))->getString() == F.getName(); }); - // TODO: probably one function can have numerous type mutations, - // so we should support this. if (ThisFuncMDIt != NamedMD->op_end()) { auto *ThisFuncMD = *ThisFuncMDIt; - MDNode *MD = dyn_cast(ThisFuncMD->getOperand(1)); - assert(MD && "MDNode operand is expected"); - ConstantInt *Const = getConstInt(MD, 0); - if (Const) { - auto *CMeta = dyn_cast(MD->getOperand(1)); - assert(CMeta && "ConstantAsMetadata operand is expected"); - assert(Const->getSExtValue() >= -1); - // Currently -1 indicates return value, greater values mean - // argument numbers. - if (Const->getSExtValue() == -1) - RetTy = CMeta->getType(); - else - ArgTypes[Const->getSExtValue()] = CMeta->getType(); + for (unsigned I = 1; I != ThisFuncMD->getNumOperands(); ++I) { + MDNode *MD = dyn_cast(ThisFuncMD->getOperand(I)); + assert(MD && "MDNode operand is expected"); + ConstantInt *Const = getConstInt(MD, 0); + if (Const) { + auto *CMeta = dyn_cast(MD->getOperand(1)); + assert(CMeta && "ConstantAsMetadata operand is expected"); + assert(Const->getSExtValue() >= -1); + // Currently -1 indicates return value, greater values mean + // argument numbers. + if (Const->getSExtValue() == -1) + RetTy = CMeta->getType(); + else + ArgTypes[Const->getSExtValue()] = CMeta->getType(); + } } } diff --git a/llvm/test/CodeGen/SPIRV/pointers/composite-fun-fix-ptr-arg.ll b/llvm/test/CodeGen/SPIRV/pointers/composite-fun-fix-ptr-arg.ll index 73c46b18bfa78..c9b2968a4aed7 100644 --- a/llvm/test/CodeGen/SPIRV/pointers/composite-fun-fix-ptr-arg.ll +++ b/llvm/test/CodeGen/SPIRV/pointers/composite-fun-fix-ptr-arg.ll @@ -10,6 +10,7 @@ ; CHECK-DAG: %[[#Int8:]] = OpTypeInt 8 0 ; CHECK-DAG: %[[#Half:]] = OpTypeFloat 16 +; CHECK-DAG: %[[#Float:]] = OpTypeFloat 32 ; CHECK-DAG: %[[#Struct:]] = OpTypeStruct %[[#Half]] ; CHECK-DAG: %[[#Void:]] = OpTypeVoid ; CHECK-DAG: %[[#PtrInt8:]] = OpTypePointer CrossWorkgroup %[[#Int8:]] @@ -17,12 +18,20 @@ ; CHECK-DAG: %[[#Int64:]] = OpTypeInt 64 0 ; CHECK-DAG: %[[#PtrInt64:]] = OpTypePointer CrossWorkgroup %[[#Int64]] ; CHECK-DAG: %[[#BarType:]] = OpTypeFunction %[[#Void]] %[[#PtrInt64]] %[[#Struct]] +; CHECK-DAG: %[[#BazType:]] = OpTypeFunction %[[#Void]] %[[#PtrInt8]] %[[#Struct]] %[[#Int8]] %[[#Struct]] %[[#Float]] %[[#Struct]] ; CHECK: OpFunction %[[#Void]] None %[[#FooType]] ; CHECK: OpFunctionParameter %[[#PtrInt8]] ; CHECK: OpFunctionParameter %[[#Struct]] ; CHECK: OpFunction %[[#Void]] None %[[#BarType]] ; CHECK: OpFunctionParameter %[[#PtrInt64]] ; CHECK: OpFunctionParameter %[[#Struct]] +; CHECK: OpFunction %[[#Void]] None %[[#BazType]] +; CHECK: OpFunctionParameter %[[#PtrInt8]] +; CHECK: OpFunctionParameter %[[#Struct]] +; CHECK: OpFunctionParameter %[[#Int8]] +; CHECK: OpFunctionParameter %[[#Struct]] +; CHECK: OpFunctionParameter %[[#Float]] +; CHECK: OpFunctionParameter %[[#Struct]] %t_half = type { half } @@ -38,4 +47,9 @@ entry: ret void } +define spir_kernel void @baz(ptr addrspace(1) %a, %t_half %b, i8 %c, %t_half %d, float %e, %t_half %f) { +entry: + ret void +} + declare spir_func %t_half @_Z29__spirv_SpecConstantComposite(half)