diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index f8093d3042c50..8b9f51d38374c 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4187,6 +4187,15 @@ def SPIRV_INTEL_StoreCacheControlAttr : SPIRV_INTEL_SCC_WriteBack, SPIRV_INTEL_SCC_Streaming ]>; +// Non-uniform quad swap direction attribute +def SPIRV_QuadSwapDirectionAttr : SPIRV_I32EnumAttr< + "QuadSwapDirection", "Swap direction of a GroupNonUniformQuadSwap", "quad_swap_direction", + [ + I32EnumAttrCase<"Horizontal", 0>, + I32EnumAttrCase<"Vertical", 1>, + I32EnumAttrCase<"Diagonal", 2> + ]>; + //===----------------------------------------------------------------------===// // SPIR-V attribute definitions //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td index d5a339115aaaa..37df66372f51c 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td @@ -1657,8 +1657,8 @@ def SPIRV_GroupNonUniformQuadSwapOp : SPIRV_Op<"GroupNonUniformQuadSwap", [ #### Example: ```mlir - %0 = spirv.GroupNonUniformQuadSwap %value %dir : f32, i32 - %1 = spirv.GroupNonUniformQuadSwap %value %dir : vector<4xf32>, i32 + %0 = spirv.GroupNonUniformQuadSwap %value : f32 + %1 = spirv.GroupNonUniformQuadSwap %value : vector<4xf32> ``` }]; @@ -1672,7 +1672,7 @@ def SPIRV_GroupNonUniformQuadSwapOp : SPIRV_Op<"GroupNonUniformQuadSwap", [ let arguments = (ins SPIRV_ScopeAttr:$execution_scope, AnyTypeOf<[SPIRV_ScalarOrVectorOf, SPIRV_ScalarOrVectorOf, SPIRV_ScalarOrVectorOf]>:$value, - SPIRV_SignlessOrUnsignedInt:$direction + SPIRV_QuadSwapDirectionAttr:$direction ); let results = (outs @@ -1682,7 +1682,7 @@ def SPIRV_GroupNonUniformQuadSwapOp : SPIRV_Op<"GroupNonUniformQuadSwap", [ let hasVerifier = 0; let assemblyFormat = [{ - $execution_scope $value $direction attr-dict `:` type($value) `,` type($direction) + $execution_scope $direction $value attr-dict `:` type($value) }]; } diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir index 168823a6e9c2d..5383f7656a1be 100644 --- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir @@ -769,28 +769,29 @@ func.func @group_non_uniform_all_equal(%value: f32) -> i1 { // CHECK-LABEL: @group_non_uniform_quad_swap func.func @group_non_uniform_quad_swap(%value: f32) -> f32 { - %dir = spirv.Constant 0 : i32 - // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap %{{.+}} %{{.+}} : f32, i32 - %0 = spirv.GroupNonUniformQuadSwap %value %dir : f32, i32 - return %0: f32 + // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap %{{.+}} : f32 + %0 = spirv.GroupNonUniformQuadSwap %value : f32 + // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap %{{.+}} : f32 + %1 = spirv.GroupNonUniformQuadSwap %0 : f32 + // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap %{{.+}} : f32 + %2 = spirv.GroupNonUniformQuadSwap %1 : f32 + return %2: f32 } // ----- // CHECK-LABEL: @group_non_uniform_quad_swap func.func @group_non_uniform_quad_swap(%value: vector<4xf32>) -> vector<4xf32> { - %dir = spirv.Constant 0 : i32 - // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap %{{.+}} %{{.+}} : vector<4xf32>, i32 - %0 = spirv.GroupNonUniformQuadSwap %value %dir : vector<4xf32>, i32 + // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap %{{.+}} : vector<4xf32> + %0 = spirv.GroupNonUniformQuadSwap %value : vector<4xf32> return %0: vector<4xf32> } // ----- func.func @group_non_uniform_quad_swap(%value: vector<4xf32>) -> vector<4xf32> { - %dir = spirv.Constant 0 : i32 // expected-error @+1 {{execution_scope must be Scope of value Subgroup}} - %0 = spirv.GroupNonUniformQuadSwap %value %dir : vector<4xf32>, i32 + %0 = spirv.GroupNonUniformQuadSwap %value : vector<4xf32> return %0: vector<4xf32> } @@ -798,16 +799,15 @@ func.func @group_non_uniform_quad_swap(%value: vector<4xf32>) -> vector<4xf32> { func.func @group_non_uniform_quad_swap(%value: vector<4xf32>) -> vector<4xf32> { %dir = spirv.Constant 0.0 : f32 - // expected-error @+1 {{op operand #1 must be 8/16/32/64-bit signless/unsigned integer, but got 'f32'}} - %0 = spirv.GroupNonUniformQuadSwap %value %dir : vector<4xf32>, f32 + // expected-error @+1 {{expected '<'}} + %0 = spirv.GroupNonUniformQuadSwap %dir %value : vector<4xf32> return %0: vector<4xf32> } // ----- func.func @group_non_uniform_quad_swap(%value: !spirv.array<3 x i32>) -> !spirv.array<3 x i32> { - %dir = spirv.Constant 0 : i32 // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 of ranks 1 or 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 of ranks 1 or bool or fixed-length vector of bool values of length 2/3/4/8/16 of ranks 1, but got '!spirv.array<3 x i32>'}} - %0 = spirv.GroupNonUniformQuadSwap %value %dir : !spirv.array<3 x i32>, i32 + %0 = spirv.GroupNonUniformQuadSwap %value : !spirv.array<3 x i32> return %0: !spirv.array<3 x i32> } diff --git a/mlir/test/Target/SPIRV/non-uniform-ops.mlir b/mlir/test/Target/SPIRV/non-uniform-ops.mlir index 0c5df594ac1d7..6975836d3ddee 100644 --- a/mlir/test/Target/SPIRV/non-uniform-ops.mlir +++ b/mlir/test/Target/SPIRV/non-uniform-ops.mlir @@ -149,16 +149,14 @@ spirv.module Logical GLSL450 requires #spirv.vce) -> vector<4xf32> "None" { - %dir = spirv.Constant 0 : i32 - // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap %{{.+}} %{{.+}} : vector<4xf32>, i32 - %0 = spirv.GroupNonUniformQuadSwap %val %dir : vector<4xf32>, i32 + // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap %{{.+}} : vector<4xf32> + %0 = spirv.GroupNonUniformQuadSwap %val : vector<4xf32> spirv.ReturnValue %0: vector<4xf32> } spirv.func @group_non_uniform_quad_swap_scalar(%val: f32) -> f32 "None" { - %dir = spirv.Constant 0 : i32 - // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap %{{.+}} %{{.+}} : f32, i32 - %0 = spirv.GroupNonUniformQuadSwap %val %dir : f32, i32 + // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap %{{.+}} : f32 + %0 = spirv.GroupNonUniformQuadSwap %val : f32 spirv.ReturnValue %0: f32 } } diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp index 0b1771ffcee71..9cb48934b2c10 100644 --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -503,6 +503,7 @@ constexpr llvm::StringLiteral constantIdEnumAttrs[] = { "SPIRV_MatrixLayoutAttr", "SPIRV_TosaExtAccTypeAttr", "SPIRV_TosaExtNaNPropagationModeAttr", + "SPIRV_QuadSwapDirectionAttr", }; /// Generates code to serialize attributes of a SPIRV_Op `op` into `os`. The