Skip to content

Commit

Permalink
Add support for SPIR-V extension: SPV_INTEL_function_pointers (#80759)
Browse files Browse the repository at this point in the history
This PR adds initial support for "SPV_INTEL_function_pointers" SPIR-V
extension:
https://github.com/intel/llvm/blob/sycl/sycl/doc/design/spirv-extensions/SPV_INTEL_function_pointers.asciidoc

The goal of the extension is to support indirect function calls and
translation of function pointers into SPIR-V.
  • Loading branch information
VyacheslavLevytskyy committed Feb 12, 2024
1 parent 9d8a236 commit d153ef6
Show file tree
Hide file tree
Showing 13 changed files with 307 additions and 24 deletions.
121 changes: 99 additions & 22 deletions llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
const Value *Val, ArrayRef<Register> VRegs,
FunctionLoweringInfo &FLI,
Register SwiftErrorVReg) const {
// Maybe run postponed production of types for function pointers
if (IndirectCalls.size() > 0) {
produceIndirectPtrTypes(MIRBuilder);
IndirectCalls.clear();
}

// Currently all return types should use a single register.
// TODO: handle the case of multiple registers.
if (VRegs.size() > 1)
Expand Down Expand Up @@ -292,7 +298,6 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
}
}

// Generate a SPIR-V type for the function.
auto MRI = MIRBuilder.getMRI();
Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass);
Expand All @@ -301,17 +306,17 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
SPIRVType *RetTy = GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder);
SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
FTy, RetTy, ArgTypeVRegs, MIRBuilder);

// Build the OpTypeFunction declaring it.
uint32_t FuncControl = getFunctionControl(F);

MIRBuilder.buildInstr(SPIRV::OpFunction)
.addDef(FuncVReg)
.addUse(GR->getSPIRVTypeID(RetTy))
.addImm(FuncControl)
.addUse(GR->getSPIRVTypeID(FuncTy));
// Add OpFunction instruction
MachineInstrBuilder MB = MIRBuilder.buildInstr(SPIRV::OpFunction)
.addDef(FuncVReg)
.addUse(GR->getSPIRVTypeID(RetTy))
.addImm(FuncControl)
.addUse(GR->getSPIRVTypeID(FuncTy));
GR->recordFunctionDefinition(&F, &MB.getInstr()->getOperand(0));

// Add OpFunctionParameters.
// Add OpFunctionParameter instructions
int i = 0;
for (const auto &Arg : F.args()) {
assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs");
Expand Down Expand Up @@ -343,9 +348,56 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
{static_cast<uint32_t>(LnkTy)}, F.getGlobalIdentifier());
}

// Handle function pointers decoration
const auto *ST =
static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
bool hasFunctionPointers =
ST->canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
if (hasFunctionPointers) {
if (F.hasFnAttribute("referenced-indirectly")) {
assert((F.getCallingConv() != CallingConv::SPIR_KERNEL) &&
"Unexpected 'referenced-indirectly' attribute of the kernel "
"function");
buildOpDecorate(FuncVReg, MIRBuilder,
SPIRV::Decoration::ReferencedIndirectlyINTEL, {});
}
}

return true;
}

// Used to postpone producing of indirect function pointer types after all
// indirect calls info is collected
// TODO:
// - add a topological sort of IndirectCalls to ensure the best types knowledge
// - we may need to fix function formal parameter types if they are opaque
// pointers used as function pointers in these indirect calls
void SPIRVCallLowering::produceIndirectPtrTypes(
MachineIRBuilder &MIRBuilder) const {
// Create indirect call data types if any
MachineFunction &MF = MIRBuilder.getMF();
for (auto const &IC : IndirectCalls) {
SPIRVType *SpirvRetTy = GR->getOrCreateSPIRVType(IC.RetTy, MIRBuilder);
SmallVector<SPIRVType *, 4> SpirvArgTypes;
for (size_t i = 0; i < IC.ArgTys.size(); ++i) {
SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(IC.ArgTys[i], MIRBuilder);
SpirvArgTypes.push_back(SPIRVTy);
if (!GR->getSPIRVTypeForVReg(IC.ArgRegs[i]))
GR->assignSPIRVTypeToVReg(SPIRVTy, IC.ArgRegs[i], MF);
}
// SPIR-V function type:
FunctionType *FTy =
FunctionType::get(const_cast<Type *>(IC.RetTy), IC.ArgTys, false);
SPIRVType *SpirvFuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
FTy, SpirvRetTy, SpirvArgTypes, MIRBuilder);
// SPIR-V pointer to function type:
SPIRVType *IndirectFuncPtrTy = GR->getOrCreateSPIRVPointerType(
SpirvFuncTy, MIRBuilder, SPIRV::StorageClass::Function);
// Correct the Calee type
GR->assignSPIRVTypeToVReg(IndirectFuncPtrTy, IC.Callee, MF);
}
}

bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
CallLoweringInfo &Info) const {
// Currently call returns should have single vregs.
Expand All @@ -356,45 +408,44 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
GR->setCurrentFunc(MF);
FunctionType *FTy = nullptr;
const Function *CF = nullptr;
std::string DemangledName;
const Type *OrigRetTy = Info.OrigRet.Ty;

// Emit a regular OpFunctionCall. If it's an externally declared function,
// be sure to emit its type and function declaration here. It will be hoisted
// globally later.
if (Info.Callee.isGlobal()) {
std::string FuncName = Info.Callee.getGlobal()->getName().str();
DemangledName = getOclOrSpirvBuiltinDemangledName(FuncName);
CF = dyn_cast_or_null<const Function>(Info.Callee.getGlobal());
// TODO: support constexpr casts and indirect calls.
if (CF == nullptr)
return false;
FTy = getOriginalFunctionType(*CF);
if ((FTy = getOriginalFunctionType(*CF)) != nullptr)
OrigRetTy = FTy->getReturnType();
}

MachineRegisterInfo *MRI = MIRBuilder.getMRI();
Register ResVReg =
Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0];
std::string FuncName = Info.Callee.getGlobal()->getName().str();
std::string DemangledName = getOclOrSpirvBuiltinDemangledName(FuncName);
const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
// TODO: check that it's OCL builtin, then apply OpenCL_std.
if (!DemangledName.empty() && CF && CF->isDeclaration() &&
ST->canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
const Type *OrigRetTy = Info.OrigRet.Ty;
if (FTy)
OrigRetTy = FTy->getReturnType();
SmallVector<Register, 8> ArgVRegs;
for (auto Arg : Info.OrigArgs) {
assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
ArgVRegs.push_back(Arg.Regs[0]);
SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(Arg.Ty, MIRBuilder);
if (!GR->getSPIRVTypeForVReg(Arg.Regs[0]))
GR->assignSPIRVTypeToVReg(SPIRVTy, Arg.Regs[0], MIRBuilder.getMF());
GR->assignSPIRVTypeToVReg(SPIRVTy, Arg.Regs[0], MF);
}
if (auto Res = SPIRV::lowerBuiltin(
DemangledName, SPIRV::InstructionSet::OpenCL_std, MIRBuilder,
ResVReg, OrigRetTy, ArgVRegs, GR))
return *Res;
}
if (CF && CF->isDeclaration() &&
!GR->find(CF, &MIRBuilder.getMF()).isValid()) {
if (CF && CF->isDeclaration() && !GR->find(CF, &MF).isValid()) {
// Emit the type info and forward function declaration to the first MBB
// to ensure VReg definition dependencies are valid across all MBBs.
MachineIRBuilder FirstBlockBuilder;
Expand All @@ -416,14 +467,40 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
lowerFormalArguments(FirstBlockBuilder, *CF, VRegArgs, FuncInfo);
}

unsigned CallOp;
if (Info.CB->isIndirectCall()) {
if (!ST->canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers))
report_fatal_error("An indirect call is encountered but SPIR-V without "
"extensions does not support it",
false);
// Set instruction operation according to SPV_INTEL_function_pointers
CallOp = SPIRV::OpFunctionPointerCallINTEL;
// Collect information about the indirect call to support possible
// specification of opaque ptr types of parent function's parameters
Register CalleeReg = Info.Callee.getReg();
if (CalleeReg.isValid()) {
SPIRVCallLowering::SPIRVIndirectCall IndirectCall;
IndirectCall.Callee = CalleeReg;
IndirectCall.RetTy = OrigRetTy;
for (const auto &Arg : Info.OrigArgs) {
assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
IndirectCall.ArgTys.push_back(Arg.Ty);
IndirectCall.ArgRegs.push_back(Arg.Regs[0]);
}
IndirectCalls.push_back(IndirectCall);
}
} else {
// Emit a regular OpFunctionCall
CallOp = SPIRV::OpFunctionCall;
}

// Make sure there's a valid return reg, even for functions returning void.
if (!ResVReg.isValid())
ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
SPIRVType *RetType =
GR->assignTypeToVReg(FTy->getReturnType(), ResVReg, MIRBuilder);
SPIRVType *RetType = GR->assignTypeToVReg(OrigRetTy, ResVReg, MIRBuilder);

// Emit the OpFunctionCall and its args.
auto MIB = MIRBuilder.buildInstr(SPIRV::OpFunctionCall)
// Emit the call instruction and its args.
auto MIB = MIRBuilder.buildInstr(CallOp)
.addDef(ResVReg)
.addUse(GR->getSPIRVTypeID(RetType))
.add(Info.Callee);
Expand Down
11 changes: 11 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVCallLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@ class SPIRVCallLowering : public CallLowering {
// Used to create and assign function, argument, and return type information.
SPIRVGlobalRegistry *GR;

// Used to postpone producing of indirect function pointer types
// after all indirect calls info is collected
struct SPIRVIndirectCall {
const Type *RetTy = nullptr;
SmallVector<Type *> ArgTys;
SmallVector<Register> ArgRegs;
Register Callee;
};
void produceIndirectPtrTypes(MachineIRBuilder &MIRBuilder) const;
mutable SmallVector<SPIRVIndirectCall> IndirectCalls;

public:
SPIRVCallLowering(const SPIRVTargetLowering &TLI, SPIRVGlobalRegistry *GR);

Expand Down
9 changes: 8 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,14 @@ void SPIRVGeneralDuplicatesTracker::buildDepsGraph(
MachineOperand &Op = MI->getOperand(i);
if (!Op.isReg())
continue;
MachineOperand *RegOp = &MRI.getVRegDef(Op.getReg())->getOperand(0);
MachineInstr *VRegDef = MRI.getVRegDef(Op.getReg());
// References to a function via function pointers generate virtual
// registers without a definition. We are able to resolve this
// reference using Globar Register info into an OpFunction instruction
// but do not expect to find it in Reg2Entry.
if (MI->getOpcode() == SPIRV::OpConstantFunctionPointerINTEL && i == 2)
continue;
MachineOperand *RegOp = &VRegDef->getOperand(0);
assert((MI->getOpcode() == SPIRV::OpVariable && i == 3) ||
Reg2Entry.count(RegOp));
if (Reg2Entry.count(RegOp))
Expand Down
29 changes: 29 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ class SPIRVGlobalRegistry {

DenseMap<SPIRVType *, const Type *> SPIRVToLLVMType;

// map a Function to its definition (as a machine instruction operand)
DenseMap<const Function *, const MachineOperand *> FunctionToInstr;
// map function pointer (as a machine instruction operand) to the used
// Function
DenseMap<const MachineOperand *, const Function *> InstrToFunction;

// Look for an equivalent of the newType in the map. Return the equivalent
// if it's found, otherwise insert newType to the map and return the type.
const MachineInstr *checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD,
Expand Down Expand Up @@ -101,6 +107,29 @@ class SPIRVGlobalRegistry {
DT.buildDepsGraph(Graph, MMI);
}

// Map a machine operand that represents a use of a function via function
// pointer to a machine operand that represents the function definition.
// Return either the register or invalid value, because we have no context for
// a good diagnostic message in case of unexpectedly missing references.
const MachineOperand *getFunctionDefinitionByUse(const MachineOperand *Use) {
auto ResF = InstrToFunction.find(Use);
if (ResF == InstrToFunction.end())
return nullptr;
auto ResReg = FunctionToInstr.find(ResF->second);
return ResReg == FunctionToInstr.end() ? nullptr : ResReg->second;
}
// map function pointer (as a machine instruction operand) to the used
// Function
void recordFunctionPointer(const MachineOperand *MO, const Function *F) {
InstrToFunction[MO] = F;
}
// map a Function to its definition (as a machine instruction)
void recordFunctionDefinition(const Function *F, const MachineOperand *MO) {
FunctionToInstr[F] = MO;
}
// Return true if any OpConstantFunctionPointerINTEL were generated
bool hasConstFunPtr() { return !InstrToFunction.empty(); }

// Get or create a SPIR-V type corresponding the given LLVM IR type,
// and map it to the given VReg by creating an ASSIGN_TYPE instruction.
SPIRVType *assignTypeToVReg(const Type *Type, Register VReg,
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ bool SPIRVInstrInfo::isConstantInstr(const MachineInstr &MI) const {
case SPIRV::OpSpecConstantComposite:
case SPIRV::OpSpecConstantOp:
case SPIRV::OpUndef:
case SPIRV::OpConstantFunctionPointerINTEL:
return true;
default:
return false;
Expand Down
10 changes: 10 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,16 @@ def OpGroupNonUniformLogicalAnd: OpGroupNUGroup<"LogicalAnd", 362>;
def OpGroupNonUniformLogicalOr: OpGroupNUGroup<"LogicalOr", 363>;
def OpGroupNonUniformLogicalXor: OpGroupNUGroup<"LogicalXor", 364>;

// 3.49.7, Constant-Creation Instructions

// - SPV_INTEL_function_pointers
def OpConstantFunctionPointerINTEL: Op<5600, (outs ID:$res), (ins TYPE:$ty, ID:$fun), "$res = OpConstantFunctionPointerINTEL $ty $fun">;

// 3.49.9. Function Instructions

// - SPV_INTEL_function_pointers
def OpFunctionPointerCallINTEL: Op<5601, (outs ID:$res), (ins TYPE:$ty, ID:$funPtr, variable_ops), "$res = OpFunctionPointerCallINTEL $ty $funPtr">;

// 3.49.21. Group and Subgroup Instructions
def OpSubgroupShuffleINTEL: Op<5571, (outs ID:$res), (ins TYPE:$type, ID:$data, ID:$invocationId),
"$res = OpSubgroupShuffleINTEL $type $data $invocationId">;
Expand Down
27 changes: 27 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1534,6 +1534,12 @@ bool SPIRVInstructionSelector::selectGlobalValue(
GlobalIdent = GV->getGlobalIdentifier();
}

// Behaviour of functions as operands depends on availability of the
// corresponding extension (SPV_INTEL_function_pointers):
// - If there is an extension to operate with functions as operands:
// We create a proper constant operand and evaluate a correct type for a
// function pointer.
// - Without the required extension:
// We have functions as operands in tests with blocks of instruction e.g. in
// transcoding/global_block.ll. These operands are not used and should be
// substituted by zero constants. Their type is expected to be always
Expand All @@ -1545,6 +1551,27 @@ bool SPIRVInstructionSelector::selectGlobalValue(
if (!NewReg.isValid()) {
Register NewReg = ResVReg;
GR.add(ConstVal, GR.CurMF, NewReg);
const Function *GVFun =
STI.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)
? dyn_cast<Function>(GV)
: nullptr;
if (GVFun) {
// References to a function via function pointers generate virtual
// registers without a definition. We will resolve it later, during
// module analysis stage.
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass);
MachineInstrBuilder MB =
BuildMI(BB, I, I.getDebugLoc(),
TII.get(SPIRV::OpConstantFunctionPointerINTEL))
.addDef(NewReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(FuncVReg);
// mapping the function pointer to the used Function
GR.recordFunctionPointer(&MB.getInstr()->getOperand(2), GVFun);
return MB.constrainAllUses(TII, TRI, RBI);
}
return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
.addDef(NewReg)
.addUse(GR.getSPIRVTypeID(ResType))
Expand Down

0 comments on commit d153ef6

Please sign in to comment.