Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPIR-V] Add WaveGetLaneIndex() intrinsic support #85979

Merged
merged 1 commit into from
Mar 25, 2024

Conversation

Keenuts
Copy link
Contributor

@Keenuts Keenuts commented Mar 20, 2024

Add support to generate valid SPIR-V for the WaveGetLaneIndex() HLSL builtin.

To implement this, I had to fix a few small issues in the backend, like the i8* pointer type being emitted, even if we have the type information elsewhere.

Add support to generate valid SPIR-V for the WaveGetLaneIndex()
HLSL builtin.

To implement this, I had to fix a few small issues in the backend, like
the i8* pointer type being emitted, even if we have the type information
elsewhere.

Signed-off-by: Nathan Gauër <brioche@google.com>
@llvmbot
Copy link
Collaborator

llvmbot commented Mar 20, 2024

@llvm/pr-subscribers-backend-spir-v

Author: Nathan Gauër (Keenuts)

Changes

Add support to generate valid SPIR-V for the WaveGetLaneIndex() HLSL builtin.

To implement this, I had to fix a few small issues in the backend, like the i8* pointer type being emitted, even if we have the type information elsewhere.


Full diff: https://github.com/llvm/llvm-project/pull/85979.diff

10 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp (+25-8)
  • (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.td (+2)
  • (modified) llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp (+16-7)
  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+8-5)
  • (modified) llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (+1-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+3-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVUtils.cpp (+4-5)
  • (added) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveGetLaneIndex.ll (+68)
  • (modified) llvm/test/CodeGen/SPIRV/scfg-add-pre-headers.ll (+4-5)
  • (modified) llvm/test/CodeGen/SPIRV/transcoding/spirv-private-array-initialization.ll (-2)
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 07be0b34b18271..804c264e21e5ea 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -368,12 +368,10 @@ static Register buildLoadInst(SPIRVType *BaseType, Register PtrRegister,
 
 /// Helper function for building a load instruction for loading a builtin global
 /// variable of \p BuiltinValue value.
-static Register buildBuiltinVariableLoad(MachineIRBuilder &MIRBuilder,
-                                         SPIRVType *VariableType,
-                                         SPIRVGlobalRegistry *GR,
-                                         SPIRV::BuiltIn::BuiltIn BuiltinValue,
-                                         LLT LLType,
-                                         Register Reg = Register(0)) {
+static Register buildBuiltinVariableLoad(
+    MachineIRBuilder &MIRBuilder, SPIRVType *VariableType,
+    SPIRVGlobalRegistry *GR, SPIRV::BuiltIn::BuiltIn BuiltinValue, LLT LLType,
+    Register Reg = Register(0), bool isConst = true, bool hasLinkageTy = true) {
   Register NewRegister =
       MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
   MIRBuilder.getMRI()->setType(NewRegister,
@@ -385,8 +383,9 @@ static Register buildBuiltinVariableLoad(MachineIRBuilder &MIRBuilder,
   // Set up the global OpVariable with the necessary builtin decorations.
   Register Variable = GR->buildGlobalVariable(
       NewRegister, PtrType, getLinkStringForBuiltIn(BuiltinValue), nullptr,
-      SPIRV::StorageClass::Input, nullptr, true, true,
-      SPIRV::LinkageType::Import, MIRBuilder, false);
+      SPIRV::StorageClass::Input, nullptr, /* isConst= */ isConst,
+      /* HasLinkageTy */ hasLinkageTy, SPIRV::LinkageType::Import, MIRBuilder,
+      false);
 
   // Load the value from the global variable.
   Register LoadedRegister =
@@ -1300,6 +1299,22 @@ static bool generateDotOrFMulInst(const SPIRV::IncomingCall *Call,
   return true;
 }
 
+static bool generateWaveInst(const SPIRV::IncomingCall *Call,
+                             MachineIRBuilder &MIRBuilder,
+                             SPIRVGlobalRegistry *GR) {
+  const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
+  SPIRV::BuiltIn::BuiltIn Value =
+      SPIRV::lookupGetBuiltin(Builtin->Name, Builtin->Set)->Value;
+
+  // For now, we only support a single Wave intrinsic with a single return type.
+  assert(Call->ReturnType->getOpcode() == SPIRV::OpTypeInt);
+  LLT LLType = LLT::scalar(GR->getScalarOrVectorBitWidth(Call->ReturnType));
+
+  return buildBuiltinVariableLoad(
+      MIRBuilder, Call->ReturnType, GR, Value, LLType, Call->ReturnRegister,
+      /* isConst= */ false, /* hasLinkageTy= */ false);
+}
+
 static bool generateGetQueryInst(const SPIRV::IncomingCall *Call,
                                  MachineIRBuilder &MIRBuilder,
                                  SPIRVGlobalRegistry *GR) {
@@ -2187,6 +2202,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
     return generateBarrierInst(Call.get(), MIRBuilder, GR);
   case SPIRV::Dot:
     return generateDotOrFMulInst(Call.get(), MIRBuilder, GR);
+  case SPIRV::Wave:
+    return generateWaveInst(Call.get(), MIRBuilder, GR);
   case SPIRV::GetQuery:
     return generateGetQueryInst(Call.get(), MIRBuilder, GR);
   case SPIRV::ImageSizeQuery:
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
index eb26f70b1861f2..3fdfde625fbe9d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
@@ -41,6 +41,7 @@ def Variable : BuiltinGroup;
 def Atomic : BuiltinGroup;
 def Barrier : BuiltinGroup;
 def Dot : BuiltinGroup;
+def Wave : BuiltinGroup;
 def GetQuery : BuiltinGroup;
 def ImageSizeQuery : BuiltinGroup;
 def ImageMiscQuery : BuiltinGroup;
@@ -1119,6 +1120,7 @@ defm : DemangledGetBuiltin<"get_global_size", OpenCL_std, GetQuery, GlobalSize>;
 defm : DemangledGetBuiltin<"get_group_id", OpenCL_std, GetQuery, WorkgroupId>;
 defm : DemangledGetBuiltin<"get_enqueued_local_size", OpenCL_std, GetQuery, EnqueuedWorkgroupSize>;
 defm : DemangledGetBuiltin<"get_num_groups", OpenCL_std, GetQuery, NumWorkgroups>;
+defm : DemangledGetBuiltin<"__hlsl_wave_get_lane_index", GLSL_std_450, Wave, SubgroupLocalInvocationId>;
 
 //===----------------------------------------------------------------------===//
 // Class defining an image query builtin record used for lowering the OpenCL
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 6f23f055b8c2ab..afdca01561b0bc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -493,9 +493,15 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
   Register ResVReg =
       Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0];
   const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
-  // TODO: check that it's OCL builtin, then apply OpenCL_std.
-  if (!DemangledName.empty() && CF && CF->isDeclaration() &&
-      ST->canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
+
+  bool isFunctionDecl = CF && CF->isDeclaration();
+  bool canUseOpenCL = ST->canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std);
+  bool canUseGLSL = ST->canUseExtInstSet(SPIRV::InstructionSet::GLSL_std_450);
+  assert(canUseGLSL != canUseOpenCL &&
+         "Scenario where both sets are enabled is not supported.");
+
+  if (isFunctionDecl && !DemangledName.empty() &&
+      (canUseGLSL || canUseOpenCL)) {
     SmallVector<Register, 8> ArgVRegs;
     for (auto Arg : Info.OrigArgs) {
       assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
@@ -504,12 +510,15 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
       if (!GR->getSPIRVTypeForVReg(Arg.Regs[0]))
         GR->assignSPIRVTypeToVReg(SPIRVTy, Arg.Regs[0], MF);
     }
-    if (auto Res = SPIRV::lowerBuiltin(
-            DemangledName, SPIRV::InstructionSet::OpenCL_std, MIRBuilder,
-            ResVReg, OrigRetTy, ArgVRegs, GR))
+    auto instructionSet = canUseOpenCL ? SPIRV::InstructionSet::OpenCL_std
+                                       : SPIRV::InstructionSet::GLSL_std_450;
+    if (auto Res =
+            SPIRV::lowerBuiltin(DemangledName, instructionSet, MIRBuilder,
+                                ResVReg, OrigRetTy, ArgVRegs, GR))
       return *Res;
   }
-  if (CF && CF->isDeclaration() && !GR->find(CF, &MF).isValid()) {
+
+  if (isFunctionDecl && !GR->find(CF, &MF).isValid()) {
     // Emit the type info and forward function declaration to the first MBB
     // to ensure VReg definition dependencies are valid across all MBBs.
     MachineIRBuilder FirstBlockBuilder;
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 42f8397a3023b1..f865853776a1b9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -721,11 +721,14 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
     AddrSpace = PType->getAddressSpace();
   else
     report_fatal_error("Unable to convert LLVM type to SPIRVType", true);
-  SPIRVType *SpvElementType;
-  // At the moment, all opaque pointers correspond to i8 element type.
-  // TODO: change the implementation once opaque pointers are supported
-  // in the SPIR-V specification.
-  SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);
+
+  SPIRVType *SpvElementType = nullptr;
+  if (auto PType = dyn_cast<TypedPointerType>(Ty))
+    SpvElementType = getOrCreateSPIRVType(PType->getElementType(), MIRBuilder,
+                                          AccQual, EmitIR);
+  else
+    SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);
+
   // Get access to information about available extensions
   const SPIRVSubtarget *ST =
       static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 00d0cbd763736d..40c3e5f9c6bdab 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -658,7 +658,7 @@ void RequirementHandler::initAvailableCapabilitiesForVulkan(
 
   // Provided by all supported Vulkan versions.
   addAvailableCaps({Capability::Int16, Capability::Int64, Capability::Float16,
-                    Capability::Float64});
+                    Capability::Float64, Capability::GroupNonUniform});
 }
 
 } // namespace SPIRV
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index d547f91ba4a565..ea53f937d31982 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -186,7 +186,9 @@ static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
       }
       case TargetOpcode::G_GLOBAL_VALUE: {
         MIB.setInsertPt(*MI->getParent(), MI);
-        Type *Ty = MI->getOperand(1).getGlobal()->getType();
+        const auto *Global = MI->getOperand(1).getGlobal();
+        auto *Ty = TypedPointerType::get(Global->getValueType(),
+                                         Global->getType()->getAddressSpace());
         SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB);
         break;
       }
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index fc7502479fdcdd..c87c1293c622fc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -306,10 +306,12 @@ static bool isNonMangledOCLBuiltin(StringRef Name) {
 std::string getOclOrSpirvBuiltinDemangledName(StringRef Name) {
   bool IsNonMangledOCL = isNonMangledOCLBuiltin(Name);
   bool IsNonMangledSPIRV = Name.starts_with("__spirv_");
+  bool IsNonMangledHLSL = Name.starts_with("__hlsl_");
   bool IsMangled = Name.starts_with("_Z");
 
-  if (!IsNonMangledOCL && !IsNonMangledSPIRV && !IsMangled)
-    return std::string();
+  // Otherwise use simple demangling to return the function name.
+  if (IsNonMangledOCL || IsNonMangledSPIRV || IsNonMangledHLSL || !IsMangled)
+    return Name.str();
 
   // Try to use the itanium demangler.
   if (char *DemangledName = itaniumDemangle(Name.data())) {
@@ -317,9 +319,6 @@ std::string getOclOrSpirvBuiltinDemangledName(StringRef Name) {
     free(DemangledName);
     return Result;
   }
-  // Otherwise use simple demangling to return the function name.
-  if (IsNonMangledOCL || IsNonMangledSPIRV)
-    return Name.str();
 
   // Autocheck C++, maybe need to do explicit check of the source language.
   // OpenCL C++ built-ins are declared in cl namespace.
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveGetLaneIndex.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveGetLaneIndex.ll
new file mode 100644
index 00000000000000..ec35690ac1547c
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveGetLaneIndex.ll
@@ -0,0 +1,68 @@
+; RUN: llc -O0 -mtriple=spirv-vulkan-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
+
+; This file generated from the following command:
+; clang -cc1 -triple spirv-vulkan-compute -x hlsl -emit-llvm -finclude-default-header -o - - <<EOF
+; [numthreads(1, 1, 1)]
+; void main() {
+;   int idx = WaveGetLaneIndex();
+; }
+; EOF
+
+; CHECK-DAG:                         OpCapability Shader
+; CHECK-DAG:                         OpCapability GroupNonUniform
+; CHECK-DAG:                         OpDecorate %[[#var:]] BuiltIn SubgroupLocalInvocationId
+; CHECK-DAG:            %[[#int:]] = OpTypeInt 32 0
+; CHECK-DAG:           %[[#ptri:]] = OpTypePointer Input %[[#int]]
+; CHECK-DAG:           %[[#ptrf:]] = OpTypePointer Function %[[#int]]
+; CHECK-DAG:             %[[#var]] = OpVariable %[[#ptri]] Input
+
+; CHECK-NOT:                         OpDecorate %[[#var]] LinkageAttributes
+
+
+; ModuleID = '-'
+source_filename = "-"
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
+target triple = "spirv-unknown-vulkan-compute"
+
+; Function Attrs: convergent noinline norecurse nounwind optnone
+define internal spir_func void @main() #0 {
+entry:
+  %0 = call token @llvm.experimental.convergence.entry()
+  %idx = alloca i32, align 4
+; CHECK:     %[[#idx:]] = OpVariable %[[#ptrf]] Function
+
+  %1 = call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %0) ]
+; CHECK:    %[[#tmp:]] = OpLoad %[[#int]] %[[#var]]
+
+  store i32 %1, ptr %idx, align 4
+; CHECK:                 OpStore %[[#idx]] %[[#tmp]]
+
+  ret void
+}
+
+; Function Attrs: norecurse
+define void @main.1() #1 {
+entry:
+  call void @main()
+  ret void
+}
+
+; Function Attrs: convergent
+declare i32 @__hlsl_wave_get_lane_index() #2
+
+; Function Attrs: convergent nocallback nofree nosync nounwind willreturn memory(none)
+declare token @llvm.experimental.convergence.entry() #3
+
+attributes #0 = { convergent noinline norecurse nounwind optnone "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #2 = { convergent }
+attributes #3 = { convergent nocallback nofree nosync nounwind willreturn memory(none) }
+
+!llvm.module.flags = !{!0, !1}
+!llvm.ident = !{!2}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 4, !"dx.disable_optimizations", i32 1}
+!2 = !{!"clang version 19.0.0git (/usr/local/google/home/nathangauer/projects/llvm-project/clang bc6fd04b73a195981ee77823cf1382d04ab96c44)"}
+
diff --git a/llvm/test/CodeGen/SPIRV/scfg-add-pre-headers.ll b/llvm/test/CodeGen/SPIRV/scfg-add-pre-headers.ll
index 329399bab3e5b9..2ea5c767730e19 100644
--- a/llvm/test/CodeGen/SPIRV/scfg-add-pre-headers.ll
+++ b/llvm/test/CodeGen/SPIRV/scfg-add-pre-headers.ll
@@ -1,5 +1,6 @@
 ; RUN: llc -mtriple=spirv-unknown-unknown -O0 %s -o - | FileCheck %s
 
+; CHECK-DAG:                  OpDecorate %[[#SubgroupLocalInvocationId:]] BuiltIn SubgroupLocalInvocationId
 ; CHECK-DAG:    %[[#bool:]] = OpTypeBool
 ; CHECK-DAG:    %[[#uint:]] = OpTypeInt 32 0
 ; CHECK-DAG:  %[[#uint_0:]] = OpConstant %[[#uint]] 0
@@ -37,10 +38,10 @@ l1_continue:
 ; CHECK-NEXT:                      OpBranch %[[#l1_header]]
 
 l1_end:
-  %call = call spir_func i32 @_Z3absi(i32 0) [ "convergencectrl"(token %tl1) ]
+  %call = call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %tl1) ]
   br label %end
 ; CHECK-DAG:   %[[#l1_end]] = OpLabel
-; CHECK-DAG:         %[[#]] = OpFunctionCall
+; CHECK-DAG:         %[[#]] = OpLoad %[[#]] %[[#SubgroupLocalInvocationId]]
 ; CHECK-NEXT:                 OpBranch %[[#end:]]
 
 l2:
@@ -76,6 +77,4 @@ declare token @llvm.experimental.convergence.entry()
 declare token @llvm.experimental.convergence.control()
 declare token @llvm.experimental.convergence.loop()
 
-; This intrinsic is not convergent. This is only because the backend doesn't
-; support convergent operations yet.
-declare spir_func i32 @_Z3absi(i32) convergent
+declare i32 @__hlsl_wave_get_lane_index() convergent
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/spirv-private-array-initialization.ll b/llvm/test/CodeGen/SPIRV/transcoding/spirv-private-array-initialization.ll
index 3551030843d062..e0172ec3c1bdb7 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/spirv-private-array-initialization.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/spirv-private-array-initialization.ll
@@ -1,6 +1,5 @@
 ; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
 ;
-; CHECK-SPIRV-DAG: %[[#i8:]] = OpTypeInt 8 0
 ; CHECK-SPIRV-DAG: %[[#i32:]] = OpTypeInt 32 0
 ; CHECK-SPIRV-DAG: %[[#one:]] = OpConstant %[[#i32]] 1
 ; CHECK-SPIRV-DAG: %[[#two:]] = OpConstant %[[#i32]] 2
@@ -13,7 +12,6 @@
 ; CHECK-SPIRV:     %[[#test_arr2:]] = OpVariable %[[#const_i32x3_ptr]] UniformConstant %[[#test_arr_init]]
 ; CHECK-SPIRV:     %[[#test_arr:]] = OpVariable %[[#const_i32x3_ptr]] UniformConstant %[[#test_arr_init]]
 
-; CHECK-SPIRV-DAG: %[[#const_i8_ptr:]] = OpTypePointer UniformConstant %[[#i8]]
 ; CHECK-SPIRV-DAG: %[[#i32x3_ptr:]] = OpTypePointer Function %[[#i32x3]]
 
 ; CHECK-SPIRV:     %[[#arr:]] = OpVariable %[[#i32x3_ptr]] Function

Copy link
Member

@sudonatalie sudonatalie left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@Keenuts Keenuts merged commit f0eb908 into llvm:main Mar 25, 2024
7 checks passed
@Keenuts Keenuts deleted the add-wave-intrinsic branch March 25, 2024 10:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants