Skip to content

Conversation

@AlexVlx
Copy link
Contributor

@AlexVlx AlexVlx commented Dec 1, 2025

This adds support for the SPV_NV_shader_atomic_fp16_vector extension, and then uses it to enable lowering of atomic add, sub, min and max on 2 and 4 component vectors of FP16, which are rather common options in ML workloads. Even though bfloat16 also works in practice, we do not enable it since it's not specified in the extension (which might need updating / promoting to KHR at least). A TODO is also inserted in SPIRVModuleAnalysis.cpp' regarding the need to upgrade its ample usage of report_fatal_error`; I have a WiP patch for that, but it still needs a bit of baking. Finally, a paired patch will be necessary in the Translator, as it's not aware of the extension either - I'll update this review to reference the PR once I create it.

@llvmbot
Copy link
Member

llvmbot commented Dec 1, 2025

@llvm/pr-subscribers-backend-spir-v

Author: Alex Voicu (AlexVlx)

Changes

This adds support for the 'SPV_NV_shader_atomic_fp16_vectorextension, and then uses it to enable lowering of atomic add, sub, min and max on 2 and 4 component vectors of FP16, which are rather common options in ML workloads. Even thoughbfloat16also works in practice, we do not enable it since it's not specified in the extension (which might need updating / promoting to KHR at least). ATODOis also inserted inSPIRVModuleAnalysis.cpp' regarding the need to upgrade its ample usage of report_fatal_error; I have a WiP patch for that, but it still needs a bit of baking. Finally, a paired patch will be necessary in the Translator, as it's not aware of the extension either - I'll update this review to reference the PR once I create it.


Full diff: https://github.com/llvm/llvm-project/pull/170213.diff

8 Files Affected:

  • (modified) llvm/docs/SPIRVUsage.rst (+2)
  • (modified) llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp (+2)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+2-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp (+6-2)
  • (modified) llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (+41)
  • (modified) llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td (+3)
  • (added) llvm/test/CodeGen/SPIRV/extensions/SPV_NV_shader_atomic_fp16_vector/atomicrmw_faddfsub_vec_float16.ll (+47)
  • (added) llvm/test/CodeGen/SPIRV/extensions/SPV_NV_shader_atomic_fp16_vector/atomicrmw_fminfmax_vec_float16.ll (+45)
diff --git a/llvm/docs/SPIRVUsage.rst b/llvm/docs/SPIRVUsage.rst
index 88164e6fa53d8..e2f85ba3c2774 100644
--- a/llvm/docs/SPIRVUsage.rst
+++ b/llvm/docs/SPIRVUsage.rst
@@ -169,6 +169,8 @@ Below is a list of supported SPIR-V extensions, sorted alphabetically by their e
      - Adds atomic min and max instruction on floating-point numbers.
    * - ``SPV_INTEL_16bit_atomics``
      - Extends the SPV_EXT_shader_atomic_float_add and SPV_EXT_shader_atomic_float_min_max to support addition, minimum and maximum on 16-bit `bfloat16` floating-point numbers in memory.
+   * - ``SPV_NV_shader_atomic_fp16_vector``
+     - Adds atomic add, min and max instructions on 2 or 4-component vectors with 16-bit float components.
    * - ``SPV_INTEL_2d_block_io``
      - Adds additional subgroup block prefetch, load, load transposed, load transformed and store instructions to read two-dimensional blocks of data from a two-dimensional region of memory, or to write two-dimensional blocks of data to a two dimensional region of memory.
    * - ``SPV_ALTERA_arbitrary_precision_integers``
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index 146384f4bf08c..d2a8fddc5d8e4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -31,6 +31,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
          SPIRV::Extension::Extension::SPV_EXT_shader_atomic_float_min_max},
         {"SPV_INTEL_16bit_atomics",
          SPIRV::Extension::Extension::SPV_INTEL_16bit_atomics},
+        {"SPV_NV_shader_atomic_fp16_vector",
+         SPIRV::Extension::Extension::SPV_NV_shader_atomic_fp16_vector},
         {"SPV_EXT_arithmetic_fence",
          SPIRV::Extension::Extension::SPV_EXT_arithmetic_fence},
         {"SPV_EXT_demote_to_helper_invocation",
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index a2e29366dc4cc..c0c60b839a2a9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -1193,7 +1193,8 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
   case TargetOpcode::G_ATOMICRMW_FSUB:
     // Translate G_ATOMICRMW_FSUB to OpAtomicFAddEXT with negative value operand
     return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicFAddEXT,
-                           SPIRV::OpFNegate);
+                           ResType->getOpcode() == SPIRV::OpTypeVector
+                              ? SPIRV::OpFNegateV : SPIRV::OpFNegate);
   case TargetOpcode::G_ATOMICRMW_FMIN:
     return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicFMinEXT);
   case TargetOpcode::G_ATOMICRMW_FMAX:
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 9d5a685fdbc84..4bc86b0168f1e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -131,6 +131,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
       s16,   s32,   s64,   v2s16, v2s32, v2s64, v3s16,  v3s32,  v3s64,
       v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
 
+  auto allFloatScalarsAndF16Vector2AndVector4s = {s16, s32, s64, v2s16, v4s16};
+
   auto allFloatAndIntScalarsAndPtrs = {s8, s16, s32, s64, p0, p1,  p2,  p3,
                                        p4, p5,  p6,  p7,  p8, p10, p11, p12};
 
@@ -339,10 +341,12 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
 
   getActionDefinitionsBuilder(
       {G_ATOMICRMW_FADD, G_ATOMICRMW_FSUB, G_ATOMICRMW_FMIN, G_ATOMICRMW_FMAX})
-      .legalForCartesianProduct(allFloatScalars, allPtrs);
+      .legalForCartesianProduct(allFloatScalarsAndF16Vector2AndVector4s,
+                                allPtrs);
 
   getActionDefinitionsBuilder(G_ATOMICRMW_XCHG)
-      .legalForCartesianProduct(allFloatAndIntScalarsAndPtrs, allPtrs);
+      .legalForCartesianProduct(allFloatScalarsAndF16Vector2AndVector4s,
+                                allPtrs);
 
   getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower();
   // TODO: add proper legalization rules.
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 2feb73d8dedfa..73432279c3306 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -14,6 +14,10 @@
 //
 //===----------------------------------------------------------------------===//
 
+// TODO: uses or report_fatal_error (which is also deprecated) /
+//       ReportFatalUsageError in this file should be refactored, as per LLVM
+//       best practices, to rely on the Diagnostic infrastructure.
+
 #include "SPIRVModuleAnalysis.h"
 #include "MCTargetDesc/SPIRVBaseInfo.h"
 #include "MCTargetDesc/SPIRVMCTargetDesc.h"
@@ -1071,6 +1075,39 @@ static bool isBFloat16Type(const SPIRVType *TypeDef) {
 #define ATOM_FLT_REQ_EXT_MSG(ExtName)                                          \
   "The atomic float instruction requires the following SPIR-V "                \
   "extension: SPV_EXT_shader_atomic_float" ExtName
+static void AddAtomicVectorFloatRequirements(const MachineInstr &MI,
+                                             SPIRV::RequirementHandler &Reqs,
+                                             const SPIRVSubtarget &ST) {
+  SPIRVType *VecTypeDef =
+      MI.getMF()->getRegInfo().getVRegDef(MI.getOperand(1).getReg());
+
+  const unsigned Rank = VecTypeDef->getOperand(2).getImm();
+  if (Rank != 2 && Rank != 4)
+    reportFatalUsageError("Result type of an atomic vector float instruction "
+                          "must be a 2-component or 4 component vector");
+
+  SPIRVType *EltTypeDef =
+      MI.getMF()->getRegInfo().getVRegDef(VecTypeDef->getOperand(1).getReg());
+
+  if (EltTypeDef->getOpcode() != SPIRV::OpTypeFloat ||
+      EltTypeDef->getOperand(1).getImm() != 16)
+    reportFatalUsageError(
+        "The element type for the result type of an atomic vector float "
+        "instruction must be a 16-bit floating-point scalar");
+
+  if (isBFloat16Type(EltTypeDef))
+    reportFatalUsageError(
+        "The element type for the result type of an atomic vector float "
+        "instruction cannot be a bfloat16 scalar");
+  if (!ST.canUseExtension(SPIRV::Extension::SPV_NV_shader_atomic_fp16_vector))
+    reportFatalUsageError(
+        "The atomic float16 vector instruction requires the following SPIR-V "
+        "extension: SPV_NV_shader_atomic_fp16_vector");
+
+  Reqs.addExtension(SPIRV::Extension::SPV_NV_shader_atomic_fp16_vector);
+  Reqs.addCapability(SPIRV::Capability::AtomicFloat16VectorNV);
+}
+
 static void AddAtomicFloatRequirements(const MachineInstr &MI,
                                        SPIRV::RequirementHandler &Reqs,
                                        const SPIRVSubtarget &ST) {
@@ -1078,6 +1115,10 @@ static void AddAtomicFloatRequirements(const MachineInstr &MI,
          "Expect register operand in atomic float instruction");
   Register TypeReg = MI.getOperand(1).getReg();
   SPIRVType *TypeDef = MI.getMF()->getRegInfo().getVRegDef(TypeReg);
+
+  if (TypeDef->getOpcode() == SPIRV::OpTypeVector)
+    return AddAtomicVectorFloatRequirements(MI, Reqs, ST);
+
   if (TypeDef->getOpcode() != SPIRV::OpTypeFloat)
     report_fatal_error("Result type of an atomic float instruction must be a "
                        "floating-point type scalar");
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 94e0138c66487..078f1dff839ea 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -391,6 +391,8 @@ defm SPV_INTEL_bfloat16_arithmetic
     : ExtensionOperand<129, [EnvVulkan, EnvOpenCL]>;
 defm SPV_INTEL_16bit_atomics : ExtensionOperand<130, [EnvVulkan, EnvOpenCL]>;
 defm SPV_ALTERA_arbitrary_precision_fixed_point : ExtensionOperand<131, [EnvOpenCL, EnvVulkan]>;
+defm SPV_NV_shader_atomic_fp16_vector
+    : ExtensionOperand<132, [EnvVulkan, EnvOpenCL]>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define Capabilities enum values and at the same time
@@ -573,6 +575,7 @@ defm AtomicFloat16MinMaxEXT : CapabilityOperand<5616, 0, 0, [SPV_EXT_shader_atom
 defm AtomicFloat32MinMaxEXT : CapabilityOperand<5612, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
 defm AtomicFloat64MinMaxEXT : CapabilityOperand<5613, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
 defm AtomicBFloat16MinMaxINTEL : CapabilityOperand<6256, 0, 0, [SPV_INTEL_16bit_atomics], []>;
+defm AtomicFloat16VectorNV : CapabilityOperand<5404, 0, 0, [SPV_NV_shader_atomic_fp16_vector], []>;
 defm VariableLengthArrayINTEL : CapabilityOperand<5817, 0, 0, [SPV_INTEL_variable_length_array], []>;
 defm GroupUniformArithmeticKHR : CapabilityOperand<6400, 0, 0, [SPV_KHR_uniform_group_instructions], []>;
 defm USMStorageClassesINTEL : CapabilityOperand<5935, 0, 0, [SPV_INTEL_usm_storage_classes], [Kernel]>;
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_NV_shader_atomic_fp16_vector/atomicrmw_faddfsub_vec_float16.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_NV_shader_atomic_fp16_vector/atomicrmw_faddfsub_vec_float16.ll
new file mode 100644
index 0000000000000..36f6e38fc75de
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_NV_shader_atomic_fp16_vector/atomicrmw_faddfsub_vec_float16.ll
@@ -0,0 +1,47 @@
+; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_NV_shader_atomic_fp16_vector %s -o - | FileCheck %s
+
+; CHECK-ERROR: LLVM ERROR: The atomic float16 vector instruction requires the following SPIR-V extension: SPV_NV_shader_atomic_fp16_vector
+
+; CHECK: Capability Float16
+; CHECK-DAG: Capability AtomicFloat16VectorNV
+; CHECK: Extension "SPV_NV_shader_atomic_fp16_vector"
+; CHECK-DAG: %[[TyF16:[0-9]+]] = OpTypeFloat 16
+; CHECK: %[[TyF16Vec2:[0-9]+]] = OpTypeVector %[[TyF16]] 2
+; CHECK: %[[TyF16Vec4:[0-9]+]] = OpTypeVector %[[TyF16]] 4
+; CHECK: %[[TyF16Vec4Ptr:[0-9]+]] = OpTypePointer {{[a-zA-Z]+}} %[[TyF16Vec4]]
+; CHECK: %[[TyF16Vec2Ptr:[0-9]+]] = OpTypePointer {{[a-zA-Z]+}} %[[TyF16Vec2]]
+; CHECK: %[[TyInt32:[0-9]+]] = OpTypeInt 32 0
+; CHECK: %[[ConstF16:[0-9]+]] = OpConstant %[[TyF16]] 20800{{$}}
+; CHECK: %[[Const0F16Vec2:[0-9]+]] = OpConstantNull %[[TyF16Vec2]]
+; CHECK: %[[f:[0-9]+]] = OpVariable %[[TyF16Vec2Ptr]] CrossWorkgroup %[[Const0F16Vec2]]
+; CHECK: %[[Const0F16Vec4:[0-9]+]] = OpConstantNull %[[TyF16Vec4]]
+; CHECK: %[[g:[0-9]+]] = OpVariable %[[TyF16Vec4Ptr]] CrossWorkgroup %[[Const0F16Vec4]]
+; CHECK: %[[ConstF16Vec2:[0-9]+]] = OpConstantComposite %[[TyF16Vec2]] %[[ConstF16]] %[[ConstF16]]
+; CHECK: %[[ScopeAllSvmDevices:[0-9]+]] = OpConstantNull %[[TyInt32]]
+; CHECK: %[[MemSeqCst:[0-9]+]] = OpConstant %[[TyInt32]] 16{{$}}
+; CHECK: %[[ConstF16Vec4:[0-9]+]] = OpConstantComposite %[[TyF16Vec4]] %[[ConstF16]] %[[ConstF16]] %[[ConstF16]] %[[ConstF16]]
+
+@f = common dso_local local_unnamed_addr addrspace(1) global <2 x half> <half 0.000000e+00, half 0.000000e+00>
+@g = common dso_local local_unnamed_addr addrspace(1) global <4 x half> <half 0.000000e+00, half 0.000000e+00, half 0.000000e+00, half 0.000000e+00>
+
+; CHECK-DAG: OpAtomicFAddEXT %[[TyF16Vec2]] %[[f]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[ConstF16Vec2]]
+; CHECK: %[[NegatedConstF16Vec2:[0-9]+]] = OpFNegate %[[TyF16Vec2]] %[[ConstF16Vec2]]
+; CHECK: OpAtomicFAddEXT %[[TyF16Vec2]] %[[f]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[NegatedConstF16Vec2]]
+define dso_local spir_func void @test1() local_unnamed_addr {
+entry:
+  %addval = atomicrmw fadd ptr addrspace(1) @f, <2 x half> <half 42.000000e+00, half 42.000000e+00> seq_cst
+  %subval = atomicrmw fsub ptr addrspace(1) @f, <2 x half> <half 42.000000e+00, half 42.000000e+00> seq_cst
+  ret void
+}
+
+; CHECK-DAG: OpAtomicFAddEXT %[[TyF16Vec4]] %[[g]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[ConstF16Vec4]]
+; CHECK: %[[NegatedConstF16Vec4:[0-9]+]] = OpFNegate %[[TyF16Vec4]] %[[ConstF16Vec4]]
+; CHECK: OpAtomicFAddEXT %[[TyF16Vec4]] %[[g]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[NegatedConstF16Vec4]]
+define dso_local spir_func void @test2() local_unnamed_addr {
+entry:
+  %addval = atomicrmw fadd ptr addrspace(1) @g, <4 x half> <half 42.000000e+00, half 42.000000e+00, half 42.000000e+00, half 42.000000e+00> seq_cst
+  %subval = atomicrmw fsub ptr addrspace(1) @g, <4 x half> <half 42.000000e+00, half 42.000000e+00, half 42.000000e+00, half 42.000000e+00> seq_cst
+  ret void
+}
\ No newline at end of file
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_NV_shader_atomic_fp16_vector/atomicrmw_fminfmax_vec_float16.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_NV_shader_atomic_fp16_vector/atomicrmw_fminfmax_vec_float16.ll
new file mode 100644
index 0000000000000..7ac772bf5d094
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_NV_shader_atomic_fp16_vector/atomicrmw_fminfmax_vec_float16.ll
@@ -0,0 +1,45 @@
+; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_NV_shader_atomic_fp16_vector %s -o - | FileCheck %s
+
+; CHECK-ERROR: LLVM ERROR: The atomic float16 vector instruction requires the following SPIR-V extension: SPV_NV_shader_atomic_fp16_vector
+
+; CHECK: Capability Float16
+; CHECK-DAG: Capability AtomicFloat16VectorNV
+; CHECK: Extension "SPV_NV_shader_atomic_fp16_vector"
+; CHECK-DAG: %[[TyF16:[0-9]+]] = OpTypeFloat 16
+; CHECK: %[[TyF16Vec2:[0-9]+]] = OpTypeVector %[[TyF16]] 2
+; CHECK: %[[TyF16Vec4:[0-9]+]] = OpTypeVector %[[TyF16]] 4
+; CHECK: %[[TyF16Vec4Ptr:[0-9]+]] = OpTypePointer {{[a-zA-Z]+}} %[[TyF16Vec4]]
+; CHECK: %[[TyF16Vec2Ptr:[0-9]+]] = OpTypePointer {{[a-zA-Z]+}} %[[TyF16Vec2]]
+; CHECK: %[[TyInt32:[0-9]+]] = OpTypeInt 32 0
+; CHECK: %[[ConstF16:[0-9]+]] = OpConstant %[[TyF16]] 20800{{$}}
+; CHECK: %[[Const0F16Vec2:[0-9]+]] = OpConstantNull %[[TyF16Vec2]]
+; CHECK: %[[f:[0-9]+]] = OpVariable %[[TyF16Vec2Ptr]] CrossWorkgroup %[[Const0F16Vec2]]
+; CHECK: %[[Const0F16Vec4:[0-9]+]] = OpConstantNull %[[TyF16Vec4]]
+; CHECK: %[[g:[0-9]+]] = OpVariable %[[TyF16Vec4Ptr]] CrossWorkgroup %[[Const0F16Vec4]]
+; CHECK: %[[ConstF16Vec2:[0-9]+]] = OpConstantComposite %[[TyF16Vec2]] %[[ConstF16]] %[[ConstF16]]
+; CHECK: %[[ScopeAllSvmDevices:[0-9]+]] = OpConstantNull %[[TyInt32]]
+; CHECK: %[[MemSeqCst:[0-9]+]] = OpConstant %[[TyInt32]] 16{{$}}
+; CHECK: %[[ConstF16Vec4:[0-9]+]] = OpConstantComposite %[[TyF16Vec4]] %[[ConstF16]] %[[ConstF16]] %[[ConstF16]] %[[ConstF16]]
+
+@f = common dso_local local_unnamed_addr addrspace(1) global <2 x half> <half 0.000000e+00, half 0.000000e+00>
+@g = common dso_local local_unnamed_addr addrspace(1) global <4 x half> <half 0.000000e+00, half 0.000000e+00, half 0.000000e+00, half 0.000000e+00>
+
+; CHECK-DAG: OpAtomicFMinEXT %[[TyF16Vec2]] %[[f]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[ConstF16Vec2]]
+; CHECK: OpAtomicFMaxEXT %[[TyF16Vec2]] %[[f]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[ConstF16Vec2]]
+define dso_local spir_func void @test1() local_unnamed_addr {
+entry:
+  %minval = atomicrmw fmin ptr addrspace(1) @f, <2 x half> <half 42.000000e+00, half 42.000000e+00> seq_cst
+  %maxval = atomicrmw fmax ptr addrspace(1) @f, <2 x half> <half 42.000000e+00, half 42.000000e+00> seq_cst
+  ret void
+}
+
+; CHECK-DAG: OpAtomicFMinEXT %[[TyF16Vec4]] %[[g]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[ConstF16Vec4]]
+; CHECK: OpAtomicFMaxEXT %[[TyF16Vec4]] %[[g]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[ConstF16Vec4]]
+define dso_local spir_func void @test2() local_unnamed_addr {
+entry:
+  %minval = atomicrmw fmin ptr addrspace(1) @g, <4 x half> <half 42.000000e+00, half 42.000000e+00, half 42.000000e+00, half 42.000000e+00> seq_cst
+  %maxval = atomicrmw fmax ptr addrspace(1) @g, <4 x half> <half 42.000000e+00, half 42.000000e+00, half 42.000000e+00, half 42.000000e+00> seq_cst
+  ret void
+}
\ No newline at end of file

@github-actions
Copy link

github-actions bot commented Dec 1, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@AlexVlx
Copy link
Contributor Author

AlexVlx commented Dec 1, 2025

@YixingZhang007 and @vmaksimo still cannot seem to add either of you as a reviewer, but as with the bfloat16 bit, I assume this might be of interest.

@github-actions
Copy link

github-actions bot commented Dec 1, 2025

🐧 Linux x64 Test Results

  • 186816 tests passed
  • 4893 tests skipped

✅ The build succeeded and all tests passed.

Copy link
Contributor

@jmmartinez jmmartinez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@AlexVlx AlexVlx merged commit 93d64a5 into llvm:main Dec 5, 2025
12 checks passed
@AlexVlx AlexVlx deleted the spirv_be_staging_10 branch December 5, 2025 20:23
honeygoyal pushed a commit to honeygoyal/llvm-project that referenced this pull request Dec 9, 2025
…atomic_fp16_vector` (llvm#170213)

This adds support for the `SPV_NV_shader_atomic_fp16_vector` extension,
and then uses it to enable lowering of atomic add, sub, min and max on 2
and 4 component vectors of FP16, which are rather common options in ML
workloads. Even though `bfloat16` also works in practice, we do not
enable it since it's not specified in the extension (which might need
updating / promoting to KHR at least). A `TODO` is also inserted in
`SPIRVModuleAnalysis.cpp' regarding the need to upgrade its ample usage
of `report_fatal_error`; I have a WiP patch for that, but it still needs
a bit of baking. Finally, a paired patch will be necessary in the
Translator, as it's not aware of the extension either - I'll update this
review to reference the PR once I create it.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants