Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][mesh] Add spmdization pass #80518

Merged
merged 1 commit into from
Feb 7, 2024
Merged

Conversation

sogartar
Copy link
Contributor

@sogartar sogartar commented Feb 3, 2024

Add a pass that converts a function that has sharding annotations into SPMD form.

@llvmbot
Copy link
Collaborator

llvmbot commented Feb 3, 2024

@llvm/pr-subscribers-mlir-func
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tosa

Author: Boian Petkantchin (sogartar)

Changes

Add a pass that converts a function that has sharding annotations into SPMD form.


Patch is 49.38 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/80518.diff

17 Files Affected:

  • (added) mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h (+23)
  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+43)
  • (modified) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h (+10-1)
  • (modified) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td (+46-1)
  • (added) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h (+125)
  • (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td (+75)
  • (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h (+20-4)
  • (modified) mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp (+2)
  • (modified) mlir/lib/Dialect/Func/Extensions/CMakeLists.txt (+15)
  • (added) mlir/lib/Dialect/Func/Extensions/MeshShardingExtensions.cpp (+24)
  • (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+32)
  • (modified) mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp (+59-6)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp (+1-5)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+241-50)
  • (modified) mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp (+2-28)
  • (added) mlir/test/Dialect/Mesh/spmdization.mlir (+131)
  • (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (+6-2)
diff --git a/mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h b/mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h
new file mode 100644
index 0000000000000..9b7abbca5d762
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h
@@ -0,0 +1,23 @@
+//===- ShardingInterfaceImpl.h - ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_FUNC_IR_SHARDINGINTERFACEIMPL_H_
+#define MLIR_DIALECT_FUNC_IR_SHARDINGINTERFACEIMPL_H_
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace func {
+
+void registerShardingInterfaceExternalModels(DialectRegistry &registry);
+
+} // namespace func
+} // namespace mlir
+
+#endif // MLIR_DIALECT_FUNC_IR_SHARDINGINTERFACEIMPL_H_
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 8e5e0f541ba5e..a819157fef147 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -10,11 +10,13 @@
 #define MLIR_DIALECT_MESH_IR_MESHOPS_H
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/MathExtras.h"
 
 namespace mlir {
 namespace mesh {
@@ -48,6 +50,11 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
 
 Partial getPartialTypeFromReduction(IteratorType iType);
 
+// Is the same tensor replicated on all processes.
+inline bool isFullReplication(MeshShardingAttr attr) {
+  return attr.getPartialAxes().empty() && attr.getSplitAxes().empty();
+}
+
 inline mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
                             SymbolTableCollection &symbolTableCollection) {
   return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
@@ -60,6 +67,13 @@ mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
   return getMesh(op.getOperation(), op.getMeshAttr(), symbolTableCollection);
 }
 
+template <>
+inline mesh::MeshOp
+getMesh<ShardOp>(ShardOp op, SymbolTableCollection &symbolTableCollection) {
+  return getMesh(op.getOperation(), op.getShardAttr().getMesh(),
+                 symbolTableCollection);
+}
+
 // Get the number of processes that participate in each group
 // induced by `meshAxes`.
 template <typename MeshAxesRange, typename MeshShapeRange>
@@ -78,6 +92,35 @@ int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
   return res;
 }
 
+inline int64_t shardDimension(int64_t dim, int64_t shardCount) {
+  if (ShapedType::isDynamic(dim) || ShapedType::isDynamic(shardCount))
+    return ShapedType::kDynamic;
+
+  assert(dim % shardCount == 0);
+  return ceilDiv(dim, shardCount);
+}
+
+inline int64_t unshardDimension(int64_t dim, int64_t shardCount) {
+  if (ShapedType::isDynamic(dim) || ShapedType::isDynamic(shardCount))
+    return ShapedType::kDynamic;
+
+  return dim * shardCount;
+}
+
+// Return the sharded shape `shape` acording ot sharding `sharding`.
+// The shape for the tensor on each device in the mesh.
+// Example:
+// On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1 would
+// result in a shape for each shard of ?x2x?.
+ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
+                           MeshShardingAttr sharding);
+
+// If ranked tensor type return its sharded counterpart.
+//
+// If not ranked tensor type return `type`.
+// `sharding` in that case must be null.
+Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding);
+
 } // namespace mesh
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index 3bef7e6babdec..cc90ddd40a622 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -10,11 +10,14 @@
 #define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
 
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/IR/Value.h"
 #include "mlir/Support/LLVM.h"
 
 namespace mlir {
 
 class Operation;
+class IRMapping;
+class SymbolTableCollection;
 
 namespace mesh {
 
@@ -58,8 +61,14 @@ defaultAddShardingAnnotations(Operation *op, OpBuilder &b,
 
 } // namespace detail
 
-} // namespace mesh
+// Assumes full replication on all ranked tensor arguments and results.
+void spmdizeFullyReplicatedOperation(
+    Operation &op, ArrayRef<Value> spmdizedOperands,
+    ArrayRef<MeshShardingAttr> operandShardings,
+    ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
+    SymbolTableCollection &symbolTable, OpBuilder &builder);
 
+} // namespace mesh
 } // namespace mlir
 
 /// Include the ODS generated interface header files.
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
index 21b6c8d4f599a..4afb1c36a72f7 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
@@ -88,7 +88,52 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
           return detail::defaultAddShardingAnnotations(
             $_op.getOperation(), b, shardingOption);
         }]
-      >
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Convert self to SPMD form.
+          This method is used during the spmdization pass of a program fully
+          annotated with shardings.
+
+          The spmdization algorithm would read the surrounding sharding
+          annotations from the IR for each argument/result and prepare
+          `operandShardings` and `resultShardings`.
+          Values that are not ranked tensors do not have sharding annotations.
+          In this case their corresponding MeshShardingAttr is null.
+
+          For convenience it will also prepare `spmdizedOperands`, although
+          they can be retrieved from the `spmdizationMap`.
+
+          The `spmdizationMap` contains a mapping from unsharded to
+          sharded/spmdized values that are constructed during the spmdization
+          pass. The interface implementation must populate `spmdizationMap`
+          with the mapping for this op's results.
+
+          `builder` is set to insert new operations in the appropriate point.
+          The implementation should not return the builder to the original
+          insertion point.
+          It should leave it as is after all insertions are done.
+
+          The default implementation does full replication.
+          This assumes that all sharding annotations are for full replication.
+        }],
+        /*retTy=*/"LogicalResult",
+        /*methodName=*/"spmdize",
+        /*args=*/(ins
+          "ArrayRef<Value>": $spmdizedOperands,
+          "ArrayRef<MeshShardingAttr>": $operandShardings,
+          "ArrayRef<MeshShardingAttr>": $resultShardings,
+          "IRMapping&": $spmdizationMap,
+          "SymbolTableCollection &": $symbolTableCollection,
+          "OpBuilder &":$builder
+        ),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          spmdizeFullyReplicatedOperation(
+            *$_op.getOperation(), spmdizedOperands, operandShardings,
+              resultShardings, spmdizationMap, symbolTableCollection, builder);
+          return success();
+        }]>
     ];
 
     let extraClassDeclaration = [{
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
new file mode 100644
index 0000000000000..8108386c2e043
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
@@ -0,0 +1,125 @@
+//===- ShardingInterfaceImpl.h ----------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
+#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
+
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Value.h"
+
+namespace mlir {
+
+class Operation;
+class IRMapping;
+class SymbolTableCollection;
+
+namespace mesh {
+
+// Inserts a clone of the operation that has all ranked tensor
+// arguments/results sharded.
+void spmdizeTriviallyShardableOperation(
+    Operation &op, ArrayRef<Value> spmdizedOperands,
+    ArrayRef<MeshShardingAttr> operandShardings,
+    ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
+    SymbolTableCollection &symbolTable, OpBuilder &builder);
+
+// All ranked tensor argument and result dimensions have
+// independent parallel loop iterators.
+template <typename Op>
+struct IndependentParallelIteratorDomainShardingInterface
+    : public ShardingInterface::ExternalModel<
+          IndependentParallelIteratorDomainShardingInterface<Op>, Op> {
+  SmallVector<IteratorType> getLoopIteratorTypes(Operation *operation) const {
+    SmallVector<IteratorType> iterTypes;
+    for (Type t : operation->getOperandTypes()) {
+      populateIteratorTypes(t, iterTypes);
+    }
+    for (Type t : operation->getResultTypes()) {
+      populateIteratorTypes(t, iterTypes);
+    }
+    return iterTypes;
+  }
+
+  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+    // TODO: implement.
+    return SmallVector<AffineMap>();
+  }
+
+  LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
+                        ArrayRef<MeshShardingAttr> operandShardings,
+                        ArrayRef<MeshShardingAttr> resultShardings,
+                        IRMapping &spmdizationMap,
+                        SymbolTableCollection &symbolTable,
+                        OpBuilder &builder) const {
+    spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
+                                       resultShardings, spmdizationMap,
+                                       symbolTable, builder);
+    return success();
+  }
+
+private:
+  void populateIteratorTypes(Type t,
+                             SmallVector<IteratorType> &iterTypes) const {
+    RankedTensorType rankedTensorType = t.dyn_cast<RankedTensorType>();
+    if (!rankedTensorType) {
+      return;
+    }
+
+    iterTypes.reserve(iterTypes.size() + rankedTensorType.getRank());
+    for (int64_t i = 0; i < rankedTensorType.getRank(); ++i) {
+      iterTypes.push_back(IteratorType::Parallel);
+    }
+  }
+};
+
+// Sharding of elementwise operations like tensor addition and multiplication.
+template <typename ElemwiseOp>
+struct ElementwiseShardingInterface
+    : public ShardingInterface::ExternalModel<
+          ElementwiseShardingInterface<ElemwiseOp>, ElemwiseOp> {
+  SmallVector<IteratorType> getLoopIteratorTypes(Operation *op) const {
+    Value val = op->getOperand(0);
+    auto type = val.getType().dyn_cast<RankedTensorType>();
+    if (!type)
+      return {};
+    SmallVector<IteratorType> types(type.getRank(), IteratorType::Parallel);
+    return types;
+  }
+
+  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+    MLIRContext *ctx = op->getContext();
+    Value val = op->getOperand(0);
+    auto type = val.getType().dyn_cast<RankedTensorType>();
+    if (!type)
+      return {};
+    int64_t rank = type.getRank();
+    int64_t num = op->getNumOperands() + op->getNumResults();
+    SmallVector<AffineMap> maps(num,
+                                AffineMap::getMultiDimIdentityMap(rank, ctx));
+    return maps;
+  }
+
+  LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
+                        ArrayRef<MeshShardingAttr> operandShardings,
+                        ArrayRef<MeshShardingAttr> resultShardings,
+                        IRMapping &spmdizationMap,
+                        SymbolTableCollection &symbolTable,
+                        OpBuilder &builder) const {
+    spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
+                                       resultShardings, spmdizationMap,
+                                       symbolTable, builder);
+    return success();
+  }
+};
+
+} // namespace mesh
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
index c09cf3e710d42..3758b617bc14e 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
@@ -29,4 +29,79 @@ def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
   ];
 }
 
+def Spmdization : Pass<"mesh-spmdization", "mlir::func::FuncOp"> {
+  let summary = "Partition a function into SPMD form.";
+  let description = [{
+    This pass fits in right after a pass that annotates the function with
+    shardings like the `ShardingPropagation` pass.
+    It operates on a fully annotated IR.
+
+    A fully annotated IR required that all ranked tensor operands, results and
+    block arguments are annotated with the `mesh.shard` operation.
+  
+    All direct descendant operations in the function must implement the
+    `ShardingInterface` interface or all their ranked tensor operands and
+    results must have full replication sharding.
+
+    The input IR must have sharding annotations such that each operation
+    that implements `ShardingInterface` can handle during spmdization with
+    its `spmdize` method.
+    This can be achieved with the `ShardingPropagation` pass.
+
+    If the function has multiple terminating blocks,
+    it is the responsibility of the the one who annotates the function with
+    shardings to make sure that all returns would be consisted that is,
+    have the same sharding.
+
+    Example:
+    ```mlir
+    mesh.mesh @mesh_1d(shape = 2)
+
+    func.func @f(
+      %arg0: tensor<2xi8>
+    // CHECK-SAME: -> tensor<2xi8> {
+    ) -> tensor<2xi8> {
+      %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xi8>
+      %1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+      %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
+      %3 = mesh.shard %2 to <@mesh_1d, [[0]]> : tensor<2xi8>
+      %4 = mesh.shard %3 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
+      return %4 : tensor<2xi8>
+    }
+    ```
+    Spmdizing the above would result in 
+    * resharding the fully replicated input into splitting it along the only
+    tensor axis.
+    * Performing the element-wise `abs` operation on each device.
+    * Resharding back to full replication with an all-gather.
+
+    ```mlir
+    mesh.mesh @mesh_1d(shape = 2)
+    func.func @f(%arg0: tensor<2xi8>) -> tensor<2xi8> {
+      // Reshard [[]] -> [[0]]
+      %c0 = arith.constant 0 : index
+      %c2 = arith.constant 2 : index
+      %0 = mesh.process_multi_index on @mesh_1d axes = [0] : index
+      %1 = mesh.mesh_shape @mesh_1d axes = [0] : index
+      %2 = arith.remui %c2, %1 : index
+      %3 = arith.cmpi eq, %2, %c0 : index
+      cf.assert %3, "Sharding a tensor with axis size that is not exactly divisible by the mesh axis size is not supported."
+      %4 = arith.divui %c2, %1 : index
+      %5 = arith.muli %4, %0 : index
+      %extracted_slice = tensor.extract_slice %arg0[%5] [1] [1] : tensor<2xi8> to tensor<1xi8>
+
+      %6 = tosa.abs %extracted_slice : (tensor<1xi8>) -> tensor<1xi8>
+
+      // Reshard [[0]] -> [[]]
+      %7 = mesh.all_gather %6 on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8>
+
+      return %7 : tensor<2xi8>
+    }
+    ```
+  }];
+  let dependentDialects = [
+    "mesh::MeshDialect"
+  ];
+}
+
 #endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
