From ad63e7f1c5700c4f6b92a0fb3a7de9d876d38169 Mon Sep 17 00:00:00 2001 From: Alex Voicu Date: Mon, 1 Dec 2025 22:18:24 +0000 Subject: [PATCH 1/4] Add support for some FP16 vector atomics, via the `SPV_NV_shader_atomic_fp16_vector` extension. --- llvm/docs/SPIRVUsage.rst | 2 + llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp | 2 + .../Target/SPIRV/SPIRVInstructionSelector.cpp | 3 +- llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp | 8 +++- llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 41 ++++++++++++++++ .../lib/Target/SPIRV/SPIRVSymbolicOperands.td | 3 ++ .../atomicrmw_faddfsub_vec_float16.ll | 47 +++++++++++++++++++ .../atomicrmw_fminfmax_vec_float16.ll | 45 ++++++++++++++++++ 8 files changed, 148 insertions(+), 3 deletions(-) create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_NV_shader_atomic_fp16_vector/atomicrmw_faddfsub_vec_float16.ll create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_NV_shader_atomic_fp16_vector/atomicrmw_fminfmax_vec_float16.ll diff --git a/llvm/docs/SPIRVUsage.rst b/llvm/docs/SPIRVUsage.rst index 88164e6fa53d8..e2f85ba3c2774 100644 --- a/llvm/docs/SPIRVUsage.rst +++ b/llvm/docs/SPIRVUsage.rst @@ -169,6 +169,8 @@ Below is a list of supported SPIR-V extensions, sorted alphabetically by their e - Adds atomic min and max instruction on floating-point numbers. * - ``SPV_INTEL_16bit_atomics`` - Extends the SPV_EXT_shader_atomic_float_add and SPV_EXT_shader_atomic_float_min_max to support addition, minimum and maximum on 16-bit `bfloat16` floating-point numbers in memory. + * - ``SPV_NV_shader_atomic_fp16_vector`` + - Adds atomic add, min and max instructions on 2 or 4-component vectors with 16-bit float components. * - ``SPV_INTEL_2d_block_io`` - Adds additional subgroup block prefetch, load, load transposed, load transformed and store instructions to read two-dimensional blocks of data from a two-dimensional region of memory, or to write two-dimensional blocks of data to a two dimensional region of memory. * - ``SPV_ALTERA_arbitrary_precision_integers`` diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp index 146384f4bf08c..d2a8fddc5d8e4 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp @@ -31,6 +31,8 @@ static const std::map> SPIRV::Extension::Extension::SPV_EXT_shader_atomic_float_min_max}, {"SPV_INTEL_16bit_atomics", SPIRV::Extension::Extension::SPV_INTEL_16bit_atomics}, + {"SPV_NV_shader_atomic_fp16_vector", + SPIRV::Extension::Extension::SPV_NV_shader_atomic_fp16_vector}, {"SPV_EXT_arithmetic_fence", SPIRV::Extension::Extension::SPV_EXT_arithmetic_fence}, {"SPV_EXT_demote_to_helper_invocation", diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index a2e29366dc4cc..c0c60b839a2a9 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -1193,7 +1193,8 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg, case TargetOpcode::G_ATOMICRMW_FSUB: // Translate G_ATOMICRMW_FSUB to OpAtomicFAddEXT with negative value operand return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicFAddEXT, - SPIRV::OpFNegate); + ResType->getOpcode() == SPIRV::OpTypeVector + ? SPIRV::OpFNegateV : SPIRV::OpFNegate); case TargetOpcode::G_ATOMICRMW_FMIN: return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicFMinEXT); case TargetOpcode::G_ATOMICRMW_FMAX: diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp index 9d5a685fdbc84..4bc86b0168f1e 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp @@ -131,6 +131,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64, v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64}; + auto allFloatScalarsAndF16Vector2AndVector4s = {s16, s32, s64, v2s16, v4s16}; + auto allFloatAndIntScalarsAndPtrs = {s8, s16, s32, s64, p0, p1, p2, p3, p4, p5, p6, p7, p8, p10, p11, p12}; @@ -339,10 +341,12 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { getActionDefinitionsBuilder( {G_ATOMICRMW_FADD, G_ATOMICRMW_FSUB, G_ATOMICRMW_FMIN, G_ATOMICRMW_FMAX}) - .legalForCartesianProduct(allFloatScalars, allPtrs); + .legalForCartesianProduct(allFloatScalarsAndF16Vector2AndVector4s, + allPtrs); getActionDefinitionsBuilder(G_ATOMICRMW_XCHG) - .legalForCartesianProduct(allFloatAndIntScalarsAndPtrs, allPtrs); + .legalForCartesianProduct(allFloatScalarsAndF16Vector2AndVector4s, + allPtrs); getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower(); // TODO: add proper legalization rules. diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index 2feb73d8dedfa..73432279c3306 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -14,6 +14,10 @@ // //===----------------------------------------------------------------------===// +// TODO: uses or report_fatal_error (which is also deprecated) / +// ReportFatalUsageError in this file should be refactored, as per LLVM +// best practices, to rely on the Diagnostic infrastructure. + #include "SPIRVModuleAnalysis.h" #include "MCTargetDesc/SPIRVBaseInfo.h" #include "MCTargetDesc/SPIRVMCTargetDesc.h" @@ -1071,6 +1075,39 @@ static bool isBFloat16Type(const SPIRVType *TypeDef) { #define ATOM_FLT_REQ_EXT_MSG(ExtName) \ "The atomic float instruction requires the following SPIR-V " \ "extension: SPV_EXT_shader_atomic_float" ExtName +static void AddAtomicVectorFloatRequirements(const MachineInstr &MI, + SPIRV::RequirementHandler &Reqs, + const SPIRVSubtarget &ST) { + SPIRVType *VecTypeDef = + MI.getMF()->getRegInfo().getVRegDef(MI.getOperand(1).getReg()); + + const unsigned Rank = VecTypeDef->getOperand(2).getImm(); + if (Rank != 2 && Rank != 4) + reportFatalUsageError("Result type of an atomic vector float instruction " + "must be a 2-component or 4 component vector"); + + SPIRVType *EltTypeDef = + MI.getMF()->getRegInfo().getVRegDef(VecTypeDef->getOperand(1).getReg()); + + if (EltTypeDef->getOpcode() != SPIRV::OpTypeFloat || + EltTypeDef->getOperand(1).getImm() != 16) + reportFatalUsageError( + "The element type for the result type of an atomic vector float " + "instruction must be a 16-bit floating-point scalar"); + + if (isBFloat16Type(EltTypeDef)) + reportFatalUsageError( + "The element type for the result type of an atomic vector float " + "instruction cannot be a bfloat16 scalar"); + if (!ST.canUseExtension(SPIRV::Extension::SPV_NV_shader_atomic_fp16_vector)) + reportFatalUsageError( + "The atomic float16 vector instruction requires the following SPIR-V " + "extension: SPV_NV_shader_atomic_fp16_vector"); + + Reqs.addExtension(SPIRV::Extension::SPV_NV_shader_atomic_fp16_vector); + Reqs.addCapability(SPIRV::Capability::AtomicFloat16VectorNV); +} + static void AddAtomicFloatRequirements(const MachineInstr &MI, SPIRV::RequirementHandler &Reqs, const SPIRVSubtarget &ST) { @@ -1078,6 +1115,10 @@ static void AddAtomicFloatRequirements(const MachineInstr &MI, "Expect register operand in atomic float instruction"); Register TypeReg = MI.getOperand(1).getReg(); SPIRVType *TypeDef = MI.getMF()->getRegInfo().getVRegDef(TypeReg); + + if (TypeDef->getOpcode() == SPIRV::OpTypeVector) + return AddAtomicVectorFloatRequirements(MI, Reqs, ST); + if (TypeDef->getOpcode() != SPIRV::OpTypeFloat) report_fatal_error("Result type of an atomic float instruction must be a " "floating-point type scalar"); diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td index 94e0138c66487..078f1dff839ea 100644 --- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td +++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td @@ -391,6 +391,8 @@ defm SPV_INTEL_bfloat16_arithmetic : ExtensionOperand<129, [EnvVulkan, EnvOpenCL]>; defm SPV_INTEL_16bit_atomics : ExtensionOperand<130, [EnvVulkan, EnvOpenCL]>; defm SPV_ALTERA_arbitrary_precision_fixed_point : ExtensionOperand<131, [EnvOpenCL, EnvVulkan]>; +defm SPV_NV_shader_atomic_fp16_vector + : ExtensionOperand<132, [EnvVulkan, EnvOpenCL]>; //===----------------------------------------------------------------------===// // Multiclass used to define Capabilities enum values and at the same time @@ -573,6 +575,7 @@ defm AtomicFloat16MinMaxEXT : CapabilityOperand<5616, 0, 0, [SPV_EXT_shader_atom defm AtomicFloat32MinMaxEXT : CapabilityOperand<5612, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>; defm AtomicFloat64MinMaxEXT : CapabilityOperand<5613, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>; defm AtomicBFloat16MinMaxINTEL : CapabilityOperand<6256, 0, 0, [SPV_INTEL_16bit_atomics], []>; +defm AtomicFloat16VectorNV : CapabilityOperand<5404, 0, 0, [SPV_NV_shader_atomic_fp16_vector], []>; defm VariableLengthArrayINTEL : CapabilityOperand<5817, 0, 0, [SPV_INTEL_variable_length_array], []>; defm GroupUniformArithmeticKHR : CapabilityOperand<6400, 0, 0, [SPV_KHR_uniform_group_instructions], []>; defm USMStorageClassesINTEL : CapabilityOperand<5935, 0, 0, [SPV_INTEL_usm_storage_classes], [Kernel]>; diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_NV_shader_atomic_fp16_vector/atomicrmw_faddfsub_vec_float16.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_NV_shader_atomic_fp16_vector/atomicrmw_faddfsub_vec_float16.ll new file mode 100644 index 0000000000000..36f6e38fc75de --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_NV_shader_atomic_fp16_vector/atomicrmw_faddfsub_vec_float16.ll @@ -0,0 +1,47 @@ +; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR + +; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_NV_shader_atomic_fp16_vector %s -o - | FileCheck %s + +; CHECK-ERROR: LLVM ERROR: The atomic float16 vector instruction requires the following SPIR-V extension: SPV_NV_shader_atomic_fp16_vector + +; CHECK: Capability Float16 +; CHECK-DAG: Capability AtomicFloat16VectorNV +; CHECK: Extension "SPV_NV_shader_atomic_fp16_vector" +; CHECK-DAG: %[[TyF16:[0-9]+]] = OpTypeFloat 16 +; CHECK: %[[TyF16Vec2:[0-9]+]] = OpTypeVector %[[TyF16]] 2 +; CHECK: %[[TyF16Vec4:[0-9]+]] = OpTypeVector %[[TyF16]] 4 +; CHECK: %[[TyF16Vec4Ptr:[0-9]+]] = OpTypePointer {{[a-zA-Z]+}} %[[TyF16Vec4]] +; CHECK: %[[TyF16Vec2Ptr:[0-9]+]] = OpTypePointer {{[a-zA-Z]+}} %[[TyF16Vec2]] +; CHECK: %[[TyInt32:[0-9]+]] = OpTypeInt 32 0 +; CHECK: %[[ConstF16:[0-9]+]] = OpConstant %[[TyF16]] 20800{{$}} +; CHECK: %[[Const0F16Vec2:[0-9]+]] = OpConstantNull %[[TyF16Vec2]] +; CHECK: %[[f:[0-9]+]] = OpVariable %[[TyF16Vec2Ptr]] CrossWorkgroup %[[Const0F16Vec2]] +; CHECK: %[[Const0F16Vec4:[0-9]+]] = OpConstantNull %[[TyF16Vec4]] +; CHECK: %[[g:[0-9]+]] = OpVariable %[[TyF16Vec4Ptr]] CrossWorkgroup %[[Const0F16Vec4]] +; CHECK: %[[ConstF16Vec2:[0-9]+]] = OpConstantComposite %[[TyF16Vec2]] %[[ConstF16]] %[[ConstF16]] +; CHECK: %[[ScopeAllSvmDevices:[0-9]+]] = OpConstantNull %[[TyInt32]] +; CHECK: %[[MemSeqCst:[0-9]+]] = OpConstant %[[TyInt32]] 16{{$}} +; CHECK: %[[ConstF16Vec4:[0-9]+]] = OpConstantComposite %[[TyF16Vec4]] %[[ConstF16]] %[[ConstF16]] %[[ConstF16]] %[[ConstF16]] + +@f = common dso_local local_unnamed_addr addrspace(1) global <2 x half> +@g = common dso_local local_unnamed_addr addrspace(1) global <4 x half> + +; CHECK-DAG: OpAtomicFAddEXT %[[TyF16Vec2]] %[[f]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[ConstF16Vec2]] +; CHECK: %[[NegatedConstF16Vec2:[0-9]+]] = OpFNegate %[[TyF16Vec2]] %[[ConstF16Vec2]] +; CHECK: OpAtomicFAddEXT %[[TyF16Vec2]] %[[f]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[NegatedConstF16Vec2]] +define dso_local spir_func void @test1() local_unnamed_addr { +entry: + %addval = atomicrmw fadd ptr addrspace(1) @f, <2 x half> seq_cst + %subval = atomicrmw fsub ptr addrspace(1) @f, <2 x half> seq_cst + ret void +} + +; CHECK-DAG: OpAtomicFAddEXT %[[TyF16Vec4]] %[[g]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[ConstF16Vec4]] +; CHECK: %[[NegatedConstF16Vec4:[0-9]+]] = OpFNegate %[[TyF16Vec4]] %[[ConstF16Vec4]] +; CHECK: OpAtomicFAddEXT %[[TyF16Vec4]] %[[g]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[NegatedConstF16Vec4]] +define dso_local spir_func void @test2() local_unnamed_addr { +entry: + %addval = atomicrmw fadd ptr addrspace(1) @g, <4 x half> seq_cst + %subval = atomicrmw fsub ptr addrspace(1) @g, <4 x half> seq_cst + ret void +} \ No newline at end of file diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_NV_shader_atomic_fp16_vector/atomicrmw_fminfmax_vec_float16.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_NV_shader_atomic_fp16_vector/atomicrmw_fminfmax_vec_float16.ll new file mode 100644 index 0000000000000..7ac772bf5d094 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_NV_shader_atomic_fp16_vector/atomicrmw_fminfmax_vec_float16.ll @@ -0,0 +1,45 @@ +; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR + +; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_NV_shader_atomic_fp16_vector %s -o - | FileCheck %s + +; CHECK-ERROR: LLVM ERROR: The atomic float16 vector instruction requires the following SPIR-V extension: SPV_NV_shader_atomic_fp16_vector + +; CHECK: Capability Float16 +; CHECK-DAG: Capability AtomicFloat16VectorNV +; CHECK: Extension "SPV_NV_shader_atomic_fp16_vector" +; CHECK-DAG: %[[TyF16:[0-9]+]] = OpTypeFloat 16 +; CHECK: %[[TyF16Vec2:[0-9]+]] = OpTypeVector %[[TyF16]] 2 +; CHECK: %[[TyF16Vec4:[0-9]+]] = OpTypeVector %[[TyF16]] 4 +; CHECK: %[[TyF16Vec4Ptr:[0-9]+]] = OpTypePointer {{[a-zA-Z]+}} %[[TyF16Vec4]] +; CHECK: %[[TyF16Vec2Ptr:[0-9]+]] = OpTypePointer {{[a-zA-Z]+}} %[[TyF16Vec2]] +; CHECK: %[[TyInt32:[0-9]+]] = OpTypeInt 32 0 +; CHECK: %[[ConstF16:[0-9]+]] = OpConstant %[[TyF16]] 20800{{$}} +; CHECK: %[[Const0F16Vec2:[0-9]+]] = OpConstantNull %[[TyF16Vec2]] +; CHECK: %[[f:[0-9]+]] = OpVariable %[[TyF16Vec2Ptr]] CrossWorkgroup %[[Const0F16Vec2]] +; CHECK: %[[Const0F16Vec4:[0-9]+]] = OpConstantNull %[[TyF16Vec4]] +; CHECK: %[[g:[0-9]+]] = OpVariable %[[TyF16Vec4Ptr]] CrossWorkgroup %[[Const0F16Vec4]] +; CHECK: %[[ConstF16Vec2:[0-9]+]] = OpConstantComposite %[[TyF16Vec2]] %[[ConstF16]] %[[ConstF16]] +; CHECK: %[[ScopeAllSvmDevices:[0-9]+]] = OpConstantNull %[[TyInt32]] +; CHECK: %[[MemSeqCst:[0-9]+]] = OpConstant %[[TyInt32]] 16{{$}} +; CHECK: %[[ConstF16Vec4:[0-9]+]] = OpConstantComposite %[[TyF16Vec4]] %[[ConstF16]] %[[ConstF16]] %[[ConstF16]] %[[ConstF16]] + +@f = common dso_local local_unnamed_addr addrspace(1) global <2 x half> +@g = common dso_local local_unnamed_addr addrspace(1) global <4 x half> + +; CHECK-DAG: OpAtomicFMinEXT %[[TyF16Vec2]] %[[f]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[ConstF16Vec2]] +; CHECK: OpAtomicFMaxEXT %[[TyF16Vec2]] %[[f]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[ConstF16Vec2]] +define dso_local spir_func void @test1() local_unnamed_addr { +entry: + %minval = atomicrmw fmin ptr addrspace(1) @f, <2 x half> seq_cst + %maxval = atomicrmw fmax ptr addrspace(1) @f, <2 x half> seq_cst + ret void +} + +; CHECK-DAG: OpAtomicFMinEXT %[[TyF16Vec4]] %[[g]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[ConstF16Vec4]] +; CHECK: OpAtomicFMaxEXT %[[TyF16Vec4]] %[[g]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[ConstF16Vec4]] +define dso_local spir_func void @test2() local_unnamed_addr { +entry: + %minval = atomicrmw fmin ptr addrspace(1) @g, <4 x half> seq_cst + %maxval = atomicrmw fmax ptr addrspace(1) @g, <4 x half> seq_cst + ret void +} \ No newline at end of file From 38a0f78f892d40c54d81c6e2e2b3eeedde6f8244 Mon Sep 17 00:00:00 2001 From: Alex Voicu Date: Mon, 1 Dec 2025 22:26:48 +0000 Subject: [PATCH 2/4] Fix formatting. --- llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index c0c60b839a2a9..f5bcc5e05428f 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -1194,7 +1194,8 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg, // Translate G_ATOMICRMW_FSUB to OpAtomicFAddEXT with negative value operand return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicFAddEXT, ResType->getOpcode() == SPIRV::OpTypeVector - ? SPIRV::OpFNegateV : SPIRV::OpFNegate); + ? SPIRV::OpFNegateV + : SPIRV::OpFNegate); case TargetOpcode::G_ATOMICRMW_FMIN: return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicFMinEXT); case TargetOpcode::G_ATOMICRMW_FMAX: From db23857fe98995e912256bbac39e978196fd78b1 Mon Sep 17 00:00:00 2001 From: Alex Voicu Date: Mon, 1 Dec 2025 22:50:41 +0000 Subject: [PATCH 3/4] Remove now unused `allFloatScalars`. --- llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp index 4bc86b0168f1e..8d6292190337b 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp @@ -125,14 +125,12 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { auto allIntScalars = {s8, s16, s32, s64}; - auto allFloatScalars = {s16, s32, s64}; + auto allFloatScalarsAndF16Vector2AndVector4s = {s16, s32, s64, v2s16, v4s16}; auto allFloatScalarsAndVectors = { s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64, v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64}; - auto allFloatScalarsAndF16Vector2AndVector4s = {s16, s32, s64, v2s16, v4s16}; - auto allFloatAndIntScalarsAndPtrs = {s8, s16, s32, s64, p0, p1, p2, p3, p4, p5, p6, p7, p8, p10, p11, p12}; From 9c058804480aeea9b98b7e0a686917e95cbfce2b Mon Sep 17 00:00:00 2001 From: Alex Voicu Date: Mon, 1 Dec 2025 23:32:42 +0000 Subject: [PATCH 4/4] Fix accidental noise. --- llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp index 8d6292190337b..3d8a9a5ae384b 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp @@ -343,8 +343,7 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { allPtrs); getActionDefinitionsBuilder(G_ATOMICRMW_XCHG) - .legalForCartesianProduct(allFloatScalarsAndF16Vector2AndVector4s, - allPtrs); + .legalForCartesianProduct(allFloatAndIntScalarsAndPtrs, allPtrs); getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower(); // TODO: add proper legalization rules.