-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[mlir][tosa] Add aggressiveReduceConstant argument for the constant reduce optimization #68765
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][tosa] Add aggressiveReduceConstant argument for the constant reduce optimization #68765
Conversation
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir-tosa Author: Amir Bishara (amirBish) ChangesAdding the argument of aggressiveReduceConstant to the TosaLayerwiseConstantFoldPass which would Full diff: https://github.com/llvm/llvm-project/pull/68765.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 38b05c792d405ad..bb56c8d203d3c15 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1082,6 +1082,14 @@ def TosaToLinalg
}];
let constructor = "tosa::createTosaToLinalg()";
+ let options = [
+ Option<"disableTosaDecompositions", "disable-tosa-decompositions",
+ "bool", /*default=*/"false",
+ "Disable tosa decompositions pass">,
+ Option<"aggressiveReduceConstant", "aggressive-reduce-constant",
+ "bool", /*default=*/"false",
+ "Always perform the reduce constant optimization">
+ ];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index 818d43ffe4e572e..8ffbd1238e5c6b3 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -33,7 +33,7 @@ std::unique_ptr<Pass> createTosaToLinalgNamed();
/// pipeline succeeds. The option to disable decompositions is available for
/// benchmarking performance improvements from the canonicalizations.
void addTosaToLinalgPasses(
- OpPassManager &pm, bool disableTosaDecompositions = false,
+ OpPassManager &pm, const TosaToLinalgOptions& options,
// Note: Default to 'none' level unless otherwise specified.
tosa::ValidationOptions const &validationOptions =
tosa::ValidationOptions().setLevel(tosa::TosaLevelEnum::None));
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index 6b5dd9c970703ee..8f3255ddaad6844 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -35,9 +35,10 @@ void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx,
void populateTosaFoldConstantTransposePatterns(MLIRContext *ctx,
RewritePatternSet &patterns);
void populateTosaConstantReduction(MLIRContext *ctx,
- RewritePatternSet &patterns);
+ RewritePatternSet &patterns,bool aggressiveReduceConstant);
std::unique_ptr<Pass> createTosaLayerwiseConstantFoldPass();
+std::unique_ptr<Pass> createTosaLayerwiseConstantFoldPass(const TosaLayerwiseConstantFoldPassOptions& options);
std::unique_ptr<Pass> createTosaInferShapesPass();
std::unique_ptr<Pass> createTosaMakeBroadcastablePass();
std::unique_ptr<Pass> createTosaTestQuantUtilAPIPass();
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index 18402b3e70647a9..ac100a6d75c7c08 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -23,6 +23,13 @@ def TosaLayerwiseConstantFoldPass : Pass<"tosa-layerwise-constant-fold", "func::
}];
let constructor = "createTosaLayerwiseConstantFoldPass()";
+
+ let options = [
+ Option<"aggressiveReduceConstant", "aggressive-reduce-constant", "bool",
+ /*default=*/"false",
+ "Always perform the reduce constant optimization"
+ "May add more tosa.const but would reduce runtime calculations">,
+ ];
}
def TosaInferShapes : Pass<"tosa-infer-shapes", "func::FuncOp"> {
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index d7e867d92282395..e934d21fe065959 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -75,10 +75,10 @@ std::unique_ptr<Pass> mlir::tosa::createTosaToLinalg() {
}
void mlir::tosa::addTosaToLinalgPasses(
- OpPassManager &pm, bool disableTosaDecompositions,
+ OpPassManager &pm, const TosaToLinalgOptions& options,
tosa::ValidationOptions const &validationOptions) {
// Optional decompositions are designed to benefit linalg.
- if (!disableTosaDecompositions)
+ if (!options.disableTosaDecompositions)
pm.addNestedPass<func::FuncOp>(tosa::createTosaOptionalDecompositions());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
@@ -87,7 +87,7 @@ void mlir::tosa::addTosaToLinalgPasses(
pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalgNamed());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// TODO: Remove pass that operates on const tensor and enable optionality
- pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass());
+ pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass({options.aggressiveReduceConstant}));
pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
pm.addNestedPass<func::FuncOp>(
tosa::createTosaValidationPass(validationOptions));
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
index 0988759b82201df..3b417eda9e20dd4 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
@@ -350,6 +350,9 @@ llvm::APInt calculateReducedValue(const mlir::ElementsAttr &oldTensorAttr,
template <typename OperationType>
struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
+ ReduceConstantOptimization(MLIRContext *context, bool aggressiveReduceConstant):
+ OpRewritePattern<OperationType>(context), aggressiveReduceConstant(aggressiveReduceConstant){}
+
using OpRewritePattern<OperationType>::OpRewritePattern;
LogicalResult matchAndRewrite(OperationType op,
@@ -361,7 +364,7 @@ struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
return rewriter.notifyMatchFailure(
op, "reduce input must be const operation");
- if (!inputOp.hasOneUse())
+ if (!inputOp.hasOneUse() && !this->aggressiveReduceConstant)
return rewriter.notifyMatchFailure(
op, "input operation has more than one user");
@@ -400,18 +403,20 @@ struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, rankedTensorType, denseAttr);
return success();
}
+ const bool aggressiveReduceConstant;
};
} // namespace
void mlir::tosa::populateTosaConstantReduction(MLIRContext *ctx,
- RewritePatternSet &patterns) {
- patterns.add<ReduceConstantOptimization<ReduceAllOp>>(ctx);
- patterns.add<ReduceConstantOptimization<ReduceAnyOp>>(ctx);
- patterns.add<ReduceConstantOptimization<ReduceMaxOp>>(ctx);
- patterns.add<ReduceConstantOptimization<ReduceMinOp>>(ctx);
- patterns.add<ReduceConstantOptimization<ReduceProdOp>>(ctx);
- patterns.add<ReduceConstantOptimization<ReduceSumOp>>(ctx);
+ RewritePatternSet &patterns,
+ bool aggressiveReduceConstant) {
+ patterns.add<ReduceConstantOptimization<ReduceAllOp>>(ctx, aggressiveReduceConstant);
+ patterns.add<ReduceConstantOptimization<ReduceAnyOp>>(ctx, aggressiveReduceConstant);
+ patterns.add<ReduceConstantOptimization<ReduceMaxOp>>(ctx, aggressiveReduceConstant);
+ patterns.add<ReduceConstantOptimization<ReduceMinOp>>(ctx, aggressiveReduceConstant);
+ patterns.add<ReduceConstantOptimization<ReduceProdOp>>(ctx, aggressiveReduceConstant);
+ patterns.add<ReduceConstantOptimization<ReduceSumOp>>(ctx, aggressiveReduceConstant);
}
void mlir::tosa::populateTosaFoldConstantTransposePatterns(
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp
index 90f15faf0108103..56bc53a4746da0d 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp
@@ -45,6 +45,10 @@ void populateTosaOpsCanonicalizationPatterns(MLIRContext *ctx,
struct TosaLayerwiseConstantFoldPass
: public tosa::impl::TosaLayerwiseConstantFoldPassBase<
TosaLayerwiseConstantFoldPass> {
+ TosaLayerwiseConstantFoldPass(const TosaLayerwiseConstantFoldPassOptions& options):TosaLayerwiseConstantFoldPassBase(options)
+ {
+ }
+
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
@@ -52,7 +56,7 @@ struct TosaLayerwiseConstantFoldPass
mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx, patterns);
mlir::tosa::populateTosaFoldConstantTransposePatterns(ctx, patterns);
- mlir::tosa::populateTosaConstantReduction(ctx, patterns);
+ mlir::tosa::populateTosaConstantReduction(ctx, patterns, aggressiveReduceConstant);
populateTosaOpsCanonicalizationPatterns(ctx, patterns);
if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed())
@@ -63,5 +67,9 @@ struct TosaLayerwiseConstantFoldPass
} // namespace
std::unique_ptr<Pass> mlir::tosa::createTosaLayerwiseConstantFoldPass() {
- return std::make_unique<TosaLayerwiseConstantFoldPass>();
+return std::make_unique<TosaLayerwiseConstantFoldPass>(TosaLayerwiseConstantFoldPassOptions{false});
+}
+
+std::unique_ptr<Pass> mlir::tosa::createTosaLayerwiseConstantFoldPass(const TosaLayerwiseConstantFoldPassOptions& options) {
+ return std::make_unique<TosaLayerwiseConstantFoldPass>(options);
}
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index 56619fbc560e5fa..612e99f198515ae 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -1,5 +1,8 @@
// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold %s | FileCheck %s
+
+// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold="aggressive-reduce-constant=true" %s | FileCheck %s --check-prefix=AGGRESIVE
+
// CHECK-LABEL: @transpose_fold
func.func @transpose_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
// CHECK: return %arg0
@@ -1051,3 +1054,57 @@ func.func @reduce_sum_constant() -> tensor<1x3xi32> {
%0 = tosa.reduce_sum %arg2 {axis = 0 : i32} : (tensor<2x3xi32>) -> tensor<1x3xi32>
return %0 : tensor<1x3xi32>
}
+
+// -----
+
+func.func @reduce_sum_constant_aggressive() -> tensor<1x3xi32> {
+ // AGGRESIVE-LABEL: func.func @reduce_sum_constant_aggressive() -> tensor<1x3xi32> {
+ // AGGRESIVE: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<4> : tensor<1x3xi32>}> : () -> tensor<1x3xi32>
+ // AGGRESIVE: return %[[VAL_0:.*]] : tensor<1x3xi32>
+
+ // CHECK-LABEL: func.func @reduce_sum_constant_aggressive() -> tensor<1x3xi32> {
+ // CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<1> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
+ // CHECK: %[[VAL_1:.*]] = tosa.reduce_sum %[[VAL_0]] {axis = 0 : i32} : (tensor<2x3xi32>) -> tensor<1x3xi32>
+ // CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_0]] {axis = 0 : i32} : (tensor<2x3xi32>) -> tensor<1x3xi32>
+ // CHECK: %[[VAL_3:.*]] = tosa.add %[[VAL_1]], %[[VAL_2]] : (tensor<1x3xi32>, tensor<1x3xi32>) -> tensor<1x3xi32>
+ // CHECK: return %[[VAL_3]] : tensor<1x3xi32>
+
+ %const = "tosa.const"() {value = dense<1> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
+ %0 = tosa.reduce_sum %const {axis = 0 : i32} : (tensor<2x3xi32>) -> tensor<1x3xi32>
+ %1 = tosa.reduce_sum %const {axis = 0 : i32} : (tensor<2x3xi32>) -> tensor<1x3xi32>
+ %res = tosa.add %0, %1 : (tensor<1x3xi32>, tensor<1x3xi32>) -> tensor<1x3xi32>
+ return %res : tensor<1x3xi32>
+}
+
+// -----
+
+func.func @reduce_sum_constant_aggressive() -> tensor<2x3xi32> {
+ // AGGRESIVE-LABEL: func.func @reduce_sum_constant_aggressive() -> tensor<2x3xi32> {
+ // AGGRESIVE: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<2> : tensor<1x2x3xi32>}> : () -> tensor<1x2x3xi32>
+ // AGGRESIVE: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<1> : tensor<2x2x3xi32>}> : () -> tensor<2x2x3xi32>
+ // AGGRESIVE: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<2> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
+ // AGGRESIVE: %[[VAL_3:.*]] = tosa.argmax %[[VAL_0]] {axis = 0 : i32} : (tensor<1x2x3xi32>) -> tensor<2x3xi32>
+ // AGGRESIVE: %[[VAL_4:.*]] = tosa.argmax %[[VAL_1]] {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<2x3xi32>
+ // AGGRESIVE: %[[VAL_5:.*]] = tosa.add %[[VAL_3]], %[[VAL_2]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+ // AGGRESIVE: %[[VAL_6:.*]] = tosa.add %[[VAL_5]], %[[VAL_4]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+ // AGGRESIVE: return %[[VAL_6]] : tensor<2x3xi32>
+
+ // CHECK-LABEL: func.func @reduce_sum_constant_aggressive() -> tensor<2x3xi32> {
+ // CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<1> : tensor<2x2x3xi32>}> : () -> tensor<2x2x3xi32>
+ // CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<2> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
+ // CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_0]] {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<1x2x3xi32>
+ // CHECK: %[[VAL_3:.*]] = tosa.argmax %[[VAL_2]] {axis = 0 : i32} : (tensor<1x2x3xi32>) -> tensor<2x3xi32>
+ // CHECK: %[[VAL_4:.*]] = tosa.argmax %[[VAL_0]] {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<2x3xi32>
+ // CHECK: %[[VAL_5:.*]] = tosa.add %[[VAL_3]], %[[VAL_1]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+ // CHECK: %[[VAL_6:.*]] = tosa.add %[[VAL_5]], %[[VAL_4]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+ // CHECK: return %[[VAL_6]] : tensor<2x3xi32>
+
+ %const0 = "tosa.const"() {value = dense<1> : tensor<2x2x3xi32>} : () -> tensor<2x2x3xi32>
+ %const1 = "tosa.const"() {value = dense<2> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
+ %reduce0 = tosa.reduce_sum %const0 {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<1x2x3xi32>
+ %argmax0 = tosa.argmax %reduce0 {axis = 0 : i32} : (tensor<1x2x3xi32>) -> tensor<2x3xi32>
+ %argmax1 = tosa.argmax %const0 {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<2x3xi32>
+ %res0 = tosa.add %argmax0, %const1 : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+ %res1 = tosa.add %res0, %argmax1 : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+ return %res1 : tensor<2x3xi32>
+}
|
24b94b5
to
fd3fb47
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
fd3fb47
to
9fedcdd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks okay to me.
…educe optimization Adding the argument of aggressiveReduceConstant to the TosaLayerwiseConstantFoldPass which would allow performing the constant optimizations on the reduce ops always. (e.g. without considering the number of users of the input of the reduce operation)
9fedcdd
to
c291b8b
Compare
Adding the argument of aggressiveReduceConstant to the TosaLayerwiseConstantFoldPass which would
allow performing the constant optimizations on the reduce ops always. (e.g. without considering the
number of users of the input of the reduce operation)