Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HLSL][SPIR-V] Add SV_DispatchThreadID semantic support #82536

Merged
merged 3 commits into from
Mar 4, 2024

Conversation

sudonatalie
Copy link
Member

Add SPIR-V backend support for the HLSL SV_DispatchThreadID semantic attribute, which is lowered to a @llvm.dx.thread.id intrinsic in LLVM IR. In the SPIR-V backend, this is now correctly translated to a GlobalInvocationId builtin variable.

Fixes #82534

@llvmbot
Copy link
Collaborator

llvmbot commented Feb 21, 2024

@llvm/pr-subscribers-clang-codegen
@llvm/pr-subscribers-hlsl
@llvm/pr-subscribers-clang

@llvm/pr-subscribers-backend-spir-v

Author: Natalie Chouinard (sudonatalie)

Changes

Add SPIR-V backend support for the HLSL SV_DispatchThreadID semantic attribute, which is lowered to a @llvm.dx.thread.id intrinsic in LLVM IR. In the SPIR-V backend, this is now correctly translated to a GlobalInvocationId builtin variable.

Fixes #82534


Full diff: https://github.com/llvm/llvm-project/pull/82536.diff

3 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+3-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+69)
  • (added) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/SV_DispatchThreadID.ll (+76)
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 47fec745c3f18a..91562364383ab3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -525,7 +525,9 @@ Register SPIRVGlobalRegistry::buildGlobalVariable(
 
   // Output decorations for the GV.
   // TODO: maybe move to GenerateDecorations pass.
-  if (IsConst)
+  const SPIRVSubtarget &ST =
+      cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
+  if (IsConst && ST.isOpenCLEnv())
     buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Constant, {});
 
   if (GVar && GVar->getAlign().valueOrOne().value() != 1) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 52eeb8a523e6f6..751ecf9e9840cf 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -27,6 +27,7 @@
 #include "llvm/CodeGen/GlobalISel/InstructionSelector.h"
 #include "llvm/CodeGen/MachineInstrBuilder.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/IR/IntrinsicsDirectX.h"
 #include "llvm/IR/IntrinsicsSPIRV.h"
 #include "llvm/Support/Debug.h"
 
@@ -182,6 +183,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
   bool selectLog10(Register ResVReg, const SPIRVType *ResType,
                    MachineInstr &I) const;
 
+  bool selectDXThreadId(Register ResVReg, const SPIRVType *ResType,
+                        MachineInstr &I) const;
+
   Register buildI32Constant(uint32_t Val, MachineInstr &I,
                             const SPIRVType *ResType = nullptr) const;
 
@@ -284,6 +288,7 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
   case TargetOpcode::G_IMPLICIT_DEF:
     return selectOpUndef(ResVReg, ResType, I);
 
+  case TargetOpcode::G_INTRINSIC:
   case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
   case TargetOpcode::G_INTRINSIC_CONVERGENT_W_SIDE_EFFECTS:
     return selectIntrinsic(ResVReg, ResType, I);
@@ -1427,6 +1432,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
           .addUse(I.getOperand(2).getReg())
           .addUse(I.getOperand(3).getReg());
     break;
+  case Intrinsic::dx_thread_id:
+    return selectDXThreadId(ResVReg, ResType, I);
   default:
     llvm_unreachable("Intrinsic selection not implemented");
   }
@@ -1660,6 +1667,68 @@ bool SPIRVInstructionSelector::selectLog10(Register ResVReg,
   return Result;
 }
 
+bool SPIRVInstructionSelector::selectDXThreadId(Register ResVReg,
+                                                const SPIRVType *ResType,
+                                                MachineInstr &I) const {
+  // DX intrinsic: @llvm.dx.thread.id(i32)
+  // ID  Name      Description
+  // 93  ThreadId  reads the thread ID
+
+  MachineIRBuilder MIRBuilder(I);
+  const SPIRVType *U32Type = GR.getOrCreateSPIRVIntegerType(32, MIRBuilder);
+  const SPIRVType *Vec3Ty =
+      GR.getOrCreateSPIRVVectorType(U32Type, 3, MIRBuilder);
+  const SPIRVType *PtrType = GR.getOrCreateSPIRVPointerType(
+      Vec3Ty, MIRBuilder, SPIRV::StorageClass::Input);
+
+  // Create new register for GlobalInvocationID builtin variable.
+  Register NewRegister =
+      MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
+  MIRBuilder.getMRI()->setType(NewRegister, LLT::pointer(0, 32));
+  GR.assignSPIRVTypeToVReg(PtrType, NewRegister, MIRBuilder.getMF());
+
+  // Build GlobalInvocationID global variable with the necessary decorations.
+  Register Variable = GR.buildGlobalVariable(
+      NewRegister, PtrType,
+      getLinkStringForBuiltIn(SPIRV::BuiltIn::GlobalInvocationId), nullptr,
+      SPIRV::StorageClass::Input, nullptr, true, true,
+      SPIRV::LinkageType::Import, MIRBuilder, false);
+
+  // Create new register for loading value.
+  MachineRegisterInfo *MRI = MIRBuilder.getMRI();
+  Register LoadedRegister = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+  MIRBuilder.getMRI()->setType(LoadedRegister, LLT::pointer(0, 32));
+  GR.assignSPIRVTypeToVReg(Vec3Ty, LoadedRegister, MIRBuilder.getMF());
+
+  // Load v3uint value from the global variable.
+  BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpLoad))
+      .addDef(LoadedRegister)
+      .addUse(GR.getSPIRVTypeID(Vec3Ty))
+      .addUse(Variable);
+
+  // Get Thread ID index. Expecting operand is a constant immediate value,
+  // wrapped in a type assignment.
+  assert(I.getOperand(2).isReg());
+  Register ThreadIdReg = I.getOperand(2).getReg();
+  SPIRVType *ConstTy = this->MRI->getVRegDef(ThreadIdReg);
+  assert(ConstTy && ConstTy->getOpcode() == SPIRV::ASSIGN_TYPE &&
+         ConstTy->getOperand(1).isReg());
+  Register ConstReg = ConstTy->getOperand(1).getReg();
+  const MachineInstr *Const = this->MRI->getVRegDef(ConstReg);
+  assert(Const && Const->getOpcode() == TargetOpcode::G_CONSTANT);
+  const llvm::APInt &Val = Const->getOperand(1).getCImm()->getValue();
+  const uint32_t ThreadId = Val.getZExtValue();
+
+  // Extract the thread ID from the loaded vector value.
+  MachineBasicBlock &BB = *I.getParent();
+  auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
+                 .addDef(ResVReg)
+                 .addUse(GR.getSPIRVTypeID(ResType))
+                 .addUse(LoadedRegister)
+                 .addImm(ThreadId);
+  return MIB.constrainAllUses(TII, TRI, RBI);
+}
+
 namespace llvm {
 InstructionSelector *
 createSPIRVInstructionSelector(const SPIRVTargetMachine &TM,
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/SV_DispatchThreadID.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/SV_DispatchThreadID.ll
new file mode 100644
index 00000000000000..4915c0d3277075
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/SV_DispatchThreadID.ll
@@ -0,0 +1,76 @@
+; RUN: llc -O0 -mtriple=spirv-vulkan-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
+
+; This file generated from the following HLSL:
+; clang -cc1 -triple spirv-vulkan-library -x hlsl -emit-llvm -disable-llvm-passes -finclude-default-header -o - DispatchThreadID.hlsl
+;
+; [shader("compute")]
+; [numthreads(1,1,1)]
+; void main(uint3 ID : SV_DispatchThreadID) {}
+
+; CHECK-DAG:        %[[#int:]] = OpTypeInt 32 0
+; CHECK-DAG:        %[[#v3int:]] = OpTypeVector %[[#int]] 3
+; CHECK-DAG:        %[[#ptr_Input_v3int:]] = OpTypePointer Input %[[#v3int]]
+; CHECK-DAG:        %[[#tempvar:]] = OpUndef %[[#v3int]]
+; CHECK-DAG:        %[[#GlobalInvocationId:]] = OpVariable %[[#ptr_Input_v3int]] Input
+
+; CHECK-DAG:        OpEntryPoint GLCompute {{.*}} %[[#GlobalInvocationId]]
+; CHECK-DAG:        OpName %[[#GlobalInvocationId]] "__spirv_BuiltInGlobalInvocationId"
+; CHECK-DAG:        OpDecorate %[[#GlobalInvocationId]] LinkageAttributes "__spirv_BuiltInGlobalInvocationId" Import
+; CHECK-DAG:        OpDecorate %[[#GlobalInvocationId]] BuiltIn GlobalInvocationId
+
+; ModuleID = 'DispatchThreadID.hlsl'
+source_filename = "DispatchThreadID.hlsl"
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
+target triple = "spirv-unknown-vulkan-library"
+
+; Function Attrs: noinline norecurse nounwind optnone
+define internal spir_func void @main(<3 x i32> noundef %ID) #0 {
+entry:
+  %ID.addr = alloca <3 x i32>, align 16
+  store <3 x i32> %ID, ptr %ID.addr, align 16
+  ret void
+}
+
+; Function Attrs: norecurse
+define void @main.1() #1 {
+entry:
+
+; CHECK:        %[[#load:]] = OpLoad %[[#v3int]] %[[#GlobalInvocationId]]
+; CHECK:        %[[#load0:]] = OpCompositeExtract %[[#int]] %[[#load]] 0
+  %0 = call i32 @llvm.dx.thread.id(i32 0)
+
+; CHECK:        %[[#tempvar:]] = OpCompositeInsert %[[#v3int]] %[[#load0]] %[[#tempvar]] 0
+  %1 = insertelement <3 x i32> poison, i32 %0, i64 0
+
+; CHECK:        %[[#load:]] = OpLoad %[[#v3int]] %[[#GlobalInvocationId]]
+; CHECK:        %[[#load1:]] = OpCompositeExtract %[[#int]] %[[#load]] 1
+  %2 = call i32 @llvm.dx.thread.id(i32 1)
+
+; CHECK:        %[[#tempvar:]] = OpCompositeInsert %[[#v3int]] %[[#load1]] %[[#tempvar]] 1
+  %3 = insertelement <3 x i32> %1, i32 %2, i64 1
+
+; CHECK:        %[[#load:]] = OpLoad %[[#v3int]] %[[#GlobalInvocationId]]
+; CHECK:        %[[#load2:]] = OpCompositeExtract %[[#int]] %[[#load]] 2
+  %4 = call i32 @llvm.dx.thread.id(i32 2)
+
+; CHECK:        %[[#tempvar:]] = OpCompositeInsert %[[#v3int]] %[[#load2]] %[[#tempvar]] 2
+  %5 = insertelement <3 x i32> %3, i32 %4, i64 2
+
+  call void @main(<3 x i32> %5)
+  ret void
+}
+
+; Function Attrs: nounwind willreturn memory(none)
+declare i32 @llvm.dx.thread.id(i32) #2
+
+attributes #0 = { noinline norecurse nounwind optnone "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #2 = { nounwind willreturn memory(none) }
+
+!llvm.module.flags = !{!0, !1}
+!llvm.ident = !{!2}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 4, !"dx.disable_optimizations", i32 1}
+!2 = !{!"clang version 19.0.0git (git@github.com:llvm/llvm-project.git c9afeaa6434a61b3b3a57c8eda6d2cfb25ab675b)"}

Copy link
Contributor

@Keenuts Keenuts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Just a comment for future work, otherwise LGTM

@llvmbot llvmbot added clang Clang issues not falling into any other category clang:codegen HLSL HLSL Language Support llvm:ir labels Feb 26, 2024
sudonatalie added a commit to sudonatalie/llvm-project that referenced this pull request Feb 26, 2024
Noticed while implementing llvm#82536 that this test was also missing the
call the FileCheck.
sudonatalie added a commit that referenced this pull request Feb 27, 2024
Noticed while implementing #82536 that this test was also missing the
call the FileCheck.
Copy link
Contributor

@bogner bogner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I don't like how we're choosing an intrinsic but changing that can come in as a follow up if you like.

clang/lib/CodeGen/CGHLSLRuntime.cpp Show resolved Hide resolved
Add SPIR-V backend support for the HLSL SV_DispatchThreadID semantic
attribute, which is lowered to a @llvm.dx.thread.id intrinsic.

Fixes llvm#82534
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:SPIR-V clang:codegen clang Clang issues not falling into any other category HLSL HLSL Language Support llvm:ir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[HLSL][SPIR-V] Add SV_DispatchThreadID semantic support
7 participants