Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][mesh] Add all-scatter operation #81218

Merged
merged 1 commit into from
Feb 15, 2024
Merged

Conversation

sogartar
Copy link
Contributor

@sogartar sogartar commented Feb 9, 2024

This op is the inverse of all-gather. It is useful to have an explicit concise representation instead of having a blob of slicing logic.

Add lowering for the op that slices from the tensor based on the in-group process index.

Make resharding generate an all-scatter instead of inserting the slicing logic directly.

@llvmbot
Copy link
Collaborator

llvmbot commented Feb 9, 2024

@llvm/pr-subscribers-mlir-arith
@llvm/pr-subscribers-mlir-affine
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Boian Petkantchin (sogartar)

Changes

This op is the inverse of all-gather. It is useful to have an explicit concise representation instead of having a blob of slicing logic.

Add lowering for the op that slices from the tensor based on the in-group process index.

Make resharding generate an all-scatter instead of inserting the slicing logic directly.


Patch is 55.68 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81218.diff

26 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Affine/Utils.h (+7)
  • (modified) mlir/include/mlir/Dialect/Arith/Utils/Utils.h (+17)
  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+6)
  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+63)
  • (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h (+18-1)
  • (modified) mlir/include/mlir/IR/Builders.h (+3)
  • (modified) mlir/lib/Dialect/Affine/Utils/Utils.cpp (+27)
  • (modified) mlir/lib/Dialect/Arith/Utils/Utils.cpp (+19)
  • (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+61-3)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt (+2)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp (+5-11)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+8-76)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp (+125-11)
  • (added) mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h (+35)
  • (modified) mlir/lib/IR/Builders.cpp (+18)
  • (added) mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir (+74)
  • (modified) mlir/test/Dialect/Mesh/canonicalization.mlir (+13)
  • (modified) mlir/test/Dialect/Mesh/invalid.mlir (+52)
  • (modified) mlir/test/Dialect/Mesh/ops.mlir (+24)
  • (modified) mlir/test/Dialect/Mesh/resharding-spmdization.mlir (+6-25)
  • (modified) mlir/test/Dialect/Mesh/spmdization.mlir (+6-8)
  • (modified) mlir/test/lib/Dialect/Mesh/CMakeLists.txt (+1-1)
  • (added) mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp (+79)
  • (removed) mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp (-54)
  • (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2-2)
  • (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (+2)
diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index b3ccbff3002fb1..441ce1952210bb 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -15,12 +15,14 @@
 
 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/IR/OpDefinition.h"
 #include <optional>
 
 namespace mlir {
 class DominanceInfo;
 class Operation;
 class PostDominanceInfo;
+class ImplicitLocOpBuilder;
 
 namespace func {
 class FuncOp;
@@ -309,6 +311,11 @@ DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs);
 FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
                                                Value linearIndex,
                                                ArrayRef<Value> basis);
+// Generate IR that extracts the linear index form a multi-index according to
+// a shape.
+OpFoldResult linearIndexFromShape(ArrayRef<OpFoldResult> multiIndex,
+                                  ArrayRef<OpFoldResult> shape,
+                                  ImplicitLocOpBuilder &builder);
 
 /// Ensure that all operations that could be executed after `start`
 /// (noninclusive) and prior to `memOp` (e.g. on a control flow/op path
diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
index 402bd196f0736a..2111a7c5810294 100644
--- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
@@ -20,6 +20,7 @@
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/Value.h"
+#include "llvm/ADT/ArrayRef.h"
 
 namespace mlir {
 
@@ -81,6 +82,22 @@ struct ArithBuilder {
   OpBuilder &b;
   Location loc;
 };
+
+namespace arith {
+
+// Build the product of a sequence.
+// If values = (v0, v1, ..., vn) than the returned
+// value is v0 * v1 * ... * vn.
+// All values must have the same type.
+//
+// The version without `resultType` must contain at least one element in values.
+// Then the result will have the same type as the elements in `values`.
+// If `values` is empty in the version with `resultType` returns 1 with type
+// `resultType`.
+Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values);
+Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
+                    Type resultType);
+} // namespace arith
 } // namespace mlir
 
 #endif // MLIR_DIALECT_ARITH_UTILS_UTILS_H
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 9154e6fd803102..fb9425b96e68e2 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -92,6 +92,12 @@ int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
   return res;
 }
 
