Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for atomic instruction on floating-point numbers #81683

Conversation

VyacheslavLevytskyy
Copy link
Contributor

@VyacheslavLevytskyy VyacheslavLevytskyy commented Feb 13, 2024

This PR adds support for atomic instruction on floating-point numbers:

  • SPV_EXT_shader_atomic_float_add
  • SPV_EXT_shader_atomic_float_min_max
  • SPV_EXT_shader_atomic_float16_add

and fixes asm printer output for half floating-type.

@llvmbot
Copy link
Collaborator

llvmbot commented Feb 13, 2024

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

Changes

This PR adds support for atomic instruction on floating-point numbers:

  • SPV_EXT_shader_atomic_float_add
  • SPV_EXT_shader_atomic_float_min_max
  • SPV_EXT_shader_atomic_float16_add
    and fixes asm printer output for half floating-type.

Patch is 40.55 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81683.diff

18 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp (+3-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp (+64-5)
  • (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.td (+39)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstrInfo.h (+8)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstrInfo.td (+3)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+27-4)
  • (modified) llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp (+6)
  • (modified) llvm/lib/Target/SPIRV/SPIRVMCInstLower.cpp (+2)
  • (modified) llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (+66)
  • (modified) llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp (+17-5)
  • (modified) llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td (+6)
  • (modified) llvm/lib/Target/SPIRV/SPIRVUtils.cpp (+3)
  • (added) llvm/test/CodeGen/SPIRV/extensions/SPV_EXT_shader_atomic_float_add/atomicrmw_faddfsub_double.ll (+43)
  • (added) llvm/test/CodeGen/SPIRV/extensions/SPV_EXT_shader_atomic_float_add/atomicrmw_faddfsub_float.ll (+43)
  • (added) llvm/test/CodeGen/SPIRV/extensions/SPV_EXT_shader_atomic_float_add/atomicrmw_faddfsub_half.ll (+46)
  • (added) llvm/test/CodeGen/SPIRV/extensions/SPV_EXT_shader_atomic_float_min_max/atomicrmw_fminfmax_double.ll (+45)
  • (added) llvm/test/CodeGen/SPIRV/extensions/SPV_EXT_shader_atomic_float_min_max/atomicrmw_fminfmax_float.ll (+45)
  • (added) llvm/test/CodeGen/SPIRV/extensions/SPV_EXT_shader_atomic_float_min_max/atomicrmw_fminfmax_half.ll (+45)
diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
index 163b2ec0fefe4d..b468b71cc0efb4 100644
--- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
+++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
@@ -13,6 +13,7 @@
 #include "SPIRVInstPrinter.h"
 #include "SPIRV.h"
 #include "SPIRVBaseInfo.h"
+#include "SPIRVInstrInfo.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/CodeGen/Register.h"
 #include "llvm/MC/MCAsmInfo.h"
@@ -50,6 +51,7 @@ void SPIRVInstPrinter::printRemainingVariableOps(const MCInst *MI,
 void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI,
                                              unsigned StartIndex,
                                              raw_ostream &O) {
+  unsigned IsBitwidth16 = MI->getFlags() & SPIRV::ASM_PRINTER_WIDTH16;
   const unsigned NumVarOps = MI->getNumOperands() - StartIndex;
 
   assert((NumVarOps == 1 || NumVarOps == 2) &&
@@ -65,7 +67,7 @@ void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI,
   }
 
   // Format and print float values.
-  if (MI->getOpcode() == SPIRV::OpConstantF) {
+  if (MI->getOpcode() == SPIRV::OpConstantF && IsBitwidth16 == 0) {
     APFloat FP = NumVarOps == 1 ? APFloat(APInt(32, Imm).bitsToFloat())
                                 : APFloat(APInt(64, Imm).bitsToDouble());
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 8721b900c8beee..545891b9ba6d73 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -93,6 +93,14 @@ struct IntelSubgroupsBuiltin {
 #define GET_IntelSubgroupsBuiltins_DECL
 #define GET_IntelSubgroupsBuiltins_IMPL
 
+struct AtomicFloatingBuiltin {
+  StringRef Name;
+  uint32_t Opcode;
+};
+
+#define GET_AtomicFloatingBuiltins_DECL
+#define GET_AtomicFloatingBuiltins_IMPL
+
 struct GetBuiltin {
   StringRef Name;
   InstructionSet::InstructionSet Set;
@@ -402,7 +410,7 @@ getSPIRVMemSemantics(std::memory_order MemOrder) {
   case std::memory_order::memory_order_seq_cst:
     return SPIRV::MemorySemantics::SequentiallyConsistent;
   default:
-    llvm_unreachable("Unknown CL memory scope");
+    report_fatal_error("Unknown CL memory scope");
   }
 }
 
@@ -419,7 +427,7 @@ static SPIRV::Scope::Scope getSPIRVScope(SPIRV::CLMemoryScope ClScope) {
   case SPIRV::CLMemoryScope::memory_scope_sub_group:
     return SPIRV::Scope::Subgroup;
   }
-  llvm_unreachable("Unknown CL memory scope");
+  report_fatal_error("Unknown CL memory scope");
 }
 
 static Register buildConstantIntReg(uint64_t Val, MachineIRBuilder &MIRBuilder,
@@ -676,6 +684,38 @@ static bool buildAtomicRMWInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
   return true;
 }
 
+/// Helper function for building an atomic floating-type instruction.
+static bool buildAtomicFloatingRMWInst(const SPIRV::IncomingCall *Call,
+                                       unsigned Opcode,
+                                       MachineIRBuilder &MIRBuilder,
+                                       SPIRVGlobalRegistry *GR) {
+  assert(Call->Arguments.size() == 4 &&
+         "Wrong number of atomic floating-type builtin");
+
+  MachineRegisterInfo *MRI = MIRBuilder.getMRI();
+
+  Register PtrReg = Call->Arguments[0];
+  MRI->setRegClass(PtrReg, &SPIRV::IDRegClass);
+
+  Register ScopeReg = Call->Arguments[1];
+  MRI->setRegClass(ScopeReg, &SPIRV::IDRegClass);
+
+  Register MemSemanticsReg = Call->Arguments[2];
+  MRI->setRegClass(MemSemanticsReg, &SPIRV::IDRegClass);
+
+  Register ValueReg = Call->Arguments[3];
+  MRI->setRegClass(ValueReg, &SPIRV::IDRegClass);
+
+  MIRBuilder.buildInstr(Opcode)
+      .addDef(Call->ReturnRegister)
+      .addUse(GR->getSPIRVTypeID(Call->ReturnType))
+      .addUse(PtrReg)
+      .addUse(ScopeReg)
+      .addUse(MemSemanticsReg)
+      .addUse(ValueReg);
+  return true;
+}
+
 /// Helper function for building atomic flag instructions (e.g.
 /// OpAtomicFlagTestAndSet).
 static bool buildAtomicFlagInst(const SPIRV::IncomingCall *Call,
@@ -786,7 +826,7 @@ static unsigned getNumComponentsForDim(SPIRV::Dim::Dim dim) {
   case SPIRV::Dim::DIM_3D:
     return 3;
   default:
-    llvm_unreachable("Cannot get num components for given Dim");
+    report_fatal_error("Cannot get num components for given Dim");
   }
 }
 
@@ -1157,6 +1197,23 @@ static bool generateAtomicInst(const SPIRV::IncomingCall *Call,
   }
 }
 
+static bool generateAtomicFloatingInst(const SPIRV::IncomingCall *Call,
+                                       MachineIRBuilder &MIRBuilder,
+                                       SPIRVGlobalRegistry *GR) {
+  // Lookup the instruction opcode in the TableGen records.
+  const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
+  unsigned Opcode = SPIRV::lookupAtomicFloatingBuiltin(Builtin->Name)->Opcode;
+
+  switch (Opcode) {
+  case SPIRV::OpAtomicFAddEXT:
+  case SPIRV::OpAtomicFMinEXT:
+  case SPIRV::OpAtomicFMaxEXT:
+    return buildAtomicFloatingRMWInst(Call, Opcode, MIRBuilder, GR);
+  default:
+    return false;
+  }
+}
+
 static bool generateBarrierInst(const SPIRV::IncomingCall *Call,
                                 MachineIRBuilder &MIRBuilder,
                                 SPIRVGlobalRegistry *GR) {
@@ -1311,7 +1368,7 @@ getSamplerAddressingModeFromBitmask(unsigned Bitmask) {
   case SPIRV::CLK_ADDRESS_NONE:
     return SPIRV::SamplerAddressingMode::None;
   default:
-    llvm_unreachable("Unknown CL address mode");
+    report_fatal_error("Unknown CL address mode");
   }
 }
 
@@ -2021,6 +2078,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
     return generateBuiltinVar(Call.get(), MIRBuilder, GR);
   case SPIRV::Atomic:
     return generateAtomicInst(Call.get(), MIRBuilder, GR);
+  case SPIRV::AtomicFloating:
+    return generateAtomicFloatingInst(Call.get(), MIRBuilder, GR);
   case SPIRV::Barrier:
     return generateBarrierInst(Call.get(), MIRBuilder, GR);
   case SPIRV::Dot:
@@ -2089,7 +2148,7 @@ static Type *parseTypeString(const StringRef Name, LLVMContext &Context) {
     return Type::getFloatTy(Context);
   else if (Name.starts_with("half"))
     return Type::getHalfTy(Context);
-  llvm_unreachable("Unable to recognize type!");
+  report_fatal_error("Unable to recognize type!");
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
index 4013dd22f4ab57..92009ae33ef32f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
@@ -55,6 +55,7 @@ def AsyncCopy : BuiltinGroup;
 def VectorLoadStore : BuiltinGroup;
 def LoadStore : BuiltinGroup;
 def IntelSubgroups : BuiltinGroup;
+def AtomicFloating : BuiltinGroup;
 
 //===----------------------------------------------------------------------===//
 // Class defining a demangled builtin record. The information in the record
@@ -872,6 +873,44 @@ defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_logical_xors", Wo
 defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_logical_xors", WorkOrSub, OpGroupNonUniformLogicalXor>;
 defm : DemangledGroupBuiltin<"group_clustered_reduce_logical_xor", WorkOrSub, OpGroupNonUniformLogicalXor>;
 
+//===----------------------------------------------------------------------===//
+// Class defining an atomic instruction on floating-point numbers.
+//
+// name is the demangled name of the given builtin.
+// opcode specifies the SPIR-V operation code of the generated instruction.
+//===----------------------------------------------------------------------===//
+class AtomicFloatingBuiltin<string name, Op operation> {
+  string Name = name;
+  Op Opcode = operation;
+}
+
+// Table gathering all the Intel sub group builtins.
+def AtomicFloatingBuiltins : GenericTable {
+  let FilterClass = "AtomicFloatingBuiltin";
+  let Fields = ["Name", "Opcode"];
+}
+
+// Function to lookup group builtins by their name and set.
+def lookupAtomicFloatingBuiltin : SearchIndex {
+  let Table = AtomicFloatingBuiltins;
+  let Key = ["Name"];
+}
+
+// Multiclass used to define incoming builtin records for the SPV_INTEL_subgroups extension
+// and corresponding work/sub group builtin records.
+multiclass DemangledAtomicFloatingBuiltin<string name, bits<8> minNumArgs, bits<8> maxNumArgs, Op operation> {
+  def : DemangledBuiltin<!strconcat("__spirv_AtomicF", name), OpenCL_std, AtomicFloating, minNumArgs, maxNumArgs>;
+  def : AtomicFloatingBuiltin<!strconcat("__spirv_AtomicF", name), operation>;
+}
+
+// SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max, SPV_EXT_shader_atomic_float16_add
+// Atomic add, min and max instruction on floating-point numbers:
+defm : DemangledAtomicFloatingBuiltin<"AddEXT", 4, 4, OpAtomicFAddEXT>;
+defm : DemangledAtomicFloatingBuiltin<"MinEXT", 4, 4, OpAtomicFMinEXT>;
+defm : DemangledAtomicFloatingBuiltin<"MaxEXT", 4, 4, OpAtomicFMaxEXT>;
+// TODO: add support for cl_ext_float_atomics to enable performing atomic operations
+//       on floating-point numbers in memory (float arguments for atomic_fetch_add, ...)
+
 //===----------------------------------------------------------------------===//
 // Class defining a sub group builtin that should be translated into a
 // SPIR-V instruction using the SPV_INTEL_subgroups extension.
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.h b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.h
index c01e30e109bd5b..b9a3425cb5a1c3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.h
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.h
@@ -53,6 +53,14 @@ class SPIRVInstrInfo : public SPIRVGenInstrInfo {
                    bool KillSrc) const override;
   bool expandPostRAPseudo(MachineInstr &MI) const override;
 };
+
+namespace SPIRV {
+enum AsmComments {
+  // It is a half type
+  ASM_PRINTER_WIDTH16 = MachineInstr::TAsmComments
+};
+}; // namespace SPIRV
+
 } // namespace llvm
 
 #endif // LLVM_LIB_TARGET_SPIRV_SPIRVINSTRINFO_H
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index 904fef1d6c82f9..7965dd969e5cfa 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -643,6 +643,9 @@ def OpAtomicAnd: AtomicOpVal<"OpAtomicAnd", 240>;
 def OpAtomicOr: AtomicOpVal<"OpAtomicOr", 241>;
 def OpAtomicXor: AtomicOpVal<"OpAtomicXor", 242>;
 
+def OpAtomicFAddEXT: AtomicOpVal<"OpAtomicFAddEXT", 6035>;
+def OpAtomicFMinEXT: AtomicOpVal<"OpAtomicFMinEXT", 5614>;
+def OpAtomicFMaxEXT: AtomicOpVal<"OpAtomicFMaxEXT", 5615>;
 
 def OpAtomicFlagTestAndSet: AtomicOp<"OpAtomicFlagTestAndSet", 318>;
 def OpAtomicFlagClear: Op<319, (outs), (ins ID:$ptr, ID:$sc, ID:$sem),
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 52eeb8a523e6f6..79ccf5bee7ff37 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -102,7 +102,8 @@ class SPIRVInstructionSelector : public InstructionSelector {
   bool selectMemOperation(Register ResVReg, MachineInstr &I) const;
 
   bool selectAtomicRMW(Register ResVReg, const SPIRVType *ResType,
-                       MachineInstr &I, unsigned NewOpcode) const;
+                       MachineInstr &I, unsigned NewOpcode,
+                       unsigned NegateOpcode = 0) const;
 
   bool selectAtomicCmpXchg(Register ResVReg, const SPIRVType *ResType,
                            MachineInstr &I) const;
@@ -489,6 +490,17 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
   case TargetOpcode::G_ATOMIC_CMPXCHG:
     return selectAtomicCmpXchg(ResVReg, ResType, I);
 
+  case TargetOpcode::G_ATOMICRMW_FADD:
+    return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicFAddEXT);
+  case TargetOpcode::G_ATOMICRMW_FSUB:
+    // Translate G_ATOMICRMW_FSUB to OpAtomicFAddEXT with negative value operand
+    return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicFAddEXT,
+                           SPIRV::OpFNegate);
+  case TargetOpcode::G_ATOMICRMW_FMIN:
+    return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicFMinEXT);
+  case TargetOpcode::G_ATOMICRMW_FMAX:
+    return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicFMaxEXT);
+
   case TargetOpcode::G_FENCE:
     return selectFence(I);
 
@@ -686,7 +698,8 @@ bool SPIRVInstructionSelector::selectMemOperation(Register ResVReg,
 bool SPIRVInstructionSelector::selectAtomicRMW(Register ResVReg,
                                                const SPIRVType *ResType,
                                                MachineInstr &I,
-                                               unsigned NewOpcode) const {
+                                               unsigned NewOpcode,
+                                               unsigned NegateOpcode) const {
   assert(I.hasOneMemOperand());
   const MachineMemOperand *MemOp = *I.memoperands_begin();
   uint32_t Scope = static_cast<uint32_t>(getScope(MemOp->getSyncScopeID()));
@@ -700,14 +713,24 @@ bool SPIRVInstructionSelector::selectAtomicRMW(Register ResVReg,
   uint32_t MemSem = static_cast<uint32_t>(getMemSemantics(AO));
   Register MemSemReg = buildI32Constant(MemSem /*| ScSem*/, I);
 
-  return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(NewOpcode))
+  bool Result = false;
+  Register ValueReg = I.getOperand(2).getReg();
+  if (NegateOpcode != 0) {
+    // Translation with negative value operand is requested
+    Register TmpReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+    Result |= selectUnOpWithSrc(TmpReg, ResType, I, ValueReg, NegateOpcode);
+    ValueReg = TmpReg;
+  }
+
+  Result |= BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(NewOpcode))
       .addDef(ResVReg)
       .addUse(GR.getSPIRVTypeID(ResType))
       .addUse(Ptr)
       .addUse(ScopeReg)
       .addUse(MemSemReg)
-      .addUse(I.getOperand(2).getReg())
+      .addUse(ValueReg)
       .constrainAllUses(TII, TRI, RBI);
+  return Result;
 }
 
 bool SPIRVInstructionSelector::selectFence(MachineInstr &I) const {
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 061bc967423712..011a550a7b3d9b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -125,6 +125,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
 
   auto allIntScalars = {s8, s16, s32, s64};
 
+  auto allFloatScalars = {s16, s32, s64};
+
   auto allFloatScalarsAndVectors = {
       s16,   s32,   s64,   v2s16, v2s32, v2s64, v3s16,  v3s32,  v3s64,
       v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
@@ -205,6 +207,10 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
                                G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN})
       .legalForCartesianProduct(allIntScalars, allWritablePtrs);
 
+  getActionDefinitionsBuilder(
+      {G_ATOMICRMW_FADD, G_ATOMICRMW_FSUB, G_ATOMICRMW_FMIN, G_ATOMICRMW_FMAX})
+      .legalForCartesianProduct(allFloatScalars, allWritablePtrs);
+
   getActionDefinitionsBuilder(G_ATOMICRMW_XCHG)
       .legalForCartesianProduct(allFloatAndIntScalars, allWritablePtrs);
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVMCInstLower.cpp b/llvm/lib/Target/SPIRV/SPIRVMCInstLower.cpp
index 0fa05d377d9f10..8c6649bf628265 100644
--- a/llvm/lib/Target/SPIRV/SPIRVMCInstLower.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVMCInstLower.cpp
@@ -23,6 +23,8 @@ using namespace llvm;
 void SPIRVMCInstLower::lower(const MachineInstr *MI, MCInst &OutMI,
                              SPIRV::ModuleAnalysisInfo *MAI) const {
   OutMI.setOpcode(MI->getOpcode());
+  // Propagate previously set flags
+  OutMI.setFlags(MI->getAsmPrinterFlags());
   const MachineFunction *MF = MI->getMF();
   for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) {
     const MachineOperand &MO = MI->getOperand(i);
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index a18aae1761c834..a716732ea57056 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -704,6 +704,67 @@ static void addOpTypeImageReqs(const MachineInstr &MI,
     Reqs.addRequirements(SPIRV::Capability::ImageBasic);
 }
 
+// Add requirements for handling atomic float instructions
+#define ATOM_FLT_REQ_EXT_MSG(ExtName)                                          \
+  "The atomic float instruction requires the following SPIR-V "                \
+  "extension: SPV_EXT_shader_atomic_float" ExtName
+static void AddAtomicFloatRequirements(const MachineInstr &MI,
+                                       SPIRV::RequirementHandler &Reqs,
+                                       const SPIRVSubtarget &ST) {
+  assert(MI.getOperand(1).isReg() &&
+         "Expect register operand in atomic float instruction");
+  Register TypeReg = MI.getOperand(1).getReg();
+  SPIRVType *TypeDef = MI.getMF()->getRegInfo().getVRegDef(TypeReg);
+  if (TypeDef->getOpcode() != SPIRV::OpTypeFloat)
+    report_fatal_error("Result type of an atomic float instruction must be a "
+                       "floating-point type scalar");
+
+  unsigned BitWidth = TypeDef->getOperand(1).getImm();
+  unsigned Op = MI.getOpcode();
+  if (Op == SPIRV::OpAtomicFAddEXT) {
+    if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add))
+      report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_add"), false);
+    Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add);
+    switch (BitWidth) {
+    case 16:
+      if (!ST.canUseExtension(
+              SPIRV::Extension::SPV_EXT_shader_atomic_float16_add))
+        report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false);
+      Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add);
+      Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT);
+      break;
+    case 32:
+      Reqs.addCapability(SPIRV::Capability::AtomicFloat32AddEXT);
+      break;
+    case 64:
+      Reqs.addCapability(SPIRV::Capability::AtomicFloat64AddEXT);
+      break;
+    default:
+      report_fatal_error(
+          "Unexpected floating-point type width in atomic float instruction");
+    }
+  } else {
+    if (!ST.canUseExtension(
+            SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max))
+      report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_min_max"), false);
+    Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max);
+    switch (BitWidth) {
+    case 16:
+      Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT);
+      break;
+    case 32:
+      Reqs.addCapability(SPIRV::Capability::AtomicFloat32MinMaxEXT);
+      break;
+    case 64:
+      Reqs.addCapability(SPIRV::Capability::AtomicFloat64MinMaxEXT);
+      break;
+    default:
+      report_fatal_error(
+          "Unexpected floating-point type width in atomic float instruction");
+    }
+  }
+}
+
 void addInstrRequirements(const MachineInstr &MI,
                           SPIRV::RequirementHandler &Reqs,
                           const SPIRVSubtarget &ST) {
@@ -976,6 +1037,11 @@ void addInstrRequirements(const MachineInstr &MI,
       Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);
     }
     break;
+  case SPIRV::OpAtomicFAddEXT:
+  case SPIRV::OpAtomicFMinEXT:
+  case SPIRV::OpAtomicFMaxEXT:
+    AddAtomicFloatRequirements(MI, Reqs, ST);
+    break;
   default:
     break;
   }
diff --git a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
index effedc2f17d351..f8fd6309ce0e72 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
@@ -31,12 +31,24 @@ cl::list<SPIRV::Extension::Extension> Extensions(
     "spirv-extensions", cl::desc("SPIR-V extensions"), cl::ZeroOrMore,
     cl::Hidden,
     cl::values(
+        clEnumValN(SPIRV::Extension::SPV_EXT_shader_atomic_float_add,
+                   "SPV_EXT_shader_atomic_float_add",
+                   "Adds atomic add instruction on floating-point numbers."),
+        clEnumValN(
+            SPIRV::Extension::SPV_EXT_shader_atomic_float16_add,
+            "SPV_EXT_shader_atomic_float16_add",
+            "Extends the SPV_EXT_shader_atomic_float_add extension to support "
+            "atomically adding to 16-bit floating-point numbers in memory."),
+        clEnumValN(
+            SPIRV::E...
[truncated]

Copy link

github-actions bot commented Feb 13, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.


; RUN: llc -O0 -mtriple=spirv32-unknown-unknown --spirv-extensions=SPV_EXT_shader_atomic_float_add %s -o - | FileCheck %s

; CHECK-ERROR: LLVM ERROR: The atomic float instruction requires the following SPIR-V extension: SPV_EXT_shader_atomic_float_add
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat idea!

@VyacheslavLevytskyy
Copy link
Contributor Author

SPIR-V Tests failing isn't something connected with the PR, but an unrelated general error:

[2277/2903] Performing download step for 'SPIRVTools' FAILED: tools/spirv-tools/SPIRVTools-prefix/src/SPIRVTools-stamp/SPIRVTools-download /mnt/build/tools/spirv-tools/SPIRVTools-prefix/src/SPIRVTools-stamp/SPIRVTools-download cd /mnt/build/tools/spirv-tools/SPIRVTools-prefix/src && git clone https://github.com/KhronosGroup/SPIRV-Tools.git SPIRVTools && cd SPIRVTools && /__t/Python/3.11.8/x64/bin/python3 utils/git-sync-deps && /usr/bin/cmake -E touch /mnt/build/tools/spirv-tools/SPIRVTools-prefix/src/SPIRVTools-stamp/SPIRVTools-download /bin/sh: 1: git: not found

@VyacheslavLevytskyy VyacheslavLevytskyy merged commit 925768e into llvm:main Feb 19, 2024
4 of 5 checks passed
MaskRay added a commit that referenced this pull request Feb 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants