-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][linalg] convert arith ops to destination-passing-style. #157854
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -17,13 +17,20 @@ | |||||||||||||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" | ||||||||||||||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h" | ||||||||||||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h" | ||||||||||||||||
#include "mlir/Dialect/Linalg/Passes.h" | ||||||||||||||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" | ||||||||||||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h" | ||||||||||||||||
#include "mlir/Dialect/Utils/StaticValueUtils.h" | ||||||||||||||||
#include "mlir/IR/Matchers.h" | ||||||||||||||||
#include "mlir/IR/PatternMatch.h" | ||||||||||||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||||||||||||||||
#include "llvm/ADT/STLExtras.h" | ||||||||||||||||
|
||||||||||||||||
namespace mlir { | ||||||||||||||||
#define GEN_PASS_DEF_LINALGCONVERTTODPSPASS | ||||||||||||||||
#include "mlir/Dialect/Linalg/Passes.h.inc" | ||||||||||||||||
} // namespace mlir | ||||||||||||||||
|
||||||||||||||||
using namespace mlir; | ||||||||||||||||
using namespace mlir::tensor; | ||||||||||||||||
|
||||||||||||||||
|
@@ -96,7 +103,7 @@ static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter, | |||||||||||||||
OpBuilder::InsertionGuard g(rewriter); | ||||||||||||||||
RankedTensorType resultType = padOp.getResultType(); | ||||||||||||||||
|
||||||||||||||||
// Examine the yielded value to decide if a linalg.generic is neede or a | ||||||||||||||||
// Examine the yielded value to decide if a linalg.generic is needed or a | ||||||||||||||||
// linalg.fill is sufficient. | ||||||||||||||||
Value yieldedValue = | ||||||||||||||||
cast<tensor::YieldOp>(padOp.getBody()->getTerminator()).getValue(); | ||||||||||||||||
|
@@ -603,6 +610,94 @@ Value linalg::bufferizeToAllocation( | |||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
namespace { | ||||||||||||||||
/// Rewrites an arith op operating on tensors, e.g. | ||||||||||||||||
/// `%z = arith.addf %x, %y : tensor<5xf32>` | ||||||||||||||||
/// into an equivalent linalg.generic in destination-passing-style. | ||||||||||||||||
/// ```mlir | ||||||||||||||||
/// %0 = tensor.empty() : tensor<5xf32> | ||||||||||||||||
/// %1 = linalg.generic ... | ||||||||||||||||
/// ins(%x, %y : tensor<5xf32>, tensor<5xf32>) | ||||||||||||||||
/// outs(%0 : tensor<5xf32>) { | ||||||||||||||||
/// ^bb0(%in: f32, %in_0: f32, %out: f32): | ||||||||||||||||
/// %2 = arith.addf %in, %in_0 : f32 | ||||||||||||||||
/// linalg.yield %2 : f32 | ||||||||||||||||
/// } -> tensor<5xf32> | ||||||||||||||||
template <typename OpTy> | ||||||||||||||||
FailureOr<Operation *> | ||||||||||||||||
rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) { | ||||||||||||||||
// Reject ops such as `arith.constant` and `arith.select`. | ||||||||||||||||
// constants don't need dps conversion and select is a a `todo`. | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||
auto numOperands = op->getNumOperands(); | ||||||||||||||||
if (numOperands == 0 || numOperands > 2) | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. only unary and binary we care about. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why?
This is what the code tells me, yes. But it doesn't say why. Also, if that's the case, then this would be clearer to me: if (numOperands != 1 && numOperands != 2) EDIT Sorry, posted my comment before noticing that this comment has been updated.
Why not? This is assuming that the only purpose of this code is to help with bufferization. That is fine with, but make it clear. Otherwise, "constants don't need dps conversion" sounds very arbitrary (missing "why"). |
||||||||||||||||
return failure(); | ||||||||||||||||
|
||||||||||||||||
// destination passing style rewrite is only for ops on tensor types. | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||
Type resultType = op->getResult(0).getType(); | ||||||||||||||||
auto tensorType = dyn_cast<RankedTensorType>(resultType); | ||||||||||||||||
if (!tensorType) | ||||||||||||||||
return failure(); | ||||||||||||||||
|
||||||||||||||||
auto loc = op.getLoc(); | ||||||||||||||||
OpBuilder::InsertionGuard g(rewriter); | ||||||||||||||||
auto dynSizes = reifyOrComputeDynamicSizes(rewriter, op->getOperand(0)); | ||||||||||||||||
|
||||||||||||||||
// Create tensor.empty for `outs` of destination-passing-style. | ||||||||||||||||
Value outs = tensor::EmptyOp::create(rewriter, loc, resultType, dynSizes); | ||||||||||||||||
|
||||||||||||||||
// Create linalg.generic | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||
auto rank = tensorType.getRank(); | ||||||||||||||||
SmallVector<AffineMap> indexingMaps(numOperands + 1, | ||||||||||||||||
rewriter.getMultiDimIdentityMap(rank)); | ||||||||||||||||
SmallVector<utils::IteratorType> iteratorTypes(rank, | ||||||||||||||||
utils::IteratorType::parallel); | ||||||||||||||||
|
||||||||||||||||
// Check 'fast-math'. If present, propagate it. | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we model those when specializing? While a If today we reject specialization based on these flags, that's ok. But if we don't, then we'll change the semantics and the round-trip will fail. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
+1, but lets make sure this is verified with tests. |
||||||||||||||||
auto fmfOpInterface = | ||||||||||||||||
llvm::dyn_cast<arith::ArithFastMathInterface>(op.getOperation()); | ||||||||||||||||
|
||||||||||||||||
auto genericOp = linalg::GenericOp::create( | ||||||||||||||||
rewriter, loc, tensorType, | ||||||||||||||||
op->getOperands(), // inputs | ||||||||||||||||
ValueRange{outs}, // outputs | ||||||||||||||||
Comment on lines
+660
to
+661
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use |
||||||||||||||||
indexingMaps, iteratorTypes, | ||||||||||||||||
[&](OpBuilder &builder, Location loc, ValueRange args) { | ||||||||||||||||
Value res; | ||||||||||||||||
if (args.size() == 2) { | ||||||||||||||||
if (fmfOpInterface) { | ||||||||||||||||
auto attr = fmfOpInterface.getFastMathFlagsAttr(); | ||||||||||||||||
auto fmf = rewriter.getNamedAttr("fastmath", attr); | ||||||||||||||||
res = builder | ||||||||||||||||
.create<OpTy>(loc, args[1].getType(), ValueRange{args[0]}, | ||||||||||||||||
fmf) | ||||||||||||||||
.getResult(); | ||||||||||||||||
} else { | ||||||||||||||||
res = builder | ||||||||||||||||
.create<OpTy>(loc, args[1].getType(), ValueRange{args[0]}) | ||||||||||||||||
.getResult(); | ||||||||||||||||
} | ||||||||||||||||
} else if (args.size() == 3) { | ||||||||||||||||
if (fmfOpInterface) { | ||||||||||||||||
auto attr = fmfOpInterface.getFastMathFlagsAttr(); | ||||||||||||||||
auto fmf = rewriter.getNamedAttr("fastmath", attr); | ||||||||||||||||
res = builder | ||||||||||||||||
.create<OpTy>(loc, args[2].getType(), | ||||||||||||||||
ValueRange{args[0], args[1]}, fmf) | ||||||||||||||||
.getResult(); | ||||||||||||||||
} else { | ||||||||||||||||
res = builder | ||||||||||||||||
.create<OpTy>(loc, args[2].getType(), | ||||||||||||||||
ValueRange{args[0], args[1]}) | ||||||||||||||||
.getResult(); | ||||||||||||||||
} | ||||||||||||||||
} else | ||||||||||||||||
llvm_unreachable("did not expect ops other than nary and binary"); | ||||||||||||||||
linalg::YieldOp::create(builder, loc, res); | ||||||||||||||||
Comment on lines
+692
to
+694
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I think the convention is that if one branch if/else branch uses braces, all the other ones should too
Suggested change
|
||||||||||||||||
}); | ||||||||||||||||
|
||||||||||||||||
rewriter.replaceAllUsesWith(op, genericOp.getResult(0)); | ||||||||||||||||
rewriter.eraseOp(op); | ||||||||||||||||
return genericOp.getOperation(); | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
template <typename OpTy> | ||||||||||||||||
LogicalResult rewriteOpInDestinationPassingStyle(OpTy op, | ||||||||||||||||
|
@@ -612,9 +707,53 @@ LogicalResult rewriteOpInDestinationPassingStyle(OpTy op, | |||||||||||||||
|
||||||||||||||||
} // namespace | ||||||||||||||||
|
||||||||||||||||
#define STAMP_OUT_ARITH_DPS_FUNCS(OPTY) \ | ||||||||||||||||
FailureOr<Operation *> linalg::rewriteInDestinationPassingStyle( \ | ||||||||||||||||
RewriterBase &rewriter, OPTY op) { \ | ||||||||||||||||
return rewriteArithInDestinationPassingStyle<OPTY>(rewriter, op); \ | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
STAMP_OUT_ARITH_DPS_FUNCS(arith::UIToFPOp) | ||||||||||||||||
STAMP_OUT_ARITH_DPS_FUNCS(arith::SIToFPOp) | ||||||||||||||||
STAMP_OUT_ARITH_DPS_FUNCS(arith::FPToUIOp) | ||||||||||||||||
STAMP_OUT_ARITH_DPS_FUNCS(arith::FPToSIOp) | ||||||||||||||||
|
||||||||||||||||
STAMP_OUT_ARITH_DPS_FUNCS(arith::AddIOp) | ||||||||||||||||
|
||||||||||||||||
STAMP_OUT_ARITH_DPS_FUNCS(arith::AddFOp) | ||||||||||||||||
STAMP_OUT_ARITH_DPS_FUNCS(arith::DivFOp) | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can undef this macro here |
||||||||||||||||
|
||||||||||||||||
void linalg::populateConvertToDestinationStylePatterns( | ||||||||||||||||
RewritePatternSet &patterns) { | ||||||||||||||||
patterns.add(rewriteOpInDestinationPassingStyle<tensor::FromElementsOp>); | ||||||||||||||||
patterns.add(rewriteOpInDestinationPassingStyle<tensor::GenerateOp>); | ||||||||||||||||
patterns.add(rewriteOpInDestinationPassingStyle<tensor::PadOp>); | ||||||||||||||||
|
||||||||||||||||
patterns.add(rewriteOpInDestinationPassingStyle<arith::UIToFPOp>); | ||||||||||||||||
patterns.add(rewriteOpInDestinationPassingStyle<arith::SIToFPOp>); | ||||||||||||||||
patterns.add(rewriteOpInDestinationPassingStyle<arith::FPToUIOp>); | ||||||||||||||||
patterns.add(rewriteOpInDestinationPassingStyle<arith::FPToSIOp>); | ||||||||||||||||
|
||||||||||||||||
patterns.add(rewriteOpInDestinationPassingStyle<arith::AddIOp>); | ||||||||||||||||
|
||||||||||||||||
patterns.add(rewriteOpInDestinationPassingStyle<arith::AddFOp>); | ||||||||||||||||
patterns.add(rewriteOpInDestinationPassingStyle<arith::DivFOp>); | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
namespace { | ||||||||||||||||
struct LinalgConvertToDPSPass | ||||||||||||||||
: public impl::LinalgConvertToDPSPassBase<LinalgConvertToDPSPass> { | ||||||||||||||||
using impl::LinalgConvertToDPSPassBase< | ||||||||||||||||
LinalgConvertToDPSPass>::LinalgConvertToDPSPassBase; | ||||||||||||||||
Comment on lines
+746
to
+747
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||
|
||||||||||||||||
void runOnOperation() override; | ||||||||||||||||
}; | ||||||||||||||||
|
||||||||||||||||
void LinalgConvertToDPSPass::runOnOperation() { | ||||||||||||||||
|
||||||||||||||||
RewritePatternSet patterns(&getContext()); | ||||||||||||||||
linalg::populateConvertToDestinationStylePatterns(patterns); | ||||||||||||||||
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would the walk rewrite driver work here? I think this will never have to rewrite the same op twice or visit newly created ops. https://mlir.llvm.org/docs/PatternRewriter/#walk-pattern-rewrite-driver |
||||||||||||||||
signalPassFailure(); | ||||||||||||||||
} | ||||||||||||||||
} // namespace |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would suggest creating a new test file with Arith Ops. I see two reasons:
If you prefer to keep everything in one file, could you add a big comment separating
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -252,3 +252,96 @@ module attributes {transform.with_named_sequence} { | |||||||||||||||||||||||||||||||||||||
transform.yield | ||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
// ----- | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0)> | ||||||||||||||||||||||||||||||||||||||
// CHECK-LABEL: func @arith_unary_op( | ||||||||||||||||||||||||||||||||||||||
// CHECK-SAME: %[[X:.+]]: tensor<64xi32>) -> tensor<64xf32> { | ||||||||||||||||||||||||||||||||||||||
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<64xf32> | ||||||||||||||||||||||||||||||||||||||
// CHECK: %[[GENERIC:.+]] = linalg.generic | ||||||||||||||||||||||||||||||||||||||
// CHECK-SAME: {indexing_maps = [#[[$map]], #[[$map]]], iterator_types = ["parallel"]} | ||||||||||||||||||||||||||||||||||||||
// CHECK-SAME: ins(%[[X:.+]] : tensor<64xi32>) outs(%[[EMPTY]] : tensor<64xf32>) { | ||||||||||||||||||||||||||||||||||||||
// CHECK: ^bb0(%[[x:.+]]: i32, %[[Out:.+]]: f32): | ||||||||||||||||||||||||||||||||||||||
// CHECK: %[[z:.+]] = arith.uitofp %[[x]] : i32 to f32 | ||||||||||||||||||||||||||||||||||||||
// CHECK: linalg.yield %[[z]] : f32 | ||||||||||||||||||||||||||||||||||||||
// CHECK: return %[[GENERIC]] : tensor<64xf32> | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
func.func @arith_unary_op(%x : tensor<64xi32>) -> tensor<64xf32> { | ||||||||||||||||||||||||||||||||||||||
%z = arith.uitofp %x : tensor<64xi32> to tensor<64xf32> | ||||||||||||||||||||||||||||||||||||||
return %z : tensor<64xf32> | ||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
module attributes {transform.with_named_sequence} { | ||||||||||||||||||||||||||||||||||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { | ||||||||||||||||||||||||||||||||||||||
%0 = transform.structured.match ops{["arith.uitofp"]} in %arg1 | ||||||||||||||||||||||||||||||||||||||
: (!transform.any_op) -> !transform.any_op | ||||||||||||||||||||||||||||||||||||||
transform.structured.rewrite_in_destination_passing_style %0 | ||||||||||||||||||||||||||||||||||||||
: (!transform.any_op) -> !transform.any_op | ||||||||||||||||||||||||||||||||||||||
transform.yield | ||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
// ----- | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0)> | ||||||||||||||||||||||||||||||||||||||
// CHECK-LABEL: func @arith_binop( | ||||||||||||||||||||||||||||||||||||||
// CHECK-SAME: %[[X:.+]]: tensor<?xf32>, %[[Y:.+]]: tensor<?xf32> | ||||||||||||||||||||||||||||||||||||||
// CHECK: %[[C0:.+]] = arith.constant 0 : index | ||||||||||||||||||||||||||||||||||||||
// CHECK: %[[DIM:.+]] = tensor.dim %[[X]], %[[C0]] : tensor<?xf32> | ||||||||||||||||||||||||||||||||||||||
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?xf32> | ||||||||||||||||||||||||||||||||||||||
// CHECK: %[[GENERIC:.+]] = linalg.generic | ||||||||||||||||||||||||||||||||||||||
// CHECK-SAME: {indexing_maps = [#[[$map]], #[[$map]], #[[$map]]], iterator_types = ["parallel"]} | ||||||||||||||||||||||||||||||||||||||
// CHECK-SAME: ins(%[[X:.+]], %[[Y:.+]] : tensor<?xf32>, tensor<?xf32>) outs(%[[EMPTY]] : tensor<?xf32>) { | ||||||||||||||||||||||||||||||||||||||
// CHECK: ^bb0(%[[x:.+]]: f32, %[[y:.+]]: f32, %[[Out:.+]]: f32): | ||||||||||||||||||||||||||||||||||||||
// CHECK: %[[z:.+]] = arith.addf %[[x]], %[[y]] : f32 | ||||||||||||||||||||||||||||||||||||||
// CHECK: linalg.yield %[[z]] : f32 | ||||||||||||||||||||||||||||||||||||||
// CHECK: return %[[GENERIC]] : tensor<?xf32> | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
func.func @arith_binop(%x : tensor<?xf32>, %y : tensor<?xf32>) | ||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please use consistent naming.
Suggested change
|
||||||||||||||||||||||||||||||||||||||
-> tensor<?xf32> { | ||||||||||||||||||||||||||||||||||||||
%z = arith.addf %x, %y : tensor<?xf32> | ||||||||||||||||||||||||||||||||||||||
return %z : tensor<?xf32> | ||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
module attributes {transform.with_named_sequence} { | ||||||||||||||||||||||||||||||||||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { | ||||||||||||||||||||||||||||||||||||||
%0 = transform.structured.match ops{["arith.addf"]} in %arg1 | ||||||||||||||||||||||||||||||||||||||
: (!transform.any_op) -> !transform.any_op | ||||||||||||||||||||||||||||||||||||||
transform.structured.rewrite_in_destination_passing_style %0 | ||||||||||||||||||||||||||||||||||||||
: (!transform.any_op) -> !transform.any_op | ||||||||||||||||||||||||||||||||||||||
transform.yield | ||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||
kuhar marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
// ----- | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0)> | ||||||||||||||||||||||||||||||||||||||
// CHECK-LABEL: func @arith_binop_fastmath( | ||||||||||||||||||||||||||||||||||||||
// CHECK-SAME: %[[X:.+]]: tensor<?xf32>, %[[Y:.+]]: tensor<?xf32> | ||||||||||||||||||||||||||||||||||||||
// CHECK: %[[C0:.+]] = arith.constant 0 : index | ||||||||||||||||||||||||||||||||||||||
// CHECK: %[[DIM:.+]] = tensor.dim %[[X]], %[[C0]] : tensor<?xf32> | ||||||||||||||||||||||||||||||||||||||
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?xf32> | ||||||||||||||||||||||||||||||||||||||
// CHECK: %[[GENERIC:.+]] = linalg.generic | ||||||||||||||||||||||||||||||||||||||
// CHECK-SAME: {indexing_maps = [#[[$map]], #[[$map]], #[[$map]]], iterator_types = ["parallel"]} | ||||||||||||||||||||||||||||||||||||||
// CHECK-SAME: ins(%[[X:.+]], %[[Y:.+]] : tensor<?xf32>, tensor<?xf32>) outs(%[[EMPTY]] : tensor<?xf32>) { | ||||||||||||||||||||||||||||||||||||||
// CHECK: ^bb0(%[[x:.+]]: f32, %[[y:.+]]: f32, %[[Out:.+]]: f32): | ||||||||||||||||||||||||||||||||||||||
// CHECK: %[[z:.+]] = arith.addf %[[x]], %[[y]] fastmath<fast> : f32 | ||||||||||||||||||||||||||||||||||||||
// CHECK: linalg.yield %[[z]] : f32 | ||||||||||||||||||||||||||||||||||||||
// CHECK: return %[[GENERIC]] : tensor<?xf32> | ||||||||||||||||||||||||||||||||||||||
Comment on lines
+319
to
+331
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You have already checked all the fine details when testing
Suggested change
|
||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
func.func @arith_binop_fastmath(%x : tensor<?xf32>, %y : tensor<?xf32>) | ||||||||||||||||||||||||||||||||||||||
-> tensor<?xf32> { | ||||||||||||||||||||||||||||||||||||||
%z = arith.addf %x, %y fastmath<fast> : tensor<?xf32> | ||||||||||||||||||||||||||||||||||||||
return %z : tensor<?xf32> | ||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
module attributes {transform.with_named_sequence} { | ||||||||||||||||||||||||||||||||||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { | ||||||||||||||||||||||||||||||||||||||
%0 = transform.structured.match ops{["arith.addf"]} in %arg1 | ||||||||||||||||||||||||||||||||||||||
: (!transform.any_op) -> !transform.any_op | ||||||||||||||||||||||||||||||||||||||
transform.structured.rewrite_in_destination_passing_style %0 | ||||||||||||||||||||||||||||||||||||||
: (!transform.any_op) -> !transform.any_op | ||||||||||||||||||||||||||||||||||||||
transform.yield | ||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
// RUN: mlir-opt %s -lower-quant-ops -linalg-convert-to-dps \ | ||
// RUN: -linalg-specialize-generic-ops -cse | FileCheck %s | ||
|
||
// CHECK-LABEL: func.func @lower_qcast_to_dps( | ||
// CHECK-SAME: %[[X:.+]]: tensor<10xf32>) -> tensor<10x!quant.uniform<i8:f32, 2.000000e+00:10>> | ||
// CHECK-DAG: %[[CST_10I:.+]] = arith.constant dense<10> : tensor<10xi8> | ||
// CHECK-DAG: %[[CST_2F:.+]] = arith.constant dense<2.000000e+00> : tensor<10xf32> | ||
// CHECK: %[[E:.+]] = tensor.empty() : tensor<10xf32> | ||
// CHECK: %[[DIV:.+]] = linalg.div ins(%[[X]], %[[CST_2F]] : tensor<10xf32>, tensor<10xf32>) | ||
// CHECK-SAME: outs(%[[E]] : tensor<10xf32>) -> tensor<10xf32> | ||
// | ||
// CHECK: %[[SITOFP:.+]] = linalg.generic | ||
// CHECK-SAME: ins(%[[CST_10I]] : tensor<10xi8>) outs(%[[E]] : tensor<10xf32>) | ||
// CHECK: %{{.*}} = arith.sitofp %{{.*}} : i8 to f32 | ||
// | ||
// CHECK: %[[ADD:.+]] = linalg.add ins(%[[DIV]], %[[SITOFP]] : tensor<10xf32>, tensor<10xf32>) | ||
// CHECK: %{{.*}} = linalg.generic | ||
// CHECK-SAME: ins(%[[ADD]] : tensor<10xf32>) | ||
// CHECK: %{{.*}} = arith.fptosi %{{.*}} : f32 to i8 | ||
|
||
|
||
!qalias = !quant.uniform<i8:f32, 2.0:10> | ||
func.func @lower_qcast_to_dps(%arg0: tensor<10xf32>) -> tensor<10x!qalias> { | ||
%0 = quant.qcast %arg0 : tensor<10xf32> to tensor<10x!qalias> | ||
return %0 : tensor<10x!qalias> | ||
} |
Uh oh!
There was an error while loading. Please reload this page.