diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index dd5d00de4147d..fc9e2f11092e6 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4588,6 +4588,7 @@ def SPIRV_OC_OpGroupNonUniformAll : I32EnumAttrCase<"OpGroupNonUnifo def SPIRV_OC_OpGroupNonUniformAny : I32EnumAttrCase<"OpGroupNonUniformAny", 335>; def SPIRV_OC_OpGroupNonUniformAllEqual : I32EnumAttrCase<"OpGroupNonUniformAllEqual", 336>; def SPIRV_OC_OpGroupNonUniformBroadcast : I32EnumAttrCase<"OpGroupNonUniformBroadcast", 337>; +def SPIRV_OC_OpGroupNonUniformBroadcastFirst : I32EnumAttrCase<"OpGroupNonUniformBroadcastFirst", 338>; def SPIRV_OC_OpGroupNonUniformBallot : I32EnumAttrCase<"OpGroupNonUniformBallot", 339>; def SPIRV_OC_OpGroupNonUniformBallotBitCount : I32EnumAttrCase<"OpGroupNonUniformBallotBitCount", 342>; def SPIRV_OC_OpGroupNonUniformBallotFindLSB : I32EnumAttrCase<"OpGroupNonUniformBallotFindLSB", 343>; @@ -4727,7 +4728,8 @@ def SPIRV_OpcodeAttr : SPIRV_OC_OpGroupSMax, SPIRV_OC_OpNoLine, SPIRV_OC_OpModuleProcessed, SPIRV_OC_OpGroupNonUniformElect, SPIRV_OC_OpGroupNonUniformAll, SPIRV_OC_OpGroupNonUniformAny, SPIRV_OC_OpGroupNonUniformAllEqual, - SPIRV_OC_OpGroupNonUniformBroadcast, SPIRV_OC_OpGroupNonUniformBallot, + SPIRV_OC_OpGroupNonUniformBroadcast, SPIRV_OC_OpGroupNonUniformBroadcastFirst, + SPIRV_OC_OpGroupNonUniformBallot, SPIRV_OC_OpGroupNonUniformBallotBitCount, SPIRV_OC_OpGroupNonUniformBallotFindLSB, SPIRV_OC_OpGroupNonUniformBallotFindMSB, SPIRV_OC_OpGroupNonUniformShuffle, diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td index 784eb40141b74..7ede319f85a5b 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td @@ -269,6 +269,60 @@ def SPIRV_GroupNonUniformBroadcastOp : SPIRV_Op<"GroupNonUniformBroadcast", // ----- +def SPIRV_GroupNonUniformBroadcastFirstOp : SPIRV_Op<"GroupNonUniformBroadcastFirst", + [Pure, SPIRV_ExecutionScopeAttrIs<"execution_scope", "Subgroup">, AllTypesMatch<["value", "result"]>]> { + let summary = [{ + Broadcast the value from the active invocation with the lowest id in + the subgroup. + }]; + + let description = [{ + Result is the Value of the invocation from the active invocations with + the lowest id within the Execution scope to all active invocations + within the Execution scope. + + Result Type must be a scalar or vector of floating-point type, integer + type, or Boolean type. + + Execution must be Subgroup Scope. + + The type of Value must be the same as Result Type. + + #### Example: + + ```mlir + %scalar_value = ... : f32 + %vector_value = ... : vector<4xf32> + %0 = spirv.GroupNonUniformBroadcastFirst %scalar_value : f32 + %1 = spirv.GroupNonUniformBroadcastFirst %vector_value : vector<4xf32> + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPIRV_C_GroupNonUniformBallot]> + ]; + + let arguments = (ins + SPIRV_ScopeAttr:$execution_scope, + AnyTypeOf<[SPIRV_ScalarOrVectorOf, SPIRV_ScalarOrVectorOf, SPIRV_ScalarOrVectorOf]>:$value + ); + + let results = (outs + AnyTypeOf<[SPIRV_ScalarOrVectorOf, SPIRV_ScalarOrVectorOf, SPIRV_ScalarOrVectorOf]>:$result + ); + + let hasVerifier = 0; + + let assemblyFormat = [{ + $execution_scope operands attr-dict `:` type($value) + }]; +} + +// ----- + def SPIRV_GroupNonUniformElectOp : SPIRV_Op<"GroupNonUniformElect", []> { let summary = [{ Result is true only in the active invocation with the lowest id in the diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir index 5383f7656a1be..2ca7601a1c9ad 100644 --- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir @@ -123,6 +123,53 @@ func.func @group_non_uniform_broadcast_negative_non_const(%value: f32, %localid: // ----- +//===----------------------------------------------------------------------===// +// spirv.GroupNonUniformBroadcastFirst +//===----------------------------------------------------------------------===// + +// ----- + +func.func @group_non_uniform_broadcast_first_scalar(%value: f32) -> f32 { + // CHECK: spirv.GroupNonUniformBroadcastFirst %{{.*}} : f32 + %0 = spirv.GroupNonUniformBroadcastFirst %value : f32 + return %0 : f32 +} + +// ----- + +func.func @group_non_uniform_broadcast_first_vector(%value: vector<4xf32>) -> vector<4xf32> { + // CHECK: spirv.GroupNonUniformBroadcastFirst %{{.*}} : vector<4xf32> + %0 = spirv.GroupNonUniformBroadcastFirst %value : vector<4xf32> + return %0: vector<4xf32> +} + +// ----- + +func.func @group_non_uniform_broadcast_first_negative_scope(%value: f32) -> f32 { + // expected-error @+1 {{execution_scope must be Scope of value Subgroup}} + %0 = spirv.GroupNonUniformBroadcastFirst %value : f32 + return %0 : f32 +} + +// ----- + + +func.func @group_non_uniform_broadcast_first_negative_type(%value: !spirv.array<3 x i32>) -> !spirv.array<3 x i32> { + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 of ranks 1 or 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 of ranks 1 or bool or fixed-length vector of bool values of length 2/3/4/8/16 of ranks 1, but got '!spirv.array<3 x i32>'}} + %0 = spirv.GroupNonUniformBroadcastFirst %value : !spirv.array<3 x i32> + return %0 : !spirv.array<3 x i32> +} + +// ----- + +func.func @group_non_uniform_broadcast_first_negative_type_mismatch(%value: f32) -> i32 { + // expected-error @+1 {{'spirv.GroupNonUniformBroadcastFirst' op failed to verify that all of {value, result} have same type}} + %0 = "spirv.GroupNonUniformBroadcastFirst"(%value) {execution_scope = #spirv.scope} : (f32) -> i32 + return %0 : i32 +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.GroupNonUniformElect //===----------------------------------------------------------------------===// diff --git a/mlir/test/Target/SPIRV/non-uniform-ops.mlir b/mlir/test/Target/SPIRV/non-uniform-ops.mlir index 6975836d3ddee..1dbc4a43395a4 100644 --- a/mlir/test/Target/SPIRV/non-uniform-ops.mlir +++ b/mlir/test/Target/SPIRV/non-uniform-ops.mlir @@ -21,6 +21,13 @@ spirv.module Logical GLSL450 requires #spirv.vce f32 "None" { + // CHECK: spirv.GroupNonUniformBroadcastFirst %{{.*}} : f32 + %0 = spirv.GroupNonUniformBroadcastFirst %value : f32 + spirv.ReturnValue %0: f32 + } + // CHECK-LABEL: @group_non_uniform_elect spirv.func @group_non_uniform_elect() -> i1 "None" { // CHECK: %{{.+}} = spirv.GroupNonUniformElect : i1