Skip to content

[mlir][tosa] Add TOSA Max Pool 2D Adaptive#191225

Merged
lhutton1 merged 1 commit into
llvm:mainfrom
iliyan-georgiev-arm:max_pool2d_adaptive_redux
Apr 10, 2026
Merged

[mlir][tosa] Add TOSA Max Pool 2D Adaptive#191225
lhutton1 merged 1 commit into
llvm:mainfrom
iliyan-georgiev-arm:max_pool2d_adaptive_redux

Conversation

@iliyan-georgiev-arm
Copy link
Copy Markdown
Contributor

@iliyan-georgiev-arm iliyan-georgiev-arm commented Apr 9, 2026

Implements:

  • Operator definition
  • Operator verifier
  • Validation
  • Tests
  • Adds NoMemoryEffect to AvgPool2dAdaptive

Signed-off-by: Iliyan Georgiev Iliyan.Georgiev@arm.com

@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Apr 9, 2026

@llvm/pr-subscribers-mlir

Author: Iliyan Georgiev (iliyan-georgiev-arm)

Changes

Implements:

  • Operator definition
  • Operator verifier
  • Validation
  • Tests
  • Adds NoMemoryEffect to AvgPool2dAdaptive

Signed-off-by: Iliyan Georgiev <Iliyan.Georgiev@arm.com>
Change-Id: I7550cc588ffc0da684605d67db71d989fb51da62


Patch is 26.37 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/191225.diff

