-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[SPIRV] Add support for bfloat16 atomics via the SPV_INTEL_16bit_atomics extension
#166257
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-llvm-globalisel @llvm/pr-subscribers-backend-spir-v Author: Alex Voicu (AlexVlx) ChangesThis enables support for atomic RMW ops (add, sub, min and max to be precise) with Full diff: https://github.com/llvm/llvm-project/pull/166257.diff 7 Files Affected:
diff --git a/llvm/docs/SPIRVUsage.rst b/llvm/docs/SPIRVUsage.rst
index 85eeabf10244a..eaa7693ad87e5 100644
--- a/llvm/docs/SPIRVUsage.rst
+++ b/llvm/docs/SPIRVUsage.rst
@@ -167,6 +167,8 @@ Below is a list of supported SPIR-V extensions, sorted alphabetically by their e
- Adds atomic add instruction on floating-point numbers.
* - ``SPV_EXT_shader_atomic_float_min_max``
- Adds atomic min and max instruction on floating-point numbers.
+ * - ``SPV_INTEL_16bit_atomics``
+ - Extends the SPV_EXT_shader_atomic_float_add and SPV_EXT_shader_atomic_float_min_max to support addition, minimum and maximum on 16-bit `bfloat16` floating-point numbers in memory.
* - ``SPV_INTEL_2d_block_io``
- Adds additional subgroup block prefetch, load, load transposed, load transformed and store instructions to read two-dimensional blocks of data from a two-dimensional region of memory, or to write two-dimensional blocks of data to a two dimensional region of memory.
* - ``SPV_INTEL_arbitrary_precision_integers``
@@ -226,9 +228,9 @@ Below is a list of supported SPIR-V extensions, sorted alphabetically by their e
* - ``SPV_INTEL_fp_max_error``
- Adds the ability to specify the maximum error for floating-point operations.
* - ``SPV_INTEL_ternary_bitwise_function``
- - Adds a bitwise instruction on three operands and a look-up table index for specifying the bitwise operation to perform.
+ - Adds a bitwise instruction on three operands and a look-up table index for specifying the bitwise operation to perform.
* - ``SPV_INTEL_subgroup_matrix_multiply_accumulate``
- - Adds an instruction to compute the matrix product of an M x K matrix with a K x N matrix and then add an M x N matrix.
+ - Adds an instruction to compute the matrix product of an M x K matrix with a K x N matrix and then add an M x N matrix.
* - ``SPV_INTEL_int4``
- Adds support for 4-bit integer type, and allow this type to be used in cooperative matrices.
* - ``SPV_KHR_float_controls2``
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 1fc90d0852aad..6a1da45de9eae 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -3453,7 +3453,7 @@ bool IRTranslator::translateAtomicCmpXchg(const User &U,
bool IRTranslator::translateAtomicRMW(const User &U,
MachineIRBuilder &MIRBuilder) {
- if (containsBF16Type(U))
+ if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
return false;
const AtomicRMWInst &I = cast<AtomicRMWInst>(U);
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index 96f5dee21bc2a..03593cf65237d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -29,6 +29,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
SPIRV::Extension::Extension::SPV_EXT_shader_atomic_float16_add},
{"SPV_EXT_shader_atomic_float_min_max",
SPIRV::Extension::Extension::SPV_EXT_shader_atomic_float_min_max},
+ {"SPV_INTEL_16bit_atomics",
+ SPIRV::Extension::Extension::SPV_INTEL_16bit_atomics},
{"SPV_EXT_arithmetic_fence",
SPIRV::Extension::Extension::SPV_EXT_arithmetic_fence},
{"SPV_EXT_demote_to_helper_invocation",
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index db036a55ee6c6..38948050b89d6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1058,6 +1058,13 @@ static void addOpTypeImageReqs(const MachineInstr &MI,
}
}
+static bool isBFloat16Type(const SPIRVType *TypeDef) {
+ return TypeDef && TypeDef->getNumOperands() == 3 &&
+ TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
+ TypeDef->getOperand(1).getImm() == 16 &&
+ TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR;
+}
+
// Add requirements for handling atomic float instructions
#define ATOM_FLT_REQ_EXT_MSG(ExtName) \
"The atomic float instruction requires the following SPIR-V " \
@@ -1081,11 +1088,20 @@ static void AddAtomicFloatRequirements(const MachineInstr &MI,
Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add);
switch (BitWidth) {
case 16:
- if (!ST.canUseExtension(
- SPIRV::Extension::SPV_EXT_shader_atomic_float16_add))
- report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false);
- Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add);
- Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT);
+ if (isBFloat16Type(TypeDef)) {
+ if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics))
+ report_fatal_error(
+ "The atomic bfloat16 instruction requires the following SPIR-V "
+ "extension: SPV_INTEL_16bit_atomics", false);
+ Reqs.addExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics);
+ Reqs.addCapability(SPIRV::Capability::AtomicBFloat16AddINTEL);
+ } else {
+ if (!ST.canUseExtension(
+ SPIRV::Extension::SPV_EXT_shader_atomic_float16_add))
+ report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false);
+ Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add);
+ Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT);
+ }
break;
case 32:
Reqs.addCapability(SPIRV::Capability::AtomicFloat32AddEXT);
@@ -1104,7 +1120,16 @@ static void AddAtomicFloatRequirements(const MachineInstr &MI,
Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max);
switch (BitWidth) {
case 16:
- Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT);
+ if (isBFloat16Type(TypeDef)) {
+ if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics))
+ report_fatal_error(
+ "The atomic bfloat16 instruction requires the following SPIR-V "
+ "extension: SPV_INTEL_16bit_atomics", false);
+ Reqs.addExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics);
+ Reqs.addCapability(SPIRV::Capability::AtomicBFloat16MinMaxINTEL);
+ } else {
+ Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT);
+ }
break;
case 32:
Reqs.addCapability(SPIRV::Capability::AtomicFloat32MinMaxEXT);
@@ -1328,13 +1353,6 @@ void addPrintfRequirements(const MachineInstr &MI,
}
}
-static bool isBFloat16Type(const SPIRVType *TypeDef) {
- return TypeDef && TypeDef->getNumOperands() == 3 &&
- TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
- TypeDef->getOperand(1).getImm() == 16 &&
- TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR;
-}
-
void addInstrRequirements(const MachineInstr &MI,
SPIRV::ModuleAnalysisInfo &MAI,
const SPIRVSubtarget &ST) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 7d08b29a51a6e..8257963d9842e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -387,6 +387,7 @@ defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125, [EnvOpenCL]>;
defm SPV_KHR_bfloat16 : ExtensionOperand<126, [EnvVulkan, EnvOpenCL]>;
defm SPV_INTEL_predicated_io : ExtensionOperand<127, [EnvOpenCL]>;
defm SPV_KHR_maximal_reconvergence : ExtensionOperand<128, [EnvVulkan]>;
+defm SPV_INTEL_16bit_atomics : ExtensionOperand<130, [EnvVulkan, EnvOpenCL]>;
//===----------------------------------------------------------------------===//
// Multiclass used to define Capabilities enum values and at the same time
@@ -564,9 +565,11 @@ defm FloatControls2
defm AtomicFloat32AddEXT : CapabilityOperand<6033, 0, 0, [SPV_EXT_shader_atomic_float_add], []>;
defm AtomicFloat64AddEXT : CapabilityOperand<6034, 0, 0, [SPV_EXT_shader_atomic_float_add], []>;
defm AtomicFloat16AddEXT : CapabilityOperand<6095, 0, 0, [SPV_EXT_shader_atomic_float16_add], []>;
+defm AtomicBFloat16AddINTEL : CapabilityOperand<6255, 0, 0, [SPV_INTEL_16bit_atomics], []>;
defm AtomicFloat16MinMaxEXT : CapabilityOperand<5616, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
defm AtomicFloat32MinMaxEXT : CapabilityOperand<5612, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
defm AtomicFloat64MinMaxEXT : CapabilityOperand<5613, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
+defm AtomicBFloat16MinMaxINTEL : CapabilityOperand<6256, 0, 0, [SPV_INTEL_16bit_atomics], []>;
defm VariableLengthArrayINTEL : CapabilityOperand<5817, 0, 0, [SPV_INTEL_variable_length_array], []>;
defm GroupUniformArithmeticKHR : CapabilityOperand<6400, 0, 0, [SPV_KHR_uniform_group_instructions], []>;
defm USMStorageClassesINTEL : CapabilityOperand<5935, 0, 0, [SPV_INTEL_usm_storage_classes], [Kernel]>;
@@ -1919,7 +1922,7 @@ defm GenericCastToPtr : SpecConstantOpOperandsOperand<122, [], [Kernel]>;
defm PtrCastToGeneric : SpecConstantOpOperandsOperand<121, [], [Kernel]>;
defm Bitcast : SpecConstantOpOperandsOperand<124, [], []>;
defm QuantizeToF16 : SpecConstantOpOperandsOperand<116, [], [Shader]>;
-// Arithmetic
+// Arithmetic
defm SNegate : SpecConstantOpOperandsOperand<126, [], []>;
defm Not : SpecConstantOpOperandsOperand<200, [], []>;
defm IAdd : SpecConstantOpOperandsOperand<128, [], []>;
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_16bit_atomics/atomicrmw_faddfsub_bfloat16.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_16bit_atomics/atomicrmw_faddfsub_bfloat16.ll
new file mode 100644
index 0000000000000..6a27251fa756a
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_16bit_atomics/atomicrmw_faddfsub_bfloat16.ll
@@ -0,0 +1,34 @@
+; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR1
+; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_EXT_shader_atomic_float_add,+SPV_KHR_bfloat16 %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR2
+
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_EXT_shader_atomic_float_add,+SPV_INTEL_16bit_atomics,+SPV_KHR_bfloat16 %s -o - | FileCheck %s
+
+; CHECK-ERROR1: LLVM ERROR: The atomic float instruction requires the following SPIR-V extension: SPV_EXT_shader_atomic_float_add
+; CHECK-ERROR2: LLVM ERROR: The atomic bfloat16 instruction requires the following SPIR-V extension: SPV_INTEL_16bit_atomics
+
+; CHECK: Capability BFloat16TypeKHR
+; CHECK: Capability AtomicBFloat16AddINTEL
+; CHECK: Extension "SPV_KHR_bfloat16"
+; CHECK: Extension "SPV_EXT_shader_atomic_float_add"
+; CHECK: Extension "SPV_INTEL_16bit_atomics"
+; CHECK-DAG: %[[TyBF16:[0-9]+]] = OpTypeFloat 16 0
+; CHECK-DAG: %[[TyBF16Ptr:[0-9]+]] = OpTypePointer {{[a-zA-Z]+}} %[[TyBF16]]
+; CHECK-DAG: %[[TyInt32:[0-9]+]] = OpTypeInt 32 0
+; CHECK-DAG: %[[ConstBF16:[0-9]+]] = OpConstant %[[TyBF16]] 16936{{$}}
+; CHECK-DAG: %[[Const0:[0-9]+]] = OpConstantNull %[[TyBF16]]
+; CHECK-DAG: %[[BF16Ptr:[0-9]+]] = OpVariable %[[TyBF16Ptr]] CrossWorkgroup %[[Const0]]
+; CHECK-DAG: %[[ScopeAllSvmDevices:[0-9]+]] = OpConstantNull %[[TyInt32]]
+; CHECK-DAG: %[[MemSeqCst:[0-9]+]] = OpConstant %[[TyInt32]] 16{{$}}
+; CHECK: OpAtomicFAddEXT %[[TyBF16]] %[[BF16Ptr]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[ConstBF16]]
+; CHECK: %[[NegatedConstBF16:[0-9]+]] = OpFNegate %[[TyBF16]] %[[ConstBF16]]
+; CHECK: OpAtomicFAddEXT %[[TyBF16]] %[[BF16Ptr]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[NegatedConstBF16]]
+
+
+@f = common dso_local local_unnamed_addr addrspace(1) global bfloat 0.000000e+00, align 8
+
+define dso_local spir_func void @test1() local_unnamed_addr {
+entry:
+ %addval = atomicrmw fadd ptr addrspace(1) @f, bfloat 42.000000e+00 seq_cst
+ %subval = atomicrmw fsub ptr addrspace(1) @f, bfloat 42.000000e+00 seq_cst
+ ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_16bit_atomics/atomicrmw_fminfmax_bfloat16.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_16bit_atomics/atomicrmw_fminfmax_bfloat16.ll
new file mode 100644
index 0000000000000..507135e0ed783
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_16bit_atomics/atomicrmw_fminfmax_bfloat16.ll
@@ -0,0 +1,28 @@
+; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_EXT_shader_atomic_float_min_max,+SPV_KHR_bfloat16 %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_EXT_shader_atomic_float_min_max,+SPV_INTEL_16bit_atomics,+SPV_KHR_bfloat16 %s -o - | FileCheck %s
+
+; CHECK-ERROR: LLVM ERROR: The atomic bfloat16 instruction requires the following SPIR-V extension: SPV_INTEL_16bit_atomics
+
+; CHECK: Capability AtomicBFloat16MinMaxINTEL
+; CHECK: Extension "SPV_KHR_bfloat16"
+; CHECK: Extension "SPV_EXT_shader_atomic_float_min_max"
+; CHECK: Extension "SPV_INTEL_16bit_atomics"
+; CHECK-DAG: %[[TyBF16:[0-9]+]] = OpTypeFloat 16 0
+; CHECK-DAG: %[[TyBF16Ptr:[0-9]+]] = OpTypePointer {{[a-zA-Z]+}} %[[TyBF16]]
+; CHECK-DAG: %[[TyInt32:[0-9]+]] = OpTypeInt 32 0
+; CHECK-DAG: %[[ConstBF16:[0-9]+]] = OpConstant %[[TyBF16]] 16936{{$}}
+; CHECK-DAG: %[[Const0:[0-9]+]] = OpConstantNull %[[TyBF16]]
+; CHECK-DAG: %[[BF16Ptr:[0-9]+]] = OpVariable %[[TyBF16Ptr]] CrossWorkgroup %[[Const0]]
+; CHECK-DAG: %[[ScopeAllSvmDevices:[0-9]+]] = OpConstantNull %[[TyInt32]]
+; CHECK-DAG: %[[MemSeqCst:[0-9]+]] = OpConstant %[[TyInt32]] 16{{$}}
+; CHECK: OpAtomicFMinEXT %[[TyBF16]] %[[BF16Ptr]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[ConstBF16]]
+; CHECK: OpAtomicFMaxEXT %[[TyBF16]] %[[BF16Ptr]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[ConstBF16]]
+
+@f = common dso_local local_unnamed_addr addrspace(1) global bfloat 0.000000e+00, align 8
+
+define dso_local spir_func void @test1() local_unnamed_addr {
+entry:
+ %minval = atomicrmw fmin ptr addrspace(1) @f, bfloat 42.0e+00 seq_cst
+ %maxval = atomicrmw fmax ptr addrspace(1) @f, bfloat 42.0e+00 seq_cst
+ ret void
+}
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
|
@YixingZhang007 for some reason I cannot add you as a reviewer, however I suspect this might be of interest. |
| Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add); | ||
| Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT); | ||
| if (isBFloat16Type(TypeDef)) { | ||
| if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This pass should probably not be using fatal errors for anything
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC there is a later mechanism which checks which extension is allowed for the current target and raises an "SPIR-V module cannot be satisfied" error if an extension/capability is showing but is not allowed.
| ; CHECK-ERROR1: LLVM ERROR: The atomic float instruction requires the following SPIR-V extension: SPV_EXT_shader_atomic_float_add | ||
| ; CHECK-ERROR2: LLVM ERROR: The atomic bfloat16 instruction requires the following SPIR-V extension: SPV_INTEL_16bit_atomics |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's not available it should go through the AtomicExpand legalization
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this pass enabled by default for all the pipelines?
At least if it's not for SPIR-V BE, then the reported error seems more legit in this case.
llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_16bit_atomics/atomicrmw_fminfmax_bfloat16.ll
Outdated
Show resolved
Hide resolved
…micrmw_fminfmax_bfloat16.ll Co-authored-by: Matt Arsenault <arsenm2@gmail.com>
Keenuts
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM (aside from what's already raised)
| Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add); | ||
| Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT); | ||
| if (isBFloat16Type(TypeDef)) { | ||
| if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC there is a later mechanism which checks which extension is allowed for the current target and raises an "SPIR-V module cannot be satisfied" error if an extension/capability is showing but is not allowed.
AlexVlx
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC there is a later mechanism which checks which extension is allowed for the current target and raises an "SPIR-V module cannot be satisfied" error if an extension/capability is showing but is not allowed.
Right, but that is somewhat opaque i.e. the error message at the moment is an approximation of "something happened". We should look at refactoring the pass itself at some point to rather go through the Diag infra, but that feels like a separate PR / piece of work, as currently all of the extension related stuff is done via report_fatal_error.
|
LGTM! Thank you :) |
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/95/builds/19050 Here is the relevant piece of the build log for the reference |
This enables support for atomic RMW ops (add, sub, min and max to be precise) with
bfloat16operands, via the SPV_INTEL_16bit_atomics extension. It's logically a successor to #166031 (I should've used a stack), but I'm putting it up for early review.