Skip to content

Commit

Permalink
[mlir][mesh] Add folding of ClusterShapeOp
Browse files Browse the repository at this point in the history
If the mesh has static size on some of the requested axes,
the result is substituted with a constant.
  • Loading branch information
sogartar committed Jan 9, 2024
1 parent 03a0bfa commit 3706d6f
Show file tree
Hide file tree
Showing 7 changed files with 172 additions and 2 deletions.
7 changes: 7 additions & 0 deletions mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
#include <utility>

namespace mlir {

class SymbolTable;

namespace mesh {

// If we have an algebraic op like "+" and a summing all-reduce,
Expand Down Expand Up @@ -103,6 +106,10 @@ void populateAllReduceEndomorphismSimplificationPatterns(
}

void populateSimplificationPatterns(RewritePatternSet &patterns);
// It is invalid to change ops that declare symbols during the application of
// these patterns, because symbolTable is used to cache them.
void populateFoldingPatterns(RewritePatternSet &patterns,
SymbolTableCollection &symbolTable);

} // namespace mesh
} // namespace mlir
Expand Down
86 changes: 86 additions & 0 deletions mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,17 @@

#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include <iterator>
#include <numeric>
#include <utility>

namespace mlir {
namespace mesh {
Expand Down Expand Up @@ -35,5 +46,80 @@ void populateSimplificationPatterns(RewritePatternSet &patterns) {
// TODO: add simplifications for all-gather and other collectives.
}

namespace {

// This folding can not be done with an operation's fold method or
// DialectFoldInterface, because it needs a SymbolTableCollection to cache the
// symbol tables.
// We can't use DialectFoldInterface since the cache may be invalidated by some
// pass changing the referenced ClusterOp ops.
struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
template <typename... OpRewritePatternArgs>
ClusterShapeFolder(SymbolTableCollection &symbolTable,
OpRewritePatternArgs &&...opRewritePatternArgs)
: OpRewritePattern(
std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
symbolTable(symbolTable) {}
LogicalResult matchAndRewrite(ClusterShapeOp op,
PatternRewriter &rewriter) const override {
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
ClusterOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(
op.getOperation(), op.getMeshAttr());
if (!mesh) {
return failure();
}
ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
SmallVector<MeshAxis> opAxesIota;
if (opMeshAxes.empty()) {
opAxesIota.resize(mesh.getRank());
std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
opMeshAxes = opAxesIota;
}
if (llvm::all_of(opMeshAxes, [&mesh](MeshAxis axis) {
return ShapedType::isDynamic(mesh.getDimSizes()[axis]);
})) {
// All mesh dimensions are dynamic. Nothing to fold.
return failure();
}

SmallVector<Value> newResults(op->getResults().size());
SmallVector<MeshAxis> newShapeOpMeshAxes;
SmallVector<size_t> newToOldResultsIndexMap;

for (size_t i = 0; i < opMeshAxes.size(); ++i) {
auto meshAxisSize = mesh.getDimSizes()[opMeshAxes[i]];
if (ShapedType::isDynamic(meshAxisSize)) {
newToOldResultsIndexMap.push_back(i);
newShapeOpMeshAxes.push_back(opMeshAxes[i]);
} else {
// Fold static mesh axes.
newResults[i] = builder.create<arith::ConstantOp>(
builder.getIndexAttr(meshAxisSize));
}
}

// Leave only the dynamic mesh axes to be queried.
ClusterShapeOp newShapeOp =
builder.create<ClusterShapeOp>(mesh.getSymName(), newShapeOpMeshAxes);
for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
}

rewriter.replaceAllUsesWith(op.getResults(), newResults);

return success();
}

private:
SymbolTableCollection &symbolTable;
};

} // namespace

void populateFoldingPatterns(RewritePatternSet &patterns,
SymbolTableCollection &symbolTable) {
patterns.add<ClusterShapeFolder>(symbolTable, patterns.getContext());
}

} // namespace mesh
} // namespace mlir
22 changes: 22 additions & 0 deletions mlir/test/Dialect/Mesh/folding.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// RUN: mlir-opt -test-mesh-folding %s | FileCheck %s

mesh.cluster @mesh0(rank = 3, dim_sizes = 4x?x2)
mesh.cluster @mesh1(rank = 2, dim_sizes = 2x3)

// CHECK-LABEL: func.func @cluster_shape_op_folding
func.func @cluster_shape_op_folding() -> (index, index) {
// CHECK: %[[AXIS_2_SIZE:.*]] = arith.constant 2 : index
// CHECK: %[[AXIS_1_SIZE:.*]] = mesh.cluster_shape @mesh0 axes = [1] : index
%0:2 = mesh.cluster_shape @mesh0 axes = [2, 1] : index, index
// CHECK: return %[[AXIS_2_SIZE]], %[[AXIS_1_SIZE]]
return %0#0, %0#1 : index, index
}

// CHECK-LABEL: func.func @cluster_shape_op_folding_all_axes_static_mesh
func.func @cluster_shape_op_folding_all_axes_static_mesh() -> (index, index) {
// CHECK: %[[AXIS_0_SIZE:.*]] = arith.constant 2 : index
// CHECK: %[[AXIS_1_SIZE:.*]] = arith.constant 3 : index
%0:2 = mesh.cluster_shape @mesh1 : index, index
// CHECK: return %[[AXIS_0_SIZE]], %[[AXIS_1_SIZE]]
return %0#0, %0#1 : index, index
}
3 changes: 2 additions & 1 deletion mlir/test/lib/Dialect/Mesh/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRMeshTestSimplifications
add_mlir_library(MLIRMeshTest
TestFolding.cpp
TestReshardingSpmdization.cpp
TestSimplifications.cpp

Expand Down
52 changes: 52 additions & 0 deletions mlir/test/lib/Dialect/Mesh/TestFolding.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
//===- TestSimplification.cpp - Test simplification -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <memory>

using namespace mlir;

namespace {

struct TestMeshFoldingPass
: public PassWrapper<TestMeshFoldingPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshFoldingPass)

void runOnOperation() override;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<mesh::MeshDialect>();
}
StringRef getArgument() const final { return "test-mesh-folding"; }
StringRef getDescription() const final { return "Test mesh folding."; }
};
} // namespace

void TestMeshFoldingPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
SymbolTableCollection symbolTables;
mesh::populateFoldingPatterns(patterns, symbolTables);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
getOperation()->emitError()
<< "Rewrite patter application did not converge.";
return signalPassFailure();
}
}

namespace mlir {
namespace test {
void registerTestMeshFoldingPass() { PassRegistration<TestMeshFoldingPass>(); }
} // namespace test
} // namespace mlir
2 changes: 1 addition & 1 deletion mlir/tools/mlir-opt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ if(MLIR_INCLUDE_TESTS)
MLIRLoopLikeInterfaceTestPasses
MLIRMathTestPasses
MLIRMemRefTestPasses
MLIRMeshTestSimplifications
MLIRMeshTest
MLIRNVGPUTestPasses
MLIRSCFTestPasses
MLIRShapeTestPasses
Expand Down
2 changes: 2 additions & 0 deletions mlir/tools/mlir-opt/mlir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ void registerTestMathAlgebraicSimplificationPass();
void registerTestMathPolynomialApproximationPass();
void registerTestMemRefDependenceCheck();
void registerTestMemRefStrideCalculation();
void registerTestMeshFoldingPass();
void registerTestMeshSimplificationsPass();
void registerTestMeshReshardingSpmdizationPass();
void registerTestNextAccessPass();
Expand Down Expand Up @@ -240,6 +241,7 @@ void registerTestPasses() {
mlir::test::registerTestMathPolynomialApproximationPass();
mlir::test::registerTestMemRefDependenceCheck();
mlir::test::registerTestMemRefStrideCalculation();
mlir::test::registerTestMeshFoldingPass();
mlir::test::registerTestMeshSimplificationsPass();
mlir::test::registerTestMeshReshardingSpmdizationPass();
mlir::test::registerTestNextAccessPass();
Expand Down

0 comments on commit 3706d6f

Please sign in to comment.