Skip to content

Commit

Permalink
[mlir][mesh] Add folding of ClusterShapeOp (llvm#77033)
Browse files Browse the repository at this point in the history
If the mesh has static size on some of the requested axes, the result is
substituted with a constant.
  • Loading branch information
sogartar authored and justinfargnoli committed Jan 28, 2024
1 parent f728ce3 commit 7c49140
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 6 deletions.
10 changes: 9 additions & 1 deletion mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
Expand Up @@ -19,6 +19,9 @@
#include <utility>

namespace mlir {

class SymbolTableCollection;

namespace mesh {

// If we have an algebraic op like "+" and a summing all-reduce,
Expand Down Expand Up @@ -102,7 +105,12 @@ void populateAllReduceEndomorphismSimplificationPatterns(
AlgebraicOp::getOperationName(), 1, patterns.getContext()));
}

void populateSimplificationPatterns(RewritePatternSet &patterns);
// It is invalid to change ops that declare symbols during the application of
// these patterns, because symbolTableCollection is used to cache them.
void populateSimplificationPatterns(
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
void populateFoldingPatterns(RewritePatternSet &patterns,
SymbolTableCollection &symbolTableCollection);

} // namespace mesh
} // namespace mlir
Expand Down
93 changes: 92 additions & 1 deletion mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
Expand Up @@ -8,11 +8,23 @@

#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include <iterator>
#include <numeric>
#include <utility>

namespace mlir {
namespace mesh {

void populateSimplificationPatterns(RewritePatternSet &patterns) {
void populateSimplificationPatterns(
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
populateAllReduceEndomorphismSimplificationPatterns<arith::AddFOp>(
patterns, Partial::Sum);
populateAllReduceEndomorphismSimplificationPatterns<arith::AddIOp>(
Expand All @@ -33,6 +45,85 @@ void populateSimplificationPatterns(RewritePatternSet &patterns) {
patterns, Partial::Max);

// TODO: add simplifications for all-gather and other collectives.

populateFoldingPatterns(patterns, symbolTableCollection);
}

namespace {

// This folding can not be done with an operation's fold method or
// DialectFoldInterface, because it needs a SymbolTableCollection to cache the
// symbol tables.
// We can't use DialectFoldInterface since the cache may be invalidated by some
// pass changing the referenced ClusterOp ops.
struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
template <typename... OpRewritePatternArgs>
ClusterShapeFolder(SymbolTableCollection &symbolTableCollection,
OpRewritePatternArgs &&...opRewritePatternArgs)
: OpRewritePattern(
std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
symbolTableCollection(symbolTableCollection) {}
LogicalResult matchAndRewrite(ClusterShapeOp op,
PatternRewriter &rewriter) const override {
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
ClusterOp mesh =
symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
op.getOperation(), op.getMeshAttr());
if (!mesh) {
return failure();
}
ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
SmallVector<MeshAxis> opAxesIota;
if (opMeshAxes.empty()) {
opAxesIota.resize(mesh.getRank());
std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
opMeshAxes = opAxesIota;
}
if (llvm::all_of(opMeshAxes, [&mesh](MeshAxis axis) {
return ShapedType::isDynamic(mesh.getDimSizes()[axis]);
})) {
// All mesh dimensions are dynamic. Nothing to fold.
return failure();
}

SmallVector<Value> newResults(op->getResults().size());
SmallVector<MeshAxis> newShapeOpMeshAxes;
SmallVector<size_t> newToOldResultsIndexMap;

for (size_t i = 0; i < opMeshAxes.size(); ++i) {
auto meshAxisSize = mesh.getDimSizes()[opMeshAxes[i]];
if (ShapedType::isDynamic(meshAxisSize)) {
newToOldResultsIndexMap.push_back(i);
newShapeOpMeshAxes.push_back(opMeshAxes[i]);
} else {
// Fold static mesh axes.
newResults[i] = builder.create<arith::ConstantOp>(
builder.getIndexAttr(meshAxisSize));
}
}

// Leave only the dynamic mesh axes to be queried.
ClusterShapeOp newShapeOp =
builder.create<ClusterShapeOp>(mesh.getSymName(), newShapeOpMeshAxes);
for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
}

rewriter.replaceAllUsesWith(op.getResults(), newResults);

return success();
}

private:
SymbolTableCollection &symbolTableCollection;
};

} // namespace

void populateFoldingPatterns(RewritePatternSet &patterns,
SymbolTableCollection &symbolTableCollection) {
patterns.add<ClusterShapeFolder>(symbolTableCollection,
patterns.getContext());
}

} // namespace mesh
Expand Down
22 changes: 22 additions & 0 deletions mlir/test/Dialect/Mesh/folding.mlir
@@ -0,0 +1,22 @@
// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s

mesh.cluster @mesh0(rank = 3, dim_sizes = 4x?x2)
mesh.cluster @mesh1(rank = 2, dim_sizes = 2x3)

// CHECK-LABEL: func.func @cluster_shape_op_folding
func.func @cluster_shape_op_folding() -> (index, index) {
// CHECK: %[[AXIS_2_SIZE:.*]] = arith.constant 2 : index
// CHECK: %[[AXIS_1_SIZE:.*]] = mesh.cluster_shape @mesh0 axes = [1] : index
%0:2 = mesh.cluster_shape @mesh0 axes = [2, 1] : index, index
// CHECK: return %[[AXIS_2_SIZE]], %[[AXIS_1_SIZE]]
return %0#0, %0#1 : index, index
}

// CHECK-LABEL: func.func @cluster_shape_op_folding_all_axes_static_mesh
func.func @cluster_shape_op_folding_all_axes_static_mesh() -> (index, index) {
// CHECK: %[[AXIS_0_SIZE:.*]] = arith.constant 2 : index
// CHECK: %[[AXIS_1_SIZE:.*]] = arith.constant 3 : index
%0:2 = mesh.cluster_shape @mesh1 : index, index
// CHECK: return %[[AXIS_0_SIZE]], %[[AXIS_1_SIZE]]
return %0#0, %0#1 : index, index
}
2 changes: 1 addition & 1 deletion mlir/test/lib/Dialect/Mesh/CMakeLists.txt
@@ -1,5 +1,5 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRMeshTestSimplifications
add_mlir_library(MLIRMeshTest
TestReshardingSpmdization.cpp
TestSimplifications.cpp

Expand Down
8 changes: 6 additions & 2 deletions mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp
Expand Up @@ -9,6 +9,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

Expand All @@ -30,8 +31,11 @@ struct TestMeshSimplificationsPass

void TestMeshSimplificationsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
mesh::populateSimplificationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
SymbolTableCollection symbolTableCollection;
mesh::populateSimplificationPatterns(patterns, symbolTableCollection);
LogicalResult status =
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
assert(succeeded(status) && "Rewrite patters application did not converge.");
}

namespace mlir {
Expand Down
2 changes: 1 addition & 1 deletion mlir/tools/mlir-opt/CMakeLists.txt
Expand Up @@ -26,7 +26,7 @@ if(MLIR_INCLUDE_TESTS)
MLIRLoopLikeInterfaceTestPasses
MLIRMathTestPasses
MLIRMemRefTestPasses
MLIRMeshTestSimplifications
MLIRMeshTest
MLIRNVGPUTestPasses
MLIRSCFTestPasses
MLIRShapeTestPasses
Expand Down

0 comments on commit 7c49140

Please sign in to comment.