Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][ArmSME] Add initial SME vector legalization pass #79152

Merged
merged 5 commits into from
Jan 31, 2024

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Jan 23, 2024

This adds a new pass (-arm-sme-vector-legalization) which legalizes vector operations so that they can be lowered to ArmSME. This initial patch adds decomposition for vector.outerproduct, vector.transfer_read, and vector.transfer_write when they operate on vector types larger than a single SME tile. For example, a [8]x[8]xf32 outer product would be decomposed into four [4]x[4]xf32 outer products, which could then be lowered to ArmSME. These three ops have been picked as supporting them alone allows lowering matmuls that use all ZA accumulators to ArmSME.

For it to be possible to legalize a vector type it has to be a multiple of an SME tile size, but other than that any shape can be used. E.g. vector<[8]x[8]xf32>, vector<[4]x[16]xf32>, vector<[16]x[4]xf32> can all be lowered to four vector<[4]x[4]xf32> operations.

In future, this pass will be extended with more SME-specific rewrites to legalize unrolling the reduction dimension of matmuls (which is not type-decomposition), which is why the pass has quite a general name.

@llvmbot
Copy link
Collaborator

llvmbot commented Jan 23, 2024

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir-sme

Author: Benjamin Maxwell (MacDue)

Changes

This adds a new pass (-arm-sme-vector-legalization) which legalizes vector operations so that they can be lowered to ArmSME. This initial patch adds decomposition for vector.outerproduct, vector.transfer_read, and vector.transfer_write when they operate on vector types larger than a single SME tile. For example, a [8]x[8]xf32 outer product would be decomposed into four [4]x[4]xf32 outer products, which could then be lowered to ArmSME. These three ops have been picked as supporting them alone allows lowering matmuls that use all ZA accumulators to ArmSME.

For it to be possible to legalize a vector type it has to be a multiple of an SME tile size, but other than that any shape can be used. E.g. vector&lt;[8]x[8]xf32&gt;, vector&lt;[4]x[16]xf32&gt;, vector&lt;[16]x[4]xf32&gt; can all be lowered to four vector&lt;[4]x[4]xf32&gt; operations.

In future, this pass will be extended with more SME-specific rewrites to legalize unrolling the reduction dimension of matmuls (which is not type-decomposition), which is why the pass has quite a general name.


Patch is 50.95 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/79152.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h (+3)
  • (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td (+19)
  • (modified) mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h (+7)
  • (modified) mlir/lib/Dialect/ArmSME/IR/Utils.cpp (+22)
  • (modified) mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt (+4-1)
  • (added) mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp (+308)
  • (added) mlir/test/Dialect/ArmSME/vector-legalization.mlir (+268)
  • (added) mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir (+109)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-multi-tile-transpose.mlir (+171)
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index aef2959265a7cd..9ba8c43551257b 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -32,6 +32,9 @@ std::unique_ptr<Pass> createEnableArmStreamingPass(
 /// Pass that allocates tile IDs to ArmSME operations.
 std::unique_ptr<Pass> createTileAllocationPass();
 
+/// Pass that legalizes vectors so they can be lowered to ArmSME.
+std::unique_ptr<Pass> createVectorLegalizationPass();
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 8d1ba6ed34e805..9a6f5446de0094 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -122,4 +122,23 @@ def TileAllocation
   let dependentDialects = ["func::FuncDialect"];
 }
 
+def VectorLegalization
+  : Pass<"arm-sme-vector-legalization", "mlir::ModuleOp"> {
+  let summary = "Legalize vectors for ArmSME";
+  let description = [{
+    This pass legalizes vector operations so that they can be lowered to ArmSME.
+    This includes decomposing operations that operate on vector types larger
+    than a single SME tile (e.g. `vector<[8]x[8]xf32>`) into multiple SME
+    tile-sized operations, as well as rewrites needed to get operations into
+    forms compatible with SME lowerings.
+  }];
+  let constructor = "mlir::arm_sme::createVectorLegalizationPass()";
+  let dependentDialects = [
+    "func::FuncDialect",
+    "arm_sme::ArmSMEDialect",
+    "vector::VectorDialect",
+    "arith::ArithDialect"
+  ];
+}
+
 #endif // MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 41702008ee48fb..e89be8ed81e03f 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -56,6 +56,13 @@ scf::ForOp createLoopOverTileSlices(
     PatternRewriter &rewriter, Location loc, Value initTile,
     std::function<Value(OpBuilder &, Location, Value, Value)> makeLoopBody);
 
+/// Returns true if `vType` is a multiple of an SME tile size. Note returns
+/// false if the `vType` exactly matches the size of an SME tile.
+bool isMultipleOfSMETileVectorType(VectorType vType);
+
+/// Creates a vector type for the SME tile of `elementType`.
+VectorType getSMETileTypeForElement(Type elementType);
+
 } // namespace mlir::arm_sme
 
 #endif // MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
index d0391210555662..6a9e0221822267 100644
--- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
@@ -94,4 +94,26 @@ scf::ForOp createLoopOverTileSlices(
   return forOp;
 }
 
+bool isMultipleOfSMETileVectorType(VectorType vType) {
+  if (vType.getRank() != 2 || !vType.allDimsScalable())
+    return false;
+
+  auto elementType = vType.getElementType();
+  if (!isValidSMETileElementType(elementType))
+    return false;
+
+  unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
+
+  int64_t vectorRows = vType.getDimSize(0);
+  int64_t vectorCols = vType.getDimSize(1);
+
+  return (vectorRows > minNumElts || vectorCols > minNumElts) &&
+         vectorRows % minNumElts == 0 && vectorCols % minNumElts == 0;
+}
+
+VectorType getSMETileTypeForElement(Type elementType) {
+  unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
+  return VectorType::get({minNumElts, minNumElts}, elementType, {true, true});
+}
+
 } // namespace mlir::arm_sme
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index 96eb5844204384..3c32fc2645ce1b 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRArmSMETransforms
   EnableArmStreaming.cpp
   TileAllocation.cpp
+  VectorLegalization.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Transforms
@@ -9,10 +10,12 @@ add_mlir_dialect_library(MLIRArmSMETransforms
   MLIRArmSMETransformsIncGen
 
   LINK_LIBS PUBLIC
+  MLIRPass
   MLIRArmSMEDialect
   MLIRFuncDialect
   MLIRLLVMCommonConversion
   MLIRVectorDialect
   MLIRSCFDialect
-  MLIRPass
+  MLIRSCFTransforms
+  MLIRFuncTransforms
   )
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
new file mode 100644
index 00000000000000..a801ebe27413d5
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -0,0 +1,308 @@
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+#include "mlir/Dialect/ArmSME/Utils/Utils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
+
+#define DEBUG_TYPE "arm-sme-vector-legalization"
+
+namespace mlir::arm_sme {
+#define GEN_PASS_DEF_VECTORLEGALIZATION
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
+} // namespace mlir::arm_sme
+
+using namespace mlir;
+using namespace mlir::arm_sme;
+
+namespace {
+
+struct SMETile {
+  // Note: The units of (row, col) are vscale (as SME tiles are scalable).
+  int row{0};
+  int col{0};
+  VectorType type;
+};
+
+/// Adds a constant scalable offset to `indices`. i.e. for 2D:
+/// { indices[0] + offset[0] * vscale, indices[1] + offset[1] * vscale }
+SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder,
+                                                Location loc,
+                                                ValueRange indices,
+                                                ArrayRef<int> scalableOffset) {
+  auto vscale = builder.create<vector::VectorScaleOp>(loc);
+  return llvm::map_to_vector(
+      llvm::zip_equal(indices, scalableOffset), [&](auto pair) -> Value {
+        auto [index, base] = pair;
+        auto offset = builder.create<arith::MulIOp>(
+            loc, builder.create<arith::ConstantIndexOp>(loc, base), vscale);
+        return builder.create<arith::AddIOp>(loc, index, offset);
+      });
+}
+
+/// Remaps indices (e.g. from a load/store) for a larger vector type to indices
+/// for one of the SME tiles it will decompose into.
+SmallVector<Value, 2> remapIndicesForSMETile(OpBuilder &builder, Location loc,
+                                             ValueRange indices,
+                                             SMETile tileTile) {
+  return addConstantScalableOffset(builder, loc, indices,
+                                   {tileTile.row, tileTile.col});
+}
+
+/// Returns true if `mask` is generated by an operation that can be decomposed
+/// for SME. Currently, that is just no mask, or vector.create_mask.
+bool isSupportedMaskOp(Value mask) {
+  return !mask || mask.getDefiningOp<vector::CreateMaskOp>();
+}
+
+/// Extracts a mask for an SME tile from the mask of a larger vector type.
+Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
+                     SMETile tileTile) {
+  assert(isSupportedMaskOp(mask));
+  if (!mask)
+    return Value{};
+  auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
+  // The the operands of `vector.create_mask` (from a 2D perspective) are the
+  // coordinates where the mask ends. So we subtract where this tile starts,
+  // from the mask operands to get the parameters for this tile tile.
+  auto tileMaskDims = addConstantScalableOffset(
+      builder, loc, createMask.getOperands(), {-tileTile.row, -tileTile.col});
+  auto createTileMask = builder.create<vector::CreateMaskOp>(
+      loc, tileTile.type.clone(builder.getI1Type()), tileMaskDims);
+  return createTileMask.getResult();
+}
+
+/// Constructs an iterator that returns each SME tile (with coordinates)
+/// contained within a VectorType.
+auto decomposeToSMETiles(OpBuilder &builder, VectorType type,
+                         VectorType smeTileType,
+                         bool transposeIndices = false) {
+  assert(isMultipleOfSMETileVectorType(type));
+  return llvm::map_range(
+      StaticTileOffsetRange(type.getShape(), {smeTileType.getDimSize(0),
+                                              smeTileType.getDimSize(1)}),
+      [=](auto indices) {
+        int row = int(indices[0]);
+        int col = int(indices[1]);
+        if (transposeIndices)
+          std::swap(row, col);
+        return SMETile{row, col, smeTileType};
+      });
+}
+
+/// Returns the number of SME tiles that fit into the a vector type.
+int getNumberOfSMETilesForVectorType(VectorType type) {
+  assert(isMultipleOfSMETileVectorType(type));
+  int64_t vectorRows = type.getDimSize(0);
+  int64_t vectorCols = type.getDimSize(1);
+  auto elementType = type.getElementType();
+  unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
+  return (vectorRows * vectorCols) / (minNumElts * minNumElts);
+}
+
+/// Legalize `vector.outerproduct` operations to fit within SME tiles.
+struct LegalizeVectorOuterProductOp
+    : public OneToNOpConversionPattern<vector::OuterProductOp> {
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::OuterProductOp outerProductOp, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    auto vectorType = outerProductOp.getResultVectorType();
+    if (!isMultipleOfSMETileVectorType(vectorType))
+      return failure();
+
+    Value mask;
+    Operation *rootOp = outerProductOp;
+    auto loc = outerProductOp.getLoc();
+    if (outerProductOp.isMasked()) {
+      auto maskOp = outerProductOp.getMaskingOp();
+      mask = maskOp.getMask();
+      rootOp = maskOp;
+    }
+
+    if (!isSupportedMaskOp(mask))
+      return failure();
+
+    ValueRange accSMETiles = adaptor.getAcc();
+    auto tileType = getSMETileTypeForElement(vectorType.getElementType());
+    VectorType sliceType = VectorType::Builder(tileType).dropDim(0);
+
+    SmallVector<Value> resultSMETiles;
+    for (auto [index, tileTile] :
+         llvm::enumerate(decomposeToSMETiles(rewriter, vectorType, tileType))) {
+
+      auto tileMask = extractSMEMask(rewriter, loc, mask, tileTile);
+      auto lhs = rewriter.create<vector::ScalableExtractOp>(
+          loc, sliceType, outerProductOp.getLhs(), tileTile.row);
+      auto rhs = rewriter.create<vector::ScalableExtractOp>(
+          loc, sliceType, outerProductOp.getRhs(), tileTile.col);
+      auto tileOuterProduct = rewriter.create<vector::OuterProductOp>(
+          loc, tileType, lhs, rhs,
+          !accSMETiles.empty() ? accSMETiles[index] : Value{},
+          outerProductOp.getKind());
+
+      auto maskedOuterProduct =
+          vector::maskOperation(rewriter, tileOuterProduct, tileMask);
+      resultSMETiles.push_back(maskedOuterProduct->getResult(0));
+    }
+
+    rewriter.replaceOp(rootOp, resultSMETiles, adaptor.getResultMapping());
+    return success();
+  }
+};
+
+// Workaround for `vector.mask`. We want to match on `vector.outerproduct` (to
+// get the help of the type conversion), but doing so results in the type
+// conversion adding target materializations in the `vector.mask` region
+// (invalid). This pattern matches on `vector.mask` then calls into the
+// `vector.outerproduct` pattern to work around this issue.
+struct LegalizeMaskedVectorOuterProductOp
+    : public OneToNOpConversionPattern<vector::MaskOp> {
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    if (auto outerProductOp =
+            llvm::dyn_cast<vector::OuterProductOp>(maskOp.getMaskableOp())) {
+      LegalizeVectorOuterProductOp pattern(*getTypeConverter(), getContext());
+      return static_cast<RewritePattern &>(pattern).matchAndRewrite(
+          outerProductOp, rewriter);
+    }
+    return failure();
+  }
+};
+
+/// Legalize `vector.transfer_read` operations to fit within SME tiles.
+struct LegalizeTransferReadOp
+    : public OneToNOpConversionPattern<vector::TransferReadOp> {
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    auto vectorType = readOp.getVectorType();
+    if (!isMultipleOfSMETileVectorType(vectorType))
+      return failure();
+
+    auto mask = readOp.getMask();
+    if (!isSupportedMaskOp(mask))
+      return failure();
+
+    auto permutationMap = readOp.getPermutationMap();
+    if (!permutationMap.isPermutation())
+      return failure();
+
+    // Note: For 2D vector types the only non-identity permutation is a simple
+    // tranpose [1, 0].
+    bool transposed = !permutationMap.isIdentity();
+
+    auto loc = readOp.getLoc();
+    auto tileType = getSMETileTypeForElement(vectorType.getElementType());
+
+    SmallVector<Value> resultSMETiles;
+    for (SMETile tileTile :
+         decomposeToSMETiles(rewriter, vectorType, tileType, transposed)) {
+      auto tileMask = extractSMEMask(rewriter, loc, mask, tileTile);
+      auto transferRead = rewriter.create<vector::TransferReadOp>(
+          loc, tileType, readOp.getSource(),
+          remapIndicesForSMETile(rewriter, loc, readOp.getIndices(), tileTile),
+          readOp.getPermutationMapAttr(), readOp.getPadding(), tileMask,
+          readOp.getInBoundsAttr());
+      resultSMETiles.push_back(transferRead);
+    }
+
+    rewriter.replaceOp(readOp, resultSMETiles, adaptor.getResultMapping());
+    return success();
+  }
+};
+
+/// Legalize `vector.transfer_write` operations to fit within SME tiles.
+struct LegalizeTransferWriteOp
+    : public OneToNOpConversionPattern<vector::TransferWriteOp> {
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    auto vectorType = writeOp.getVectorType();
+    if (!isMultipleOfSMETileVectorType(vectorType))
+      return failure();
+
+    auto mask = writeOp.getMask();
+    if (!isSupportedMaskOp(mask))
+      return failure();
+
+    auto permutationMap = writeOp.getPermutationMap();
+    if (!permutationMap.isPermutation())
+      return failure();
+
+    // Note: For 2D vector types the only non-identity permutation is a simple
+    // tranpose [1, 0].
+    bool transposed = !permutationMap.isIdentity();
+
+    auto loc = writeOp.getLoc();
+    auto tileType = getSMETileTypeForElement(vectorType.getElementType());
+    auto inputSMETiles = adaptor.getVector();
+
+    Value destTensorOrMemref = writeOp.getSource();
+    for (auto [index, tileTile] : llvm::enumerate(
+             decomposeToSMETiles(rewriter, vectorType, tileType, transposed))) {
+      auto tileMask = extractSMEMask(rewriter, loc, mask, tileTile);
+      auto tileWrite = rewriter.create<vector::TransferWriteOp>(
+          loc, inputSMETiles[index], destTensorOrMemref,
+          remapIndicesForSMETile(rewriter, loc, writeOp.getIndices(), tileTile),
+          writeOp.getPermutationMapAttr(), tileMask, writeOp.getInBoundsAttr());
+      if (writeOp.hasPureTensorSemantics())
+        destTensorOrMemref = tileWrite.getResult();
+    }
+
+    if (writeOp.hasPureTensorSemantics())
+      rewriter.replaceOp(writeOp, destTensorOrMemref);
+    else
+      rewriter.eraseOp(writeOp);
+
+    return success();
+  }
+};
+
+struct VectorLegalizationPass
+    : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
+  void runOnOperation() override {
+    auto *context = &getContext();
+    OneToNTypeConverter converter;
+    RewritePatternSet patterns(context);
+
+    converter.addConversion([](Type type) { return type; });
+    converter.addConversion(
+        [](VectorType vectorType,
+           SmallVectorImpl<Type> &types) -> std::optional<LogicalResult> {
+          if (!isMultipleOfSMETileVectorType(vectorType))
+            return std::nullopt;
+          auto tileTileCount = getNumberOfSMETilesForVectorType(vectorType);
+          auto tileType = getSMETileTypeForElement(vectorType.getElementType());
+          types = SmallVector<Type>(tileTileCount, tileType);
+          return success();
+        });
+
+    // Note: High benefit to ensure masked outer products are lowered first.
+    patterns.add<LegalizeMaskedVectorOuterProductOp>(converter, context, 1024);
+    patterns.add<LegalizeVectorOuterProductOp, LegalizeTransferReadOp,
+                 LegalizeTransferWriteOp>(converter, context);
+    populateFuncTypeConversionPatterns(converter, patterns);
+    scf::populateSCFStructuralOneToNTypeConversions(converter, patterns);
+
+    if (failed(applyPartialOneToNConversion(getOperation(), converter,
+                                            std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::arm_sme::createVectorLegalizationPass() {
+  return std::make_unique<VectorLegalizationPass>();
+}
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
new file mode 100644
index 00000000000000..a20abeefedcfd4
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -0,0 +1,268 @@
+// RUN: mlir-opt %s -arm-sme-vector-legalization -cse -canonicalize -split-input-file | FileCheck %s
+
+// CHECK-LABEL: @outerproduct_f32_scalable_8x8_no_acc(
+// CHECK-SAME:                                        %[[LHS:.*]]: vector<[8]xf32>,
+// CHECK-SAME:                                        %[[RHS:.*]]: vector<[8]xf32>)
+// CHECK-SAME: -> (vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>)
+func.func @outerproduct_f32_scalable_8x8_no_acc(%lhs: vector<[8]xf32>, %rhs: vector<[8]xf32>) -> vector<[8]x[8]xf32>
+{
+  // CHECK-DAG: %[[LHS_0:.*]] = vector.scalable.extract %[[LHS]][0] : vector<[4]xf32> from vector<[8]xf32>
+  // CHECK-DAG: %[[RHS_0:.*]] = vector.scalable.extract %[[RHS]][0] : vector<[4]xf32> from vector<[8]xf32>
+  // CHECK-DAG: %[[LHS_1:.*]] = vector.scalable.extract %[[LHS]][4] : vector<[4]xf32> from vector<[8]xf32>
+  // CHECK-DAG: %[[RHS_1:.*]] = vector.scalable.extract %[[RHS]][4] : vector<[4]xf32> from vector<[8]xf32>
+  // CHECK-DAG: %[[TOP_LEFT:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32>
+  // CHECK-DAG: %[[TOP_RIGHT:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_1]] : vector<[4]xf32>, vector<[4]xf32>
+  // CHECK-DAG: %[[BOTTOM_LEFT:.*]] = vector.outerproduct %[[LHS_1]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32>
+  // CHECK-DAG: %[[BOTTOM_RIGHT:.*]] = vector.outerproduct %[[LHS_1]], %[[RHS_1]] : vector<[4]xf32>, vector<[4]xf32>
+  // CHECK-NEXT: return %[[TOP_LEFT]], %[[TOP_RIGHT]], %[[BOTTOM_LEFT]], %[[BOTTOM_RIGHT]] : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>
+  %0 = vector.outerproduct %lhs, %rhs : vector<[8]xf32>, vector<[8]xf32>
+  return %0 : vector<[8]x[8]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_f32_scalable_4x16_acc(
+// CHECK-SAME:                                      %[[LHS:.*]]: vector<[4]xf32>,
+// CHECK-SAME:                                      %[[RHS:.*]]: vector<[16]xf32>,
+// CHECK-SAME:                                      %[[ACC_0:[A-Za-z0-9]*]]: vector<[4]x[4]xf32>,
+// CHECK-SAME:                                      %[[ACC_1:[A-Za-z0-9]*]]: vector<[4]x[4]xf32>,
+// CHECK-SAME:                                      %[[ACC_2:[A-Za-z0-9]*]]: vector<[4]x[4]xf32>,
+// CHECK-SAME:                                      %[[ACC_3:[A-Za-z0-9]*]]: vector<[4]x[4]xf32>)
+// CHECK-SAME: -> (vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>)
+func.func @outerproduct_f32_scalable_4x16_acc(%lhs: vector<[4]xf32>, %rhs: vector<[16]xf32>, %acc: ve...
[truncated]

@MacDue
Copy link
Member Author

MacDue commented Jan 23, 2024

Note: The tests passing depend on #78983.

mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h Outdated Show resolved Hide resolved
mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp Outdated Show resolved Hide resolved
let description = [{
This pass legalizes vector operations so that they can be lowered to ArmSME.
This includes decomposing operations that operate on vector types larger
than a single SME tile (e.g. `vector<[8]x[8]xf32>`) into multiple SME
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vector<[8]x[8]xf32> matches the size of ZA, right? How about vectors that are larger than this? Or smaller?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That matches the size of ZA. There's infinite bigger vector types (this pass does not care how many tiles the resulting type uses), and a few smaller sizes. There may also be cases when you want sizes smaller than ZA (but more than a tile), i.e. a matmul with a rather small width/height but a large reduction dimension. This pass allows the tiling to be fairly flexible.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right - but are there any limitations? I know that there are - those are effectively enforced by isMultipleOfSMETileVectorType, right? It would be good add a few more comments here, e.g.:

  • "legal" in this context means something that matches SMEs virtual tiles as available in hardware
  • "input" type must be a multiple of these tiles (so definitely 2d and scalable in both dimensions)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

legal - means something that can be lowered to ArmSME, as mentioned this pass will be broader than just type decomposition, and will include rewrites to eliminate illegal types like vector<[8]x4xf32> (which also falls under legalizing vector operations).

(I think the description already mentions that the decomposition breaks down larger vector types into multiple SME-sized operations?)

I'll add the input size limitation 👍

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

legal - means something that can be lowered to ArmSME

Perhaps you could try to defining that? I assume it means either:

  • 2d scalable vectors with sizes matching SME virtual tiles,
  • 1d scalable vectors
  • n-D vectors with at most the trailing dim scalable?

mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp Outdated Show resolved Hide resolved
Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for patch Ben this is good work! Not got through all of it yet, but left a few comments for now. Will take another pass later

mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp Outdated Show resolved Hide resolved

SmallVector<Value> resultSMETiles;
for (auto [index, tileTile] :
llvm::enumerate(decomposeToSMETiles(rewriter, vectorType, tileType))) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems a bit odd to be telling the function doing the legalization what the legal type is, the only benefit I can see is not having to initialise sliceType in the loop?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I follow? decomposeToSMETiles() is passed the tileType which it uses to work out the step of the iterator, and passes it to the SMETile too (which makes some helper functions less verbose).

mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp Outdated Show resolved Hide resolved
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Posting these comments now as I need to be AFK for a while.

My general impression is that this is missing a clearly defined input and output formats. In the context of this pass, what is a tile?

  • SME tile that can me mapped onto hardware virtual tile?
  • Linalg tile (as in a generic tile after Linalg tiling)?

What do we get after this decomposition? Also tiles? Perhaps something along the lines:

Maps "generic" 2D scalable tiles (as produced by the Linalg tiling and vectorisation logic) into "SME" 2D tiles (smaller 2D scalable tiles) that match the size of SME virtual tiles (as available in hardware).

Basically, variable names like tileTile are a hint that we need to expand the dictionary a bit :-) Perhaps just use GenericTile and SMESubTile?

mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp Outdated Show resolved Hide resolved
let description = [{
This pass legalizes vector operations so that they can be lowered to ArmSME.
This includes decomposing operations that operate on vector types larger
than a single SME tile (e.g. `vector<[8]x[8]xf32>`) into multiple SME
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right - but are there any limitations? I know that there are - those are effectively enforced by isMultipleOfSMETileVectorType, right? It would be good add a few more comments here, e.g.:

  • "legal" in this context means something that matches SMEs virtual tiles as available in hardware
  • "input" type must be a multiple of these tiles (so definitely 2d and scalable in both dimensions)

mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp Outdated Show resolved Hide resolved
@MacDue
Copy link
Member Author

MacDue commented Jan 26, 2024

Basically, variable names like tileTile are a hint that we need to expand the dictionary a bit :-) Perhaps just use GenericTile and > SMESubTile?

tileTile is just a rename typo, that was always meant to be called subTile or smeTile

(I think I accidentally did s/sub/tile/g)

@MacDue
Copy link
Member Author

MacDue commented Jan 26, 2024

Perhaps just use GenericTile and SMESubTile?

(all uses of the word 'tile' were intended to refer to SME tiles, and the generic case is just a vectorType).

@banach-space
Copy link
Contributor

Perhaps just use GenericTile and SMESubTile?

(all uses of the word 'tile' were intended to refer to SME tiles, and the generic case is just a vectorType).

Got it - I still think it would be good to use stricter names. For somebody focusing on SME, "tile" will normally mean "SME virtual tile". So all good. However, for people more used to Linalg level of abstractions, "tile" is just a slice of the iteration space. I just think that the term is a bit ambiguous and overloaded.

Another option that could help disambiguate:

  • ZATile to represent e.g. vector<[8]x[8]xi32>, and
  • ZASubTile for things like vector<[4]x[4]xi32>.

These would be different for different element types.

@MacDue MacDue force-pushed the arm_sme_vector_legalization branch from 1be013a to 0fc3fbd Compare January 26, 2024 11:15
@MacDue
Copy link
Member Author

MacDue commented Jan 26, 2024

I've added some more context to the header of the pass, that within this context tile always means SME tile (and renamed most things to say smeTile).

Another option that could help disambiguate:

  • ZATile to represent e.g. vector<[8]x[8]xi32>, and
  • ZASubTile for things like vector<[4]x[4]xi32>.

These would be different for different element types.

This pass is not considering ZA, the input vectorType does not have map to the size of ZA, it can be bigger, smaller, or the same size, but that's not the concern of this pass. This pass just wants to convert operations on vectorTypes that are some multiple of SME tiles, into multiple SME-sized operations.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The extra comments are super helpful, thank you Ben!

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing my comments - LGTM!

@dcaballe
Copy link
Contributor

Great work! I acknowledge the complexity of this PR but I was wondering how much the "vector register splitting" side could be generalized to make it reusable by other targets. Vector register legalization is something that has come up multiple times so having a pass that can be parametrized accordingly would be extremely valuable. What are your thoughts about this? :)

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great work Ben this is very well written! LGTM cheers

mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp Outdated Show resolved Hide resolved
@c-rhodes
Copy link
Collaborator

Great work! I acknowledge the complexity of this PR but I was wondering how much the "vector register splitting" side could be generalized to make it reusable by other targets. Vector register legalization is something that has come up multiple times so having a pass that can be parametrized accordingly would be extremely valuable. What are your thoughts about this? :)

I agree, made a similar observation offline and think we should look into this as a follow up.

@MacDue
Copy link
Member Author

MacDue commented Jan 30, 2024

Great work! I acknowledge the complexity of this PR but I was wondering how much the "vector register splitting" side could be generalized to make it reusable by other targets. Vector register legalization is something that has come up multiple times so having a pass that can be parametrized accordingly would be extremely valuable. What are your thoughts about this? :)

I think the logic here could be generalized (though, as written, I've made assumptions about scalability). I think it'd always be a rank-preserving legalization (i.e. 2D -> 2D, like here, or 1D to 1D, etc). Other transforms like unrolling fixed vectors can already be handled elsewhere (and would be complex to incorporate).

This adds a new pass (`-arm-sme-vector-legalization`) which legalizes
vector operations so that they can be lowered to ArmSME. This initial
patch adds decomposition for `vector.outerproduct`,
`vector.transfer_read`, and `vector.transfer_write` when they operate on
vector types larger than a single SME tile. For example, a [8]x[8]xf32
outer product would be decomposed into four [4]x[4]xf32 outer products,
which could then be lowered to ArmSME. These three ops have been picked
as supporting them alone allows lowering matmuls that use all ZA
accumulators to ArmSME.

For it to be possible to legalize a vector type it has to be a multiple
of an SME tile size, but other than that any shape can be used. E.g.
`vector<[8]x[8]xf32>`, `vector<[4]x[16]xf32>`, `vector<[16]x[4]xf32>`
can all be lowered to four `vector<[4]x[4]xf32>` operations.

In future, this pass will be extended with more SME-specific rewrites to
legalize unrolling the reduction dimension of matmuls (which is not
type-decomposition), which is why the pass has quite a general name.
@MacDue MacDue force-pushed the arm_sme_vector_legalization branch from 0450b4a to 5901316 Compare January 31, 2024 10:49
@MacDue MacDue merged commit 042800a into llvm:main Jan 31, 2024
3 of 4 checks passed
@MacDue MacDue deleted the arm_sme_vector_legalization branch January 31, 2024 11:55
JoelWee added a commit that referenced this pull request Jan 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants