Skip to content

Commit

Permalink
Add folding patterns to all simplification patterns
Browse files Browse the repository at this point in the history
  • Loading branch information
sogartar committed Jan 9, 2024
1 parent 3706d6f commit aad2791
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 70 deletions.
9 changes: 5 additions & 4 deletions mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
Expand Up @@ -20,7 +20,7 @@

namespace mlir {

class SymbolTable;
class SymbolTableCollection;

namespace mesh {

Expand Down Expand Up @@ -105,11 +105,12 @@ void populateAllReduceEndomorphismSimplificationPatterns(
AlgebraicOp::getOperationName(), 1, patterns.getContext()));
}

void populateSimplificationPatterns(RewritePatternSet &patterns);
// It is invalid to change ops that declare symbols during the application of
// these patterns, because symbolTable is used to cache them.
// these patterns, because symbolTableCollection is used to cache them.
void populateSimplificationPatterns(
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
void populateFoldingPatterns(RewritePatternSet &patterns,
SymbolTableCollection &symbolTable);
SymbolTableCollection &symbolTableCollection);

} // namespace mesh
} // namespace mlir
Expand Down
21 changes: 13 additions & 8 deletions mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
Expand Up @@ -23,7 +23,8 @@
namespace mlir {
namespace mesh {

void populateSimplificationPatterns(RewritePatternSet &patterns) {
void populateSimplificationPatterns(
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
populateAllReduceEndomorphismSimplificationPatterns<arith::AddFOp>(
patterns, Partial::Sum);
populateAllReduceEndomorphismSimplificationPatterns<arith::AddIOp>(
Expand All @@ -44,6 +45,8 @@ void populateSimplificationPatterns(RewritePatternSet &patterns) {
patterns, Partial::Max);

// TODO: add simplifications for all-gather and other collectives.

populateFoldingPatterns(patterns, symbolTableCollection);
}

namespace {
Expand All @@ -55,16 +58,17 @@ namespace {
// pass changing the referenced ClusterOp ops.
struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
template <typename... OpRewritePatternArgs>
ClusterShapeFolder(SymbolTableCollection &symbolTable,
ClusterShapeFolder(SymbolTableCollection &symbolTableCollection,
OpRewritePatternArgs &&...opRewritePatternArgs)
: OpRewritePattern(
std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
symbolTable(symbolTable) {}
symbolTableCollection(symbolTableCollection) {}
LogicalResult matchAndRewrite(ClusterShapeOp op,
PatternRewriter &rewriter) const override {
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
ClusterOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(
op.getOperation(), op.getMeshAttr());
ClusterOp mesh =
symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
op.getOperation(), op.getMeshAttr());
if (!mesh) {
return failure();
}
Expand Down Expand Up @@ -111,14 +115,15 @@ struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
}

private:
SymbolTableCollection &symbolTable;
SymbolTableCollection &symbolTableCollection;
};

} // namespace

void populateFoldingPatterns(RewritePatternSet &patterns,
SymbolTableCollection &symbolTable) {
patterns.add<ClusterShapeFolder>(symbolTable, patterns.getContext());
SymbolTableCollection &symbolTableCollection) {
patterns.add<ClusterShapeFolder>(symbolTableCollection,
patterns.getContext());
}

} // namespace mesh
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Mesh/folding.mlir
@@ -1,4 +1,4 @@
// RUN: mlir-opt -test-mesh-folding %s | FileCheck %s
// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s

mesh.cluster @mesh0(rank = 3, dim_sizes = 4x?x2)
mesh.cluster @mesh1(rank = 2, dim_sizes = 2x3)
Expand Down
1 change: 0 additions & 1 deletion mlir/test/lib/Dialect/Mesh/CMakeLists.txt
@@ -1,6 +1,5 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRMeshTest
TestFolding.cpp
TestReshardingSpmdization.cpp
TestSimplifications.cpp

Expand Down
52 changes: 0 additions & 52 deletions mlir/test/lib/Dialect/Mesh/TestFolding.cpp

This file was deleted.

8 changes: 6 additions & 2 deletions mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp
Expand Up @@ -9,6 +9,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

Expand All @@ -30,8 +31,11 @@ struct TestMeshSimplificationsPass

void TestMeshSimplificationsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
mesh::populateSimplificationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
SymbolTableCollection symbolTableCollection;
mesh::populateSimplificationPatterns(patterns, symbolTableCollection);
LogicalResult status =
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
assert(succeeded(status) && "Rewrite patters application did not converge.");
}

namespace mlir {
Expand Down
2 changes: 0 additions & 2 deletions mlir/tools/mlir-opt/mlir-opt.cpp
Expand Up @@ -118,7 +118,6 @@ void registerTestMathAlgebraicSimplificationPass();
void registerTestMathPolynomialApproximationPass();
void registerTestMemRefDependenceCheck();
void registerTestMemRefStrideCalculation();
void registerTestMeshFoldingPass();
void registerTestMeshSimplificationsPass();
void registerTestMeshReshardingSpmdizationPass();
void registerTestNextAccessPass();
Expand Down Expand Up @@ -241,7 +240,6 @@ void registerTestPasses() {
mlir::test::registerTestMathPolynomialApproximationPass();
mlir::test::registerTestMemRefDependenceCheck();
mlir::test::registerTestMemRefStrideCalculation();
mlir::test::registerTestMeshFoldingPass();
mlir::test::registerTestMeshSimplificationsPass();
mlir::test::registerTestMeshReshardingSpmdizationPass();
mlir::test::registerTestNextAccessPass();
Expand Down

0 comments on commit aad2791

Please sign in to comment.