[mlir][tosa] Add canonicalization for adaptive to their non-adaptive variants#195865
Merged
Conversation
…variants This commit adds canonicalization patterns to convert adaptive pooling (max and avg) to their non-adaptive variants when their CTC inputs are constants. This is beneficial for backends that do not support the adaptive op variants. Change-Id: I9037438325a3b0071f14ebed0aa444acf66656df
|
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir Author: Luke Hutton (lhutton1) ChangesThis commit adds canonicalization patterns to convert adaptive pooling (max and avg) to their non-adaptive variants when their CTC inputs are constants. This is beneficial for backends that do not support the adaptive op variants. Full diff: https://github.com/llvm/llvm-project/pull/195865.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 207618adc1352..71122ba8531c6 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -161,6 +161,7 @@ def Tosa_AvgPool2dAdaptiveOp
}];
let hasVerifier = 1;
+ let hasCanonicalizer = 1;
let assemblyFormat =
"operands attr-dict `:` functional-type(operands, results)";
@@ -536,7 +537,7 @@ def Tosa_MaxPool2dAdaptiveOp
This performs a max pooling over the given input tensor. A sliding window of
size given by <kernel size> is passed over the input tensor, with the
maximum value being placed in the output tensor.
- Compared to MAX_POOL2D, MAX_POOL2D_ADAPTIVE has the kernel, stride,
+ Compared to MAX_POOL2D, MAX_POOL2D_ADAPTIVE has the kernel, stride,
pad arguments as inputs rather than attributes.
}];
@@ -557,6 +558,7 @@ def Tosa_MaxPool2dAdaptiveOp
];
let hasVerifier = 1;
+ let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 1c186cd3ae122..642ee4b98e216 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -245,6 +245,36 @@ void DepthwiseConv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
context);
}
+struct AvgPool2dAdaptiveToAvgPool2d
+ : public OpRewritePattern<tosa::AvgPool2dAdaptiveOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::AvgPool2dAdaptiveOp op,
+ PatternRewriter &rewriter) const override {
+ llvm::SmallVector<int64_t> kernel;
+ llvm::SmallVector<int64_t> stride;
+ llvm::SmallVector<int64_t> pad;
+ if (!tosa::getConstShapeValues(op.getKernel().getDefiningOp(), kernel) ||
+ !tosa::getConstShapeValues(op.getStride().getDefiningOp(), stride) ||
+ !tosa::getConstShapeValues(op.getPad().getDefiningOp(), pad))
+ return rewriter.notifyMatchFailure(
+ op, "expected constant kernel, stride, and pad operands");
+
+ auto replacement = tosa::AvgPool2dOp::create(
+ rewriter, op.getLoc(), op.getType(), op.getInput(), op.getInputZp(),
+ op.getOutputZp(), rewriter.getDenseI64ArrayAttr(kernel),
+ rewriter.getDenseI64ArrayAttr(stride),
+ rewriter.getDenseI64ArrayAttr(pad), op.getAccTypeAttr());
+ rewriter.replaceOp(op, replacement.getOutput());
+ return success();
+ }
+};
+
+void AvgPool2dAdaptiveOp::getCanonicalizationPatterns(
+ RewritePatternSet &results, MLIRContext *context) {
+ results.add<AvgPool2dAdaptiveToAvgPool2d>(context);
+}
+
struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
using OpRewritePattern::OpRewritePattern;
@@ -283,6 +313,36 @@ void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
context);
}
+struct MaxPool2dAdaptiveToMaxPool2d
+ : public OpRewritePattern<tosa::MaxPool2dAdaptiveOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::MaxPool2dAdaptiveOp op,
+ PatternRewriter &rewriter) const override {
+ llvm::SmallVector<int64_t> kernel;
+ llvm::SmallVector<int64_t> stride;
+ llvm::SmallVector<int64_t> pad;
+ if (!tosa::getConstShapeValues(op.getKernel().getDefiningOp(), kernel) ||
+ !tosa::getConstShapeValues(op.getStride().getDefiningOp(), stride) ||
+ !tosa::getConstShapeValues(op.getPad().getDefiningOp(), pad))
+ return rewriter.notifyMatchFailure(
+ op, "expected constant kernel, stride, and pad operands");
+
+ auto replacement = tosa::MaxPool2dOp::create(
+ rewriter, op.getLoc(), op.getType(), op.getInput(),
+ rewriter.getDenseI64ArrayAttr(kernel),
+ rewriter.getDenseI64ArrayAttr(stride),
+ rewriter.getDenseI64ArrayAttr(pad), op.getNanModeAttr());
+ rewriter.replaceOp(op, replacement.getOutput());
+ return success();
+ }
+};
+
+void MaxPool2dAdaptiveOp::getCanonicalizationPatterns(
+ RewritePatternSet &results, MLIRContext *context) {
+ results.add<MaxPool2dAdaptiveToMaxPool2d>(context);
+}
+
//===----------------------------------------------------------------------===//
// Data Layout / Memory Reinterpretation.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 9a7fa3efc8d3c..19583e111ebef 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1649,3 +1649,61 @@ func.func @test_do_not_canonicalize_cast_from_cast_to_block_scaled_unranked(%arg
%1, %2 = tosa.cast_to_block_scaled %0 {block_size = BLOCK_SIZE_32} : (tensor<*xf32>) -> (tensor<*xf6E2M3FN>, tensor<*xf8E8M0FNU>)
return %1, %2 : tensor<*xf6E2M3FN>, tensor<*xf8E8M0FNU>
}
+
+// -----
+
+// CHECK-LABEL: @canonicalize_max_pool2d_adaptive
+// CHECK: %[[POOL:.+]] = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, nan_mode = IGNORE, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
+// CHECK: return %[[POOL]]
+func.func @canonicalize_max_pool2d_adaptive(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+ %kernel = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad {nan_mode = IGNORE} :
+ (tensor<1x32x32x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x32x8xf32>
+ return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @canonicalize_avg_pool2d_adaptive
+// CHECK: %[[POOL:.+]] = tosa.avg_pool2d %arg0, %{{.*}}, %{{.*}} {acc_type = f32, kernel = array<i64: 3, 3>, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x7x7x9xf32>
+// CHECK: return %[[POOL]]
+func.func @canonicalize_avg_pool2d_adaptive(%arg0: tensor<1x7x7x9xf32>, %input_zp: tensor<1xf32>, %output_zp: tensor<1xf32>) -> tensor<1x7x7x9xf32> {
+ %kernel = tosa.const_shape {values = dense<[3, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %pad = tosa.const_shape {values = dense<[1, 1, 1, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %0 = tosa.avg_pool2d_adaptive %arg0, %input_zp, %output_zp, %kernel, %stride, %pad {acc_type = f32} :
+ (tensor<1x7x7x9xf32>, tensor<1xf32>, tensor<1xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x7x7x9xf32>
+ return %0 : tensor<1x7x7x9xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @dont_canonicalize_non_const_max_pool2d_adaptive
+// CHECK: tosa.max_pool2d_adaptive
+func.func @dont_canonicalize_non_const_max_pool2d_adaptive(%arg0: tensor<1x?x?x8xf32>) -> tensor<1x?x?x8xf32> {
+ %dim1 = tosa.dim %arg0 {axis = 1 : i32} : (tensor<1x?x?x8xf32>) -> !tosa.shape<1>
+ %dim2 = tosa.dim %arg0 {axis = 2 : i32} : (tensor<1x?x?x8xf32>) -> !tosa.shape<1>
+ %kernel = tosa.concat_shape %dim1, %dim2 : (!tosa.shape<1>, !tosa.shape<1>) -> !tosa.shape<2>
+ %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad {nan_mode = IGNORE} :
+ (tensor<1x?x?x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x?x?x8xf32>
+ return %0 : tensor<1x?x?x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @dont_canonicalize_non_const_avg_pool2d_adaptive
+// CHECK: tosa.avg_pool2d_adaptive
+func.func @dont_canonicalize_non_const_avg_pool2d_adaptive(%arg0: tensor<1x?x?x8xf32>, %input_zp: tensor<1xf32>, %output_zp: tensor<1xf32>) -> tensor<1x?x?x8xf32> {
+ %dim1 = tosa.dim %arg0 {axis = 1 : i32} : (tensor<1x?x?x8xf32>) -> !tosa.shape<1>
+ %dim2 = tosa.dim %arg0 {axis = 2 : i32} : (tensor<1x?x?x8xf32>) -> !tosa.shape<1>
+ %kernel = tosa.concat_shape %dim1, %dim2 : (!tosa.shape<1>, !tosa.shape<1>) -> !tosa.shape<2>
+ %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %0 = tosa.avg_pool2d_adaptive %arg0, %input_zp, %output_zp, %kernel, %stride, %pad {acc_type = f32} :
+ (tensor<1x?x?x8xf32>, tensor<1xf32>, tensor<1xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x?x?x8xf32>
+ return %0 : tensor<1x?x?x8xf32>
+}
|
moar55
pushed a commit
to moar55/llvm-project
that referenced
this pull request
May 12, 2026
…variants (llvm#195865) This commit adds canonicalization patterns to convert adaptive pooling (max and avg) to their non-adaptive variants when their CTC inputs are constants. This is beneficial for backends that do not support the adaptive op variants.
pedroMVicente
pushed a commit
to pedroMVicente/llvm-project
that referenced
this pull request
May 19, 2026
…variants (llvm#195865) This commit adds canonicalization patterns to convert adaptive pooling (max and avg) to their non-adaptive variants when their CTC inputs are constants. This is beneficial for backends that do not support the adaptive op variants.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This commit adds canonicalization patterns to convert adaptive pooling (max and avg) to their non-adaptive variants when their CTC inputs are constants.
This is beneficial for backends that do not support the adaptive op variants.