Skip to content

[mlir][mesh] Add collective communication operations #71960

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Nov 21, 2023

Conversation

sogartar
Copy link
Contributor

Add all-gather, all-reduce, all-to-all and reduce-scatter. These operations have device mesh semantics.

I have not included ops like reduce, gather, send and recv to see first if reviewers notice any systemic issues. Also this PR is already big enough.

Add all-gather, all-reduce, all-to-all and reduce-scatter.
These operations have device mesh semantics.
@llvmbot llvmbot added the mlir label Nov 10, 2023
@sogartar sogartar requested a review from joker-eph November 10, 2023 17:18
@llvmbot
Copy link
Member

llvmbot commented Nov 10, 2023

@llvm/pr-subscribers-mlir

Author: Boian Petkantchin (sogartar)

Changes

Add all-gather, all-reduce, all-to-all and reduce-scatter. These operations have device mesh semantics.

I have not included ops like reduce, gather, send and recv to see first if reviewers notice any systemic issues. Also this PR is already big enough.


Patch is 43.83 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/71960.diff

8 Files Affected:

  • (added) mlir/docs/Dialects/Mesh.md (+34)
  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td (+5-3)
  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+2)
  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+216)
  • (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+417)
  • (added) mlir/test/Dialect/Mesh/canonicalization.mlir (+72)
  • (modified) mlir/test/Dialect/Mesh/invalid.mlir (+240)
  • (modified) mlir/test/Dialect/Mesh/ops.mlir (+119)
diff --git a/mlir/docs/Dialects/Mesh.md b/mlir/docs/Dialects/Mesh.md
new file mode 100644
index 000000000000000..6dd4f79022061ee
--- /dev/null
+++ b/mlir/docs/Dialects/Mesh.md
@@ -0,0 +1,34 @@
+# 'mesh' Dialect
+
+The `mesh` dialect contains a set of attributes, operations and interfaces that
+are useful for representing sharding and communication on a device mesh
+cluster.
+
+[TOC]
+
+## Collective Communication Operations
+There are a number of operations in the Mesh dialect to facilitate
+communication between devices in a mesh.
+It is assumed that the user is familiar with collective operations.
+[Wikipedia](https://en.wikipedia.org/wiki/Collective_operation) has a good
+explanation.
+The main addition is that the collectives in this dialect have mesh
+semantics.
+The operation attributes `mesh` and `mesh_axes` specifies a set of device mesh
+axes that partition the devices into disjoint groups.
+The collective operation is performed between devices in the same group.
+Devices that have the same coordinates outside of axes `mesh_axes` are in the
+same group.
+For example if we have a device mesh of size `2x3x4x5` and the partition mesh
+axes set is `{0, 1}` then devices are partitioned into the groups
+`{ { (i, j, k, m) | 0<=i<2, 0<=j<3 } | 0<=k<4, 0<=m<5 }`.
+Devices (1, 0, 2, 3) and (1, 1, 2, 3) will be in the same group.
+Device (1, 0, 2, 4) will be in another group.
+
+## Operations
+
+[include "Dialects/MeshOps.md"]
+
+## Attributes
+
+[include "Dialects/MeshAttributes.md"]
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index a91ef569347bff1..9d39b1b3329fb4b 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -23,9 +23,7 @@ def Mesh_Dialect : Dialect {
   let cppNamespace = "::mlir::mesh";
 
   let description = [{
-    The `mesh` dialect contains a set of attributes, operations, interfaces that
-    are useful for representing sharding and communication on device mesh
-    cluster.
+    See [Mesh dialect documentation](mlir/docs/Dialects/Mesh.md).
   }];
 
   let dependentDialects = [
@@ -49,6 +47,10 @@ def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor"
   let cppNamespace = "::mlir::mesh";
 }
 
+def Mesh_PartialAttr : EnumAttr<Mesh_Dialect, Mesh_Partial, "partial"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
 // Mesh_IteratorType and Mesh_Partial are used to annotate different aspects of
 // distributed tensors. Mesh_IteratorType annotates loops in an operation, while
 // Mesh_Partial indicates whether a tensor is sharded on a specific dimension or
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 05eba66a89949b6..7698d60813a8f10 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -10,9 +10,11 @@
 #define MLIR_DIALECT_MESH_IR_MESHOPS_H
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include <algorithm>
 
 #include "mlir/Dialect/Mesh/IR/MeshOpsDialect.h.inc"
 
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index a8aa0a694bee29f..15354babe870599 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -13,6 +13,8 @@ include "mlir/Dialect/Mesh/IR/MeshBase.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/BuiltinTypes.td"
+include "mlir/IR/CommonAttrConstraints.td"
+include "mlir/IR/CommonTypeConstraints.td"
 include "mlir/IR/SymbolInterfaces.td"
 
 //===----------------------------------------------------------------------===//
@@ -77,6 +79,15 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
     $sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` $dim_sizes^)? `)`
       attr-dict
   }];
+  let extraClassDeclaration = [{
+    ::mlir::SmallVector<int64_t> canonicalDimSizes();
+
+    template <typename OutIt>
+    void canonicalDimSizes(OutIt outIt) {
+      std::copy(getDimSizes().begin(), getDimSizes().end(), outIt);
+      std::fill_n(outIt, getRank() - getDimSizes().size(), 0);
+    }
+  }];
   let hasVerifier = 1;
 }
 
@@ -171,4 +182,209 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// collective communication ops
+//===----------------------------------------------------------------------===//
+
+class Mesh_CollectiveCommunicationOpBase<
+    string mnemonic, list<Trait> traits = []> :
+    Mesh_Op<mnemonic,
+      !listconcat(traits,
+      [SymbolUserOpInterface])> {
+  let assemblyFormat = "$input attr-dict `:` type($input) `->` type($result)";
+  code extraClassDeclarationBase = [{
+    ::mlir::LogicalResult verifySymbolUses(
+          ::mlir::SymbolTableCollection &symbolTable);
+  }];
+}
+
+def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
+    SameOperandsAndResultElementType,
+    SameOperandsAndResultRank
+  ]> {
+  let summary = "All-gather over a device mesh.";
+  let description = [{
+    Gathers along the `gather_axis` tensor axis.
+    The order of input tensors in the resulting tensor is the same as the
+    order of the corresponding devices' multi-index in the mesh.
+
+    Example:
+    ```mlir
+    mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+    ...
+    %1 = mesh.all_gather %0 {
+        mesh = @mesh0, mesh_axes = array<i16: 1>, gather_axis = 1 : index
+      } : tensor<2x2xi8> -> tensor<2x4xi8>
+    ```
+    Input:
+    ```
+                     +-------+-------+
+    device (0, 0) -> |  1  2 |  5  6 | <- device (0, 1)
+                     |  3  4 |  7  8 |
+                     +-------+-------+
+    device (1, 0) -> |  9 10 | 13 14 | <- device (1, 1)
+                     | 11 12 | 15 16 |
+                     +-------+-------+
+    ```
+    Result:
+    ```
+    +-------------+
+    |  1  2  5  6 | <- devices (0, 0) and (0, 1)
+    |  3  4  7  8 |
+    +-------------+
+    |  9 10 13 14 | <- devices (1, 0) and (1, 1)
+    | 11 12 15 16 |
+    +-------------+
+    ```
+  }];
+  let arguments = (ins
+    AnyNon0RankedTensor:$input,
+    FlatSymbolRefAttr:$mesh,
+    DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
+    APIntAttr:$gather_axis
+  );
+  let results = (outs
+    AnyNon0RankedTensor:$result
+  );
+  let hasCanonicalizer = 1;
+  let hasVerifier = 1;
+  let extraClassDeclaration = extraClassDeclarationBase;
+}
+
+def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
+    SameOperandsAndResultShape]> {
+  let summary = "All-reduce over a device mesh.";
+  let description = [{
+    The accumulation element type is specified by the result type and
+    it does not need to match the input element type.
+    The input element is converted to the result element type before
+    performing the reduction.
+
+    Attributes:
+    `reduction`: Indicates the reduction method.
+
+    Example:
+    ```
+    %1 = mesh.all_reduce %0 {
+        mesh = @mesh0, mesh_axes = array<i16: 1, 0>, reduction = #mesh.partial<max>
+      } : tensor<3x4xf32> -> tensor<3x4xf64>
+    ```
+  }];
+  let arguments = (ins
+    AnyRankedTensor:$input,
+    FlatSymbolRefAttr:$mesh,
+    DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
+    DefaultValuedOptionalAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction
+  );
+  let results = (outs
+    AnyRankedTensor:$result
+  );
+  let hasCanonicalizer = 1;
+  let extraClassDeclaration = extraClassDeclarationBase;
+}
+
+def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
+    SameOperandsAndResultElementType,
+    SameOperandsAndResultRank]> {
+  let summary = "All-to-all over a device mesh.";
+  let description = [{
+    Performs an all-to-all on tensor pieces split along `split_axis`.
+    The resulting pieces are concatenated along `concat_axis` on ech device.
+    Example:
+    ```
+    mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+    ...
+    %1 = mesh.all_to_all %0 {
+        mesh = @mesh0, mesh_axes = array<i16: 0>, split_axis = 0, concat_axis = 0
+      } : tensor<3x6xi8> -> tensor<3x6xi8>
+    ```
+    Input:
+    ```
+     device  device  device
+     (0)     (1)     (2)
+    +-------+-------+-------+
+    | 11 12 | 21 22 | 31 32 |
+    | 13 14 | 23 24 | 33 34 |
+    | 15 16 | 25 26 | 35 36 |
+    +-------+-------+-------+
+    ```
+    Result:
+    ```
+     device  device  device
+     (0)     (1)     (2)
+    +-------+-------+-------+
+    | 11 12 | 13 14 | 15 16 |
+    | 21 22 | 23 24 | 25 26 |
+    | 31 32 | 33 34 | 35 36 |
+    +-------+-------+-------+
+    ```
+  }];
+  let arguments = (ins
+    AnyNon0RankedTensor:$input,
+    FlatSymbolRefAttr:$mesh,
+    DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
+    APIntAttr:$split_axis,
+    APIntAttr:$concat_axis
+  );
+  let results = (outs
+    AnyNon0RankedTensor:$result
+  );
+  let hasCanonicalizer = 1;
+  let hasVerifier = 1;
+  let extraClassDeclaration = extraClassDeclarationBase;
+}
+
+def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
+    SameOperandsAndResultRank]> {
+  let summary = "Reduce-scatter over a device mesh.";
+  let description = [{
+    After the reduction scatters the result within each device group.
+    The tensor is split along `scatter_axis` and the pieces distributed
+    across the device group.
+    Example:
+    ```
+    mesh.cluster @mesh0(rank = 1, dim_sizes = [2, 2])
+    ...
+    %1 = mesh.reduce_scatter %0 {
+        mesh = @mesh0, mesh_axes = array<i16: 1>, reduction = #mesh.partial<max>, scatter_axis = 0
+      } : tensor<3x4xf32> -> tensor<1x4xf64>
+    ```
+    Input:
+    ```
+                     +-------+-------+
+    device (0, 0) -> |  1  2 |  5  6 | <- device (0, 1)
+                     |  3  4 |  7  8 |
+                     +-------+-------+
+    device (1, 0) -> |  9 10 | 13 14 | <- device (1, 1)
+                     | 11 12 | 15 16 |
+                     +-------+-------+
+    ```
+    Result:
+    ```
+    +-------+
+    |  6  8 | <- devices (0, 0)
+    +-------+
+    | 10 12 | <- devices (0, 1)
+    +-------+
+    | 22 24 | <- devices (1, 0)
+    +-------+
+    | 26 28 | <- devices (1, 1)
+    +-------+
+    ```
+  }];
+  let arguments = (ins
+    AnyNon0RankedTensor:$input,
+    FlatSymbolRefAttr:$mesh,
+    DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
+    DefaultValuedOptionalAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
+    APIntAttr:$scatter_axis
+  );
+  let results = (outs
+    AnyRankedTensor:$result
+  );
+  let hasCanonicalizer = 1;
+  let hasVerifier = 1;
+  let extraClassDeclaration = extraClassDeclarationBase;
+}
+
 #endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 588704f24574f90..6efc4c4ecc326ad 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -8,10 +8,26 @@
 
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include <algorithm>
+#include <functional>
+#include <iterator>
+#include <numeric>
+#include <optional>
+#include <string>
+#include <utility>
 
 #define DEBUG_TYPE "mesh-ops"
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
@@ -21,6 +37,60 @@ using namespace mlir::mesh;
 
 #include "mlir/Dialect/Mesh/IR/MeshOpsDialect.cpp.inc"
 
+namespace {
+
+template <typename It>
+It canonicalizeSetAsArray(It begin, It end) {
+  std::sort(begin, end);
+  return std::unique(begin, end);
+}
+
+template <typename R>
+auto canonicalizeSetAsArray(R &&range) {
+  return canonicalizeSetAsArray(adl_begin(range), adl_end(range));
+}
+
+template <typename T>
+SmallVector<T> &canonicalizeSetAsVector(SmallVector<T> &vec) {
+  auto newEnd = canonicalizeSetAsArray(vec);
+  vec.resize(newEnd - vec.begin());
+  return vec;
+}
+
+template <typename DimSize>
+bool isMeshDimensionDynamic(DimSize size) {
+  return size <= DimSize(0);
+}
+
+using MeshAxis = int16_t;
+
+struct DimensionSize {
+  static DimensionSize dynamic() { return DimensionSize(ShapedType::kDynamic); }
+  DimensionSize(int64_t val) : val(val) {}
+  int64_t value() const { return val; }
+  operator int64_t() const { return val; }
+  bool isDynamic() const { return ShapedType::isDynamic(val); }
+
+private:
+  int64_t val;
+};
+
+DimensionSize operator/(DimensionSize lhs, DimensionSize rhs) {
+  if (lhs.isDynamic() || rhs.isDynamic()) {
+    return DimensionSize::dynamic();
+  }
+  return lhs.value() / rhs.value();
+}
+
+DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
+  if (lhs.isDynamic() || rhs.isDynamic()) {
+    return DimensionSize::dynamic();
+  }
+  return lhs.value() * rhs.value();
+}
+
+} // namespace
+
 //===----------------------------------------------------------------------===//
 // Mesh dialect
 //===----------------------------------------------------------------------===//
@@ -96,6 +166,12 @@ LogicalResult ClusterOp::verify() {
   return success();
 }
 
+SmallVector<int64_t> ClusterOp::canonicalDimSizes() {
+  SmallVector<int64_t> result;
+  canonicalDimSizes(std::back_inserter(result));
+  return result;
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.shard op
 //===----------------------------------------------------------------------===//
@@ -129,6 +205,347 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// collective communication ops
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+std::optional<DenseI16ArrayAttr>
+canonicalizeAxesSetAttribute(DenseI16ArrayAttr attr) {
+  if (!attr) {
+    return std::nullopt;
+  }
+  SmallVector<int16_t> axes = llvm::to_vector(attr.asArrayRef());
+  canonicalizeSetAsVector(axes);
+  if (axes.empty()) {
+    return std::nullopt;
+  }
+  return DenseI16ArrayAttr::get(attr.getContext(), axes);
+}
+
+template <typename Op>
+struct AxesSetCanonicalizationPattern : OpRewritePattern<Op> {
+  AxesSetCanonicalizationPattern(MLIRContext *context, StringRef axisSetAttr)
+      : OpRewritePattern<Op>(context), axisSetAttr(axisSetAttr) {}
+  LogicalResult matchAndRewrite(Op op,
+                                PatternRewriter &rewriter) const override {
+    auto canonicalMeshAxesAttr = canonicalizeAxesSetAttribute(
+        op->template getAttrOfType<DenseI16ArrayAttr>(axisSetAttr));
+    if (!canonicalMeshAxesAttr) {
+      op->removeAttr(axisSetAttr);
+    } else {
+      op->setAttr(axisSetAttr, canonicalMeshAxesAttr.value());
+    }
+    return success();
+  }
+
+  std::string axisSetAttr;
+};
+
+template <typename Op>
+void populateMeshAxesSetCanonicalizationPatterns(RewritePatternSet &patterns,
+                                                 MLIRContext *context) {
+  patterns.add<AxesSetCanonicalizationPattern<Op>>(context, "mesh_axes");
+}
+
+template <typename Op>
+LogicalResult verifyMeshSymbolUses(Op op, SymbolTableCollection &symbolTable) {
+  FlatSymbolRefAttr symbolAttr = op.getMeshAttr();
+  if (!symbolAttr) {
+    return op.emitError() << "Unspecified \"mesh\" symbol attribute.";
+  }
+  SymbolTableCollection symbolTableCollection;
+  mesh::ClusterOp mesh =
+      symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
+          op.getOperation(), symbolAttr);
+  if (!mesh) {
+    return op.emitError() << "Undefined required mesh symbol \""
+                          << symbolAttr.getValue() << "\".";
+  }
+  DenseI16ArrayAttr meshAxes = op.getMeshAxesAttr();
+  if (!meshAxes) {
+    return success();
+  }
+  MeshAxis rank = mesh.getRank();
+  for (auto axis : meshAxes.asArrayRef()) {
+    if (axis >= rank || axis < 0) {
+      return op.emitError()
+             << "0-based mesh axis index " << axis
+             << " is out of bounds. The referenced mesh \""
+             << symbolAttr.getValue() << "\" is of rank " << rank << ".";
+    }
+  }
+
+  return success();
+}
+
+template <typename It>
+auto product(It begin, It end) {
+  using ElementType = std::decay_t<decltype(*begin)>;
+  return std::accumulate(begin, end, ElementType(1),
+                         std::multiplies<ElementType>());
+}
+
+template <typename R>
+auto product(R &&range) {
+  return product(adl_begin(range), adl_end(range));
+}
+
+int64_t collectiveDeviceGroupSize(ArrayRef<MeshAxis> meshAxes,
+                                  ArrayRef<int64_t> meshShape) {
+  int64_t res = 1;
+  for (MeshAxis axis = 0; axis < MeshAxis(meshShape.size()); ++axis) {
+    if (llvm::find(meshAxes, axis) == meshAxes.end()) {
+      continue;
+    }
+    if (isMeshDimensionDynamic(meshShape[axis])) {
+      return ShapedType::kDynamic;
+    }
+    res *= meshShape[axis];
+  }
+  return res;
+}
+
+LogicalResult verifyDimensionCompatibility(Location loc,
+                                           int64_t expectedDimSize,
+                                           int64_t resultDimSize,
+                                           int64_t resultAxis) {
+  if (!ShapedType::isDynamic(resultDimSize) &&
+      expectedDimSize != resultDimSize) {
+    return emitError(loc) << "Dimension size mismatch for result axis "
+                          << resultAxis << ". Expected "
+                          << (ShapedType::isDynamic(expectedDimSize)
+                                  ? Twine("dynamic")
+                                  : Twine(expectedDimSize))
+                          << ", but got " << resultDimSize << ".";
+  }
+
+  return success();
+}
+
+LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result,
+                                                int64_t gatherAxis,
+                                                ArrayRef<MeshAxis> meshAxes,
+                                                ArrayRef<int64_t> meshShape) {
+  ShapedType operandType = operand.getType().cast<ShapedType>();
+  ShapedType resultType = result.getType().cast<ShapedType>();
+  auto deviceGroupSize =
+      DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
+    auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
+    auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
+    auto expectedResultDimSize =
+        axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
+    if (failed(verifyDimensionCompatibility(
+            result.getLoc(), expectedResultDimSize, resultDimSize, axis))) {
+      return failure();
+    }
+  }
+  return success();
+}
+
+template <typename Op>
+FailureOr<ClusterOp> getMesh(Op op) {
+  SymbolTableCollection symbolTableCollection;
+  if (failed(verifyMeshSymbolUses(op, symbolTableCollection))) {
+    // We need to check the symbol here since this runs before
+    // SymbolUserOpInterface.
+    return failure();
+  }
+  return symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
+      op.getOperation(), op.getMeshAttr());
+}
+
+template <typename Op>
+LogicalResult verifyGather(Op op) {
+  auto rank = op.getResult().getType().template cast<ShapedType>().getRank();
+  auto gatherAxis = op.getGatherAxis().getSExtValue();
+  if (gatherAxis < 0 || gatherAxis >= rank) {
+    return op.emitError() << "Gather axis " << gatherAxis
+                          << " is out of bounds [0, " << rank << ").";
+  }
+
+  auto mesh = getMesh(op);
+  if (failed(mesh)) {
+    return failure();
+  }
+  return verifyGatherOperandAndResultShape(op.getOperand(), op.get...
[truncated]

@@ -96,6 +166,12 @@ LogicalResult ClusterOp::verify() {
return success();
}

SmallVector<int64_t> ClusterOp::canonicalDimSizes() {
SmallVector<int64_t> result;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
SmallVector<int64_t> result;
SmallVector<int64_t> result;
result.reserve(getRank());

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

template <typename It>
auto product(It begin, It end) {
using ElementType = std::decay_t<decltype(*begin)>;
return std::accumulate(begin, end, ElementType(1),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to suggest to use static_cast<ElementType>(1) in case some classes use explicit constructor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

int64_t collectiveDeviceGroupSize(ArrayRef<MeshAxis> meshAxes,
ArrayRef<int64_t> meshShape) {
int64_t res = 1;
for (MeshAxis axis = 0; axis < MeshAxis(meshShape.size()); ++axis) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on that there's no duplication in meshAxes, it could change to

for (MeshAxis axis : meshAxes) {
    if (isMeshDimensionDynamic(meshShape[axis])) {
        return ShapedType::kDynamic;
    }
    res *= meshShape[axis];
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

return success();
}

LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you make all the ccl ops inherit InferShapedTypeOpInterface and implement verification based on it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is the problem that InferShapedTypeOpInterface does not assume that the operation is created. You only have the operands (Values). But we need to have the operation to find the referenced CusterOp. The resolution of the referenced ClusterOp is dependent on which block the collective operation is defined in.
I am not sure how to resolve this yet. It seems shape/type inference is unequipped to deal with this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you get it referenced ClusterOp by using symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(operands[0].getParentRegion()->getParentOp(), meshAttr)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically the operand can be defined higher up the tree then the cluster op.
Something like

{
  %operand = ...
  {
    mesh.cluster ...
    %res = mesh.all_reduce (%operand) ...
  }
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure what are the restrictions on symbol resolution. We may have to impose additional restrictions on the mesh operations to be able to use this.
It is likely no one else has ran into this problem. func.call for example does not have type inference. It would have hit this issue if it did.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

func.call doesn't need to have type inference because its inference logic depends on its body.

@joker-eph do you have any suggestion? The problem is that here we need to loopup the symbol without the operation.

namespace {

std::optional<DenseI16ArrayAttr>
canonicalizeAxesSetAttribute(DenseI16ArrayAttr attr) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you regard mesh_axis as a set? I think the duplication should not be allowed in verification.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking to change its meaning. In the Mesh.md doc it is stated that the order of devices is the same as the order of their multi-index in the mesh. This would make it a set.
We can change the meaning where the order of devices could be induced by the order in mesh_axes. For example if mesh_axes = [2, 1] then device (i, j, k) would precede device (i, l, m) if k > m. This would influence operations like all-gather where order device order is important.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But it seems your reply cannot explain why we have to make mesh_axes support duplications like [2, 1, 2]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to use a set but there is no such attribute, so I went for an array. In math the set {2, 1, 2} is the same as {2, 1}. I can see why you want to forbid duplication as it may result from a programming error where the intention was something else.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think not-accepting {2, 1, 2} could make the semantic more clear.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.
I also change the meaning of mesh_axes to induce the order of devices in the device groups. I added documentation about this.

}

template <typename Op>
LogicalResult verifyMeshSymbolUses(Op op, SymbolTableCollection &symbolTable) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you make verifyMeshSymbolUses a non-template function? I'm afraid of code bloat.

The signature could change to verifyMeshSymbolUses(Operation*, FlatSymbolRefAttr, DenseI16ArrayAttr, SymbolTableCollection &).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think extracting attributes from the concrete operation type through class methods like op.getMeshAttr() does not incur attribute lookup. Also there is some benefit to inlining this function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you think this is not that important I can change it to accept Operation*.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be a bit nit-picky. It's a good point that inlining + template can improve performance at the expense of binary size.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could also add the specific attributes likes getMeshAttr() to the API of the non-templated function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made getMesh and verifyMeshSymbolUses non-template.

}

template <typename Op>
FailureOr<ClusterOp> getMesh(Op op) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you change getMesh to a non-template function to avoid code bloat?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same reasoning as in the comment about verifyMeshSymbolUses. Let's discuss there.

return success();
}

LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you change name to verifyAllGatherOperandAndResultShape because both gather and all_gather exist in ccl.

Copy link
Contributor Author

@sogartar sogartar Nov 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking ahead that the verification of gather and all-gather will be the same. I renamed it to verifyAllGatherOperandAndResultShape to avoid the confusion.

}

template <typename Op>
LogicalResult verifyGather(Op op) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you change name to verifyAllGather

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@sogartar sogartar requested a review from yaochengji November 14, 2023 19:00

LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes) {
SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
std::sort(sorted.begin(), sorted.end());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to suggest to check std::set(axes.begin(), axes.end()).size() == axes.size() instead so we don't need to sort axes. And the code will be more concise.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually construction of std::set would be more expensive then doing a copy + sort + linear search for consecutive duplicate elements. I have not measured it, but my intuition is that for small numbers like < 10 this would be true.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense, but I'd like suggest to use std::unique(begin, end) == end instead of implementing isUnique ourselves.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would even speculate that for all size a vector would be faster.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, this is fine to me.

RewritePatternSet &patterns, MLIRContext *context) {
populateMeshAxesSetCanonicalizationPatterns<AllReduceOp>(patterns, context);
LogicalResult mlir::mesh::AllReduceOp::verify() {
return verifyMeshAxes(getLoc(), getMeshAxes());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should verify the shape of all_reduce.

Copy link
Contributor Author

@sogartar sogartar Nov 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has the trait SameOperandsAndResultShape. It should take care of that verification.
What is not verified here is whether the referenced mesh has the required rank by the mesh_axes attribute. I will add it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never mind, I actually do the check in verifyMeshSymbolUses.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean this check?

if (!meshAxes) {
    return success();
  }

But an ArrayAttr with an empty array is not nullptr.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed this check since I made the mesh_axes attribute non-optional. It still has a default value.

return %0 : tensor<4xf64>
}

// CHECK-LABEL: func @all_reduce_empty_mesh_axes_and_default_reduction
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If empty mesh_axes is not allowed, we'd better check it in theverifyMeshAxes method.

@yaochengji
Copy link
Member

@sogartar seems you missed the 9 hidden coversations.

@sogartar sogartar requested a review from yaochengji November 14, 2023 22:24
::mlir::SmallVector<int64_t> canonicalDimSizes();

template <typename OutIt>
void canonicalDimSizes(OutIt outIt) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you document these methods? What is the "canonical" form?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

}];
let arguments = !con(commonArgs, (ins
AnyNon0RankedTensor:$input,
APIntAttr:$gather_axis
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I32Attr?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed all tensor axes to IndexAttr on all occasions.

SymbolTableCollection symbolTableCollection;
mesh::ClusterOp mesh =
symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
op.getOperation(), symbolAttr);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be costly, what about verifySymbolUses?

Copy link
Contributor Author

@sogartar sogartar Nov 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is called form verifySymbolUses and getMesh. We need the mesh shape for verification. Do you mean to put all verification related to the mesh symbol in verifySymbolUses?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes: verifySymbolUses is meant to be able to build a symbol table once and forall to avoid repeated traversal of the IR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All verifications are now in verifySymbolUses.

return failure();
}
return symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
op.getOperation(), op.getMeshAttr());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(same as above)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All verifications are now in verifySymbolUses.

}

template <typename Op>
LogicalResult verifyMeshSymbolUses(Op op, SymbolTableCollection &symbolTable) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could also add the specific attributes likes getMeshAttr() to the API of the non-templated function?

@sogartar sogartar requested a review from joker-eph November 16, 2023 02:06
@yaochengji
Copy link
Member

Thanks, LGTM

@sogartar sogartar merged commit 5f7c8c1 into llvm:main Nov 21, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants