Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPIR-V] Add WaveGetLaneIndex() intrinsic support #85979

Merged
merged 1 commit into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
33 changes: 25 additions & 8 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,12 +368,10 @@ static Register buildLoadInst(SPIRVType *BaseType, Register PtrRegister,

/// Helper function for building a load instruction for loading a builtin global
/// variable of \p BuiltinValue value.
static Register buildBuiltinVariableLoad(MachineIRBuilder &MIRBuilder,
SPIRVType *VariableType,
SPIRVGlobalRegistry *GR,
SPIRV::BuiltIn::BuiltIn BuiltinValue,
LLT LLType,
Register Reg = Register(0)) {
static Register buildBuiltinVariableLoad(
MachineIRBuilder &MIRBuilder, SPIRVType *VariableType,
SPIRVGlobalRegistry *GR, SPIRV::BuiltIn::BuiltIn BuiltinValue, LLT LLType,
Register Reg = Register(0), bool isConst = true, bool hasLinkageTy = true) {
Register NewRegister =
MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
MIRBuilder.getMRI()->setType(NewRegister,
Expand All @@ -385,8 +383,9 @@ static Register buildBuiltinVariableLoad(MachineIRBuilder &MIRBuilder,
// Set up the global OpVariable with the necessary builtin decorations.
Register Variable = GR->buildGlobalVariable(
NewRegister, PtrType, getLinkStringForBuiltIn(BuiltinValue), nullptr,
SPIRV::StorageClass::Input, nullptr, true, true,
SPIRV::LinkageType::Import, MIRBuilder, false);
SPIRV::StorageClass::Input, nullptr, /* isConst= */ isConst,
/* HasLinkageTy */ hasLinkageTy, SPIRV::LinkageType::Import, MIRBuilder,
false);

// Load the value from the global variable.
Register LoadedRegister =
Expand Down Expand Up @@ -1300,6 +1299,22 @@ static bool generateDotOrFMulInst(const SPIRV::IncomingCall *Call,
return true;
}

static bool generateWaveInst(const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
SPIRV::BuiltIn::BuiltIn Value =
SPIRV::lookupGetBuiltin(Builtin->Name, Builtin->Set)->Value;

// For now, we only support a single Wave intrinsic with a single return type.
assert(Call->ReturnType->getOpcode() == SPIRV::OpTypeInt);
LLT LLType = LLT::scalar(GR->getScalarOrVectorBitWidth(Call->ReturnType));

return buildBuiltinVariableLoad(
MIRBuilder, Call->ReturnType, GR, Value, LLType, Call->ReturnRegister,
/* isConst= */ false, /* hasLinkageTy= */ false);
}

static bool generateGetQueryInst(const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
Expand Down Expand Up @@ -2187,6 +2202,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
return generateBarrierInst(Call.get(), MIRBuilder, GR);
case SPIRV::Dot:
return generateDotOrFMulInst(Call.get(), MIRBuilder, GR);
case SPIRV::Wave:
return generateWaveInst(Call.get(), MIRBuilder, GR);
case SPIRV::GetQuery:
return generateGetQueryInst(Call.get(), MIRBuilder, GR);
case SPIRV::ImageSizeQuery:
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.td
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def Variable : BuiltinGroup;
def Atomic : BuiltinGroup;
def Barrier : BuiltinGroup;
def Dot : BuiltinGroup;
def Wave : BuiltinGroup;
def GetQuery : BuiltinGroup;
def ImageSizeQuery : BuiltinGroup;
def ImageMiscQuery : BuiltinGroup;
Expand Down Expand Up @@ -1119,6 +1120,7 @@ defm : DemangledGetBuiltin<"get_global_size", OpenCL_std, GetQuery, GlobalSize>;
defm : DemangledGetBuiltin<"get_group_id", OpenCL_std, GetQuery, WorkgroupId>;
defm : DemangledGetBuiltin<"get_enqueued_local_size", OpenCL_std, GetQuery, EnqueuedWorkgroupSize>;
defm : DemangledGetBuiltin<"get_num_groups", OpenCL_std, GetQuery, NumWorkgroups>;
defm : DemangledGetBuiltin<"__hlsl_wave_get_lane_index", GLSL_std_450, Wave, SubgroupLocalInvocationId>;

//===----------------------------------------------------------------------===//
// Class defining an image query builtin record used for lowering the OpenCL
Expand Down
23 changes: 16 additions & 7 deletions llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -493,9 +493,15 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
Register ResVReg =
Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0];
const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
// TODO: check that it's OCL builtin, then apply OpenCL_std.
if (!DemangledName.empty() && CF && CF->isDeclaration() &&
ST->canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {

bool isFunctionDecl = CF && CF->isDeclaration();
bool canUseOpenCL = ST->canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std);
bool canUseGLSL = ST->canUseExtInstSet(SPIRV::InstructionSet::GLSL_std_450);
assert(canUseGLSL != canUseOpenCL &&
"Scenario where both sets are enabled is not supported.");

if (isFunctionDecl && !DemangledName.empty() &&
(canUseGLSL || canUseOpenCL)) {
SmallVector<Register, 8> ArgVRegs;
for (auto Arg : Info.OrigArgs) {
assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
Expand All @@ -504,12 +510,15 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
if (!GR->getSPIRVTypeForVReg(Arg.Regs[0]))
GR->assignSPIRVTypeToVReg(SPIRVTy, Arg.Regs[0], MF);
}
if (auto Res = SPIRV::lowerBuiltin(
DemangledName, SPIRV::InstructionSet::OpenCL_std, MIRBuilder,
ResVReg, OrigRetTy, ArgVRegs, GR))
auto instructionSet = canUseOpenCL ? SPIRV::InstructionSet::OpenCL_std
: SPIRV::InstructionSet::GLSL_std_450;
if (auto Res =
SPIRV::lowerBuiltin(DemangledName, instructionSet, MIRBuilder,
ResVReg, OrigRetTy, ArgVRegs, GR))
return *Res;
}
if (CF && CF->isDeclaration() && !GR->find(CF, &MF).isValid()) {

if (isFunctionDecl && !GR->find(CF, &MF).isValid()) {
// Emit the type info and forward function declaration to the first MBB
// to ensure VReg definition dependencies are valid across all MBBs.
MachineIRBuilder FirstBlockBuilder;
Expand Down
13 changes: 8 additions & 5 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -721,11 +721,14 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
AddrSpace = PType->getAddressSpace();
else
report_fatal_error("Unable to convert LLVM type to SPIRVType", true);
SPIRVType *SpvElementType;
// At the moment, all opaque pointers correspond to i8 element type.
// TODO: change the implementation once opaque pointers are supported
// in the SPIR-V specification.
SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);

SPIRVType *SpvElementType = nullptr;
if (auto PType = dyn_cast<TypedPointerType>(Ty))
SpvElementType = getOrCreateSPIRVType(PType->getElementType(), MIRBuilder,
AccQual, EmitIR);
else
SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);

// Get access to information about available extensions
const SPIRVSubtarget *ST =
static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ void RequirementHandler::initAvailableCapabilitiesForVulkan(

// Provided by all supported Vulkan versions.
addAvailableCaps({Capability::Int16, Capability::Int64, Capability::Float16,
Capability::Float64});
Capability::Float64, Capability::GroupNonUniform});
}

} // namespace SPIRV
Expand Down
4 changes: 3 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,9 @@ static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
}
case TargetOpcode::G_GLOBAL_VALUE: {
MIB.setInsertPt(*MI->getParent(), MI);
Type *Ty = MI->getOperand(1).getGlobal()->getType();
const auto *Global = MI->getOperand(1).getGlobal();
auto *Ty = TypedPointerType::get(Global->getValueType(),
Global->getType()->getAddressSpace());
SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB);
break;
}
Expand Down
9 changes: 4 additions & 5 deletions llvm/lib/Target/SPIRV/SPIRVUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,20 +306,19 @@ static bool isNonMangledOCLBuiltin(StringRef Name) {
std::string getOclOrSpirvBuiltinDemangledName(StringRef Name) {
bool IsNonMangledOCL = isNonMangledOCLBuiltin(Name);
bool IsNonMangledSPIRV = Name.starts_with("__spirv_");
bool IsNonMangledHLSL = Name.starts_with("__hlsl_");
bool IsMangled = Name.starts_with("_Z");

if (!IsNonMangledOCL && !IsNonMangledSPIRV && !IsMangled)
return std::string();
// Otherwise use simple demangling to return the function name.
if (IsNonMangledOCL || IsNonMangledSPIRV || IsNonMangledHLSL || !IsMangled)
return Name.str();

// Try to use the itanium demangler.
if (char *DemangledName = itaniumDemangle(Name.data())) {
std::string Result = DemangledName;
free(DemangledName);
return Result;
}
// Otherwise use simple demangling to return the function name.
if (IsNonMangledOCL || IsNonMangledSPIRV)
return Name.str();

// Autocheck C++, maybe need to do explicit check of the source language.
// OpenCL C++ built-ins are declared in cl namespace.
Expand Down
68 changes: 68 additions & 0 deletions llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveGetLaneIndex.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
; RUN: llc -O0 -mtriple=spirv-vulkan-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - -filetype=obj | spirv-val %}

; This file generated from the following command:
; clang -cc1 -triple spirv-vulkan-compute -x hlsl -emit-llvm -finclude-default-header -o - - <<EOF
; [numthreads(1, 1, 1)]
; void main() {
; int idx = WaveGetLaneIndex();
; }
; EOF

; CHECK-DAG: OpCapability Shader
; CHECK-DAG: OpCapability GroupNonUniform
; CHECK-DAG: OpDecorate %[[#var:]] BuiltIn SubgroupLocalInvocationId
; CHECK-DAG: %[[#int:]] = OpTypeInt 32 0
; CHECK-DAG: %[[#ptri:]] = OpTypePointer Input %[[#int]]
; CHECK-DAG: %[[#ptrf:]] = OpTypePointer Function %[[#int]]
; CHECK-DAG: %[[#var]] = OpVariable %[[#ptri]] Input

; CHECK-NOT: OpDecorate %[[#var]] LinkageAttributes


; ModuleID = '-'
source_filename = "-"
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
target triple = "spirv-unknown-vulkan-compute"

; Function Attrs: convergent noinline norecurse nounwind optnone
define internal spir_func void @main() #0 {
entry:
%0 = call token @llvm.experimental.convergence.entry()
%idx = alloca i32, align 4
; CHECK: %[[#idx:]] = OpVariable %[[#ptrf]] Function

%1 = call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %0) ]
; CHECK: %[[#tmp:]] = OpLoad %[[#int]] %[[#var]]

store i32 %1, ptr %idx, align 4
; CHECK: OpStore %[[#idx]] %[[#tmp]]

ret void
}

; Function Attrs: norecurse
define void @main.1() #1 {
entry:
call void @main()
ret void
}

; Function Attrs: convergent
declare i32 @__hlsl_wave_get_lane_index() #2

; Function Attrs: convergent nocallback nofree nosync nounwind willreturn memory(none)
declare token @llvm.experimental.convergence.entry() #3

attributes #0 = { convergent noinline norecurse nounwind optnone "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
attributes #1 = { norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
attributes #2 = { convergent }
attributes #3 = { convergent nocallback nofree nosync nounwind willreturn memory(none) }

!llvm.module.flags = !{!0, !1}
!llvm.ident = !{!2}

!0 = !{i32 1, !"wchar_size", i32 4}
!1 = !{i32 4, !"dx.disable_optimizations", i32 1}
!2 = !{!"clang version 19.0.0git (/usr/local/google/home/nathangauer/projects/llvm-project/clang bc6fd04b73a195981ee77823cf1382d04ab96c44)"}

9 changes: 4 additions & 5 deletions llvm/test/CodeGen/SPIRV/scfg-add-pre-headers.ll
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
; RUN: llc -mtriple=spirv-unknown-unknown -O0 %s -o - | FileCheck %s

; CHECK-DAG: OpDecorate %[[#SubgroupLocalInvocationId:]] BuiltIn SubgroupLocalInvocationId
; CHECK-DAG: %[[#bool:]] = OpTypeBool
; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
; CHECK-DAG: %[[#uint_0:]] = OpConstant %[[#uint]] 0
Expand Down Expand Up @@ -37,10 +38,10 @@ l1_continue:
; CHECK-NEXT: OpBranch %[[#l1_header]]

l1_end:
%call = call spir_func i32 @_Z3absi(i32 0) [ "convergencectrl"(token %tl1) ]
%call = call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %tl1) ]
br label %end
; CHECK-DAG: %[[#l1_end]] = OpLabel
; CHECK-DAG: %[[#]] = OpFunctionCall
; CHECK-DAG: %[[#]] = OpLoad %[[#]] %[[#SubgroupLocalInvocationId]]
; CHECK-NEXT: OpBranch %[[#end:]]

l2:
Expand Down Expand Up @@ -76,6 +77,4 @@ declare token @llvm.experimental.convergence.entry()
declare token @llvm.experimental.convergence.control()
declare token @llvm.experimental.convergence.loop()

; This intrinsic is not convergent. This is only because the backend doesn't
; support convergent operations yet.
declare spir_func i32 @_Z3absi(i32) convergent
declare i32 @__hlsl_wave_get_lane_index() convergent
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
;
; CHECK-SPIRV-DAG: %[[#i8:]] = OpTypeInt 8 0
; CHECK-SPIRV-DAG: %[[#i32:]] = OpTypeInt 32 0
; CHECK-SPIRV-DAG: %[[#one:]] = OpConstant %[[#i32]] 1
; CHECK-SPIRV-DAG: %[[#two:]] = OpConstant %[[#i32]] 2
Expand All @@ -13,7 +12,6 @@
; CHECK-SPIRV: %[[#test_arr2:]] = OpVariable %[[#const_i32x3_ptr]] UniformConstant %[[#test_arr_init]]
; CHECK-SPIRV: %[[#test_arr:]] = OpVariable %[[#const_i32x3_ptr]] UniformConstant %[[#test_arr_init]]

; CHECK-SPIRV-DAG: %[[#const_i8_ptr:]] = OpTypePointer UniformConstant %[[#i8]]
; CHECK-SPIRV-DAG: %[[#i32x3_ptr:]] = OpTypePointer Function %[[#i32x3]]

; CHECK-SPIRV: %[[#arr:]] = OpVariable %[[#i32x3_ptr]] Function
Expand Down