New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][mesh] Add folding of ClusterShapeOp #77033
Conversation
@llvm/pr-subscribers-mlir Author: Boian Petkantchin (sogartar) ChangesIf the mesh has static size on some of the requested axes, the result is substituted with a constant. Full diff: https://github.com/llvm/llvm-project/pull/77033.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
index f70bdaa9de0a0f..f7096cfce634ee 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
@@ -19,6 +19,9 @@
#include <utility>
namespace mlir {
+
+class SymbolTable;
+
namespace mesh {
// If we have an algebraic op like "+" and a summing all-reduce,
@@ -103,6 +106,10 @@ void populateAllReduceEndomorphismSimplificationPatterns(
}
void populateSimplificationPatterns(RewritePatternSet &patterns);
+// It is invalid to change ops that declare symbols during the application of
+// these patterns, because symbolTable is used to cache them.
+void populateFoldingPatterns(RewritePatternSet &patterns,
+ SymbolTableCollection &symbolTable);
} // namespace mesh
} // namespace mlir
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
index 643bd7b8e77c93..eab3bc88fd1d38 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
@@ -8,6 +8,17 @@
#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include <iterator>
+#include <numeric>
+#include <utility>
namespace mlir {
namespace mesh {
@@ -35,5 +46,80 @@ void populateSimplificationPatterns(RewritePatternSet &patterns) {
// TODO: add simplifications for all-gather and other collectives.
}
+namespace {
+
+// This folding can not be done with an operation's fold method or
+// DialectFoldInterface, because it needs a SymbolTableCollection to cache the
+// symbol tables.
+// We can't use DialectFoldInterface since the cache may be invalidated by some
+// pass changing the referenced ClusterOp ops.
+struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
+ template <typename... OpRewritePatternArgs>
+ ClusterShapeFolder(SymbolTableCollection &symbolTable,
+ OpRewritePatternArgs &&...opRewritePatternArgs)
+ : OpRewritePattern(
+ std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
+ symbolTable(symbolTable) {}
+ LogicalResult matchAndRewrite(ClusterShapeOp op,
+ PatternRewriter &rewriter) const override {
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+ ClusterOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(
+ op.getOperation(), op.getMeshAttr());
+ if (!mesh) {
+ return failure();
+ }
+ ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
+ SmallVector<MeshAxis> opAxesIota;
+ if (opMeshAxes.empty()) {
+ opAxesIota.resize(mesh.getRank());
+ std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
+ opMeshAxes = opAxesIota;
+ }
+ if (llvm::all_of(opMeshAxes, [&mesh](MeshAxis axis) {
+ return ShapedType::isDynamic(mesh.getDimSizes()[axis]);
+ })) {
+ // All mesh dimensions are dynamic. Nothing to fold.
+ return failure();
+ }
+
+ SmallVector<Value> newResults(op->getResults().size());
+ SmallVector<MeshAxis> newShapeOpMeshAxes;
+ SmallVector<size_t> newToOldResultsIndexMap;
+
+ for (size_t i = 0; i < opMeshAxes.size(); ++i) {
+ auto meshAxisSize = mesh.getDimSizes()[opMeshAxes[i]];
+ if (ShapedType::isDynamic(meshAxisSize)) {
+ newToOldResultsIndexMap.push_back(i);
+ newShapeOpMeshAxes.push_back(opMeshAxes[i]);
+ } else {
+ // Fold static mesh axes.
+ newResults[i] = builder.create<arith::ConstantOp>(
+ builder.getIndexAttr(meshAxisSize));
+ }
+ }
+
+ // Leave only the dynamic mesh axes to be queried.
+ ClusterShapeOp newShapeOp =
+ builder.create<ClusterShapeOp>(mesh.getSymName(), newShapeOpMeshAxes);
+ for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
+ newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
+ }
+
+ rewriter.replaceAllUsesWith(op.getResults(), newResults);
+
+ return success();
+ }
+
+private:
+ SymbolTableCollection &symbolTable;
+};
+
+} // namespace
+
+void populateFoldingPatterns(RewritePatternSet &patterns,
+ SymbolTableCollection &symbolTable) {
+ patterns.add<ClusterShapeFolder>(symbolTable, patterns.getContext());
+}
+
} // namespace mesh
} // namespace mlir
diff --git a/mlir/test/Dialect/Mesh/folding.mlir b/mlir/test/Dialect/Mesh/folding.mlir
new file mode 100644
index 00000000000000..1283353709ca3c
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/folding.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt -test-mesh-folding %s | FileCheck %s
+
+mesh.cluster @mesh0(rank = 3, dim_sizes = 4x?x2)
+mesh.cluster @mesh1(rank = 2, dim_sizes = 2x3)
+
+// CHECK-LABEL: func.func @cluster_shape_op_folding
+func.func @cluster_shape_op_folding() -> (index, index) {
+ // CHECK: %[[AXIS_2_SIZE:.*]] = arith.constant 2 : index
+ // CHECK: %[[AXIS_1_SIZE:.*]] = mesh.cluster_shape @mesh0 axes = [1] : index
+ %0:2 = mesh.cluster_shape @mesh0 axes = [2, 1] : index, index
+ // CHECK: return %[[AXIS_2_SIZE]], %[[AXIS_1_SIZE]]
+ return %0#0, %0#1 : index, index
+}
+
+// CHECK-LABEL: func.func @cluster_shape_op_folding_all_axes_static_mesh
+func.func @cluster_shape_op_folding_all_axes_static_mesh() -> (index, index) {
+ // CHECK: %[[AXIS_0_SIZE:.*]] = arith.constant 2 : index
+ // CHECK: %[[AXIS_1_SIZE:.*]] = arith.constant 3 : index
+ %0:2 = mesh.cluster_shape @mesh1 : index, index
+ // CHECK: return %[[AXIS_0_SIZE]], %[[AXIS_1_SIZE]]
+ return %0#0, %0#1 : index, index
+}
diff --git a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
index f14d282857a1e0..3da64694ee2155 100644
--- a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
@@ -1,5 +1,6 @@
# Exclude tests from libMLIR.so
-add_mlir_library(MLIRMeshTestSimplifications
+add_mlir_library(MLIRMeshTest
+ TestFolding.cpp
TestReshardingSpmdization.cpp
TestSimplifications.cpp
diff --git a/mlir/test/lib/Dialect/Mesh/TestFolding.cpp b/mlir/test/lib/Dialect/Mesh/TestFolding.cpp
new file mode 100644
index 00000000000000..1cf436edea8e35
--- /dev/null
+++ b/mlir/test/lib/Dialect/Mesh/TestFolding.cpp
@@ -0,0 +1,52 @@
+//===- TestSimplification.cpp - Test simplification -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include <memory>
+
+using namespace mlir;
+
+namespace {
+
+struct TestMeshFoldingPass
+ : public PassWrapper<TestMeshFoldingPass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshFoldingPass)
+
+ void runOnOperation() override;
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<mesh::MeshDialect>();
+ }
+ StringRef getArgument() const final { return "test-mesh-folding"; }
+ StringRef getDescription() const final { return "Test mesh folding."; }
+};
+} // namespace
+
+void TestMeshFoldingPass::runOnOperation() {
+ RewritePatternSet patterns(&getContext());
+ SymbolTableCollection symbolTables;
+ mesh::populateFoldingPatterns(patterns, symbolTables);
+ if (failed(
+ applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
+ getOperation()->emitError()
+ << "Rewrite patter application did not converge.";
+ return signalPassFailure();
+ }
+}
+
+namespace mlir {
+namespace test {
+void registerTestMeshFoldingPass() { PassRegistration<TestMeshFoldingPass>(); }
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index b6ada66d321880..a5da9390a0c5b3 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -26,7 +26,7 @@ if(MLIR_INCLUDE_TESTS)
MLIRLoopLikeInterfaceTestPasses
MLIRMathTestPasses
MLIRMemRefTestPasses
- MLIRMeshTestSimplifications
+ MLIRMeshTest
MLIRNVGPUTestPasses
MLIRSCFTestPasses
MLIRShapeTestPasses
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index f7a5b3183b50b1..461163f671ce89 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -118,6 +118,7 @@ void registerTestMathAlgebraicSimplificationPass();
void registerTestMathPolynomialApproximationPass();
void registerTestMemRefDependenceCheck();
void registerTestMemRefStrideCalculation();
+void registerTestMeshFoldingPass();
void registerTestMeshSimplificationsPass();
void registerTestMeshReshardingSpmdizationPass();
void registerTestNextAccessPass();
@@ -237,6 +238,7 @@ void registerTestPasses() {
mlir::test::registerTestMathPolynomialApproximationPass();
mlir::test::registerTestMemRefDependenceCheck();
mlir::test::registerTestMemRefStrideCalculation();
+ mlir::test::registerTestMeshFoldingPass();
mlir::test::registerTestMeshSimplificationsPass();
mlir::test::registerTestMeshReshardingSpmdizationPass();
mlir::test::registerTestNextAccessPass();
|
@llvm/pr-subscribers-mlir-core Author: Boian Petkantchin (sogartar) ChangesIf the mesh has static size on some of the requested axes, the result is substituted with a constant. Full diff: https://github.com/llvm/llvm-project/pull/77033.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
index f70bdaa9de0a0f..f7096cfce634ee 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
@@ -19,6 +19,9 @@
#include <utility>
namespace mlir {
+
+class SymbolTable;
+
namespace mesh {
// If we have an algebraic op like "+" and a summing all-reduce,
@@ -103,6 +106,10 @@ void populateAllReduceEndomorphismSimplificationPatterns(
}
void populateSimplificationPatterns(RewritePatternSet &patterns);
+// It is invalid to change ops that declare symbols during the application of
+// these patterns, because symbolTable is used to cache them.
+void populateFoldingPatterns(RewritePatternSet &patterns,
+ SymbolTableCollection &symbolTable);
} // namespace mesh
} // namespace mlir
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
index 643bd7b8e77c93..eab3bc88fd1d38 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
@@ -8,6 +8,17 @@
#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include <iterator>
+#include <numeric>
+#include <utility>
namespace mlir {
namespace mesh {
@@ -35,5 +46,80 @@ void populateSimplificationPatterns(RewritePatternSet &patterns) {
// TODO: add simplifications for all-gather and other collectives.
}
+namespace {
+
+// This folding can not be done with an operation's fold method or
+// DialectFoldInterface, because it needs a SymbolTableCollection to cache the
+// symbol tables.
+// We can't use DialectFoldInterface since the cache may be invalidated by some
+// pass changing the referenced ClusterOp ops.
+struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
+ template <typename... OpRewritePatternArgs>
+ ClusterShapeFolder(SymbolTableCollection &symbolTable,
+ OpRewritePatternArgs &&...opRewritePatternArgs)
+ : OpRewritePattern(
+ std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
+ symbolTable(symbolTable) {}
+ LogicalResult matchAndRewrite(ClusterShapeOp op,
+ PatternRewriter &rewriter) const override {
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+ ClusterOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(
+ op.getOperation(), op.getMeshAttr());
+ if (!mesh) {
+ return failure();
+ }
+ ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
+ SmallVector<MeshAxis> opAxesIota;
+ if (opMeshAxes.empty()) {
+ opAxesIota.resize(mesh.getRank());
+ std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
+ opMeshAxes = opAxesIota;
+ }
+ if (llvm::all_of(opMeshAxes, [&mesh](MeshAxis axis) {
+ return ShapedType::isDynamic(mesh.getDimSizes()[axis]);
+ })) {
+ // All mesh dimensions are dynamic. Nothing to fold.
+ return failure();
+ }
+
+ SmallVector<Value> newResults(op->getResults().size());
+ SmallVector<MeshAxis> newShapeOpMeshAxes;
+ SmallVector<size_t> newToOldResultsIndexMap;
+
+ for (size_t i = 0; i < opMeshAxes.size(); ++i) {
+ auto meshAxisSize = mesh.getDimSizes()[opMeshAxes[i]];
+ if (ShapedType::isDynamic(meshAxisSize)) {
+ newToOldResultsIndexMap.push_back(i);
+ newShapeOpMeshAxes.push_back(opMeshAxes[i]);
+ } else {
+ // Fold static mesh axes.
+ newResults[i] = builder.create<arith::ConstantOp>(
+ builder.getIndexAttr(meshAxisSize));
+ }
+ }
+
+ // Leave only the dynamic mesh axes to be queried.
+ ClusterShapeOp newShapeOp =
+ builder.create<ClusterShapeOp>(mesh.getSymName(), newShapeOpMeshAxes);
+ for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
+ newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
+ }
+
+ rewriter.replaceAllUsesWith(op.getResults(), newResults);
+
+ return success();
+ }
+
+private:
+ SymbolTableCollection &symbolTable;
+};
+
+} // namespace
+
+void populateFoldingPatterns(RewritePatternSet &patterns,
+ SymbolTableCollection &symbolTable) {
+ patterns.add<ClusterShapeFolder>(symbolTable, patterns.getContext());
+}
+
} // namespace mesh
} // namespace mlir
diff --git a/mlir/test/Dialect/Mesh/folding.mlir b/mlir/test/Dialect/Mesh/folding.mlir
new file mode 100644
index 00000000000000..1283353709ca3c
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/folding.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt -test-mesh-folding %s | FileCheck %s
+
+mesh.cluster @mesh0(rank = 3, dim_sizes = 4x?x2)
+mesh.cluster @mesh1(rank = 2, dim_sizes = 2x3)
+
+// CHECK-LABEL: func.func @cluster_shape_op_folding
+func.func @cluster_shape_op_folding() -> (index, index) {
+ // CHECK: %[[AXIS_2_SIZE:.*]] = arith.constant 2 : index
+ // CHECK: %[[AXIS_1_SIZE:.*]] = mesh.cluster_shape @mesh0 axes = [1] : index
+ %0:2 = mesh.cluster_shape @mesh0 axes = [2, 1] : index, index
+ // CHECK: return %[[AXIS_2_SIZE]], %[[AXIS_1_SIZE]]
+ return %0#0, %0#1 : index, index
+}
+
+// CHECK-LABEL: func.func @cluster_shape_op_folding_all_axes_static_mesh
+func.func @cluster_shape_op_folding_all_axes_static_mesh() -> (index, index) {
+ // CHECK: %[[AXIS_0_SIZE:.*]] = arith.constant 2 : index
+ // CHECK: %[[AXIS_1_SIZE:.*]] = arith.constant 3 : index
+ %0:2 = mesh.cluster_shape @mesh1 : index, index
+ // CHECK: return %[[AXIS_0_SIZE]], %[[AXIS_1_SIZE]]
+ return %0#0, %0#1 : index, index
+}
diff --git a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
index f14d282857a1e0..3da64694ee2155 100644
--- a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
@@ -1,5 +1,6 @@
# Exclude tests from libMLIR.so
-add_mlir_library(MLIRMeshTestSimplifications
+add_mlir_library(MLIRMeshTest
+ TestFolding.cpp
TestReshardingSpmdization.cpp
TestSimplifications.cpp
diff --git a/mlir/test/lib/Dialect/Mesh/TestFolding.cpp b/mlir/test/lib/Dialect/Mesh/TestFolding.cpp
new file mode 100644
index 00000000000000..1cf436edea8e35
--- /dev/null
+++ b/mlir/test/lib/Dialect/Mesh/TestFolding.cpp
@@ -0,0 +1,52 @@
+//===- TestSimplification.cpp - Test simplification -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include <memory>
+
+using namespace mlir;
+
+namespace {
+
+struct TestMeshFoldingPass
+ : public PassWrapper<TestMeshFoldingPass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshFoldingPass)
+
+ void runOnOperation() override;
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<mesh::MeshDialect>();
+ }
+ StringRef getArgument() const final { return "test-mesh-folding"; }
+ StringRef getDescription() const final { return "Test mesh folding."; }
+};
+} // namespace
+
+void TestMeshFoldingPass::runOnOperation() {
+ RewritePatternSet patterns(&getContext());
+ SymbolTableCollection symbolTables;
+ mesh::populateFoldingPatterns(patterns, symbolTables);
+ if (failed(
+ applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
+ getOperation()->emitError()
+ << "Rewrite patter application did not converge.";
+ return signalPassFailure();
+ }
+}
+
+namespace mlir {
+namespace test {
+void registerTestMeshFoldingPass() { PassRegistration<TestMeshFoldingPass>(); }
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index b6ada66d321880..a5da9390a0c5b3 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -26,7 +26,7 @@ if(MLIR_INCLUDE_TESTS)
MLIRLoopLikeInterfaceTestPasses
MLIRMathTestPasses
MLIRMemRefTestPasses
- MLIRMeshTestSimplifications
+ MLIRMeshTest
MLIRNVGPUTestPasses
MLIRSCFTestPasses
MLIRShapeTestPasses
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index f7a5b3183b50b1..461163f671ce89 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -118,6 +118,7 @@ void registerTestMathAlgebraicSimplificationPass();
void registerTestMathPolynomialApproximationPass();
void registerTestMemRefDependenceCheck();
void registerTestMemRefStrideCalculation();
+void registerTestMeshFoldingPass();
void registerTestMeshSimplificationsPass();
void registerTestMeshReshardingSpmdizationPass();
void registerTestNextAccessPass();
@@ -237,6 +238,7 @@ void registerTestPasses() {
mlir::test::registerTestMathPolynomialApproximationPass();
mlir::test::registerTestMemRefDependenceCheck();
mlir::test::registerTestMemRefStrideCalculation();
+ mlir::test::registerTestMeshFoldingPass();
mlir::test::registerTestMeshSimplificationsPass();
mlir::test::registerTestMeshReshardingSpmdizationPass();
mlir::test::registerTestNextAccessPass();
|
@yaochengji, could you review this PR? |
LGTM, thanks |
If the mesh has static size on some of the requested axes, the result is substituted with a constant.
30fe35c
to
aad2791
Compare
I rebased before merging. |
If the mesh has static size on some of the requested axes, the result is substituted with a constant.
If the mesh has static size on some of the requested axes, the result is substituted with a constant.