diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp index dd57b74d79a5e..b61b9e5d4fc7f 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp @@ -131,47 +131,6 @@ fixFunctionTypeIfPtrArgs(SPIRVGlobalRegistry *GR, const Function &F, return FunctionType::get(const_cast(RetTy), ArgTys, false); } -// This code restores function args/retvalue types for composite cases -// because the final types should still be aggregate whereas they're i32 -// during the translation to cope with aggregate flattening etc. -static FunctionType *getOriginalFunctionType(const Function &F) { - auto *NamedMD = F.getParent()->getNamedMetadata("spv.cloned_funcs"); - if (NamedMD == nullptr) - return F.getFunctionType(); - - Type *RetTy = F.getFunctionType()->getReturnType(); - SmallVector ArgTypes; - for (auto &Arg : F.args()) - ArgTypes.push_back(Arg.getType()); - - auto ThisFuncMDIt = - std::find_if(NamedMD->op_begin(), NamedMD->op_end(), [&F](MDNode *N) { - return isa(N->getOperand(0)) && - cast(N->getOperand(0))->getString() == F.getName(); - }); - if (ThisFuncMDIt != NamedMD->op_end()) { - auto *ThisFuncMD = *ThisFuncMDIt; - for (unsigned I = 1; I != ThisFuncMD->getNumOperands(); ++I) { - MDNode *MD = dyn_cast(ThisFuncMD->getOperand(I)); - assert(MD && "MDNode operand is expected"); - ConstantInt *Const = getConstInt(MD, 0); - if (Const) { - auto *CMeta = dyn_cast(MD->getOperand(1)); - assert(CMeta && "ConstantAsMetadata operand is expected"); - assert(Const->getSExtValue() >= -1); - // Currently -1 indicates return value, greater values mean - // argument numbers. - if (Const->getSExtValue() == -1) - RetTy = CMeta->getType(); - else - ArgTypes[Const->getSExtValue()] = CMeta->getType(); - } - } - } - - return FunctionType::get(RetTy, ArgTypes, F.isVarArg()); -} - static SPIRV::AccessQualifier::AccessQualifier getArgAccessQual(const Function &F, unsigned ArgIdx) { if (F.getCallingConv() != CallingConv::SPIR_KERNEL) @@ -204,7 +163,8 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx, SPIRV::AccessQualifier::AccessQualifier ArgAccessQual = getArgAccessQual(F, ArgIdx); - Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(ArgIdx); + Type *OriginalArgType = + SPIRV::getOriginalFunctionType(F)->getParamType(ArgIdx); // If OriginalArgType is non-pointer, use the OriginalArgType (the type cannot // be legally reassigned later). @@ -421,7 +381,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, auto MRI = MIRBuilder.getMRI(); Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(64)); MRI->setRegClass(FuncVReg, &SPIRV::iIDRegClass); - FunctionType *FTy = getOriginalFunctionType(F); + FunctionType *FTy = SPIRV::getOriginalFunctionType(F); Type *FRetTy = FTy->getReturnType(); if (isUntypedPointerTy(FRetTy)) { if (Type *FRetElemTy = GR->findDeducedElementType(&F)) { @@ -506,10 +466,15 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, // - add a topological sort of IndirectCalls to ensure the best types knowledge // - we may need to fix function formal parameter types if they are opaque // pointers used as function pointers in these indirect calls +// - defaulting to StorageClass::Function in the absence of the +// SPV_INTEL_function_pointers extension seems wrong, as that might not be +// able to hold a full width pointer to function, and it also does not model +// the semantics of a pointer to function in a generic fashion. void SPIRVCallLowering::produceIndirectPtrTypes( MachineIRBuilder &MIRBuilder) const { // Create indirect call data types if any MachineFunction &MF = MIRBuilder.getMF(); + const SPIRVSubtarget &ST = MF.getSubtarget(); for (auto const &IC : IndirectCalls) { SPIRVType *SpirvRetTy = GR->getOrCreateSPIRVType( IC.RetTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true); @@ -527,8 +492,11 @@ void SPIRVCallLowering::produceIndirectPtrTypes( SPIRVType *SpirvFuncTy = GR->getOrCreateOpTypeFunctionWithArgs( FTy, SpirvRetTy, SpirvArgTypes, MIRBuilder); // SPIR-V pointer to function type: - SPIRVType *IndirectFuncPtrTy = GR->getOrCreateSPIRVPointerType( - SpirvFuncTy, MIRBuilder, SPIRV::StorageClass::Function); + auto SC = ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers) + ? SPIRV::StorageClass::CodeSectionINTEL + : SPIRV::StorageClass::Function; + SPIRVType *IndirectFuncPtrTy = + GR->getOrCreateSPIRVPointerType(SpirvFuncTy, MIRBuilder, SC); // Correct the Callee type GR->assignSPIRVTypeToVReg(IndirectFuncPtrTy, IC.Callee, MF); } @@ -556,12 +524,12 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, // TODO: support constexpr casts and indirect calls. if (CF == nullptr) return false; - if (FunctionType *FTy = getOriginalFunctionType(*CF)) { - OrigRetTy = FTy->getReturnType(); - if (isUntypedPointerTy(OrigRetTy)) { - if (auto *DerivedRetTy = GR->findReturnType(CF)) - OrigRetTy = DerivedRetTy; - } + + FunctionType *FTy = SPIRV::getOriginalFunctionType(*CF); + OrigRetTy = FTy->getReturnType(); + if (isUntypedPointerTy(OrigRetTy)) { + if (auto *DerivedRetTy = GR->findReturnType(CF)) + OrigRetTy = DerivedRetTy; } } @@ -683,11 +651,15 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, if (CalleeReg.isValid()) { SPIRVCallLowering::SPIRVIndirectCall IndirectCall; IndirectCall.Callee = CalleeReg; - IndirectCall.RetTy = OrigRetTy; - for (const auto &Arg : Info.OrigArgs) { - assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs"); - IndirectCall.ArgTys.push_back(Arg.Ty); - IndirectCall.ArgRegs.push_back(Arg.Regs[0]); + FunctionType *FTy = SPIRV::getOriginalFunctionType(*Info.CB); + IndirectCall.RetTy = OrigRetTy = FTy->getReturnType(); + assert(FTy->getNumParams() == Info.OrigArgs.size() && + "Function types mismatch"); + for (unsigned I = 0; I != Info.OrigArgs.size(); ++I) { + assert(Info.OrigArgs[I].Regs.size() == 1 && + "Call arg has multiple VRegs"); + IndirectCall.ArgTys.push_back(FTy->getParamType(I)); + IndirectCall.ArgRegs.push_back(Info.OrigArgs[I].Regs[0]); } IndirectCalls.push_back(IndirectCall); } diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp index 8e14fb03127fc..72b0e031a4a3f 100644 --- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp @@ -360,6 +360,17 @@ static void emitAssignName(Instruction *I, IRBuilder<> &B) { if (!I->hasName() || I->getType()->isAggregateType() || expectIgnoredInIRTranslation(I)) return; + + if (isa(I)) { + // TODO: this is a temporary workaround meant to prevent inserting internal + // noise into the generated binary; remove once we rework the entire + // aggregate removal machinery. + StringRef Name = I->getName(); + if (Name.starts_with("spv.mutated_callsite")) + return; + if (Name.starts_with("spv.named_mutated_callsite")) + I->setName(Name.substr(Name.rfind('.') + 1)); + } reportFatalOnTokenType(I); setInsertPointAfterDef(B, I); LLVMContext &Ctx = I->getContext(); @@ -759,10 +770,15 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper( if (Type *ElemTy = getPointeeType(KnownTy)) maybeAssignPtrType(Ty, I, ElemTy, UnknownElemTypeI8); } else if (auto *Ref = dyn_cast(I)) { - Ty = deduceElementTypeByValueDeep( - Ref->getValueType(), - Ref->getNumOperands() > 0 ? Ref->getOperand(0) : nullptr, Visited, - UnknownElemTypeI8); + if (auto *Fn = dyn_cast(Ref)) { + Ty = SPIRV::getOriginalFunctionType(*Fn); + GR->addDeducedElementType(I, Ty); + } else { + Ty = deduceElementTypeByValueDeep( + Ref->getValueType(), + Ref->getNumOperands() > 0 ? Ref->getOperand(0) : nullptr, Visited, + UnknownElemTypeI8); + } } else if (auto *Ref = dyn_cast(I)) { Type *RefTy = deduceElementTypeHelper(Ref->getPointerOperand(), Visited, UnknownElemTypeI8); @@ -1062,10 +1078,10 @@ void SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionPointer( if (!Op || !isPointerTy(Op->getType())) return; Ops.push_back(std::make_pair(Op, std::numeric_limits::max())); - FunctionType *FTy = CI->getFunctionType(); + FunctionType *FTy = SPIRV::getOriginalFunctionType(*CI); bool IsNewFTy = false, IsIncomplete = false; SmallVector ArgTys; - for (Value *Arg : CI->args()) { + for (auto &&[ParmIdx, Arg] : llvm::enumerate(CI->args())) { Type *ArgTy = Arg->getType(); if (ArgTy->isPointerTy()) { if (Type *ElemTy = GR->findDeducedElementType(Arg)) { @@ -1076,6 +1092,8 @@ void SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionPointer( } else { IsIncomplete = true; } + } else { + ArgTy = FTy->getFunctionParamType(ParmIdx); } ArgTys.push_back(ArgTy); } diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h index 09c77f0cfd4f5..16f3260bf4ffc 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -214,6 +214,8 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping { if (Value *GlobalElem = Global->getNumOperands() > 0 ? Global->getOperand(0) : nullptr) ElementTy = findDeducedCompositeType(GlobalElem); + else if (const Function *Fn = dyn_cast(Global)) + ElementTy = SPIRV::getOriginalFunctionType(*Fn); } return ElementTy ? ElementTy : Global->getValueType(); } diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp index 0f4b3d59b904a..f3c886aafc131 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp @@ -257,9 +257,12 @@ static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR, Register Def = MI.getOperand(0).getReg(); Register Source = MI.getOperand(2).getReg(); Type *ElemTy = getMDOperandAsType(MI.getOperand(3).getMetadata(), 0); - SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType( - ElemTy, MI, - addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST)); + auto SC = + isa(ElemTy) + ? SPIRV::StorageClass::CodeSectionINTEL + : addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST); + SPIRVType *AssignedPtrType = + GR->getOrCreateSPIRVPointerType(ElemTy, MI, SC); // If the ptrcast would be redundant, replace all uses with the source // register. diff --git a/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp b/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp index be88f334d2171..fdd0af871e03e 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp @@ -26,6 +26,8 @@ #include "llvm/Analysis/ValueTracking.h" #include "llvm/CodeGen/IntrinsicLowering.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/IntrinsicsSPIRV.h" @@ -41,6 +43,7 @@ class SPIRVPrepareFunctions : public ModulePass { const SPIRVTargetMachine &TM; bool substituteIntrinsicCalls(Function *F); Function *removeAggregateTypesFromSignature(Function *F); + bool removeAggregateTypesFromCalls(Function *F); public: static char ID; @@ -469,6 +472,23 @@ bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) { return Changed; } +static void +addFunctionTypeMutation(NamedMDNode *NMD, + SmallVector> ChangedTys, + StringRef Name) { + + LLVMContext &Ctx = NMD->getParent()->getContext(); + Type *I32Ty = IntegerType::getInt32Ty(Ctx); + + SmallVector MDArgs; + MDArgs.push_back(MDString::get(Ctx, Name)); + transform(ChangedTys, std::back_inserter(MDArgs), [=, &Ctx](auto &&CTy) { + return MDNode::get( + Ctx, {ConstantAsMetadata::get(ConstantInt::get(I32Ty, CTy.first, true)), + ValueAsMetadata::get(Constant::getNullValue(CTy.second))}); + }); + NMD->addOperand(MDNode::get(Ctx, MDArgs)); +} // Returns F if aggregate argument/return types are not present or cloned F // function with the types replaced by i32 types. The change in types is // noted in 'spv.cloned_funcs' metadata for later restoration. @@ -503,7 +523,8 @@ SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) { FunctionType *NewFTy = FunctionType::get(RetType, ArgTypes, F->getFunctionType()->isVarArg()); Function *NewF = - Function::Create(NewFTy, F->getLinkage(), F->getName(), *F->getParent()); + Function::Create(NewFTy, F->getLinkage(), F->getAddressSpace(), + F->getName(), F->getParent()); ValueToValueMapTy VMap; auto NewFArgIt = NewF->arg_begin(); @@ -518,22 +539,18 @@ SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) { Returns); NewF->takeName(F); - NamedMDNode *FuncMD = - F->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs"); - SmallVector MDArgs; - MDArgs.push_back(MDString::get(B.getContext(), NewF->getName())); - for (auto &ChangedTyP : ChangedTypes) - MDArgs.push_back(MDNode::get( - B.getContext(), - {ConstantAsMetadata::get(B.getInt32(ChangedTyP.first)), - ValueAsMetadata::get(Constant::getNullValue(ChangedTyP.second))})); - MDNode *ThisFuncMD = MDNode::get(B.getContext(), MDArgs); - FuncMD->addOperand(ThisFuncMD); + addFunctionTypeMutation( + NewF->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs"), + std::move(ChangedTypes), NewF->getName()); for (auto *U : make_early_inc_range(F->users())) { - if (auto *CI = dyn_cast(U)) + if (CallInst *CI; + (CI = dyn_cast(U)) && CI->getCalledFunction() == F) CI->mutateFunctionType(NewF->getFunctionType()); - U->replaceUsesOfWith(F, NewF); + if (auto *C = dyn_cast(U)) + C->handleOperandChange(F, NewF); + else + U->replaceUsesOfWith(F, NewF); } // register the mutation @@ -543,11 +560,78 @@ SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) { return NewF; } +// Mutates indirect callsites iff if aggregate argument/return types are present +// with the types replaced by i32 types. The change in types is noted in +// 'spv.mutated_callsites' metadata for later restoration. +bool SPIRVPrepareFunctions::removeAggregateTypesFromCalls(Function *F) { + if (F->isDeclaration() || F->isIntrinsic()) + return false; + + SmallVector> Calls; + for (auto &&I : instructions(F)) { + if (auto *CB = dyn_cast(&I)) { + if (!CB->getCalledOperand() || CB->getCalledFunction()) + continue; + if (CB->getType()->isAggregateType() || + any_of(CB->args(), + [](auto &&Arg) { return Arg->getType()->isAggregateType(); })) + Calls.emplace_back(CB, nullptr); + } + } + + if (Calls.empty()) + return false; + + IRBuilder<> B(F->getContext()); + + for (auto &&[CB, NewFnTy] : Calls) { + SmallVector> ChangedTypes; + SmallVector NewArgTypes; + + Type *RetTy = CB->getType(); + if (RetTy->isAggregateType()) { + ChangedTypes.emplace_back(-1, RetTy); + RetTy = B.getInt32Ty(); + } + + for (auto &&Arg : CB->args()) { + if (Arg->getType()->isAggregateType()) { + NewArgTypes.push_back(B.getInt32Ty()); + ChangedTypes.emplace_back(Arg.getOperandNo(), Arg->getType()); + } else { + NewArgTypes.push_back(Arg->getType()); + } + } + NewFnTy = FunctionType::get(RetTy, NewArgTypes, + CB->getFunctionType()->isVarArg()); + + if (!CB->hasName()) + CB->setName("spv.mutated_callsite." + F->getName()); + else + CB->setName("spv.named_mutated_callsite." + F->getName() + "." + + CB->getName()); + + addFunctionTypeMutation( + F->getParent()->getOrInsertNamedMetadata("spv.mutated_callsites"), + std::move(ChangedTypes), CB->getName()); + } + + for (auto &&[CB, NewFTy] : Calls) { + if (NewFTy->getReturnType() != CB->getType()) + TM.getSubtarget(*F).getSPIRVGlobalRegistry()->addMutated( + CB, CB->getType()); + CB->mutateFunctionType(NewFTy); + } + + return true; +} + bool SPIRVPrepareFunctions::runOnModule(Module &M) { bool Changed = false; for (Function &F : M) { Changed |= substituteIntrinsicCalls(&F); Changed |= sortBlocks(F); + Changed |= removeAggregateTypesFromCalls(&F); } std::vector FuncsWorklist; diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp index 8f2fc01da476f..5757e2a382b85 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp @@ -28,6 +28,69 @@ #include namespace llvm { +namespace SPIRV { +// This code restores function args/retvalue types for composite cases +// because the final types should still be aggregate whereas they're i32 +// during the translation to cope with aggregate flattening etc. +// TODO: should these just return nullptr when there's no metadata? +static FunctionType *extractFunctionTypeFromMetadata(NamedMDNode *NMD, + FunctionType *FTy, + StringRef Name) { + if (!NMD) + return FTy; + + constexpr auto getConstInt = [](MDNode *MD, unsigned OpId) -> ConstantInt * { + if (MD->getNumOperands() <= OpId) + return nullptr; + if (auto *CMeta = dyn_cast(MD->getOperand(OpId))) + return dyn_cast(CMeta->getValue()); + return nullptr; + }; + + auto It = find_if(NMD->operands(), [Name](MDNode *N) { + if (auto *MDS = dyn_cast_or_null(N->getOperand(0))) + return MDS->getString() == Name; + return false; + }); + + if (It == NMD->op_end()) + return FTy; + + Type *RetTy = FTy->getReturnType(); + SmallVector PTys(FTy->params()); + + for (unsigned I = 1; I != (*It)->getNumOperands(); ++I) { + MDNode *MD = dyn_cast((*It)->getOperand(I)); + assert(MD && "MDNode operand is expected"); + + if (auto *Const = getConstInt(MD, 0)) { + auto *CMeta = dyn_cast(MD->getOperand(1)); + assert(CMeta && "ConstantAsMetadata operand is expected"); + assert(Const->getSExtValue() >= -1); + // Currently -1 indicates return value, greater values mean + // argument numbers. + if (Const->getSExtValue() == -1) + RetTy = CMeta->getType(); + else + PTys[Const->getSExtValue()] = CMeta->getType(); + } + } + + return FunctionType::get(RetTy, PTys, FTy->isVarArg()); +} + +FunctionType *getOriginalFunctionType(const Function &F) { + return extractFunctionTypeFromMetadata( + F.getParent()->getNamedMetadata("spv.cloned_funcs"), F.getFunctionType(), + F.getName()); +} + +FunctionType *getOriginalFunctionType(const CallBase &CB) { + return extractFunctionTypeFromMetadata( + CB.getModule()->getNamedMetadata("spv.mutated_callsites"), + CB.getFunctionType(), CB.getName()); +} +} // Namespace SPIRV // The following functions are used to add these string literals as a series of // 32-bit integer operands with the correct format, and unpack them if necessary diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h index 99d9d403ea70c..3da77f1d6c1ec 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.h +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h @@ -159,6 +159,11 @@ struct FPFastMathDefaultInfoVector } }; +// This code restores function args/retvalue types for composite cases +// because the final types should still be aggregate whereas they're i32 +// during the translation to cope with aggregate flattening etc. +FunctionType *getOriginalFunctionType(const Function &F); +FunctionType *getOriginalFunctionType(const CallBase &CB); } // namespace SPIRV // Add the given string as a series of integer operand, inserting null diff --git a/llvm/test/CodeGen/SPIRV/pointers/fun-with-aggregate-arg-in-const-init.ll b/llvm/test/CodeGen/SPIRV/pointers/fun-with-aggregate-arg-in-const-init.ll new file mode 100644 index 0000000000000..ec3fd41f7de9e --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/pointers/fun-with-aggregate-arg-in-const-init.ll @@ -0,0 +1,107 @@ +; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_function_pointers %s -o - | FileCheck %s +; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_function_pointers %s -o - -filetype=obj | spirv-val %} + +; CHECK: OpCapability Kernel +; CHECK-DAG: OpCapability FunctionPointersINTEL +; CHECK-DAG: OpExtension "SPV_INTEL_function_pointers" +; CHECK-DAG: OpName %[[#fArray:]] "array" +; CHECK-DAG: OpName %[[#fStruct:]] "struct" + +; CHECK-DAG: %[[#Int8Ty:]] = OpTypeInt 8 0 +; CHECK: %[[#GlobalInt8PtrTy:]] = OpTypePointer CrossWorkgroup %[[#Int8Ty]] +; CHECK: %[[#VoidTy:]] = OpTypeVoid +; CHECK: %[[#TestFnTy:]] = OpTypeFunction %[[#VoidTy]] %[[#GlobalInt8PtrTy]] +; CHECK: %[[#F16Ty:]] = OpTypeFloat 16 +; CHECK: %[[#t_halfTy:]] = OpTypeStruct %[[#F16Ty]] +; CHECK: %[[#FnTy:]] = OpTypeFunction %[[#t_halfTy]] %[[#GlobalInt8PtrTy]] %[[#t_halfTy]] +; CHECK: %[[#IntelFnPtrTy:]] = OpTypePointer CodeSectionINTEL %[[#FnTy]] +; CHECK: %[[#Int8PtrTy:]] = OpTypePointer Function %[[#Int8Ty]] +; CHECK: %[[#Int32Ty:]] = OpTypeInt 32 0 +; CHECK: %[[#I32Const3:]] = OpConstant %[[#Int32Ty]] 3 +; CHECK: %[[#FnArrTy:]] = OpTypeArray %[[#Int8PtrTy]] %[[#I32Const3]] +; CHECK: %[[#GlobalFnArrPtrTy:]] = OpTypePointer CrossWorkgroup %[[#FnArrTy]] +; CHECK: %[[#GlobalFnPtrTy:]] = OpTypePointer CrossWorkgroup %[[#FnTy]] +; CHECK: %[[#FnPtrTy:]] = OpTypePointer Function %[[#FnTy]] +; CHECK: %[[#StructWithPfnTy:]] = OpTypeStruct %[[#FnPtrTy]] %[[#FnPtrTy]] %[[#FnPtrTy]] +; CHECK: %[[#ArrayOfPfnTy:]] = OpTypeArray %[[#FnPtrTy]] %[[#I32Const3]] +; CHECK: %[[#Int64Ty:]] = OpTypeInt 64 0 +; CHECK: %[[#GlobalStructWithPfnPtrTy:]] = OpTypePointer CrossWorkgroup %[[#StructWithPfnTy]] +; CHECK: %[[#GlobalArrOfPfnPtrTy:]] = OpTypePointer CrossWorkgroup %[[#ArrayOfPfnTy]] +; CHECK: %[[#I64Const2:]] = OpConstant %[[#Int64Ty]] 2 +; CHECK: %[[#I64Const1:]] = OpConstant %[[#Int64Ty]] 1 +; CHECK: %[[#I64Const0:]] = OpConstantNull %[[#Int64Ty]] +; CHECK: %[[#f0Pfn:]] = OpConstantFunctionPointerINTEL %[[#IntelFnPtrTy]] %28 +; CHECK: %[[#f1Pfn:]] = OpConstantFunctionPointerINTEL %[[#IntelFnPtrTy]] %32 +; CHECK: %[[#f2Pfn:]] = OpConstantFunctionPointerINTEL %[[#IntelFnPtrTy]] %36 +; CHECK: %[[#f0Cast:]] = OpSpecConstantOp %[[#FnPtrTy]] Bitcast %[[#f0Pfn]] +; CHECK: %[[#f1Cast:]] = OpSpecConstantOp %[[#FnPtrTy]] Bitcast %[[#f1Pfn]] +; CHECK: %[[#f2Cast:]] = OpSpecConstantOp %[[#FnPtrTy]] Bitcast %[[#f2Pfn]] +; CHECK: %[[#fnptrTy:]] = OpConstantComposite %[[#ArrayOfPfnTy]] %[[#f0Cast]] %[[#f1Cast]] %[[#f2Cast]] +; CHECK: %[[#fnptr:]] = OpVariable %[[#GlobalArrOfPfnPtrTy]] CrossWorkgroup %[[#fnptrTy]] +; CHECK: %[[#fnstructTy:]] = OpConstantComposite %[[#StructWithPfnTy]] %[[#f0Cast]] %[[#f1Cast]] %[[#f2Cast]] +; CHECK: %[[#fnstruct:]] = OpVariable %[[#GlobalStructWithPfnPtrTy:]] CrossWorkgroup %[[#fnstructTy]] +; CHECK-DAG: %[[#GlobalInt8PtrPtrTy:]] = OpTypePointer CrossWorkgroup %[[#Int8PtrTy]] +; CHECK: %[[#StructWithPtrTy:]] = OpTypeStruct %[[#Int8PtrTy]] %[[#Int8PtrTy]] %[[#Int8PtrTy]] +; CHECK: %[[#GlobalStructWithPtrPtrTy:]] = OpTypePointer CrossWorkgroup %[[#StructWithPtrTy]] +; CHECK: %[[#I32Const2:]] = OpConstant %[[#Int32Ty]] 2 +; CHECK: %[[#I32Const1:]] = OpConstant %[[#Int32Ty]] 1 +; CHECK: %[[#I32Const0:]] = OpConstantNull %[[#Int32Ty]] +; CHECK: %[[#GlobalFnPtrPtrTy:]] = OpTypePointer CrossWorkgroup %[[#FnPtrTy]] +%t_half = type { half } +%struct.anon = type { ptr, ptr, ptr } + +declare spir_func %t_half @f0(ptr addrspace(1) %a, %t_half %b) +declare spir_func %t_half @f1(ptr addrspace(1) %a, %t_half %b) +declare spir_func %t_half @f2(ptr addrspace(1) %a, %t_half %b) + +@fnptr = addrspace(1) constant [3 x ptr] [ptr @f0, ptr @f1, ptr @f2] +@fnstruct = addrspace(1) constant %struct.anon { ptr @f0, ptr @f1, ptr @f2 }, align 8 + +; CHECK-DAG: %[[#fArray]] = OpFunction %[[#VoidTy]] None %[[#TestFnTy]] +; CHECK-DAG: %[[#fnptrCast:]] = OpBitcast %[[#GlobalFnArrPtrTy]] %[[#fnptr]] +; CHECK: %[[#f0GEP:]] = OpInBoundsPtrAccessChain %[[#GlobalFnArrPtrTy]] %[[#fnptrCast]] %[[#I64Const0]] +; CHECK: %[[#f0GEPCast:]] = OpBitcast %[[#GlobalFnPtrTy]] %[[#f0GEP]] +; CHECK: %[[#f1GEP:]] = OpInBoundsPtrAccessChain %[[#GlobalFnArrPtrTy]] %[[#fnptrCast]] %[[#I64Const1]] +; CHECK: %[[#f1GEPCast:]] = OpBitcast %[[#GlobalFnPtrTy]] %[[#f1GEP]] +; CHECK: %[[#f2GEP:]] = OpInBoundsPtrAccessChain %[[#GlobalFnArrPtrTy]] %[[#fnptrCast]] %[[#I64Const2]] +; CHECK: %[[#f2GEPCast:]] = OpBitcast %[[#GlobalFnPtrTy]] %[[#f2GEP]] +; CHECK: %{{.*}} = OpFunctionPointerCallINTEL %[[#t_halfTy]] %[[#f0GEPCast]] +; CHECK: %{{.*}} = OpFunctionPointerCallINTEL %[[#t_halfTy]] %[[#f1GEPCast]] +; CHECK: %{{.*}} = OpFunctionPointerCallINTEL %[[#t_halfTy]] %[[#f2GEPCast]] +define spir_func void @array(ptr addrspace(1) %p) { +entry: + %f = getelementptr inbounds [3 x ptr], ptr addrspace(1) @fnptr, i64 0 + %g = getelementptr inbounds [3 x ptr], ptr addrspace(1) @fnptr, i64 1 + %h = getelementptr inbounds [3 x ptr], ptr addrspace(1) @fnptr, i64 2 + %0 = call spir_func addrspace(1) %t_half %f(ptr addrspace(1) %p, %t_half poison) + %1 = call spir_func addrspace(1) %t_half %g(ptr addrspace(1) %p, %t_half %0) + %2 = call spir_func addrspace(1) %t_half %h(ptr addrspace(1) %p, %t_half %1) + + ret void +} + +; CHECK-DAG: %[[#fStruct]] = OpFunction %[[#VoidTy]] None %[[#TestFnTy]] +; CHECK-DAG: %[[#fnStructCast0:]] = OpBitcast %[[#GlobalInt8PtrPtrTy]] %[[#fnstruct]] +; CHECK: %[[#fnStructCast1:]] = OpBitcast %[[#GlobalFnPtrPtrTy]] %[[#fnStructCast0]] +; CHECK: %[[#f0Load:]] = OpLoad %[[#FnPtrTy]] %[[#fnStructCast1]] +; CHECK: %[[#fnStructCast2:]] = OpBitcast %[[#GlobalStructWithPtrPtrTy]] %[[#fnstruct]] +; CHECK: %[[#f1GEP:]] = OpInBoundsPtrAccessChain %[[#GlobalInt8PtrPtrTy]] %[[#fnStructCast2]] %[[#I32Const0]] %[[#I32Const1]] +; CHECK: %[[#f1GEPCast:]] = OpBitcast %[[#GlobalFnPtrPtrTy]] %[[#f1GEP]] +; CHECK: %[[#f1Load:]] = OpLoad %[[#FnPtrTy]] %[[#f1GEPCast]] +; CHECK: %[[#f2GEP:]] = OpInBoundsPtrAccessChain %[[#GlobalInt8PtrPtrTy]] %[[#fnStructCast2]] %[[#I32Const0]] %[[#I32Const2]] +; CHECK: %[[#f2GEPCast:]] = OpBitcast %[[#GlobalFnPtrPtrTy]] %[[#f2GEP]] +; CHECK: %[[#f2Load:]] = OpLoad %[[#FnPtrTy]] %[[#f2GEPCast]] +; CHECK: %{{.*}} = OpFunctionPointerCallINTEL %[[#t_halfTy]] %[[#f0Load]] +; CHECK: %{{.*}} = OpFunctionPointerCallINTEL %[[#t_halfTy]] %[[#f1Load]] +; CHECK: %{{.*}} = OpFunctionPointerCallINTEL %[[#t_halfTy]] %[[#f2Load]] +define spir_func void @struct(ptr addrspace(1) %p) { +entry: + %f = load ptr, ptr addrspace(1) @fnstruct + %g = load ptr, ptr addrspace(1) getelementptr inbounds (%struct.anon, ptr addrspace(1) @fnstruct, i32 0, i32 1) + %h = load ptr, ptr addrspace(1) getelementptr inbounds (%struct.anon, ptr addrspace(1) @fnstruct, i32 0, i32 2) + %0 = call spir_func noundef %t_half %f(ptr addrspace(1) %p, %t_half poison) + %1 = call spir_func noundef %t_half %g(ptr addrspace(1) %p, %t_half %0) + %2 = call spir_func noundef %t_half %h(ptr addrspace(1) %p, %t_half %1) + + ret void +}