Skip to content

Conversation

amirBish
Copy link
Contributor

Adding the argument of aggressiveReduceConstant to the TosaLayerwiseConstantFoldPass which would
allow performing the constant optimizations on the reduce ops always. (e.g. without considering the
number of users of the input of the reduce operation)

@llvmbot
Copy link
Member

llvmbot commented Oct 11, 2023

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tosa

Author: Amir Bishara (amirBish)

Changes

Adding the argument of aggressiveReduceConstant to the TosaLayerwiseConstantFoldPass which would
allow performing the constant optimizations on the reduce ops always. (e.g. without considering the
number of users of the input of the reduce operation)


Full diff: https://github.com/llvm/llvm-project/pull/68765.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Conversion/Passes.td (+8)
  • (modified) mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h (+1-1)
  • (modified) mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h (+2-1)
  • (modified) mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td (+7)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp (+13-8)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp (+10-2)
  • (modified) mlir/test/Dialect/Tosa/constant-op-fold.mlir (+57)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 38b05c792d405ad..bb56c8d203d3c15 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1082,6 +1082,14 @@ def TosaToLinalg
   }];
 
   let constructor = "tosa::createTosaToLinalg()";
+  let options = [
+    Option<"disableTosaDecompositions", "disable-tosa-decompositions",
+           "bool", /*default=*/"false",
+           "Disable tosa decompositions pass">,
+    Option<"aggressiveReduceConstant", "aggressive-reduce-constant",
+           "bool", /*default=*/"false",
+           "Always perform the reduce constant optimization">
+  ];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index 818d43ffe4e572e..8ffbd1238e5c6b3 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -33,7 +33,7 @@ std::unique_ptr<Pass> createTosaToLinalgNamed();
 /// pipeline succeeds.  The option to disable decompositions is available for
 /// benchmarking performance improvements from the canonicalizations.
 void addTosaToLinalgPasses(
-    OpPassManager &pm, bool disableTosaDecompositions = false,
+    OpPassManager &pm, const TosaToLinalgOptions& options,
     // Note: Default to 'none' level unless otherwise specified.
     tosa::ValidationOptions const &validationOptions =
         tosa::ValidationOptions().setLevel(tosa::TosaLevelEnum::None));
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index 6b5dd9c970703ee..8f3255ddaad6844 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -35,9 +35,10 @@ void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx,
 void populateTosaFoldConstantTransposePatterns(MLIRContext *ctx,
                                                RewritePatternSet &patterns);
 void populateTosaConstantReduction(MLIRContext *ctx,
-                                   RewritePatternSet &patterns);
+                                   RewritePatternSet &patterns,bool aggressiveReduceConstant);
 
 std::unique_ptr<Pass> createTosaLayerwiseConstantFoldPass();
+std::unique_ptr<Pass> createTosaLayerwiseConstantFoldPass(const TosaLayerwiseConstantFoldPassOptions& options);
 std::unique_ptr<Pass> createTosaInferShapesPass();
 std::unique_ptr<Pass> createTosaMakeBroadcastablePass();
 std::unique_ptr<Pass> createTosaTestQuantUtilAPIPass();
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index 18402b3e70647a9..ac100a6d75c7c08 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -23,6 +23,13 @@ def TosaLayerwiseConstantFoldPass : Pass<"tosa-layerwise-constant-fold", "func::
   }];
 
   let constructor = "createTosaLayerwiseConstantFoldPass()";
+
+  let options = [
+      Option<"aggressiveReduceConstant", "aggressive-reduce-constant", "bool",
+             /*default=*/"false",
+             "Always perform the reduce constant optimization"
+             "May add more tosa.const but would reduce runtime calculations">,
+   ];
 }
 
 def TosaInferShapes : Pass<"tosa-infer-shapes", "func::FuncOp"> {
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index d7e867d92282395..e934d21fe065959 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -75,10 +75,10 @@ std::unique_ptr<Pass> mlir::tosa::createTosaToLinalg() {
 }
 
 void mlir::tosa::addTosaToLinalgPasses(
-    OpPassManager &pm, bool disableTosaDecompositions,
+    OpPassManager &pm, const TosaToLinalgOptions& options,
     tosa::ValidationOptions const &validationOptions) {
   // Optional decompositions are designed to benefit linalg.
-  if (!disableTosaDecompositions)
+  if (!options.disableTosaDecompositions)
     pm.addNestedPass<func::FuncOp>(tosa::createTosaOptionalDecompositions());
   pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
 
@@ -87,7 +87,7 @@ void mlir::tosa::addTosaToLinalgPasses(
   pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalgNamed());
   pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
   // TODO: Remove pass that operates on const tensor and enable optionality
-  pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass());
+  pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass({options.aggressiveReduceConstant}));
   pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
   pm.addNestedPass<func::FuncOp>(
       tosa::createTosaValidationPass(validationOptions));
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
index 0988759b82201df..3b417eda9e20dd4 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
@@ -350,6 +350,9 @@ llvm::APInt calculateReducedValue(const mlir::ElementsAttr &oldTensorAttr,
 template <typename OperationType>
 struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
 
+  ReduceConstantOptimization(MLIRContext *context, bool aggressiveReduceConstant):
+  OpRewritePattern<OperationType>(context), aggressiveReduceConstant(aggressiveReduceConstant){}
+
   using OpRewritePattern<OperationType>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(OperationType op,
@@ -361,7 +364,7 @@ struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
       return rewriter.notifyMatchFailure(
           op, "reduce input must be const operation");
 
-    if (!inputOp.hasOneUse())
+    if (!inputOp.hasOneUse() && !this->aggressiveReduceConstant)
       return rewriter.notifyMatchFailure(
           op, "input operation has more than one user");
 
@@ -400,18 +403,20 @@ struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
     rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, rankedTensorType, denseAttr);
     return success();
   }
+  const bool aggressiveReduceConstant;
 };
 
 } // namespace
 
 void mlir::tosa::populateTosaConstantReduction(MLIRContext *ctx,
-                                               RewritePatternSet &patterns) {
-  patterns.add<ReduceConstantOptimization<ReduceAllOp>>(ctx);
-  patterns.add<ReduceConstantOptimization<ReduceAnyOp>>(ctx);
-  patterns.add<ReduceConstantOptimization<ReduceMaxOp>>(ctx);
-  patterns.add<ReduceConstantOptimization<ReduceMinOp>>(ctx);
-  patterns.add<ReduceConstantOptimization<ReduceProdOp>>(ctx);
-  patterns.add<ReduceConstantOptimization<ReduceSumOp>>(ctx);
+                                               RewritePatternSet &patterns,
+                                               bool aggressiveReduceConstant) {
+  patterns.add<ReduceConstantOptimization<ReduceAllOp>>(ctx, aggressiveReduceConstant);
+  patterns.add<ReduceConstantOptimization<ReduceAnyOp>>(ctx, aggressiveReduceConstant);
+  patterns.add<ReduceConstantOptimization<ReduceMaxOp>>(ctx, aggressiveReduceConstant);
+  patterns.add<ReduceConstantOptimization<ReduceMinOp>>(ctx, aggressiveReduceConstant);
+  patterns.add<ReduceConstantOptimization<ReduceProdOp>>(ctx, aggressiveReduceConstant);
+  patterns.add<ReduceConstantOptimization<ReduceSumOp>>(ctx, aggressiveReduceConstant);
 }
 
 void mlir::tosa::populateTosaFoldConstantTransposePatterns(
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp
index 90f15faf0108103..56bc53a4746da0d 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp
@@ -45,6 +45,10 @@ void populateTosaOpsCanonicalizationPatterns(MLIRContext *ctx,
 struct TosaLayerwiseConstantFoldPass
     : public tosa::impl::TosaLayerwiseConstantFoldPassBase<
           TosaLayerwiseConstantFoldPass> {
+  TosaLayerwiseConstantFoldPass(const TosaLayerwiseConstantFoldPassOptions& options):TosaLayerwiseConstantFoldPassBase(options)
+  {
+  }
+
   void runOnOperation() override {
     auto *ctx = &getContext();
     RewritePatternSet patterns(ctx);
@@ -52,7 +56,7 @@ struct TosaLayerwiseConstantFoldPass
 
     mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx, patterns);
     mlir::tosa::populateTosaFoldConstantTransposePatterns(ctx, patterns);
-    mlir::tosa::populateTosaConstantReduction(ctx, patterns);
+    mlir::tosa::populateTosaConstantReduction(ctx, patterns, aggressiveReduceConstant);
     populateTosaOpsCanonicalizationPatterns(ctx, patterns);
 
     if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed())
@@ -63,5 +67,9 @@ struct TosaLayerwiseConstantFoldPass
 } // namespace
 
 std::unique_ptr<Pass> mlir::tosa::createTosaLayerwiseConstantFoldPass() {
-  return std::make_unique<TosaLayerwiseConstantFoldPass>();
+return std::make_unique<TosaLayerwiseConstantFoldPass>(TosaLayerwiseConstantFoldPassOptions{false});
+}
+
+std::unique_ptr<Pass> mlir::tosa::createTosaLayerwiseConstantFoldPass(const TosaLayerwiseConstantFoldPassOptions& options) {
+  return std::make_unique<TosaLayerwiseConstantFoldPass>(options);
 }
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index 56619fbc560e5fa..612e99f198515ae 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -1,5 +1,8 @@
 // RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold %s | FileCheck %s
 
+
+// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold="aggressive-reduce-constant=true" %s | FileCheck %s --check-prefix=AGGRESIVE
+
 // CHECK-LABEL: @transpose_fold
 func.func @transpose_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
   // CHECK: return %arg0
@@ -1051,3 +1054,57 @@ func.func @reduce_sum_constant() -> tensor<1x3xi32> {
   %0 = tosa.reduce_sum %arg2 {axis = 0 : i32} : (tensor<2x3xi32>) -> tensor<1x3xi32>
   return %0 : tensor<1x3xi32>
 }
+
+// -----
+
+func.func @reduce_sum_constant_aggressive() -> tensor<1x3xi32> {
+  // AGGRESIVE-LABEL: func.func @reduce_sum_constant_aggressive() -> tensor<1x3xi32> {
+  // AGGRESIVE:       %[[VAL_0:.*]] = "tosa.const"() <{value = dense<4> : tensor<1x3xi32>}> : () -> tensor<1x3xi32>
+  // AGGRESIVE:       return %[[VAL_0:.*]] : tensor<1x3xi32>
+  
+  // CHECK-LABEL:     func.func @reduce_sum_constant_aggressive() -> tensor<1x3xi32> {
+  // CHECK:           %[[VAL_0:.*]] = "tosa.const"() <{value = dense<1> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
+  // CHECK:           %[[VAL_1:.*]] = tosa.reduce_sum %[[VAL_0]] {axis = 0 : i32} : (tensor<2x3xi32>) -> tensor<1x3xi32>
+  // CHECK:           %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_0]] {axis = 0 : i32} : (tensor<2x3xi32>) -> tensor<1x3xi32>
+  // CHECK:           %[[VAL_3:.*]] = tosa.add %[[VAL_1]], %[[VAL_2]] : (tensor<1x3xi32>, tensor<1x3xi32>) -> tensor<1x3xi32>
+  // CHECK:           return %[[VAL_3]] : tensor<1x3xi32>
+
+  %const = "tosa.const"() {value = dense<1> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
+  %0 = tosa.reduce_sum %const {axis = 0 : i32} : (tensor<2x3xi32>) -> tensor<1x3xi32>
+  %1 = tosa.reduce_sum %const {axis = 0 : i32} : (tensor<2x3xi32>) -> tensor<1x3xi32>
+  %res = tosa.add %0, %1 : (tensor<1x3xi32>, tensor<1x3xi32>) -> tensor<1x3xi32>
+  return %res : tensor<1x3xi32>
+}
+
+// -----
+
+func.func @reduce_sum_constant_aggressive() -> tensor<2x3xi32> {
+  // AGGRESIVE-LABEL:     func.func @reduce_sum_constant_aggressive() -> tensor<2x3xi32> {
+  // AGGRESIVE:           %[[VAL_0:.*]] = "tosa.const"() <{value = dense<2> : tensor<1x2x3xi32>}> : () -> tensor<1x2x3xi32>
+  // AGGRESIVE:           %[[VAL_1:.*]] = "tosa.const"() <{value = dense<1> : tensor<2x2x3xi32>}> : () -> tensor<2x2x3xi32>
+  // AGGRESIVE:           %[[VAL_2:.*]] = "tosa.const"() <{value = dense<2> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
+  // AGGRESIVE:           %[[VAL_3:.*]] = tosa.argmax %[[VAL_0]] {axis = 0 : i32} : (tensor<1x2x3xi32>) -> tensor<2x3xi32>
+  // AGGRESIVE:           %[[VAL_4:.*]] = tosa.argmax %[[VAL_1]] {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<2x3xi32>
+  // AGGRESIVE:           %[[VAL_5:.*]] = tosa.add %[[VAL_3]], %[[VAL_2]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  // AGGRESIVE:           %[[VAL_6:.*]] = tosa.add %[[VAL_5]], %[[VAL_4]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  // AGGRESIVE:           return %[[VAL_6]] : tensor<2x3xi32>
+
+  // CHECK-LABEL:     func.func @reduce_sum_constant_aggressive() -> tensor<2x3xi32> {
+  // CHECK:           %[[VAL_0:.*]] = "tosa.const"() <{value = dense<1> : tensor<2x2x3xi32>}> : () -> tensor<2x2x3xi32>
+  // CHECK:           %[[VAL_1:.*]] = "tosa.const"() <{value = dense<2> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
+  // CHECK:           %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_0]] {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<1x2x3xi32>
+  // CHECK:           %[[VAL_3:.*]] = tosa.argmax %[[VAL_2]] {axis = 0 : i32} : (tensor<1x2x3xi32>) -> tensor<2x3xi32>
+  // CHECK:           %[[VAL_4:.*]] = tosa.argmax %[[VAL_0]] {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<2x3xi32>
+  // CHECK:           %[[VAL_5:.*]] = tosa.add %[[VAL_3]], %[[VAL_1]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  // CHECK:           %[[VAL_6:.*]] = tosa.add %[[VAL_5]], %[[VAL_4]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  // CHECK:           return %[[VAL_6]] : tensor<2x3xi32>
+
+  %const0 = "tosa.const"() {value = dense<1> : tensor<2x2x3xi32>} : () -> tensor<2x2x3xi32>
+  %const1 = "tosa.const"() {value = dense<2> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
+  %reduce0 = tosa.reduce_sum %const0 {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<1x2x3xi32>
+  %argmax0 = tosa.argmax %reduce0 {axis = 0 : i32} : (tensor<1x2x3xi32>) -> tensor<2x3xi32>
+  %argmax1 = tosa.argmax %const0 {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<2x3xi32>
+  %res0 = tosa.add %argmax0, %const1 : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  %res1 = tosa.add %res0, %argmax1 : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  return %res1 : tensor<2x3xi32>
+}

@amirBish amirBish force-pushed the mlir/tosa/aggressive_const_optimization branch from 24b94b5 to fd3fb47 Compare October 11, 2023 06:18
@github-actions
Copy link

github-actions bot commented Oct 11, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@amirBish amirBish force-pushed the mlir/tosa/aggressive_const_optimization branch from fd3fb47 to 9fedcdd Compare October 11, 2023 06:22
@amirBish
Copy link
Contributor Author

Adding @AviadCo @amrami as subscribers.

Copy link
Contributor

@eric-k256 eric-k256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks okay to me.

…educe optimization

Adding the argument of aggressiveReduceConstant to
the TosaLayerwiseConstantFoldPass which would
allow performing the constant optimizations on the
reduce ops always. (e.g. without considering the
number of users of the input of the reduce operation)
@amirBish amirBish force-pushed the mlir/tosa/aggressive_const_optimization branch from 9fedcdd to c291b8b Compare October 12, 2023 05:37
@amirBish amirBish merged commit 9dd15f7 into llvm:main Oct 12, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants