diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp index 61f8d82a615d8..3226b5d99114a 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp @@ -866,6 +866,46 @@ struct ConvertShRSI final : OpConversionPattern { } }; +//===----------------------------------------------------------------------===// +// ConvertSubI +//===----------------------------------------------------------------------===// + +struct ConvertSubI final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::SubIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto newTy = getTypeConverter()->convertType(op.getType()); + if (!newTy) + return rewriter.notifyMatchFailure( + loc, llvm::formatv("unsupported type: {}", op.getType())); + + Type newElemTy = reduceInnermostDim(newTy); + + auto [lhsElem0, lhsElem1] = + extractLastDimHalves(rewriter, loc, adaptor.getLhs()); + auto [rhsElem0, rhsElem1] = + extractLastDimHalves(rewriter, loc, adaptor.getRhs()); + + // Emulates LHS - RHS by [LHS0 - RHS0, LHS1 - RHS1 - CARRY] where + // CARRY is 1 or 0. + Value low = rewriter.create(loc, lhsElem0, rhsElem0); + // We have a carry if lhsElem0 < rhsElem0. + Value carry0 = rewriter.create( + loc, arith::CmpIPredicate::ult, lhsElem0, rhsElem0); + Value carryVal = rewriter.create(loc, newElemTy, carry0); + + Value high0 = rewriter.create(loc, lhsElem1, carryVal); + Value high = rewriter.create(loc, high0, rhsElem1); + + Value resultVec = constructResultVector(rewriter, loc, newTy, {low, high}); + rewriter.replaceOp(op, resultVec); + return success(); + } +}; + //===----------------------------------------------------------------------===// // ConvertSIToFP //===----------------------------------------------------------------------===// @@ -885,22 +925,16 @@ struct ConvertSIToFP final : OpConversionPattern { return rewriter.notifyMatchFailure( loc, llvm::formatv("unsupported type: {0}", oldTy)); - unsigned oldBitWidth = getElementTypeOrSelf(oldTy).getIntOrFloatBitWidth(); Value zeroCst = createScalarOrSplatConstant(rewriter, loc, oldTy, 0); - Value oneCst = createScalarOrSplatConstant(rewriter, loc, oldTy, 1); - Value allOnesCst = createScalarOrSplatConstant( - rewriter, loc, oldTy, APInt::getAllOnes(oldBitWidth)); // To avoid operating on very large unsigned numbers, perform the // conversion on the absolute value. Then, decide whether to negate the - // result or not based on that sign bit. We assume two's complement and - // implement negation by flipping all bits and adding 1. - // Note that this relies on the the other conversion patterns to legalize - // created ops and narrow the bit widths. + // result or not based on that sign bit. We implement negation by + // subtracting from zero. Note that this relies on the the other conversion + // patterns to legalize created ops and narrow the bit widths. Value isNeg = rewriter.create(loc, arith::CmpIPredicate::slt, in, zeroCst); - Value bitwiseNeg = rewriter.create(loc, in, allOnesCst); - Value neg = rewriter.create(loc, bitwiseNeg, oneCst); + Value neg = rewriter.create(loc, zeroCst, in); Value abs = rewriter.create(loc, isNeg, neg, in); Value absResult = rewriter.create(loc, op.getType(), abs); @@ -1139,7 +1173,7 @@ void arith::populateArithWideIntEmulationPatterns( ConvertMaxMin, ConvertMaxMin, ConvertMaxMin, - ConvertMaxMin, + ConvertMaxMin, ConvertSubI, // Bitwise binary ops. ConvertBitwiseBinary, ConvertBitwiseBinary, ConvertBitwiseBinary, diff --git a/mlir/test/Dialect/Arith/emulate-wide-int.mlir b/mlir/test/Dialect/Arith/emulate-wide-int.mlir index ed08779c10266..52da80ce26a73 100644 --- a/mlir/test/Dialect/Arith/emulate-wide-int.mlir +++ b/mlir/test/Dialect/Arith/emulate-wide-int.mlir @@ -130,6 +130,44 @@ func.func @addi_vector_a_b(%a : vector<4xi64>, %b : vector<4xi64>) -> vector<4xi return %x : vector<4xi64> } +// CHECK-LABEL: func @subi_scalar +// CHECK-SAME: ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32> +// CHECK-NEXT: [[LOW0:%.+]] = vector.extract [[ARG0]][0] : i32 from vector<2xi32> +// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract [[ARG0]][1] : i32 from vector<2xi32> +// CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[ARG1]][0] : i32 from vector<2xi32> +// CHECK-NEXT: [[HIGH1:%.+]] = vector.extract [[ARG1]][1] : i32 from vector<2xi32> +// CHECK-NEXT: [[SUB_L:%.+]] = arith.subi [[LOW0]], [[LOW1]] : i32 +// CHECK-NEXT: [[ULT:%.+]] = arith.cmpi ult, [[LOW0]], [[LOW1]] : i32 +// CHECK-NEXT: [[CARRY:%.+]] = arith.extui [[ULT]] : i1 to i32 +// CHECK-NEXT: [[SUB_H0:%.+]] = arith.subi [[HIGH0]], [[CARRY]] : i32 +// CHECK-NEXT: [[SUB_H1:%.+]] = arith.subi [[SUB_H0]], [[HIGH1]] : i32 +// CHECK: [[INS0:%.+]] = vector.insert [[SUB_L]], {{%.+}} [0] : i32 into vector<2xi32> +// CHECK-NEXT: [[INS1:%.+]] = vector.insert [[SUB_H1]], [[INS0]] [1] : i32 into vector<2xi32> +// CHECK-NEXT: return [[INS1]] : vector<2xi32> +func.func @subi_scalar(%a : i64, %b : i64) -> i64 { + %x = arith.subi %a, %b : i64 + return %x : i64 +} + +// CHECK-LABEL: func @subi_vector +// CHECK-SAME: ([[ARG0:%.+]]: vector<4x2xi32>, [[ARG1:%.+]]: vector<4x2xi32>) -> vector<4x2xi32> +// CHECK-NEXT: [[LOW0:%.+]] = vector.extract_strided_slice [[ARG0]] {offsets = [0, 0], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32> +// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract_strided_slice [[ARG0]] {offsets = [0, 1], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32> +// CHECK-NEXT: [[LOW1:%.+]] = vector.extract_strided_slice [[ARG1]] {offsets = [0, 0], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32> +// CHECK-NEXT: [[HIGH1:%.+]] = vector.extract_strided_slice [[ARG1]] {offsets = [0, 1], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32> +// CHECK-NEXT: [[SUB_L:%.+]] = arith.subi [[LOW0]], [[LOW1]] : vector<4x1xi32> +// CHECK-NEXT: [[ULT:%.+]] = arith.cmpi ult, [[LOW0]], [[LOW1]] : vector<4x1xi32> +// CHECK-NEXT: [[CARRY:%.+]] = arith.extui [[ULT]] : vector<4x1xi1> to vector<4x1xi32> +// CHECK-NEXT: [[SUB_H0:%.+]] = arith.subi [[HIGH0]], [[CARRY]] : vector<4x1xi32> +// CHECK-NEXT: [[SUB_H1:%.+]] = arith.subi [[SUB_H0]], [[HIGH1]] : vector<4x1xi32> +// CHECK: [[INS0:%.+]] = vector.insert_strided_slice [[SUB_L]], {{%.+}} {offsets = [0, 0], strides = [1, 1]} : vector<4x1xi32> into vector<4x2xi32> +// CHECK-NEXT: [[INS1:%.+]] = vector.insert_strided_slice [[SUB_H1]], [[INS0]] {offsets = [0, 1], strides = [1, 1]} : vector<4x1xi32> into vector<4x2xi32> +// CHECK-NEXT: return [[INS1]] : vector<4x2xi32> +func.func @subi_vector(%a : vector<4xi64>, %b : vector<4xi64>) -> vector<4xi64> { + %x = arith.subi %a, %b : vector<4xi64> + return %x : vector<4xi64> +} + // CHECK-LABEL: func.func @cmpi_eq_scalar // CHECK-SAME: ([[LHS:%.+]]: vector<2xi32>, [[RHS:%.+]]: vector<2xi32>) // CHECK-NEXT: [[LHSLOW:%.+]] = vector.extract [[LHS]][0] : i32 from vector<2xi32> @@ -967,11 +1005,12 @@ func.func @uitofp_i64_f16(%a : i64) -> f16 { // CHECK-LABEL: func @sitofp_i64_f64 // CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> f64 -// CHECK: [[VONES:%.+]] = arith.constant dense<-1> : vector<2xi32> -// CHECK: [[ONES1:%.+]] = vector.extract [[VONES]][0] : i32 from vector<2xi32> -// CHECK-NEXT: [[ONES2:%.+]] = vector.extract [[VONES]][1] : i32 from vector<2xi32> -// CHECK: arith.xori {{%.+}}, [[ONES1]] : i32 -// CHECK-NEXT: arith.xori {{%.+}}, [[ONES2]] : i32 +// CHECK: [[VZERO:%.+]] = arith.constant dense<0> : vector<2xi32> +// CHECK: vector.extract [[VZERO]][0] : i32 from vector<2xi32> +// CHECK: [[ZERO1:%.+]] = vector.extract [[VZERO]][0] : i32 from vector<2xi32> +// CHECK-NEXT: [[ZERO2:%.+]] = vector.extract [[VZERO]][1] : i32 from vector<2xi32> +// CHECK: arith.subi [[ZERO1]], {{%.+}} : i32 +// CHECK: arith.subi [[ZERO2]], {{%.+}} : i32 // CHECK: [[CST0:%.+]] = arith.constant 0 : i32 // CHECK: [[HIEQ0:%.+]] = arith.cmpi eq, [[HI:%.+]], [[CST0]] : i32 // CHECK-NEXT: [[LOWFP:%.+]] = arith.uitofp [[LOW:%.+]] : i32 to f64 @@ -990,9 +1029,9 @@ func.func @sitofp_i64_f64(%a : i64) -> f64 { // CHECK-LABEL: func @sitofp_i64_f64_vector // CHECK-SAME: ([[ARG:%.+]]: vector<3x2xi32>) -> vector<3xf64> -// CHECK: [[VONES:%.+]] = arith.constant dense<-1> : vector<3x2xi32> -// CHECK: arith.xori -// CHECK-NEXT: arith.xori +// CHECK: [[VZERO:%.+]] = arith.constant dense<0> : vector<3x2xi32> +// CHECK: arith.subi +// CHECK: arith.subi // CHECK: [[HIEQ0:%.+]] = arith.cmpi eq, [[HI:%.+]], [[CST0:%.+]] : vector<3xi32> // CHECK-NEXT: [[LOWFP:%.+]] = arith.uitofp [[LOW:%.+]] : vector<3xi32> to vector<3xf64> // CHECK-NEXT: [[HIFP:%.+]] = arith.uitofp [[HI:%.+]] : vector<3xi32> to vector<3xf64> diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-subi-i32.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-subi-i32.mlir new file mode 100644 index 0000000000000..63d2c941c48e7 --- /dev/null +++ b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-subi-i32.mlir @@ -0,0 +1,104 @@ +// Ops in this function will be emulated using i16 types. + +// RUN: mlir-opt %s --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \ +// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \ +// RUN: mlir-runner -e entry -entry-point-result=void \ +// RUN: --shared-libs=%mlir_c_runner_utils | \ +// RUN: FileCheck %s --match-full-lines + +// RUN: mlir-opt %s --test-arith-emulate-wide-int="widest-int-supported=16" \ +// RUN: --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \ +// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \ +// RUN: mlir-runner -e entry -entry-point-result=void \ +// RUN: --shared-libs=%mlir_c_runner_utils | \ +// RUN: FileCheck %s --match-full-lines + +func.func @emulate_subi(%arg: i32, %arg0: i32) -> i32 { + %res = arith.subi %arg, %arg0 : i32 + return %res : i32 +} + +func.func @check_subi(%arg : i32, %arg0 : i32) -> () { + %res = func.call @emulate_subi(%arg, %arg0) : (i32, i32) -> (i32) + vector.print %res : i32 + return +} + +func.func @entry() { + %lhs1 = arith.constant 1 : i32 + %rhs1 = arith.constant 2 : i32 + + // CHECK: -1 + func.call @check_subi(%lhs1, %rhs1) : (i32, i32) -> () + // CHECK-NEXT: 1 + func.call @check_subi(%rhs1, %lhs1) : (i32, i32) -> () + + %lhs2 = arith.constant 1 : i32 + %rhs2 = arith.constant -2 : i32 + + // CHECK-NEXT: 3 + func.call @check_subi(%lhs2, %rhs2) : (i32, i32) -> () + // CHECK-NEXT: -3 + func.call @check_subi(%rhs2, %lhs2) : (i32, i32) -> () + + %lhs3 = arith.constant -1 : i32 + %rhs3 = arith.constant -2 : i32 + + // CHECK-NEXT: 1 + func.call @check_subi(%lhs3, %rhs3) : (i32, i32) -> () + // CHECK-NEXT: -1 + func.call @check_subi(%rhs3, %lhs3) : (i32, i32) -> () + + // Overflow from the upper/lower part. + %lhs4 = arith.constant 131074 : i32 + %rhs4 = arith.constant 3 : i32 + + // CHECK-NEXT: 131071 + func.call @check_subi(%lhs4, %rhs4) : (i32, i32) -> () + // CHECK-NEXT: -131071 + func.call @check_subi(%rhs4, %lhs4) : (i32, i32) -> () + + // Overflow in both parts. + %lhs5 = arith.constant 16385027 : i32 + %rhs5 = arith.constant 16450564 : i32 + + // CHECK-NEXT: -65537 + func.call @check_subi(%lhs5, %rhs5) : (i32, i32) -> () + // CHECK-NEXT: 65537 + func.call @check_subi(%rhs5, %lhs5) : (i32, i32) -> () + + %lhs6 = arith.constant 65536 : i32 + %rhs6 = arith.constant 1 : i32 + + // CHECK-NEXT: 65535 + func.call @check_subi(%lhs6, %rhs6) : (i32, i32) -> () + // CHECK-NEXT: -65535 + func.call @check_subi(%rhs6, %lhs6) : (i32, i32) -> () + + // Max/Min (un)signed integers. + %sintmax = arith.constant 2147483647 : i32 + %sintmin = arith.constant -2147483648 : i32 + %uintmax = arith.constant -1 : i32 + %uintmin = arith.constant 0 : i32 + %cst1 = arith.constant 1 : i32 + + // CHECK-NEXT: -1 + func.call @check_subi(%sintmax, %sintmin) : (i32, i32) -> () + // CHECK-NEXT: 1 + func.call @check_subi(%sintmin, %sintmax) : (i32, i32) -> () + // CHECK-NEXT: 2147483647 + func.call @check_subi(%sintmin, %cst1) : (i32, i32) -> () + // CHECK-NEXT: -2147483648 + func.call @check_subi(%sintmax, %uintmax) : (i32, i32) -> () + // CHECK-NEXT: -2 + func.call @check_subi(%uintmax, %cst1) : (i32, i32) -> () + // CHECK-NEXT: 0 + func.call @check_subi(%uintmax, %uintmax) : (i32, i32) -> () + // CHECK-NEXT: -1 + func.call @check_subi(%uintmin, %cst1) : (i32, i32) -> () + // CHECK-NEXT: 1 + func.call @check_subi(%uintmin, %uintmax) : (i32, i32) -> () + + + return +}