Skip to content

[mlir][spirv][tosa] Extend TOSA to SPIR-V TOSA op conversion#200009

Merged
davidegrohmann merged 3 commits into
llvm:mainfrom
davidegrohmann:mlir-spirv-extend-tosa-to-spirv-tosa-pass
May 29, 2026
Merged

[mlir][spirv][tosa] Extend TOSA to SPIR-V TOSA op conversion#200009
davidegrohmann merged 3 commits into
llvm:mainfrom
davidegrohmann:mlir-spirv-extend-tosa-to-spirv-tosa-pass

Conversation

@davidegrohmann
Copy link
Copy Markdown
Contributor

@davidegrohmann davidegrohmann commented May 27, 2026

Add conversion patterns for additional TOSA 1.0 operations targeting
the SPIR-V TOSA extended instruction set.

Introduce a common TosaOpConvert pattern with small replacer classes
to share result type conversion while keeping op-specific replacement
logic explicit.

The newly covered operations include:

  • elementwise ops such as clamp, arithmetic_right_shift, mul, table,
    negate, and select
  • reductions and argmax
  • gather, scatter, resize
  • reshape, reverse, slice, tile, transpose
  • cast and const_shape

Also add conversion tests.

Introduce a common TosaOpConvert pattern with small replacer classes
to share result type conversion while keeping replacement logic
explicit.

This prepares the pass for adding more TOSA 1.0 operations without
duplicating the common type-conversion boilerplate.

Change-Id: I7796339135e583b425e8d66c07f379acdcc530b8
Signed-off-by: Davide Grohmann <davide.grohmann@arm.com>
Add conversion patterns for additional TOSA 1.0 operations targeting
the SPIR-V TOSA extended instruction set.

The newly covered operations include:
* elementwise ops such as clamp, arithmetic_right_shift, mul, table,
  negate, and select
* reductions and argmax
* gather, scatter, resize
* reshape, reverse, slice, tile, transpose
* cast and const_shape

Also add conversion tests.

Change-Id: Ife2286c50cece52821037c950baf5885e7fa6931
Signed-off-by: Davide Grohmann <davide.grohmann@arm.com>
@llvmorg-github-actions
Copy link
Copy Markdown

llvmorg-github-actions Bot commented May 27, 2026

@llvm/pr-subscribers-mlir-spirv
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tosa

Author: Davide Grohmann (davidegrohmann)

Changes

Add conversion patterns for additional TOSA 1.0 operations targeting
the SPIR-V TOSA extended instruction set.

Introduce a common TosaOpConvert pattern with small replacer classes
to share result type conversion while keeping op-specific replacement
logic explicit.

The newly covered operations include:

  • elementwise ops such as clamp, arithmetic_right_shift, mul, table,
    negate, and select
  • reductions and argmax
  • gather, scatter, resize
  • reshape, reverse, slice, tile, transpose
  • cast and const_shape

Also add conversion tests.


Patch is 36.00 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/200009.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosaOps.cpp (+299-66)
  • (modified) mlir/test/Conversion/TosaToSPIRVTosa/tosa-to-spirv.mlir (+299)
