diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index 297dde3a67b2a..27abd13b8ddb1 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4512,6 +4512,8 @@ def SPIRV_OC_OpIAddCarry : I32EnumAttrCase<"OpIAddCarry", 1 def SPIRV_OC_OpISubBorrow : I32EnumAttrCase<"OpISubBorrow", 150>; def SPIRV_OC_OpUMulExtended : I32EnumAttrCase<"OpUMulExtended", 151>; def SPIRV_OC_OpSMulExtended : I32EnumAttrCase<"OpSMulExtended", 152>; +def SPIRV_OC_OpAny : I32EnumAttrCase<"OpAny", 154>; +def SPIRV_OC_OpAll : I32EnumAttrCase<"OpAll", 155>; def SPIRV_OC_OpIsNan : I32EnumAttrCase<"OpIsNan", 156>; def SPIRV_OC_OpIsInf : I32EnumAttrCase<"OpIsInf", 157>; def SPIRV_OC_OpIsFinite : I32EnumAttrCase<"OpIsFinite", 158>; @@ -4714,6 +4716,7 @@ def SPIRV_OpcodeAttr : SPIRV_OC_OpMatrixTimesMatrix, SPIRV_OC_OpOuterProduct, SPIRV_OC_OpDot, SPIRV_OC_OpIAddCarry, SPIRV_OC_OpISubBorrow, SPIRV_OC_OpUMulExtended, SPIRV_OC_OpSMulExtended, + SPIRV_OC_OpAny, SPIRV_OC_OpAll, SPIRV_OC_OpIsNan, SPIRV_OC_OpIsInf, SPIRV_OC_OpIsFinite, SPIRV_OC_OpOrdered, SPIRV_OC_OpUnordered, SPIRV_OC_OpLogicalEqual, SPIRV_OC_OpLogicalNotEqual, SPIRV_OC_OpLogicalOr, diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td index 179219042c882..4de76ca884534 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td @@ -48,6 +48,70 @@ class SPIRV_LogicalUnaryOp { + let summary = "Result is true if any component of Vector is true."; + + let description = [{ + Result Type must be a Boolean type scalar. + + Vector must be a vector of Boolean type. + + + + #### Example: + + ```mlir + %result = spirv.Any %vector : vector<4xi1> + ``` + }]; + + let arguments = (ins + SPIRV_VectorOf:$vector + ); + + let results = (outs + SPIRV_Bool:$result + ); + + let assemblyFormat = "$vector attr-dict `:` type($vector)"; + + let hasVerifier = 0; +} + +// ----- + +def SPIRV_AllOp : SPIRV_Op<"All", [Pure]> { + let summary = "Result is true if all components of Vector are true."; + + let description = [{ + Result Type must be a Boolean type scalar. + + Vector must be a vector of Boolean type. + + + + #### Example: + + ```mlir + %result = spirv.All %vector : vector<4xi1> + ``` + }]; + + let arguments = (ins + SPIRV_VectorOf:$vector + ); + + let results = (outs + SPIRV_Bool:$result + ); + + let assemblyFormat = "$vector attr-dict `:` type($vector)"; + + let hasVerifier = 0; +} + +// ----- + def SPIRV_FOrdEqualOp : SPIRV_LogicalBinaryOp<"FOrdEqual", SPIRV_Float, [Commutative]> { let summary = "Floating-point comparison for being ordered and equal."; diff --git a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir index 1018751cf65e0..d8d80ddfae097 100644 --- a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir @@ -1,5 +1,99 @@ // RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s +//===----------------------------------------------------------------------===// +// spirv.Any +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @any_vector2 +func.func @any_vector2(%arg0: vector<2xi1>) -> i1 { + // CHECK: spirv.Any %{{.*}} : vector<2xi1> + %0 = spirv.Any %arg0 : vector<2xi1> + return %0 : i1 +} + +// ----- + +// CHECK-LABEL: @any_vector3 +func.func @any_vector3(%arg0: vector<3xi1>) -> i1 { + // CHECK: spirv.Any %{{.*}} : vector<3xi1> + %0 = spirv.Any %arg0 : vector<3xi1> + return %0 : i1 +} + +// ----- + +// CHECK-LABEL: @any_vector4 +func.func @any_vector4(%arg0: vector<4xi1>) -> i1 { + // CHECK: spirv.Any %{{.*}} : vector<4xi1> + %0 = spirv.Any %arg0 : vector<4xi1> + return %0 : i1 +} + +// ----- + +func.func @any_scalar(%arg0: i1) -> i1 { + // expected-error @+1 {{invalid kind of type specified: expected builtin.vector, but found 'i1'}} + %0 = spirv.Any %arg0 : i1 + return %0 : i1 +} + +// ----- + +func.func @any_wrong_result_type(%arg0: vector<2xi1>) -> vector<2xi1> { + // expected-error @+1 {{'spirv.Any' op result #0 must be bool, but got 'vector<2xi1>'}} + %0 = "spirv.Any"(%arg0) : (vector<2xi1>) -> vector<2xi1> + return %0 : vector<2xi1> +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.All +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @all_vector2 +func.func @all_vector2(%arg0: vector<2xi1>) -> i1 { + // CHECK: spirv.All %{{.*}} : vector<2xi1> + %0 = spirv.All %arg0 : vector<2xi1> + return %0 : i1 +} + +// ----- + +// CHECK-LABEL: @all_vector3 +func.func @all_vector3(%arg0: vector<3xi1>) -> i1 { + // CHECK: spirv.All %{{.*}} : vector<3xi1> + %0 = spirv.All %arg0 : vector<3xi1> + return %0 : i1 +} + +// ----- + +// CHECK-LABEL: @all_vector4 +func.func @all_vector4(%arg0: vector<4xi1>) -> i1 { + // CHECK: spirv.All %{{.*}} : vector<4xi1> + %0 = spirv.All %arg0 : vector<4xi1> + return %0 : i1 +} + +// ----- + +func.func @all_scalar(%arg0: i1) -> i1 { + // expected-error @+1 {{invalid kind of type specified: expected builtin.vector, but found 'i1'}} + %0 = spirv.All %arg0 : i1 + return %0 : i1 +} + +// ----- + +func.func @all_wrong_result_type(%arg0: vector<2xi1>) -> vector<2xi1> { + // expected-error @+1 {{'spirv.All' op result #0 must be bool, but got 'vector<2xi1>'}} + %0 = "spirv.All"(%arg0) : (vector<2xi1>) -> vector<2xi1> + return %0 : vector<2xi1> +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.IEqual //===----------------------------------------------------------------------===// diff --git a/mlir/test/Target/SPIRV/logical-ops.mlir b/mlir/test/Target/SPIRV/logical-ops.mlir index d570815448b7c..83459b980a1d7 100644 --- a/mlir/test/Target/SPIRV/logical-ops.mlir +++ b/mlir/test/Target/SPIRV/logical-ops.mlir @@ -6,6 +6,16 @@ // RUN: %if spirv-tools %{ spirv-val %t %} spirv.module Logical OpenCL requires #spirv.vce { + spirv.func @any_vector(%arg0: vector<4xi1>) "None" { + // CHECK: {{.*}} = spirv.Any {{.*}} : vector<4xi1> + %0 = spirv.Any %arg0 : vector<4xi1> + spirv.Return + } + spirv.func @all_vector(%arg0: vector<4xi1>) "None" { + // CHECK: {{.*}} = spirv.All {{.*}} : vector<4xi1> + %0 = spirv.All %arg0 : vector<4xi1> + spirv.Return + } spirv.func @iequal_scalar(%arg0: i32, %arg1: i32) "None" { // CHECK: {{.*}} = spirv.IEqual {{.*}}, {{.*}} : i32 %0 = spirv.IEqual %arg0, %arg1 : i32