[mlir][spirv][tosa] Extend TOSA to SPIR-V TOSA op conversion#200009
Merged
davidegrohmann merged 3 commits intoMay 29, 2026
Merged
Conversation
Introduce a common TosaOpConvert pattern with small replacer classes to share result type conversion while keeping replacement logic explicit. This prepares the pass for adding more TOSA 1.0 operations without duplicating the common type-conversion boilerplate. Change-Id: I7796339135e583b425e8d66c07f379acdcc530b8 Signed-off-by: Davide Grohmann <davide.grohmann@arm.com>
Add conversion patterns for additional TOSA 1.0 operations targeting the SPIR-V TOSA extended instruction set. The newly covered operations include: * elementwise ops such as clamp, arithmetic_right_shift, mul, table, negate, and select * reductions and argmax * gather, scatter, resize * reshape, reverse, slice, tile, transpose * cast and const_shape Also add conversion tests. Change-Id: Ife2286c50cece52821037c950baf5885e7fa6931 Signed-off-by: Davide Grohmann <davide.grohmann@arm.com>
|
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir-tosa Author: Davide Grohmann (davidegrohmann) ChangesAdd conversion patterns for additional TOSA 1.0 operations targeting Introduce a common TosaOpConvert pattern with small replacer classes The newly covered operations include:
Also add conversion tests. Patch is 36.00 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/200009.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosaOps.cpp b/mlir/lib/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosaOps.cpp
index bb4f4b76c9b29..eef01b298d7b6 100644
--- a/mlir/lib/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosaOps.cpp
+++ b/mlir/lib/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosaOps.cpp
@@ -21,18 +21,26 @@ namespace mlir::tosa {
namespace {
template <typename OpAdaptor>
-Value getInput1(OpAdaptor adaptor) {
- return adaptor.getInput1();
+spirv::TosaExtNaNPropagationModeType getNanMode(OpAdaptor adaptor) {
+ return static_cast<spirv::TosaExtNaNPropagationModeType>(
+ adaptor.getNanMode());
}
-Value getInput1(tosa::ErfOpAdaptor adaptor) { return adaptor.getInput(); }
-
-Value getInput1(tosa::SigmoidOpAdaptor adaptor) { return adaptor.getInput(); }
+template <typename OpAdaptor>
+spirv::TosaExtResizeModeType getResizeMode(OpAdaptor adaptor) {
+ return static_cast<spirv::TosaExtResizeModeType>(adaptor.getMode());
+}
-Value getInput1(tosa::TanhOpAdaptor adaptor) { return adaptor.getInput(); }
+DenseIntElementsAttr getI32TensorArmAttr(ArrayRef<int32_t> values,
+ ConversionPatternRewriter &rewriter) {
+ return DenseIntElementsAttr::get(
+ spirv::TensorArmType::get(static_cast<int64_t>(values.size()),
+ IntegerType::get(rewriter.getContext(), 32)),
+ values);
+}
-template <typename SourceOp, typename TargetOp>
-struct UnaryElementwiseOpConvert final : public OpConversionPattern<SourceOp> {
+template <typename SourceOp, typename Replacer>
+struct TosaOpConvert final : public OpConversionPattern<SourceOp> {
using OpConversionPattern<SourceOp>::OpConversionPattern;
LogicalResult
@@ -41,42 +49,223 @@ struct UnaryElementwiseOpConvert final : public OpConversionPattern<SourceOp> {
Type type = this->getTypeConverter()->convertType(op.getType());
if (!type)
return rewriter.notifyMatchFailure(op, "type conversion failed");
- rewriter.replaceOpWithNewOp<TargetOp>(op, type, getInput1(adaptor));
+ return Replacer::replace(op, adaptor, type, rewriter);
+ }
+};
+
+template <typename TargetOp>
+struct UnaryInput1Replace {
+ template <typename SourceOp>
+ static LogicalResult replace(SourceOp op, typename SourceOp::Adaptor adaptor,
+ Type type, ConversionPatternRewriter &rewriter) {
+ rewriter.replaceOpWithNewOp<TargetOp>(op, type, adaptor.getInput1());
return success();
}
};
-template <typename SourceOp, typename TargetOp>
-struct BinaryElementwiseOpConvert final : public OpConversionPattern<SourceOp> {
- using OpConversionPattern<SourceOp>::OpConversionPattern;
+template <typename TargetOp>
+struct UnaryInputReplace {
+ template <typename SourceOp>
+ static LogicalResult replace(SourceOp op, typename SourceOp::Adaptor adaptor,
+ Type type, ConversionPatternRewriter &rewriter) {
+ rewriter.replaceOpWithNewOp<TargetOp>(op, type, adaptor.getInput());
+ return success();
+ }
+};
- LogicalResult
- matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- Type type = this->getTypeConverter()->convertType(op.getType());
- if (!type)
- return rewriter.notifyMatchFailure(op, "type conversion failed");
+template <typename TargetOp>
+struct BinaryElementwiseReplace {
+ template <typename SourceOp>
+ static LogicalResult replace(SourceOp op, typename SourceOp::Adaptor adaptor,
+ Type type, ConversionPatternRewriter &rewriter) {
rewriter.replaceOpWithNewOp<TargetOp>(op, type, adaptor.getInput1(),
adaptor.getInput2());
return success();
}
};
-template <typename SourceOp, typename TargetOp>
-struct BinaryNanModeElementwiseOpConvert final
- : public OpConversionPattern<SourceOp> {
- using OpConversionPattern<SourceOp>::OpConversionPattern;
+template <typename TargetOp>
+struct BinaryNanModeElementwiseReplace {
+ template <typename SourceOp>
+ static LogicalResult replace(SourceOp op, typename SourceOp::Adaptor adaptor,
+ Type type, ConversionPatternRewriter &rewriter) {
+ rewriter.replaceOpWithNewOp<TargetOp>(op, type, getNanMode(adaptor),
+ adaptor.getInput1(),
+ adaptor.getInput2());
+ return success();
+ }
+};
- LogicalResult
- matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto nanMode =
- static_cast<spirv::TosaExtNaNPropagationModeType>(adaptor.getNanMode());
- Type type = this->getTypeConverter()->convertType(op.getType());
- if (!type)
- return rewriter.notifyMatchFailure(op, "type conversion failed");
+template <typename TargetOp>
+struct ReductionReplace {
+ template <typename SourceOp>
+ static LogicalResult replace(SourceOp op, typename SourceOp::Adaptor adaptor,
+ Type type, ConversionPatternRewriter &rewriter) {
+ rewriter.replaceOpWithNewOp<TargetOp>(op, type, adaptor.getAxis(),
+ adaptor.getInput());
+ return success();
+ }
+};
+
+template <typename TargetOp>
+struct NanModeReductionReplace {
+ template <typename SourceOp>
+ static LogicalResult replace(SourceOp op, typename SourceOp::Adaptor adaptor,
+ Type type, ConversionPatternRewriter &rewriter) {
rewriter.replaceOpWithNewOp<TargetOp>(
- op, type, nanMode, adaptor.getInput1(), adaptor.getInput2());
+ op, type, adaptor.getAxis(), getNanMode(adaptor), adaptor.getInput());
+ return success();
+ }
+};
+
+struct ClampReplace {
+ static LogicalResult replace(tosa::ClampOp op, tosa::ClampOpAdaptor adaptor,
+ Type type, ConversionPatternRewriter &rewriter) {
+ rewriter.replaceOpWithNewOp<spirv::TosaClampOp>(
+ op, type, adaptor.getMinVal(), adaptor.getMaxVal(), getNanMode(adaptor),
+ adaptor.getInput());
+ return success();
+ }
+};
+
+struct ArithmeticRightShiftReplace {
+ static LogicalResult replace(tosa::ArithmeticRightShiftOp op,
+ tosa::ArithmeticRightShiftOpAdaptor adaptor,
+ Type type, ConversionPatternRewriter &rewriter) {
+ rewriter.replaceOpWithNewOp<spirv::TosaArithmeticRightShiftOp>(
+ op, type, adaptor.getRound(), adaptor.getInput1(), adaptor.getInput2());
+ return success();
+ }
+};
+
+struct MulReplace {
+ static LogicalResult replace(tosa::MulOp op, tosa::MulOpAdaptor adaptor,
+ Type type, ConversionPatternRewriter &rewriter) {
+ rewriter.replaceOpWithNewOp<spirv::TosaMulOp>(
+ op, type, adaptor.getInput1(), adaptor.getInput2(), adaptor.getShift());
+ return success();
+ }
+};
+
+struct TableReplace {
+ static LogicalResult replace(tosa::TableOp op, tosa::TableOpAdaptor adaptor,
+ Type type, ConversionPatternRewriter &rewriter) {
+ rewriter.replaceOpWithNewOp<spirv::TosaTableOp>(
+ op, type, adaptor.getInput1(), adaptor.getTable());
+ return success();
+ }
+};
+
+struct NegateReplace {
+ static LogicalResult replace(tosa::NegateOp op, tosa::NegateOpAdaptor adaptor,
+ Type type, ConversionPatternRewriter &rewriter) {
+ rewriter.replaceOpWithNewOp<spirv::TosaNegateOp>(
+ op, type, adaptor.getInput1(), adaptor.getInput1Zp(),
+ adaptor.getOutputZp());
+ return success();
+ }
+};
+
+struct SelectReplace {
+ static LogicalResult replace(tosa::SelectOp op, tosa::SelectOpAdaptor adaptor,
+ Type type, ConversionPatternRewriter &rewriter) {
+ rewriter.replaceOpWithNewOp<spirv::TosaSelectOp>(
+ op, type, adaptor.getInput1(), adaptor.getInput2(),
+ adaptor.getInput3());
+ return success();
+ }
+};
+
+struct ReshapeReplace {
+ static LogicalResult replace(tosa::ReshapeOp op,
+ tosa::ReshapeOpAdaptor adaptor, Type type,
+ ConversionPatternRewriter &rewriter) {
+ rewriter.replaceOpWithNewOp<spirv::TosaReshapeOp>(
+ op, type, adaptor.getInput1(), adaptor.getShape());
+ return success();
+ }
+};
+
+struct ReverseReplace {
+ static LogicalResult replace(tosa::ReverseOp op,
+ tosa::ReverseOpAdaptor adaptor, Type type,
+ ConversionPatternRewriter &rewriter) {
+ rewriter.replaceOpWithNewOp<spirv::TosaReverseOp>(
+ op, type, adaptor.getAxis(), adaptor.getInput1());
+ return success();
+ }
+};
+
+struct SliceReplace {
+ static LogicalResult replace(tosa::SliceOp op, tosa::SliceOpAdaptor adaptor,
+ Type type, ConversionPatternRewriter &rewriter) {
+ rewriter.replaceOpWithNewOp<spirv::TosaSliceOp>(
+ op, type, adaptor.getInput1(), adaptor.getStart(), adaptor.getSize());
+ return success();
+ }
+};
+
+struct TileReplace {
+ static LogicalResult replace(tosa::TileOp op, tosa::TileOpAdaptor adaptor,
+ Type type, ConversionPatternRewriter &rewriter) {
+ rewriter.replaceOpWithNewOp<spirv::TosaTileOp>(
+ op, type, adaptor.getInput1(), adaptor.getMultiples());
+ return success();
+ }
+};
+
+struct TransposeReplace {
+ static LogicalResult replace(tosa::TransposeOp op,
+ tosa::TransposeOpAdaptor adaptor, Type type,
+ ConversionPatternRewriter &rewriter) {
+ DenseIntElementsAttr perms =
+ getI32TensorArmAttr(adaptor.getPerms(), rewriter);
+ rewriter.replaceOpWithNewOp<spirv::TosaTransposeOp>(op, type, perms,
+ adaptor.getInput1());
+ return success();
+ }
+};
+
+struct GatherReplace {
+ static LogicalResult replace(tosa::GatherOp op, tosa::GatherOpAdaptor adaptor,
+ Type type, ConversionPatternRewriter &rewriter) {
+ rewriter.replaceOpWithNewOp<spirv::TosaGatherOp>(
+ op, type, adaptor.getValues(), adaptor.getIndices());
+ return success();
+ }
+};
+
+struct ScatterReplace {
+ static LogicalResult replace(tosa::ScatterOp op,
+ tosa::ScatterOpAdaptor adaptor, Type type,
+ ConversionPatternRewriter &rewriter) {
+ rewriter.replaceOpWithNewOp<spirv::TosaScatterOp>(
+ op, type, adaptor.getValuesIn(), adaptor.getIndices(),
+ adaptor.getInput());
+ return success();
+ }
+};
+
+struct ResizeReplace {
+ static LogicalResult replace(tosa::ResizeOp op, tosa::ResizeOpAdaptor adaptor,
+ Type type, ConversionPatternRewriter &rewriter) {
+ rewriter.replaceOpWithNewOp<spirv::TosaResizeOp>(
+ op, type, getResizeMode(adaptor), adaptor.getInput(),
+ adaptor.getScale(), adaptor.getOffset(), adaptor.getBorder());
+ return success();
+ }
+};
+
+struct ConstShapeReplace {
+ static LogicalResult replace(tosa::ConstShapeOp op,
+ tosa::ConstShapeOpAdaptor adaptor, Type type,
+ ConversionPatternRewriter &rewriter) {
+ SmallVector<int32_t> values;
+ for (const APInt &value : adaptor.getValues().getValues<APInt>())
+ values.push_back(value.getSExtValue());
+
+ rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
+ op, type, getI32TensorArmAttr(values, rewriter));
return success();
}
};
@@ -86,41 +275,85 @@ struct BinaryNanModeElementwiseOpConvert final
void populateTosaToSPIRVTosaOpsConversionPatterns(
SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<
- UnaryElementwiseOpConvert<tosa::ErfOp, spirv::TosaErfOp>,
- UnaryElementwiseOpConvert<tosa::SigmoidOp, spirv::TosaSigmoidOp>,
- UnaryElementwiseOpConvert<tosa::TanhOp, spirv::TosaTanhOp>,
- BinaryElementwiseOpConvert<tosa::AddOp, spirv::TosaAddOp>,
- BinaryElementwiseOpConvert<tosa::BitwiseAndOp, spirv::TosaBitwiseAndOp>,
- BinaryElementwiseOpConvert<tosa::BitwiseOrOp, spirv::TosaBitwiseOrOp>,
- BinaryElementwiseOpConvert<tosa::BitwiseXorOp, spirv::TosaBitwiseXorOp>,
- BinaryElementwiseOpConvert<tosa::IntDivOp, spirv::TosaIntDivOp>,
- BinaryElementwiseOpConvert<tosa::LogicalAndOp, spirv::TosaLogicalAndOp>,
- BinaryElementwiseOpConvert<tosa::LogicalLeftShiftOp,
- spirv::TosaLogicalLeftShiftOp>,
- BinaryElementwiseOpConvert<tosa::LogicalRightShiftOp,
- spirv::TosaLogicalRightShiftOp>,
- BinaryElementwiseOpConvert<tosa::LogicalOrOp, spirv::TosaLogicalOrOp>,
- BinaryElementwiseOpConvert<tosa::LogicalXorOp, spirv::TosaLogicalXorOp>,
- BinaryNanModeElementwiseOpConvert<tosa::MaximumOp, spirv::TosaMaximumOp>,
- BinaryNanModeElementwiseOpConvert<tosa::MinimumOp, spirv::TosaMinimumOp>,
- BinaryElementwiseOpConvert<tosa::PowOp, spirv::TosaPowOp>,
- BinaryElementwiseOpConvert<tosa::SubOp, spirv::TosaSubOp>,
- UnaryElementwiseOpConvert<tosa::AbsOp, spirv::TosaAbsOp>,
- UnaryElementwiseOpConvert<tosa::BitwiseNotOp, spirv::TosaBitwiseNotOp>,
- UnaryElementwiseOpConvert<tosa::CeilOp, spirv::TosaCeilOp>,
- UnaryElementwiseOpConvert<tosa::ClzOp, spirv::TosaClzOp>,
- UnaryElementwiseOpConvert<tosa::CosOp, spirv::TosaCosOp>,
- UnaryElementwiseOpConvert<tosa::ExpOp, spirv::TosaExpOp>,
- UnaryElementwiseOpConvert<tosa::FloorOp, spirv::TosaFloorOp>,
- UnaryElementwiseOpConvert<tosa::LogOp, spirv::TosaLogOp>,
- UnaryElementwiseOpConvert<tosa::LogicalNotOp, spirv::TosaLogicalNotOp>,
- UnaryElementwiseOpConvert<tosa::ReciprocalOp, spirv::TosaReciprocalOp>,
- UnaryElementwiseOpConvert<tosa::RsqrtOp, spirv::TosaRsqrtOp>,
- UnaryElementwiseOpConvert<tosa::SinOp, spirv::TosaSinOp>,
- BinaryElementwiseOpConvert<tosa::EqualOp, spirv::TosaEqualOp>,
- BinaryElementwiseOpConvert<tosa::GreaterOp, spirv::TosaGreaterOp>,
- BinaryElementwiseOpConvert<tosa::GreaterEqualOp,
- spirv::TosaGreaterEqualOp>>(
+ TosaOpConvert<tosa::ArgMaxOp,
+ NanModeReductionReplace<spirv::TosaArgMaxOp>>,
+ TosaOpConvert<tosa::ClampOp, ClampReplace>,
+ TosaOpConvert<tosa::ErfOp, UnaryInputReplace<spirv::TosaErfOp>>,
+ TosaOpConvert<tosa::SigmoidOp, UnaryInputReplace<spirv::TosaSigmoidOp>>,
+ TosaOpConvert<tosa::TanhOp, UnaryInputReplace<spirv::TosaTanhOp>>,
+ TosaOpConvert<tosa::AddOp, BinaryElementwiseReplace<spirv::TosaAddOp>>,
+ TosaOpConvert<tosa::ArithmeticRightShiftOp, ArithmeticRightShiftReplace>,
+ TosaOpConvert<tosa::BitwiseAndOp,
+ BinaryElementwiseReplace<spirv::TosaBitwiseAndOp>>,
+ TosaOpConvert<tosa::BitwiseOrOp,
+ BinaryElementwiseReplace<spirv::TosaBitwiseOrOp>>,
+ TosaOpConvert<tosa::BitwiseXorOp,
+ BinaryElementwiseReplace<spirv::TosaBitwiseXorOp>>,
+ TosaOpConvert<tosa::IntDivOp,
+ BinaryElementwiseReplace<spirv::TosaIntDivOp>>,
+ TosaOpConvert<tosa::LogicalAndOp,
+ BinaryElementwiseReplace<spirv::TosaLogicalAndOp>>,
+ TosaOpConvert<tosa::LogicalLeftShiftOp,
+ BinaryElementwiseReplace<spirv::TosaLogicalLeftShiftOp>>,
+ TosaOpConvert<tosa::LogicalRightShiftOp,
+ BinaryElementwiseReplace<spirv::TosaLogicalRightShiftOp>>,
+ TosaOpConvert<tosa::LogicalOrOp,
+ BinaryElementwiseReplace<spirv::TosaLogicalOrOp>>,
+ TosaOpConvert<tosa::LogicalXorOp,
+ BinaryElementwiseReplace<spirv::TosaLogicalXorOp>>,
+ TosaOpConvert<tosa::MaximumOp,
+ BinaryNanModeElementwiseReplace<spirv::TosaMaximumOp>>,
+ TosaOpConvert<tosa::MinimumOp,
+ BinaryNanModeElementwiseReplace<spirv::TosaMinimumOp>>,
+ TosaOpConvert<tosa::MulOp, MulReplace>,
+ TosaOpConvert<tosa::PowOp, BinaryElementwiseReplace<spirv::TosaPowOp>>,
+ TosaOpConvert<tosa::SubOp, BinaryElementwiseReplace<spirv::TosaSubOp>>,
+ TosaOpConvert<tosa::TableOp, TableReplace>,
+ TosaOpConvert<tosa::AbsOp, UnaryInput1Replace<spirv::TosaAbsOp>>,
+ TosaOpConvert<tosa::BitwiseNotOp,
+ UnaryInput1Replace<spirv::TosaBitwiseNotOp>>,
+ TosaOpConvert<tosa::CeilOp, UnaryInput1Replace<spirv::TosaCeilOp>>,
+ TosaOpConvert<tosa::ClzOp, UnaryInput1Replace<spirv::TosaClzOp>>,
+ TosaOpConvert<tosa::CosOp, UnaryInput1Replace<spirv::TosaCosOp>>,
+ TosaOpConvert<tosa::ExpOp, UnaryInput1Replace<spirv::TosaExpOp>>,
+ TosaOpConvert<tosa::FloorOp, UnaryInput1Replace<spirv::TosaFloorOp>>,
+ TosaOpConvert<tosa::LogOp, UnaryInput1Replace<spirv::TosaLogOp>>,
+ TosaOpConvert<tosa::LogicalNotOp,
+ UnaryInput1Replace<spirv::TosaLogicalNotOp>>,
+ TosaOpConvert<tosa::NegateOp, NegateReplace>,
+ TosaOpConvert<tosa::ReciprocalOp,
+ UnaryInput1Replace<spirv::TosaReciprocalOp>>,
+ TosaOpConvert<tosa::RsqrtOp, UnaryInput1Replace<spirv::TosaRsqrtOp>>,
+ TosaOpConvert<tosa::SinOp, UnaryInput1Replace<spirv::TosaSinOp>>,
+ TosaOpConvert<tosa::SelectOp, SelectReplace>,
+ TosaOpConvert<tosa::EqualOp,
+ BinaryElementwiseReplace<spirv::TosaEqualOp>>,
+ TosaOpConvert<tosa::GreaterOp,
+ BinaryElementwiseReplace<spirv::TosaGreaterOp>>,
+ TosaOpConvert<tosa::GreaterEqualOp,
+ BinaryElementwiseReplace<spirv::TosaGreaterEqualOp>>,
+ TosaOpConvert<tosa::ReduceAllOp,
+ ReductionReplace<spirv::TosaReduceAllOp>>,
+ TosaOpConvert<tosa::ReduceAnyOp,
+ ReductionReplace<spirv::TosaReduceAnyOp>>,
+ TosaOpConvert<tosa::ReduceMaxOp,
+ NanModeReductionReplace<spirv::TosaReduceMaxOp>>,
+ TosaOpConvert<tosa::ReduceMinOp,
+ NanModeReductionReplace<spirv::TosaReduceMinOp>>,
+ TosaOpConvert<tosa::ReduceProductOp,
+ ReductionReplace<spirv::TosaReduceProductOp>>,
+ TosaOpConvert<tosa::ReduceSumOp,
+ ReductionReplace<spirv::TosaReduceSumOp>>,
+ TosaOpConvert<tosa::ReshapeOp, ReshapeReplace>,
+ TosaOpConvert<tosa::ReverseOp, ReverseReplace>,
+ TosaOpConvert<tosa::SliceOp, SliceReplace>,
+ TosaOpConvert<tosa::TileOp, TileReplace>,
+ TosaOpConvert<tosa::TransposeOp, TransposeReplace>,
+ TosaOpConvert<tosa::GatherOp, GatherReplace>,
+ TosaOpConvert<tosa::ScatterOp, ScatterReplace>,
+ TosaOpConvert<tosa::ResizeOp, ResizeReplace>,
+ TosaOpConvert<tosa::CastOp, UnaryInputReplace<spirv::TosaCastOp>>,
+ TosaOpConvert<tosa::ConstShapeOp, ConstShapeReplace>>(
typeConverter, patterns.getContext());
}
diff --git a/mlir/test/Conversion/TosaToSPIRVTosa/tosa-to-spirv.mlir b/mlir/test/Conversion/TosaToSPIRVTosa/tosa-to-spirv.mlir
index 4e54e5cdd3634..a175baf62eda1 100644
--- a/mlir/test/Conversion/TosaToSPIRVTosa/tosa-to-spirv.mlir
+++ b/mlir/test/Conversion/TosaToSPIRVTosa/tosa-to-spirv.mlir
@@ -1,5 +1,31 @@
// RUN: mlir-opt --split-input-file --tosa-to-spirv-tosa %s | FileCheck %s
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ArgMax
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: spirv.ARM.Graph @argmax_int
+func.func @argmax_int(%arg0: tensor<2x3x4xi8>) -> tensor<2x4xi32> {
+ // CHECK: %[[ARGMAX:.*]] = spirv.Tosa.ArgMax axis = 1, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<2x3x4xi8> -> !spirv.arm.tensor<2x4xi32>
+ %res = tosa.argmax %arg0 {axis = 1 : i32, nan_mode = PROPAGATE} : (tensor<2x3x4xi8>) -> tensor<2x4xi32>
+ return %res : tensor<2x4xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Clamp
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: spirv.ARM.Graph @clamp_int
+func.func @clamp_int(%arg0: tensor<4x8xi8>) -> tensor<4x8xi8> {
+ // CHECK: %[[CLAMP:.*]] = spirv.Tosa.Clam...
[truncated]
|
kuhar
reviewed
May 28, 2026
Replace stateless replacer classes with plain replacement functions passed to TosaOpConvert as non-type template parameters. Change-Id: Ieb3e4bba0385996089f2f490ca3a03f33cb97eb6 Signed-off-by: Davide Grohmann <davide.grohmann@arm.com>
kuhar
approved these changes
May 29, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Add conversion patterns for additional TOSA 1.0 operations targeting
the SPIR-V TOSA extended instruction set.
Introduce a common TosaOpConvert pattern with small replacer classes
to share result type conversion while keeping op-specific replacement
logic explicit.
The newly covered operations include:
negate, and select
Also add conversion tests.