diff --git a/mlir/lib/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosaOps.cpp b/mlir/lib/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosaOps.cpp
index bb4f4b76c9b29..eef01b298d7b6 100644
--- a/mlir/lib/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosaOps.cpp
+++ b/mlir/lib/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosaOps.cpp
@@ -21,18 +21,26 @@ namespace mlir::tosa {
 namespace {
 
 template <typename OpAdaptor>
-Value getInput1(OpAdaptor adaptor) {
-  return adaptor.getInput1();
+spirv::TosaExtNaNPropagationModeType getNanMode(OpAdaptor adaptor) {
+  return static_cast<spirv::TosaExtNaNPropagationModeType>(
+      adaptor.getNanMode());
 }
 
-Value getInput1(tosa::ErfOpAdaptor adaptor) { return adaptor.getInput(); }
-
-Value getInput1(tosa::SigmoidOpAdaptor adaptor) { return adaptor.getInput(); }
+template <typename OpAdaptor>
+spirv::TosaExtResizeModeType getResizeMode(OpAdaptor adaptor) {
+  return static_cast<spirv::TosaExtResizeModeType>(adaptor.getMode());
+}
 
-Value getInput1(tosa::TanhOpAdaptor adaptor) { return adaptor.getInput(); }
+DenseIntElementsAttr getI32TensorArmAttr(ArrayRef<int32_t> values,
+                                         ConversionPatternRewriter &rewriter) {
+  return DenseIntElementsAttr::get(
+      spirv::TensorArmType::get(static_cast<int64_t>(values.size()),
+                                IntegerType::get(rewriter.getContext(), 32)),
+      values);
+}
 
-template <typename SourceOp, typename TargetOp>
-struct UnaryElementwiseOpConvert final : public OpConversionPattern<SourceOp> {
+template <typename SourceOp, typename Replacer>
+struct TosaOpConvert final : public OpConversionPattern<SourceOp> {
   using OpConversionPattern<SourceOp>::OpConversionPattern;
 
   LogicalResult
@@ -41,42 +49,223 @@ struct UnaryElementwiseOpConvert final : public OpConversionPattern<SourceOp> {
     Type type = this->getTypeConverter()->convertType(op.getType());
     if (!type)
       return rewriter.notifyMatchFailure(op, "type conversion failed");
-    rewriter.replaceOpWithNewOp<TargetOp>(op, type, getInput1(adaptor));
+    return Replacer::replace(op, adaptor, type, rewriter);
+  }
+};
+
+template <typename TargetOp>
+struct UnaryInput1Replace {
+  template <typename SourceOp>
+  static LogicalResult replace(SourceOp op, typename SourceOp::Adaptor adaptor,
+                               Type type, ConversionPatternRewriter &rewriter) {
+    rewriter.replaceOpWithNewOp<TargetOp>(op, type, adaptor.getInput1());
     return success();
   }
 };
 
-template <typename SourceOp, typename TargetOp>
-struct BinaryElementwiseOpConvert final : public OpConversionPattern<SourceOp> {
-  using OpConversionPattern<SourceOp>::OpConversionPattern;
+template <typename TargetOp>
+struct UnaryInputReplace {
+  template <typename SourceOp>
+  static LogicalResult replace(SourceOp op, typename SourceOp::Adaptor adaptor,
+                               Type type, ConversionPatternRewriter &rewriter) {
+    rewriter.replaceOpWithNewOp<TargetOp>(op, type, adaptor.getInput());
+    return success();
+  }
+};
 
-  LogicalResult
-  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    Type type = this->getTypeConverter()->convertType(op.getType());
-    if (!type)
-      return rewriter.notifyMatchFailure(op, "type conversion failed");
+template <typename TargetOp>
+struct BinaryElementwiseReplace {
+  template <typename SourceOp>
+  static LogicalResult replace(SourceOp op, typename SourceOp::Adaptor adaptor,
+                               Type type, ConversionPatternRewriter &rewriter) {
     rewriter.replaceOpWithNewOp<TargetOp>(op, type, adaptor.getInput1(),
                                           adaptor.getInput2());
     return success();
   }
 };
 
-template <typename SourceOp, typename TargetOp>
-struct BinaryNanModeElementwiseOpConvert final
-    : public OpConversionPattern<SourceOp> {
-  using OpConversionPattern<SourceOp>::OpConversionPattern;
+template <typename TargetOp>
+struct BinaryNanModeElementwiseReplace {
+  template <typename SourceOp>
+  static LogicalResult replace(SourceOp op, typename SourceOp::Adaptor adaptor,
+                               Type type, ConversionPatternRewriter &rewriter) {
+    rewriter.replaceOpWithNewOp<TargetOp>(op, type, getNanMode(adaptor),
+                                          adaptor.getInput1(),
+                                          adaptor.getInput2());
+    return success();
+  }
+};
 
-  LogicalResult
-  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto nanMode =
-        static_cast<spirv::TosaExtNaNPropagationModeType>(adaptor.getNanMode());
-    Type type = this->getTypeConverter()->convertType(op.getType());
-    if (!type)
-      return rewriter.notifyMatchFailure(op, "type conversion failed");
+template <typename TargetOp>
+struct ReductionReplace {
+  template <typename SourceOp>
+  static LogicalResult replace(SourceOp op, typename SourceOp::Adaptor adaptor,
+                               Type type, ConversionPatternRewriter &rewriter) {
+    rewriter.replaceOpWithNewOp<TargetOp>(op, type, adaptor.getAxis(),
+                                          adaptor.getInput());
+    return success();
+  }
+};
+
+template <typename TargetOp>
+struct NanModeReductionReplace {
+  template <typename SourceOp>
+  static LogicalResult replace(SourceOp op, typename SourceOp::Adaptor adaptor,
+                               Type type, ConversionPatternRewriter &rewriter) {
     rewriter.replaceOpWithNewOp<TargetOp>(
-        op, type, nanMode, adaptor.getInput1(), adaptor.getInput2());
+        op, type, adaptor.getAxis(), getNanMode(adaptor), adaptor.getInput());
+    return success();
+  }
+};
+
+struct ClampReplace {
+  static LogicalResult replace(tosa::ClampOp op, tosa::ClampOpAdaptor adaptor,
+                               Type type, ConversionPatternRewriter &rewriter) {
+    rewriter.replaceOpWithNewOp<spirv::TosaClampOp>(
+        op, type, adaptor.getMinVal(), adaptor.getMaxVal(), getNanMode(adaptor),
+        adaptor.getInput());
+    return success();
+  }
+};
+
+struct ArithmeticRightShiftReplace {
+  static LogicalResult replace(tosa::ArithmeticRightShiftOp op,
+                               tosa::ArithmeticRightShiftOpAdaptor adaptor,
+                               Type type, ConversionPatternRewriter &rewriter) {
+    rewriter.replaceOpWithNewOp<spirv::TosaArithmeticRightShiftOp>(
+        op, type, adaptor.getRound(), adaptor.getInput1(), adaptor.getInput2());
+    return success();
+  }
+};
+
+struct MulReplace {
+  static LogicalResult replace(tosa::MulOp op, tosa::MulOpAdaptor adaptor,
+                               Type type, ConversionPatternRewriter &rewriter) {
+    rewriter.replaceOpWithNewOp<spirv::TosaMulOp>(
+        op, type, adaptor.getInput1(), adaptor.getInput2(), adaptor.getShift());
+    return success();
+  }
+};
+
+struct TableReplace {
+  static LogicalResult replace(tosa::TableOp op, tosa::TableOpAdaptor adaptor,
+                               Type type, ConversionPatternRewriter &rewriter) {
+    rewriter.replaceOpWithNewOp<spirv::TosaTableOp>(
+        op, type, adaptor.getInput1(), adaptor.getTable());
+    return success();
+  }
+};
+
+struct NegateReplace {
+  static LogicalResult replace(tosa::NegateOp op, tosa::NegateOpAdaptor adaptor,
+                               Type type, ConversionPatternRewriter &rewriter) {
+    rewriter.replaceOpWithNewOp<spirv::TosaNegateOp>(
+        op, type, adaptor.getInput1(), adaptor.getInput1Zp(),
+        adaptor.getOutputZp());
+    return success();
+  }
+};
+
+struct SelectReplace {
+  static LogicalResult replace(tosa::SelectOp op, tosa::SelectOpAdaptor adaptor,
+                               Type type, ConversionPatternRewriter &rewriter) {
+    rewriter.replaceOpWithNewOp<spirv::TosaSelectOp>(
+        op, type, adaptor.getInput1(), adaptor.getInput2(),
+        adaptor.getInput3());
+    return success();
+  }
+};
+
+struct ReshapeReplace {
+  static LogicalResult replace(tosa::ReshapeOp op,
+                               tosa::ReshapeOpAdaptor adaptor, Type type,
+                               ConversionPatternRewriter &rewriter) {
+    rewriter.replaceOpWithNewOp<spirv::TosaReshapeOp>(
+        op, type, adaptor.getInput1(), adaptor.getShape());
+    return success();
+  }
+};
+
+struct ReverseReplace {
+  static LogicalResult replace(tosa::ReverseOp op,
+                               tosa::ReverseOpAdaptor adaptor, Type type,
+                               ConversionPatternRewriter &rewriter) {
+    rewriter.replaceOpWithNewOp<spirv::TosaReverseOp>(
+        op, type, adaptor.getAxis(), adaptor.getInput1());
+    return success();
+  }
+};
+
+struct SliceReplace {
+  static LogicalResult replace(tosa::SliceOp op, tosa::SliceOpAdaptor adaptor,
+                               Type type, ConversionPatternRewriter &rewriter) {
+    rewriter.replaceOpWithNewOp<spirv::TosaSliceOp>(
+        op, type, adaptor.getInput1(), adaptor.getStart(), adaptor.getSize());
+    return success();
+  }
+};
+
+struct TileReplace {
+  static LogicalResult replace(tosa::TileOp op, tosa::TileOpAdaptor adaptor,
+                               Type type, ConversionPatternRewriter &rewriter) {
+    rewriter.replaceOpWithNewOp<spirv::TosaTileOp>(
+        op, type, adaptor.getInput1(), adaptor.getMultiples());
+    return success();
+  }
+};
+
+struct TransposeReplace {
+  static LogicalResult replace(tosa::TransposeOp op,
+                               tosa::TransposeOpAdaptor adaptor, Type type,
+                               ConversionPatternRewriter &rewriter) {
+    DenseIntElementsAttr perms =
+        getI32TensorArmAttr(adaptor.getPerms(), rewriter);
+    rewriter.replaceOpWithNewOp<spirv::TosaTransposeOp>(op, type, perms,
+                                                        adaptor.getInput1());
+    return success();
+  }
+};
+
+struct GatherReplace {
+  static LogicalResult replace(tosa::GatherOp op, tosa::GatherOpAdaptor adaptor,
+                               Type type, ConversionPatternRewriter &rewriter) {
+    rewriter.replaceOpWithNewOp<spirv::TosaGatherOp>(
+        op, type, adaptor.getValues(), adaptor.getIndices());
+    return success();
+  }
+};
+
+struct ScatterReplace {
+  static LogicalResult replace(tosa::ScatterOp op,
+                               tosa::ScatterOpAdaptor adaptor, Type type,
+                               ConversionPatternRewriter &rewriter) {
+    rewriter.replaceOpWithNewOp<spirv::TosaScatterOp>(
+        op, type, adaptor.getValuesIn(), adaptor.getIndices(),
+        adaptor.getInput());
+    return success();
+  }
+};
+
+struct ResizeReplace {
+  static LogicalResult replace(tosa::ResizeOp op, tosa::ResizeOpAdaptor adaptor,
+                               Type type, ConversionPatternRewriter &rewriter) {
+    rewriter.replaceOpWithNewOp<spirv::TosaResizeOp>(
+        op, type, getResizeMode(adaptor), adaptor.getInput(),
+        adaptor.getScale(), adaptor.getOffset(), adaptor.getBorder());
+    return success();
+  }
+};
+
+struct ConstShapeReplace {
+  static LogicalResult replace(tosa::ConstShapeOp op,
+                               tosa::ConstShapeOpAdaptor adaptor, Type type,
+                               ConversionPatternRewriter &rewriter) {
+    SmallVector<int32_t> values;
+    for (const APInt &value : adaptor.getValues().getValues<APInt>())
+      values.push_back(value.getSExtValue());
+
+    rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
+        op, type, getI32TensorArmAttr(values, rewriter));
     return success();
   }
 };
@@ -86,41 +275,85 @@ struct BinaryNanModeElementwiseOpConvert final
 void populateTosaToSPIRVTosaOpsConversionPatterns(
     SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
   patterns.add<
-      UnaryElementwiseOpConvert<tosa::ErfOp, spirv::TosaErfOp>,
-      UnaryElementwiseOpConvert<tosa::SigmoidOp, spirv::TosaSigmoidOp>,
-      UnaryElementwiseOpConvert<tosa::TanhOp, spirv::TosaTanhOp>,
-      BinaryElementwiseOpConvert<tosa::AddOp, spirv::TosaAddOp>,
-      BinaryElementwiseOpConvert<tosa::BitwiseAndOp, spirv::TosaBitwiseAndOp>,
-      BinaryElementwiseOpConvert<tosa::BitwiseOrOp, spirv::TosaBitwiseOrOp>,
-      BinaryElementwiseOpConvert<tosa::BitwiseXorOp, spirv::TosaBitwiseXorOp>,
-      BinaryElementwiseOpConvert<tosa::IntDivOp, spirv::TosaIntDivOp>,
-      BinaryElementwiseOpConvert<tosa::LogicalAndOp, spirv::TosaLogicalAndOp>,
-      BinaryElementwiseOpConvert<tosa::LogicalLeftShiftOp,
-                                 spirv::TosaLogicalLeftShiftOp>,
-      BinaryElementwiseOpConvert<tosa::LogicalRightShiftOp,
-                                 spirv::TosaLogicalRightShiftOp>,
-      BinaryElementwiseOpConvert<tosa::LogicalOrOp, spirv::TosaLogicalOrOp>,
-      BinaryElementwiseOpConvert<tosa::LogicalXorOp, spirv::TosaLogicalXorOp>,
-      BinaryNanModeElementwiseOpConvert<tosa::MaximumOp, spirv::TosaMaximumOp>,
-      BinaryNanModeElementwiseOpConvert<tosa::MinimumOp, spirv::TosaMinimumOp>,
-      BinaryElementwiseOpConvert<tosa::PowOp, spirv::TosaPowOp>,
-      BinaryElementwiseOpConvert<tosa::SubOp, spirv::TosaSubOp>,
-      UnaryElementwiseOpConvert<tosa::AbsOp, spirv::TosaAbsOp>,
-      UnaryElementwiseOpConvert<tosa::BitwiseNotOp, spirv::TosaBitwiseNotOp>,
-      UnaryElementwiseOpConvert<tosa::CeilOp, spirv::TosaCeilOp>,
-      UnaryElementwiseOpConvert<tosa::ClzOp, spirv::TosaClzOp>,
-      UnaryElementwiseOpConvert<tosa::CosOp, spirv::TosaCosOp>,
-      UnaryElementwiseOpConvert<tosa::ExpOp, spirv::TosaExpOp>,
-      UnaryElementwiseOpConvert<tosa::FloorOp, spirv::TosaFloorOp>,
-      UnaryElementwiseOpConvert<tosa::LogOp, spirv::TosaLogOp>,
-      UnaryElementwiseOpConvert<tosa::LogicalNotOp, spirv::TosaLogicalNotOp>,
-      UnaryElementwiseOpConvert<tosa::ReciprocalOp, spirv::TosaReciprocalOp>,
-      UnaryElementwiseOpConvert<tosa::RsqrtOp, spirv::TosaRsqrtOp>,
-      UnaryElementwiseOpConvert<tosa::SinOp, spirv::TosaSinOp>,
-      BinaryElementwiseOpConvert<tosa::EqualOp, spirv::TosaEqualOp>,
-      BinaryElementwiseOpConvert<tosa::GreaterOp, spirv::TosaGreaterOp>,
-      BinaryElementwiseOpConvert<tosa::GreaterEqualOp,
-                                 spirv::TosaGreaterEqualOp>>(
+      TosaOpConvert<tosa::ArgMaxOp,
+                    NanModeReductionReplace<spirv::TosaArgMaxOp>>,
+      TosaOpConvert<tosa::ClampOp, ClampReplace>,
+      TosaOpConvert<tosa::ErfOp, UnaryInputReplace<spirv::TosaErfOp>>,
+      TosaOpConvert<tosa::SigmoidOp, UnaryInputReplace<spirv::TosaSigmoidOp>>,
+      TosaOpConvert<tosa::TanhOp, UnaryInputReplace<spirv::TosaTanhOp>>,
+      TosaOpConvert<tosa::AddOp, BinaryElementwiseReplace<spirv::TosaAddOp>>,
+      TosaOpConvert<tosa::ArithmeticRightShiftOp, ArithmeticRightShiftReplace>,
+      TosaOpConvert<tosa::BitwiseAndOp,
+                    BinaryElementwiseReplace<spirv::TosaBitwiseAndOp>>,
+      TosaOpConvert<tosa::BitwiseOrOp,
+                    BinaryElementwiseReplace<spirv::TosaBitwiseOrOp>>,
+      TosaOpConvert<tosa::BitwiseXorOp,
+                    BinaryElementwiseReplace<spirv::TosaBitwiseXorOp>>,
+      TosaOpConvert<tosa::IntDivOp,
+                    BinaryElementwiseReplace<spirv::TosaIntDivOp>>,
+      TosaOpConvert<tosa::LogicalAndOp,
+                    BinaryElementwiseReplace<spirv::TosaLogicalAndOp>>,
+      TosaOpConvert<tosa::LogicalLeftShiftOp,
+                    BinaryElementwiseReplace<spirv::TosaLogicalLeftShiftOp>>,
+      TosaOpConvert<tosa::LogicalRightShiftOp,
+                    BinaryElementwiseReplace<spirv::TosaLogicalRightShiftOp>>,
+      TosaOpConvert<tosa::LogicalOrOp,
+                    BinaryElementwiseReplace<spirv::TosaLogicalOrOp>>,
+      TosaOpConvert<tosa::LogicalXorOp,
+                    BinaryElementwiseReplace<spirv::TosaLogicalXorOp>>,
+      TosaOpConvert<tosa::MaximumOp,
+                    BinaryNanModeElementwiseReplace<spirv::TosaMaximumOp>>,
+      TosaOpConvert<tosa::MinimumOp,
+                    BinaryNanModeElementwiseReplace<spirv::TosaMinimumOp>>,
+      TosaOpConvert<tosa::MulOp, MulReplace>,
+      TosaOpConvert<tosa::PowOp, BinaryElementwiseReplace<spirv::TosaPowOp>>,
+      TosaOpConvert<tosa::SubOp, BinaryElementwiseReplace<spirv::TosaSubOp>>,
+      TosaOpConvert<tosa::TableOp, TableReplace>,
+      TosaOpConvert<tosa::AbsOp, UnaryInput1Replace<spirv::TosaAbsOp>>,
+      TosaOpConvert<tosa::BitwiseNotOp,
+                    UnaryInput1Replace<spirv::TosaBitwiseNotOp>>,
+      TosaOpConvert<tosa::CeilOp, UnaryInput1Replace<spirv::TosaCeilOp>>,
+      TosaOpConvert<tosa::ClzOp, UnaryInput1Replace<spirv::TosaClzOp>>,
+      TosaOpConvert<tosa::CosOp, UnaryInput1Replace<spirv::TosaCosOp>>,
+      TosaOpConvert<tosa::ExpOp, UnaryInput1Replace<spirv::TosaExpOp>>,
+      TosaOpConvert<tosa::FloorOp, UnaryInput1Replace<spirv::TosaFloorOp>>,
+      TosaOpConvert<tosa::LogOp, UnaryInput1Replace<spirv::TosaLogOp>>,
+      TosaOpConvert<tosa::LogicalNotOp,
+                    UnaryInput1Replace<spirv::TosaLogicalNotOp>>,
+      TosaOpConvert<tosa::NegateOp, NegateReplace>,
+      TosaOpConvert<tosa::ReciprocalOp,
+                    UnaryInput1Replace<spirv::TosaReciprocalOp>>,
+      TosaOpConvert<tosa::RsqrtOp, UnaryInput1Replace<spirv::TosaRsqrtOp>>,
+      TosaOpConvert<tosa::SinOp, UnaryInput1Replace<spirv::TosaSinOp>>,
+      TosaOpConvert<tosa::SelectOp, SelectReplace>,
+      TosaOpConvert<tosa::EqualOp,
+                    BinaryElementwiseReplace<spirv::TosaEqualOp>>,
+      TosaOpConvert<tosa::GreaterOp,
+                    BinaryElementwiseReplace<spirv::TosaGreaterOp>>,
+      TosaOpConvert<tosa::GreaterEqualOp,
+                    BinaryElementwiseReplace<spirv::TosaGreaterEqualOp>>,
+      TosaOpConvert<tosa::ReduceAllOp,
+                    ReductionReplace<spirv::TosaReduceAllOp>>,
+      TosaOpConvert<tosa::ReduceAnyOp,
+                    ReductionReplace<spirv::TosaReduceAnyOp>>,
+      TosaOpConvert<tosa::ReduceMaxOp,
+                    NanModeReductionReplace<spirv::TosaReduceMaxOp>>,
+      TosaOpConvert<tosa::ReduceMinOp,
+                    NanModeReductionReplace<spirv::TosaReduceMinOp>>,
+      TosaOpConvert<tosa::ReduceProductOp,
+                    ReductionReplace<spirv::TosaReduceProductOp>>,
+      TosaOpConvert<tosa::ReduceSumOp,
+                    ReductionReplace<spirv::TosaReduceSumOp>>,
+      TosaOpConvert<tosa::ReshapeOp, ReshapeReplace>,
+      TosaOpConvert<tosa::ReverseOp, ReverseReplace>,
+      TosaOpConvert<tosa::SliceOp, SliceReplace>,
+      TosaOpConvert<tosa::TileOp, TileReplace>,
+      TosaOpConvert<tosa::TransposeOp, TransposeReplace>,
+      TosaOpConvert<tosa::GatherOp, GatherReplace>,
+      TosaOpConvert<tosa::ScatterOp, ScatterReplace>,
+      TosaOpConvert<tosa::ResizeOp, ResizeReplace>,
+      TosaOpConvert<tosa::CastOp, UnaryInputReplace<spirv::TosaCastOp>>,
+      TosaOpConvert<tosa::ConstShapeOp, ConstShapeReplace>>(
       typeConverter, patterns.getContext());
 }
 
diff --git a/mlir/test/Conversion/TosaToSPIRVTosa/tosa-to-spirv.mlir b/mlir/test/Conversion/TosaToSPIRVTosa/tosa-to-spirv.mlir
index 4e54e5cdd3634..a175baf62eda1 100644
--- a/mlir/test/Conversion/TosaToSPIRVTosa/tosa-to-spirv.mlir
+++ b/mlir/test/Conversion/TosaToSPIRVTosa/tosa-to-spirv.mlir
@@ -1,5 +1,31 @@
 // RUN: mlir-opt --split-input-file --tosa-to-spirv-tosa %s | FileCheck %s
 
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ArgMax
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: spirv.ARM.Graph @argmax_int
+func.func @argmax_int(%arg0: tensor<2x3x4xi8>) -> tensor<2x4xi32> {
+  // CHECK: %[[ARGMAX:.*]] = spirv.Tosa.ArgMax axis = 1, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<2x3x4xi8> -> !spirv.arm.tensor<2x4xi32>
+  %res = tosa.argmax %arg0 {axis = 1 : i32, nan_mode = PROPAGATE} : (tensor<2x3x4xi8>) -> tensor<2x4xi32>
+  return %res : tensor<2x4xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Clamp
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: spirv.ARM.Graph @clamp_int
+func.func @clamp_int(%arg0: tensor<4x8xi8>) -> tensor<4x8xi8> {
+  // CHECK: %[[CLAMP:.*]] = spirv.Tosa.Clam...
[truncated]

Comment thread mlir/lib/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosaOps.cpp Outdated
Replace stateless replacer classes with plain replacement functions
passed to TosaOpConvert as non-type template parameters.

Change-Id: Ieb3e4bba0385996089f2f490ca3a03f33cb97eb6
Signed-off-by: Davide Grohmann <davide.grohmann@arm.com>
Copy link
Copy Markdown
Contributor

@IgWod IgWod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment thread mlir/lib/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosaOps.cpp
@davidegrohmann davidegrohmann merged commit bbf8a33 into llvm:main May 29, 2026
10 checks passed
@davidegrohmann davidegrohmann deleted the mlir-spirv-extend-tosa-to-spirv-tosa-pass branch May 29, 2026 06:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants