diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp index 23c5798f9d0af..9bbca15ffb6c1 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp @@ -221,41 +221,17 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx, static SPIRV::ExecutionModel::ExecutionModel getExecutionModel(const SPIRVSubtarget &STI, const Function &F) { + assert(STI.getEnv() != SPIRVSubtarget::Unknown && + "Environment must be resolved before lowering entry points."); + if (STI.isKernel()) return SPIRV::ExecutionModel::Kernel; - if (STI.isShader()) { - auto attribute = F.getFnAttribute("hlsl.shader"); - if (!attribute.isValid()) { - report_fatal_error( - "This entry point lacks mandatory hlsl.shader attribute."); - } - - const auto value = attribute.getValueAsString(); - if (value == "compute") - return SPIRV::ExecutionModel::GLCompute; - if (value == "vertex") - return SPIRV::ExecutionModel::Vertex; - if (value == "pixel") - return SPIRV::ExecutionModel::Fragment; - - report_fatal_error( - "This HLSL entry point is not supported by this backend."); - } - - assert(STI.getEnv() == SPIRVSubtarget::Unknown); - // "hlsl.shader" attribute is mandatory for Vulkan, so we can set Env to - // Shader whenever we find it, and to Kernel otherwise. - - // We will now change the Env based on the attribute, so we need to strip - // `const` out of the ref to STI. - SPIRVSubtarget *NonConstSTI = const_cast(&STI); auto attribute = F.getFnAttribute("hlsl.shader"); if (!attribute.isValid()) { - NonConstSTI->setEnv(SPIRVSubtarget::Kernel); - return SPIRV::ExecutionModel::Kernel; + report_fatal_error( + "This entry point lacks mandatory hlsl.shader attribute."); } - NonConstSTI->setEnv(SPIRVSubtarget::Shader); const auto value = attribute.getValueAsString(); if (value == "compute") @@ -432,11 +408,6 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, // Handle entry points and function linkage. if (isEntryPoint(F)) { - // EntryPoints can help us to determine the environment we're working on. - // Therefore, we need a non-const pointer to SPIRVSubtarget to update the - // environment if we need to. - const SPIRVSubtarget *ST = - static_cast(&MIRBuilder.getMF().getSubtarget()); auto MIB = MIRBuilder.buildInstr(SPIRV::OpEntryPoint) .addImm(static_cast(getExecutionModel(*ST, F))) .addUse(FuncVReg); diff --git a/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp b/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp index a3425704f050d..408173eb57394 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp @@ -656,6 +656,13 @@ bool SPIRVPrepareFunctions::removeAggregateTypesFromCalls(Function *F) { } bool SPIRVPrepareFunctions::runOnModule(Module &M) { + // Resolve the SPIR-V environment from module content before any + // function-level processing. This must happen before legalization so that + // isShader()/isKernel() return correct values. + const_cast(TM) + .getMutableSubtargetImpl() + ->resolveEnvFromModule(M); + bool Changed = false; for (Function &F : M) { Changed |= substituteIntrinsicCalls(&F); diff --git a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp index ad6c9cd421b7c..6a798057240de 100644 --- a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp @@ -173,6 +173,42 @@ void SPIRVSubtarget::initAvailableExtInstSets() { accountForAMDShaderTrinaryMinmax(); } +void SPIRVSubtarget::setEnv(SPIRVEnvType E) { + if (E == Unknown) + report_fatal_error("Unknown environment is not allowed."); + if (Env != Unknown && Env != E) + report_fatal_error("Environment is already set to a different value."); + if (Env == E) + return; + + Env = E; + + // Reinitialize Env-dependent state aka ExtInstSet and legalizer info. + initAvailableExtInstSets(); + Legalizer = std::make_unique(*this); +} + +void SPIRVSubtarget::resolveEnvFromModule(const Module &M) { + if (Env != Unknown) { + assert(!(isKernel() && any_of(M, + [](const Function &F) { + return F.hasFnAttribute("hlsl.shader"); + })) && + "Module has hlsl.shader attributes but environment is Kernel"); + return; + } + + bool HasShaderAttr = false; + for (const Function &F : M) { + if (F.hasFnAttribute("hlsl.shader")) { + HasShaderAttr = true; + break; + } + } + + setEnv(HasShaderAttr ? Shader : Kernel); +} + // Set available extensions after SPIRVSubtarget is created. void SPIRVSubtarget::initAvailableExtensions( const std::set &AllowedExtIds) { diff --git a/llvm/lib/Target/SPIRV/SPIRVSubtarget.h b/llvm/lib/Target/SPIRV/SPIRVSubtarget.h index ad3e38d296ed7..18f7e0179270c 100644 --- a/llvm/lib/Target/SPIRV/SPIRVSubtarget.h +++ b/llvm/lib/Target/SPIRV/SPIRVSubtarget.h @@ -25,6 +25,7 @@ #include "llvm/CodeGen/SelectionDAGTargetInfo.h" #include "llvm/CodeGen/TargetSubtargetInfo.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/Module.h" #include "llvm/Target/TargetMachine.h" #include "llvm/TargetParser/Triple.h" @@ -62,8 +63,6 @@ class SPIRVSubtarget : public SPIRVGenSubtargetInfo { std::unique_ptr InstSelector; std::unique_ptr InlineAsmInfo; - // TODO: Initialise the available extensions, extended instruction sets - // based on the environment settings. void initAvailableExtInstSets(); void accountForAMDShaderTrinaryMinmax(); @@ -76,6 +75,7 @@ class SPIRVSubtarget : public SPIRVGenSubtargetInfo { void initAvailableExtensions( const std::set &AllowedExtIds); + void resolveEnvFromModule(const Module &M); // Parses features string setting specified subtarget options. // The definition of this function is auto generated by tblgen. @@ -83,14 +83,7 @@ class SPIRVSubtarget : public SPIRVGenSubtargetInfo { unsigned getPointerSize() const { return PointerSize; } unsigned getBound() const { return GR->getBound(); } bool canDirectlyComparePointers() const; - void setEnv(SPIRVEnvType E) { - if (E == Unknown) - report_fatal_error("Unknown environment is not allowed."); - if (Env != Unknown) - report_fatal_error("Environment is already set."); - - Env = E; - } + void setEnv(SPIRVEnvType E); SPIRVEnvType getEnv() const { return Env; } bool isKernel() const { return getEnv() == Kernel; } bool isShader() const { return getEnv() == Shader; } diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.h b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.h index 9c59d021dfc1b..ea09fe98c55ee 100644 --- a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.h +++ b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.h @@ -35,6 +35,8 @@ class SPIRVTargetMachine : public CodeGenTargetMachineImpl { return &Subtarget; } + SPIRVSubtarget *getMutableSubtargetImpl() { return &Subtarget; } + TargetTransformInfo getTargetTransformInfo(const Function &F) const override; TargetPassConfig *createPassConfig(PassManagerBase &PM) override; diff --git a/llvm/test/CodeGen/SPIRV/is-shader-env.ll b/llvm/test/CodeGen/SPIRV/is-shader-env.ll new file mode 100644 index 0000000000000..0365f93e781e6 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/is-shader-env.ll @@ -0,0 +1,34 @@ +; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val --target-env vulkan1.3 %} + +; Regression test for https://github.com/llvm/llvm-project/issues/171898 +; When triple is spirv-unknown-unknown and a non-entry-point function using +; wide vectors (e.g. <8 x i32>) appears before the entry point with +; hlsl.shader attribute, the environment must be resolved early enough that +; legalization uses the correct vector size limits. + +; CHECK-DAG: OpCapability Shader +; CHECK-DAG: OpEntryPoint GLCompute %[[#entry:]] "main" +; CHECK-NOT: OpTypeVector %{{.*}} 8 + +@GVec4 = internal addrspace(10) global <4 x double> zeroinitializer +@Lows = internal addrspace(10) global <4 x i32> zeroinitializer +@Highs = internal addrspace(10) global <4 x i32> zeroinitializer + +define internal void @test_split() { +entry: + %0 = load <8 x i32>, ptr addrspace(10) @GVec4, align 32 + %1 = shufflevector <8 x i32> %0, <8 x i32> poison, <4 x i32> + %2 = shufflevector <8 x i32> %0, <8 x i32> poison, <4 x i32> + store <4 x i32> %1, ptr addrspace(10) @Lows, align 16 + store <4 x i32> %2, ptr addrspace(10) @Highs, align 16 + ret void +} + +define void @main() local_unnamed_addr #0 { +entry: + call void @test_split() + ret void +} + +attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }