-
Notifications
You must be signed in to change notification settings - Fork 11.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][mesh] Add endomorphism simplification for all-reduce #73150
[mlir][mesh] Add endomorphism simplification for all-reduce #73150
Conversation
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: Boian Petkantchin (sogartar) ChangesDoes transformations like
Patch is 24.10 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/73150.diff 11 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
index 83399d10beaae48..9b788d3f304c2c8 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
@@ -17,6 +17,8 @@ namespace func {
class FuncOp;
}
+class RewritePatternSet;
+
namespace mesh {
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
new file mode 100644
index 000000000000000..1af0f52114f10e9
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
@@ -0,0 +1,89 @@
+//===- Simplifications.h - Mesh Simplifications -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
+#define MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
+
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/EndomorphismSimplification.h"
+#include "llvm/Support/Casting.h"
+#include <algorithm>
+#include <iterator>
+#include <memory>
+#include <utility>
+
+namespace mlir {
+namespace mesh {
+
+template <typename AlgebraicOp>
+void populateAllReduceEndomorphismSimplificationPatterns(
+ RewritePatternSet &patterns, Partial reduction) {
+ auto getEndomorphismOpOperand = [](Operation *op) {
+ auto allReduceOp = llvm::cast<AllReduceOp>(op);
+ return &allReduceOp.getInputMutable();
+ };
+ auto getEndomorphismOpResult = [](Operation *op) {
+ auto allReduceOp = llvm::cast<AllReduceOp>(op);
+ return allReduceOp->getResult(0);
+ };
+ auto getAlgebraicOpOperands = [](Operation *op,
+ SmallVector<OpOperand *> &operands) {
+ auto algebraicOp = llvm::cast<AlgebraicOp>(op);
+ std::transform(algebraicOp->getOpOperands().begin(),
+ algebraicOp->getOpOperands().end(),
+ std::back_inserter(operands),
+ [](OpOperand &operand) { return &operand; });
+ };
+ auto getAlgebraicOpResult = [](Operation *op) {
+ auto algebraicOp = llvm::cast<AlgebraicOp>(op);
+ return algebraicOp->getResult(0);
+ };
+ auto isEndomorphismOp = [reduction](Operation *op,
+ std::optional<Operation *> referenceOp) {
+ auto allReduceOp = llvm::dyn_cast<AllReduceOp>(op);
+ if (!allReduceOp ||
+ allReduceOp.getInput().getType().getElementType() !=
+ allReduceOp.getResult().getType().getElementType() ||
+ allReduceOp.getReduction() != reduction) {
+ return false;
+ }
+
+ if (!referenceOp) {
+ return true;
+ }
+
+ auto refAllReduceOp = llvm::dyn_cast<AllReduceOp>(referenceOp.value());
+ return refAllReduceOp->getAttrs() == allReduceOp->getAttrs() &&
+ allReduceOp.getInput().getType().getElementType() ==
+ refAllReduceOp.getInput().getType().getElementType();
+ };
+ auto isAlgebraicOp = [](Operation *op) {
+ return static_cast<bool>(llvm::dyn_cast<AlgebraicOp>(op));
+ };
+
+ using ConcreteEndomorphismSimplification = EndomorphismSimplification<
+ std::decay_t<decltype(getEndomorphismOpOperand)>,
+ std::decay_t<decltype(getEndomorphismOpResult)>,
+ std::decay_t<decltype(getAlgebraicOpOperands)>,
+ std::decay_t<decltype(getAlgebraicOpResult)>,
+ std::decay_t<decltype(isEndomorphismOp)>,
+ std::decay_t<decltype(isAlgebraicOp)>>;
+ patterns.add(std::make_unique<ConcreteEndomorphismSimplification>(
+ std::move(getEndomorphismOpOperand), std::move(getEndomorphismOpResult),
+ std::move(getAlgebraicOpOperands), std::move(getAlgebraicOpResult),
+ std::move(isEndomorphismOp), std::move(isAlgebraicOp),
+ AlgebraicOp::getOperationName(), 1, patterns.getContext()));
+}
+
+void populateSimplificationPatterns(RewritePatternSet &patterns);
+
+} // namespace mesh
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
diff --git a/mlir/include/mlir/Transforms/EndomorphismSimplification.h b/mlir/include/mlir/Transforms/EndomorphismSimplification.h
new file mode 100644
index 000000000000000..b2bc7377a80fff3
--- /dev/null
+++ b/mlir/include/mlir/Transforms/EndomorphismSimplification.h
@@ -0,0 +1,163 @@
+//===- RegionUtils.h - Region-related transformation utilities --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_SIMPLIFY_ENDOMORPHISM_H_
+#define MLIR_TRANSFORMS_SIMPLIFY_ENDOMORPHISM_H_
+
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
+#include <iterator>
+#include <optional>
+#include <type_traits>
+#include <utility>
+
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+
+// If `f` is an endomorphism with respect to the algebraic structure induced by
+// function `g`, transforms `g(f(x1), f(x2) ..., f(xn))` into
+// `f(g(x1, x2, ..., xn))`.
+// `g` is the algebraic operation and `f` is the endomorphism.
+//
+// Functors:
+// ---------
+// `GetEndomorphismOpOperandFn`: `(Operation*) -> OpOperand*`
+// Returns the operand relevant to the endomorphism.
+// There may be other operands that are not relevant.
+//
+// `GetEndomorphismOpResultFn`: `(Operation*) -> OpResult`
+// Returns the result relevant to the endomorphism.
+//
+// `GetAlgebraicOpOperandsFn`: `(Operation*, SmallVector<OpOperand*>&) -> void`
+// Populates into the vector the operands relevant to the endomorphism.
+//
+// `GetAlgebraicOpResultFn`: `(Operation*) -> OpResult`
+// Return the result relevant to the endomorphism.
+//
+// `IsEndomorphismOpFn`: `(Operation*, std::optional<Operation*>) -> bool`
+// Check if the operation is an endomorphism of the required type.
+// Additionally if the optional is present checks if the operations are
+// compatible endomorphisms.
+//
+// `IsAlgebraicOpFn`: `(Operation*) -> bool`
+// Check if the operation is an operation of the algebraic structure.
+template <typename GetEndomorphismOpOperandFn,
+ typename GetEndomorphismOpResultFn, typename GetAlgebraicOpOperandsFn,
+ typename GetAlgebraicOpResultFn, typename IsEndomorphismOpFn,
+ typename IsAlgebraicOpFn>
+struct EndomorphismSimplification : RewritePattern {
+ template <typename GetEndomorphismOpOperandFnArg,
+ typename GetEndomorphismOpResultFnArg,
+ typename GetAlgebraicOpOperandsFnArg,
+ typename GetAlgebraicOpResultFnArg, typename IsEndomorphismOpFnArg,
+ typename IsAlgebraicOpFnArg, typename... RewritePatternArgs>
+ EndomorphismSimplification(
+ GetEndomorphismOpOperandFnArg &&getEndomorphismOpOperand,
+ GetEndomorphismOpResultFnArg &&getEndomorphismOpResult,
+ GetAlgebraicOpOperandsFnArg &&getAlgebraicOpOperands,
+ GetAlgebraicOpResultFnArg &&getAlgebraicOpResult,
+ IsEndomorphismOpFnArg &&isEndomorphismOp,
+ IsAlgebraicOpFnArg &&isAlgebraicOp, RewritePatternArgs &&...args)
+ : RewritePattern(std::forward<RewritePatternArgs>(args)...),
+ getEndomorphismOpOperand(std::forward<GetEndomorphismOpOperandFnArg>(
+ getEndomorphismOpOperand)),
+ getEndomorphismOpResult(std::forward<GetEndomorphismOpResultFnArg>(
+ getEndomorphismOpResult)),
+ getAlgebraicOpOperands(
+ std::forward<GetAlgebraicOpOperandsFnArg>(getAlgebraicOpOperands)),
+ getAlgebraicOpResult(
+ std::forward<GetAlgebraicOpResultFnArg>(getAlgebraicOpResult)),
+ isEndomorphismOp(std::forward<IsEndomorphismOpFnArg>(isEndomorphismOp)),
+ isAlgebraicOp(std::forward<IsAlgebraicOpFnArg>(isAlgebraicOp)) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ if (failed(matchOp(op, algebraicOpOperands))) {
+ return failure();
+ }
+ return rewriteOp(op, algebraicOpOperands, rewriter);
+ }
+
+private:
+ LogicalResult matchOp(Operation *algebraicOp,
+ SmallVector<OpOperand *> &algebraicOpOperands) const {
+ if (!isAlgebraicOp(algebraicOp)) {
+ return failure();
+ }
+ algebraicOpOperands.clear();
+ getAlgebraicOpOperands(algebraicOp, algebraicOpOperands);
+ if (algebraicOpOperands.empty()) {
+ return failure();
+ }
+
+ Operation *firstEndomorphismOp =
+ algebraicOpOperands.front()->get().getDefiningOp();
+ if (!firstEndomorphismOp ||
+ !isEndomorphismOp(firstEndomorphismOp, std::nullopt)) {
+ return failure();
+ }
+ OpResult firstEndomorphismOpResult =
+ getEndomorphismOpResult(firstEndomorphismOp);
+ if (getEndomorphismOpResult(firstEndomorphismOp) !=
+ algebraicOpOperands.front()->get()) {
+ return failure();
+ }
+
+ for (auto operand : algebraicOpOperands) {
+ Operation *endomorphismOp = operand->get().getDefiningOp();
+ if (!endomorphismOp ||
+ !isEndomorphismOp(endomorphismOp, firstEndomorphismOp)) {
+ return failure();
+ }
+ }
+ return success();
+ }
+
+ LogicalResult rewriteOp(Operation *algebraicOp,
+ const SmallVector<OpOperand *> &algebraicOpOperands,
+ PatternRewriter &rewriter) const {
+ irMapping.clear();
+ for (auto operand : algebraicOpOperands) {
+ Operation *endomorphismOp = operand->get().getDefiningOp();
+ irMapping.map(operand->get(),
+ getEndomorphismOpOperand(endomorphismOp)->get());
+ }
+ Operation *newAlgebraicOp = rewriter.clone(*algebraicOp, irMapping);
+
+ irMapping.clear();
+ assert(!algebraicOpOperands.empty());
+ Operation *firstEndomorphismOp =
+ algebraicOpOperands[0]->get().getDefiningOp();
+ irMapping.map(getEndomorphismOpOperand(firstEndomorphismOp)->get(),
+ getAlgebraicOpResult(newAlgebraicOp));
+ Operation *newEndomorphismOp =
+ rewriter.clone(*firstEndomorphismOp, irMapping);
+ rewriter.replaceAllUsesWith(getAlgebraicOpResult(algebraicOp),
+ getEndomorphismOpResult(newEndomorphismOp));
+ return success();
+ }
+
+ GetEndomorphismOpOperandFn getEndomorphismOpOperand;
+ GetEndomorphismOpResultFn getEndomorphismOpResult;
+ GetAlgebraicOpOperandsFn getAlgebraicOpOperands;
+ GetAlgebraicOpResultFn getAlgebraicOpResult;
+ IsEndomorphismOpFn isEndomorphismOp;
+ IsAlgebraicOpFn isAlgebraicOp;
+ mutable SmallVector<OpOperand *> algebraicOpOperands;
+ mutable IRMapping irMapping;
+};
+
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_SIMPLIFY_ENDOMORPHISM_H_
diff --git a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
index bcf45c4ea276080..044b8672c8c60cf 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRMeshTransforms
+ Simplifications.cpp
ShardingPropagation.cpp
ADDITIONAL_HEADER_DIRS
@@ -9,6 +10,7 @@ add_mlir_dialect_library(MLIRMeshTransforms
MLIRShardingInterface
LINK_LIBS PUBLIC
+ MLIRArithDialect
MLIRFuncDialect
MLIRIR
MLIRMeshDialect
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
new file mode 100644
index 000000000000000..1d241fe03a127b8
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
@@ -0,0 +1,37 @@
+//===- Patterns.cpp - Mesh Patterns -----------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+
+namespace mlir {
+namespace mesh {
+
+void populateSimplificationPatterns(RewritePatternSet &patterns) {
+ populateAllReduceEndomorphismSimplificationPatterns<arith::AddFOp>(
+ patterns, Partial::Sum);
+ populateAllReduceEndomorphismSimplificationPatterns<arith::AddIOp>(
+ patterns, Partial::Sum);
+
+ populateAllReduceEndomorphismSimplificationPatterns<arith::MinimumFOp>(
+ patterns, Partial::Min);
+ populateAllReduceEndomorphismSimplificationPatterns<arith::MinSIOp>(
+ patterns, Partial::Min);
+ populateAllReduceEndomorphismSimplificationPatterns<arith::MinUIOp>(
+ patterns, Partial::Min);
+
+ populateAllReduceEndomorphismSimplificationPatterns<arith::MaximumFOp>(
+ patterns, Partial::Max);
+ populateAllReduceEndomorphismSimplificationPatterns<arith::MaxSIOp>(
+ patterns, Partial::Max);
+ populateAllReduceEndomorphismSimplificationPatterns<arith::MaxUIOp>(
+ patterns, Partial::Max);
+}
+
+} // namespace mesh
+} // namespace mlir
diff --git a/mlir/test/Dialect/Mesh/simplifications.mlir b/mlir/test/Dialect/Mesh/simplifications.mlir
new file mode 100644
index 000000000000000..2b305df6e0a97f1
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/simplifications.mlir
@@ -0,0 +1,131 @@
+// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [4, 2])
+mesh.cluster @mesh1(rank = 1, dim_sizes = [4])
+
+// Checks that `all_reduce(x) + all_reduce(y)` gets transformed to
+// `all_reduce(x + y)`.
+// CHECK-LABEL: func.func @all_reduce_arith_addf_endomorphism
+func.func @all_reduce_arith_addf_endomorphism(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg0: tensor<5xf32>,
+ // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg1: tensor<5xf32>) -> tensor<5xf32> {
+ %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+ : tensor<5xf32> -> tensor<5xf32>
+ %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+ : tensor<5xf32> -> tensor<5xf32>
+ // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ARG0]], %[[ARG1]]
+ %2 = arith.addf %0, %1 : tensor<5xf32>
+ // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]]
+ // CHECK: return %[[ALL_REDUCE_RES]]
+ return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_mesh
+func.func @all_reduce_arith_addf_no_endomorphism_different_mesh(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg0: tensor<5xf32>,
+ // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg1: tensor<5xf32>) -> tensor<5xf32> {
+ // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0
+ %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+ : tensor<5xf32> -> tensor<5xf32>
+ // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh1
+ %1 = mesh.all_reduce %arg1 on @mesh1 mesh_axes = [0]
+ : tensor<5xf32> -> tensor<5xf32>
+ // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
+ %2 = arith.addf %0, %1 : tensor<5xf32>
+ // CHECK: return %[[ADD_RES]]
+ return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_mesh_axes
+func.func @all_reduce_arith_addf_no_endomorphism_different_mesh_axes(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg0: tensor<5xf32>,
+ // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg1: tensor<5xf32>) -> tensor<5xf32> {
+ // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0]
+ %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+ : tensor<5xf32> -> tensor<5xf32>
+ // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [1]
+ %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [1]
+ : tensor<5xf32> -> tensor<5xf32>
+ // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
+ %2 = arith.addf %0, %1 : tensor<5xf32>
+ // CHECK: return %[[ADD_RES]]
+ return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_wrong_reduction_kind
+func.func @all_reduce_arith_addf_no_endomorphism_wrong_reduction_kind(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg0: tensor<5xf32>,
+ // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg1: tensor<5xf32>) -> tensor<5xf32> {
+ // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0] reduction = <max>
+ %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = <max>
+ : tensor<5xf32> -> tensor<5xf32>
+ // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [0]
+ %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+ : tensor<5xf32> -> tensor<5xf32>
+ // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
+ %2 = arith.addf %0, %1 : tensor<5xf32>
+ // CHECK: return %[[ADD_RES]]
+ return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_operand_result_element_types
+func.func @all_reduce_arith_addf_no_endomorphism_different_operand_result_element_types(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg0: tensor<5xf32>,
+ // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg1: tensor<5xf32>) -> tensor<5xf64> {
+ // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0]
+ %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+ : tensor<5xf32> -> tensor<5xf64>
+ // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [0]
+ %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+ : tensor<5xf32> -> tensor<5xf64>
+ // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
+ %2 = arith.addf %0, %1 : tensor<5xf64>
+ // CHECK: return %[[ADD_RES]]
+ return %2 : tensor<5xf64>
+}
+
+// Checks that `min(all_reduce(x), all_reduce(y))` gets transformed to
+// `all_reduce(min(x, y))`.
+// CHECK-LABEL: func.func @all_reduce_arith_minimumf_endomorphism
+func.func @all_reduce_arith_minimumf_endomorphism(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg0: tensor<5xf32>,
+ // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg1: tensor<5xf32>) -> tensor<5xf32> {
+ %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = <min>
+ : tensor<5xf32> -> tensor<5xf32>
+ %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = <min>
+ : tensor<5xf32> -> tensor<5xf32>
+ // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.minimumf %[[ARG0]], %[[ARG1]]
+ %2 = arith.minimumf %0, %1 : tensor<5xf32>
+ // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = <min>
+ // CHECK: return %[[ALL_REDUCE_RES]]
+ return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_minsi_endomorphism
+func.func @all_reduce_arith_minsi_endomorphism(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xi32>
+ %arg0: tensor<5xi32>,
+ // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xi32>
+ %arg1: tensor<5xi32>) -> tensor<5xi32> {
+ %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = <min>
+ : tensor<5xi32> -> tensor<5xi32>
+ %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = <min>
+ : tensor<5xi32> -> tensor<5xi32>
+ // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.minsi %[[ARG0]], %[[ARG1]]
+ %2 = arith.minsi %0, %1 : tensor<5xi32>
+ // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = <min>
+ // CHECK: return %[[ALL_REDUCE_RES]]
+ return %2 : ten...
[truncated]
|
@yaochengji could you review this? |
363f6e5
to
8df8ae3
Compare
@@ -17,6 +17,8 @@ namespace func { | |||
class FuncOp; | |||
} | |||
|
|||
class RewritePatternSet; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems the forward declaration is not needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed it.
void populateAllReduceEndomorphismSimplificationPatterns( | ||
RewritePatternSet &patterns, Partial reduction) { | ||
auto getEndomorphismOpOperand = [](Operation *op) { | ||
auto allReduceOp = llvm::cast<AllReduceOp>(op); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems all_gather
could also benefit from this. I suggest we could mark it as a TODO here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a comment in populateSimplificationPatterns
.
}; | ||
auto getAlgebraicOpOperands = [](Operation *op, | ||
SmallVector<OpOperand *> &operands) { | ||
auto algebraicOp = llvm::cast<AlgebraicOp>(op); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if AlgebraicOp is a destination style op? I think the logic should be different, but we can add a TODO here first.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean a scenario like this?
h
is the endomorphism and a
is the algebraic structure op.
Then
a(h(x), h(y), z) = h(a(x, y, z))
does not actually hold.
If that is the case you may be able to do homomorphism simplification where the target algebraic structure op b
is different. And we have
a(h(x), h(y), z) = h(b(x, y, z))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, what I meant is ops in linalg dialect which have the DestinationStyleOpInterface
. For a linalg.matmul op
%matmul = linalg.matmul ins(%0, %1 : tensor<1x1xi8>, tensor<1x?xi8>) outs(%2 : tensor<1x?xi32>) -> tensor<1x?xi32>
Here the %2
operand is different from %0
and %1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I expanded the documentation of the function.
@@ -0,0 +1,131 @@ | |||
// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happened if the all_reduce
has multi users?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This pass does not remove code it just adds. My expectation is to fix the side-effecting nature of the collectives and allow for DCE to remove the unused code. I think in the context of SPMD we ca reason that all_reduce
is not side-effecting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sogartar seems the multi-users case is not addressed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean not to do the simplification in that case since there may not be a performance benefit?
A more complete approach would be if somehow we knew if the other use will be also removed.
Maybe we can have another rewrite pattern that reverses the endomorphism simplification if the other use is still alive. Then we apply it in another pass to avoid infinite loop of rewrites.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The other option is to avoid doing the simplification in the first place, but how would we know in a generic way if the other use will be removed or not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's a complex problem. A common approach is to leave it unsimplified. XLA has a pass with similar logic, maybe we can check its logic there
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made it not do the simplification in such a scenario.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lol. Didn't think I was starting a review, so my comments were not submitted.
// CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0] | ||
%0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] | ||
: tensor<5xf32> -> tensor<5xf32> | ||
// CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This simplification does not remove dead code. The previous all_reduce
s will still be there. They will not be used though.
We need to decide on the side-effecting nature of the collective operations. Whether we are in the context of SPMD or MPMD?
@@ -0,0 +1,131 @@ | |||
// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This pass does not remove code it just adds. My expectation is to fix the side-effecting nature of the collectives and allow for DCE to remove the unused code. I think in the context of SPMD we ca reason that all_reduce
is not side-effecting.
}; | ||
auto getAlgebraicOpOperands = [](Operation *op, | ||
SmallVector<OpOperand *> &operands) { | ||
auto algebraicOp = llvm::cast<AlgebraicOp>(op); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean a scenario like this?
h
is the endomorphism and a
is the algebraic structure op.
Then
a(h(x), h(y), z) = h(a(x, y, z))
does not actually hold.
If that is the case you may be able to do homomorphism simplification where the target algebraic structure op b
is different. And we have
a(h(x), h(y), z) = h(b(x, y, z))
void populateAllReduceEndomorphismSimplificationPatterns( | ||
RewritePatternSet &patterns, Partial reduction) { | ||
auto getEndomorphismOpOperand = [](Operation *op) { | ||
auto allReduceOp = llvm::cast<AllReduceOp>(op); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a comment in populateSimplificationPatterns
.
@@ -17,6 +17,8 @@ namespace func { | |||
class FuncOp; | |||
} | |||
|
|||
class RewritePatternSet; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed it.
@@ -0,0 +1,131 @@ | |||
// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean not to do the simplification in that case since there may not be a performance benefit?
A more complete approach would be if somehow we knew if the other use will be also removed.
Maybe we can have another rewrite pattern that reverses the endomorphism simplification if the other use is still alive. Then we apply it in another pass to avoid infinite loop of rewrites.
@@ -0,0 +1,131 @@ | |||
// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The other option is to avoid doing the simplification in the first place, but how would we know in a generic way if the other use will be removed or not?
@yaochengji could you take a look again? |
LGTM, thanks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG, thanks
Does transformations like all_reduce(x) + all_reduce(y) -> all_reduce(x + y) max(all_reduce(x), all_reduce(y)) -> all_reduce(max(x, y)) when the all_reduce element-wise op is max.
… context of the endomorphism
…rphismSimplification rewrite pattern
53b0d2e
to
fab21a1
Compare
Thank you for the review I rebased it before merging to check that all is well. |
Does transformations like
all_reduce(x) + all_reduce(y) -> all_reduce(x + y)
max(all_reduce(x), all_reduce(y)) -> all_reduce(max(x, y))
when the
all_reduce
element-wise op ismax
.In this PR I added a general rewrite pattern
EndomorphismSimplification
where I tried to isolate the general rewrite algorithm. I can split this in 2 PRs.I see a path to generalize this to homomorphisms to allow for generalizing patterns like
exp(x)*exp(y) -> exp(x + y)
,but have not done it in this PR.
Update:
I made the endomorphism rewrite pattern derive from homomorphism rewrite pattern.