10 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc (+13)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+37-1)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+53-3)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp (+9)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+4-3)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+116)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+47)
  • (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir (+11)
  • (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir (+10)
  • (modified) mlir/test/Dialect/Tosa/verifier.mlir (+24)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index a575024a6144a..d3e2cd129028e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -65,6 +65,11 @@ profileComplianceMap = {
       {{Profile::pro_fp},
        {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
         {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+    {"tosa.max_pool2d_adaptive",
+     {{{Profile::pro_int}, {{{i8T, i8T}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Profile::pro_fp},
+       {{{fp16T, fp16T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp32T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}}}},
     {"tosa.transpose_conv2d",
      {{{Profile::pro_int},
        {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}},
@@ -657,6 +662,14 @@ extensionComplianceMap = {
       {{Extension::fp8e5m2},
        {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
       {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+    {"tosa.max_pool2d_adaptive",
+     {{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Extension::fp8e4m3},
+       {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Extension::fp8e5m2},
+       {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Extension::bf16},
+       {{{bf16T, bf16T}, SpecificationVersion::V_1_1_DRAFT}}}}},
     {"tosa.rfft2d",
      {{{Extension::fft},
        {{{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 5ac91e6b65457..45d1388a28749 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -124,7 +124,8 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d", [NoMemoryEffect]> {
 //===----------------------------------------------------------------------===//
 // Operator: avg_pool2d_adaptive
 //===----------------------------------------------------------------------===//
-def Tosa_AvgPool2dAdaptiveOp : Tosa_InferShapedTypeOp<"avg_pool2d_adaptive"> {
+def Tosa_AvgPool2dAdaptiveOp
+    : Tosa_InferShapedTypeOp<"avg_pool2d_adaptive", [NoMemoryEffect]> {
   let summary = "Performs average pooling on the input with shape operands.";
 
   let description = [{
@@ -524,6 +525,41 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d", [Pure]> {
   let hasCustomAssemblyFormat = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Operator: max_pool2d_adaptive
+//===----------------------------------------------------------------------===//
+def Tosa_MaxPool2dAdaptiveOp
+    : Tosa_InferShapedTypeOp<"max_pool2d_adaptive", [Pure]> {
+  let summary = "Performs max pooling on the input.";
+
+  let description = [{
+    This performs a max pooling over the given input tensor. A sliding window of
+    size given by <kernel size> is passed over the input tensor, with the
+    maximum value being placed in the output tensor.
+    Compared to MAX_POOL2D, MAX_POOL2D_ADAPTIVE has the kernel, stride, 
+    pad arguments as inputs rather than attributes.
+  }];
+
+  let arguments =
+      (ins Tosa_Tensor4D:$input, Rank2TosaShape:$kernel, Rank2TosaShape:$stride,
+          Rank4TosaShape:$pad,
+
+          DefaultValuedAttr<
+              Tosa_NanPropagationModeAttr,
+              "::mlir::tosa::NanPropagationMode::PROPAGATE">:$nan_mode);
+
+  let results = (outs Tosa_Tensor4D:$output);
+
+  list<Availability> availability =
+      [Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+       Extension<[Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2,
+                  Tosa_EXT_BF16]>,
+  ];
+
+  let hasVerifier = 1;
+  let hasCustomAssemblyFormat = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Operator: rfft2d
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 29318023092a1..3bf878304429e 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -486,6 +486,15 @@ void MaxPool2dOp::print(OpAsmPrinter &parser) {
   printWithNanPropagationHandling(parser, *this);
 }
 
+ParseResult MaxPool2dAdaptiveOp::parse(OpAsmParser &parser,
+                                       OperationState &result) {
+  return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
+}
+
+void MaxPool2dAdaptiveOp::print(OpAsmPrinter &parser) {
+  printWithNanPropagationHandling(parser, *this);
+}
+
 ParseResult ClampOp::parse(OpAsmParser &parser, OperationState &result) {
   return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
 }
@@ -1228,9 +1237,8 @@ struct AdaptivePoolingConstShapeValues {
 
 template <typename T>
 static constexpr bool IsSupportedAdaptivePoolConstShapeVerifyOp =
-    std::is_same_v<T, tosa::AvgPool2dAdaptiveOp>
-    // || std::is_same_v<T, tosa::MaxPool2dAdaptiveOp>
-    ;
+    std::is_same_v<T, tosa::AvgPool2dAdaptiveOp> ||
+    std::is_same_v<T, tosa::MaxPool2dAdaptiveOp>;
 
 template <typename T,
           typename std::enable_if<IsSupportedAdaptivePoolConstShapeVerifyOp<T>,
@@ -4085,6 +4093,33 @@ LogicalResult MaxPool2dOp::inferReturnTypeComponents(
                                  inferredReturnShapes);
 }
 
+LogicalResult MaxPool2dAdaptiveOp::inferReturnTypeComponents(
+    MLIRContext *context, ::std::optional<Location> location,
+    MaxPool2dAdaptiveOp::Adaptor adaptor,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  ShapeAdaptor inputShape(adaptor.getInput().getType());
+
+  llvm::SmallVector<int64_t> kernelValues;
+  llvm::SmallVector<int64_t> strideValues;
+  llvm::SmallVector<int64_t> padValues;
+  if (tosa::getConstShapeValues(adaptor.getKernel().getDefiningOp(),
+                                kernelValues) &&
+      tosa::getConstShapeValues(adaptor.getStride().getDefiningOp(),
+                                strideValues) &&
+      tosa::getConstShapeValues(adaptor.getPad().getDefiningOp(), padValues)) {
+    return poolingInferReturnTypes(inputShape, kernelValues, strideValues,
+                                   padValues, inferredReturnShapes);
+  }
+
+  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
+  if (inputShape.hasRank()) {
+    outputShape[0] = inputShape.getDimSize(0);
+    outputShape[3] = inputShape.getDimSize(3);
+  }
+  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+  return success();
+}
+
 LogicalResult MaxPool2dOp::verify() {
   if (failed(verifySameElementTypes(*this, /* intype = */ getInput().getType(),
                                     /* outType = */ getOutput().getType())))
@@ -4096,6 +4131,21 @@ LogicalResult MaxPool2dOp::verify() {
   return success();
 }
 
+LogicalResult MaxPool2dAdaptiveOp::verify() {
+  if (failed(verifySameElementTypes(*this, /* intype = */ getInput().getType(),
+                                    /* outType = */ getOutput().getType())))
+    return failure();
+
+  AdaptivePoolingConstShapeValues values;
+  extractAdaptivePoolingConstShapeOperands(*this, values);
+
+  if (failed(verifyPoolingOpImpl(getOperation(), values.kernel, values.stride,
+                                 values.pad, getInput(), getOutput())))
+    return failure();
+
+  return success();
+}
+
 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     DepthwiseConv2DOp::Adaptor adaptor,
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 78bf700597c3c..01c85be4f704f 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -88,6 +88,14 @@ ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dAdaptiveOp op) {
   return success();
 }
 
+template <>
+LogicalResult
+ProfileInfoDepot::populateProfileInfo(tosa::MaxPool2dAdaptiveOp op) {
+  addValue(op.getInput());
+  addValue(op.getOutput());
+  return success();
+}
+
 template <typename T>
 LogicalResult ProfileInfoDepot::populateProfileInfoConv(T op) {
   addValue(op.getInput());
@@ -288,6 +296,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
   POPULATE_PROFILE_INFO_CUSTOM(Variable)
   POPULATE_PROFILE_INFO_CUSTOM(VariableWrite)
   POPULATE_PROFILE_INFO_CUSTOM(Dim)
+  POPULATE_PROFILE_INFO_CUSTOM(MaxPool2dAdaptive)
 
   // For the most of tosa operators, all operands are profile/extension related
   // and hence are all considered in this profile-based compilance check.
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 6169003881487..8c00603d7abb4 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -359,9 +359,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
 
   template <typename T>
   static constexpr bool IsSupportedAdaptivePoolOp =
-      std::is_same_v<T, tosa::AvgPool2dAdaptiveOp>
-      // || std::is_same_v<T, tosa::MaxPool2dAdaptiveOp>
-      ;
+      std::is_same_v<T, tosa::AvgPool2dAdaptiveOp> ||
+      std::is_same_v<T, tosa::MaxPool2dAdaptiveOp>;
 
   template <typename T, typename std::enable_if<IsSupportedAdaptivePoolOp<T>,
                                                 int>::type = 0>
@@ -817,6 +816,7 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
   CHECK_SIZES(MatMul);
   CHECK_SIZES(MatmulTBlockScaled);
   CHECK_SIZES(MaxPool2d);
+  CHECK_SIZES(MaxPool2dAdaptive);
   CHECK_SIZES(RFFT2d);
   // Scatter/Gather Operators
   CHECK_SIZES(Gather);
@@ -918,6 +918,7 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
       failed(levelCheckConv<tosa::DepthwiseConv2DOp>(op)) ||
       failed(levelCheckFFT<tosa::FFT2dOp>(op)) ||
       failed(levelCheckPool<tosa::MaxPool2dOp>(op)) ||
+      failed(levelCheckAdaptivePool<tosa::MaxPool2dAdaptiveOp>(op)) ||
       failed(levelCheckFFT<tosa::RFFT2dOp>(op)) ||
       failed(levelCheckTransposeConv2d(op)) || failed(levelCheckResize(op)) ||
       failed(levelCheckConv2DBlockScaled(op))) {
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index b3bdb02c20103..ca4d2dca0e7c9 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -929,6 +929,122 @@ func.func @test_maxpool2d_stride_x(%arg0: tensor<1x32x8194x8xf32>) -> tensor<1x3
 
 // -----
 
+func.func @test_maxpool2d_adaptive_kernel_y(%arg0: tensor<1x8194x32x8xf32>) -> tensor<1x2x32x8xf32> {
+  %kernel = tosa.const_shape {values = dense<[8193, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // expected-error@+1 {{'tosa.max_pool2d_adaptive' op failed level check: kernel <= MAX_KERNEL (8192), got 8193}}
+  %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
+         (tensor<1x8194x32x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x2x32x8xf32>
+  return %0 : tensor<1x2x32x8xf32>
+}
+
+// -----
+
+func.func @test_maxpool2d_adaptive_kernel_x(%arg0: tensor<1x32x8194x8xf32>) -> tensor<1x32x2x8xf32> {
+  %kernel = tosa.const_shape {values = dense<[1, 8193]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // expected-error@+1 {{'tosa.max_pool2d_adaptive' op failed level check: kernel <= MAX_KERNEL (8192), got 8193}}
+  %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
+         (tensor<1x32x8194x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x2x8xf32>
+  return %0 : tensor<1x32x2x8xf32>
+}
+
+// -----
+
+func.func @test_maxpool2d_adaptive_stride_y(%arg0: tensor<1x8194x32x8xf32>) -> tensor<1x2x32x8xf32> {
+  %kernel = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %stride = tosa.const_shape {values = dense<[8193, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // expected-error@+1 {{'tosa.max_pool2d_adaptive' op failed level check: stride <= MAX_STRIDE (8192), got 8193}}
+  %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
+         (tensor<1x8194x32x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x2x32x8xf32>
+  return %0 : tensor<1x2x32x8xf32>
+}
+
+// -----
+
+func.func @test_maxpool2d_adaptive_stride_x(%arg0: tensor<1x32x8194x8xf32>) -> tensor<1x32x2x8xf32> {
+  %kernel = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %stride = tosa.const_shape {values = dense<[1, 8193]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // expected-error@+1 {{'tosa.max_pool2d_adaptive' op failed level check: stride <= MAX_STRIDE (8192), got 8193}}
+  %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
+         (tensor<1x32x8194x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x2x8xf32>
+  return %0 : tensor<1x32x2x8xf32>
+}
+
+// -----
+
+func.func @test_maxpool2d_adaptive_pad_first(%arg0: tensor<1x32x8194x8xf32>) -> tensor<1x32x2x8xf32> {
+  // If the source of the kernel passed to max_pool2d_adaptive is a const_shape then pad < kernel check is applied.
+  // This is a workaround for the above so that we can level check the padding.
+  %a = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %b = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %kernel = tosa.add_shape %a, %b : (!tosa.shape<2>, !tosa.shape<2>) -> !tosa.shape<2>
+  
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[8193, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // expected-error@+1 {{'tosa.max_pool2d_adaptive' op failed level check: pad <= MAX_KERNEL (8192), got 8193}}
+  %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
+         (tensor<1x32x8194x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x2x8xf32>
+  return %0 : tensor<1x32x2x8xf32>
+}
+
+// -----
+
+func.func @test_maxpool2d_adaptive_pad_second(%arg0: tensor<1x32x8194x8xf32>) -> tensor<1x32x2x8xf32> {
+  // If the source of the kernel passed to max_pool2d_adaptive is a const_shape then pad < kernel check is applied.
+  // This is a workaround for the above so that we can level check the padding.
+  %a = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %b = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %kernel = tosa.add_shape %a, %b : (!tosa.shape<2>, !tosa.shape<2>) -> !tosa.shape<2>
+
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[0, 8193, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // expected-error@+1 {{'tosa.max_pool2d_adaptive' op failed level check: pad <= MAX_KERNEL (8192), got 8193}}
+  %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
+         (tensor<1x32x8194x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x2x8xf32>
+  return %0 : tensor<1x32x2x8xf32>
+}
+
+// -----
+
+func.func @test_maxpool2d_adaptive_pad_third(%arg0: tensor<1x32x8194x8xf32>) -> tensor<1x32x2x8xf32> {
+  // If the source of the kernel passed to max_pool2d_adaptive is a const_shape then pad < kernel check is applied.
+  // This is a workaround for the above so that we can level check the padding.
+  %a = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %b = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %kernel = tosa.add_shape %a, %b : (!tosa.shape<2>, !tosa.shape<2>) -> !tosa.shape<2>
+
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[0, 0, 8193, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // expected-error@+1 {{'tosa.max_pool2d_adaptive' op failed level check: pad <= MAX_KERNEL (8192), got 8193}}
+  %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
+         (tensor<1x32x8194x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x2x8xf32>
+  return %0 : tensor<1x32x2x8xf32>
+}
+
+// -----
+
+func.func @test_maxpool2d_adaptive_pad_forth(%arg0: tensor<1x32x8194x8xf32>) -> tensor<1x32x2x8xf32> {
+  // If the source of the kernel passed to max_pool2d_adaptive is a const_shape then pad < kernel check is applied.
+  // This is a workaround for the above so that we can level check the padding.
+  %a = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %b = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %kernel = tosa.add_shape %a, %b : (!tosa.shape<2>, !tosa.shape<2>) -> !tosa.shape<2>
+
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 8193]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // expected-error@+1 {{'tosa.max_pool2d_adaptive' op failed level check: pad <= MAX_KERNEL (8192), got 8193}}
+  %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
+         (tensor<1x32x8194x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x2x8xf32>
+  return %0 : tensor<1x32x2x8xf32>
+}
+
+// -----
+
 func.func @test_rfft2d_input_h(%arg0: tensor<13x16384x16xf32>) -> (tensor<13x16384x9xf32>, tensor<13x16384x9xf32>) {
   // expected-error@+1 {{'tosa.rfft2d' op failed level check: H <= MAX_KERNEL (8192), got 16384}}
   %0, %1 = "tosa.rfft2d"(%arg0) {} : (tensor<13x16384x16xf32>) -> (tensor<13x16384x9xf32>, tensor<13x16384x9xf32>)
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index e80d3d84a8105..b30e92c4a9621 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -253,6 +253,53 @@ func.func @test_max_pool2d_f16(%arg0: tensor<1x32x32x8xf16>) -> tensor<1x32x32x8
   return %0 : tensor<1x32x32x8xf16>
 }
 
+// CHECK-LABEL: max_pool2d_adaptive_f32
+func.func @test_max_pool2d_adaptive_f32(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+  %kernel = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad : (tensor<1x32x32x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x32x8xf32>
+  return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+// CHECK-LABEL: max_pool2d_adaptive_bf16
+func.func @test_max_pool2d_adaptive_bf16(%arg0: tensor<1x32x32x8xbf16>) -> tensor<1x32x32x8xbf16> {
+  %kernel = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad : (tensor<1x32x32x8xbf16>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x32x8xbf16>
+  return %0 : tensor<1x32x32x8xbf16>
+}
+
+// -----
+// CHECK-LABEL: max_pool2d_adaptive_f16
+func.func @test_max_pool2d_adaptive_f16(%arg0: tensor<1x32x32x8xf16>) -> tensor<1x32x32x8xf16> {
+  %kernel = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %stride = tosa.const_...
[truncated]

@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Apr 9, 2026

@llvm/pr-subscribers-mlir-tosa

Author: Iliyan Georgiev (iliyan-georgiev-arm)

Changes

Implements:

  • Operator definition
  • Operator verifier
  • Validation
  • Tests
  • Adds NoMemoryEffect to AvgPool2dAdaptive

Signed-off-by: Iliyan Georgiev <Iliyan.Georgiev@arm.com>
Change-Id: I7550cc588ffc0da684605d67db71d989fb51da62


Patch is 26.37 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/191225.diff

10 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc (+13)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+37-1)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+53-3)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp (+9)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+4-3)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+116)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+47)
  • (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir (+11)
  • (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir (+10)
  • (modified) mlir/test/Dialect/Tosa/verifier.mlir (+24)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index a575024a6144a..d3e2cd129028e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -65,6 +65,11 @@ profileComplianceMap = {
       {{Profile::pro_fp},
        {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
         {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+    {"tosa.max_pool2d_adaptive",
+     {{{Profile::pro_int}, {{{i8T, i8T}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Profile::pro_fp},
+       {{{fp16T, fp16T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp32T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}}}},
     {"tosa.transpose_conv2d",
      {{{Profile::pro_int},
        {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}},
@@ -657,6 +662,14 @@ extensionComplianceMap = {
       {{Extension::fp8e5m2},
        {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
       {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+    {"tosa.max_pool2d_adaptive",
+     {{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Extension::fp8e4m3},
+       {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Extension::fp8e5m2},
+       {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Extension::bf16},
+       {{{bf16T, bf16T}, SpecificationVersion::V_1_1_DRAFT}}}}},
     {"tosa.rfft2d",
      {{{Extension::fft},
        {{{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 5ac91e6b65457..45d1388a28749 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -124,7 +124,8 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d", [NoMemoryEffect]> {
 //===----------------------------------------------------------------------===//
 // Operator: avg_pool2d_adaptive
 //===----------------------------------------------------------------------===//
-def Tosa_AvgPool2dAdaptiveOp : Tosa_InferShapedTypeOp<"avg_pool2d_adaptive"> {
+def Tosa_AvgPool2dAdaptiveOp
+    : Tosa_InferShapedTypeOp<"avg_pool2d_adaptive", [NoMemoryEffect]> {
   let summary = "Performs average pooling on the input with shape operands.";
 
   let description = [{
@@ -524,6 +525,41 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d", [Pure]> {
   let hasCustomAssemblyFormat = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Operator: max_pool2d_adaptive
+//===----------------------------------------------------------------------===//
+def Tosa_MaxPool2dAdaptiveOp
+    : Tosa_InferShapedTypeOp<"max_pool2d_adaptive", [Pure]> {
+  let summary = "Performs max pooling on the input.";
+
+  let description = [{
+    This performs a max pooling over the given input tensor. A sliding window of
+    size given by <kernel size> is passed over the input tensor, with the
+    maximum value being placed in the output tensor.
+    Compared to MAX_POOL2D, MAX_POOL2D_ADAPTIVE has the kernel, stride, 
+    pad arguments as inputs rather than attributes.
+  }];
+
+  let arguments =
+      (ins Tosa_Tensor4D:$input, Rank2TosaShape:$kernel, Rank2TosaShape:$stride,
+          Rank4TosaShape:$pad,
+
+          DefaultValuedAttr<
+              Tosa_NanPropagationModeAttr,
+              "::mlir::tosa::NanPropagationMode::PROPAGATE">:$nan_mode);
+
+  let results = (outs Tosa_Tensor4D:$output);
+
+  list<Availability> availability =
+      [Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+       Extension<[Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2,
+                  Tosa_EXT_BF16]>,
+  ];
+
+  let hasVerifier = 1;
+  let hasCustomAssemblyFormat = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Operator: rfft2d
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 29318023092a1..3bf878304429e 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -486,6 +486,15 @@ void MaxPool2dOp::print(OpAsmPrinter &parser) {
   printWithNanPropagationHandling(parser, *this);
 }
 
+ParseResult MaxPool2dAdaptiveOp::parse(OpAsmParser &parser,
+                                       OperationState &result) {
+  return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
+}
+
+void MaxPool2dAdaptiveOp::print(OpAsmPrinter &parser) {
+  printWithNanPropagationHandling(parser, *this);
+}
+
 ParseResult ClampOp::parse(OpAsmParser &parser, OperationState &result) {
   return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
 }
@@ -1228,9 +1237,8 @@ struct AdaptivePoolingConstShapeValues {
 
 template <typename T>
 static constexpr bool IsSupportedAdaptivePoolConstShapeVerifyOp =
-    std::is_same_v<T, tosa::AvgPool2dAdaptiveOp>
-    // || std::is_same_v<T, tosa::MaxPool2dAdaptiveOp>
-    ;
+    std::is_same_v<T, tosa::AvgPool2dAdaptiveOp> ||
+    std::is_same_v<T, tosa::MaxPool2dAdaptiveOp>;
 
 template <typename T,
           typename std::enable_if<IsSupportedAdaptivePoolConstShapeVerifyOp<T>,
@@ -4085,6 +4093,33 @@ LogicalResult MaxPool2dOp::inferReturnTypeComponents(
                                  inferredReturnShapes);
 }
 
+LogicalResult MaxPool2dAdaptiveOp::inferReturnTypeComponents(
+    MLIRContext *context, ::std::optional<Location> location,
+    MaxPool2dAdaptiveOp::Adaptor adaptor,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  ShapeAdaptor inputShape(adaptor.getInput().getType());
+
+  llvm::SmallVector<int64_t> kernelValues;
+  llvm::SmallVector<int64_t> strideValues;
+  llvm::SmallVector<int64_t> padValues;
+  if (tosa::getConstShapeValues(adaptor.getKernel().getDefiningOp(),
+                                kernelValues) &&
+      tosa::getConstShapeValues(adaptor.getStride().getDefiningOp(),
+                                strideValues) &&
+      tosa::getConstShapeValues(adaptor.getPad().getDefiningOp(), padValues)) {
+    return poolingInferReturnTypes(inputShape, kernelValues, strideValues,
+                                   padValues, inferredReturnShapes);
+  }
+
+  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
+  if (inputShape.hasRank()) {
+    outputShape[0] = inputShape.getDimSize(0);
+    outputShape[3] = inputShape.getDimSize(3);
+  }
+  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+  return success();
+}
+
 LogicalResult MaxPool2dOp::verify() {
   if (failed(verifySameElementTypes(*this, /* intype = */ getInput().getType(),
                                     /* outType = */ getOutput().getType())))
@@ -4096,6 +4131,21 @@ LogicalResult MaxPool2dOp::verify() {
   return success();
 }
 
+LogicalResult MaxPool2dAdaptiveOp::verify() {
+  if (failed(verifySameElementTypes(*this, /* intype = */ getInput().getType(),
+                                    /* outType = */ getOutput().getType())))
+    return failure();
+
+  AdaptivePoolingConstShapeValues values;
+  extractAdaptivePoolingConstShapeOperands(*this, values);
+
+  if (failed(verifyPoolingOpImpl(getOperation(), values.kernel, values.stride,
+                                 values.pad, getInput(), getOutput())))
+    return failure();
+
+  return success();
+}
+
 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     DepthwiseConv2DOp::Adaptor adaptor,
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 78bf700597c3c..01c85be4f704f 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -88,6 +88,14 @@ ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dAdaptiveOp op) {
   return success();
 }
 
+template <>
+LogicalResult
+ProfileInfoDepot::populateProfileInfo(tosa::MaxPool2dAdaptiveOp op) {
+  addValue(op.getInput());
+  addValue(op.getOutput());
+  return success();
+}
+
 template <typename T>
 LogicalResult ProfileInfoDepot::populateProfileInfoConv(T op) {
   addValue(op.getInput());
@@ -288,6 +296,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
   POPULATE_PROFILE_INFO_CUSTOM(Variable)
   POPULATE_PROFILE_INFO_CUSTOM(VariableWrite)
   POPULATE_PROFILE_INFO_CUSTOM(Dim)
+  POPULATE_PROFILE_INFO_CUSTOM(MaxPool2dAdaptive)
 
   // For the most of tosa operators, all operands are profile/extension related
   // and hence are all considered in this profile-based compilance check.
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 6169003881487..8c00603d7abb4 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -359,9 +359,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
 
   template <typename T>
   static constexpr bool IsSupportedAdaptivePoolOp =
-      std::is_same_v<T, tosa::AvgPool2dAdaptiveOp>
-      // || std::is_same_v<T, tosa::MaxPool2dAdaptiveOp>
-      ;
+      std::is_same_v<T, tosa::AvgPool2dAdaptiveOp> ||
+      std::is_same_v<T, tosa::MaxPool2dAdaptiveOp>;
 
   template <typename T, typename std::enable_if<IsSupportedAdaptivePoolOp<T>,
                                                 int>::type = 0>
@@ -817,6 +816,7 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
   CHECK_SIZES(MatMul);
   CHECK_SIZES(MatmulTBlockScaled);
   CHECK_SIZES(MaxPool2d);
+  CHECK_SIZES(MaxPool2dAdaptive);
   CHECK_SIZES(RFFT2d);
   // Scatter/Gather Operators
   CHECK_SIZES(Gather);
@@ -918,6 +918,7 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
       failed(levelCheckConv<tosa::DepthwiseConv2DOp>(op)) ||
       failed(levelCheckFFT<tosa::FFT2dOp>(op)) ||
       failed(levelCheckPool<tosa::MaxPool2dOp>(op)) ||
+      failed(levelCheckAdaptivePool<tosa::MaxPool2dAdaptiveOp>(op)) ||
       failed(levelCheckFFT<tosa::RFFT2dOp>(op)) ||
       failed(levelCheckTransposeConv2d(op)) || failed(levelCheckResize(op)) ||
       failed(levelCheckConv2DBlockScaled(op))) {
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index b3bdb02c20103..ca4d2dca0e7c9 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -929,6 +929,122 @@ func.func @test_maxpool2d_stride_x(%arg0: tensor<1x32x8194x8xf32>) -> tensor<1x3
 
 // -----
 
+func.func @test_maxpool2d_adaptive_kernel_y(%arg0: tensor<1x8194x32x8xf32>) -> tensor<1x2x32x8xf32> {
+  %kernel = tosa.const_shape {values = dense<[8193, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // expected-error@+1 {{'tosa.max_pool2d_adaptive' op failed level check: kernel <= MAX_KERNEL (8192), got 8193}}
+  %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
+         (tensor<1x8194x32x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x2x32x8xf32>
+  return %0 : tensor<1x2x32x8xf32>
+}
+
+// -----
+
+func.func @test_maxpool2d_adaptive_kernel_x(%arg0: tensor<1x32x8194x8xf32>) -> tensor<1x32x2x8xf32> {
+  %kernel = tosa.const_shape {values = dense<[1, 8193]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // expected-error@+1 {{'tosa.max_pool2d_adaptive' op failed level check: kernel <= MAX_KERNEL (8192), got 8193}}
+  %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
+         (tensor<1x32x8194x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x2x8xf32>
+  return %0 : tensor<1x32x2x8xf32>
+}
+
+// -----
+
+func.func @test_maxpool2d_adaptive_stride_y(%arg0: tensor<1x8194x32x8xf32>) -> tensor<1x2x32x8xf32> {
+  %kernel = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %stride = tosa.const_shape {values = dense<[8193, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // expected-error@+1 {{'tosa.max_pool2d_adaptive' op failed level check: stride <= MAX_STRIDE (8192), got 8193}}
+  %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
+         (tensor<1x8194x32x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x2x32x8xf32>
+  return %0 : tensor<1x2x32x8xf32>
+}
+
+// -----
+
+func.func @test_maxpool2d_adaptive_stride_x(%arg0: tensor<1x32x8194x8xf32>) -> tensor<1x32x2x8xf32> {
+  %kernel = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %stride = tosa.const_shape {values = dense<[1, 8193]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // expected-error@+1 {{'tosa.max_pool2d_adaptive' op failed level check: stride <= MAX_STRIDE (8192), got 8193}}
+  %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
+         (tensor<1x32x8194x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x2x8xf32>
+  return %0 : tensor<1x32x2x8xf32>
+}
+
+// -----
+
+func.func @test_maxpool2d_adaptive_pad_first(%arg0: tensor<1x32x8194x8xf32>) -> tensor<1x32x2x8xf32> {
+  // If the source of the kernel passed to max_pool2d_adaptive is a const_shape then pad < kernel check is applied.
+  // This is a workaround for the above so that we can level check the padding.
+  %a = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %b = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %kernel = tosa.add_shape %a, %b : (!tosa.shape<2>, !tosa.shape<2>) -> !tosa.shape<2>
+  
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[8193, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // expected-error@+1 {{'tosa.max_pool2d_adaptive' op failed level check: pad <= MAX_KERNEL (8192), got 8193}}
+  %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
+         (tensor<1x32x8194x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x2x8xf32>
+  return %0 : tensor<1x32x2x8xf32>
+}
+
+// -----
+
+func.func @test_maxpool2d_adaptive_pad_second(%arg0: tensor<1x32x8194x8xf32>) -> tensor<1x32x2x8xf32> {
+  // If the source of the kernel passed to max_pool2d_adaptive is a const_shape then pad < kernel check is applied.
+  // This is a workaround for the above so that we can level check the padding.
+  %a = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %b = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %kernel = tosa.add_shape %a, %b : (!tosa.shape<2>, !tosa.shape<2>) -> !tosa.shape<2>
+
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[0, 8193, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // expected-error@+1 {{'tosa.max_pool2d_adaptive' op failed level check: pad <= MAX_KERNEL (8192), got 8193}}
+  %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
+         (tensor<1x32x8194x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x2x8xf32>
+  return %0 : tensor<1x32x2x8xf32>
+}
+
+// -----
+
+func.func @test_maxpool2d_adaptive_pad_third(%arg0: tensor<1x32x8194x8xf32>) -> tensor<1x32x2x8xf32> {
+  // If the source of the kernel passed to max_pool2d_adaptive is a const_shape then pad < kernel check is applied.
+  // This is a workaround for the above so that we can level check the padding.
+  %a = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %b = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %kernel = tosa.add_shape %a, %b : (!tosa.shape<2>, !tosa.shape<2>) -> !tosa.shape<2>
+
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[0, 0, 8193, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // expected-error@+1 {{'tosa.max_pool2d_adaptive' op failed level check: pad <= MAX_KERNEL (8192), got 8193}}
+  %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
+         (tensor<1x32x8194x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x2x8xf32>
+  return %0 : tensor<1x32x2x8xf32>
+}
+
+// -----
+
+func.func @test_maxpool2d_adaptive_pad_forth(%arg0: tensor<1x32x8194x8xf32>) -> tensor<1x32x2x8xf32> {
+  // If the source of the kernel passed to max_pool2d_adaptive is a const_shape then pad < kernel check is applied.
+  // This is a workaround for the above so that we can level check the padding.
+  %a = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %b = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %kernel = tosa.add_shape %a, %b : (!tosa.shape<2>, !tosa.shape<2>) -> !tosa.shape<2>
+
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 8193]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // expected-error@+1 {{'tosa.max_pool2d_adaptive' op failed level check: pad <= MAX_KERNEL (8192), got 8193}}
+  %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
+         (tensor<1x32x8194x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x2x8xf32>
+  return %0 : tensor<1x32x2x8xf32>
+}
+
+// -----
+
 func.func @test_rfft2d_input_h(%arg0: tensor<13x16384x16xf32>) -> (tensor<13x16384x9xf32>, tensor<13x16384x9xf32>) {
   // expected-error@+1 {{'tosa.rfft2d' op failed level check: H <= MAX_KERNEL (8192), got 16384}}
   %0, %1 = "tosa.rfft2d"(%arg0) {} : (tensor<13x16384x16xf32>) -> (tensor<13x16384x9xf32>, tensor<13x16384x9xf32>)
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index e80d3d84a8105..b30e92c4a9621 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -253,6 +253,53 @@ func.func @test_max_pool2d_f16(%arg0: tensor<1x32x32x8xf16>) -> tensor<1x32x32x8
   return %0 : tensor<1x32x32x8xf16>
 }
 
+// CHECK-LABEL: max_pool2d_adaptive_f32
+func.func @test_max_pool2d_adaptive_f32(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+  %kernel = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad : (tensor<1x32x32x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x32x8xf32>
+  return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+// CHECK-LABEL: max_pool2d_adaptive_bf16
+func.func @test_max_pool2d_adaptive_bf16(%arg0: tensor<1x32x32x8xbf16>) -> tensor<1x32x32x8xbf16> {
+  %kernel = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad : (tensor<1x32x32x8xbf16>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x32x8xbf16>
+  return %0 : tensor<1x32x32x8xbf16>
+}
+
+// -----
+// CHECK-LABEL: max_pool2d_adaptive_f16
+func.func @test_max_pool2d_adaptive_f16(%arg0: tensor<1x32x32x8xf16>) -> tensor<1x32x32x8xf16> {
+  %kernel = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %stride = tosa.const_...
[truncated]

@iliyan-georgiev-arm
Copy link
Copy Markdown
Contributor Author

@lhutton1 for review, thanks!

Copy link
Copy Markdown
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @iliyan-georgiev-arm!

Implements:
- Operator definition
- Operator verifier
- Validation
- Tests
- Adds NoMemoryEffect to AvgPool2dAdaptive

Signed-off-by: Iliyan Georgiev <Iliyan.Georgiev@arm.com>
Change-Id: I7550cc588ffc0da684605d67db71d989fb51da62
@iliyan-georgiev-arm iliyan-georgiev-arm force-pushed the max_pool2d_adaptive_redux branch from c5e2c03 to b042df7 Compare April 10, 2026 10:01
@lhutton1
Copy link
Copy Markdown
Contributor

Note that this change implements the specification change: arm/tosa-specification#23

@lhutton1 lhutton1 merged commit 3263854 into llvm:main Apr 10, 2026
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants