-
Notifications
You must be signed in to change notification settings - Fork 11.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][ArmSME] Add initial SME vector legalization pass #79152
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-sme Author: Benjamin Maxwell (MacDue) ChangesThis adds a new pass ( For it to be possible to legalize a vector type it has to be a multiple of an SME tile size, but other than that any shape can be used. E.g. In future, this pass will be extended with more SME-specific rewrites to legalize unrolling the reduction dimension of matmuls (which is not type-decomposition), which is why the pass has quite a general name. Patch is 50.95 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/79152.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index aef2959265a7cd..9ba8c43551257b 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -32,6 +32,9 @@ std::unique_ptr<Pass> createEnableArmStreamingPass(
/// Pass that allocates tile IDs to ArmSME operations.
std::unique_ptr<Pass> createTileAllocationPass();
+/// Pass that legalizes vectors so they can be lowered to ArmSME.
+std::unique_ptr<Pass> createVectorLegalizationPass();
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 8d1ba6ed34e805..9a6f5446de0094 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -122,4 +122,23 @@ def TileAllocation
let dependentDialects = ["func::FuncDialect"];
}
+def VectorLegalization
+ : Pass<"arm-sme-vector-legalization", "mlir::ModuleOp"> {
+ let summary = "Legalize vectors for ArmSME";
+ let description = [{
+ This pass legalizes vector operations so that they can be lowered to ArmSME.
+ This includes decomposing operations that operate on vector types larger
+ than a single SME tile (e.g. `vector<[8]x[8]xf32>`) into multiple SME
+ tile-sized operations, as well as rewrites needed to get operations into
+ forms compatible with SME lowerings.
+ }];
+ let constructor = "mlir::arm_sme::createVectorLegalizationPass()";
+ let dependentDialects = [
+ "func::FuncDialect",
+ "arm_sme::ArmSMEDialect",
+ "vector::VectorDialect",
+ "arith::ArithDialect"
+ ];
+}
+
#endif // MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 41702008ee48fb..e89be8ed81e03f 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -56,6 +56,13 @@ scf::ForOp createLoopOverTileSlices(
PatternRewriter &rewriter, Location loc, Value initTile,
std::function<Value(OpBuilder &, Location, Value, Value)> makeLoopBody);
+/// Returns true if `vType` is a multiple of an SME tile size. Note returns
+/// false if the `vType` exactly matches the size of an SME tile.
+bool isMultipleOfSMETileVectorType(VectorType vType);
+
+/// Creates a vector type for the SME tile of `elementType`.
+VectorType getSMETileTypeForElement(Type elementType);
+
} // namespace mlir::arm_sme
#endif // MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
index d0391210555662..6a9e0221822267 100644
--- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
@@ -94,4 +94,26 @@ scf::ForOp createLoopOverTileSlices(
return forOp;
}
+bool isMultipleOfSMETileVectorType(VectorType vType) {
+ if (vType.getRank() != 2 || !vType.allDimsScalable())
+ return false;
+
+ auto elementType = vType.getElementType();
+ if (!isValidSMETileElementType(elementType))
+ return false;
+
+ unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
+
+ int64_t vectorRows = vType.getDimSize(0);
+ int64_t vectorCols = vType.getDimSize(1);
+
+ return (vectorRows > minNumElts || vectorCols > minNumElts) &&
+ vectorRows % minNumElts == 0 && vectorCols % minNumElts == 0;
+}
+
+VectorType getSMETileTypeForElement(Type elementType) {
+ unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
+ return VectorType::get({minNumElts, minNumElts}, elementType, {true, true});
+}
+
} // namespace mlir::arm_sme
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index 96eb5844204384..3c32fc2645ce1b 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRArmSMETransforms
EnableArmStreaming.cpp
TileAllocation.cpp
+ VectorLegalization.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Transforms
@@ -9,10 +10,12 @@ add_mlir_dialect_library(MLIRArmSMETransforms
MLIRArmSMETransformsIncGen
LINK_LIBS PUBLIC
+ MLIRPass
MLIRArmSMEDialect
MLIRFuncDialect
MLIRLLVMCommonConversion
MLIRVectorDialect
MLIRSCFDialect
- MLIRPass
+ MLIRSCFTransforms
+ MLIRFuncTransforms
)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
new file mode 100644
index 00000000000000..a801ebe27413d5
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -0,0 +1,308 @@
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+#include "mlir/Dialect/ArmSME/Utils/Utils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
+
+#define DEBUG_TYPE "arm-sme-vector-legalization"
+
+namespace mlir::arm_sme {
+#define GEN_PASS_DEF_VECTORLEGALIZATION
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
+} // namespace mlir::arm_sme
+
+using namespace mlir;
+using namespace mlir::arm_sme;
+
+namespace {
+
+struct SMETile {
+ // Note: The units of (row, col) are vscale (as SME tiles are scalable).
+ int row{0};
+ int col{0};
+ VectorType type;
+};
+
+/// Adds a constant scalable offset to `indices`. i.e. for 2D:
+/// { indices[0] + offset[0] * vscale, indices[1] + offset[1] * vscale }
+SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder,
+ Location loc,
+ ValueRange indices,
+ ArrayRef<int> scalableOffset) {
+ auto vscale = builder.create<vector::VectorScaleOp>(loc);
+ return llvm::map_to_vector(
+ llvm::zip_equal(indices, scalableOffset), [&](auto pair) -> Value {
+ auto [index, base] = pair;
+ auto offset = builder.create<arith::MulIOp>(
+ loc, builder.create<arith::ConstantIndexOp>(loc, base), vscale);
+ return builder.create<arith::AddIOp>(loc, index, offset);
+ });
+}
+
+/// Remaps indices (e.g. from a load/store) for a larger vector type to indices
+/// for one of the SME tiles it will decompose into.
+SmallVector<Value, 2> remapIndicesForSMETile(OpBuilder &builder, Location loc,
+ ValueRange indices,
+ SMETile tileTile) {
+ return addConstantScalableOffset(builder, loc, indices,
+ {tileTile.row, tileTile.col});
+}
+
+/// Returns true if `mask` is generated by an operation that can be decomposed
+/// for SME. Currently, that is just no mask, or vector.create_mask.
+bool isSupportedMaskOp(Value mask) {
+ return !mask || mask.getDefiningOp<vector::CreateMaskOp>();
+}
+
+/// Extracts a mask for an SME tile from the mask of a larger vector type.
+Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
+ SMETile tileTile) {
+ assert(isSupportedMaskOp(mask));
+ if (!mask)
+ return Value{};
+ auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
+ // The the operands of `vector.create_mask` (from a 2D perspective) are the
+ // coordinates where the mask ends. So we subtract where this tile starts,
+ // from the mask operands to get the parameters for this tile tile.
+ auto tileMaskDims = addConstantScalableOffset(
+ builder, loc, createMask.getOperands(), {-tileTile.row, -tileTile.col});
+ auto createTileMask = builder.create<vector::CreateMaskOp>(
+ loc, tileTile.type.clone(builder.getI1Type()), tileMaskDims);
+ return createTileMask.getResult();
+}
+
+/// Constructs an iterator that returns each SME tile (with coordinates)
+/// contained within a VectorType.
+auto decomposeToSMETiles(OpBuilder &builder, VectorType type,
+ VectorType smeTileType,
+ bool transposeIndices = false) {
+ assert(isMultipleOfSMETileVectorType(type));
+ return llvm::map_range(
+ StaticTileOffsetRange(type.getShape(), {smeTileType.getDimSize(0),
+ smeTileType.getDimSize(1)}),
+ [=](auto indices) {
+ int row = int(indices[0]);
+ int col = int(indices[1]);
+ if (transposeIndices)
+ std::swap(row, col);
+ return SMETile{row, col, smeTileType};
+ });
+}
+
+/// Returns the number of SME tiles that fit into the a vector type.
+int getNumberOfSMETilesForVectorType(VectorType type) {
+ assert(isMultipleOfSMETileVectorType(type));
+ int64_t vectorRows = type.getDimSize(0);
+ int64_t vectorCols = type.getDimSize(1);
+ auto elementType = type.getElementType();
+ unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
+ return (vectorRows * vectorCols) / (minNumElts * minNumElts);
+}
+
+/// Legalize `vector.outerproduct` operations to fit within SME tiles.
+struct LegalizeVectorOuterProductOp
+ : public OneToNOpConversionPattern<vector::OuterProductOp> {
+ using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::OuterProductOp outerProductOp, OpAdaptor adaptor,
+ OneToNPatternRewriter &rewriter) const override {
+ auto vectorType = outerProductOp.getResultVectorType();
+ if (!isMultipleOfSMETileVectorType(vectorType))
+ return failure();
+
+ Value mask;
+ Operation *rootOp = outerProductOp;
+ auto loc = outerProductOp.getLoc();
+ if (outerProductOp.isMasked()) {
+ auto maskOp = outerProductOp.getMaskingOp();
+ mask = maskOp.getMask();
+ rootOp = maskOp;
+ }
+
+ if (!isSupportedMaskOp(mask))
+ return failure();
+
+ ValueRange accSMETiles = adaptor.getAcc();
+ auto tileType = getSMETileTypeForElement(vectorType.getElementType());
+ VectorType sliceType = VectorType::Builder(tileType).dropDim(0);
+
+ SmallVector<Value> resultSMETiles;
+ for (auto [index, tileTile] :
+ llvm::enumerate(decomposeToSMETiles(rewriter, vectorType, tileType))) {
+
+ auto tileMask = extractSMEMask(rewriter, loc, mask, tileTile);
+ auto lhs = rewriter.create<vector::ScalableExtractOp>(
+ loc, sliceType, outerProductOp.getLhs(), tileTile.row);
+ auto rhs = rewriter.create<vector::ScalableExtractOp>(
+ loc, sliceType, outerProductOp.getRhs(), tileTile.col);
+ auto tileOuterProduct = rewriter.create<vector::OuterProductOp>(
+ loc, tileType, lhs, rhs,
+ !accSMETiles.empty() ? accSMETiles[index] : Value{},
+ outerProductOp.getKind());
+
+ auto maskedOuterProduct =
+ vector::maskOperation(rewriter, tileOuterProduct, tileMask);
+ resultSMETiles.push_back(maskedOuterProduct->getResult(0));
+ }
+
+ rewriter.replaceOp(rootOp, resultSMETiles, adaptor.getResultMapping());
+ return success();
+ }
+};
+
+// Workaround for `vector.mask`. We want to match on `vector.outerproduct` (to
+// get the help of the type conversion), but doing so results in the type
+// conversion adding target materializations in the `vector.mask` region
+// (invalid). This pattern matches on `vector.mask` then calls into the
+// `vector.outerproduct` pattern to work around this issue.
+struct LegalizeMaskedVectorOuterProductOp
+ : public OneToNOpConversionPattern<vector::MaskOp> {
+ using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
+ OneToNPatternRewriter &rewriter) const override {
+ if (auto outerProductOp =
+ llvm::dyn_cast<vector::OuterProductOp>(maskOp.getMaskableOp())) {
+ LegalizeVectorOuterProductOp pattern(*getTypeConverter(), getContext());
+ return static_cast<RewritePattern &>(pattern).matchAndRewrite(
+ outerProductOp, rewriter);
+ }
+ return failure();
+ }
+};
+
+/// Legalize `vector.transfer_read` operations to fit within SME tiles.
+struct LegalizeTransferReadOp
+ : public OneToNOpConversionPattern<vector::TransferReadOp> {
+ using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
+ OneToNPatternRewriter &rewriter) const override {
+ auto vectorType = readOp.getVectorType();
+ if (!isMultipleOfSMETileVectorType(vectorType))
+ return failure();
+
+ auto mask = readOp.getMask();
+ if (!isSupportedMaskOp(mask))
+ return failure();
+
+ auto permutationMap = readOp.getPermutationMap();
+ if (!permutationMap.isPermutation())
+ return failure();
+
+ // Note: For 2D vector types the only non-identity permutation is a simple
+ // tranpose [1, 0].
+ bool transposed = !permutationMap.isIdentity();
+
+ auto loc = readOp.getLoc();
+ auto tileType = getSMETileTypeForElement(vectorType.getElementType());
+
+ SmallVector<Value> resultSMETiles;
+ for (SMETile tileTile :
+ decomposeToSMETiles(rewriter, vectorType, tileType, transposed)) {
+ auto tileMask = extractSMEMask(rewriter, loc, mask, tileTile);
+ auto transferRead = rewriter.create<vector::TransferReadOp>(
+ loc, tileType, readOp.getSource(),
+ remapIndicesForSMETile(rewriter, loc, readOp.getIndices(), tileTile),
+ readOp.getPermutationMapAttr(), readOp.getPadding(), tileMask,
+ readOp.getInBoundsAttr());
+ resultSMETiles.push_back(transferRead);
+ }
+
+ rewriter.replaceOp(readOp, resultSMETiles, adaptor.getResultMapping());
+ return success();
+ }
+};
+
+/// Legalize `vector.transfer_write` operations to fit within SME tiles.
+struct LegalizeTransferWriteOp
+ : public OneToNOpConversionPattern<vector::TransferWriteOp> {
+ using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
+ OneToNPatternRewriter &rewriter) const override {
+ auto vectorType = writeOp.getVectorType();
+ if (!isMultipleOfSMETileVectorType(vectorType))
+ return failure();
+
+ auto mask = writeOp.getMask();
+ if (!isSupportedMaskOp(mask))
+ return failure();
+
+ auto permutationMap = writeOp.getPermutationMap();
+ if (!permutationMap.isPermutation())
+ return failure();
+
+ // Note: For 2D vector types the only non-identity permutation is a simple
+ // tranpose [1, 0].
+ bool transposed = !permutationMap.isIdentity();
+
+ auto loc = writeOp.getLoc();
+ auto tileType = getSMETileTypeForElement(vectorType.getElementType());
+ auto inputSMETiles = adaptor.getVector();
+
+ Value destTensorOrMemref = writeOp.getSource();
+ for (auto [index, tileTile] : llvm::enumerate(
+ decomposeToSMETiles(rewriter, vectorType, tileType, transposed))) {
+ auto tileMask = extractSMEMask(rewriter, loc, mask, tileTile);
+ auto tileWrite = rewriter.create<vector::TransferWriteOp>(
+ loc, inputSMETiles[index], destTensorOrMemref,
+ remapIndicesForSMETile(rewriter, loc, writeOp.getIndices(), tileTile),
+ writeOp.getPermutationMapAttr(), tileMask, writeOp.getInBoundsAttr());
+ if (writeOp.hasPureTensorSemantics())
+ destTensorOrMemref = tileWrite.getResult();
+ }
+
+ if (writeOp.hasPureTensorSemantics())
+ rewriter.replaceOp(writeOp, destTensorOrMemref);
+ else
+ rewriter.eraseOp(writeOp);
+
+ return success();
+ }
+};
+
+struct VectorLegalizationPass
+ : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
+ void runOnOperation() override {
+ auto *context = &getContext();
+ OneToNTypeConverter converter;
+ RewritePatternSet patterns(context);
+
+ converter.addConversion([](Type type) { return type; });
+ converter.addConversion(
+ [](VectorType vectorType,
+ SmallVectorImpl<Type> &types) -> std::optional<LogicalResult> {
+ if (!isMultipleOfSMETileVectorType(vectorType))
+ return std::nullopt;
+ auto tileTileCount = getNumberOfSMETilesForVectorType(vectorType);
+ auto tileType = getSMETileTypeForElement(vectorType.getElementType());
+ types = SmallVector<Type>(tileTileCount, tileType);
+ return success();
+ });
+
+ // Note: High benefit to ensure masked outer products are lowered first.
+ patterns.add<LegalizeMaskedVectorOuterProductOp>(converter, context, 1024);
+ patterns.add<LegalizeVectorOuterProductOp, LegalizeTransferReadOp,
+ LegalizeTransferWriteOp>(converter, context);
+ populateFuncTypeConversionPatterns(converter, patterns);
+ scf::populateSCFStructuralOneToNTypeConversions(converter, patterns);
+
+ if (failed(applyPartialOneToNConversion(getOperation(), converter,
+ std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::arm_sme::createVectorLegalizationPass() {
+ return std::make_unique<VectorLegalizationPass>();
+}
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
new file mode 100644
index 00000000000000..a20abeefedcfd4
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -0,0 +1,268 @@
+// RUN: mlir-opt %s -arm-sme-vector-legalization -cse -canonicalize -split-input-file | FileCheck %s
+
+// CHECK-LABEL: @outerproduct_f32_scalable_8x8_no_acc(
+// CHECK-SAME: %[[LHS:.*]]: vector<[8]xf32>,
+// CHECK-SAME: %[[RHS:.*]]: vector<[8]xf32>)
+// CHECK-SAME: -> (vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>)
+func.func @outerproduct_f32_scalable_8x8_no_acc(%lhs: vector<[8]xf32>, %rhs: vector<[8]xf32>) -> vector<[8]x[8]xf32>
+{
+ // CHECK-DAG: %[[LHS_0:.*]] = vector.scalable.extract %[[LHS]][0] : vector<[4]xf32> from vector<[8]xf32>
+ // CHECK-DAG: %[[RHS_0:.*]] = vector.scalable.extract %[[RHS]][0] : vector<[4]xf32> from vector<[8]xf32>
+ // CHECK-DAG: %[[LHS_1:.*]] = vector.scalable.extract %[[LHS]][4] : vector<[4]xf32> from vector<[8]xf32>
+ // CHECK-DAG: %[[RHS_1:.*]] = vector.scalable.extract %[[RHS]][4] : vector<[4]xf32> from vector<[8]xf32>
+ // CHECK-DAG: %[[TOP_LEFT:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32>
+ // CHECK-DAG: %[[TOP_RIGHT:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_1]] : vector<[4]xf32>, vector<[4]xf32>
+ // CHECK-DAG: %[[BOTTOM_LEFT:.*]] = vector.outerproduct %[[LHS_1]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32>
+ // CHECK-DAG: %[[BOTTOM_RIGHT:.*]] = vector.outerproduct %[[LHS_1]], %[[RHS_1]] : vector<[4]xf32>, vector<[4]xf32>
+ // CHECK-NEXT: return %[[TOP_LEFT]], %[[TOP_RIGHT]], %[[BOTTOM_LEFT]], %[[BOTTOM_RIGHT]] : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>
+ %0 = vector.outerproduct %lhs, %rhs : vector<[8]xf32>, vector<[8]xf32>
+ return %0 : vector<[8]x[8]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_f32_scalable_4x16_acc(
+// CHECK-SAME: %[[LHS:.*]]: vector<[4]xf32>,
+// CHECK-SAME: %[[RHS:.*]]: vector<[16]xf32>,
+// CHECK-SAME: %[[ACC_0:[A-Za-z0-9]*]]: vector<[4]x[4]xf32>,
+// CHECK-SAME: %[[ACC_1:[A-Za-z0-9]*]]: vector<[4]x[4]xf32>,
+// CHECK-SAME: %[[ACC_2:[A-Za-z0-9]*]]: vector<[4]x[4]xf32>,
+// CHECK-SAME: %[[ACC_3:[A-Za-z0-9]*]]: vector<[4]x[4]xf32>)
+// CHECK-SAME: -> (vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>)
+func.func @outerproduct_f32_scalable_4x16_acc(%lhs: vector<[4]xf32>, %rhs: vector<[16]xf32>, %acc: ve...
[truncated]
|
Note: The tests passing depend on #78983. |
let description = [{ | ||
This pass legalizes vector operations so that they can be lowered to ArmSME. | ||
This includes decomposing operations that operate on vector types larger | ||
than a single SME tile (e.g. `vector<[8]x[8]xf32>`) into multiple SME |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vector<[8]x[8]xf32>
matches the size of ZA
, right? How about vectors that are larger than this? Or smaller?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That matches the size of ZA. There's infinite bigger vector types (this pass does not care how many tiles the resulting type uses), and a few smaller sizes. There may also be cases when you want sizes smaller than ZA (but more than a tile), i.e. a matmul with a rather small width/height but a large reduction dimension. This pass allows the tiling to be fairly flexible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right - but are there any limitations? I know that there are - those are effectively enforced by isMultipleOfSMETileVectorType
, right? It would be good add a few more comments here, e.g.:
- "legal" in this context means something that matches SMEs virtual tiles as available in hardware
- "input" type must be a multiple of these tiles (so definitely 2d and scalable in both dimensions)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
legal
- means something that can be lowered to ArmSME, as mentioned this pass will be broader than just type decomposition, and will include rewrites to eliminate illegal types like vector<[8]x4xf32>
(which also falls under legalizing vector operations).
(I think the description already mentions that the decomposition breaks down larger vector types into multiple SME-sized operations?)
I'll add the input size limitation 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
legal - means something that can be lowered to ArmSME
Perhaps you could try to defining that? I assume it means either:
- 2d scalable vectors with sizes matching SME virtual tiles,
- 1d scalable vectors
- n-D vectors with at most the trailing dim scalable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for patch Ben this is good work! Not got through all of it yet, but left a few comments for now. Will take another pass later
|
||
SmallVector<Value> resultSMETiles; | ||
for (auto [index, tileTile] : | ||
llvm::enumerate(decomposeToSMETiles(rewriter, vectorType, tileType))) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems a bit odd to be telling the function doing the legalization what the legal type is, the only benefit I can see is not having to initialise sliceType
in the loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I follow? decomposeToSMETiles()
is passed the tileType
which it uses to work out the step of the iterator, and passes it to the SMETile
too (which makes some helper functions less verbose).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Posting these comments now as I need to be AFK for a while.
My general impression is that this is missing a clearly defined input and output formats. In the context of this pass, what is a tile?
- SME tile that can me mapped onto hardware virtual tile?
- Linalg tile (as in a generic tile after Linalg tiling)?
What do we get after this decomposition? Also tiles? Perhaps something along the lines:
Maps "generic" 2D scalable tiles (as produced by the Linalg tiling and vectorisation logic) into "SME" 2D tiles (smaller 2D scalable tiles) that match the size of SME virtual tiles (as available in hardware).
Basically, variable names like tileTile
are a hint that we need to expand the dictionary a bit :-) Perhaps just use GenericTile
and SMESubTile
?
let description = [{ | ||
This pass legalizes vector operations so that they can be lowered to ArmSME. | ||
This includes decomposing operations that operate on vector types larger | ||
than a single SME tile (e.g. `vector<[8]x[8]xf32>`) into multiple SME |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right - but are there any limitations? I know that there are - those are effectively enforced by isMultipleOfSMETileVectorType
, right? It would be good add a few more comments here, e.g.:
- "legal" in this context means something that matches SMEs virtual tiles as available in hardware
- "input" type must be a multiple of these tiles (so definitely 2d and scalable in both dimensions)
(I think I accidentally did |
(all uses of the word 'tile' were intended to refer to SME tiles, and the generic case is just a |
Got it - I still think it would be good to use stricter names. For somebody focusing on SME, "tile" will normally mean "SME virtual tile". So all good. However, for people more used to Linalg level of abstractions, "tile" is just a slice of the iteration space. I just think that the term is a bit ambiguous and overloaded. Another option that could help disambiguate:
These would be different for different element types. |
1be013a
to
0fc3fbd
Compare
I've added some more context to the header of the pass, that within this context
This pass is not considering ZA, the input |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The extra comments are super helpful, thank you Ben!
mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-multi-tile-transpose.mlir
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing my comments - LGTM!
Great work! I acknowledge the complexity of this PR but I was wondering how much the "vector register splitting" side could be generalized to make it reusable by other targets. Vector register legalization is something that has come up multiple times so having a pass that can be parametrized accordingly would be extremely valuable. What are your thoughts about this? :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great work Ben this is very well written! LGTM cheers
I agree, made a similar observation offline and think we should look into this as a follow up. |
I think the logic here could be generalized (though, as written, I've made assumptions about scalability). I think it'd always be a rank-preserving legalization (i.e. 2D -> 2D, like here, or 1D to 1D, etc). Other transforms like unrolling fixed vectors can already be handled elsewhere (and would be complex to incorporate). |
This adds a new pass (`-arm-sme-vector-legalization`) which legalizes vector operations so that they can be lowered to ArmSME. This initial patch adds decomposition for `vector.outerproduct`, `vector.transfer_read`, and `vector.transfer_write` when they operate on vector types larger than a single SME tile. For example, a [8]x[8]xf32 outer product would be decomposed into four [4]x[4]xf32 outer products, which could then be lowered to ArmSME. These three ops have been picked as supporting them alone allows lowering matmuls that use all ZA accumulators to ArmSME. For it to be possible to legalize a vector type it has to be a multiple of an SME tile size, but other than that any shape can be used. E.g. `vector<[8]x[8]xf32>`, `vector<[4]x[16]xf32>`, `vector<[16]x[4]xf32>` can all be lowered to four `vector<[4]x[4]xf32>` operations. In future, this pass will be extended with more SME-specific rewrites to legalize unrolling the reduction dimension of matmuls (which is not type-decomposition), which is why the pass has quite a general name.
0450b4a
to
5901316
Compare
This adds a new pass (
-arm-sme-vector-legalization
) which legalizes vector operations so that they can be lowered to ArmSME. This initial patch adds decomposition forvector.outerproduct
,vector.transfer_read
, andvector.transfer_write
when they operate on vector types larger than a single SME tile. For example, a [8]x[8]xf32 outer product would be decomposed into four [4]x[4]xf32 outer products, which could then be lowered to ArmSME. These three ops have been picked as supporting them alone allows lowering matmuls that use all ZA accumulators to ArmSME.For it to be possible to legalize a vector type it has to be a multiple of an SME tile size, but other than that any shape can be used. E.g.
vector<[8]x[8]xf32>
,vector<[4]x[16]xf32>
,vector<[16]x[4]xf32>
can all be lowered to fourvector<[4]x[4]xf32>
operations.In future, this pass will be extended with more SME-specific rewrites to legalize unrolling the reduction dimension of matmuls (which is not type-decomposition), which is why the pass has quite a general name.