Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,11 @@ class AttrConverterConstrainedFPToLLVM {
convertedAttr.set(TargetOp::getRoundingModeAttrName(),
convertArithRoundingModeAttrToLLVM(arithAttr));
}
// Constrained intrinsics (llvm.intr.experimental.constrained.*) do not
// support fastmath flags. Remove the arith fastmath attribute if present.
if constexpr (SourceOp::template hasTrait<
Comment thread
matthias-springer marked this conversation as resolved.
arith::ArithFastMathInterface::Trait>())
convertedAttr.erase(srcOp.getFastMathAttrName());
convertedAttr.set(TargetOp::getFPExceptionBehaviorAttrName(),
getLLVMDefaultFPExceptionBehavior(*srcOp->getContext()));
}
Expand Down
7 changes: 7 additions & 0 deletions mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ def Arith_Dialect : Dialect {
Manipulating value with type `i0` isn't supported in this dialect at the
moment and is considered invalid. This can change in the future if some
motivating use-cases are presented.

Some floating-point operations may specify rounding modes and/or fast-math
flags. In the absence of an explicit rounding mode, the arith dialect uses
this default round mode for internal purposes such as constant folding and
canonicalization: round-to-nearest, ties-to-even. The runtime behavior of
operations without an explicit rounding mode is deferred to the target
backend and may differ from the default arith rounding mode.
}];

let hasConstantMaterializer = 1;
Expand Down
103 changes: 75 additions & 28 deletions mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,52 @@ class Arith_FloatUnaryOp<string mnemonic, list<Trait> traits = []> :
class Arith_FloatBinaryOp<string mnemonic, list<Trait> traits = []> :
Arith_BinaryOp<mnemonic,
!listconcat([Pure, DeclareOpInterfaceMethods<ArithFastMathInterface>],
traits)>,
Arguments<(ins FloatLike:$lhs, FloatLike:$rhs,
traits)> {
let arguments = (ins FloatLike:$lhs, FloatLike:$rhs,
DefaultValuedAttr<
Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath)>,
Results<(outs FloatLike:$result)> {
Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath);
let results = (outs FloatLike:$result);
let assemblyFormat = [{ $lhs `,` $rhs (`fastmath` `` $fastmath^)?
attr-dict `:` type($result) }];
}

// Base class for floating point binary operations with an optional rounding
// mode.
class Arith_FloatBinaryOpWithRoundingMode<string mnemonic,
list<Trait> traits = []> :
Arith_FloatBinaryOp<mnemonic,
!listconcat([DeclareOpInterfaceMethods<ArithRoundingModeInterface>],
traits)> {
let arguments = (ins FloatLike:$lhs, FloatLike:$rhs,
DefaultValuedAttr<
Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath,
OptionalAttr<Arith_RoundingModeAttr>:$roundingmode);
let builders = [
OpBuilder<(ins "Value":$lhs, "Value":$rhs,
CArg<"::mlir::arith::FastMathFlags",
"::mlir::arith::FastMathFlags::none">:$fastmath), [{
build($_builder, $_state, lhs, rhs, fastmath,
::mlir::arith::RoundingModeAttr{});
}]>,
OpBuilder<(ins "Value":$lhs, "Value":$rhs,
"::mlir::arith::FastMathFlagsAttr":$fastmath), [{
build($_builder, $_state, lhs, rhs, fastmath,
::mlir::arith::RoundingModeAttr{});
}]>,
Comment thread
rengolin marked this conversation as resolved.
OpBuilder<(ins "Type":$type, "Value":$lhs, "Value":$rhs,
CArg<"::mlir::arith::FastMathFlags",
"::mlir::arith::FastMathFlags::none">:$fastmath), [{
build($_builder, $_state, type, lhs, rhs,
::mlir::arith::FastMathFlagsAttr::get(
$_builder.getContext(), fastmath),
::mlir::arith::RoundingModeAttr{});
}]>,
];
let assemblyFormat = [{ $lhs `,` $rhs ($roundingmode^)?
(`fastmath` `` $fastmath^)?
attr-dict `:` type($result) }];
}

// Checks that tensor input and outputs have identical shapes. This is stricker
// than the verification done in `SameOperandsAndResultShape` that allows for
// tensor dimensions to be 'compatible' (e.g., dynamic dimensions being
Expand Down Expand Up @@ -957,7 +994,7 @@ def Arith_NegFOp : Arith_FloatUnaryOp<"negf"> {
// AddFOp
//===----------------------------------------------------------------------===//

def Arith_AddFOp : Arith_FloatBinaryOp<"addf", [Commutative]> {
def Arith_AddFOp : Arith_FloatBinaryOpWithRoundingMode<"addf", [Commutative]> {
let summary = "floating point addition operation";
let description = [{
The `addf` operation takes two operands and returns one result, each of
Expand All @@ -976,10 +1013,10 @@ def Arith_AddFOp : Arith_FloatBinaryOp<"addf", [Commutative]> {

// Tensor addition.
%x = arith.addf %y, %z : tensor<4x?xbf16>
```

TODO: In the distant future, this will accept optional attributes for fast
math, contraction, rounding mode, and other controls.
// Scalar addition with rounding mode.
%a = arith.addf %b, %c to_nearest_even : f64
```
}];
let hasFolder = 1;
}
Expand All @@ -988,7 +1025,7 @@ def Arith_AddFOp : Arith_FloatBinaryOp<"addf", [Commutative]> {
// SubFOp
//===----------------------------------------------------------------------===//

def Arith_SubFOp : Arith_FloatBinaryOp<"subf"> {
def Arith_SubFOp : Arith_FloatBinaryOpWithRoundingMode<"subf"> {
let summary = "floating point subtraction operation";
let description = [{
The `subf` operation takes two operands and returns one result, each of
Expand All @@ -1007,10 +1044,10 @@ def Arith_SubFOp : Arith_FloatBinaryOp<"subf"> {

// Tensor subtraction.
%x = arith.subf %y, %z : tensor<4x?xbf16>
```

TODO: In the distant future, this will accept optional attributes for fast
math, contraction, rounding mode, and other controls.
// Scalar subtraction with rounding mode.
%a = arith.subf %b, %c downward : f64
```
}];
let hasFolder = 1;
}
Expand Down Expand Up @@ -1139,7 +1176,7 @@ def Arith_MinUIOp : Arith_TotalIntBinaryOp<"minui", [Commutative]> {
// MulFOp
//===----------------------------------------------------------------------===//

def Arith_MulFOp : Arith_FloatBinaryOp<"mulf", [Commutative]> {
def Arith_MulFOp : Arith_FloatBinaryOpWithRoundingMode<"mulf", [Commutative]> {
let summary = "floating point multiplication operation";
let description = [{
The `mulf` operation takes two operands and returns one result, each of
Expand All @@ -1158,10 +1195,10 @@ def Arith_MulFOp : Arith_FloatBinaryOp<"mulf", [Commutative]> {

// Tensor pointwise multiplication.
%x = arith.mulf %y, %z : tensor<4x?xbf16>
```

TODO: In the distant future, this will accept optional attributes for fast
math, contraction, rounding mode, and other controls.
// Scalar multiplication with rounding mode.
%a = arith.mulf %b, %c upward : f64
```
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
Expand All @@ -1171,8 +1208,24 @@ def Arith_MulFOp : Arith_FloatBinaryOp<"mulf", [Commutative]> {
// DivFOp
//===----------------------------------------------------------------------===//

def Arith_DivFOp : Arith_FloatBinaryOp<"divf"> {
def Arith_DivFOp : Arith_FloatBinaryOpWithRoundingMode<"divf"> {
let summary = "floating point division operation";
let description = [{
The `divf` operation takes two operands and returns one result, each of
these is required to be the same type. This type may be a floating point
scalar type, a vector whose element type is a floating point type, or a
floating point tensor.

Example:

```mlir
// Scalar division.
%a = arith.divf %b, %c : f64

// Scalar division with rounding mode.
%a = arith.divf %b, %c toward_zero : f64
```
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}
Expand All @@ -1186,6 +1239,8 @@ def Arith_RemFOp : Arith_FloatBinaryOp<"remf"> {
let description = [{
Returns the floating point division remainder.
The remainder has the same sign as the dividend (lhs operand).

TODO: Add support for rounding modes.
}];
let hasFolder = 1;
}
Expand Down Expand Up @@ -1420,8 +1475,6 @@ def Arith_TruncFOp :
let description = [{
Truncate a floating-point value to a smaller floating-point-typed value.
The destination type must be strictly narrower than the source type.
If the value cannot be exactly represented, it is rounded using the
provided rounding mode or the default one if no rounding mode is provided.
When operating on vectors, casts elementwise.
}];
let builders = [
Expand Down Expand Up @@ -1461,9 +1514,7 @@ def Arith_ConvertFOp :
be represented by `arith.extf` or `arith.truncf`.

The source and destination element types must be different and must have
the same bitwidth. If the value cannot be exactly represented, it is
rounded using the provided rounding mode or the default one if no rounding
mode is provided. When operating on vectors, casts elementwise.
the same bitwidth. When operating on vectors, casts elementwise.
}];

let hasFolder = 1;
Expand Down Expand Up @@ -1552,9 +1603,7 @@ def Arith_UIToFPOp :
let summary = "cast from unsigned integer type to floating-point";
let description = [{
Cast from a value interpreted as unsigned integer to the corresponding
floating-point value. If the value cannot be exactly represented, it is
rounded using the default rounding mode. When operating on vectors, casts
elementwise.
floating-point value. When operating on vectors, casts elementwise.

When the `nneg` flag is present, the operand is assumed to have
the most significant bit set to 0. In this case, zero extension is
Expand Down Expand Up @@ -1589,9 +1638,7 @@ def Arith_SIToFPOp : Arith_IToFCastOp<"sitofp"> {
let summary = "cast from integer type to floating-point";
let description = [{
Cast from a value interpreted as a signed integer to the corresponding
floating-point value. If the value cannot be exactly represented, it is
rounded using the default rounding mode. When operating on vectors, casts
elementwise.
floating-point value. When operating on vectors, casts elementwise.
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
Expand Down
70 changes: 49 additions & 21 deletions mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,20 @@ using namespace mlir;

namespace {

/// Operations whose conversion will depend on whether they are passed a
/// rounding mode attribute or not.
/// Lowering pattern that matches only when the source op's rounding mode
/// presence agrees with `HasRoundingMode`. This allows registering two
/// instances of the same pattern for one source op: one that handles the
/// unconstrained case (no rounding mode, lowering to a regular LLVM op) and
/// one that handles the constrained case (rounding mode present, lowering to
/// a constrained LLVM intrinsic).
///
/// `SourceOp` is the source operation; `TargetOp`, the operation it will lower
/// to; `AttrConvert` is the attribute conversion to convert the rounding mode
/// attribute.
template <typename SourceOp, typename TargetOp, bool Constrained,
/// * `HasRoundingMode`: the pattern matches if and only if the source op has
/// a rounding mode attribute.
/// * `AttrConvert`: attribute converter to translate source attributes to
/// target attributes.
/// * `FailOnUnsupportedFP`: whether to fail if the source op has unsupported
/// floating point types.
template <typename SourceOp, typename TargetOp, bool HasRoundingMode,
template <typename, typename> typename AttrConvert =
AttrConvertPassThrough,
bool FailOnUnsupportedFP = false>
Expand All @@ -49,7 +56,7 @@ struct ConstrainedVectorConvertToLLVMPattern
LogicalResult
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (Constrained != static_cast<bool>(op.getRoundingModeAttr()))
if (HasRoundingMode != static_cast<bool>(op.getRoundingModeAttr()))
return failure();
return VectorConvertToLLVMPattern<
SourceOp, TargetOp, AttrConvert,
Expand Down Expand Up @@ -81,19 +88,27 @@ struct IdentityBitcastLowering final
//===----------------------------------------------------------------------===//

using AddFOpLowering =
VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
arith::AttrConvertFastMathToLLVM,
/*FailOnUnsupportedFP=*/true>;
ConstrainedVectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
/*HasRoundingMode=*/false,
arith::AttrConvertFastMathToLLVM,
/*FailOnUnsupportedFP=*/true>;
using ConstrainedAddFOpLowering = ConstrainedVectorConvertToLLVMPattern<
arith::AddFOp, LLVM::ConstrainedFAddIntr, /*HasRoundingMode=*/true,
arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
using AddIOpLowering =
VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
arith::AttrConvertOverflowToLLVM>;
using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>;
using BitcastOpLowering =
VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
using DivFOpLowering =
VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
arith::AttrConvertFastMathToLLVM,
/*FailOnUnsupportedFP=*/true>;
ConstrainedVectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
/*HasRoundingMode=*/false,
arith::AttrConvertFastMathToLLVM,
/*FailOnUnsupportedFP=*/true>;
using ConstrainedDivFOpLowering = ConstrainedVectorConvertToLLVMPattern<
arith::DivFOp, LLVM::ConstrainedFDivIntr, /*HasRoundingMode=*/true,
arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
using DivSIOpLowering =
VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
using DivUIOpLowering =
Expand Down Expand Up @@ -139,9 +154,13 @@ using MinSIOpLowering =
using MinUIOpLowering =
VectorConvertToLLVMPattern<arith::MinUIOp, LLVM::UMinOp>;
using MulFOpLowering =
VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
arith::AttrConvertFastMathToLLVM,
/*FailOnUnsupportedFP=*/true>;
ConstrainedVectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
/*HasRoundingMode=*/false,
arith::AttrConvertFastMathToLLVM,
/*FailOnUnsupportedFP=*/true>;
using ConstrainedMulFOpLowering = ConstrainedVectorConvertToLLVMPattern<
arith::MulFOp, LLVM::ConstrainedFMulIntr, /*HasRoundingMode=*/true,
arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
using MulIOpLowering =
VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
arith::AttrConvertOverflowToLLVM>;
Expand Down Expand Up @@ -170,18 +189,23 @@ using ShRUIOpLowering =
using SIToFPOpLowering =
VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
using SubFOpLowering =
VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
arith::AttrConvertFastMathToLLVM,
/*FailOnUnsupportedFP=*/true>;
ConstrainedVectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
/*HasRoundingMode=*/false,
arith::AttrConvertFastMathToLLVM,
/*FailOnUnsupportedFP=*/true>;
using ConstrainedSubFOpLowering = ConstrainedVectorConvertToLLVMPattern<
arith::SubFOp, LLVM::ConstrainedFSubIntr, /*HasRoundingMode=*/true,
arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
using SubIOpLowering =
VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
arith::AttrConvertOverflowToLLVM>;
using TruncFOpLowering =
ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
false, AttrConvertPassThrough,
/*HasRoundingMode=*/false,
AttrConvertPassThrough,
/*FailOnUnsupportedFP=*/true>;
using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true,
arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, /*HasRoundingMode=*/true,
arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
using TruncIOpLowering =
VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp,
Expand Down Expand Up @@ -700,6 +724,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
// clang-format off
patterns.add<
AddFOpLowering,
ConstrainedAddFOpLowering,
AddIOpLowering,
AndIOpLowering,
AddUIExtendedOpLowering,
Expand All @@ -708,6 +733,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
CmpFOpLowering,
CmpIOpLowering,
DivFOpLowering,
ConstrainedDivFOpLowering,
DivSIOpLowering,
DivUIOpLowering,
ExtFOpLowering,
Expand All @@ -727,6 +753,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
MinSIOpLowering,
MinUIOpLowering,
MulFOpLowering,
ConstrainedMulFOpLowering,
MulIOpLowering,
MulSIExtendedOpLowering,
MulUIExtendedOpLowering,
Expand All @@ -742,6 +769,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
ShRUIOpLowering,
SIToFPOpLowering,
SubFOpLowering,
ConstrainedSubFOpLowering,
SubIOpLowering,
TruncFOpLowering,
ConstrainedTruncFOpLowering,
Expand Down
Loading