Skip to content

Commit

Permalink
[mlir][arith] Allow to specify constFoldBinaryOp result type
Browse files Browse the repository at this point in the history
This enables us to use the common fold helpers on elementwise ops that
produce different result type than operand types, e.g., `arith.cmpi` or
`arith.addui_extended`.

Use the updated helper to teach `arith.cmpi` to fold constant vectors.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D143779
  • Loading branch information
kuhar committed Feb 13, 2023
1 parent dc38cbc commit 892bf09
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 65 deletions.
57 changes: 51 additions & 6 deletions mlir/include/mlir/Dialect/CommonFolders.h
Expand Up @@ -24,14 +24,16 @@
namespace mlir {
/// Performs constant folding `calculate` with element-wise behavior on the two
/// attributes in `operands` and returns the result if possible.
/// Uses `resultType` for the type of the returned attribute.
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class CalculationT = function_ref<
std::optional<ElementValueT>(ElementValueT, ElementValueT)>>
Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
Type resultType,
const CalculationT &calculate) {
assert(operands.size() == 2 && "binary op takes two operands");
if (!operands[0] || !operands[1])
if (!resultType || !operands[0] || !operands[1])
return {};

if (operands[0].isa<AttrElementT>() && operands[1].isa<AttrElementT>()) {
Expand All @@ -45,7 +47,7 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
if (!calRes)
return {};

return AttrElementT::get(lhs.getType(), *calRes);
return AttrElementT::get(resultType, *calRes);
}

if (operands[0].isa<SplatElementsAttr>() &&
Expand All @@ -62,9 +64,10 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
if (!elementResult)
return {};

return DenseElementsAttr::get(lhs.getType(), *elementResult);
} else if (operands[0].isa<ElementsAttr>() &&
operands[1].isa<ElementsAttr>()) {
return DenseElementsAttr::get(resultType, *elementResult);
}

if (operands[0].isa<ElementsAttr>() && operands[1].isa<ElementsAttr>()) {
// Operands are ElementsAttr-derived; perform an element-wise fold by
// expanding the values.
auto lhs = operands[0].cast<ElementsAttr>();
Expand All @@ -83,11 +86,53 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
elementResults.push_back(*elementResult);
}

return DenseElementsAttr::get(lhs.getType(), elementResults);
return DenseElementsAttr::get(resultType, elementResults);
}
return {};
}

/// Performs constant folding `calculate` with element-wise behavior on the two
/// attributes in `operands` and returns the result if possible.
/// Uses the operand element type for the element type of the returned
/// attribute.
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class CalculationT = function_ref<
std::optional<ElementValueT>(ElementValueT, ElementValueT)>>
Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
const CalculationT &calculate) {
assert(operands.size() == 2 && "binary op takes two operands");
auto getResultType = [](Attribute attr) -> Type {
if (auto typed = attr.dyn_cast_or_null<TypedAttr>())
return typed.getType();
return {};
};

Type lhsType = getResultType(operands[0]);
Type rhsType = getResultType(operands[1]);
if (!lhsType || !rhsType)
return {};
if (lhsType != rhsType)
return {};

return constFoldBinaryOpConditional<AttrElementT, ElementValueT,
CalculationT>(operands, lhsType,
calculate);
}

template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class CalculationT =
function_ref<ElementValueT(ElementValueT, ElementValueT)>>
Attribute constFoldBinaryOp(ArrayRef<Attribute> operands, Type resultType,
const CalculationT &calculate) {
return constFoldBinaryOpConditional<AttrElementT>(
operands, resultType,
[&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
return calculate(a, b);
});
}

template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class CalculationT =
Expand Down
90 changes: 35 additions & 55 deletions mlir/lib/Dialect/Arith/IR/ArithOps.cpp
Expand Up @@ -13,11 +13,14 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"

#include "llvm/ADT/APInt.h"
#include "llvm/ADT/APSInt.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallString.h"
Expand Down Expand Up @@ -107,6 +110,23 @@ namespace {
#include "ArithCanonicalization.inc"
} // namespace

//===----------------------------------------------------------------------===//
// Common helpers
//===----------------------------------------------------------------------===//

/// Return the type of the same shape (scalar, vector or tensor) containing i1.
static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(type.getContext(), 1);
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i1Type);
if (type.isa<UnrankedTensorType>())
return UnrankedTensorType::get(i1Type);
if (auto vectorType = type.dyn_cast<VectorType>())
return VectorType::get(vectorType.getShape(), i1Type,
vectorType.getNumScalableDims());
return i1Type;
}

//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -276,41 +296,16 @@ arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
// addui_extended(constant_a, constant_b) -> constant_sum, constant_carry
// Let the `constFoldBinaryOp` utility attempt to fold the sum of both
// operands. If that succeeds, calculate the overflow bit based on the sum
// and the first (constant) operand, `lhs`. Note that we cannot simply call
// `constFoldBinaryOp` again to calculate the overflow bit because the
// constructed attribute is of the same element type as both operands.
// and the first (constant) operand, `lhs`.
if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(),
[](APInt a, const APInt &b) { return std::move(a) + b; })) {
Attribute overflowAttr;
if (auto lhs = adaptor.getLhs().dyn_cast<IntegerAttr>()) {
// Both arguments are scalars, calculate the scalar overflow value.
auto sum = sumAttr.cast<IntegerAttr>();
overflowAttr = IntegerAttr::get(
overflowTy,
calculateUnsignedOverflow(sum.getValue(), lhs.getValue()));
} else if (auto lhs = adaptor.getLhs().dyn_cast<SplatElementsAttr>()) {
// Both arguments are splats, calculate the splat overflow value.
auto sum = sumAttr.cast<SplatElementsAttr>();
APInt overflow = calculateUnsignedOverflow(sum.getSplatValue<APInt>(),
lhs.getSplatValue<APInt>());
overflowAttr = SplatElementsAttr::get(overflowTy, overflow);
} else if (auto lhs = adaptor.getLhs().dyn_cast<ElementsAttr>()) {
// Othwerwise calculate element-wise overflow values.
auto sum = sumAttr.cast<ElementsAttr>();
const auto numElems = static_cast<size_t>(sum.getNumElements());
SmallVector<APInt> overflowValues;
overflowValues.reserve(numElems);

auto sumIt = sum.value_begin<APInt>();
auto lhsIt = lhs.value_begin<APInt>();
for (size_t i = 0, e = numElems; i != e; ++i, ++sumIt, ++lhsIt)
overflowValues.push_back(calculateUnsignedOverflow(*sumIt, *lhsIt));

overflowAttr = DenseElementsAttr::get(overflowTy, overflowValues);
} else {
Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
ArrayRef({sumAttr, adaptor.getLhs()}),
getI1SameShape(sumAttr.cast<TypedAttr>().getType()),
calculateUnsignedOverflow);
if (!overflowAttr)
return failure();
}

results.push_back(sumAttr);
results.push_back(overflowAttr);
Expand Down Expand Up @@ -1534,23 +1529,6 @@ void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<BitcastOfBitcast>(context);
}

//===----------------------------------------------------------------------===//
// Helpers for compare ops
//===----------------------------------------------------------------------===//

/// Return the type of the same shape (scalar, vector or tensor) containing i1.
static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(type.getContext(), 1);
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i1Type);
if (type.isa<UnrankedTensorType>())
return UnrankedTensorType::get(i1Type);
if (auto vectorType = type.dyn_cast<VectorType>())
return VectorType::get(vectorType.getShape(), i1Type,
vectorType.getNumScalableDims());
return i1Type;
}

//===----------------------------------------------------------------------===//
// CmpIOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1671,16 +1649,18 @@ OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
llvm_unreachable("unknown cmpi predicate kind");
}

auto lhs = adaptor.getLhs().dyn_cast_or_null<IntegerAttr>();
if (!lhs)
return {};

// We are moving constants to the right side; So if lhs is constant rhs is
// guaranteed to be a constant.
auto rhs = adaptor.getRhs().cast<IntegerAttr>();
if (auto lhs = adaptor.getLhs().dyn_cast_or_null<TypedAttr>()) {
return constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), getI1SameShape(lhs.getType()),
[pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
return APInt(1,
static_cast<int64_t>(applyCmpPredicate(pred, lhs, rhs)));
});
}

auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
return BoolAttr::get(getContext(), val);
return {};
}

void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
Expand Down
40 changes: 40 additions & 0 deletions mlir/test/Dialect/Arith/canonicalize.mlir
Expand Up @@ -322,6 +322,46 @@ func.func @cmpIExtUIEQ(%arg0: i8, %arg1: i8) -> i1 {
return %res : i1
}

// CHECK-LABEL: @cmpIFoldEQ
// CHECK: %[[res:.+]] = arith.constant dense<[true, true, false]> : vector<3xi1>
// CHECK: return %[[res]]
func.func @cmpIFoldEQ() -> vector<3xi1> {
%lhs = arith.constant dense<[1, 2, 3]> : vector<3xi32>
%rhs = arith.constant dense<[1, 2, 4]> : vector<3xi32>
%res = arith.cmpi eq, %lhs, %rhs : vector<3xi32>
return %res : vector<3xi1>
}

// CHECK-LABEL: @cmpIFoldNE
// CHECK: %[[res:.+]] = arith.constant dense<[false, false, true]> : vector<3xi1>
// CHECK: return %[[res]]
func.func @cmpIFoldNE() -> vector<3xi1> {
%lhs = arith.constant dense<[1, 2, 3]> : vector<3xi32>
%rhs = arith.constant dense<[1, 2, 4]> : vector<3xi32>
%res = arith.cmpi ne, %lhs, %rhs : vector<3xi32>
return %res : vector<3xi1>
}

// CHECK-LABEL: @cmpIFoldSGE
// CHECK: %[[res:.+]] = arith.constant dense<[true, true, false]> : vector<3xi1>
// CHECK: return %[[res]]
func.func @cmpIFoldSGE() -> vector<3xi1> {
%lhs = arith.constant dense<2> : vector<3xi32>
%rhs = arith.constant dense<[1, 2, 4]> : vector<3xi32>
%res = arith.cmpi sge, %lhs, %rhs : vector<3xi32>
return %res : vector<3xi1>
}

// CHECK-LABEL: @cmpIFoldULT
// CHECK: %[[res:.+]] = arith.constant dense<false> : vector<3xi1>
// CHECK: return %[[res]]
func.func @cmpIFoldULT() -> vector<3xi1> {
%lhs = arith.constant dense<2> : vector<3xi32>
%rhs = arith.constant dense<1> : vector<3xi32>
%res = arith.cmpi ult, %lhs, %rhs : vector<3xi32>
return %res : vector<3xi1>
}

// -----

// CHECK-LABEL: @andOfExtSI
Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Dialect/SCF/canonicalize.mlir
Expand Up @@ -1070,13 +1070,13 @@ func.func @invariant_loop_args_in_same_order(%f_arg0: tensor<i32>) -> (tensor<i3
// CHECK: return %[[WHILE]]#0, %[[FUNC_ARG0]], %[[WHILE]]#1, %[[WHILE]]#2, %[[ZERO]]

// CHECK-LABEL: @while_loop_invariant_argument_different_order
func.func @while_loop_invariant_argument_different_order() -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
func.func @while_loop_invariant_argument_different_order(%arg : tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
%cst_0 = arith.constant dense<0> : tensor<i32>
%cst_1 = arith.constant dense<1> : tensor<i32>
%cst_42 = arith.constant dense<42> : tensor<i32>

%0:6 = scf.while (%arg0 = %cst_0, %arg1 = %cst_1, %arg2 = %cst_1, %arg3 = %cst_1, %arg4 = %cst_0) : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
%1 = arith.cmpi slt, %arg0, %cst_42 : tensor<i32>
%1 = arith.cmpi slt, %arg0, %arg : tensor<i32>
%2 = tensor.extract %1[] : tensor<i1>
scf.condition(%2) %arg1, %arg0, %arg2, %arg0, %arg3, %arg4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
} do {
Expand All @@ -1087,11 +1087,11 @@ func.func @while_loop_invariant_argument_different_order() -> (tensor<i32>, tens
}
return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
}
// CHECK-SAME: (%[[ARG:.+]]: tensor<i32>)
// CHECK: %[[ZERO:.*]] = arith.constant dense<0>
// CHECK: %[[ONE:.*]] = arith.constant dense<1>
// CHECK: %[[CST42:.*]] = arith.constant dense<42>
// CHECK: %[[WHILE:.*]]:2 = scf.while (%[[ARG1:.*]] = %[[ONE]], %[[ARG4:.*]] = %[[ZERO]])
// CHECK: arith.cmpi slt, %[[ZERO]], %[[CST42]]
// CHECK: arith.cmpi sgt, %[[ARG]], %[[ZERO]]
// CHECK: tensor.extract %{{.*}}[]
// CHECK: scf.condition(%{{.*}}) %[[ARG1]], %[[ARG4]]
// CHECK: } do {
Expand Down

0 comments on commit 892bf09

Please sign in to comment.