index 7cb992aac019b..f847ce30a1b40 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
@@ -16,16 +16,32 @@
 namespace mlir {
 namespace mesh {
 
-// Return the sharded shape `shape` acording ot sharding `sharding`.
-ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
-                           MeshShardingAttr sharding);
-
 // Insert resharding spmdization of the value `sourceShardValue`
 // from sharding `source` to sharding `target`.
 // `sourceShardValue` is the already sharded value according to `source`.
+//
+// Example
+//
+// ```mlir
+//   mesh.mesh @mesh_1d(shape = 2)
+//   ...
+//   %1 = mesh.shard %0 to <@mesh_1d, [[0]]> : tensor<2xi8>
+//   %2 = mesh.shard %1 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
+// ```
+//
+// Will result in
+//
+// ```mlir
+//   %1 = mesh.all_gather %0 on @mesh_1d mesh_axes = [0] gather_axis = 0 :
+//     tensor<1xi8> -> tensor<2xi8>
+// ```
 TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
                                ShardOp target,
                                TypedValue<ShapedType> sourceShardValue);
+TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
+                               ShardOp target,
+                               TypedValue<ShapedType> sourceShardValue,
+                               SymbolTableCollection &symbolTableCollection);
 
 void reshardingRegisterDependentDialects(DialectRegistry &registry);
 
diff --git a/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp b/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp
index b3a3a37663a21..eb6b59bb00f1b 100644
--- a/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp
+++ b/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp
@@ -8,9 +8,11 @@
 
 #include "mlir/Dialect/Func/Extensions/AllExtensions.h"
 #include "mlir/Dialect/Func/Extensions/InlinerExtension.h"
+#include "mlir/Dialect/Func/Extensions/MeshShardingExtensions.h"
 
 using namespace mlir;
 
 void mlir::func::registerAllExtensions(DialectRegistry &registry) {
   registerInlinerExtension(registry);
+  registerShardingInterfaceExternalModels(registry);
 }
diff --git a/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt b/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt
index 8712b1583719b..47363f48f95cc 100644
--- a/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt
+++ b/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt
@@ -1,6 +1,7 @@
 set(LLVM_OPTIONAL_SOURCES
   AllExtensions.cpp
   InlinerExtension.cpp
+  MeshShardingExtensions.cpp
   )
 
 add_mlir_extension_library(MLIRFuncInlinerExtension
@@ -16,6 +17,19 @@ add_mlir_extension_library(MLIRFuncInlinerExtension
   MLIRFuncDialect
   )
 
+add_mlir_extension_library(MLIRFuncMeshShardingExtensions
+  MeshShardingExtensions.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Func/Extensions
+
+  LINK_LIBS PUBLIC
+  MLIRFuncDialect
+  MLIRIR
+  MLIRShardingInterface
+  )
+
+
 add_mlir_extension_library(MLIRFuncAllExtensions
   AllExtensions.cpp
 
@@ -24,4 +38,5 @@ add_mlir_extension_library(MLIRFuncAllExtensions
 
   LINK_LIBS PUBLIC
   MLIRFuncInlinerExtension
+  MLIRFuncMeshShardingExtensions
   )
diff --git a/mlir/lib/Dialect/Func/Extensions/MeshShardingExtensions.cpp b/mlir/lib/Dialect/Func/Extensions/MeshShardingExtensions.cpp
new file mode 100644
index 0000000000000..da508cc95bfe1
--- /dev/null
+++ b/mlir/lib/Dialect/Func/Extensions/MeshShardingExtensions.cpp
@@ -0,0 +1,24 @@
+//===- MeshShardingExtensions.cpp - ---------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===-------------------------------...
[truncated]

@sogartar
Copy link
Contributor Author

sogartar commented Feb 3, 2024

@yaochengji, could review this PR?

return ceilDiv(dim, shardCount);
}

inline int64_t unshardDimension(int64_t dim, int64_t shardCount) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure unshard is a suitable name or not. Or call it gatherDimension?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed it.

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h Outdated Show resolved Hide resolved

```mlir
mesh.mesh @mesh_1d(shape = 2)
func.func @f(%arg0: tensor<2xi8>) -> tensor<2xi8> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be useful that the argument and result could be sharded optionally.

Copy link
Contributor Author

@sogartar sogartar Feb 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this mean the sharding special case of full replication? This is covered by the pass.
I taught that the passes before that are responsible for producing a complete sharding annotations, so that this pass does not need to deal with these special cases.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I meant the function type could change.

E.g. the shard of the argument tensor<2xi8> of a func is set to <@mesh_1d, [[0]]>, then after spmdization this argument could be of type tensor<1xi8>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you mean. I changed the example.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't see the option to control whether we should change the func signature. In some real cases every device has the full tensor but only use part of it.

I think a similar option as https://github.com/openxla/xla/blob/main/xla/service/spmd/spmd_partitioner.h#L66C8-L66C37 would be useful, and we can add it in the later PR.

@@ -0,0 +1,131 @@
// RUN: mlir-opt --mesh-spmdization --test-constant-fold %s | FileCheck %s
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add spmdiazation/reshard of a partial tensor?

E.g. a partial tensor resharded to a sharded tensor, then a reduce-scatter would be inserted.

Copy link
Contributor Author

@sogartar sogartar Feb 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is already a test that reshards from partial axis -> full replication at

The one you are proposing currently produces an unoptimal result.

mesh.mesh @mesh_1d(shape = 2)

func.func @partial_axis_to_split_axis(
  %arg0: tensor<10x14xf32>
) -> tensor<10x14xf32> {
  %0 = mesh.shard %arg0 to <@mesh_1d, [[]], partial = sum[0]> : tensor<10x14xf32>
  %1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users : tensor<10x14xf32>
  return %1 : tensor<10x14xf32>
}

results in

mesh.mesh @mesh_1d(shape = 2)
func.func @partial_axis_to_split_axis(%arg0: tensor<10x14xf32>) -> tensor<5x14xf32> {
  %c0 = arith.constant 0 : index
  %c10 = arith.constant 10 : index
  %0 = mesh.all_reduce %arg0 on @mesh_1d mesh_axes = [0] : tensor<10x14xf32> -> tensor<10x14xf32>
  %1 = mesh.process_multi_index on @mesh_1d axes = [0] : index
  %2 = mesh.mesh_shape @mesh_1d axes = [0] : index
  %3 = arith.remui %c10, %2 : index
  %4 = arith.cmpi eq, %3, %c0 : index
  cf.assert %4, "Sharding a tensor with axis size that is not exactly divisible by the mesh axis size is not supported."
  %5 = arith.divui %c10, %2 : index
  %6 = arith.muli %5, %1 : index
  %extracted_slice = tensor.extract_slice %0[%6, 0] [5, 14] [1, 1] : tensor<10x14xf32> to tensor<5x14xf32>
  return %extracted_slice : tensor<5x14xf32>
}

It first all-reduces to get to full replication than slices to split the axis.
There are 2 approaches

  1. Match this pattern and optimize it to a reduce-scatter. This would benefit if there are other cases where we get to this result. Unfortunately, then it will be brittle with respect to future changes of the full-to-split slicing.
  2. Detect the resharding pattern at the level of sharding annotations and produce the optimal reduce-scatter at the first place.

I would like to add this optimization in another PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest to add some canonicalization pattern on the mesh dialect later.

E.g. mesh.all_reduce + mesh.local_split might be optimized to mesh.reduce_scatter

Copy link
Contributor Author

@sogartar sogartar Feb 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mesh.local_split would be a useful op to abstract the above blob of ops. I will add it in subsequent PR and make resharding produce it instead. Then have another rewrite pattern to lower it to the slicing logic.

return success();
}

struct Spmdization : public impl::SpmdizationBase<Spmdization> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just drive by comment: please wrap the class in an anonymous namespace.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, fixed it.

Copy link
Member

@yaochengji yaochengji left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks

Add a pass that converts a function that has sharding annotations into SPMD
form.
@sogartar
Copy link
Contributor Author

sogartar commented Feb 7, 2024

Thank you for reviewing. I squashed and rebased before merging.

@sogartar sogartar merged commit adbf21f into llvm:main Feb 7, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants