Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][mesh] Add folding of ClusterShapeOp #77033

Merged
merged 2 commits into from Jan 9, 2024

Conversation

sogartar
Copy link
Contributor

@sogartar sogartar commented Jan 5, 2024

If the mesh has static size on some of the requested axes, the result is substituted with a constant.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Jan 5, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Jan 5, 2024

@llvm/pr-subscribers-mlir

Author: Boian Petkantchin (sogartar)

Changes

If the mesh has static size on some of the requested axes, the result is substituted with a constant.


Full diff: https://github.com/llvm/llvm-project/pull/77033.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h (+7)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp (+86)
  • (added) mlir/test/Dialect/Mesh/folding.mlir (+22)
  • (modified) mlir/test/lib/Dialect/Mesh/CMakeLists.txt (+2-1)
  • (added) mlir/test/lib/Dialect/Mesh/TestFolding.cpp (+52)
  • (modified) mlir/tools/mlir-opt/CMakeLists.txt (+1-1)
  • (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2)
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
index f70bdaa9de0a0f..f7096cfce634ee 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
@@ -19,6 +19,9 @@
 #include <utility>
 
 namespace mlir {
+
+class SymbolTable;
+
 namespace mesh {
 
 // If we have an algebraic op like "+" and a summing all-reduce,
@@ -103,6 +106,10 @@ void populateAllReduceEndomorphismSimplificationPatterns(
 }
 
 void populateSimplificationPatterns(RewritePatternSet &patterns);
+// It is invalid to change ops that declare symbols during the application of
+// these patterns, because symbolTable is used to cache them.
+void populateFoldingPatterns(RewritePatternSet &patterns,
+                             SymbolTableCollection &symbolTable);
 
 } // namespace mesh
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
index 643bd7b8e77c93..eab3bc88fd1d38 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
@@ -8,6 +8,17 @@
 
 #include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include <iterator>
+#include <numeric>
+#include <utility>
 
 namespace mlir {
 namespace mesh {
@@ -35,5 +46,80 @@ void populateSimplificationPatterns(RewritePatternSet &patterns) {
   // TODO: add simplifications for all-gather and other collectives.
 }
 
+namespace {
+
+// This folding can not be done with an operation's fold method or
+// DialectFoldInterface, because it needs a SymbolTableCollection to cache the
+// symbol tables.
+// We can't use DialectFoldInterface since the cache may be invalidated by some
+// pass changing the referenced ClusterOp ops.
+struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
+  template <typename... OpRewritePatternArgs>
+  ClusterShapeFolder(SymbolTableCollection &symbolTable,
+                     OpRewritePatternArgs &&...opRewritePatternArgs)
+      : OpRewritePattern(
+            std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
+        symbolTable(symbolTable) {}
+  LogicalResult matchAndRewrite(ClusterShapeOp op,
+                                PatternRewriter &rewriter) const override {
+    ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+    ClusterOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(
+        op.getOperation(), op.getMeshAttr());
+    if (!mesh) {
+      return failure();
+    }
+    ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
+    SmallVector<MeshAxis> opAxesIota;
+    if (opMeshAxes.empty()) {
+      opAxesIota.resize(mesh.getRank());
+      std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
+      opMeshAxes = opAxesIota;
+    }
+    if (llvm::all_of(opMeshAxes, [&mesh](MeshAxis axis) {
+          return ShapedType::isDynamic(mesh.getDimSizes()[axis]);
+        })) {
+      // All mesh dimensions are dynamic. Nothing to fold.
+      return failure();
+    }
+
+    SmallVector<Value> newResults(op->getResults().size());
+    SmallVector<MeshAxis> newShapeOpMeshAxes;
+    SmallVector<size_t> newToOldResultsIndexMap;
+
+    for (size_t i = 0; i < opMeshAxes.size(); ++i) {
+      auto meshAxisSize = mesh.getDimSizes()[opMeshAxes[i]];
+      if (ShapedType::isDynamic(meshAxisSize)) {
+        newToOldResultsIndexMap.push_back(i);
+        newShapeOpMeshAxes.push_back(opMeshAxes[i]);
+      } else {
+        // Fold static mesh axes.
+        newResults[i] = builder.create<arith::ConstantOp>(
+            builder.getIndexAttr(meshAxisSize));
+      }
+    }
+
+    // Leave only the dynamic mesh axes to be queried.
+    ClusterShapeOp newShapeOp =
+        builder.create<ClusterShapeOp>(mesh.getSymName(), newShapeOpMeshAxes);
+    for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
+      newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
+    }
+
+    rewriter.replaceAllUsesWith(op.getResults(), newResults);
+
+    return success();
+  }
+
+private:
+  SymbolTableCollection &symbolTable;
+};
+
+} // namespace
+
+void populateFoldingPatterns(RewritePatternSet &patterns,
+                             SymbolTableCollection &symbolTable) {
+  patterns.add<ClusterShapeFolder>(symbolTable, patterns.getContext());
+}
+
 } // namespace mesh
 } // namespace mlir
diff --git a/mlir/test/Dialect/Mesh/folding.mlir b/mlir/test/Dialect/Mesh/folding.mlir
new file mode 100644
index 00000000000000..1283353709ca3c
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/folding.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt -test-mesh-folding %s | FileCheck %s
+
+mesh.cluster @mesh0(rank = 3, dim_sizes = 4x?x2)
+mesh.cluster @mesh1(rank = 2, dim_sizes = 2x3)
+
+// CHECK-LABEL: func.func @cluster_shape_op_folding
+func.func @cluster_shape_op_folding() -> (index, index) {
+  // CHECK: %[[AXIS_2_SIZE:.*]] = arith.constant 2 : index
+  // CHECK: %[[AXIS_1_SIZE:.*]] = mesh.cluster_shape @mesh0 axes = [1] : index
+  %0:2 = mesh.cluster_shape @mesh0 axes = [2, 1] : index, index
+  // CHECK: return %[[AXIS_2_SIZE]], %[[AXIS_1_SIZE]]
+  return %0#0, %0#1 : index, index
+}
+
+// CHECK-LABEL: func.func @cluster_shape_op_folding_all_axes_static_mesh
+func.func @cluster_shape_op_folding_all_axes_static_mesh() -> (index, index) {
+  // CHECK: %[[AXIS_0_SIZE:.*]] = arith.constant 2 : index
+  // CHECK: %[[AXIS_1_SIZE:.*]] = arith.constant 3 : index
+  %0:2 = mesh.cluster_shape @mesh1 : index, index
+  // CHECK: return %[[AXIS_0_SIZE]], %[[AXIS_1_SIZE]]
+  return %0#0, %0#1 : index, index
+}
diff --git a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
index f14d282857a1e0..3da64694ee2155 100644
--- a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
@@ -1,5 +1,6 @@
 # Exclude tests from libMLIR.so
-add_mlir_library(MLIRMeshTestSimplifications
+add_mlir_library(MLIRMeshTest
+  TestFolding.cpp
   TestReshardingSpmdization.cpp
   TestSimplifications.cpp
 
diff --git a/mlir/test/lib/Dialect/Mesh/TestFolding.cpp b/mlir/test/lib/Dialect/Mesh/TestFolding.cpp
new file mode 100644
index 00000000000000..1cf436edea8e35
--- /dev/null
+++ b/mlir/test/lib/Dialect/Mesh/TestFolding.cpp
@@ -0,0 +1,52 @@
+//===- TestSimplification.cpp - Test simplification -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include <memory>
+
+using namespace mlir;
+
+namespace {
+
+struct TestMeshFoldingPass
+    : public PassWrapper<TestMeshFoldingPass, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshFoldingPass)
+
+  void runOnOperation() override;
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<mesh::MeshDialect>();
+  }
+  StringRef getArgument() const final { return "test-mesh-folding"; }
+  StringRef getDescription() const final { return "Test mesh folding."; }
+};
+} // namespace
+
+void TestMeshFoldingPass::runOnOperation() {
+  RewritePatternSet patterns(&getContext());
+  SymbolTableCollection symbolTables;
+  mesh::populateFoldingPatterns(patterns, symbolTables);
+  if (failed(
+          applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
+    getOperation()->emitError()
+        << "Rewrite patter application did not converge.";
+    return signalPassFailure();
+  }
+}
+
+namespace mlir {
+namespace test {
+void registerTestMeshFoldingPass() { PassRegistration<TestMeshFoldingPass>(); }
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index b6ada66d321880..a5da9390a0c5b3 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -26,7 +26,7 @@ if(MLIR_INCLUDE_TESTS)
     MLIRLoopLikeInterfaceTestPasses
     MLIRMathTestPasses
     MLIRMemRefTestPasses
-    MLIRMeshTestSimplifications
+    MLIRMeshTest
     MLIRNVGPUTestPasses
     MLIRSCFTestPasses
     MLIRShapeTestPasses
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index f7a5b3183b50b1..461163f671ce89 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -118,6 +118,7 @@ void registerTestMathAlgebraicSimplificationPass();
 void registerTestMathPolynomialApproximationPass();
 void registerTestMemRefDependenceCheck();
 void registerTestMemRefStrideCalculation();
+void registerTestMeshFoldingPass();
 void registerTestMeshSimplificationsPass();
 void registerTestMeshReshardingSpmdizationPass();
 void registerTestNextAccessPass();
@@ -237,6 +238,7 @@ void registerTestPasses() {
   mlir::test::registerTestMathPolynomialApproximationPass();
   mlir::test::registerTestMemRefDependenceCheck();
   mlir::test::registerTestMemRefStrideCalculation();
+  mlir::test::registerTestMeshFoldingPass();
   mlir::test::registerTestMeshSimplificationsPass();
   mlir::test::registerTestMeshReshardingSpmdizationPass();
   mlir::test::registerTestNextAccessPass();

@llvmbot
Copy link
Collaborator

llvmbot commented Jan 5, 2024

@llvm/pr-subscribers-mlir-core

Author: Boian Petkantchin (sogartar)

Changes

If the mesh has static size on some of the requested axes, the result is substituted with a constant.


Full diff: https://github.com/llvm/llvm-project/pull/77033.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h (+7)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp (+86)
  • (added) mlir/test/Dialect/Mesh/folding.mlir (+22)
  • (modified) mlir/test/lib/Dialect/Mesh/CMakeLists.txt (+2-1)
  • (added) mlir/test/lib/Dialect/Mesh/TestFolding.cpp (+52)
  • (modified) mlir/tools/mlir-opt/CMakeLists.txt (+1-1)
  • (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2)
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
index f70bdaa9de0a0f..f7096cfce634ee 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
@@ -19,6 +19,9 @@
 #include <utility>
 
 namespace mlir {
+
+class SymbolTable;
+
 namespace mesh {
 
 // If we have an algebraic op like "+" and a summing all-reduce,
@@ -103,6 +106,10 @@ void populateAllReduceEndomorphismSimplificationPatterns(
 }
 
 void populateSimplificationPatterns(RewritePatternSet &patterns);
+// It is invalid to change ops that declare symbols during the application of
+// these patterns, because symbolTable is used to cache them.
+void populateFoldingPatterns(RewritePatternSet &patterns,
+                             SymbolTableCollection &symbolTable);
 
 } // namespace mesh
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
index 643bd7b8e77c93..eab3bc88fd1d38 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
@@ -8,6 +8,17 @@
 
 #include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include <iterator>
+#include <numeric>
+#include <utility>
 
 namespace mlir {
 namespace mesh {
@@ -35,5 +46,80 @@ void populateSimplificationPatterns(RewritePatternSet &patterns) {
   // TODO: add simplifications for all-gather and other collectives.
 }
 
+namespace {
+
+// This folding can not be done with an operation's fold method or
+// DialectFoldInterface, because it needs a SymbolTableCollection to cache the
+// symbol tables.
+// We can't use DialectFoldInterface since the cache may be invalidated by some
+// pass changing the referenced ClusterOp ops.
+struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
+  template <typename... OpRewritePatternArgs>
+  ClusterShapeFolder(SymbolTableCollection &symbolTable,
+                     OpRewritePatternArgs &&...opRewritePatternArgs)
+      : OpRewritePattern(
+            std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
+        symbolTable(symbolTable) {}
+  LogicalResult matchAndRewrite(ClusterShapeOp op,
+                                PatternRewriter &rewriter) const override {
+    ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+    ClusterOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(
+        op.getOperation(), op.getMeshAttr());
+    if (!mesh) {
+      return failure();
+    }
+    ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
+    SmallVector<MeshAxis> opAxesIota;
+    if (opMeshAxes.empty()) {
+      opAxesIota.resize(mesh.getRank());
+      std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
+      opMeshAxes = opAxesIota;
+    }
+    if (llvm::all_of(opMeshAxes, [&mesh](MeshAxis axis) {
+          return ShapedType::isDynamic(mesh.getDimSizes()[axis]);
+        })) {
+      // All mesh dimensions are dynamic. Nothing to fold.
+      return failure();
+    }
+
+    SmallVector<Value> newResults(op->getResults().size());
+    SmallVector<MeshAxis> newShapeOpMeshAxes;
+    SmallVector<size_t> newToOldResultsIndexMap;
+
+    for (size_t i = 0; i < opMeshAxes.size(); ++i) {
+      auto meshAxisSize = mesh.getDimSizes()[opMeshAxes[i]];
+      if (ShapedType::isDynamic(meshAxisSize)) {
+        newToOldResultsIndexMap.push_back(i);
+        newShapeOpMeshAxes.push_back(opMeshAxes[i]);
+      } else {
+        // Fold static mesh axes.
+        newResults[i] = builder.create<arith::ConstantOp>(
+            builder.getIndexAttr(meshAxisSize));
+      }
+    }
+
+    // Leave only the dynamic mesh axes to be queried.
+    ClusterShapeOp newShapeOp =
+        builder.create<ClusterShapeOp>(mesh.getSymName(), newShapeOpMeshAxes);
+    for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
+      newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
+    }
+
+    rewriter.replaceAllUsesWith(op.getResults(), newResults);
+
+    return success();
+  }
+
+private:
+  SymbolTableCollection &symbolTable;
+};
+
+} // namespace
+
+void populateFoldingPatterns(RewritePatternSet &patterns,
+                             SymbolTableCollection &symbolTable) {
+  patterns.add<ClusterShapeFolder>(symbolTable, patterns.getContext());
+}
+
 } // namespace mesh
 } // namespace mlir
diff --git a/mlir/test/Dialect/Mesh/folding.mlir b/mlir/test/Dialect/Mesh/folding.mlir
new file mode 100644
index 00000000000000..1283353709ca3c
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/folding.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt -test-mesh-folding %s | FileCheck %s
+
+mesh.cluster @mesh0(rank = 3, dim_sizes = 4x?x2)
+mesh.cluster @mesh1(rank = 2, dim_sizes = 2x3)
+
+// CHECK-LABEL: func.func @cluster_shape_op_folding
+func.func @cluster_shape_op_folding() -> (index, index) {
+  // CHECK: %[[AXIS_2_SIZE:.*]] = arith.constant 2 : index
+  // CHECK: %[[AXIS_1_SIZE:.*]] = mesh.cluster_shape @mesh0 axes = [1] : index
+  %0:2 = mesh.cluster_shape @mesh0 axes = [2, 1] : index, index
+  // CHECK: return %[[AXIS_2_SIZE]], %[[AXIS_1_SIZE]]
+  return %0#0, %0#1 : index, index
+}
+
+// CHECK-LABEL: func.func @cluster_shape_op_folding_all_axes_static_mesh
+func.func @cluster_shape_op_folding_all_axes_static_mesh() -> (index, index) {
+  // CHECK: %[[AXIS_0_SIZE:.*]] = arith.constant 2 : index
+  // CHECK: %[[AXIS_1_SIZE:.*]] = arith.constant 3 : index
+  %0:2 = mesh.cluster_shape @mesh1 : index, index
+  // CHECK: return %[[AXIS_0_SIZE]], %[[AXIS_1_SIZE]]
+  return %0#0, %0#1 : index, index
+}
diff --git a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
index f14d282857a1e0..3da64694ee2155 100644
--- a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
@@ -1,5 +1,6 @@
 # Exclude tests from libMLIR.so
-add_mlir_library(MLIRMeshTestSimplifications
+add_mlir_library(MLIRMeshTest
+  TestFolding.cpp
   TestReshardingSpmdization.cpp
   TestSimplifications.cpp
 
diff --git a/mlir/test/lib/Dialect/Mesh/TestFolding.cpp b/mlir/test/lib/Dialect/Mesh/TestFolding.cpp
new file mode 100644
index 00000000000000..1cf436edea8e35
--- /dev/null
+++ b/mlir/test/lib/Dialect/Mesh/TestFolding.cpp
@@ -0,0 +1,52 @@
+//===- TestSimplification.cpp - Test simplification -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include <memory>
+
+using namespace mlir;
+
+namespace {
+
+struct TestMeshFoldingPass
+    : public PassWrapper<TestMeshFoldingPass, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshFoldingPass)
+
+  void runOnOperation() override;
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<mesh::MeshDialect>();
+  }
+  StringRef getArgument() const final { return "test-mesh-folding"; }
+  StringRef getDescription() const final { return "Test mesh folding."; }
+};
+} // namespace
+
+void TestMeshFoldingPass::runOnOperation() {
+  RewritePatternSet patterns(&getContext());
+  SymbolTableCollection symbolTables;
+  mesh::populateFoldingPatterns(patterns, symbolTables);
+  if (failed(
+          applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
+    getOperation()->emitError()
+        << "Rewrite patter application did not converge.";
+    return signalPassFailure();
+  }
+}
+
+namespace mlir {
+namespace test {
+void registerTestMeshFoldingPass() { PassRegistration<TestMeshFoldingPass>(); }
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index b6ada66d321880..a5da9390a0c5b3 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -26,7 +26,7 @@ if(MLIR_INCLUDE_TESTS)
     MLIRLoopLikeInterfaceTestPasses
     MLIRMathTestPasses
     MLIRMemRefTestPasses
-    MLIRMeshTestSimplifications
+    MLIRMeshTest
     MLIRNVGPUTestPasses
     MLIRSCFTestPasses
     MLIRShapeTestPasses
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index f7a5b3183b50b1..461163f671ce89 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -118,6 +118,7 @@ void registerTestMathAlgebraicSimplificationPass();
 void registerTestMathPolynomialApproximationPass();
 void registerTestMemRefDependenceCheck();
 void registerTestMemRefStrideCalculation();
+void registerTestMeshFoldingPass();
 void registerTestMeshSimplificationsPass();
 void registerTestMeshReshardingSpmdizationPass();
 void registerTestNextAccessPass();
@@ -237,6 +238,7 @@ void registerTestPasses() {
   mlir::test::registerTestMathPolynomialApproximationPass();
   mlir::test::registerTestMemRefDependenceCheck();
   mlir::test::registerTestMemRefStrideCalculation();
+  mlir::test::registerTestMeshFoldingPass();
   mlir::test::registerTestMeshSimplificationsPass();
   mlir::test::registerTestMeshReshardingSpmdizationPass();
   mlir::test::registerTestNextAccessPass();

@sogartar
Copy link
Contributor Author

sogartar commented Jan 5, 2024

@yaochengji, could you review this PR?

@yaochengji
Copy link
Member

LGTM, thanks

If the mesh has static size on some of the requested axes,
the result is substituted with a constant.
@sogartar
Copy link
Contributor Author

sogartar commented Jan 9, 2024

I rebased before merging.

@sogartar sogartar merged commit ab59037 into llvm:main Jan 9, 2024
3 of 4 checks passed
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
If the mesh has static size on some of the requested axes, the result is
substituted with a constant.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants