Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for SPIR-V extension: SPV_INTEL_function_pointers #80759

Merged

Conversation

VyacheslavLevytskyy
Copy link
Contributor

This PR adds initial support for "SPV_INTEL_function_pointers" SPIR-V extension: https://github.com/intel/llvm/blob/sycl/sycl/doc/design/spirv-extensions/SPV_INTEL_function_pointers.asciidoc

The goal of the extension is to support indirect function calls and translation of function pointers into SPIR-V.

Copy link

github-actions bot commented Feb 5, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@VyacheslavLevytskyy VyacheslavLevytskyy force-pushed the ext_function_pointers branch 2 times, most recently from 8b3b27c to 3fb82d2 Compare February 6, 2024 09:54
@VyacheslavLevytskyy VyacheslavLevytskyy marked this pull request as ready for review February 7, 2024 20:51
@llvmbot
Copy link
Collaborator

llvmbot commented Feb 7, 2024

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

Changes

This PR adds initial support for "SPV_INTEL_function_pointers" SPIR-V extension: https://github.com/intel/llvm/blob/sycl/sycl/doc/design/spirv-extensions/SPV_INTEL_function_pointers.asciidoc

The goal of the extension is to support indirect function calls and translation of function pointers into SPIR-V.


Patch is 33.09 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/80759.diff

13 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp (+181-26)
  • (modified) llvm/lib/Target/SPIRV/SPIRVCallLowering.h (+28)
  • (modified) llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp (+8-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h (+29)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp (+1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstrInfo.td (+10)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+27)
  • (modified) llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (+42)
  • (modified) llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h (+2)
  • (modified) llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp (+4-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td (+6)
  • (added) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_const.ll (+34)
  • (added) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_two_calls.ll (+34)
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 97b25147ffb34b..42deba3b330e8a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -34,6 +34,10 @@ bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
                                     const Value *Val, ArrayRef<Register> VRegs,
                                     FunctionLoweringInfo &FLI,
                                     Register SwiftErrorVReg) const {
+  // Maybe run postponed production of OpFunction/OpFunctionParameter's
+  if (FormalArgs.F != nullptr)
+    FormalArgs.produceFunArgsInstructions(MIRBuilder, GR, IndirectCalls);
+
   // Currently all return types should use a single register.
   // TODO: handle the case of multiple registers.
   if (VRegs.size() > 1)
@@ -217,6 +221,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
   // Assign types and names to all args, and store their types for later.
   FunctionType *FTy = getOriginalFunctionType(F);
   SmallVector<SPIRVType *, 4> ArgTypeVRegs;
+  bool HasOpaquePtrArg = false;
   if (VRegs.size() > 0) {
     unsigned i = 0;
     for (const auto &Arg : F.args()) {
@@ -231,6 +236,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
       if (Arg.hasName())
         buildOpName(VRegs[i][0], Arg.getName(), MIRBuilder);
       if (Arg.getType()->isPointerTy()) {
+        HasOpaquePtrArg = true;
         auto DerefBytes = static_cast<unsigned>(Arg.getDereferenceableBytes());
         if (DerefBytes != 0)
           buildOpDecorate(VRegs[i][0], MIRBuilder,
@@ -292,33 +298,60 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
     }
   }
 
-  // Generate a SPIR-V type for the function.
+  // If there is support of indirect calls and there are opaque pointer formal
+  // arguments, there is a chance to specify opaque ptr types later (after the
+  // function's body is processed) by information about the indirect call. To
+  // support this case we may postpone generation of some SPIR-V types, and
+  // OpFunction and OpFunctionParameter's. Otherwise we generate all SPIR-V
+  // types related to the function along with instructions.
+  const auto *ST =
+      static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
+  bool hasFunctionPointers =
+      ST->canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
+  bool PostponeOpFunction = HasOpaquePtrArg && hasFunctionPointers;
+
   auto MRI = MIRBuilder.getMRI();
   Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
   MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass);
   if (F.isDeclaration())
     GR->add(&F, &MIRBuilder.getMF(), FuncVReg);
   SPIRVType *RetTy = GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder);
-  SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
-      FTy, RetTy, ArgTypeVRegs, MIRBuilder);
-
-  // Build the OpTypeFunction declaring it.
+  SPIRVType *FuncTy = PostponeOpFunction
+                          ? nullptr
+                          : GR->getOrCreateOpTypeFunctionWithArgs(
+                                FTy, RetTy, ArgTypeVRegs, MIRBuilder);
   uint32_t FuncControl = getFunctionControl(F);
 
-  MIRBuilder.buildInstr(SPIRV::OpFunction)
-      .addDef(FuncVReg)
-      .addUse(GR->getSPIRVTypeID(RetTy))
-      .addImm(FuncControl)
-      .addUse(GR->getSPIRVTypeID(FuncTy));
+  if (PostponeOpFunction) {
+    FormalArgs.F = &F;
+    FormalArgs.KeepMBB = &(MIRBuilder.getMBB());
+    FormalArgs.KeepInsertPt = MIRBuilder.getInsertPt();
+    FormalArgs.FuncVReg = FuncVReg;
+    FormalArgs.RetTy = RetTy;
+    FormalArgs.FuncControl = FuncControl;
+    FormalArgs.OrigFTy = FTy;
+    FormalArgs.ArgTypeVRegs = ArgTypeVRegs;
+  } else {
+    MachineInstrBuilder MB = MIRBuilder.buildInstr(SPIRV::OpFunction)
+                                 .addDef(FuncVReg)
+                                 .addUse(GR->getSPIRVTypeID(RetTy))
+                                 .addImm(FuncControl)
+                                 .addUse(GR->getSPIRVTypeID(FuncTy));
+    GR->recordFunctionDefinition(&F, &MB.getInstr()->getOperand(0));
+  }
 
   // Add OpFunctionParameters.
   int i = 0;
   for (const auto &Arg : F.args()) {
     assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs");
     MRI->setRegClass(VRegs[i][0], &SPIRV::IDRegClass);
-    MIRBuilder.buildInstr(SPIRV::OpFunctionParameter)
-        .addDef(VRegs[i][0])
-        .addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i]));
+    if (PostponeOpFunction) {
+      FormalArgs.ArgVRegs.push_back(VRegs[i][0]);
+    } else {
+      MIRBuilder.buildInstr(SPIRV::OpFunctionParameter)
+          .addDef(VRegs[i][0])
+          .addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i]));
+    }
     if (F.isDeclaration())
       GR->add(&Arg, &MIRBuilder.getMF(), VRegs[i][0]);
     i++;
@@ -343,9 +376,106 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
                     {static_cast<uint32_t>(LnkTy)}, F.getGlobalIdentifier());
   }
 
+  // Handle function pointers decoration
+  if (hasFunctionPointers) {
+    if (F.hasFnAttribute("referenced-indirectly")) {
+      assert((F.getCallingConv() != CallingConv::SPIR_KERNEL) &&
+             "Unexpected 'referenced-indirectly' attribute of the kernel "
+             "function");
+      buildOpDecorate(FuncVReg, MIRBuilder,
+                      SPIRV::Decoration::ReferencedIndirectlyINTEL, {});
+    }
+  }
+
   return true;
 }
 
+// Use collect during function's body analysis information about the indirect
+// call to specify opaque ptr types of parent function's parameters
+void SPIRVCallLowering::SPIRVFunFormalArgs::produceFunArgsInstructions(
+    MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR,
+    SmallVector<SPIRVCallLowering::SPIRVIndirectCall> &IndirectCalls) {
+  // Store current insertion point
+  MachineBasicBlock &NextKeepMBB = MIRBuilder.getMBB();
+  MachineBasicBlock::iterator NextKeepInsertPt = MIRBuilder.getInsertPt();
+  // Set a new insertion point
+  MIRBuilder.setInsertPt(*KeepMBB, KeepInsertPt);
+
+  bool IsTypeUpd = false;
+  if (IndirectCalls.size() > 0) {
+    // TODO: add a topological sort of IndirectCalls
+    // Create indirect call data types if any
+    MachineFunction &MF = MIRBuilder.getMF();
+    for (auto const &IC : IndirectCalls) {
+      SPIRVType *SpirvRetTy = GR->getOrCreateSPIRVType(IC.RetTy, MIRBuilder);
+      SmallVector<SPIRVType *, 4> SpirvArgTypes;
+      for (size_t i = 0; i < IC.ArgTys.size(); ++i) {
+        SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(IC.ArgTys[i], MIRBuilder);
+        SpirvArgTypes.push_back(SPIRVTy);
+        if (!GR->getSPIRVTypeForVReg(IC.ArgRegs[i]))
+          GR->assignSPIRVTypeToVReg(SPIRVTy, IC.ArgRegs[i], MF);
+      }
+      // SPIR-V function type:
+      FunctionType *FTy =
+          FunctionType::get(const_cast<Type *>(IC.RetTy), IC.ArgTys, false);
+      SPIRVType *SpirvFuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
+          FTy, SpirvRetTy, SpirvArgTypes, MIRBuilder);
+      // SPIR-V pointer to function type:
+      SPIRVType *IndirectFuncPtrTy = GR->getOrCreateSPIRVPointerType(
+          SpirvFuncTy, MIRBuilder, SPIRV::StorageClass::Function);
+      // Correct the Calee type
+      GR->assignSPIRVTypeToVReg(IndirectFuncPtrTy, IC.Callee, MF);
+    }
+
+    // Check if our knowledge about a type of the function parameter is updated
+    // as a result of indirect calls analysis
+    for (size_t i = 0; i < ArgVRegs.size(); ++i) {
+      SPIRVType *ArgTy = GR->getSPIRVTypeForVReg(ArgVRegs[i]);
+      if (ArgTy && ArgTypeVRegs[i] != ArgTy) {
+        ArgTypeVRegs[i] = ArgTy;
+        IsTypeUpd = true;
+      }
+    }
+  }
+
+  // If we have update about function parameter types, create a new function
+  // type instead of the stored
+  // TODO: (maybe) allocated in getOriginalFunctionType(F) this->OrigFTy may be
+  // overwritten and is not used (tracked?) anywhere
+  FunctionType *UpdateFTy = OrigFTy;
+  if (IsTypeUpd) {
+    SmallVector<Type *, 4> ArgTys;
+    for (size_t i = 0; i < ArgTypeVRegs.size(); ++i) {
+      const Type *Ty = GR->getTypeForSPIRVType(ArgTypeVRegs[i]);
+      ArgTys.push_back(const_cast<Type *>(Ty));
+    }
+    // Argument types were specified, we must update function type
+    UpdateFTy = FunctionType::get(F->getReturnType(), ArgTys,
+                                  F->getFunctionType()->isVarArg());
+  }
+  // Create SPIR-V function type
+  SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
+      UpdateFTy, RetTy, ArgTypeVRegs, MIRBuilder);
+
+  // Emit OpFunction
+  MachineInstrBuilder MB = MIRBuilder.buildInstr(SPIRV::OpFunction)
+                               .addDef(FuncVReg)
+                               .addUse(GR->getSPIRVTypeID(RetTy))
+                               .addImm(FuncControl)
+                               .addUse(GR->getSPIRVTypeID(FuncTy));
+  GR->recordFunctionDefinition(F, &MB.getInstr()->getOperand(0));
+
+  // Emit OpFunctionParameter's
+  for (size_t i = 0; i < ArgVRegs.size(); ++i) {
+    MIRBuilder.buildInstr(SPIRV::OpFunctionParameter)
+        .addDef(ArgVRegs[i])
+        .addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i]));
+  }
+
+  // Restore insertion point
+  MIRBuilder.setInsertPt(NextKeepMBB, NextKeepInsertPt);
+}
+
 bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
                                   CallLoweringInfo &Info) const {
   // Currently call returns should have single vregs.
@@ -356,45 +486,44 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
   GR->setCurrentFunc(MF);
   FunctionType *FTy = nullptr;
   const Function *CF = nullptr;
+  std::string DemangledName;
+  const Type *OrigRetTy = Info.OrigRet.Ty;
 
   // Emit a regular OpFunctionCall. If it's an externally declared function,
   // be sure to emit its type and function declaration here. It will be hoisted
   // globally later.
   if (Info.Callee.isGlobal()) {
+    std::string FuncName = Info.Callee.getGlobal()->getName().str();
+    DemangledName = getOclOrSpirvBuiltinDemangledName(FuncName);
     CF = dyn_cast_or_null<const Function>(Info.Callee.getGlobal());
     // TODO: support constexpr casts and indirect calls.
     if (CF == nullptr)
       return false;
-    FTy = getOriginalFunctionType(*CF);
+    if ((FTy = getOriginalFunctionType(*CF)) != nullptr)
+      OrigRetTy = FTy->getReturnType();
   }
 
   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
   Register ResVReg =
       Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0];
-  std::string FuncName = Info.Callee.getGlobal()->getName().str();
-  std::string DemangledName = getOclOrSpirvBuiltinDemangledName(FuncName);
   const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
   // TODO: check that it's OCL builtin, then apply OpenCL_std.
   if (!DemangledName.empty() && CF && CF->isDeclaration() &&
       ST->canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
-    const Type *OrigRetTy = Info.OrigRet.Ty;
-    if (FTy)
-      OrigRetTy = FTy->getReturnType();
     SmallVector<Register, 8> ArgVRegs;
     for (auto Arg : Info.OrigArgs) {
       assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
       ArgVRegs.push_back(Arg.Regs[0]);
       SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(Arg.Ty, MIRBuilder);
       if (!GR->getSPIRVTypeForVReg(Arg.Regs[0]))
-        GR->assignSPIRVTypeToVReg(SPIRVTy, Arg.Regs[0], MIRBuilder.getMF());
+        GR->assignSPIRVTypeToVReg(SPIRVTy, Arg.Regs[0], MF);
     }
     if (auto Res = SPIRV::lowerBuiltin(
             DemangledName, SPIRV::InstructionSet::OpenCL_std, MIRBuilder,
             ResVReg, OrigRetTy, ArgVRegs, GR))
       return *Res;
   }
-  if (CF && CF->isDeclaration() &&
-      !GR->find(CF, &MIRBuilder.getMF()).isValid()) {
+  if (CF && CF->isDeclaration() && !GR->find(CF, &MF).isValid()) {
     // Emit the type info and forward function declaration to the first MBB
     // to ensure VReg definition dependencies are valid across all MBBs.
     MachineIRBuilder FirstBlockBuilder;
@@ -416,14 +545,40 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
     lowerFormalArguments(FirstBlockBuilder, *CF, VRegArgs, FuncInfo);
   }
 
+  unsigned CallOp;
+  if (Info.CB->isIndirectCall()) {
+    if (!ST->canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers))
+      report_fatal_error("An indirect call is encountered but SPIR-V without "
+                         "extensions does not support it",
+                         false);
+    // Set instruction operation according to SPV_INTEL_function_pointers
+    CallOp = SPIRV::OpFunctionPointerCallINTEL;
+    // Collect information about the indirect call to support possible
+    // specification of opaque ptr types of parent function's parameters
+    Register CalleeReg = Info.Callee.getReg();
+    if (CalleeReg.isValid()) {
+      SPIRVCallLowering::SPIRVIndirectCall IndirectCall;
+      IndirectCall.Callee = CalleeReg;
+      IndirectCall.RetTy = OrigRetTy;
+      for (const auto &Arg : Info.OrigArgs) {
+        assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
+        IndirectCall.ArgTys.push_back(Arg.Ty);
+        IndirectCall.ArgRegs.push_back(Arg.Regs[0]);
+      }
+      IndirectCalls.push_back(IndirectCall);
+    }
+  } else {
+    // Emit a regular OpFunctionCall
+    CallOp = SPIRV::OpFunctionCall;
+  }
+
   // Make sure there's a valid return reg, even for functions returning void.
   if (!ResVReg.isValid())
     ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
-  SPIRVType *RetType =
-      GR->assignTypeToVReg(FTy->getReturnType(), ResVReg, MIRBuilder);
+  SPIRVType *RetType = GR->assignTypeToVReg(OrigRetTy, ResVReg, MIRBuilder);
 
-  // Emit the OpFunctionCall and its args.
-  auto MIB = MIRBuilder.buildInstr(SPIRV::OpFunctionCall)
+  // Emit the call instruction and its args.
+  auto MIB = MIRBuilder.buildInstr(CallOp)
                  .addDef(ResVReg)
                  .addUse(GR->getSPIRVTypeID(RetType))
                  .add(Info.Callee);
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.h b/llvm/lib/Target/SPIRV/SPIRVCallLowering.h
index c2d6ad82d507d1..680db7ca7b1be3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.h
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.h
@@ -26,6 +26,34 @@ class SPIRVCallLowering : public CallLowering {
   // Used to create and assign function, argument, and return type information.
   SPIRVGlobalRegistry *GR;
 
+  // Used to postpone producing of OpFunction and OpFunctionParameter
+  // and use indirect calls to specify argument types
+  struct SPIRVIndirectCall {
+    const Type *RetTy = nullptr;
+    SmallVector<Type *> ArgTys;
+    SmallVector<Register> ArgRegs;
+    Register Callee;
+  };
+  struct SPIRVFunFormalArgs {
+    const Function *F = nullptr;
+    // the insertion point
+    MachineBasicBlock *KeepMBB = nullptr;
+    MachineBasicBlock::iterator KeepInsertPt;
+    // OpFunction and OpFunctionParameter operands
+    Register FuncVReg;
+    SPIRVType *RetTy = nullptr;
+    uint32_t FuncControl;
+    FunctionType *OrigFTy = nullptr;
+    SmallVector<SPIRVType *, 4> ArgTypeVRegs;
+    SmallVector<Register> ArgVRegs;
+
+    void produceFunArgsInstructions(
+        MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR,
+        SmallVector<SPIRVCallLowering::SPIRVIndirectCall> &IndirectCalls);
+  };
+  mutable SPIRVFunFormalArgs FormalArgs;
+  mutable SmallVector<SPIRVIndirectCall> IndirectCalls;
+
 public:
   SPIRVCallLowering(const SPIRVTargetLowering &TLI, SPIRVGlobalRegistry *GR);
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp
index cbe1a53fd75688..d82fb2df4539a3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp
@@ -54,7 +54,14 @@ void SPIRVGeneralDuplicatesTracker::buildDepsGraph(
         MachineOperand &Op = MI->getOperand(i);
         if (!Op.isReg())
           continue;
-        MachineOperand *RegOp = &MRI.getVRegDef(Op.getReg())->getOperand(0);
+        MachineInstr *VRegDef = MRI.getVRegDef(Op.getReg());
+        // References to a function via function pointers generate virtual
+        // registers without a definition. We are able to resolve this
+        // reference using Globar Register info into an OpFunction instruction
+        // but do not expect to find it in Reg2Entry.
+        if (MI->getOpcode() == SPIRV::OpConstantFunctionPointerINTEL && i == 2)
+          continue;
+        MachineOperand *RegOp = &VRegDef->getOperand(0);
         assert((MI->getOpcode() == SPIRV::OpVariable && i == 3) ||
                Reg2Entry.count(RegOp));
         if (Reg2Entry.count(RegOp))
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index f3280928c25dfa..792a00786f0aaf 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -38,6 +38,12 @@ class SPIRVGlobalRegistry {
 
   DenseMap<SPIRVType *, const Type *> SPIRVToLLVMType;
 
+  // map a Function to its definition (as a machine instruction operand)
+  DenseMap<const Function *, const MachineOperand *> FunctionToInstr;
+  // map function pointer (as a machine instruction operand) to the used
+  // Function
+  DenseMap<const MachineOperand *, const Function *> InstrToFunction;
+
   // Look for an equivalent of the newType in the map. Return the equivalent
   // if it's found, otherwise insert newType to the map and return the type.
   const MachineInstr *checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD,
@@ -101,6 +107,29 @@ class SPIRVGlobalRegistry {
     DT.buildDepsGraph(Graph, MMI);
   }
 
+  // Map a machine operand that represents a use of a function via function
+  // pointer to a machine operand that represents the function definition.
+  // Return either the register or invalid value, because we have no context for
+  // a good diagnostic message in case of unexpectedly missing references.
+  const MachineOperand *getFunctionDefinitionByUse(const MachineOperand *Use) {
+    auto ResF = InstrToFunction.find(Use);
+    if (ResF == InstrToFunction.end())
+      return nullptr;
+    auto ResReg = FunctionToInstr.find(ResF->second);
+    return ResReg == FunctionToInstr.end() ? nullptr : ResReg->second;
+  }
+  // map function pointer (as a machine instruction operand) to the used
+  // Function
+  void recordFunctionPointer(const MachineOperand *MO, const Function *F) {
+    InstrToFunction[MO] = F;
+  }
+  // map a Function to its definition (as a machine instruction)
+  void recordFunctionDefinition(const Function *F, const MachineOperand *MO) {
+    FunctionToInstr[F] = MO;
+  }
+  // Return true if any OpConstantFunctionPointerINTEL were generated
+  bool hasConstFunPtr() { return !InstrToFunction.empty(); }
+
   // Get or create a SPIR-V type corresponding the given LLVM IR type,
   // and map it to the given VReg by creating an ASSIGN_TYPE instruction.
   SPIRVType *assignTypeToVReg(const Type *Type, Register VReg,
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
index 42317453a2370e..e3f76419f13137 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
@@ -40,6 +40,7 @@ bool SPIRVInstrInfo::isConstantInstr(const MachineInstr &MI) const {
   case SPIRV::OpSpecConstantComposite:
   case SPIRV::OpSpecConstantOp:
   case SPIRV::OpUndef:
+  case SPIRV::OpConstantFunctionPointerINTEL:
     return true;
   default:
     return false;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index da033ba32624cc..3683fe9ec16482 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -761,3 +761,13 @@ def OpGroupNonUniformBitwiseXor: OpGroupNUGroup<"BitwiseXor", 361>;
 def OpGroupNonUniformLogicalAnd: OpGroupNUGroup<"LogicalAnd", 362>;
 def OpGroupNonUniformLogicalOr: OpGroupNUGroup<"LogicalOr", 363>;
 def OpGroupNonUniformLogicalXor: OpGroupNUGroup<"LogicalXor", 364>;
+
+// 3.49.7, Constant-Creation Instructions
+
+//  - SPV_INTEL_function_pointers
+def OpConstantFunctionPointerINTEL: Op<5600, (outs ID:$res), (ins TYPE:$ty, ID:$fun), "$res = OpConstantFunctionPointerINTEL $ty $fun">;
+
+/...
[truncated]

@VyacheslavLevytskyy
Copy link
Contributor Author

@michalpaszkowski I simplified implementation by removing the feature of resolving opaque pointer types by information about indirect calls. I think it's better to address this later, as just one of many cases when we'd benefit from a more general scheme of improvement of resolving opaque pointer types. This subject appears time from time and probably deserves a generalized solution.

Copy link
Member

@michalpaszkowski michalpaszkowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@VyacheslavLevytskyy Thank you for the patch! LGTM! Verified -- the patch does not influence OpenCL CTS pass rate.

// SPIR-V pointer to function type:
SPIRVType *IndirectFuncPtrTy = GR->getOrCreateSPIRVPointerType(
SpirvFuncTy, MIRBuilder, SPIRV::StorageClass::Function);
// Correct the Calee type
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: Calee -> Callee

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! I'll fix it in the next PR/commit

@VyacheslavLevytskyy VyacheslavLevytskyy merged commit d153ef6 into llvm:main Feb 12, 2024
4 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants