-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][arith] arith-to-apfloat: Add vector support
#171024
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-arith @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesAdd support for vectorized operations such as Patch is 35.67 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/171024.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 75ab4b64b7f38..fcbaf3ccc1486 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -198,7 +198,8 @@ def ArithToAPFloatConversionPass
calls (APFloatWrappers.cpp). APFloat is a software implementation of
floating-point arithmetic operations.
}];
- let dependentDialects = ["func::FuncDialect"];
+ let dependentDialects = ["arith::ArithDialect", "func::FuncDialect",
+ "vector::VectorDialect"];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
index 4776ba0f49b94..e18316eae486b 100644
--- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Utils/Utils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
@@ -90,6 +91,75 @@ static Value getSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy) {
b.getIntegerAttr(b.getI32Type(), sem));
}
+/// Given two operands of vector type and vector result type (with the same
+/// shape), call the given function for each pair of scalar operands and
+/// package the result into a vector. If the given operands and result type are
+/// not vectors, call the function directly. The second operand is optional.
+template <typename Fn, typename... Values>
+static Value forEachScalarValue(RewriterBase &rewriter, Location loc,
+ Value operand1, Value operand2, Type resultType,
+ Fn fn) {
+ auto vecTy1 = dyn_cast<VectorType>(operand1.getType());
+ if (operand2) {
+ // Sanity check: Operand types must match.
+ assert(vecTy1 == dyn_cast<VectorType>(operand2.getType()) &&
+ "expected same vector types");
+ }
+ if (!vecTy1) {
+ // Not a vector. Call the function directly.
+ return fn(operand1, operand2, resultType);
+ }
+
+ // Prepare scalar operands.
+ auto sclars1 = vector::ToElementsOp::create(rewriter, loc, operand1);
+ SmallVector<Value> scalars2;
+ if (!operand2) {
+ // No second operand. Create a vector of empty values.
+ scalars2.assign(vecTy1.getNumElements(), Value());
+ } else {
+ llvm::append_range(
+ scalars2,
+ vector::ToElementsOp::create(rewriter, loc, operand2)->getResults());
+ }
+
+ // Call the function for each pair of scalar operands.
+ auto resultVecType = cast<VectorType>(resultType);
+ SmallVector<Value> results;
+ for (auto [scalar1, scalar2] : llvm::zip(sclars1->getResults(), scalars2)) {
+ Value result = fn(scalar1, scalar2, resultVecType.getElementType());
+ results.push_back(result);
+ }
+
+ // Package the results into a vector.
+ return vector::FromElementsOp::create(
+ rewriter, loc,
+ vecTy1.cloneWith(/*shape=*/std::nullopt, results.front().getType()),
+ results);
+}
+
+/// Check preconditions for the conversion:
+/// 1. All operands / results must be integers or floats (or vectors thereof).
+/// 2. The bitwidth of the operands / results must be <= 64.
+static LogicalResult checkPreconditions(RewriterBase &rewriter, Operation *op) {
+ SmallVector<Value> values;
+ llvm::append_range(values, op->getOperands());
+ llvm::append_range(values, op->getResults());
+ for (Value value : values) {
+ Type type = value.getType();
+ if (auto vecTy = dyn_cast<VectorType>(type)) {
+ type = vecTy.getElementType();
+ }
+ if (!type.isIntOrFloat()) {
+ return rewriter.notifyMatchFailure(
+ op, "only integers and floats (or vectors thereof) are supported");
+ }
+ if (type.getIntOrFloatBitWidth() > 64)
+ return rewriter.notifyMatchFailure(op,
+ "bitwidth > 64 bits is not supported");
+ }
+ return success();
+}
+
/// Rewrite a binary arithmetic operation to an APFloat function call.
template <typename OpTy>
struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
@@ -102,9 +172,8 @@ struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
- if (op.getType().getIntOrFloatBitWidth() > 64)
- return rewriter.notifyMatchFailure(op,
- "bitwidth > 64 bits is not supported");
+ if (failed(checkPreconditions(rewriter, op)))
+ return failure();
// Get APFloat function from runtime library.
FailureOr<FuncOp> fn =
@@ -112,31 +181,37 @@ struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
if (failed(fn))
return fn;
- rewriter.setInsertionPoint(op);
- // Cast operands to 64-bit integers.
+ // Scalarize and convert to APFloat runtime calls.
Location loc = op.getLoc();
- auto floatTy = cast<FloatType>(op.getType());
- auto intWType = rewriter.getIntegerType(floatTy.getWidth());
- auto int64Type = rewriter.getI64Type();
- Value lhsBits = arith::ExtUIOp::create(
- rewriter, loc, int64Type,
- arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs()));
- Value rhsBits = arith::ExtUIOp::create(
- rewriter, loc, int64Type,
- arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));
-
- // Call APFloat function.
- Value semValue = getSemanticsValue(rewriter, loc, floatTy);
- SmallVector<Value> params = {semValue, lhsBits, rhsBits};
- auto resultOp =
- func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
- SymbolRefAttr::get(*fn), params);
-
- // Truncate result to the original width.
- Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
- resultOp->getResult(0));
- rewriter.replaceOp(
- op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
+ rewriter.setInsertionPoint(op);
+ Value repl = forEachScalarValue(
+ rewriter, loc, op.getLhs(), op.getRhs(), op.getType(),
+ [&](Value lhs, Value rhs, Type resultType) {
+ // Cast operands to 64-bit integers.
+ auto floatTy = cast<FloatType>(resultType);
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ auto int64Type = rewriter.getI64Type();
+ Value lhsBits = arith::ExtUIOp::create(
+ rewriter, loc, int64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, lhs));
+ Value rhsBits = arith::ExtUIOp::create(
+ rewriter, loc, int64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, rhs));
+
+ // Call APFloat function.
+ Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+ SmallVector<Value> params = {semValue, lhsBits, rhsBits};
+ auto resultOp = func::CallOp::create(rewriter, loc,
+ TypeRange(rewriter.getI64Type()),
+ SymbolRefAttr::get(*fn), params);
+
+ // Truncate result to the original width.
+ Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
+ resultOp->getResult(0));
+ return arith::BitcastOp::create(rewriter, loc, floatTy,
+ truncatedBits);
+ });
+ rewriter.replaceOp(op, repl);
return success();
}
@@ -152,10 +227,8 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> {
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
- if (op.getType().getIntOrFloatBitWidth() > 64 ||
- op.getOperand().getType().getIntOrFloatBitWidth() > 64)
- return rewriter.notifyMatchFailure(op,
- "bitwidth > 64 bits is not supported");
+ if (failed(checkPreconditions(rewriter, op)))
+ return failure();
// Get APFloat function from runtime library.
auto i32Type = IntegerType::get(symTable->getContext(), 32);
@@ -165,30 +238,36 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> {
if (failed(fn))
return fn;
- rewriter.setInsertionPoint(op);
- // Cast operands to 64-bit integers.
+ // Scalarize and convert to APFloat runtime calls.
Location loc = op.getLoc();
- auto inFloatTy = cast<FloatType>(op.getOperand().getType());
- auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
- Value operandBits = arith::ExtUIOp::create(
- rewriter, loc, i64Type,
- arith::BitcastOp::create(rewriter, loc, inIntWType, op.getOperand()));
-
- // Call APFloat function.
- Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
- auto outFloatTy = cast<FloatType>(op.getType());
- Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
- std::array<Value, 3> params = {inSemValue, outSemValue, operandBits};
- auto resultOp =
- func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
- SymbolRefAttr::get(*fn), params);
-
- // Truncate result to the original width.
- auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth());
- Value truncatedBits = arith::TruncIOp::create(rewriter, loc, outIntWType,
- resultOp->getResult(0));
- rewriter.replaceOp(
- op, arith::BitcastOp::create(rewriter, loc, outFloatTy, truncatedBits));
+ rewriter.setInsertionPoint(op);
+ Value repl = forEachScalarValue(
+ rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
+ [&](Value operand1, Value operand2, Type resultType) {
+ // Cast operands to 64-bit integers.
+ auto inFloatTy = cast<FloatType>(operand1.getType());
+ auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
+ Value operandBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
+
+ // Call APFloat function.
+ Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
+ auto outFloatTy = cast<FloatType>(resultType);
+ Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
+ std::array<Value, 3> params = {inSemValue, outSemValue, operandBits};
+ auto resultOp = func::CallOp::create(rewriter, loc,
+ TypeRange(rewriter.getI64Type()),
+ SymbolRefAttr::get(*fn), params);
+
+ // Truncate result to the original width.
+ auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth());
+ Value truncatedBits = arith::TruncIOp::create(
+ rewriter, loc, outIntWType, resultOp->getResult(0));
+ return arith::BitcastOp::create(rewriter, loc, outFloatTy,
+ truncatedBits);
+ });
+ rewriter.replaceOp(op, repl);
return success();
}
@@ -204,10 +283,8 @@ struct FpToIntConversion final : OpRewritePattern<OpTy> {
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
- if (op.getType().getIntOrFloatBitWidth() > 64 ||
- op.getOperand().getType().getIntOrFloatBitWidth() > 64)
- return rewriter.notifyMatchFailure(op,
- "bitwidth > 64 bits is not supported");
+ if (failed(checkPreconditions(rewriter, op)))
+ return failure();
// Get APFloat function from runtime library.
auto i1Type = IntegerType::get(symTable->getContext(), 1);
@@ -219,33 +296,39 @@ struct FpToIntConversion final : OpRewritePattern<OpTy> {
if (failed(fn))
return fn;
- rewriter.setInsertionPoint(op);
- // Cast operands to 64-bit integers.
+ // Scalarize and convert to APFloat runtime calls.
Location loc = op.getLoc();
- auto inFloatTy = cast<FloatType>(op.getOperand().getType());
- auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
- Value operandBits = arith::ExtUIOp::create(
- rewriter, loc, i64Type,
- arith::BitcastOp::create(rewriter, loc, inIntWType, op.getOperand()));
-
- // Call APFloat function.
- Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
- auto outIntTy = cast<IntegerType>(op.getType());
- Value outWidthValue = arith::ConstantOp::create(
- rewriter, loc, i32Type,
- rewriter.getIntegerAttr(i32Type, outIntTy.getWidth()));
- Value isUnsignedValue = arith::ConstantOp::create(
- rewriter, loc, i1Type, rewriter.getIntegerAttr(i1Type, isUnsigned));
- SmallVector<Value> params = {inSemValue, outWidthValue, isUnsignedValue,
- operandBits};
- auto resultOp =
- func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
- SymbolRefAttr::get(*fn), params);
-
- // Truncate result to the original width.
- Value truncatedBits = arith::TruncIOp::create(rewriter, loc, outIntTy,
- resultOp->getResult(0));
- rewriter.replaceOp(op, truncatedBits);
+ rewriter.setInsertionPoint(op);
+ Value repl = forEachScalarValue(
+ rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
+ [&](Value operand1, Value operand2, Type resultType) {
+ // Cast operands to 64-bit integers.
+ auto inFloatTy = cast<FloatType>(operand1.getType());
+ auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
+ Value operandBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
+
+ // Call APFloat function.
+ Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
+ auto outIntTy = cast<IntegerType>(resultType);
+ Value outWidthValue = arith::ConstantOp::create(
+ rewriter, loc, i32Type,
+ rewriter.getIntegerAttr(i32Type, outIntTy.getWidth()));
+ Value isUnsignedValue = arith::ConstantOp::create(
+ rewriter, loc, i1Type,
+ rewriter.getIntegerAttr(i1Type, isUnsigned));
+ SmallVector<Value> params = {inSemValue, outWidthValue,
+ isUnsignedValue, operandBits};
+ auto resultOp = func::CallOp::create(rewriter, loc,
+ TypeRange(rewriter.getI64Type()),
+ SymbolRefAttr::get(*fn), params);
+
+ // Truncate result to the original width.
+ return arith::TruncIOp::create(rewriter, loc, outIntTy,
+ resultOp->getResult(0));
+ });
+ rewriter.replaceOp(op, repl);
return success();
}
@@ -262,10 +345,8 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
- if (op.getType().getIntOrFloatBitWidth() > 64 ||
- op.getOperand().getType().getIntOrFloatBitWidth() > 64)
- return rewriter.notifyMatchFailure(op,
- "bitwidth > 64 bits is not supported");
+ if (failed(checkPreconditions(rewriter, op)))
+ return failure();
// Get APFloat function from runtime library.
auto i1Type = IntegerType::get(symTable->getContext(), 1);
@@ -277,42 +358,48 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
if (failed(fn))
return fn;
- rewriter.setInsertionPoint(op);
- // Cast operands to 64-bit integers.
+ // Scalarize and convert to APFloat runtime calls.
Location loc = op.getLoc();
- auto inIntTy = cast<IntegerType>(op.getOperand().getType());
- Value operandBits = op.getOperand();
- if (operandBits.getType().getIntOrFloatBitWidth() < 64) {
- if (isUnsigned) {
- operandBits =
- arith::ExtUIOp::create(rewriter, loc, i64Type, operandBits);
- } else {
- operandBits =
- arith::ExtSIOp::create(rewriter, loc, i64Type, operandBits);
- }
- }
-
- // Call APFloat function.
- auto outFloatTy = cast<FloatType>(op.getType());
- Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
- Value inWidthValue = arith::ConstantOp::create(
- rewriter, loc, i32Type,
- rewriter.getIntegerAttr(i32Type, inIntTy.getWidth()));
- Value isUnsignedValue = arith::ConstantOp::create(
- rewriter, loc, i1Type, rewriter.getIntegerAttr(i1Type, isUnsigned));
- SmallVector<Value> params = {outSemValue, inWidthValue, isUnsignedValue,
- operandBits};
- auto resultOp =
- func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
- SymbolRefAttr::get(*fn), params);
-
- // Truncate result to the original width.
- auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth());
- Value truncatedBits = arith::TruncIOp::create(rewriter, loc, outIntWType,
- resultOp->getResult(0));
- Value result =
- arith::BitcastOp::create(rewriter, loc, outFloatTy, truncatedBits);
- rewriter.replaceOp(op, result);
+ rewriter.setInsertionPoint(op);
+ Value repl = forEachScalarValue(
+ rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
+ [&](Value operand1, Value operand2, Type resultType) {
+ // Cast operands to 64-bit integers.
+ auto inIntTy = cast<IntegerType>(operand1.getType());
+ Value operandBits = operand1;
+ if (operandBits.getType().getIntOrFloatBitWidth() < 64) {
+ if (isUnsigned) {
+ operandBits =
+ arith::ExtUIOp::create(rewriter, loc, i64Type, operandBits);
+ } else {
+ operandBits =
+ arith::ExtSIOp::create(rewriter, loc, i64Type, operandBits);
+ }
+ }
+
+ // Call APFloat function.
+ auto outFloatTy = cast<FloatType>(resultType);
+ Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
+ Value inWidthValue = arith::ConstantOp::create(
+ rewriter, loc, i32Type,
+ rewriter.getIntegerAttr(i32Type, inIntTy.getWidth()));
+ Value isUnsignedValue = arith::ConstantOp::create(
+ rewriter, loc, i1Type,
+ rewriter.getIntegerAttr(i1Type, isUnsigned));
+ SmallVector<Value> params = {outSemValue, inWidthValue,
+ isUnsignedValue, operandBits};
+ auto resultOp = func::CallOp::create(rewriter, loc,
+ TypeRange(rewriter.getI64Type()),
+ SymbolRefAttr::get(*fn), params);
+
+ // Truncate result to the original width.
+ auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth());
+ Value truncatedBits = arith::TruncIOp::create(
+ rewriter, loc, outIntWType, resultOp->getResult(0));
+ return arith::BitcastOp::create(rewriter, loc, outFloatTy,
+ truncatedBits);
+ });
+ rewriter.replaceOp(op, repl);
return success();
}
@@ -327,9 +414,8 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
LogicalResult matchAndRewrite(arith::CmpFOp op,
PatternRewriter &rewriter) const override {
- if (op.getLhs().getType().getIntOrFloatBitWidth() > 64)
- return rewriter.notifyMatchFailure(op,
- ...
[truncated]
|
9da9241 to
a8474dd
Compare
a8474dd to
c369b96
Compare
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/177/builds/25495 Here is the relevant piece of the build log for the reference |
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/116/builds/21903 Here is the relevant piece of the build log for the reference |
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/138/builds/22810 Here is the relevant piece of the build log for the reference |
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/198/builds/10292 Here is the relevant piece of the build log for the reference |
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/143/builds/13112 Here is the relevant piece of the build log for the reference |
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/17/builds/13274 Here is the relevant piece of the build log for the reference |
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/199/builds/7753 Here is the relevant piece of the build log for the reference |
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/4/builds/11006 Here is the relevant piece of the build log for the reference |
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/41/builds/10743 Here is the relevant piece of the build log for the reference |
Add support for vectorized operations such as
arith.addf ... : vector<4xf4E2M1FN>. The computation is scalarized: scalar operands are extracted withvector.to_elements, multiple scalar computations are performed and the result is inserted back into a vector withvector.from_elements.