Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][mesh] Add endomorphism simplification for all-reduce #73150

Merged
merged 11 commits into from
Dec 12, 2023

Conversation

sogartar
Copy link
Contributor

@sogartar sogartar commented Nov 22, 2023

Does transformations like
all_reduce(x) + all_reduce(y) -> all_reduce(x + y)

max(all_reduce(x), all_reduce(y)) -> all_reduce(max(x, y))
when the all_reduce element-wise op is max.

In this PR I added a general rewrite pattern EndomorphismSimplification where I tried to isolate the general rewrite algorithm. I can split this in 2 PRs.
I see a path to generalize this to homomorphisms to allow for generalizing patterns like
exp(x)*exp(y) -> exp(x + y),
but have not done it in this PR.
Update:
I made the endomorphism rewrite pattern derive from homomorphism rewrite pattern.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Nov 22, 2023
@llvmbot
Copy link
Collaborator

llvmbot commented Nov 22, 2023

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Boian Petkantchin (sogartar)

Changes

Does transformations like
all_reduce(x) + all_reduce(y) -> all_reduce(x + y)

max(all_reduce(x), all_reduce(y)) -> all_reduce(max(x, y))
when the all_reduce element-wise op is max.


Patch is 24.10 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/73150.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h (+2)
  • (added) mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h (+89)
  • (added) mlir/include/mlir/Transforms/EndomorphismSimplification.h (+163)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt (+2)
  • (added) mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp (+37)
  • (added) mlir/test/Dialect/Mesh/simplifications.mlir (+131)
  • (modified) mlir/test/lib/Dialect/CMakeLists.txt (+1)
  • (added) mlir/test/lib/Dialect/Mesh/CMakeLists.txt (+11)
  • (added) mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp (+43)
  • (modified) mlir/tools/mlir-opt/CMakeLists.txt (+1)
  • (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2)
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
index 83399d10beaae48..9b788d3f304c2c8 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
@@ -17,6 +17,8 @@ namespace func {
 class FuncOp;
 }
 
