Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] Implement Mesh's ShardingInterface for Linalg ops #82284

Merged
merged 6 commits into from
Mar 8, 2024

Conversation

sogartar
Copy link
Contributor

Allows linalg structured operations to be handled during spmdization and sharding propagation.

There is only support for projected permutation indexing maps.

@llvmbot
Copy link
Collaborator

llvmbot commented Feb 19, 2024

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Boian Petkantchin (sogartar)

Changes

Allows linalg structured operations to be handled during spmdization and sharding propagation.

There is only support for projected permutation indexing maps.


Patch is 41.14 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/82284.diff

16 Files Affected:

  • (added) mlir/include/mlir/Dialect/Linalg/Transforms/AllInterfaces.h (+26)
  • (added) mlir/include/mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h (+20)
  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td (+6)
  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+4)
  • (modified) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h (+18)
  • (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h (+6)
  • (modified) mlir/include/mlir/InitAllDialects.h (+2-8)
  • (added) mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp (+24)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt (+5)
  • (added) mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp (+336)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp (-8)
  • (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+7)
  • (modified) mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp (+79)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp (+13)
  • (added) mlir/test/Dialect/Linalg/mesh-spmdization.mlir (+165)
  • (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (+3)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/AllInterfaces.h b/mlir/include/mlir/Dialect/Linalg/Transforms/AllInterfaces.h
new file mode 100644
index 00000000000000..a69751e072b797
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/AllInterfaces.h
@@ -0,0 +1,26 @@
+//===- AllInterfaces.h - ----------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a common entry point for registering all external
+// interface implementations to the linalg dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LINALG_TRANSFORMS_ALLINTERFACES_H
+#define MLIR_DIALECT_LINALG_TRANSFORMS_ALLINTERFACES_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace linalg {
+void registerAllDialectInterfaceImplementations(DialectRegistry &registry);
+} // namespace linalg
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LINALG_TRANSFORMS_ALLINTERFACES_H
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h
new file mode 100644
index 00000000000000..c57501ea86b7ed
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h
@@ -0,0 +1,20 @@
+//===- MeshShardingInterfaceImpl.h ----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H
+#define MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace linalg {
+void registerMeshShardingInterfaceExternalModels(DialectRegistry &registry);
+} // namespace linalg
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index fc2acc70381ef7..9d9b5892e1a51f 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -46,6 +46,12 @@ def Mesh_ReductionKind : I32EnumAttr<"ReductionKind",
   I32EnumAttrCase<"Sum", 1, "sum">,
   I32EnumAttrCase<"Max", 2, "max">,
   I32EnumAttrCase<"Min", 3, "min">,
+  I32EnumAttrCase<"Product", 4, "product">,
+  // Arithmetic mean.
+  I32EnumAttrCase<"Average", 5, "average">,
+  I32EnumAttrCase<"BitwiseAnd", 6, "bitwise_and">,
+  I32EnumAttrCase<"BitwiseOr", 7, "bitwise_or">,
+  I32EnumAttrCase<"BitwiseXor", 8, "bitwise_xor">,
   I32EnumAttrCase<"Generic", 100, "generic">
 ]> {
   let genSpecializedAttr = 0;
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 8ba7c111aea6bb..19020c29459821 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -340,6 +340,10 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
     attr-dict `:` type($input) `->` type($result)
   }];
   let hasCanonicalizer = 1;
+    let builders = [
+    OpBuilder<(ins "Value":$input, "StringRef":$mesh,
+      "ArrayRef<MeshAxis>":$meshAxes, "ReductionKind":$reduction)>
+  ];
 }
 
 def Mesh_AllSliceOp : Mesh_CollectiveCommunicationOpBase<"all_slice", [
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
index ffc9b6fb18be53..ab4df2ab028d43 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
@@ -22,6 +22,24 @@ class SymbolTableCollection;
 
 namespace mesh {
 
+// Retrieve the mesh axes corresponding to each operation loop iterator based
+// on the provided shardings for the op's operands and results.
+// Assumes that the indexingMaps are projected permutations.
+ShardingArray getMeshAxisAssignmentForLoopIterators(
+    ArrayRef<MeshShardingAttr> operandShardings,
+    ArrayRef<MeshShardingAttr> resultShardings,
+    ArrayRef<utils::IteratorType> loopIteratorTypes,
+    ArrayRef<AffineMap> indexingMaps);
+
+bool isAtLeastOneReductionIteratorSharded(
+    ArrayRef<utils::IteratorType> loopIteratorTypes,
+    ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators);
+
+// Get the set of mesh axes that correspond to reduction loop iterators.
+SmallVector<MeshAxis> getReductionMeshAxes(
+    ArrayRef<utils::IteratorType> loopIteratorTypes,
+    ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators);
+
 // Inserts a clone of the operation that has all ranked tensor
 // arguments/results sharded.
 void spmdizeTriviallyShardableOperation(
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
index aeab28961a4e1e..be82e2af399dc8 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
@@ -13,6 +13,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Value.h"
 #include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ArrayRef.h"
 
 namespace mlir {
 class RewritePatternSet;
@@ -37,6 +38,11 @@ TypedValue<IndexType>
 createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
                                  ImplicitLocOpBuilder &builder);
 
+// Get process linear index along the given mesh axes.
+TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
+                                               ArrayRef<MeshAxis> meshAxes,
+                                               ImplicitLocOpBuilder &builder);
+
 } // namespace mesh
 } // namespace mlir
 
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index e508d51205f347..04fc0f906a8fc4 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -43,10 +43,7 @@
 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h"
-#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
-#include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h"
-#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h"
 #include "mlir/Dialect/MLProgram/IR/MLProgram.h"
 #include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/MPI/IR/MPI.h"
@@ -155,10 +152,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   cf::registerBufferizableOpInterfaceExternalModels(registry);
   cf::registerBufferDeallocationOpInterfaceExternalModels(registry);
   gpu::registerBufferDeallocationOpInterfaceExternalModels(registry);
-  linalg::registerBufferizableOpInterfaceExternalModels(registry);
-  linalg::registerSubsetOpInterfaceExternalModels(registry);
-  linalg::registerTilingInterfaceExternalModels(registry);
-  linalg::registerValueBoundsOpInterfaceExternalModels(registry);
+  linalg::registerAllDialectInterfaceImplementations(registry);
   memref::registerAllocationOpInterfaceExternalModels(registry);
   memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
   memref::registerValueBoundsOpInterfaceExternalModels(registry);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp b/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp
new file mode 100644
index 00000000000000..cc9f8d23231ee1
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp
@@ -0,0 +1,24 @@
+//===- AllInterfaces.cpp - --------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h"
+
+#include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
+
+void mlir::linalg::registerAllDialectInterfaceImplementations(
+    DialectRegistry &registry) {
+  registerBufferizableOpInterfaceExternalModels(registry);
+  registerMeshShardingInterfaceExternalModels(registry);
+  registerSubsetOpInterfaceExternalModels(registry);
+  registerTilingInterfaceExternalModels(registry);
+  registerValueBoundsOpInterfaceExternalModels(registry);
+}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 4f47e3b8718454..513c54de5d7bfc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRLinalgTransforms
+  AllInterfaces.cpp
   BubbleUpExtractSlice.cpp
   BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
@@ -21,6 +22,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   InlineScalarOperands.cpp
   Interchange.cpp
   Loops.cpp
+  MeshShardingInterfaceImpl.cpp
   NamedOpConversions.cpp
   Padding.cpp
   Promotion.cpp
@@ -61,12 +63,15 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   MLIRIR
   MLIRMemRefDialect
   MLIRMemRefTransforms
+  MLIRMeshDialect
+  MLIRMeshTransforms
   MLIRLinalgDialect
   MLIRLinalgUtils
   MLIRSCFDialect
   MLIRSCFTransforms
   MLIRSCFUtils
   MLIRPass
+  MLIRShardingInterface
   MLIRSubsetOpInterface
   MLIRSparseTensorDialect
   MLIRTensorDialect
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
new file mode 100644
index 00000000000000..621885974b2ef3
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
@@ -0,0 +1,336 @@
+//===- MeshShardingInterfaceImpl.cpp --------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h"
+
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/TilingInterface.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include <iterator>
+#include <optional>
+#include <utility>
+
+namespace mlir::linalg {
+
+using MeshAxis = mesh::MeshAxis;
+using ReductionKind = mesh::ReductionKind;
+using MeshShardingAttr = mesh::MeshShardingAttr;
+using ShardingArray = mesh::ShardingArray;
+using MeshOp = mesh::MeshOp;
+
+static ReductionKind getReductionKind(Operation *op) {
+  return llvm::TypeSwitch<Operation *, ReductionKind>(op)
+      // Floating-point operations.
+      .Case([](arith::AddFOp op) { return ReductionKind::Sum; })
+      .Case([](arith::MulFOp op) { return ReductionKind::Product; })
+      .Case([](arith::MaximumFOp op) { return ReductionKind::Max; })
+      .Case([](arith::MinimumFOp op) { return ReductionKind::Min; })
+      // Integer operations.
+      .Case([](arith::AddIOp op) { return ReductionKind::Sum; })
+      .Case([](arith::OrIOp op) { return ReductionKind::BitwiseOr; })
+      .Case([](arith::XOrIOp op) { return ReductionKind::BitwiseXor; })
+      .Case([](arith::AndIOp op) { return ReductionKind::Sum; })
+      .Case([](arith::MaxUIOp op) { return ReductionKind::Max; })
+      .Case([](arith::MinUIOp op) { return ReductionKind::Min; })
+      .Case([](arith::MaxSIOp op) { return ReductionKind::Max; })
+      .Case([](arith::MinSIOp op) { return ReductionKind::Min; })
+      .Case([](arith::MulIOp op) { return ReductionKind::Product; })
+      .Default([](Operation *op) { return ReductionKind::Generic; });
+}
+
+static std::optional<Operation *> getReductionOp(LinalgOp op) {
+  SmallVector<Operation *> combinerOps;
+  Value reducedValue = matchReduction(op.getRegionOutputArgs(), 0, combinerOps);
+  if (!reducedValue || combinerOps.size() != 1) {
+    return std::nullopt;
+  }
+
+  return combinerOps[0];
+}
+
+static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) {
+  std::optional<Operation *> reductionOp = getReductionOp(op);
+  if (!reductionOp) {
+    return ReductionKind::Generic;
+  }
+  return getReductionKind(reductionOp.value());
+}
+
+static MeshOp getMesh(Operation *op,
+                      ArrayRef<MeshShardingAttr> operandShardings,
+                      ArrayRef<MeshShardingAttr> resultShardings,
+                      SymbolTableCollection &symbolTable) {
+  for (MeshShardingAttr sharding : operandShardings) {
+    if (sharding) {
+      return mesh::getMesh(op, sharding.getMesh(), symbolTable);
+    }
+  }
+
+  for (MeshShardingAttr sharding : resultShardings) {
+    if (sharding) {
+      return mesh::getMesh(op, sharding.getMesh(), symbolTable);
+    }
+  }
+
+  assert(false);
+}
+
+// Choose the operand based on the current process index along the reduction
+// mesh axes.
+// We need to use the initial value only once to avoid including it in the
+// reduction multiple times.
+// In each process group only the leading process with linear index 0 would use
+// the original operand.
+// The other processes would use the reduction operation neutral tensor.
+static Value createDestinationPassingStyleInitOperand(
+    LinalgOp op, Value spmdizedOperand, ArrayRef<MeshAxis> reductionMeshAxes,
+    MeshOp meshOp, ImplicitLocOpBuilder &builder) {
+  Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex(
+      meshOp.getSymName(), reductionMeshAxes, builder);
+  Value zero = builder.create<arith::ConstantIndexOp>(0);
+  Value isLeadProcess = builder.create<arith::CmpIOp>(
+      builder.getI1Type(), arith::CmpIPredicate::eq,
+      processLinearIndexInReductionGroup, zero);
+  scf::IfOp ifOp = builder.create<scf::IfOp>(spmdizedOperand.getType(),
+                                             isLeadProcess, true, true);
+  // Then block.
+  {
+    OpBuilder::InsertionGuard insertionGuard(builder);
+    builder.setInsertionPointToEnd(&ifOp.getThenRegion().front());
+    builder.create<scf::YieldOp>(spmdizedOperand);
+  }
+
+  // Else block.
+  {
+    OpBuilder::InsertionGuard insertionGuard(builder);
+    builder.setInsertionPointToEnd(&ifOp.getElseRegion().front());
+    SmallVector<OpFoldResult> shape =
+        tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand);
+    PartialReductionOpInterface partialReductionIface =
+        llvm::cast<PartialReductionOpInterface>(op.getOperation());
+    FailureOr<Operation *> reductionNeutralTensorOp =
+        partialReductionIface.generateInitialTensorForPartialReduction(
+            builder, builder.getLoc(), shape, {});
+    assert(succeeded(reductionNeutralTensorOp));
+    builder.create<scf::YieldOp>(
+        reductionNeutralTensorOp.value()->getResult(0));
+  }
+  return ifOp.getResult(0);
+}
+
+// Create the DPS init operands for the spmdized Linalg op.
+// Return all the new spmdized operands.
+static SmallVector<Value> createDestinationPassingStyleInitOperands(
+    LinalgOp op, MeshOp meshOp, ArrayRef<Value> spmdizedOperands,
+    ArrayRef<MeshAxis> reductionMeshAxes, IRMapping &spmdizationMap,
+    ImplicitLocOpBuilder &builder) {
+  // TODO: add support for multiple destination passing style initial value
+  // operands.
+  // PartialReductionOpInterface::generateInitialTensorForPartialReduction
+  // needs to also support multiple DPS initial operands.
+  SmallVector<Value> newOperands = llvm::to_vector(spmdizedOperands);
+  auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
+  Value spmdizedInitOperand =
+      spmdizationMap.lookup(op->getOperands()[operandIdx]);
+  newOperands[operandIdx] = createDestinationPassingStyleInitOperand(
+      op, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
+  return newOperands;
+}
+
+static void createAllReduceForResultWithoutPartialSharding(
+    Value unshardedLinalgOpResult, ArrayRef<MeshAxis> opReductionMeshAxes,
+    MeshShardingAttr resultSharding, ReductionKind reductionKind,
+    IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder) {
+  SmallVector<MeshAxis> allReduceMeshAxes;
+  llvm::copy_if(opReductionMeshAxes, std::back_inserter(allReduceMeshAxes),
+                [&resultSharding](MeshAxis axis) {
+                  return !llvm::is_contained(resultSharding.getPartialAxes(),
+                                             axis);
+                });
+  if (allReduceMeshAxes.empty()) {
+    return;
+  }
+
+  Value spmdizedLinalgOpResult = spmdizationMap.lookup(unshardedLinalgOpResult);
+  Value reducedValue = builder.create<mesh::AllReduceOp>(
+      spmdizedLinalgOpResult, resultSharding.getMesh().getValue(),
+      allReduceMeshAxes, reductionKind);
+  spmdizationMap.map(unshardedLinalgOpResult, reducedValue);
+}
+
+static void createAllReduceForResultsWithoutPartialShardings(
+    LinalgOp unshardedOp, ArrayRef<MeshAxis> opReductionMeshAxes,
+    ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
+    ImplicitLocOpBuilder &builder) {
+  ReductionKind reductionKind = getReductionKindOfLinalgOp(unshardedOp);
+  for (auto [unshardedLinalgOpResult, resultSharding] :
+       llvm::zip(unshardedOp->getResults(), resultShardings)) {
+    createAllReduceForResultWithoutPartialSharding(
+        unshardedLinalgOpResult, opReductionMeshAxes, resultSharding,
+        reductionKind, spmdizationMap, builder);
+  }
+}
+
+static void spmdizeLinalgOpWithShardedReduction(
+    LinalgOp op, ArrayRef<Value> spmdizedOperands,
+    ArrayRef<MeshShardingAttr> operandShardings,
+    ArrayRef<MeshShardingAttr> resultShardings,
+    ArrayRef<utils::IteratorType> loopIteratorTypes,
+    ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators,
+    IRMapping &spmdizationMap, SymbolTableCollection &symbolTable,
+    ImplicitLocOpBuilder &builder) {
+  MeshOp mesh = getMesh(op, operandShardings, resultShardings, symbolTable);
+  SmallVector<MeshAxis> reductionMeshAxes = mesh::getReductionMeshAxes(
+      loopIteratorTypes, meshAxisAssignmentForLoopIterators);
+  SmallVector<Value> spmdizedLinalgOpOperands =
+      createDestinationPassingStyleInitOperands(op, mesh, spmdizedOperands,
+                                                reductionMeshAxes,
+                                                spmdizationMap, builder);
+  // We must not change the operand mappings of the original spmdizationMap as
+  // they are the mappings for the whole spmdization blob and may be used by
+  // others.
+  IRMapping internalSpmdizationMap;
+  for (auto...
[truncated]

@sogartar
Copy link
Contributor Author

@yaochengji could you check out this PR?

@sogartar
Copy link
Contributor Author

Let me know if the code is not clear. I will put in more explanation/comments.

@joker-eph
Copy link
Collaborator

You setup the interface as a dynamically registered one, we should implement a "promise" for it, see #78368

@sogartar
Copy link
Contributor Author

@joker-eph, thank you for mentioning this. I can augment this PR if #78368 gets merged in first. If this gets merged first, I will add the promise in another PR.

@joker-eph
Copy link
Collaborator

There is no dependency on #78368 I believe: can't you just add the promise here?

@llvmbot llvmbot added the mlir:core MLIR Core Infrastructure label Feb 21, 2024
@sogartar
Copy link
Contributor Author

I added the promise. I copy-pasted one function from #78368 to register multiple ops at once.

Copy link
Member

@antiagainst antiagainst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Thanks for working this through! LGTM in general; I have some comments but not blocking.

@@ -0,0 +1,24 @@
//===- AllInterfaces.cpp - --------------------------------------*- C++ -*-===//
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: no need for C++ in the cpp file cause the file suffix is enough to tell. :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed it.

@@ -216,6 +216,11 @@ class Dialect {
{TypeID::get<ConcreteT>(), InterfaceT::getInterfaceID()});
}

template <typename InterfaceT, typename... ConcreteT>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an important header. Can we add some documentation for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a comment.

SmallVector<MeshShardingAttr> operatorAndResultShardings;
operatorAndResultShardings.reserve(operandShardings.size() +
resultShardings.size());
operatorAndResultShardings.insert(operatorAndResultShardings.end(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

append_range?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

for (auto [loopIteratorType, meshAxisAssignment] :
llvm::zip(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
if (loopIteratorType == utils::IteratorType::reduction) {
meshAxes.insert(meshAxes.end(), meshAxisAssignment.begin(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

append_range?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

operandShardings.begin(),
operandShardings.end());
for (auto [sharding, affineMap] :
llvm::zip(operatorAndResultShardings, indexingMaps)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use zip_equal to be clear? Similarly for the others in this file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

.Case([](arith::AndIOp op) { return ReductionKind::Sum; })
.Case([](arith::MaxUIOp op) { return ReductionKind::Max; })
.Case([](arith::MinUIOp op) { return ReductionKind::Min; })
.Case([](arith::MaxSIOp op) { return ReductionKind::Max; })
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No differentiation of signed/unsigned cases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a comment in the code that it needs addressing.

using MeshShardingAttr = mesh::MeshShardingAttr;
using ShardingArray = mesh::ShardingArray;
using MeshOp = mesh::MeshOp;

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// Returns the corresponding mesh reduction kind for the given arith op

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the comment.
Ultimately, here we would like to support any op that can be a reduction inside a Linalg op.

.Default([](Operation *op) { return ReductionKind::Generic; });
}

static std::optional<Operation *> getReductionOp(LinalgOp op) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd call it getCombinerOp to be consistent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

});
if (!allIndexingMapsAreProjectedPermutation) {
// TODO: handle non-projected permutations.
op->emitOpError()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can directly return op->emitOpError here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

if (!allIndexingMapsAreProjectedPermutation) {
// TODO: handle non-projected permutations.
op->emitOpError()
<< "Only projected permutation indexing maps are supported.";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: typically error messages start with lower case so it composes well with the prefix. Here it would read "'linag.*' op only projected ..". (Can also adjust the error message to make it read more naturally.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it.

@sogartar
Copy link
Contributor Author

sogartar commented Mar 6, 2024

@antiagainst, thank you for you review, I addressed your comments. The bigger changes that are required I kind of left them with a TODO.

@yaochengji
Copy link
Member

LGTM, thanks @sogartar

@sogartar sogartar force-pushed the linalg-generic-mesh-sharding-interface branch from 1bb6ebf to d8b806a Compare March 7, 2024 15:34
@sogartar
Copy link
Contributor Author

sogartar commented Mar 7, 2024

Thank you for the reviews, I rebased to resolve the conflicts.

@sogartar sogartar merged commit fb582b6 into llvm:main Mar 8, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:linalg mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants