diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp index dbe8e18ecfcfc..d91923b41ddd3 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp @@ -507,7 +507,9 @@ static Register buildLoadInst(SPIRVType *BaseType, Register PtrRegister, static Register buildBuiltinVariableLoad( MachineIRBuilder &MIRBuilder, SPIRVType *VariableType, SPIRVGlobalRegistry *GR, SPIRV::BuiltIn::BuiltIn BuiltinValue, LLT LLType, - Register Reg = Register(0), bool isConst = true, bool hasLinkageTy = true) { + Register Reg = Register(0), bool isConst = true, + const std::optional &LinkageTy = { + SPIRV::LinkageType::Import}) { Register NewRegister = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::pIDRegClass); MIRBuilder.getMRI()->setType( @@ -521,9 +523,8 @@ static Register buildBuiltinVariableLoad( // Set up the global OpVariable with the necessary builtin decorations. Register Variable = GR->buildGlobalVariable( NewRegister, PtrType, getLinkStringForBuiltIn(BuiltinValue), nullptr, - SPIRV::StorageClass::Input, nullptr, /* isConst= */ isConst, - /* HasLinkageTy */ hasLinkageTy, SPIRV::LinkageType::Import, MIRBuilder, - false); + SPIRV::StorageClass::Input, nullptr, /* isConst= */ isConst, LinkageTy, + MIRBuilder, false); // Load the value from the global variable. Register LoadedRegister = @@ -1851,7 +1852,7 @@ static bool generateWaveInst(const SPIRV::IncomingCall *Call, return buildBuiltinVariableLoad( MIRBuilder, Call->ReturnType, GR, Value, LLType, Call->ReturnRegister, - /* isConst= */ false, /* hasLinkageTy= */ false); + /* isConst= */ false, /* LinkageType= */ std::nullopt); } // We expect a builtin diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp index 9d5f8285f447b..9e11c3a281a1b 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp @@ -479,18 +479,9 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, .addImm(static_cast(getExecutionModel(*ST, F))) .addUse(FuncVReg); addStringImm(F.getName(), MIB); - } else if (!F.hasLocalLinkage() && - F.getVisibility() != GlobalValue::HiddenVisibility) { - SPIRV::LinkageType::LinkageType LnkTy = - F.isDeclaration() - ? SPIRV::LinkageType::Import - : (F.getLinkage() == GlobalValue::LinkOnceODRLinkage && - ST->canUseExtension( - SPIRV::Extension::SPV_KHR_linkonce_odr) - ? SPIRV::LinkageType::LinkOnceODR - : SPIRV::LinkageType::Export); + } else if (const auto LnkTy = getSpirvLinkageTypeFor(*ST, F)) { buildOpDecorate(FuncVReg, MIRBuilder, SPIRV::Decoration::LinkageAttributes, - {static_cast(LnkTy)}, F.getName()); + {static_cast(*LnkTy)}, F.getName()); } // Handle function pointers decoration diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp index 6fd1c7ed78c06..6181abb281cc6 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -712,9 +712,9 @@ SPIRVGlobalRegistry::buildConstantSampler(Register ResReg, unsigned AddrMode, Register SPIRVGlobalRegistry::buildGlobalVariable( Register ResVReg, SPIRVType *BaseType, StringRef Name, const GlobalValue *GV, SPIRV::StorageClass::StorageClass Storage, - const MachineInstr *Init, bool IsConst, bool HasLinkageTy, - SPIRV::LinkageType::LinkageType LinkageType, MachineIRBuilder &MIRBuilder, - bool IsInstSelector) { + const MachineInstr *Init, bool IsConst, + const std::optional &LinkageType, + MachineIRBuilder &MIRBuilder, bool IsInstSelector) { const GlobalVariable *GVar = nullptr; if (GV) { GVar = cast(GV); @@ -792,9 +792,9 @@ Register SPIRVGlobalRegistry::buildGlobalVariable( buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Alignment, {Alignment}); } - if (HasLinkageTy) + if (LinkageType) buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::LinkageAttributes, - {static_cast(LinkageType)}, Name); + {static_cast(*LinkageType)}, Name); SPIRV::BuiltIn::BuiltIn BuiltInId; if (getSpirvBuiltInIdByName(Name, BuiltInId)) @@ -821,8 +821,8 @@ Register SPIRVGlobalRegistry::getOrCreateGlobalVariableWithBinding( MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::iIDRegClass); buildGlobalVariable(VarReg, VarType, Name, nullptr, - getPointerStorageClass(VarType), nullptr, false, false, - SPIRV::LinkageType::Import, MIRBuilder, false); + getPointerStorageClass(VarType), nullptr, false, + std::nullopt, MIRBuilder, false); buildOpDecorate(VarReg, MIRBuilder, SPIRV::Decoration::DescriptorSet, {Set}); buildOpDecorate(VarReg, MIRBuilder, SPIRV::Decoration::Binding, {Binding}); diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h index a648defa0a888..c230e62e795e8 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -548,14 +548,12 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping { MachineIRBuilder &MIRBuilder); Register getOrCreateUndef(MachineInstr &I, SPIRVType *SpvType, const SPIRVInstrInfo &TII); - Register buildGlobalVariable(Register Reg, SPIRVType *BaseType, - StringRef Name, const GlobalValue *GV, - SPIRV::StorageClass::StorageClass Storage, - const MachineInstr *Init, bool IsConst, - bool HasLinkageTy, - SPIRV::LinkageType::LinkageType LinkageType, - MachineIRBuilder &MIRBuilder, - bool IsInstSelector); + Register buildGlobalVariable( + Register Reg, SPIRVType *BaseType, StringRef Name, const GlobalValue *GV, + SPIRV::StorageClass::StorageClass Storage, const MachineInstr *Init, + bool IsConst, + const std::optional &LinkageType, + MachineIRBuilder &MIRBuilder, bool IsInstSelector); Register getOrCreateGlobalVariableWithBinding(const SPIRVType *VarType, uint32_t Set, uint32_t Binding, StringRef Name, diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 86d1444180855..5591d9ffa9292 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -4350,14 +4350,8 @@ bool SPIRVInstructionSelector::selectGlobalValue( if (hasInitializer(GlobalVar) && !Init) return true; - bool HasLnkTy = !GV->hasLocalLinkage() && !GV->hasHiddenVisibility(); - SPIRV::LinkageType::LinkageType LnkType = - GV->isDeclarationForLinker() - ? SPIRV::LinkageType::Import - : (GV->hasLinkOnceODRLinkage() && - STI.canUseExtension(SPIRV::Extension::SPV_KHR_linkonce_odr) - ? SPIRV::LinkageType::LinkOnceODR - : SPIRV::LinkageType::Export); + const std::optional LnkType = + getSpirvLinkageTypeFor(STI, *GV); const unsigned AddrSpace = GV->getAddressSpace(); SPIRV::StorageClass::StorageClass StorageClass = @@ -4365,7 +4359,7 @@ bool SPIRVInstructionSelector::selectGlobalValue( SPIRVType *ResType = GR.getOrCreateSPIRVPointerType(GVType, I, StorageClass); Register Reg = GR.buildGlobalVariable( ResVReg, ResType, GlobalIdent, GV, StorageClass, Init, - GlobalVar->isConstant(), HasLnkTy, LnkType, MIRBuilder, true); + GlobalVar->isConstant(), LnkType, MIRBuilder, true); return Reg.isValid(); } @@ -4516,8 +4510,8 @@ bool SPIRVInstructionSelector::loadVec3BuiltinInputID( // builtin variable. Register Variable = GR.buildGlobalVariable( NewRegister, PtrType, getLinkStringForBuiltIn(BuiltInValue), nullptr, - SPIRV::StorageClass::Input, nullptr, true, false, - SPIRV::LinkageType::Import, MIRBuilder, false); + SPIRV::StorageClass::Input, nullptr, true, std::nullopt, MIRBuilder, + false); // Create new register for loading value. MachineRegisterInfo *MRI = MIRBuilder.getMRI(); @@ -4569,8 +4563,8 @@ bool SPIRVInstructionSelector::loadBuiltinInputID( // builtin variable. Register Variable = GR.buildGlobalVariable( NewRegister, PtrType, getLinkStringForBuiltIn(BuiltInValue), nullptr, - SPIRV::StorageClass::Input, nullptr, true, false, - SPIRV::LinkageType::Import, MIRBuilder, false); + SPIRV::StorageClass::Input, nullptr, true, std::nullopt, MIRBuilder, + false); // Load uint value from the global variable. auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpLoad)) diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp index 1d47c892d03b1..4e2cc882ed6ba 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp @@ -1040,4 +1040,19 @@ getFirstValidInstructionInsertPoint(MachineBasicBlock &BB) { : VarPos; } +std::optional +getSpirvLinkageTypeFor(const SPIRVSubtarget &ST, const GlobalValue &GV) { + if (GV.hasLocalLinkage() || GV.hasHiddenVisibility()) + return std::nullopt; + + if (GV.isDeclarationForLinker()) + return SPIRV::LinkageType::Import; + + if (GV.hasLinkOnceODRLinkage() && + ST.canUseExtension(SPIRV::Extension::SPV_KHR_linkonce_odr)) + return SPIRV::LinkageType::LinkOnceODR; + + return SPIRV::LinkageType::Export; +} + } // namespace llvm diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h index 5777a24faabed..99d9d403ea70c 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.h +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h @@ -559,5 +559,8 @@ unsigned getArrayComponentCount(const MachineRegisterInfo *MRI, const MachineInstr *ResType); MachineBasicBlock::iterator getFirstValidInstructionInsertPoint(MachineBasicBlock &BB); + +std::optional +getSpirvLinkageTypeFor(const SPIRVSubtarget &ST, const GlobalValue &GV); } // namespace llvm #endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H