+template <typename MeshAxesRange>
+int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshOp mesh) {
+  return collectiveProcessGroupSize(std::forward<MeshAxesRange>(meshAxes),
+                                    mesh.getShape());
+}
+
 // Get the size of a sharded dimension.
 inline int64_t shardDimension(int64_t dimSize, int64_t shardCount) {
   if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index da372706ec724c..7792aac784d4be 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -96,6 +96,7 @@ def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
 
   let builders = [
     OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
+    OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh, "ArrayRef<MeshAxis>":$axes)>,
     OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
   ];
 }
@@ -341,6 +342,68 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
   let hasCanonicalizer = 1;
 }
 
+def Mesh_AllScatterOp : Mesh_CollectiveCommunicationOpBase<"all_scatter", [
+    Pure,
+    SameOperandsAndResultElementType,
+    SameOperandsAndResultRank
+  ]> {
+  let summary = "All-scatter over a device mesh. This is the inverse of all-gather.";
+  let description = [{
+    Scatter along the `scatter_axis` tensor axis.
+    This operation can be taught of as the inverse of all-gather.
+    Technically, it is not required that all processes have the same input tensor.
+    Each process will slice a piece of its local tensor based on its in-group device index.
+    The operation does not communicate data between devices. 
+
+    Example:
+    ```mlir
+    mesh.mesh @mesh0(shape = 2x2)
+    ...
+    %1 = mesh.all_scatter %0 on @mesh0 mesh_axes = [1] scatter_axis = 1
+      : tensor<2x4xi8> -> tensor<2x2xi8>
+    ```
+    Input:
+    ```
+    +-------------+
+    |  1  2  5  6 | <- devices (0, 0) and (0, 1)
+    |  3  4  7  8 |
+    +-------------+
+    |  9 10 13 14 | <- devices (1, 0) and (1, 1)
+    | 11 12 15 16 |
+    +-------------+
+    ```
+    Result:
+    ```
+    gather tensor
+    axis 1
+    ------------>
+                     +-------+-------+
+    device (0, 0) -> |  1  2 |  5  6 | <- device (0, 1)
+                     |  3  4 |  7  8 |
+                     +-------+-------+
+    device (1, 0) -> |  9 10 | 13 14 | <- device (1, 1)
+                     | 11 12 | 15 16 |
+                     +-------+-------+
+    ```
+  }];
+  let arguments = !con(commonArgs, (ins
+    AnyNon0RankedTensor:$input,
+    IndexAttr:$scatter_axis
+  ));
+  let results = (outs
+    AnyNon0RankedTensor:$result
+  );
+  let assemblyFormat = [{
+    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `scatter_axis` `=` $scatter_axis
+    attr-dict `:` type($input) `->` type($result)
+  }];
+  let hasCanonicalizer = 1;
+  let builders = [
+    OpBuilder<(ins "Value":$input, "MeshOp":$mesh, "ArrayRef<MeshAxis>":$meshAxes, "int64_t":$scatterAxis)>,
+    OpBuilder<(ins "Type":$result_type, "Value":$input, "StringRef":$mesh, "ArrayRef<MeshAxis>":$meshAxes, "int64_t":$scatterAxis)>
+  ];
+}
+
 def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
     Pure,
     SameOperandsAndResultElementType,
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
index 10a965daac71b9..d398bdd65330b1 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
@@ -9,17 +9,34 @@
 #ifndef MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
 #define MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
 
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
+
 namespace mlir {
 class RewritePatternSet;
 class SymbolTableCollection;
 class DialectRegistry;
+class ImplicitLocOpBuilder;
 namespace mesh {
 
 void processMultiIndexOpLoweringPopulatePatterns(
     RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
-
 void processMultiIndexOpLoweringRegisterDialects(DialectRegistry &registry);
 
+void allScatterOpLoweringPopulatePatterns(
+    RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
+void allScatterOpLoweringRegisterDialects(DialectRegistry &registry);
+
+void populateAllOpLoweringPatterns(
+    RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
+void registerAllOpLoweringDialects(DialectRegistry &registry);
+
+TypedValue<IndexType>
+createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
+                                 ImplicitLocOpBuilder &builder);
+
 } // namespace mesh
 } // namespace mlir
 
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 2fe1495b2b593b..43b6d2b3841690 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -118,6 +118,9 @@ class Builder {
   // supports boolean, integer, and 16-/32-/64-bit float types, and vector or
   // ranked tensor of them. Returns null attribute otherwise.
   TypedAttr getZeroAttr(Type type);
+  // Returns a 1-valued attribute of the given `type`.
+  // Type constraints are the same as `getZeroAttr`.
+  TypedAttr getOneAttr(Type type);
 
   // Convenience methods for fixed types.
   FloatAttr getF16FloatAttr(float value);
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 4d4adb94a9fc8d..41bcd2b9f33e6b 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -20,9 +20,11 @@
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/Dominance.h"
 #include "mlir/IR/IRMapping.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include <optional>
@@ -1869,3 +1871,28 @@ mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
   results.push_back(residual);
   return results;
 }
+
+OpFoldResult
+mlir::affine::linearIndexFromShape(ArrayRef<OpFoldResult> multiIndex,
+                                   ArrayRef<OpFoldResult> shape,
+                                   ImplicitLocOpBuilder &builder) {
+  assert(multiIndex.size() == shape.size());
+  SmallVector<AffineExpr> shapeAffine;
+  for (size_t i = 0; i < shape.size(); ++i) {
+    shapeAffine.push_back(getAffineSymbolExpr(i, builder.getContext()));
+  }
+
+  SmallVector<AffineExpr> stridesAffine = computeStrides(shapeAffine);
+  SmallVector<OpFoldResult> strides;
+  strides.reserve(stridesAffine.size());
+  llvm::transform(stridesAffine, std::back_inserter(strides),
+                  [&builder, &shape](AffineExpr strideExpr) {
+                    return affine::makeComposedFoldedAffineApply(
+                        builder, builder.getLoc(), strideExpr, shape);
+                  });
+
+  auto &&[linearIndexExpr, multiIndexAndStrides] = computeLinearIndex(
+      OpFoldResult(builder.getIndexAttr(0)), strides, multiIndex);
+  return affine::makeComposedFoldedAffineApply(
+      builder, builder.getLoc(), linearIndexExpr, multiIndexAndStrides);
+}
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index bf274d4ae27ed8..999cdbc5c10008 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "llvm/ADT/SmallBitVector.h"
+#include <numeric>
 
 using namespace mlir;
 
@@ -262,3 +263,21 @@ Value ArithBuilder::slt(Value lhs, Value rhs) {
 Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) {
   return b.create<arith::SelectOp>(loc, cmp, lhs, rhs);
 }
+
+namespace mlir::arith {
+
+Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values) {
+  return createProduct(builder, loc, values, values.front().getType());
+}
+
+Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
+                    Type resultType) {
+  Value one = builder.create<ConstantOp>(loc, resultType,
+                                         builder.getOneAttr(resultType));
+  ArithBuilder arithBuilder(builder, loc);
+  return std::accumulate(
+      values.begin(), values.end(), one,
+      [&arithBuilder](Value acc, Value v) { return arithBuilder.mul(acc, v); });
+}
+
+} // namespace mlir::arith
\ No newline at end of file
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index a65b8f2e5a2376..762725d2c56e66 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -252,14 +252,20 @@ MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 
 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
                         MeshOp mesh) {
+  build(odsBuilder, odsState, mesh, SmallVector<MeshAxis>());
+}
+
+void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                        MeshOp mesh, ArrayRef<MeshAxis> axes) {
   build(odsBuilder, odsState,
-        SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
-        mesh.getSymName(),
-        MeshAxesAttr::get(odsBuilder.getContext(), SmallVector<MeshAxis>()));
+        SmallVector<Type>(axes.empty() ? mesh.getRank() : axes.size(),
+                          odsBuilder.getIndexType()),
+        mesh.getSymName(), MeshAxesAttr::get(odsBuilder.getContext(), axes));
 }
 
 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
                         StringRef mesh, ArrayRef<MeshAxis> axes) {
+  assert(!axes.empty());
   build(odsBuilder, odsState,
         SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
         MeshAxesAttr::get(odsBuilder.getContext(), axes));
@@ -590,6 +596,22 @@ static LogicalResult verifyScatterOperandAndResultShape(
   return success();
 }
 
+static RankedTensorType scatterResultType(Type operandType, MeshOp mesh,
+                                          ArrayRef<MeshAxis> meshAxes,
+                                          int64_t scatterAxis) {
+  RankedTensorType operandRankedTensorType =
+      operandType.cast<RankedTensorType>();
+  DimensionSize operandScatterAxisSize =
+      operandRankedTensorType.getShape()[scatterAxis];
+  SmallVector<int64_t> resultShape =
+      llvm::to_vector(operandRankedTensorType.getShape());
+
+  resultShape[scatterAxis] =
+      operandScatterAxisSize /
+      DimensionSize(collectiveProcessGroupSize(meshAxes, mesh));
+  return operandRankedTensorType.clone(resultShape);
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.all_gather op
 //===----------------------------------------------------------------------===//
@@ -625,6 +647,42 @@ void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// mesh.all_scatter op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+AllScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  return verifyScatterOperandAndResultShape(
+      getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
+      mesh.value().getShape());
+}
+
+void AllScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                               MLIRContext *context) {
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllScatterOp>>(context);
+}
+
+void AllScatterOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                         Value input, MeshOp mesh, ArrayRef<MeshAxis> meshAxes,
+                         int64_t scatterAxis) {
+  Type resultType =
+      scatterResultType(input.getType(), mesh, meshAxes, scatterAxis);
+  build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
+        scatterAxis);
+}
+
+void AllScatterOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                         Type resultType, Value input, StringRef mesh,
+                         ArrayRef<MeshAxis> meshAxes, int64_t scatterAxis) {
+  build(odsBuilder, odsState, resultType, mesh, meshAxes, input,
+        APInt(sizeof(scatterAxis) * CHAR_BIT, scatterAxis));
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.all_to_all op
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
index dccb75848c94f0..28af820440076c 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
@@ -13,7 +13,9 @@ add_mlir_dialect_library(MLIRMeshTransforms
 
   LINK_LIBS PUBLIC
   MLIRAffineDialect
+  MLIRAffineUtils
   MLIRArithDialect
+  MLIRArithUtils
   MLIRControlFlowDialect
   MLIRFuncDialect
   MLIRIR
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
index c0273cdaef7144..7fcac2312444f3 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "TransformsDetail.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
@@ -16,7 +17,6 @@
 #include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
-#include <iterator>
 #include <numeric>
 #include <utility>
 
@@ -56,13 +56,10 @@ namespace {
 // symbol tables.
 // We can't use DialectFoldInterface since the cache may be invalidated by some
 // pass changing the referenced MeshOp ops.
-struct MeshShapeFolder : OpRewritePattern<MeshShapeOp> {
-  template <typename... OpRewritePatternArgs>
-  MeshShapeFolder(SymbolTableCollection &symbolTableCollection,
-                  OpRewritePatternArgs &&...opRewritePatternArgs)
-      : OpRewritePattern(
-            std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
-        symbolTableCollection(symbolTableCollection) {}
+struct MeshShapeFolder
+    : OpRewritePatternWithSymbolTableCollection<MeshShapeOp> {
+  using OpRewritePatternWithSymbolTableCollection::
+      OpRewritePatternWithSymbolTableCollection;
   LogicalResult matchAndRewrite(MeshShapeOp op,
                                 PatternRewriter &rewriter) const override {
     ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
@@ -113,9 +110,6 @@ struct MeshShapeFolder : OpRewritePattern<MeshShapeOp> {
 
     return success();
   }
-
-private:
-  SymbolTableCollection &symbolTableCollection;
 };
 
 } // namespace
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index b649157a9e46de..21e03d9572c590 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -128,92 +128,24 @@ targetShardingInSplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
       sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
 }
 
-static ShapedType targetShapeInSplitLastAxis(ShapedType sourceShape,
-                                             int64_t splitTensorAxis,
-                                             int64_t splitCount) {
-  SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
-  targetShape[splitTensorAxis] =
-      shardDimension(targetShape[splitTensorAxis], splitCount);
-  return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
-}
-
 // Split a replicated tensor along a mesh axis.
 // e.g. [[0, 1]] -> [[0, 1, 2]].
 // Returns the spmdized target value with its sharding.
-//
-// The implementation is the extract the tensor slice corresponding
-// to the current device.
 static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
 splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
                           MeshShardingAttr sourceSharding,
                           TypedValue<ShapedType> sourceShard, MeshOp mesh,
                           int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
-  MLIRContext *ctx = builder.getContext();
-  builder.setInsertionPointAfterValue(sourceShard);
-
-  Value zero = builder.create<arith::ConstantOp>(builder.getIndexAttr(0));
-
-  Value processIndexAlongAxis =
+  TypedValue<ShapedType> targetShard =
       builder
-          .create<ProcessMultiIndexOp>(mesh.getSymName(),
-                                       SmallVector<MeshAxis>({splitMeshAxis}))
-          .getResult()[0];
-
+          .create<AllScatterOp>(sourceShard, mesh,
+                                ArrayRef<MeshAxis>(splitMeshAxis),
+                                splitTensorAxis)
+          .ge...
[truncated]

@sogartar
Copy link
Contributor Author

sogartar commented Feb 9, 2024

I am not sure what is the best name for this operation. Scatter is the inverse of gather. It makes sense that the inverse of all-gather would be all-scatter. But then all-scatter implies that all devices scatter their tensor, when we are actually slicing.
Local split is anther possible name.
If anyone has other suggestions I can change the name.

@sogartar
Copy link
Contributor Author

sogartar commented Feb 9, 2024

@yaochengji, can you review this?
Thank you for suggesting to add this op.

@sogartar
Copy link
Contributor Author

Here is a bit more context from a previous PR
#80518 (comment)

@@ -341,6 +342,68 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
let hasCanonicalizer = 1;
}

def Mesh_AllScatterOp : Mesh_CollectiveCommunicationOpBase<"all_scatter", [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if it's appropriate to include scatter in the operation's name, as scatter typically implies that the operation involves communication.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. Maybe all_slice or something like that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rename it to all_slice.

[&arithBuilder](Value acc, Value v) { return arithBuilder.mul(acc, v); });
}

} // namespace mlir::arith
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: new line

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added one. I thought the formatter would take care of this.

@@ -309,6 +311,11 @@ DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs);
FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
Value linearIndex,
ArrayRef<Value> basis);
// Generate IR that extracts the linear index form a multi-index according to
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: s/form/from/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -309,6 +311,11 @@ DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs);
FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
Value linearIndex,
ArrayRef<Value> basis);
// Generate IR that extracts the linear index form a multi-index according to
// a shape.
OpFoldResult linearIndexFromShape(ArrayRef<OpFoldResult> multiIndex,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

linearizeIndex?

Copy link
Contributor Author

@sogartar sogartar Feb 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed it. delinearizeIndex above it uses shape/basis as well. I changed the argument to basis for consistency.
There is another function computeLinearIndex in indexing utils that uses stride instead of basis.

@@ -341,6 +342,68 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
let hasCanonicalizer = 1;
}

def Mesh_AllScatterOp : Mesh_CollectiveCommunicationOpBase<"all_scatter", [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. Maybe all_slice or something like that?

let summary = "All-scatter over a device mesh. This is the inverse of all-gather.";
let description = [{
Scatter along the `scatter_axis` tensor axis.
This operation can be taught of as the inverse of all-gather.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: s/taught/thought/ ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

void processMultiIndexOpLoweringRegisterDialects(DialectRegistry &registry);

void allScatterOpLoweringPopulatePatterns(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

populateAllScatterOpLoweringPatterns to follow the baming convention?

ArrayRef<MeshAxis> meshAxes,
int64_t scatterAxis) {
RankedTensorType operandRankedTensorType =
operandType.cast<RankedTensorType>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This style of casting is deprecated? Prefer to use cast<RankedTensorType>(operandType).

@@ -46,6 +51,7 @@ struct ProcessMultiIndexOpLowering : OpRewritePattern<ProcessMultiIndexOp> {
builder.setInsertionPointAfter(op.getOperation());
Value linearIndex = builder.create<ProcessLinearIndexOp>(mesh);
ValueRange meshShape = builder.create<MeshShapeOp>(mesh).getResults();
llvm::errs() << "meshShape.size() = " << meshShape.size() << "\n";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debug leftover?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. I removed it.

llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);

// extract slice
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to flesh out the comment here a bit explaining the indeixng logic below--easier for readers to follow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some documentation at the beginning of the function describing the algorithm.

This op is the inverse of all-gather. It is useful to have an explicit
concise representation instead of having a blob of slicing logic.

Add lowering for the op that slices from the tensor based on the in-group
process index.

Make resharding generate an all-slice instead of inserting the slicing logic
directly.
@sogartar
Copy link
Contributor Author

Thank you, I squashed and rebased before merging.

@sogartar sogartar merged commit dc3258c into llvm:main Feb 15, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants