diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp index 97b25147ffb34..8ac498e1556be 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp @@ -34,6 +34,12 @@ bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder, const Value *Val, ArrayRef VRegs, FunctionLoweringInfo &FLI, Register SwiftErrorVReg) const { + // Maybe run postponed production of types for function pointers + if (IndirectCalls.size() > 0) { + produceIndirectPtrTypes(MIRBuilder); + IndirectCalls.clear(); + } + // Currently all return types should use a single register. // TODO: handle the case of multiple registers. if (VRegs.size() > 1) @@ -292,7 +298,6 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, } } - // Generate a SPIR-V type for the function. auto MRI = MIRBuilder.getMRI(); Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32)); MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass); @@ -301,17 +306,17 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, SPIRVType *RetTy = GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder); SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs( FTy, RetTy, ArgTypeVRegs, MIRBuilder); - - // Build the OpTypeFunction declaring it. uint32_t FuncControl = getFunctionControl(F); - MIRBuilder.buildInstr(SPIRV::OpFunction) - .addDef(FuncVReg) - .addUse(GR->getSPIRVTypeID(RetTy)) - .addImm(FuncControl) - .addUse(GR->getSPIRVTypeID(FuncTy)); + // Add OpFunction instruction + MachineInstrBuilder MB = MIRBuilder.buildInstr(SPIRV::OpFunction) + .addDef(FuncVReg) + .addUse(GR->getSPIRVTypeID(RetTy)) + .addImm(FuncControl) + .addUse(GR->getSPIRVTypeID(FuncTy)); + GR->recordFunctionDefinition(&F, &MB.getInstr()->getOperand(0)); - // Add OpFunctionParameters. + // Add OpFunctionParameter instructions int i = 0; for (const auto &Arg : F.args()) { assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs"); @@ -343,9 +348,56 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, {static_cast(LnkTy)}, F.getGlobalIdentifier()); } + // Handle function pointers decoration + const auto *ST = + static_cast(&MIRBuilder.getMF().getSubtarget()); + bool hasFunctionPointers = + ST->canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers); + if (hasFunctionPointers) { + if (F.hasFnAttribute("referenced-indirectly")) { + assert((F.getCallingConv() != CallingConv::SPIR_KERNEL) && + "Unexpected 'referenced-indirectly' attribute of the kernel " + "function"); + buildOpDecorate(FuncVReg, MIRBuilder, + SPIRV::Decoration::ReferencedIndirectlyINTEL, {}); + } + } + return true; } +// Used to postpone producing of indirect function pointer types after all +// indirect calls info is collected +// TODO: +// - add a topological sort of IndirectCalls to ensure the best types knowledge +// - we may need to fix function formal parameter types if they are opaque +// pointers used as function pointers in these indirect calls +void SPIRVCallLowering::produceIndirectPtrTypes( + MachineIRBuilder &MIRBuilder) const { + // Create indirect call data types if any + MachineFunction &MF = MIRBuilder.getMF(); + for (auto const &IC : IndirectCalls) { + SPIRVType *SpirvRetTy = GR->getOrCreateSPIRVType(IC.RetTy, MIRBuilder); + SmallVector SpirvArgTypes; + for (size_t i = 0; i < IC.ArgTys.size(); ++i) { + SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(IC.ArgTys[i], MIRBuilder); + SpirvArgTypes.push_back(SPIRVTy); + if (!GR->getSPIRVTypeForVReg(IC.ArgRegs[i])) + GR->assignSPIRVTypeToVReg(SPIRVTy, IC.ArgRegs[i], MF); + } + // SPIR-V function type: + FunctionType *FTy = + FunctionType::get(const_cast(IC.RetTy), IC.ArgTys, false); + SPIRVType *SpirvFuncTy = GR->getOrCreateOpTypeFunctionWithArgs( + FTy, SpirvRetTy, SpirvArgTypes, MIRBuilder); + // SPIR-V pointer to function type: + SPIRVType *IndirectFuncPtrTy = GR->getOrCreateSPIRVPointerType( + SpirvFuncTy, MIRBuilder, SPIRV::StorageClass::Function); + // Correct the Calee type + GR->assignSPIRVTypeToVReg(IndirectFuncPtrTy, IC.Callee, MF); + } +} + bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, CallLoweringInfo &Info) const { // Currently call returns should have single vregs. @@ -356,45 +408,44 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, GR->setCurrentFunc(MF); FunctionType *FTy = nullptr; const Function *CF = nullptr; + std::string DemangledName; + const Type *OrigRetTy = Info.OrigRet.Ty; // Emit a regular OpFunctionCall. If it's an externally declared function, // be sure to emit its type and function declaration here. It will be hoisted // globally later. if (Info.Callee.isGlobal()) { + std::string FuncName = Info.Callee.getGlobal()->getName().str(); + DemangledName = getOclOrSpirvBuiltinDemangledName(FuncName); CF = dyn_cast_or_null(Info.Callee.getGlobal()); // TODO: support constexpr casts and indirect calls. if (CF == nullptr) return false; - FTy = getOriginalFunctionType(*CF); + if ((FTy = getOriginalFunctionType(*CF)) != nullptr) + OrigRetTy = FTy->getReturnType(); } MachineRegisterInfo *MRI = MIRBuilder.getMRI(); Register ResVReg = Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0]; - std::string FuncName = Info.Callee.getGlobal()->getName().str(); - std::string DemangledName = getOclOrSpirvBuiltinDemangledName(FuncName); const auto *ST = static_cast(&MF.getSubtarget()); // TODO: check that it's OCL builtin, then apply OpenCL_std. if (!DemangledName.empty() && CF && CF->isDeclaration() && ST->canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) { - const Type *OrigRetTy = Info.OrigRet.Ty; - if (FTy) - OrigRetTy = FTy->getReturnType(); SmallVector ArgVRegs; for (auto Arg : Info.OrigArgs) { assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs"); ArgVRegs.push_back(Arg.Regs[0]); SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(Arg.Ty, MIRBuilder); if (!GR->getSPIRVTypeForVReg(Arg.Regs[0])) - GR->assignSPIRVTypeToVReg(SPIRVTy, Arg.Regs[0], MIRBuilder.getMF()); + GR->assignSPIRVTypeToVReg(SPIRVTy, Arg.Regs[0], MF); } if (auto Res = SPIRV::lowerBuiltin( DemangledName, SPIRV::InstructionSet::OpenCL_std, MIRBuilder, ResVReg, OrigRetTy, ArgVRegs, GR)) return *Res; } - if (CF && CF->isDeclaration() && - !GR->find(CF, &MIRBuilder.getMF()).isValid()) { + if (CF && CF->isDeclaration() && !GR->find(CF, &MF).isValid()) { // Emit the type info and forward function declaration to the first MBB // to ensure VReg definition dependencies are valid across all MBBs. MachineIRBuilder FirstBlockBuilder; @@ -416,14 +467,40 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, lowerFormalArguments(FirstBlockBuilder, *CF, VRegArgs, FuncInfo); } + unsigned CallOp; + if (Info.CB->isIndirectCall()) { + if (!ST->canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) + report_fatal_error("An indirect call is encountered but SPIR-V without " + "extensions does not support it", + false); + // Set instruction operation according to SPV_INTEL_function_pointers + CallOp = SPIRV::OpFunctionPointerCallINTEL; + // Collect information about the indirect call to support possible + // specification of opaque ptr types of parent function's parameters + Register CalleeReg = Info.Callee.getReg(); + if (CalleeReg.isValid()) { + SPIRVCallLowering::SPIRVIndirectCall IndirectCall; + IndirectCall.Callee = CalleeReg; + IndirectCall.RetTy = OrigRetTy; + for (const auto &Arg : Info.OrigArgs) { + assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs"); + IndirectCall.ArgTys.push_back(Arg.Ty); + IndirectCall.ArgRegs.push_back(Arg.Regs[0]); + } + IndirectCalls.push_back(IndirectCall); + } + } else { + // Emit a regular OpFunctionCall + CallOp = SPIRV::OpFunctionCall; + } + // Make sure there's a valid return reg, even for functions returning void. if (!ResVReg.isValid()) ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass); - SPIRVType *RetType = - GR->assignTypeToVReg(FTy->getReturnType(), ResVReg, MIRBuilder); + SPIRVType *RetType = GR->assignTypeToVReg(OrigRetTy, ResVReg, MIRBuilder); - // Emit the OpFunctionCall and its args. - auto MIB = MIRBuilder.buildInstr(SPIRV::OpFunctionCall) + // Emit the call instruction and its args. + auto MIB = MIRBuilder.buildInstr(CallOp) .addDef(ResVReg) .addUse(GR->getSPIRVTypeID(RetType)) .add(Info.Callee); diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.h b/llvm/lib/Target/SPIRV/SPIRVCallLowering.h index c2d6ad82d507d..48b3a5eb2671c 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.h +++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.h @@ -26,6 +26,17 @@ class SPIRVCallLowering : public CallLowering { // Used to create and assign function, argument, and return type information. SPIRVGlobalRegistry *GR; + // Used to postpone producing of indirect function pointer types + // after all indirect calls info is collected + struct SPIRVIndirectCall { + const Type *RetTy = nullptr; + SmallVector ArgTys; + SmallVector ArgRegs; + Register Callee; + }; + void produceIndirectPtrTypes(MachineIRBuilder &MIRBuilder) const; + mutable SmallVector IndirectCalls; + public: SPIRVCallLowering(const SPIRVTargetLowering &TLI, SPIRVGlobalRegistry *GR); diff --git a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp index cbe1a53fd7568..d82fb2df4539a 100644 --- a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp @@ -54,7 +54,14 @@ void SPIRVGeneralDuplicatesTracker::buildDepsGraph( MachineOperand &Op = MI->getOperand(i); if (!Op.isReg()) continue; - MachineOperand *RegOp = &MRI.getVRegDef(Op.getReg())->getOperand(0); + MachineInstr *VRegDef = MRI.getVRegDef(Op.getReg()); + // References to a function via function pointers generate virtual + // registers without a definition. We are able to resolve this + // reference using Globar Register info into an OpFunction instruction + // but do not expect to find it in Reg2Entry. + if (MI->getOpcode() == SPIRV::OpConstantFunctionPointerINTEL && i == 2) + continue; + MachineOperand *RegOp = &VRegDef->getOperand(0); assert((MI->getOpcode() == SPIRV::OpVariable && i == 3) || Reg2Entry.count(RegOp)); if (Reg2Entry.count(RegOp)) diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h index f3280928c25df..792a00786f0aa 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -38,6 +38,12 @@ class SPIRVGlobalRegistry { DenseMap SPIRVToLLVMType; + // map a Function to its definition (as a machine instruction operand) + DenseMap FunctionToInstr; + // map function pointer (as a machine instruction operand) to the used + // Function + DenseMap InstrToFunction; + // Look for an equivalent of the newType in the map. Return the equivalent // if it's found, otherwise insert newType to the map and return the type. const MachineInstr *checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD, @@ -101,6 +107,29 @@ class SPIRVGlobalRegistry { DT.buildDepsGraph(Graph, MMI); } + // Map a machine operand that represents a use of a function via function + // pointer to a machine operand that represents the function definition. + // Return either the register or invalid value, because we have no context for + // a good diagnostic message in case of unexpectedly missing references. + const MachineOperand *getFunctionDefinitionByUse(const MachineOperand *Use) { + auto ResF = InstrToFunction.find(Use); + if (ResF == InstrToFunction.end()) + return nullptr; + auto ResReg = FunctionToInstr.find(ResF->second); + return ResReg == FunctionToInstr.end() ? nullptr : ResReg->second; + } + // map function pointer (as a machine instruction operand) to the used + // Function + void recordFunctionPointer(const MachineOperand *MO, const Function *F) { + InstrToFunction[MO] = F; + } + // map a Function to its definition (as a machine instruction) + void recordFunctionDefinition(const Function *F, const MachineOperand *MO) { + FunctionToInstr[F] = MO; + } + // Return true if any OpConstantFunctionPointerINTEL were generated + bool hasConstFunPtr() { return !InstrToFunction.empty(); } + // Get or create a SPIR-V type corresponding the given LLVM IR type, // and map it to the given VReg by creating an ASSIGN_TYPE instruction. SPIRVType *assignTypeToVReg(const Type *Type, Register VReg, diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp index 42317453a2370..e3f76419f1313 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp @@ -40,6 +40,7 @@ bool SPIRVInstrInfo::isConstantInstr(const MachineInstr &MI) const { case SPIRV::OpSpecConstantComposite: case SPIRV::OpSpecConstantOp: case SPIRV::OpUndef: + case SPIRV::OpConstantFunctionPointerINTEL: return true; default: return false; diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td index caf2ae43480b1..904fef1d6c82f 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td @@ -762,6 +762,16 @@ def OpGroupNonUniformLogicalAnd: OpGroupNUGroup<"LogicalAnd", 362>; def OpGroupNonUniformLogicalOr: OpGroupNUGroup<"LogicalOr", 363>; def OpGroupNonUniformLogicalXor: OpGroupNUGroup<"LogicalXor", 364>; +// 3.49.7, Constant-Creation Instructions + +// - SPV_INTEL_function_pointers +def OpConstantFunctionPointerINTEL: Op<5600, (outs ID:$res), (ins TYPE:$ty, ID:$fun), "$res = OpConstantFunctionPointerINTEL $ty $fun">; + +// 3.49.9. Function Instructions + +// - SPV_INTEL_function_pointers +def OpFunctionPointerCallINTEL: Op<5601, (outs ID:$res), (ins TYPE:$ty, ID:$funPtr, variable_ops), "$res = OpFunctionPointerCallINTEL $ty $funPtr">; + // 3.49.21. Group and Subgroup Instructions def OpSubgroupShuffleINTEL: Op<5571, (outs ID:$res), (ins TYPE:$type, ID:$data, ID:$invocationId), "$res = OpSubgroupShuffleINTEL $type $data $invocationId">; diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 8c1dfc5e626db..52eeb8a523e6f 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -1534,6 +1534,12 @@ bool SPIRVInstructionSelector::selectGlobalValue( GlobalIdent = GV->getGlobalIdentifier(); } + // Behaviour of functions as operands depends on availability of the + // corresponding extension (SPV_INTEL_function_pointers): + // - If there is an extension to operate with functions as operands: + // We create a proper constant operand and evaluate a correct type for a + // function pointer. + // - Without the required extension: // We have functions as operands in tests with blocks of instruction e.g. in // transcoding/global_block.ll. These operands are not used and should be // substituted by zero constants. Their type is expected to be always @@ -1545,6 +1551,27 @@ bool SPIRVInstructionSelector::selectGlobalValue( if (!NewReg.isValid()) { Register NewReg = ResVReg; GR.add(ConstVal, GR.CurMF, NewReg); + const Function *GVFun = + STI.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers) + ? dyn_cast(GV) + : nullptr; + if (GVFun) { + // References to a function via function pointers generate virtual + // registers without a definition. We will resolve it later, during + // module analysis stage. + MachineRegisterInfo *MRI = MIRBuilder.getMRI(); + Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32)); + MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass); + MachineInstrBuilder MB = + BuildMI(BB, I, I.getDebugLoc(), + TII.get(SPIRV::OpConstantFunctionPointerINTEL)) + .addDef(NewReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(FuncVReg); + // mapping the function pointer to the used Function + GR.recordFunctionPointer(&MB.getInstr()->getOperand(2), GVFun); + return MB.constrainAllUses(TII, TRI, RBI); + } return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull)) .addDef(NewReg) .addUse(GR.getSPIRVTypeID(ResType)) diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index 2dfb71dad193a..a18aae1761c83 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -291,6 +291,32 @@ void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI, } } +// References to a function via function pointers generate virtual +// registers without a definition. We are able to resolve this +// reference using Globar Register info into an OpFunction instruction +// and replace dummy operands by the corresponding global register references. +void SPIRVModuleAnalysis::collectFuncPtrs() { + for (auto &MI : MAI.MS[SPIRV::MB_TypeConstVars]) + if (MI->getOpcode() == SPIRV::OpConstantFunctionPointerINTEL) + collectFuncPtrs(MI); +} + +void SPIRVModuleAnalysis::collectFuncPtrs(MachineInstr *MI) { + const MachineOperand *FunUse = &MI->getOperand(2); + if (const MachineOperand *FunDef = GR->getFunctionDefinitionByUse(FunUse)) { + const MachineInstr *FunDefMI = FunDef->getParent(); + assert(FunDefMI->getOpcode() == SPIRV::OpFunction && + "Constant function pointer must refer to function definition"); + Register FunDefReg = FunDef->getReg(); + Register GlobalFunDefReg = + MAI.getRegisterAlias(FunDefMI->getMF(), FunDefReg); + assert(GlobalFunDefReg.isValid() && + "Function definition must refer to a global register"); + Register FunPtrReg = FunUse->getReg(); + MAI.setRegisterAlias(MI->getMF(), FunPtrReg, GlobalFunDefReg); + } +} + using InstrSignature = SmallVector; using InstrTraces = std::set; @@ -938,6 +964,18 @@ void addInstrRequirements(const MachineInstr &MI, Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR); } break; + case SPIRV::OpConstantFunctionPointerINTEL: + if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) { + Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers); + Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL); + } + break; + case SPIRV::OpFunctionPointerCallINTEL: + if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) { + Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers); + Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL); + } + break; default: break; } @@ -1096,6 +1134,10 @@ bool SPIRVModuleAnalysis::runOnModule(Module &M) { // Number rest of registers from N+1 onwards. numberRegistersGlobally(M); + // Update references to OpFunction instructions to use Global Registers + if (GR->hasConstFunPtr()) + collectFuncPtrs(); + // Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions. processOtherInstrs(M); diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h index d0b8027edd420..b05526b06e7da 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h @@ -224,6 +224,8 @@ struct SPIRVModuleAnalysis : public ModulePass { void collectFuncNames(MachineInstr &MI, const Function *F); void processOtherInstrs(const Module &M); void numberRegistersGlobally(const Module &M); + void collectFuncPtrs(); + void collectFuncPtrs(MachineInstr *MI); const SPIRVSubtarget *ST; SPIRVGlobalRegistry *GR; diff --git a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp index 6eb81f2deb3ab..effedc2f17d35 100644 --- a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp @@ -53,7 +53,10 @@ cl::list Extensions( clEnumValN(SPIRV::Extension::SPV_KHR_bit_instructions, "SPV_KHR_bit_instructions", "This enables bit instructions to be used by SPIR-V modules " - "without requiring the Shader capability"))); + "without requiring the Shader capability"), + clEnumValN(SPIRV::Extension::SPV_INTEL_function_pointers, + "SPV_INTEL_function_pointers", + "Allows translation of function pointers"))); // Compare version numbers, but allow 0 to mean unspecified. static bool isAtLeastVer(uint32_t Target, uint32_t VerToCompareTo) { diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td index 58ba7781b7777..5d252275ac709 100644 --- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td +++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td @@ -295,6 +295,7 @@ defm SPV_INTEL_usm_storage_classes : ExtensionOperand<100>; defm SPV_INTEL_fpga_latency_control : ExtensionOperand<101>; defm SPV_INTEL_fpga_argument_interfaces : ExtensionOperand<102>; defm SPV_INTEL_optnone : ExtensionOperand<103>; +defm SPV_INTEL_function_pointers : ExtensionOperand<104>; //===----------------------------------------------------------------------===// // Multiclass used to define Capabilities enum values and at the same time @@ -452,6 +453,8 @@ defm ArbitraryPrecisionIntegersINTEL : CapabilityOperand<5844, 0, 0, [SPV_INTEL_ defm OptNoneINTEL : CapabilityOperand<6094, 0, 0, [SPV_INTEL_optnone], []>; defm BitInstructions : CapabilityOperand<6025, 0, 0, [SPV_KHR_bit_instructions], []>; defm ExpectAssumeKHR : CapabilityOperand<5629, 0, 0, [SPV_KHR_expect_assume], []>; +defm FunctionPointersINTEL : CapabilityOperand<5603, 0, 0, [SPV_INTEL_function_pointers], []>; +defm IndirectReferencesINTEL : CapabilityOperand<5604, 0, 0, [SPV_INTEL_function_pointers], []>; //===----------------------------------------------------------------------===// // Multiclass used to define SourceLanguage enum values and at the same time @@ -688,6 +691,7 @@ defm HitAttributeNV : StorageClassOperand<5339, [RayTracingNV]>; defm IncomingRayPayloadNV : StorageClassOperand<5342, [RayTracingNV]>; defm ShaderRecordBufferNV : StorageClassOperand<5343, [RayTracingNV]>; defm PhysicalStorageBufferEXT : StorageClassOperand<5349, [PhysicalStorageBufferAddressesEXT]>; +defm CodeSectionINTEL : StorageClassOperand<5605, [FunctionPointersINTEL]>; //===----------------------------------------------------------------------===// // Multiclass used to define Dim enum values and at the same time @@ -1179,6 +1183,8 @@ defm CountBuffer : DecorationOperand<5634, 0, 0, [], []>; defm UserSemantic : DecorationOperand<5635, 0, 0, [], []>; defm RestrictPointerEXT : DecorationOperand<5355, 0, 0, [], [PhysicalStorageBufferAddressesEXT]>; defm AliasedPointerEXT : DecorationOperand<5356, 0, 0, [], [PhysicalStorageBufferAddressesEXT]>; +defm ReferencedIndirectlyINTEL : DecorationOperand<5602, 0, 0, [], [IndirectReferencesINTEL]>; +defm ArgumentAttributeINTEL : DecorationOperand<6409, 0, 0, [], [FunctionPointersINTEL]>; //===----------------------------------------------------------------------===// // Multiclass used to define BuiltIn enum values and at the same time diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_const.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_const.ll new file mode 100644 index 0000000000000..0bd1b5d776a94 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_const.ll @@ -0,0 +1,34 @@ +; RUN: llc -O0 -mtriple=spirv32-unknown-unknown --spirv-extensions=SPV_INTEL_function_pointers %s -o - | FileCheck %s +; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %} + +; CHECK-DAG: OpCapability Int8 +; CHECK-DAG: OpCapability FunctionPointersINTEL +; CHECK-DAG: OpCapability Int64 +; CHECK: OpExtension "SPV_INTEL_function_pointers" +; CHECK-DAG: %[[TyInt8:.*]] = OpTypeInt 8 0 +; CHECK-DAG: %[[TyVoid:.*]] = OpTypeVoid +; CHECK-DAG: %[[TyInt64:.*]] = OpTypeInt 64 0 +; CHECK-DAG: %[[TyFunFp:.*]] = OpTypeFunction %[[TyVoid]] %[[TyInt64]] +; CHECK-DAG: %[[ConstInt64:.*]] = OpConstant %[[TyInt64]] 42 +; CHECK-DAG: %[[TyPtrFunFp:.*]] = OpTypePointer Function %[[TyFunFp]] +; CHECK-DAG: %[[ConstFunFp:.*]] = OpConstantFunctionPointerINTEL %[[TyPtrFunFp]] %[[DefFunFp:.*]] +; CHECK: %[[FunPtr1:.*]] = OpBitcast %[[#]] %[[ConstFunFp]] +; CHECK: %[[FunPtr2:.*]] = OpLoad %[[#]] %[[FunPtr1]] +; CHECK: OpFunctionPointerCallINTEL %[[TyInt64]] %[[FunPtr2]] %[[ConstInt64]] +; CHECK: OpReturn +; CHECK: OpFunctionEnd +; CHECK: %[[DefFunFp]] = OpFunction %[[TyVoid]] None %[[TyFunFp]] + +target triple = "spir64-unknown-unknown" + +define spir_kernel void @test() { +entry: + %0 = load ptr, ptr @foo + %1 = call i64 %0(i64 42) + ret void +} + +define void @foo(i64 %a) { +entry: + ret void +} diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_two_calls.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_two_calls.ll new file mode 100644 index 0000000000000..89de098dead9c --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_two_calls.ll @@ -0,0 +1,34 @@ +; RUN: llc -O0 -mtriple=spirv32-unknown-unknown --spirv-extensions=SPV_INTEL_function_pointers %s -o - | FileCheck %s +; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %} + +; CHECK-DAG: OpCapability Int8 +; CHECK-DAG: OpCapability FunctionPointersINTEL +; CHECK-DAG: OpCapability Int64 +; CHECK: OpExtension "SPV_INTEL_function_pointers" +; CHECK-DAG: %[[TyInt8:.*]] = OpTypeInt 8 0 +; CHECK-DAG: %[[TyVoid:.*]] = OpTypeVoid +; CHECK-DAG: %[[TyFloat32:.*]] = OpTypeFloat 32 +; CHECK-DAG: %[[TyInt64:.*]] = OpTypeInt 64 0 +; CHECK-DAG: %[[TyPtrInt8:.*]] = OpTypePointer Function %[[TyInt8]] +; CHECK-DAG: %[[TyFunFp:.*]] = OpTypeFunction %[[TyFloat32]] %[[TyPtrInt8]] +; CHECK-DAG: %[[TyFunBar:.*]] = OpTypeFunction %[[TyInt64]] %[[TyPtrInt8]] %[[TyPtrInt8]] +; CHECK-DAG: %[[TyPtrFunFp:.*]] = OpTypePointer Function %[[TyFunFp]] +; CHECK-DAG: %[[TyPtrFunBar:.*]] = OpTypePointer Function %[[TyFunBar]] +; CHECK-DAG: %[[TyFunTest:.*]] = OpTypeFunction %[[TyVoid]] %[[TyPtrInt8]] %[[TyPtrInt8]] %[[TyPtrInt8]] +; CHECK: %[[FunTest:.*]] = OpFunction %[[TyVoid]] None %[[TyFunTest]] +; CHECK: %[[ArgFp:.*]] = OpFunctionParameter %[[TyPtrInt8]] +; CHECK: %[[ArgData:.*]] = OpFunctionParameter %[[TyPtrInt8]] +; CHECK: %[[ArgBar:.*]] = OpFunctionParameter %[[TyPtrInt8]] +; CHECK: OpFunctionPointerCallINTEL %[[TyFloat32]] %[[ArgFp]] %[[ArgBar]] +; CHECK: OpFunctionPointerCallINTEL %[[TyInt64]] %[[ArgBar]] %[[ArgFp]] %[[ArgData]] +; CHECK: OpReturn +; CHECK: OpFunctionEnd + +target triple = "spir64-unknown-unknown" + +define spir_kernel void @test(ptr %fp, ptr %data, ptr %bar) { +entry: + %0 = call spir_func float %fp(ptr %bar) + %1 = call spir_func i64 %bar(ptr %fp, ptr %data) + ret void +}