Skip to content

[mlir][tosa] Add canonicalization for adaptive to their non-adaptive variants#195865

Merged
lhutton1 merged 1 commit into
llvm:mainfrom
lhutton1:adaptive-canonicalization
May 6, 2026
Merged

[mlir][tosa] Add canonicalization for adaptive to their non-adaptive variants#195865
lhutton1 merged 1 commit into
llvm:mainfrom
lhutton1:adaptive-canonicalization

Conversation

@lhutton1
Copy link
Copy Markdown
Contributor

@lhutton1 lhutton1 commented May 5, 2026

This commit adds canonicalization patterns to convert adaptive pooling (max and avg) to their non-adaptive variants when their CTC inputs are constants.

This is beneficial for backends that do not support the adaptive op variants.

…variants

This commit adds canonicalization patterns to convert adaptive pooling
(max and avg) to their non-adaptive variants when their CTC inputs are
constants.

This is beneficial for backends that do not support the adaptive op
variants.

Change-Id: I9037438325a3b0071f14ebed0aa444acf66656df
@llvmorg-github-actions
Copy link
Copy Markdown

llvmorg-github-actions Bot commented May 5, 2026

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: Luke Hutton (lhutton1)

Changes

This commit adds canonicalization patterns to convert adaptive pooling (max and avg) to their non-adaptive variants when their CTC inputs are constants.

This is beneficial for backends that do not support the adaptive op variants.


Full diff: https://github.com/llvm/llvm-project/pull/195865.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+3-1)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+60)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+58)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 207618adc1352..71122ba8531c6 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -161,6 +161,7 @@ def Tosa_AvgPool2dAdaptiveOp
   }];
 
   let hasVerifier = 1;
+  let hasCanonicalizer = 1;
 
   let assemblyFormat =
       "operands attr-dict `:` functional-type(operands, results)";
@@ -536,7 +537,7 @@ def Tosa_MaxPool2dAdaptiveOp
     This performs a max pooling over the given input tensor. A sliding window of
     size given by <kernel size> is passed over the input tensor, with the
     maximum value being placed in the output tensor.
-    Compared to MAX_POOL2D, MAX_POOL2D_ADAPTIVE has the kernel, stride, 
+    Compared to MAX_POOL2D, MAX_POOL2D_ADAPTIVE has the kernel, stride,
     pad arguments as inputs rather than attributes.
   }];
 
@@ -557,6 +558,7 @@ def Tosa_MaxPool2dAdaptiveOp
   ];
 
   let hasVerifier = 1;
+  let hasCanonicalizer = 1;
   let hasCustomAssemblyFormat = 1;
 }
 
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 1c186cd3ae122..642ee4b98e216 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -245,6 +245,36 @@ void DepthwiseConv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
       context);
 }
 
+struct AvgPool2dAdaptiveToAvgPool2d
+    : public OpRewritePattern<tosa::AvgPool2dAdaptiveOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::AvgPool2dAdaptiveOp op,
+                                PatternRewriter &rewriter) const override {
+    llvm::SmallVector<int64_t> kernel;
+    llvm::SmallVector<int64_t> stride;
+    llvm::SmallVector<int64_t> pad;
+    if (!tosa::getConstShapeValues(op.getKernel().getDefiningOp(), kernel) ||
+        !tosa::getConstShapeValues(op.getStride().getDefiningOp(), stride) ||
+        !tosa::getConstShapeValues(op.getPad().getDefiningOp(), pad))
+      return rewriter.notifyMatchFailure(
+          op, "expected constant kernel, stride, and pad operands");
+
+    auto replacement = tosa::AvgPool2dOp::create(
+        rewriter, op.getLoc(), op.getType(), op.getInput(), op.getInputZp(),
+        op.getOutputZp(), rewriter.getDenseI64ArrayAttr(kernel),
+        rewriter.getDenseI64ArrayAttr(stride),
+        rewriter.getDenseI64ArrayAttr(pad), op.getAccTypeAttr());
+    rewriter.replaceOp(op, replacement.getOutput());
+    return success();
+  }
+};
+
+void AvgPool2dAdaptiveOp::getCanonicalizationPatterns(
+    RewritePatternSet &results, MLIRContext *context) {
+  results.add<AvgPool2dAdaptiveToAvgPool2d>(context);
+}
+
 struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -283,6 +313,36 @@ void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
       context);
 }
 
+struct MaxPool2dAdaptiveToMaxPool2d
+    : public OpRewritePattern<tosa::MaxPool2dAdaptiveOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::MaxPool2dAdaptiveOp op,
+                                PatternRewriter &rewriter) const override {
+    llvm::SmallVector<int64_t> kernel;
+    llvm::SmallVector<int64_t> stride;
+    llvm::SmallVector<int64_t> pad;
+    if (!tosa::getConstShapeValues(op.getKernel().getDefiningOp(), kernel) ||
+        !tosa::getConstShapeValues(op.getStride().getDefiningOp(), stride) ||
+        !tosa::getConstShapeValues(op.getPad().getDefiningOp(), pad))
+      return rewriter.notifyMatchFailure(
+          op, "expected constant kernel, stride, and pad operands");
+
+    auto replacement = tosa::MaxPool2dOp::create(
+        rewriter, op.getLoc(), op.getType(), op.getInput(),
+        rewriter.getDenseI64ArrayAttr(kernel),
+        rewriter.getDenseI64ArrayAttr(stride),
+        rewriter.getDenseI64ArrayAttr(pad), op.getNanModeAttr());
+    rewriter.replaceOp(op, replacement.getOutput());
+    return success();
+  }
+};
+
+void MaxPool2dAdaptiveOp::getCanonicalizationPatterns(
+    RewritePatternSet &results, MLIRContext *context) {
+  results.add<MaxPool2dAdaptiveToMaxPool2d>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // Data Layout / Memory Reinterpretation.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 9a7fa3efc8d3c..19583e111ebef 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1649,3 +1649,61 @@ func.func @test_do_not_canonicalize_cast_from_cast_to_block_scaled_unranked(%arg
   %1, %2 = tosa.cast_to_block_scaled %0 {block_size = BLOCK_SIZE_32} : (tensor<*xf32>) -> (tensor<*xf6E2M3FN>, tensor<*xf8E8M0FNU>)
   return %1, %2 : tensor<*xf6E2M3FN>, tensor<*xf8E8M0FNU>
 }
+
+// -----
+
+// CHECK-LABEL: @canonicalize_max_pool2d_adaptive
+// CHECK: %[[POOL:.+]] = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, nan_mode = IGNORE, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
+// CHECK: return %[[POOL]]
+func.func @canonicalize_max_pool2d_adaptive(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+  %kernel = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad {nan_mode = IGNORE} :
+         (tensor<1x32x32x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x32x8xf32>
+  return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @canonicalize_avg_pool2d_adaptive
+// CHECK: %[[POOL:.+]] = tosa.avg_pool2d %arg0, %{{.*}}, %{{.*}} {acc_type = f32, kernel = array<i64: 3, 3>, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x7x7x9xf32>
+// CHECK: return %[[POOL]]
+func.func @canonicalize_avg_pool2d_adaptive(%arg0: tensor<1x7x7x9xf32>, %input_zp: tensor<1xf32>, %output_zp: tensor<1xf32>) -> tensor<1x7x7x9xf32> {
+  %kernel = tosa.const_shape {values = dense<[3, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[1, 1, 1, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %0 = tosa.avg_pool2d_adaptive %arg0, %input_zp, %output_zp, %kernel, %stride, %pad {acc_type = f32} :
+         (tensor<1x7x7x9xf32>, tensor<1xf32>, tensor<1xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x7x7x9xf32>
+  return %0 : tensor<1x7x7x9xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @dont_canonicalize_non_const_max_pool2d_adaptive
+// CHECK: tosa.max_pool2d_adaptive
+func.func @dont_canonicalize_non_const_max_pool2d_adaptive(%arg0: tensor<1x?x?x8xf32>) -> tensor<1x?x?x8xf32> {
+  %dim1 = tosa.dim %arg0 {axis = 1 : i32} : (tensor<1x?x?x8xf32>) -> !tosa.shape<1>
+  %dim2 = tosa.dim %arg0 {axis = 2 : i32} : (tensor<1x?x?x8xf32>) -> !tosa.shape<1>
+  %kernel = tosa.concat_shape %dim1, %dim2 : (!tosa.shape<1>, !tosa.shape<1>) -> !tosa.shape<2>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad {nan_mode = IGNORE} :
+          (tensor<1x?x?x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x?x?x8xf32>
+  return %0 : tensor<1x?x?x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @dont_canonicalize_non_const_avg_pool2d_adaptive
+// CHECK: tosa.avg_pool2d_adaptive
+func.func @dont_canonicalize_non_const_avg_pool2d_adaptive(%arg0: tensor<1x?x?x8xf32>, %input_zp: tensor<1xf32>, %output_zp: tensor<1xf32>) -> tensor<1x?x?x8xf32> {
+  %dim1 = tosa.dim %arg0 {axis = 1 : i32} : (tensor<1x?x?x8xf32>) -> !tosa.shape<1>
+  %dim2 = tosa.dim %arg0 {axis = 2 : i32} : (tensor<1x?x?x8xf32>) -> !tosa.shape<1>
+  %kernel = tosa.concat_shape %dim1, %dim2 : (!tosa.shape<1>, !tosa.shape<1>) -> !tosa.shape<2>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %0 = tosa.avg_pool2d_adaptive %arg0, %input_zp, %output_zp, %kernel, %stride, %pad {acc_type = f32} :
+          (tensor<1x?x?x8xf32>, tensor<1xf32>, tensor<1xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x?x?x8xf32>
+  return %0 : tensor<1x?x?x8xf32>
+}

Copy link
Copy Markdown
Member

@sahas3 sahas3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@lhutton1 lhutton1 merged commit 0e8a8d0 into llvm:main May 6, 2026
13 checks passed
moar55 pushed a commit to moar55/llvm-project that referenced this pull request May 12, 2026
…variants (llvm#195865)

This commit adds canonicalization patterns to convert adaptive pooling
(max and avg) to their non-adaptive variants when their CTC inputs are
constants.

This is beneficial for backends that do not support the adaptive op
variants.
pedroMVicente pushed a commit to pedroMVicente/llvm-project that referenced this pull request May 19, 2026
…variants (llvm#195865)

This commit adds canonicalization patterns to convert adaptive pooling
(max and avg) to their non-adaptive variants when their CTC inputs are
constants.

This is beneficial for backends that do not support the adaptive op
variants.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants