Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,32 @@ def LinalgFoldIntoElementwisePass : Pass<"linalg-fold-into-elementwise"> {
let dependentDialects = ["linalg::LinalgDialect"];
}

def LinalgConvertToDPSPass : Pass<"linalg-convert-to-dps"> {
let summary = "Convert ops to destination-passing-style";
let description = [{
Converts ops that operate on tensors but are not in
destination passing style (DPS) to equivalent linalg
generic which is in DPS. e.g.
```mlir
%0 = arith.uitofp %arg0 : tensor<?xi32> to tensor<?xf32>
```
gets rewritten as:
```mlir
%c0 = arith.constant 0 : index
%dim = tensor.dim %arg0, %c0 : tensor<?xi32>
%0 = tensor.empty(%dim) : tensor<?xf32>
%1 = linalg.generic
{indexing_maps = [#map, #map], iterator_types = ["parallel"]}
ins(%arg0 : tensor<?xi32>) outs(%0 : tensor<?xf32>) {
^bb0(%in: i32, %out: f32):
%2 = arith.uitofp %in : i32 to f32
linalg.yield %2 : f32
} -> tensor<?xf32>
```
}];
let dependentDialects = ["linalg::LinalgDialect"];
}

def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInterface"> {
let summary = "Detensorize linalg ops";
let dependentDialects = [];
Expand Down
17 changes: 17 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1377,6 +1377,23 @@ rewriteInDestinationPassingStyle(RewriterBase &rewriter,
FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
tensor::PadOp padOp);

FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
arith::UIToFPOp op);
FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
arith::SIToFPOp op);
FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
arith::FPToUIOp op);
FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
arith::FPToSIOp op);

FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
arith::AddIOp op);

FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
arith::AddFOp op);
FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
arith::DivFOp op);

/// Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing)
/// and linalg.matmul.
///
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2611,7 +2611,9 @@ transform::RewriteInDestinationPassingStyleOp::applyToOne(
rewriter.setInsertionPoint(target);
FailureOr<Operation *> maybeResult =
TypeSwitch<Operation *, FailureOr<Operation *>>(target)
.Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
.Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp,
arith::UIToFPOp, arith::SIToFPOp, arith::FPToUIOp,
arith::FPToSIOp, arith::AddIOp, arith::AddFOp, arith::DivFOp>(
[&rewriter](auto op) {
return rewriteInDestinationPassingStyle(rewriter, op);
});
Expand Down
141 changes: 140 additions & 1 deletion mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,20 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"

namespace mlir {
#define GEN_PASS_DEF_LINALGCONVERTTODPSPASS
#include "mlir/Dialect/Linalg/Passes.h.inc"
} // namespace mlir

using namespace mlir;
using namespace mlir::tensor;

Expand Down Expand Up @@ -96,7 +103,7 @@ static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter,
OpBuilder::InsertionGuard g(rewriter);
RankedTensorType resultType = padOp.getResultType();

// Examine the yielded value to decide if a linalg.generic is neede or a
// Examine the yielded value to decide if a linalg.generic is needed or a
// linalg.fill is sufficient.
Value yieldedValue =
cast<tensor::YieldOp>(padOp.getBody()->getTerminator()).getValue();
Expand Down Expand Up @@ -603,6 +610,94 @@ Value linalg::bufferizeToAllocation(
}

namespace {
/// Rewrites an arith op operating on tensors, e.g.
/// `%z = arith.addf %x, %y : tensor<5xf32>`
/// into an equivalent linalg.generic in destination-passing-style.
/// ```mlir
/// %0 = tensor.empty() : tensor<5xf32>
/// %1 = linalg.generic ...
/// ins(%x, %y : tensor<5xf32>, tensor<5xf32>)
/// outs(%0 : tensor<5xf32>) {
/// ^bb0(%in: f32, %in_0: f32, %out: f32):
/// %2 = arith.addf %in, %in_0 : f32
/// linalg.yield %2 : f32
/// } -> tensor<5xf32>
template <typename OpTy>
FailureOr<Operation *>
rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
// Reject ops such as `arith.constant` and `arith.select`.
// constants don't need dps conversion and select is a a `todo`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// constants don't need dps conversion and select is a a `todo`.
// constants don't need dps conversion and select is a TODO.

auto numOperands = op->getNumOperands();
if (numOperands == 0 || numOperands > 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only unary and binary we care about.

Copy link
Contributor

@banach-space banach-space Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

only unary and binary we care about.

This is what the code tells me, yes. But it doesn't say why. Also, if that's the case, then this would be clearer to me:

if (numOperands != 1 && numOperands != 2)

EDIT Sorry, posted my comment before noticing that this comment has been updated.

constants don't need dps conversion

Why not? This is assuming that the only purpose of this code is to help with bufferization. That is fine with, but make it clear. Otherwise, "constants don't need dps conversion" sounds very arbitrary (missing "why").

return failure();

// destination passing style rewrite is only for ops on tensor types.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// destination passing style rewrite is only for ops on tensor types.
// Destination passing style rewrite is only for ops on tensor types.

Type resultType = op->getResult(0).getType();
auto tensorType = dyn_cast<RankedTensorType>(resultType);
if (!tensorType)
return failure();

auto loc = op.getLoc();
OpBuilder::InsertionGuard g(rewriter);
auto dynSizes = reifyOrComputeDynamicSizes(rewriter, op->getOperand(0));

// Create tensor.empty for `outs` of destination-passing-style.
Value outs = tensor::EmptyOp::create(rewriter, loc, resultType, dynSizes);

// Create linalg.generic
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Create linalg.generic
// Create linalg.generic.

auto rank = tensorType.getRank();
SmallVector<AffineMap> indexingMaps(numOperands + 1,
rewriter.getMultiDimIdentityMap(rank));
SmallVector<utils::IteratorType> iteratorTypes(rank,
utils::IteratorType::parallel);

// Check 'fast-math'. If present, propagate it.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about overflowFlags, e.g., nsw?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we model those when specializing? While a generic will accept any payload, specializing will hide the payload.

If today we reject specialization based on these flags, that's ok. But if we don't, then we'll change the semantics and the round-trip will fail.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If today we reject specialization based on these flags, that's ok

+1, but lets make sure this is verified with tests.

auto fmfOpInterface =
llvm::dyn_cast<arith::ArithFastMathInterface>(op.getOperation());

auto genericOp = linalg::GenericOp::create(
rewriter, loc, tensorType,
op->getOperands(), // inputs
ValueRange{outs}, // outputs
Comment on lines +660 to +661
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use /*argName=*/ param -- some tools can then warn if the names don't match

indexingMaps, iteratorTypes,
[&](OpBuilder &builder, Location loc, ValueRange args) {
Value res;
if (args.size() == 2) {
if (fmfOpInterface) {
auto attr = fmfOpInterface.getFastMathFlagsAttr();
auto fmf = rewriter.getNamedAttr("fastmath", attr);
res = builder
.create<OpTy>(loc, args[1].getType(), ValueRange{args[0]},
fmf)
.getResult();
} else {
res = builder
.create<OpTy>(loc, args[1].getType(), ValueRange{args[0]})
.getResult();
}
} else if (args.size() == 3) {
if (fmfOpInterface) {
auto attr = fmfOpInterface.getFastMathFlagsAttr();
auto fmf = rewriter.getNamedAttr("fastmath", attr);
res = builder
.create<OpTy>(loc, args[2].getType(),
ValueRange{args[0], args[1]}, fmf)
.getResult();
} else {
res = builder
.create<OpTy>(loc, args[2].getType(),
ValueRange{args[0], args[1]})
.getResult();
}
} else
llvm_unreachable("did not expect ops other than nary and binary");
linalg::YieldOp::create(builder, loc, res);
Comment on lines +692 to +694
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think the convention is that if one branch if/else branch uses braces, all the other ones should too

Suggested change
} else
llvm_unreachable("did not expect ops other than nary and binary");
linalg::YieldOp::create(builder, loc, res);
} else {
llvm_unreachable("did not expect ops other than nary and binary");
}
linalg::YieldOp::create(builder, loc, res);

});

rewriter.replaceAllUsesWith(op, genericOp.getResult(0));
rewriter.eraseOp(op);
return genericOp.getOperation();
}

template <typename OpTy>
LogicalResult rewriteOpInDestinationPassingStyle(OpTy op,
Expand All @@ -612,9 +707,53 @@ LogicalResult rewriteOpInDestinationPassingStyle(OpTy op,

} // namespace

#define STAMP_OUT_ARITH_DPS_FUNCS(OPTY) \
FailureOr<Operation *> linalg::rewriteInDestinationPassingStyle( \
RewriterBase &rewriter, OPTY op) { \
return rewriteArithInDestinationPassingStyle<OPTY>(rewriter, op); \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return rewriteArithInDestinationPassingStyle<OPTY>(rewriter, op); \
return rewriteArithInDestinationPassingStyle(rewriter, op); \

}

STAMP_OUT_ARITH_DPS_FUNCS(arith::UIToFPOp)
STAMP_OUT_ARITH_DPS_FUNCS(arith::SIToFPOp)
STAMP_OUT_ARITH_DPS_FUNCS(arith::FPToUIOp)
STAMP_OUT_ARITH_DPS_FUNCS(arith::FPToSIOp)

STAMP_OUT_ARITH_DPS_FUNCS(arith::AddIOp)

STAMP_OUT_ARITH_DPS_FUNCS(arith::AddFOp)
STAMP_OUT_ARITH_DPS_FUNCS(arith::DivFOp)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can undef this macro here


void linalg::populateConvertToDestinationStylePatterns(
RewritePatternSet &patterns) {
patterns.add(rewriteOpInDestinationPassingStyle<tensor::FromElementsOp>);
patterns.add(rewriteOpInDestinationPassingStyle<tensor::GenerateOp>);
patterns.add(rewriteOpInDestinationPassingStyle<tensor::PadOp>);

patterns.add(rewriteOpInDestinationPassingStyle<arith::UIToFPOp>);
patterns.add(rewriteOpInDestinationPassingStyle<arith::SIToFPOp>);
patterns.add(rewriteOpInDestinationPassingStyle<arith::FPToUIOp>);
patterns.add(rewriteOpInDestinationPassingStyle<arith::FPToSIOp>);

patterns.add(rewriteOpInDestinationPassingStyle<arith::AddIOp>);

patterns.add(rewriteOpInDestinationPassingStyle<arith::AddFOp>);
patterns.add(rewriteOpInDestinationPassingStyle<arith::DivFOp>);
}

namespace {
struct LinalgConvertToDPSPass
: public impl::LinalgConvertToDPSPassBase<LinalgConvertToDPSPass> {
using impl::LinalgConvertToDPSPassBase<
LinalgConvertToDPSPass>::LinalgConvertToDPSPassBase;
Comment on lines +746 to +747
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
using impl::LinalgConvertToDPSPassBase<
LinalgConvertToDPSPass>::LinalgConvertToDPSPassBase;
using Base::Base;


void runOnOperation() override;
};

void LinalgConvertToDPSPass::runOnOperation() {

RewritePatternSet patterns(&getContext());
linalg::populateConvertToDestinationStylePatterns(patterns);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
Copy link
Member

@kuhar kuhar Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would the walk rewrite driver work here? I think this will never have to rewrite the same op twice or visit newly created ops. https://mlir.llvm.org/docs/PatternRewriter/#walk-pattern-rewrite-driver

signalPassFailure();
}
} // namespace
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest creating a new test file with Arith Ops. I see two reasons:

  • Separate tests for Tensor and Arith Ops (these trigger different patterns, so separating them makes sense IMHO).
  • For Tensor Ops, the naming scheme for test functions is @tensor_<op-name>_variant. For Arith, you are using @arith_<binary|unary>_op. So that's a bit inconsistent.

If you prefer to keep everything in one file, could you add a big comment separating Tensor and Arith Ops? Here's an example block comment:

///----------------------------------------------------------------------------------------
/// Tests for tensor.pad
///----------------------------------------------------------------------------------------

Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,96 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}

// -----

// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: func @arith_unary_op(
// CHECK-SAME: %[[X:.+]]: tensor<64xi32>) -> tensor<64xf32> {
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<64xf32>
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: {indexing_maps = [#[[$map]], #[[$map]]], iterator_types = ["parallel"]}
// CHECK-SAME: ins(%[[X:.+]] : tensor<64xi32>) outs(%[[EMPTY]] : tensor<64xf32>) {
// CHECK: ^bb0(%[[x:.+]]: i32, %[[Out:.+]]: f32):
// CHECK: %[[z:.+]] = arith.uitofp %[[x]] : i32 to f32
// CHECK: linalg.yield %[[z]] : f32
// CHECK: return %[[GENERIC]] : tensor<64xf32>

func.func @arith_unary_op(%x : tensor<64xi32>) -> tensor<64xf32> {
%z = arith.uitofp %x : tensor<64xi32> to tensor<64xf32>
return %z : tensor<64xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["arith.uitofp"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.rewrite_in_destination_passing_style %0
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: func @arith_binop(
// CHECK-SAME: %[[X:.+]]: tensor<?xf32>, %[[Y:.+]]: tensor<?xf32>
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[DIM:.+]] = tensor.dim %[[X]], %[[C0]] : tensor<?xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?xf32>
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: {indexing_maps = [#[[$map]], #[[$map]], #[[$map]]], iterator_types = ["parallel"]}
// CHECK-SAME: ins(%[[X:.+]], %[[Y:.+]] : tensor<?xf32>, tensor<?xf32>) outs(%[[EMPTY]] : tensor<?xf32>) {
// CHECK: ^bb0(%[[x:.+]]: f32, %[[y:.+]]: f32, %[[Out:.+]]: f32):
// CHECK: %[[z:.+]] = arith.addf %[[x]], %[[y]] : f32
// CHECK: linalg.yield %[[z]] : f32
// CHECK: return %[[GENERIC]] : tensor<?xf32>

func.func @arith_binop(%x : tensor<?xf32>, %y : tensor<?xf32>)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use consistent naming.

Suggested change
func.func @arith_binop(%x : tensor<?xf32>, %y : tensor<?xf32>)
func.func @arith_bin_op(%x : tensor<?xf32>, %y : tensor<?xf32>)

-> tensor<?xf32> {
%z = arith.addf %x, %y : tensor<?xf32>
return %z : tensor<?xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["arith.addf"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.rewrite_in_destination_passing_style %0
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: func @arith_binop_fastmath(
// CHECK-SAME: %[[X:.+]]: tensor<?xf32>, %[[Y:.+]]: tensor<?xf32>
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[DIM:.+]] = tensor.dim %[[X]], %[[C0]] : tensor<?xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?xf32>
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: {indexing_maps = [#[[$map]], #[[$map]], #[[$map]]], iterator_types = ["parallel"]}
// CHECK-SAME: ins(%[[X:.+]], %[[Y:.+]] : tensor<?xf32>, tensor<?xf32>) outs(%[[EMPTY]] : tensor<?xf32>) {
// CHECK: ^bb0(%[[x:.+]]: f32, %[[y:.+]]: f32, %[[Out:.+]]: f32):
// CHECK: %[[z:.+]] = arith.addf %[[x]], %[[y]] fastmath<fast> : f32
// CHECK: linalg.yield %[[z]] : f32
// CHECK: return %[[GENERIC]] : tensor<?xf32>
Comment on lines +319 to +331
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have already checked all the fine details when testing @arith_binop, here it is sufficient to make sure that fastmath is propagated correctly. I suggest simplifying:

Suggested change
// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: func @arith_binop_fastmath(
// CHECK-SAME: %[[X:.+]]: tensor<?xf32>, %[[Y:.+]]: tensor<?xf32>
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[DIM:.+]] = tensor.dim %[[X]], %[[C0]] : tensor<?xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?xf32>
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: {indexing_maps = [#[[$map]], #[[$map]], #[[$map]]], iterator_types = ["parallel"]}
// CHECK-SAME: ins(%[[X:.+]], %[[Y:.+]] : tensor<?xf32>, tensor<?xf32>) outs(%[[EMPTY]] : tensor<?xf32>) {
// CHECK: ^bb0(%[[x:.+]]: f32, %[[y:.+]]: f32, %[[Out:.+]]: f32):
// CHECK: %[[z:.+]] = arith.addf %[[x]], %[[y]] fastmath<fast> : f32
// CHECK: linalg.yield %[[z]] : f32
// CHECK: return %[[GENERIC]] : tensor<?xf32>
// CHECK-LABEL: func @arith_binop_fastmath(
// CHECK: linalg.generic
// CHECK-SAME: ins({{.*}} : tensor<?xf32>, tensor<?xf32>) outs({{.*}} tensor<?xf32>) {
// CHECK: ^bb0({.*}}: f32, %{{.*}}: f32, %[[Out:.+]]: f32):
// CHECK: arith.addf {{.*}} fastmath<fast> : f32


func.func @arith_binop_fastmath(%x : tensor<?xf32>, %y : tensor<?xf32>)
-> tensor<?xf32> {
%z = arith.addf %x, %y fastmath<fast> : tensor<?xf32>
return %z : tensor<?xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["arith.addf"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.rewrite_in_destination_passing_style %0
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}
26 changes: 26 additions & 0 deletions mlir/test/Dialect/Quant/lower-quant-ops-to-dps.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// RUN: mlir-opt %s -lower-quant-ops -linalg-convert-to-dps \
// RUN: -linalg-specialize-generic-ops -cse | FileCheck %s

// CHECK-LABEL: func.func @lower_qcast_to_dps(
// CHECK-SAME: %[[X:.+]]: tensor<10xf32>) -> tensor<10x!quant.uniform<i8:f32, 2.000000e+00:10>>
// CHECK-DAG: %[[CST_10I:.+]] = arith.constant dense<10> : tensor<10xi8>
// CHECK-DAG: %[[CST_2F:.+]] = arith.constant dense<2.000000e+00> : tensor<10xf32>
// CHECK: %[[E:.+]] = tensor.empty() : tensor<10xf32>
// CHECK: %[[DIV:.+]] = linalg.div ins(%[[X]], %[[CST_2F]] : tensor<10xf32>, tensor<10xf32>)
// CHECK-SAME: outs(%[[E]] : tensor<10xf32>) -> tensor<10xf32>
//
// CHECK: %[[SITOFP:.+]] = linalg.generic
// CHECK-SAME: ins(%[[CST_10I]] : tensor<10xi8>) outs(%[[E]] : tensor<10xf32>)
// CHECK: %{{.*}} = arith.sitofp %{{.*}} : i8 to f32
//
// CHECK: %[[ADD:.+]] = linalg.add ins(%[[DIV]], %[[SITOFP]] : tensor<10xf32>, tensor<10xf32>)
// CHECK: %{{.*}} = linalg.generic
// CHECK-SAME: ins(%[[ADD]] : tensor<10xf32>)
// CHECK: %{{.*}} = arith.fptosi %{{.*}} : f32 to i8


!qalias = !quant.uniform<i8:f32, 2.0:10>
func.func @lower_qcast_to_dps(%arg0: tensor<10xf32>) -> tensor<10x!qalias> {
%0 = quant.qcast %arg0 : tensor<10xf32> to tensor<10x!qalias>
return %0 : tensor<10x!qalias>
}
Loading