diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index c101a95685a25..2149892cc603d 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -435,22 +435,28 @@ struct VectorReductionPattern final : OpConversionPattern { result = fop::create(rewriter, loc, resultType, result, next); \ break +#define INT_CASE(kind, iop) \ + case vector::CombiningKind::kind: \ + assert(isa(resultType)); \ + result = spirv::iop::create(rewriter, loc, resultType, result, next); \ + break + INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp); INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp); INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp); INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp); INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp); INT_OR_FLOAT_CASE(MAXSI, SPIRVSMaxOp); + INT_CASE(AND, BitwiseAndOp); + INT_CASE(OR, BitwiseOrOp); + INT_CASE(XOR, BitwiseXorOp); - case vector::CombiningKind::AND: - case vector::CombiningKind::OR: - case vector::CombiningKind::XOR: - return rewriter.notifyMatchFailure(reduceOp, "unimplemented"); default: return rewriter.notifyMatchFailure(reduceOp, "not handled here"); } #undef INT_AND_FLOAT_CASE #undef INT_OR_FLOAT_CASE +#undef INT_CASE } rewriter.replaceOp(reduceOp, result); diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index c399250151261..9eaf99203836a 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -856,6 +856,105 @@ func.func @reduction_minui(%v : vector<3xi32>, %s: i32) -> i32 { // ----- +// CHECK-LABEL: func @reduction_and +// CHECK-SAME: (%[[V:.+]]: vector<4xi32>) +// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<4xi32> +// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<4xi32> +// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<4xi32> +// CHECK: %[[S3:.+]] = spirv.CompositeExtract %[[V]][3 : i32] : vector<4xi32> +// CHECK: %[[AND0:.+]] = spirv.BitwiseAnd %[[S0]], %[[S1]] +// CHECK: %[[AND1:.+]] = spirv.BitwiseAnd %[[AND0]], %[[S2]] +// CHECK: %[[AND2:.+]] = spirv.BitwiseAnd %[[AND1]], %[[S3]] +// CHECK: return %[[AND2]] +func.func @reduction_and(%v : vector<4xi32>) -> i32 { + %reduce = vector.reduction , %v : vector<4xi32> into i32 + return %reduce : i32 +} + +// ----- + +// CHECK-LABEL: func @reduction_and_acc +// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32) +// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xi32> +// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xi32> +// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xi32> +// CHECK: %[[AND0:.+]] = spirv.BitwiseAnd %[[S0]], %[[S1]] +// CHECK: %[[AND1:.+]] = spirv.BitwiseAnd %[[AND0]], %[[S2]] +// CHECK: %[[AND2:.+]] = spirv.BitwiseAnd %[[AND1]], %[[S]] +// CHECK: return %[[AND2]] +func.func @reduction_and_acc(%v : vector<3xi32>, %s: i32) -> i32 { + %reduce = vector.reduction , %v, %s : vector<3xi32> into i32 + return %reduce : i32 +} + +// ----- + +// CHECK-LABEL: func @reduction_or +// CHECK-SAME: (%[[V:.+]]: vector<4xi32>) +// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<4xi32> +// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<4xi32> +// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<4xi32> +// CHECK: %[[S3:.+]] = spirv.CompositeExtract %[[V]][3 : i32] : vector<4xi32> +// CHECK: %[[OR0:.+]] = spirv.BitwiseOr %[[S0]], %[[S1]] +// CHECK: %[[OR1:.+]] = spirv.BitwiseOr %[[OR0]], %[[S2]] +// CHECK: %[[OR2:.+]] = spirv.BitwiseOr %[[OR1]], %[[S3]] +// CHECK: return %[[OR2]] +func.func @reduction_or(%v : vector<4xi32>) -> i32 { + %reduce = vector.reduction , %v : vector<4xi32> into i32 + return %reduce : i32 +} + +// ----- + +// CHECK-LABEL: func @reduction_or_acc +// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32) +// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xi32> +// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xi32> +// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xi32> +// CHECK: %[[OR0:.+]] = spirv.BitwiseOr %[[S0]], %[[S1]] +// CHECK: %[[OR1:.+]] = spirv.BitwiseOr %[[OR0]], %[[S2]] +// CHECK: %[[OR2:.+]] = spirv.BitwiseOr %[[OR1]], %[[S]] +// CHECK: return %[[OR2]] +func.func @reduction_or_acc(%v : vector<3xi32>, %s: i32) -> i32 { + %reduce = vector.reduction , %v, %s : vector<3xi32> into i32 + return %reduce : i32 +} + +// ----- + +// CHECK-LABEL: func @reduction_xor +// CHECK-SAME: (%[[V:.+]]: vector<4xi32>) +// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<4xi32> +// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<4xi32> +// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<4xi32> +// CHECK: %[[S3:.+]] = spirv.CompositeExtract %[[V]][3 : i32] : vector<4xi32> +// CHECK: %[[XOR0:.+]] = spirv.BitwiseXor %[[S0]], %[[S1]] +// CHECK: %[[XOR1:.+]] = spirv.BitwiseXor %[[XOR0]], %[[S2]] +// CHECK: %[[XOR2:.+]] = spirv.BitwiseXor %[[XOR1]], %[[S3]] +// CHECK: return %[[XOR2]] +func.func @reduction_xor(%v : vector<4xi32>) -> i32 { + %reduce = vector.reduction , %v : vector<4xi32> into i32 + return %reduce : i32 +} + +// ----- + +// CHECK-LABEL: func @reduction_xor_acc +// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32) +// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xi32> +// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xi32> +// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xi32> +// CHECK: %[[XOR0:.+]] = spirv.BitwiseXor %[[S0]], %[[S1]] +// CHECK: %[[XOR1:.+]] = spirv.BitwiseXor %[[XOR0]], %[[S2]] +// CHECK: %[[XOR2:.+]] = spirv.BitwiseXor %[[XOR1]], %[[S]] +// CHECK: return %[[XOR2]] +func.func @reduction_xor_acc(%v : vector<3xi32>, %s: i32) -> i32 { + %reduce = vector.reduction , %v, %s : vector<3xi32> into i32 + return %reduce : i32 +} + +// ----- + module attributes { spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { // CHECK-LABEL: func @reduction_bf16_addf_mulf