Skip to content

Commit

Permalink
[mlir] Add arith.addi_carry op
Browse files Browse the repository at this point in the history
The `arith.addi_carry` op implements integer addition with overflows. The carry is returned via the second result, as `i1`.

Reviewed By: antiagainst, bondhugula

Differential Revision: https://reviews.llvm.org/D131893
  • Loading branch information
kuhar committed Aug 17, 2022
1 parent 36bdec4 commit 4309170
Show file tree
Hide file tree
Showing 5 changed files with 249 additions and 1 deletion.
35 changes: 35 additions & 0 deletions mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
Expand Up @@ -202,6 +202,41 @@ def Arith_AddIOp : Arith_IntBinaryOp<"addi", [Commutative]> {
let hasCanonicalizer = 1;
}


def Arith_AddICarryOp : Arith_Op<"addi_carry", [Commutative,
AllTypesMatch<["lhs", "rhs", "sum"]>]> {
let summary = "integer addition operation returning both the sum and carry";
let description = [{
The `addi_carry` operation takes two operands and returns two results: the
sum (same type as both operands), and the carry (boolean-like).

Example:

```mlir
// Scalar addition.
%sum, %carry = arith.addi_carry %b, %c : i64, i1

// Vector element-wise addition.
%b:2 = arith.addi_carry %g, %h : vector<4xi32>, vector<4xi1>

// Tensor element-wise addition.
%c:2 = arith.addi_carry %y, %z : tensor<4x?xi8>, tensor<4x?xi1>
```
}];

let arguments = (ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs);
let results = (outs SignlessIntegerLike:$sum, BoolLike:$carry);
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` type($sum) `,` type($carry)
}];

let hasFolder = 1;

let extraClassDeclaration = [{
::llvm::Optional<::llvm::SmallVector<int64_t, 4>> getShapeForUnroll();
}];
}

//===----------------------------------------------------------------------===//
// SubIOp
//===----------------------------------------------------------------------===//
Expand Down
78 changes: 77 additions & 1 deletion mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
Expand Up @@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//

#include <cassert>
#include <utility>

#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
Expand All @@ -15,9 +16,9 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/SmallString.h"

#include "llvm/ADT/APSInt.h"
#include "llvm/ADT/SmallString.h"

using namespace mlir;
using namespace mlir::arith;
Expand Down Expand Up @@ -216,6 +217,81 @@ void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
context);
}

//===----------------------------------------------------------------------===//
// AddICarryOp
//===----------------------------------------------------------------------===//

Optional<SmallVector<int64_t, 4>> arith::AddICarryOp::getShapeForUnroll() {
if (auto vt = getType(0).dyn_cast<VectorType>())
return llvm::to_vector<4>(vt.getShape());
return None;
}

// Returns the carry bit, assuming that `sum` is the result of addition of
// `operand` and another number.
static APInt calculateCarry(const APInt &sum, const APInt &operand) {
return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
}

LogicalResult arith::AddICarryOp::fold(ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
auto carryTy = getCarry().getType();
// addi_carry(x, 0) -> x, false
if (matchPattern(getRhs(), m_Zero())) {
auto carryZero = APInt::getZero(1);
Builder builder(getContext());
auto falseValue = builder.getZeroAttr(carryTy);

results.push_back(getLhs());
results.push_back(falseValue);
return success();
}

// addi_carry(constant_a, constant_b) -> constant_sum, constant_carry
// Let the `constFoldBinaryOp` utility attempt to fold the sum of both
// operands. If that succeeds, calculate the carry boolean based on the sum
// and the first (constant) operand, `lhs`. Note that we cannot simply call
// `constFoldBinaryOp` again to calculate the carry (bit) because the
// constructed attribute is of the same element type as both operands.
if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
operands, [](APInt a, const APInt &b) { return std::move(a) + b; })) {
Attribute carryAttr;
if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
// Both arguments are scalars, calculate the scalar carry value.
auto sum = sumAttr.cast<IntegerAttr>();
carryAttr = IntegerAttr::get(
carryTy, calculateCarry(sum.getValue(), lhs.getValue()));
} else if (auto lhs = operands[0].dyn_cast<SplatElementsAttr>()) {
// Both arguments are splats, calculate the splat carry value.
auto sum = sumAttr.cast<SplatElementsAttr>();
APInt carry = calculateCarry(sum.getSplatValue<APInt>(),
lhs.getSplatValue<APInt>());
carryAttr = SplatElementsAttr::get(carryTy, carry);
} else if (auto lhs = operands[0].dyn_cast<ElementsAttr>()) {
// Othwerwise calculate element-wise carry values.
auto sum = sumAttr.cast<ElementsAttr>();
const auto numElems = static_cast<size_t>(sum.getNumElements());
SmallVector<APInt> carryValues;
carryValues.reserve(numElems);

auto sumIt = sum.value_begin<APInt>();
auto lhsIt = lhs.value_begin<APInt>();
for (size_t i = 0, e = numElems; i != e; ++i, ++sumIt, ++lhsIt)
carryValues.push_back(calculateCarry(*sumIt, *lhsIt));

carryAttr = DenseElementsAttr::get(carryTy, carryValues);
} else {
return failure();
}

results.push_back(sumAttr);
results.push_back(carryAttr);
return success();
}

return failure();
}

//===----------------------------------------------------------------------===//
// SubIOp
//===----------------------------------------------------------------------===//
Expand Down
81 changes: 81 additions & 0 deletions mlir/test/Dialect/Arithmetic/canonicalize.mlir
Expand Up @@ -544,6 +544,87 @@ func.func @doubleAddSub2(%arg0: index, %arg1 : index) -> index {
return %add : index
}

// CHECK-LABEL: @addiCarryZeroRhs
// CHECK-NEXT: %[[false:.+]] = arith.constant false
// CHECK-NEXT: return %arg0, %[[false]]
func.func @addiCarryZeroRhs(%arg0: i32) -> (i32, i1) {
%zero = arith.constant 0 : i32
%sum, %carry = arith.addi_carry %arg0, %zero: i32, i1
return %sum, %carry : i32, i1
}

// CHECK-LABEL: @addiCarryZeroRhsSplat
// CHECK-NEXT: %[[false:.+]] = arith.constant dense<false> : vector<4xi1>
// CHECK-NEXT: return %arg0, %[[false]]
func.func @addiCarryZeroRhsSplat(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi1>) {
%zero = arith.constant dense<0> : vector<4xi32>
%sum, %carry = arith.addi_carry %arg0, %zero: vector<4xi32>, vector<4xi1>
return %sum, %carry : vector<4xi32>, vector<4xi1>
}

// CHECK-LABEL: @addiCarryZeroLhs
// CHECK-NEXT: %[[false:.+]] = arith.constant false
// CHECK-NEXT: return %arg0, %[[false]]
func.func @addiCarryZeroLhs(%arg0: i32) -> (i32, i1) {
%zero = arith.constant 0 : i32
%sum, %carry = arith.addi_carry %zero, %arg0: i32, i1
return %sum, %carry : i32, i1
}

// CHECK-LABEL: @addiCarryConstants
// CHECK-DAG: %[[false:.+]] = arith.constant false
// CHECK-DAG: %[[c50:.+]] = arith.constant 50 : i32
// CHECK-NEXT: return %[[c50]], %[[false]]
func.func @addiCarryConstants() -> (i32, i1) {
%c13 = arith.constant 13 : i32
%c37 = arith.constant 37 : i32
%sum, %carry = arith.addi_carry %c13, %c37: i32, i1
return %sum, %carry : i32, i1
}

// CHECK-LABEL: @addiCarryConstantsOverflow1
// CHECK-DAG: %[[true:.+]] = arith.constant true
// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : i32
// CHECK-NEXT: return %[[c0]], %[[true]]
func.func @addiCarryConstantsOverflow1() -> (i32, i1) {
%max = arith.constant 4294967295 : i32
%c1 = arith.constant 1 : i32
%sum, %carry = arith.addi_carry %max, %c1: i32, i1
return %sum, %carry : i32, i1
}

// CHECK-LABEL: @addiCarryConstantsOverflow2
// CHECK-DAG: %[[true:.+]] = arith.constant true
// CHECK-DAG: %[[c_2:.+]] = arith.constant -2 : i32
// CHECK-NEXT: return %[[c_2]], %[[true]]
func.func @addiCarryConstantsOverflow2() -> (i32, i1) {
%max = arith.constant 4294967295 : i32
%sum, %carry = arith.addi_carry %max, %max: i32, i1
return %sum, %carry : i32, i1
}

// CHECK-LABEL: @addiCarryConstantsOverflowVector
// CHECK-DAG: %[[sum:.+]] = arith.constant dense<[1, 6, 2, 14]> : vector<4xi32>
// CHECK-DAG: %[[carry:.+]] = arith.constant dense<[false, false, true, false]> : vector<4xi1>
// CHECK-NEXT: return %[[sum]], %[[carry]]
func.func @addiCarryConstantsOverflowVector() -> (vector<4xi32>, vector<4xi1>) {
%v1 = arith.constant dense<[1, 3, 3, 7]> : vector<4xi32>
%v2 = arith.constant dense<[0, 3, 4294967295, 7]> : vector<4xi32>
%sum, %carry = arith.addi_carry %v1, %v2 : vector<4xi32>, vector<4xi1>
return %sum, %carry : vector<4xi32>, vector<4xi1>
}

// CHECK-LABEL: @addiCarryConstantsSplatVector
// CHECK-DAG: %[[sum:.+]] = arith.constant dense<3> : vector<4xi32>
// CHECK-DAG: %[[carry:.+]] = arith.constant dense<false> : vector<4xi1>
// CHECK-NEXT: return %[[sum]], %[[carry]]
func.func @addiCarryConstantsSplatVector() -> (vector<4xi32>, vector<4xi1>) {
%v1 = arith.constant dense<1> : vector<4xi32>
%v2 = arith.constant dense<2> : vector<4xi32>
%sum, %carry = arith.addi_carry %v1, %v2 : vector<4xi32>, vector<4xi1>
return %sum, %carry : vector<4xi32>, vector<4xi1>
}

// CHECK-LABEL: @notCmpEQ
// CHECK: %[[cres:.+]] = arith.cmpi ne, %arg0, %arg1 : i8
// CHECK: return %[[cres]]
Expand Down
32 changes: 32 additions & 0 deletions mlir/test/Dialect/Arithmetic/invalid.mlir
Expand Up @@ -110,6 +110,38 @@ func.func @func_with_ops(f32) {

// -----

func.func @func_with_ops(%a: f32) {
// expected-error@+1 {{'arith.addi_carry' op operand #0 must be signless-integer-like}}
%r:2 = arith.addi_carry %a, %a : f32, i32
return
}

// -----

func.func @func_with_ops(%a: i32) {
// expected-error@+1 {{'arith.addi_carry' op result #1 must be bool-like}}
%r:2 = arith.addi_carry %a, %a : i32, i32
return
}

// -----

func.func @func_with_ops(%a: vector<8xi32>) {
// expected-error@+1 {{'arith.addi_carry' op if an operand is non-scalar, then all results must be non-scalar}}
%r:2 = arith.addi_carry %a, %a : vector<8xi32>, i1
return
}

// -----

func.func @func_with_ops(%a: vector<8xi32>) {
// expected-error@+1 {{'arith.addi_carry' op all non-scalar operands/results must have the same shape and base type}}
%r:2 = arith.addi_carry %a, %a : vector<8xi32>, tensor<8xi1>
return
}

// -----

func.func @func_with_ops(i32) {
^bb0(%a : i32):
%sf = arith.addf %a, %a : i32 // expected-error {{'arith.addf' op operand #0 must be floating-point-like}}
Expand Down
24 changes: 24 additions & 0 deletions mlir/test/Dialect/Arithmetic/ops.mlir
Expand Up @@ -25,6 +25,30 @@ func.func @test_addi_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]
return %0 : vector<[8]xi64>
}

// CHECK-LABEL: test_addi_carry
func.func @test_addi_carry(%arg0 : i64, %arg1 : i64) -> i64 {
%sum, %carry = arith.addi_carry %arg0, %arg1 : i64, i1
return %sum : i64
}

// CHECK-LABEL: test_addi_carry_tensor
func.func @test_addi_carry_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> {
%sum, %carry = arith.addi_carry %arg0, %arg1 : tensor<8x8xi64>, tensor<8x8xi1>
return %sum : tensor<8x8xi64>
}

// CHECK-LABEL: test_addi_carry_vector
func.func @test_addi_carry_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8xi64> {
%0:2 = arith.addi_carry %arg0, %arg1 : vector<8xi64>, vector<8xi1>
return %0#0 : vector<8xi64>
}

// CHECK-LABEL: test_addi_carry_scalable_vector
func.func @test_addi_carry_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
%0:2 = arith.addi_carry %arg0, %arg1 : vector<[8]xi64>, vector<[8]xi1>
return %0#0 : vector<[8]xi64>
}

// CHECK-LABEL: test_subi
func.func @test_subi(%arg0 : i64, %arg1 : i64) -> i64 {
%0 = arith.subi %arg0, %arg1 : i64
Expand Down

0 comments on commit 4309170

Please sign in to comment.