+class RewritePatternSet;
+
 namespace mesh {
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
new file mode 100644
index 000000000000000..1af0f52114f10e9
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
@@ -0,0 +1,89 @@
+//===- Simplifications.h - Mesh Simplifications -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
+#define MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
+
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/EndomorphismSimplification.h"
+#include "llvm/Support/Casting.h"
+#include <algorithm>
+#include <iterator>
+#include <memory>
+#include <utility>
+
+namespace mlir {
+namespace mesh {
+
+template <typename AlgebraicOp>
+void populateAllReduceEndomorphismSimplificationPatterns(
+    RewritePatternSet &patterns, Partial reduction) {
+  auto getEndomorphismOpOperand = [](Operation *op) {
+    auto allReduceOp = llvm::cast<AllReduceOp>(op);
+    return &allReduceOp.getInputMutable();
+  };
+  auto getEndomorphismOpResult = [](Operation *op) {
+    auto allReduceOp = llvm::cast<AllReduceOp>(op);
+    return allReduceOp->getResult(0);
+  };
+  auto getAlgebraicOpOperands = [](Operation *op,
+                                   SmallVector<OpOperand *> &operands) {
+    auto algebraicOp = llvm::cast<AlgebraicOp>(op);
+    std::transform(algebraicOp->getOpOperands().begin(),
+                   algebraicOp->getOpOperands().end(),
+                   std::back_inserter(operands),
+                   [](OpOperand &operand) { return &operand; });
+  };
+  auto getAlgebraicOpResult = [](Operation *op) {
+    auto algebraicOp = llvm::cast<AlgebraicOp>(op);
+    return algebraicOp->getResult(0);
+  };
+  auto isEndomorphismOp = [reduction](Operation *op,
+                                      std::optional<Operation *> referenceOp) {
+    auto allReduceOp = llvm::dyn_cast<AllReduceOp>(op);
+    if (!allReduceOp ||
+        allReduceOp.getInput().getType().getElementType() !=
+            allReduceOp.getResult().getType().getElementType() ||
+        allReduceOp.getReduction() != reduction) {
+      return false;
+    }
+
+    if (!referenceOp) {
+      return true;
+    }
+
+    auto refAllReduceOp = llvm::dyn_cast<AllReduceOp>(referenceOp.value());
+    return refAllReduceOp->getAttrs() == allReduceOp->getAttrs() &&
+           allReduceOp.getInput().getType().getElementType() ==
+               refAllReduceOp.getInput().getType().getElementType();
+  };
+  auto isAlgebraicOp = [](Operation *op) {
+    return static_cast<bool>(llvm::dyn_cast<AlgebraicOp>(op));
+  };
+
+  using ConcreteEndomorphismSimplification = EndomorphismSimplification<
+      std::decay_t<decltype(getEndomorphismOpOperand)>,
+      std::decay_t<decltype(getEndomorphismOpResult)>,
+      std::decay_t<decltype(getAlgebraicOpOperands)>,
+      std::decay_t<decltype(getAlgebraicOpResult)>,
+      std::decay_t<decltype(isEndomorphismOp)>,
+      std::decay_t<decltype(isAlgebraicOp)>>;
+  patterns.add(std::make_unique<ConcreteEndomorphismSimplification>(
+      std::move(getEndomorphismOpOperand), std::move(getEndomorphismOpResult),
+      std::move(getAlgebraicOpOperands), std::move(getAlgebraicOpResult),
+      std::move(isEndomorphismOp), std::move(isAlgebraicOp),
+      AlgebraicOp::getOperationName(), 1, patterns.getContext()));
+}
+
+void populateSimplificationPatterns(RewritePatternSet &patterns);
+
+} // namespace mesh
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
diff --git a/mlir/include/mlir/Transforms/EndomorphismSimplification.h b/mlir/include/mlir/Transforms/EndomorphismSimplification.h
new file mode 100644
index 000000000000000..b2bc7377a80fff3
--- /dev/null
+++ b/mlir/include/mlir/Transforms/EndomorphismSimplification.h
@@ -0,0 +1,163 @@
+//===- RegionUtils.h - Region-related transformation utilities --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_SIMPLIFY_ENDOMORPHISM_H_
+#define MLIR_TRANSFORMS_SIMPLIFY_ENDOMORPHISM_H_
+
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
+#include <iterator>
+#include <optional>
+#include <type_traits>
+#include <utility>
+
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+
+// If `f` is an endomorphism with respect to the algebraic structure induced by
+// function `g`, transforms `g(f(x1), f(x2) ..., f(xn))` into
+// `f(g(x1, x2, ..., xn))`.
+// `g` is the algebraic operation and `f` is the endomorphism.
+//
+// Functors:
+// ---------
+// `GetEndomorphismOpOperandFn`: `(Operation*) -> OpOperand*`
+// Returns the operand relevant to the endomorphism.
+// There may be other operands that are not relevant.
+//
+// `GetEndomorphismOpResultFn`: `(Operation*) -> OpResult`
+// Returns the result relevant to the endomorphism.
+//
+// `GetAlgebraicOpOperandsFn`: `(Operation*, SmallVector<OpOperand*>&) -> void`
+// Populates into the vector the operands relevant to the endomorphism.
+//
+// `GetAlgebraicOpResultFn`: `(Operation*) -> OpResult`
+//  Return the result relevant to the endomorphism.
+//
+// `IsEndomorphismOpFn`: `(Operation*, std::optional<Operation*>) -> bool`
+// Check if the operation is an endomorphism of the required type.
+// Additionally if the optional is present checks if the operations are
+// compatible endomorphisms.
+//
+// `IsAlgebraicOpFn`: `(Operation*) -> bool`
+// Check if the operation is an operation of the algebraic structure.
+template <typename GetEndomorphismOpOperandFn,
+          typename GetEndomorphismOpResultFn, typename GetAlgebraicOpOperandsFn,
+          typename GetAlgebraicOpResultFn, typename IsEndomorphismOpFn,
+          typename IsAlgebraicOpFn>
+struct EndomorphismSimplification : RewritePattern {
+  template <typename GetEndomorphismOpOperandFnArg,
+            typename GetEndomorphismOpResultFnArg,
+            typename GetAlgebraicOpOperandsFnArg,
+            typename GetAlgebraicOpResultFnArg, typename IsEndomorphismOpFnArg,
+            typename IsAlgebraicOpFnArg, typename... RewritePatternArgs>
+  EndomorphismSimplification(
+      GetEndomorphismOpOperandFnArg &&getEndomorphismOpOperand,
+      GetEndomorphismOpResultFnArg &&getEndomorphismOpResult,
+      GetAlgebraicOpOperandsFnArg &&getAlgebraicOpOperands,
+      GetAlgebraicOpResultFnArg &&getAlgebraicOpResult,
+      IsEndomorphismOpFnArg &&isEndomorphismOp,
+      IsAlgebraicOpFnArg &&isAlgebraicOp, RewritePatternArgs &&...args)
+      : RewritePattern(std::forward<RewritePatternArgs>(args)...),
+        getEndomorphismOpOperand(std::forward<GetEndomorphismOpOperandFnArg>(
+            getEndomorphismOpOperand)),
+        getEndomorphismOpResult(std::forward<GetEndomorphismOpResultFnArg>(
+            getEndomorphismOpResult)),
+        getAlgebraicOpOperands(
+            std::forward<GetAlgebraicOpOperandsFnArg>(getAlgebraicOpOperands)),
+        getAlgebraicOpResult(
+            std::forward<GetAlgebraicOpResultFnArg>(getAlgebraicOpResult)),
+        isEndomorphismOp(std::forward<IsEndomorphismOpFnArg>(isEndomorphismOp)),
+        isAlgebraicOp(std::forward<IsAlgebraicOpFnArg>(isAlgebraicOp)) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    if (failed(matchOp(op, algebraicOpOperands))) {
+      return failure();
+    }
+    return rewriteOp(op, algebraicOpOperands, rewriter);
+  }
+
+private:
+  LogicalResult matchOp(Operation *algebraicOp,
+                        SmallVector<OpOperand *> &algebraicOpOperands) const {
+    if (!isAlgebraicOp(algebraicOp)) {
+      return failure();
+    }
+    algebraicOpOperands.clear();
+    getAlgebraicOpOperands(algebraicOp, algebraicOpOperands);
+    if (algebraicOpOperands.empty()) {
+      return failure();
+    }
+
+    Operation *firstEndomorphismOp =
+        algebraicOpOperands.front()->get().getDefiningOp();
+    if (!firstEndomorphismOp ||
+        !isEndomorphismOp(firstEndomorphismOp, std::nullopt)) {
+      return failure();
+    }
+    OpResult firstEndomorphismOpResult =
+        getEndomorphismOpResult(firstEndomorphismOp);
+    if (getEndomorphismOpResult(firstEndomorphismOp) !=
+        algebraicOpOperands.front()->get()) {
+      return failure();
+    }
+
+    for (auto operand : algebraicOpOperands) {
+      Operation *endomorphismOp = operand->get().getDefiningOp();
+      if (!endomorphismOp ||
+          !isEndomorphismOp(endomorphismOp, firstEndomorphismOp)) {
+        return failure();
+      }
+    }
+    return success();
+  }
+
+  LogicalResult rewriteOp(Operation *algebraicOp,
+                          const SmallVector<OpOperand *> &algebraicOpOperands,
+                          PatternRewriter &rewriter) const {
+    irMapping.clear();
+    for (auto operand : algebraicOpOperands) {
+      Operation *endomorphismOp = operand->get().getDefiningOp();
+      irMapping.map(operand->get(),
+                    getEndomorphismOpOperand(endomorphismOp)->get());
+    }
+    Operation *newAlgebraicOp = rewriter.clone(*algebraicOp, irMapping);
+
+    irMapping.clear();
+    assert(!algebraicOpOperands.empty());
+    Operation *firstEndomorphismOp =
+        algebraicOpOperands[0]->get().getDefiningOp();
+    irMapping.map(getEndomorphismOpOperand(firstEndomorphismOp)->get(),
+                  getAlgebraicOpResult(newAlgebraicOp));
+    Operation *newEndomorphismOp =
+        rewriter.clone(*firstEndomorphismOp, irMapping);
+    rewriter.replaceAllUsesWith(getAlgebraicOpResult(algebraicOp),
+                                getEndomorphismOpResult(newEndomorphismOp));
+    return success();
+  }
+
+  GetEndomorphismOpOperandFn getEndomorphismOpOperand;
+  GetEndomorphismOpResultFn getEndomorphismOpResult;
+  GetAlgebraicOpOperandsFn getAlgebraicOpOperands;
+  GetAlgebraicOpResultFn getAlgebraicOpResult;
+  IsEndomorphismOpFn isEndomorphismOp;
+  IsAlgebraicOpFn isAlgebraicOp;
+  mutable SmallVector<OpOperand *> algebraicOpOperands;
+  mutable IRMapping irMapping;
+};
+
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_SIMPLIFY_ENDOMORPHISM_H_
diff --git a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
index bcf45c4ea276080..044b8672c8c60cf 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRMeshTransforms
+  Simplifications.cpp
   ShardingPropagation.cpp
 
   ADDITIONAL_HEADER_DIRS
@@ -9,6 +10,7 @@ add_mlir_dialect_library(MLIRMeshTransforms
   MLIRShardingInterface
 
   LINK_LIBS PUBLIC
+  MLIRArithDialect
   MLIRFuncDialect
   MLIRIR
   MLIRMeshDialect
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
new file mode 100644
index 000000000000000..1d241fe03a127b8
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
@@ -0,0 +1,37 @@
+//===- Patterns.cpp - Mesh Patterns -----------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+
+namespace mlir {
+namespace mesh {
+
+void populateSimplificationPatterns(RewritePatternSet &patterns) {
+  populateAllReduceEndomorphismSimplificationPatterns<arith::AddFOp>(
+      patterns, Partial::Sum);
+  populateAllReduceEndomorphismSimplificationPatterns<arith::AddIOp>(
+      patterns, Partial::Sum);
+
+  populateAllReduceEndomorphismSimplificationPatterns<arith::MinimumFOp>(
+      patterns, Partial::Min);
+  populateAllReduceEndomorphismSimplificationPatterns<arith::MinSIOp>(
+      patterns, Partial::Min);
+  populateAllReduceEndomorphismSimplificationPatterns<arith::MinUIOp>(
+      patterns, Partial::Min);
+
+  populateAllReduceEndomorphismSimplificationPatterns<arith::MaximumFOp>(
+      patterns, Partial::Max);
+  populateAllReduceEndomorphismSimplificationPatterns<arith::MaxSIOp>(
+      patterns, Partial::Max);
+  populateAllReduceEndomorphismSimplificationPatterns<arith::MaxUIOp>(
+      patterns, Partial::Max);
+}
+
+} // namespace mesh
+} // namespace mlir
diff --git a/mlir/test/Dialect/Mesh/simplifications.mlir b/mlir/test/Dialect/Mesh/simplifications.mlir
new file mode 100644
index 000000000000000..2b305df6e0a97f1
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/simplifications.mlir
@@ -0,0 +1,131 @@
+// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [4, 2])
+mesh.cluster @mesh1(rank = 1, dim_sizes = [4])
+
+// Checks that `all_reduce(x) + all_reduce(y)` gets transformed to
+// `all_reduce(x + y)`.
+// CHECK-LABEL: func.func @all_reduce_arith_addf_endomorphism
+func.func @all_reduce_arith_addf_endomorphism(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg0: tensor<5xf32>,
+    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg1: tensor<5xf32>) -> tensor<5xf32> {
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf32>
+  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ARG0]], %[[ARG1]]
+  %2 = arith.addf %0, %1 : tensor<5xf32>
+  // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]]
+  // CHECK: return %[[ALL_REDUCE_RES]]
+  return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_mesh
+func.func @all_reduce_arith_addf_no_endomorphism_different_mesh(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg0: tensor<5xf32>,
+    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg1: tensor<5xf32>) -> tensor<5xf32> {
+  // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh1
+  %1 = mesh.all_reduce %arg1 on @mesh1 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
+  %2 = arith.addf %0, %1 : tensor<5xf32>
+  // CHECK: return %[[ADD_RES]]
+  return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_mesh_axes
+func.func @all_reduce_arith_addf_no_endomorphism_different_mesh_axes(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg0: tensor<5xf32>,
+    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg1: tensor<5xf32>) -> tensor<5xf32> {
+  // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0]
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [1]
+  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [1]
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
+  %2 = arith.addf %0, %1 : tensor<5xf32>
+  // CHECK: return %[[ADD_RES]]
+  return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_wrong_reduction_kind
+func.func @all_reduce_arith_addf_no_endomorphism_wrong_reduction_kind(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg0: tensor<5xf32>,
+    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg1: tensor<5xf32>) -> tensor<5xf32> {
+  // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0] reduction = <max>
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = <max>
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [0]
+  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
+  %2 = arith.addf %0, %1 : tensor<5xf32>
+  // CHECK: return %[[ADD_RES]]
+  return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_operand_result_element_types
+func.func @all_reduce_arith_addf_no_endomorphism_different_operand_result_element_types(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg0: tensor<5xf32>,
+    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg1: tensor<5xf32>) -> tensor<5xf64> {
+  // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0]
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf64>
+  // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [0]
+  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf64>
+  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
+  %2 = arith.addf %0, %1 : tensor<5xf64>
+  // CHECK: return %[[ADD_RES]]
+  return %2 : tensor<5xf64>
+}
+
+// Checks that `min(all_reduce(x), all_reduce(y))` gets transformed to
+// `all_reduce(min(x, y))`.
+// CHECK-LABEL: func.func @all_reduce_arith_minimumf_endomorphism
+func.func @all_reduce_arith_minimumf_endomorphism(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg0: tensor<5xf32>,
+    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg1: tensor<5xf32>) -> tensor<5xf32> {
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = <min>
+    : tensor<5xf32> -> tensor<5xf32>
+  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = <min>
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.minimumf %[[ARG0]], %[[ARG1]]
+  %2 = arith.minimumf %0, %1 : tensor<5xf32>
+  // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = <min>
+  // CHECK: return %[[ALL_REDUCE_RES]]
+  return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_minsi_endomorphism
+func.func @all_reduce_arith_minsi_endomorphism(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xi32>
+    %arg0: tensor<5xi32>,
+    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xi32>
+    %arg1: tensor<5xi32>) -> tensor<5xi32> {
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = <min>
+    : tensor<5xi32> -> tensor<5xi32>
+  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = <min>
+    : tensor<5xi32> -> tensor<5xi32>
+  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.minsi %[[ARG0]], %[[ARG1]]
+  %2 = arith.minsi %0, %1 : tensor<5xi32>
+  // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = <min>
+  // CHECK: return %[[ALL_REDUCE_RES]]
+  return %2 : ten...
[truncated]

@sogartar
Copy link
Contributor Author

@yaochengji could you review this?

@sogartar sogartar force-pushed the simplify-all-reduce-endomorphism branch 3 times, most recently from 363f6e5 to 8df8ae3 Compare November 22, 2023 23:41
@@ -17,6 +17,8 @@ namespace func {
class FuncOp;
}

class RewritePatternSet;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems the forward declaration is not needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed it.

void populateAllReduceEndomorphismSimplificationPatterns(
RewritePatternSet &patterns, Partial reduction) {
auto getEndomorphismOpOperand = [](Operation *op) {
auto allReduceOp = llvm::cast<AllReduceOp>(op);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems all_gather could also benefit from this. I suggest we could mark it as a TODO here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a comment in populateSimplificationPatterns.

};
auto getAlgebraicOpOperands = [](Operation *op,
SmallVector<OpOperand *> &operands) {
auto algebraicOp = llvm::cast<AlgebraicOp>(op);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if AlgebraicOp is a destination style op? I think the logic should be different, but we can add a TODO here first.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean a scenario like this?
h is the endomorphism and a is the algebraic structure op.
Then

a(h(x), h(y), z) = h(a(x, y, z))

does not actually hold.

If that is the case you may be able to do homomorphism simplification where the target algebraic structure op b is different. And we have

a(h(x), h(y), z) = h(b(x, y, z))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, what I meant is ops in linalg dialect which have the DestinationStyleOpInterface. For a linalg.matmul op

%matmul = linalg.matmul ins(%0, %1 : tensor<1x1xi8>, tensor<1x?xi8>) outs(%2 : tensor<1x?xi32>) -> tensor<1x?xi32>

Here the %2 operand is different from %0 and %1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I expanded the documentation of the function.

@@ -0,0 +1,131 @@
// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happened if the all_reduce has multi users?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pass does not remove code it just adds. My expectation is to fix the side-effecting nature of the collectives and allow for DCE to remove the unused code. I think in the context of SPMD we ca reason that all_reduce is not side-effecting.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sogartar seems the multi-users case is not addressed?

Copy link
Contributor Author

@sogartar sogartar Dec 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean not to do the simplification in that case since there may not be a performance benefit?
A more complete approach would be if somehow we knew if the other use will be also removed.
Maybe we can have another rewrite pattern that reverses the endomorphism simplification if the other use is still alive. Then we apply it in another pass to avoid infinite loop of rewrites.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other option is to avoid doing the simplification in the first place, but how would we know in a generic way if the other use will be removed or not?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's a complex problem. A common approach is to leave it unsimplified. XLA has a pass with similar logic, maybe we can check its logic there

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made it not do the simplification in such a scenario.

@sogartar sogartar self-assigned this Dec 6, 2023
Copy link
Contributor Author

@sogartar sogartar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lol. Didn't think I was starting a review, so my comments were not submitted.

mlir/include/mlir/Transforms/EndomorphismSimplification.h Outdated Show resolved Hide resolved
// CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0]
%0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
: tensor<5xf32> -> tensor<5xf32>
// CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [1]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This simplification does not remove dead code. The previous all_reduces will still be there. They will not be used though.
We need to decide on the side-effecting nature of the collective operations. Whether we are in the context of SPMD or MPMD?

@@ -0,0 +1,131 @@
// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pass does not remove code it just adds. My expectation is to fix the side-effecting nature of the collectives and allow for DCE to remove the unused code. I think in the context of SPMD we ca reason that all_reduce is not side-effecting.

};
auto getAlgebraicOpOperands = [](Operation *op,
SmallVector<OpOperand *> &operands) {
auto algebraicOp = llvm::cast<AlgebraicOp>(op);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean a scenario like this?
h is the endomorphism and a is the algebraic structure op.
Then

a(h(x), h(y), z) = h(a(x, y, z))

does not actually hold.

If that is the case you may be able to do homomorphism simplification where the target algebraic structure op b is different. And we have

a(h(x), h(y), z) = h(b(x, y, z))

void populateAllReduceEndomorphismSimplificationPatterns(
RewritePatternSet &patterns, Partial reduction) {
auto getEndomorphismOpOperand = [](Operation *op) {
auto allReduceOp = llvm::cast<AllReduceOp>(op);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a comment in populateSimplificationPatterns.

@@ -17,6 +17,8 @@ namespace func {
class FuncOp;
}

class RewritePatternSet;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed it.

@@ -0,0 +1,131 @@
// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s
Copy link
Contributor Author

@sogartar sogartar Dec 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean not to do the simplification in that case since there may not be a performance benefit?
A more complete approach would be if somehow we knew if the other use will be also removed.
Maybe we can have another rewrite pattern that reverses the endomorphism simplification if the other use is still alive. Then we apply it in another pass to avoid infinite loop of rewrites.

@@ -0,0 +1,131 @@
// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other option is to avoid doing the simplification in the first place, but how would we know in a generic way if the other use will be removed or not?

@sogartar
Copy link
Contributor Author

sogartar commented Dec 6, 2023

Lol. Didn't think I was starting a review, so my comments were not submitted.

@yaochengji could you take a look again?

@yaochengji
Copy link
Member

LGTM, thanks

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG, thanks

@sogartar sogartar force-pushed the simplify-all-reduce-endomorphism branch from 53b0d2e to fab21a1 Compare December 12, 2023 14:48
@sogartar
Copy link
Contributor Author

Thank you for the review I rebased it before merging to check that all is well.

@sogartar sogartar merged commit 4b34467 into llvm:main Dec 12, 2023
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants