diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index d6b1e9552fbc5..1b5a8728dd3f8 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -182,6 +182,11 @@ struct ElementwiseArithOpPattern final : OpConversionPattern { matchAndRewrite(Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { assert(adaptor.getOperands().size() <= 3); + // Reject boolean types to allow specialized boolean patterns to handle + // them (e.g., addi/subi on i1 should use LogicalNotEqual, not IAdd/ISub). + if (!adaptor.getOperands().empty() && + isBoolScalarOrVector(adaptor.getOperands().front().getType())) + return failure(); auto converter = this->template getTypeConverter(); Type dstType = converter->convertType(op.getType()); if (!dstType) { @@ -572,6 +577,98 @@ struct XOrIOpBooleanPattern final : public OpConversionPattern { } }; +/// Converts an arith integer op to the given SPIR-V boolean op if the type is +/// i1 or vector of i1. Each mapping follows from the boolean truth table of +/// the operation: +/// addi(a, b) = a ^ b (add mod 2 = XOR = LogicalNotEqual) +/// subi(a, b) = a ^ b (sub mod 2 = XOR = LogicalNotEqual) +/// muli(a, b) = a & b (1*1=1, else 0 = LogicalAnd) +/// divui(a, b) = a & b (a/1=a, a/0=UB; truth table matches AND) +/// divsi(a, b) = a & b (same as divui on i1) +/// maxsi(a, b) = a & b (signed i1: 1 represents -1, so max is 0 unless both +/// are 1) +/// maxui(a, b) = a | b (unsigned max on i1: 1 when either operand is 1) +/// minsi(a, b) = a | b (signed i1: -1 < 0, so min is 1 when either operand +/// is 1) +/// minui(a, b) = a & b (unsigned min on i1: 1 only when both operands are +/// 1) +template +struct BoolIOpPattern final : public OpConversionPattern { + BoolIOpPattern(const TypeConverter &converter, MLIRContext *context) + // benefit=2: takes priority over the generic ElementwiseArithOpPattern + // (benefit=1) when the operand type is i1. + : OpConversionPattern(converter, context, /*benefit=*/2) {} + + LogicalResult + matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isBoolScalarOrVector(adaptor.getOperands().front().getType())) + return failure(); + + Type dstType = this->getTypeConverter()->convertType(op.getType()); + if (!dstType) + return getTypeConversionFailure(rewriter, op); + + rewriter.replaceOpWithNewOp(op, dstType, adaptor.getOperands()); + return success(); + } +}; + +/// Converts an arith binary op on i1 to spirv.LogicalAnd(lhs, +/// spirv.LogicalNot(rhs)). This covers shift-left, shift-right-unsigned, and +/// unsigned remainder on i1: +/// shli(a, b) = a & ~b (shift left clears the bit when b=1) +/// shrui(a, b) = a & ~b (shift right unsigned clears the bit when b=1) +/// remui(a, b) = a & ~b (only defined when b=1; a%1=0, and ~b=~1=0, so AND +/// gives 0) +/// remsi(a, b) = a & ~b (only defined when b=1; a%1=0, and ~b=~1=0, so AND +/// gives 0) +template +struct BoolIOpAndNotPattern final : public OpConversionPattern { + BoolIOpAndNotPattern(const TypeConverter &converter, MLIRContext *context) + // benefit=2: takes priority over the generic ElementwiseArithOpPattern + // (benefit=1) when the operand type is i1. + : OpConversionPattern(converter, context, /*benefit=*/2) {} + + LogicalResult + matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isBoolScalarOrVector(adaptor.getOperands().front().getType())) + return failure(); + + Type dstType = this->getTypeConverter()->convertType(op.getType()); + if (!dstType) + return getTypeConversionFailure(rewriter, op); + + Location loc = op.getLoc(); + Value notRhs = spirv::LogicalNotOp::create(rewriter, loc, dstType, + adaptor.getOperands()[1]); + rewriter.replaceOpWithNewOp( + op, dstType, adaptor.getOperands()[0], notRhs); + return success(); + } +}; + +/// Converts arith.shrsi on i1 to identity: arithmetic right shift of a 1-bit +/// signed value always yields the original value (0 >> n = 0, -1 >> n = -1). +struct ShRSIBoolPattern final : public OpConversionPattern { + ShRSIBoolPattern(const TypeConverter &converter, MLIRContext *context) + // benefit=2: takes priority over the generic spirv::ElementwiseOpPattern + // (benefit=1) when the operand type is i1. + : OpConversionPattern(converter, context, + /*benefit=*/2) {} + + LogicalResult + matchAndRewrite(arith::ShRSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isBoolScalarOrVector(adaptor.getOperands().front().getType())) + return failure(); + + rewriter.replaceOp(op, adaptor.getOperands().front()); + return success(); + } +}; + //===----------------------------------------------------------------------===// // UIToFPOp //===----------------------------------------------------------------------===// @@ -1410,18 +1507,28 @@ void mlir::arith::populateArithToSPIRVPatterns( patterns.add< ConstantCompositeOpPattern, ConstantScalarOpPattern, + BoolIOpPattern, // add mod 2 = XOR = not-equal ElementwiseArithOpPattern, + BoolIOpPattern, // sub mod 2 = XOR = not-equal ElementwiseArithOpPattern, + BoolIOpPattern, // 1*1=1, else 0 = AND ElementwiseArithOpPattern, + BoolIOpPattern, // a/1=a, a/0=UB; truth table = AND spirv::ElementwiseOpPattern, + BoolIOpPattern, // same as divui on i1 spirv::ElementwiseOpPattern, + BoolIOpAndNotPattern, // remui(a,b) = a & ~b (see pattern comment) spirv::ElementwiseOpPattern, + BoolIOpAndNotPattern, // remsi(a,b) = a & ~b (see pattern comment) RemSIOpGLPattern, RemSIOpCLPattern, BitwiseOpPattern, BitwiseOpPattern, XOrIOpLogicalPattern, XOrIOpBooleanPattern, + BoolIOpAndNotPattern, // shli(a,b) = a & ~b (see pattern comment) ElementwiseArithOpPattern, + BoolIOpAndNotPattern, // shrui(a,b) = a & ~b (see pattern comment) spirv::ElementwiseOpPattern, + ShRSIBoolPattern, // shrsi(a,b) = a (identity; see pattern comment) spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, @@ -1454,6 +1561,10 @@ void mlir::arith::populateArithToSPIRVPatterns( MinimumMaximumFOpPattern, MinNumMaxNumFOpPattern, MinNumMaxNumFOpPattern, + BoolIOpPattern, // signed i1: 1=-1, so max=0 unless both are 1 + BoolIOpPattern, // unsigned max on i1: 1 when either is 1 + BoolIOpPattern, // signed i1: -1<0, so min=1 when either is 1 + BoolIOpPattern, // unsigned min on i1: 1 only when both are 1 spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir index 9c726b8643a46..31b70177a0d19 100644 --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir @@ -272,6 +272,28 @@ func.func @logical_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) { return } +// CHECK-LABEL: @bool_arith_scalar +func.func @bool_arith_scalar(%arg0 : i1, %arg1 : i1) { + // CHECK: spirv.LogicalNotEqual + %0 = arith.addi %arg0, %arg1 : i1 + // CHECK: spirv.LogicalNotEqual + %1 = arith.subi %arg0, %arg1 : i1 + // CHECK: spirv.LogicalAnd + %2 = arith.muli %arg0, %arg1 : i1 + return +} + +// CHECK-LABEL: @bool_arith_vector +func.func @bool_arith_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) { + // CHECK: spirv.LogicalNotEqual + %0 = arith.addi %arg0, %arg1 : vector<4xi1> + // CHECK: spirv.LogicalNotEqual + %1 = arith.subi %arg0, %arg1 : vector<4xi1> + // CHECK: spirv.LogicalAnd + %2 = arith.muli %arg0, %arg1 : vector<4xi1> + return +} + // CHECK-LABEL: @shift_scalar func.func @shift_scalar(%arg0 : i32, %arg1 : i32) { // CHECK: spirv.ShiftLeftLogical @@ -298,6 +320,63 @@ func.func @shift_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) { // ----- +// Test i1 lowerings for shift, div, rem, and min/max ops. + +module attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + +// CHECK-LABEL: @bool_shift_div_rem_scalar +func.func @bool_shift_div_rem_scalar(%arg0 : i1, %arg1 : i1) { + // shli(a,b) = a & ~b + // CHECK: %[[NOTB:.+]] = spirv.LogicalNot %arg1 + // CHECK: spirv.LogicalAnd %arg0, %[[NOTB]] + %0 = arith.shli %arg0, %arg1 : i1 + // shrui(a,b) = a & ~b + // CHECK: %[[NOTB:.+]] = spirv.LogicalNot %arg1 + // CHECK: spirv.LogicalAnd %arg0, %[[NOTB]] + %1 = arith.shrui %arg0, %arg1 : i1 + // shrsi(a,b) = a (arithmetic right shift of i1 is identity) + // CHECK-NOT: spirv.ShiftRightArithmetic + %2 = arith.shrsi %arg0, %arg1 : i1 + // divui(a,b) = a & b (only valid for b=1, same as muli) + // CHECK: spirv.LogicalAnd %arg0, %arg1 + %3 = arith.divui %arg0, %arg1 : i1 + // divsi(a,b) = a & b (only non-UB/non-overflow case: 0/-1 = 0) + // CHECK: spirv.LogicalAnd %arg0, %arg1 + %4 = arith.divsi %arg0, %arg1 : i1 + // remui(a,b) = a & ~b (a % 1 = 0 for valid b=1) + // CHECK: %[[NOTB:.+]] = spirv.LogicalNot %arg1 + // CHECK: spirv.LogicalAnd %arg0, %[[NOTB]] + %5 = arith.remui %arg0, %arg1 : i1 + // remsi(a,b) = a & ~b + // CHECK: %[[NOTB:.+]] = spirv.LogicalNot %arg1 + // CHECK: spirv.LogicalAnd %arg0, %[[NOTB]] + %6 = arith.remsi %arg0, %arg1 : i1 + return +} + +// CHECK-LABEL: @bool_minmax_scalar +func.func @bool_minmax_scalar(%arg0 : i1, %arg1 : i1) { + // maxui(a,b) = a | b (unsigned max of two booleans is OR) + // CHECK: spirv.LogicalOr %arg0, %arg1 + %0 = arith.maxui %arg0, %arg1 : i1 + // maxsi(a,b) = a & b (signed max: max(-1,0)=0, so max(true,false)=false → AND) + // CHECK: spirv.LogicalAnd %arg0, %arg1 + %1 = arith.maxsi %arg0, %arg1 : i1 + // minui(a,b) = a & b (unsigned min of two booleans is AND) + // CHECK: spirv.LogicalAnd %arg0, %arg1 + %2 = arith.minui %arg0, %arg1 : i1 + // minsi(a,b) = a | b (signed min: min(-1,0)=-1, so min(true,false)=true → OR) + // CHECK: spirv.LogicalOr %arg0, %arg1 + %3 = arith.minsi %arg0, %arg1 : i1 + return +} + +} // end module + +// ----- + //===----------------------------------------------------------------------===// // arith.cmpf //===----------------------------------------------------------------------===//