-
Notifications
You must be signed in to change notification settings - Fork 11.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][ArmSME] Add support for lowering masked tile_load ops #70915
[mlir][ArmSME] Add support for lowering masked tile_load ops #70915
Conversation
✅ With the latest revision this PR passed the C/C++ code formatter. |
d360236
to
300dc50
Compare
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-sme Author: Cullen Rhodes (c-rhodes) ChangesThis patch extends ArmSMEToSCF to support lowering of masked tile_load There are two lowerings, one for pad of constant zero and another for
The former (constant non-zero pad) is lowered as follows:
The tile is zeroed the satisfy the padding and only active rows are The latter (non-zero pad) is lowered as follows:
The scalar pad is broadcast to a 1-D vector and a regular Patch is 23.41 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/70915.diff 3 Files Affected:
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 80da6ffda1ed2ea..46e81bb935c406a 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -80,9 +80,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
PatternRewriter &rewriter) const override {
if (tileLoadOp.getMask())
- // TODO: add masked patterns.
- return rewriter.notifyMatchFailure(
- tileLoadOp, "op has mask, needs masked pattern(s)");
+ return rewriter.notifyMatchFailure(tileLoadOp,
+ "op has mask, apply masked patterns");
OpBuilder::InsertionGuard g(rewriter);
auto loc = tileLoadOp.getLoc();
@@ -142,6 +141,254 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
}
};
+/// Lower `arm_sme.tile_load` with mask and pad of constant zero.
+///
+/// BEFORE:
+/// ```mlir
+/// %pad = arith.constant 0 : i32
+/// %num_rows = arith.constant 2 : index
+/// %num_cols = arith.constant 4 : index
+/// %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
+/// %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
+/// memref<?x?xi32>, vector<[4]x[4]xi32>
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %c0 = arith.constant 0 : index
+/// %c1 = arith.constant 1 : index
+/// %tile = arm_sme.zero : vector<[4]x[4]xi32>
+/// %num_cols = vector.create_mask %c4 : vector<[4]xi1>
+/// scf.for %tile_slice_idx = %c0 to %num_rows step %c1 {
+/// %tile_update = arm_sme.load_tile_slice
+/// %src[%tile_slice_idx], %num_cols, %tile, %tile_slice_idx :
+/// memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>
+/// }
+/// ```
+///
+/// NOTE: Only mask of 'vector.create_mask' op is currently supported.
+struct TileLoadOpWithMaskAndPadZeroConversion
+ : public OpRewritePattern<arm_sme::TileLoadOp> {
+ using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
+ PatternRewriter &rewriter) const override {
+ OpBuilder::InsertionGuard g(rewriter);
+ auto loc = tileLoadOp.getLoc();
+ auto tileType = tileLoadOp.getVectorType();
+
+ auto maskOp = tileLoadOp.getMask();
+ if (!maskOp)
+ return rewriter.notifyMatchFailure(
+ tileLoadOp, "op has no mask, needs unmasked pattern");
+
+ auto padOp = tileLoadOp.getPadding();
+ assert(padOp && "expected padding when masking!");
+
+ auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
+ if (!createMaskOp)
+ return rewriter.notifyMatchFailure(
+ tileLoadOp, "unsupported mask op, only 'vector.create_mask' is "
+ "currently supported");
+
+ auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
+ if (!constPadOp || constPadOp.getValue() !=
+ rewriter.getZeroAttr(tileType.getElementType()))
+ return rewriter.notifyMatchFailure(
+ tileLoadOp, "op has non-zero pad, needs non-zero pad pattern");
+
+ auto numRows = createMaskOp.getOperands()[0];
+ auto numCols = createMaskOp.getOperands()[1];
+
+ auto predicateType =
+ VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
+ auto numColsOp =
+ rewriter.create<vector::CreateMaskOp>(loc, predicateType, numCols);
+
+ // Initialize tile with zero to satisfy padding. Inactive cols will be
+ // zeroed anyway since the loads use zeroing predication. For inactive rows
+ // however, no load will occur so these need to be zeroed.
+ auto tile = rewriter.create<arm_sme::ZeroOp>(loc, tileType);
+
+ // Create a loop to load the active tile slices from memory.
+ auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto upperBound = numRows;
+ auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+
+ rewriter.setInsertionPointToStart(forOp.getBody());
+
+ // Create 'arm_sme.load_tile_slice' to load tile slice from memory into
+ // tile.
+ SmallVector<Value> memrefIndices;
+ auto tileSliceIndex = forOp.getInductionVar();
+ getMemrefIndices(tileLoadOp.getIndices(),
+ tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
+ upperBound, memrefIndices, loc, rewriter);
+ rewriter.create<arm_sme::LoadTileSliceOp>(
+ loc, tileType, tileLoadOp.getBase(), numColsOp, tile, memrefIndices,
+ tileSliceIndex, tileLoadOp.getLayout());
+
+ rewriter.setInsertionPointAfter(forOp);
+
+ // Replace 'arm_sme.tile_load' with the tile.
+ rewriter.replaceOp(tileLoadOp, tile);
+
+ return success();
+ }
+};
+
+/// Lower `arm_sme.tile_load` with mask and non-zero pad.
+///
+/// BEFORE:
+/// ```mlir
+/// %pad = arith.constant 1 : i32
+/// %num_rows = arith.constant 2 : index
+/// %num_cols = arith.constant 4 : index
+/// %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
+/// %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
+/// memref<?x?xi32>, vector<[4]x[4]xi32>
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %pad_1d = arith.constant dense<1> : vector<[4]xi32>
+/// %num_rows = arith.constant 2 : index
+/// %num_cols = arith.constant 4 : index
+/// %tile_id = arm_sme.get_tile_id : i32
+/// %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
+/// %vscale = vector.vscale
+/// %c0 = arith.constant 0 : index
+/// %c1 = arith.constant 1 : index
+/// %min_svl_s = arith.constant 4 : index
+/// %svl_s = arith.muli %min_svl_s, %vscale : index
+/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
+/// %row_is_active = arith.cmpi ult %tile_slice_idx, %num_rows : index
+/// %slice = scf.if %row_is_active -> vector<[4]xi32> {
+/// %slice = vector.maskedload %base[%tile_slice_idx, %c0], %num_cols, %pad
+/// : memref<?x?xi32>, vector<[4]xi1>,
+/// vector<[4]xi32> into vector<[4]xi32>
+/// scf.yield %slice : vector<[4]xi32>
+/// } else {
+/// scf.yield %pad_1d : vector<[4]xi32>
+/// }
+/// // Insert slice into tile
+/// arm_sme.move_vector_to_tile_slice %slice, %tile, %tile_slice_idx
+/// : vector<[4]xi32> into vector<[4]x[4]xi32>
+/// }
+/// ```
+struct TileLoadOpWithMaskAndPadNonZeroConversion
+ : public OpRewritePattern<arm_sme::TileLoadOp> {
+ using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
+ PatternRewriter &rewriter) const override {
+ OpBuilder::InsertionGuard g(rewriter);
+ auto loc = tileLoadOp.getLoc();
+ auto tileType = tileLoadOp.getVectorType();
+ auto tileElementType = tileType.getElementType();
+ unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
+
+ auto maskOp = tileLoadOp.getMask();
+ if (!maskOp)
+ return rewriter.notifyMatchFailure(
+ tileLoadOp, "op has no mask, needs unmasked pattern");
+
+ auto padOp = tileLoadOp.getPadding();
+ assert(padOp && "expected padding when masking!");
+
+ auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
+ if (!createMaskOp)
+ return rewriter.notifyMatchFailure(
+ tileLoadOp, "unsupported mask op, only 'vector.create_mask' is "
+ "currently supported");
+
+ auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
+ if (constPadOp &&
+ constPadOp.getValue() == rewriter.getZeroAttr(tileElementType))
+ return rewriter.notifyMatchFailure(
+ tileLoadOp, "op has constant zero pad, needs zero pad pattern");
+
+ auto numRows = createMaskOp.getOperands()[0];
+ auto numCols = createMaskOp.getOperands()[1];
+
+ VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
+ auto predicateType =
+ VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
+ auto numColsOp =
+ rewriter.create<vector::CreateMaskOp>(loc, predicateType, numCols);
+
+ // Create 'arm_sme.get_tile' op.
+ auto tileId = rewriter.create<arm_sme::GetTileID>(
+ loc, rewriter.getIntegerType(tileElementWidth));
+
+ // Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type to
+ // use as input tile to 'arm_sme.load_tile_slice' ops.
+ auto tile =
+ rewriter.create<arm_sme::CastTileToVector>(loc, tileType, tileId);
+
+ // Create a loop that loads each ZA tile slice from memory.
+ auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
+ loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
+ auto vscale =
+ rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+ auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto numTileSlices =
+ rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+ auto forOp =
+ rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
+
+ rewriter.setInsertionPointToStart(forOp.getBody());
+
+ auto tileSliceIndex = forOp.getInductionVar();
+
+ auto rowIsActive = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
+
+ SmallVector<Value> memrefIndices;
+ getMemrefIndices(tileLoadOp.getIndices(),
+ tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
+ numTileSlices, memrefIndices, loc, rewriter);
+
+ // Splat pad into 1-D vector matching type of tile slice.
+ auto pad1DOp = rewriter.create<vector::SplatOp>(loc, tileSliceType, padOp);
+
+ Operation *slice = rewriter.create<scf::IfOp>(
+ loc, rowIsActive,
+ [&](OpBuilder &b, Location loc) {
+ // If the row is active, emit a masked load where the predicate is
+ // 'numCols'. Pad is used for inactive elements, taken from
+ // passthru.
+ auto loadSlice = rewriter.create<vector::MaskedLoadOp>(
+ loc, tileSliceType, tileLoadOp.getBase(), memrefIndices,
+ numColsOp, /*passthru=*/pad1DOp);
+ rewriter.create<scf::YieldOp>(loc, loadSlice->getResult(0));
+ },
+ [&](OpBuilder &b, Location loc) {
+ // Inactive rows are filled with pad.
+ rewriter.create<scf::YieldOp>(loc, pad1DOp.getResult());
+ });
+
+ // TODO: If the load is vertical the transpose can't be done in-flight with
+ // a regular (SVE) maskedload. Propagate layout to
+ // 'arm_sme.move_vector_to_tile_slice' below once it supports layout. This
+ // is currently broken.
+
+ // Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile.
+ rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
+ loc, tileType, slice->getResult(0), tile, tileSliceIndex,
+ tileLoadOp.getLayout());
+
+ rewriter.setInsertionPointAfter(forOp);
+
+ // Replace 'arm_sme.tile_load' with the tile.
+ rewriter.replaceOp(tileLoadOp, tile);
+
+ return success();
+ }
+};
+
/// Lower `arm_sme.tile_store` to a loop over the tile slices and store each
/// slice using `arm_sme.store_tile_slice`.
///
@@ -273,7 +520,8 @@ struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
} // namespace
void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns) {
- patterns.add<TileLoadOpConversion, TileStoreOpConversion,
+ patterns.add<TileLoadOpConversion, TileLoadOpWithMaskAndPadZeroConversion,
+ TileLoadOpWithMaskAndPadNonZeroConversion, TileStoreOpConversion,
TileVectorPrintOpConversion>(patterns.getContext());
}
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index d61f588941b408c..55ea56f42c96ed9 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -33,6 +33,62 @@ func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
return
}
+// -----
+
+// CHECK-LABEL: func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(
+// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi32>) {
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[NUM_ROWS:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1>
+// CHECK-DAG: %[[TILEZERO:.*]] = arm_sme.zero : vector<[4]x[4]xi32>
+// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_ROWS]] step %[[C1]] {
+// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
+// CHECK-NEXT: arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[NUM_COLS]], %[[TILEZERO]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
+func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>) {
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %pad = arith.constant 0 : i32
+ %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
+ %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(
+// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi32>,
+// CHECK-SAME: %[[PAD:.*]]: i32) {
+// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
+// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[NUM_ROWS:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1>
+// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+// CHECK-NEXT: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
+// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
+// CHECK-NEXT: %[[ROW_IS_ACTIVE:.*]] = arith.cmpi ult, %[[TILE_SLICE_INDEX]], %[[NUM_ROWS]] : index
+// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
+// CHECK: %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32>
+// CHECK: %[[SLICE:.*]] = scf.if %[[ROW_IS_ACTIVE]] -> (vector<[4]xi32>) {
+// CHECK: %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[NUM_COLS]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
+// CHECK: scf.yield %[[LOAD_SLICE]] : vector<[4]xi32>
+// CHECK: } else {
+// CHECK: scf.yield %[[PAD_1D]] : vector<[4]xi32>
+// CHECK: }
+// CHECK: arm_sme.move_vector_to_tile_slice %[[SLICE]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32>
+func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32>, %pad : i32) {
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
+ %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+ return
+}
+
//===----------------------------------------------------------------------===//
// arm_sme.tile_store
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
new file mode 100644
index 000000000000000..644f90d950645b8
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
@@ -0,0 +1,212 @@
+// DEFINE: %{entry_point} = entry
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE: %{run} = %mcr_aarch64_cmd \
+// DEFINE: -march=aarch64 -mattr=+sve,+sme \
+// DEFINE: -e %{entry_point} -entry-point-result=void \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile} | %{run} | FileCheck %s
+
+// Vector load.
+func.func @transfer_read_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+ %c4 = arith.constant 4 : index
+ %pad = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %A[%base1, %base2], %pad {in_bounds=[true, true]} :
+ memref<?x?xf32>, vector<[4]x[4]xf32>
+
+ vector.print str "TILE BEGIN:"
+ vector.print %0: vector<[4]x[4]xf32>
+
+ return
+}
+
+// Vector load + transpose.
+func.func @transfer_read_2d_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+ %pad = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %A[%base1, %base2], %pad
+ {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]}
+ : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+ vector.print str "TILE BEGIN:"
+ vector.print %0 : vector<[4]x[4]xf32>
+
+ return
+}
+
+// Vector load with mask and pad of zero.
+func.func @transfer_read_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %pad = arith.constant 0.0 : f32
+ %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+ %0 = vector.transfer_read %A[%base1, %base2], %pad, %mask
+ {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+ vector.print str "TILE BEGIN:"
+ vector.print %0: vector<[4]x[4]xf32>
+
+ return
+}
+
+// Vector load with mask and pad of zero + transpose.
+func.func @transfer_read_2d_mask_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %pad = arith.constant 0.0 : f32
+ %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+ %0 = vector.transfer_read %A[%base1, %base2], %pad, %mask
+ {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]}
+ : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+ vector.print str "TILE BEGIN:"
+ vector.print %0: vector<[4]x[4]xf32>
+
+ return
+}
+
+// Vector load with mask and non-zero pad.
+func.func @transfer_read_2d_mask_non_zero_pad(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %pad = arith.constant -42.0 : f32
+ %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+ %0 = vector.transfer_read %A[%base1, %base2], %pad, %mask
+ {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+ vector.print str "TILE BEGIN:"
+ vector.print %0: vector<[4]x[4]xf32>
+
+ return
+}
+
+// Vector load with mask and non-zero pad + transpose.
+func.func @transfer_read_2d_mask_non_zero_pad_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %pad = arith.constant -42.0 : f32
+ %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+ %0 = vector.transfer_read %A[%base1, %base2], %pad, %mask
+ {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]}
+ : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+ vector.print str "TILE BEGIN:"
+ vector.print %0: vector<[4]x[4]xf32>
+
+ return
+}
+
+// Allocate heap memory of size 'd0' x 'd1' and initialize.
+//
+// Example:
+//
+// initialize_memory(%c4, %c5)
+//
+// 0, 1, 2, 3, 4
+// 10, 11, 12, 13, 14
+// 20, 21, 22, 23, 24
+// 30, 31, 32, 33, 34
+//
+// Returns dynamic memref. It's the callers responsiblity to free the returned
+// memref.
+func.func @initialize_memor...
[truncated]
|
@llvm/pr-subscribers-mlir-vector Author: Cullen Rhodes (c-rhodes) ChangesThis patch extends ArmSMEToSCF to support lowering of masked tile_load There are two lowerings, one for pad of constant zero and another for
The former (constant non-zero pad) is lowered as follows:
The tile is zeroed the satisfy the padding and only active rows are The latter (non-zero pad) is lowered as follows:
The scalar pad is broadcast to a 1-D vector and a regular Patch is 23.41 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/70915.diff 3 Files Affected:
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 80da6ffda1ed2ea..46e81bb935c406a 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -80,9 +80,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
PatternRewriter &rewriter) const override {
if (tileLoadOp.getMask())
- // TODO: add masked patterns.
- return rewriter.notifyMatchFailure(
- tileLoadOp, "op has mask, needs masked pattern(s)");
+ return rewriter.notifyMatchFailure(tileLoadOp,
+ "op has mask, apply masked patterns");
OpBuilder::InsertionGuard g(rewriter);
auto loc = tileLoadOp.getLoc();
@@ -142,6 +141,254 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
}
};
+/// Lower `arm_sme.tile_load` with mask and pad of constant zero.
+///
+/// BEFORE:
+/// ```mlir
+/// %pad = arith.constant 0 : i32
+/// %num_rows = arith.constant 2 : index
+/// %num_cols = arith.constant 4 : index
+/// %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
+/// %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
+/// memref<?x?xi32>, vector<[4]x[4]xi32>
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %c0 = arith.constant 0 : index
+/// %c1 = arith.constant 1 : index
+/// %tile = arm_sme.zero : vector<[4]x[4]xi32>
+/// %num_cols = vector.create_mask %c4 : vector<[4]xi1>
+/// scf.for %tile_slice_idx = %c0 to %num_rows step %c1 {
+/// %tile_update = arm_sme.load_tile_slice
+/// %src[%tile_slice_idx], %num_cols, %tile, %tile_slice_idx :
+/// memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>
+/// }
+/// ```
+///
+/// NOTE: Only mask of 'vector.create_mask' op is currently supported.
+struct TileLoadOpWithMaskAndPadZeroConversion
+ : public OpRewritePattern<arm_sme::TileLoadOp> {
+ using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
+ PatternRewriter &rewriter) const override {
+ OpBuilder::InsertionGuard g(rewriter);
+ auto loc = tileLoadOp.getLoc();
+ auto tileType = tileLoadOp.getVectorType();
+
+ auto maskOp = tileLoadOp.getMask();
+ if (!maskOp)
+ return rewriter.notifyMatchFailure(
+ tileLoadOp, "op has no mask, needs unmasked pattern");
+
+ auto padOp = tileLoadOp.getPadding();
+ assert(padOp && "expected padding when masking!");
+
+ auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
+ if (!createMaskOp)
+ return rewriter.notifyMatchFailure(
+ tileLoadOp, "unsupported mask op, only 'vector.create_mask' is "
+ "currently supported");
+
+ auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
+ if (!constPadOp || constPadOp.getValue() !=
+ rewriter.getZeroAttr(tileType.getElementType()))
+ return rewriter.notifyMatchFailure(
+ tileLoadOp, "op has non-zero pad, needs non-zero pad pattern");
+
+ auto numRows = createMaskOp.getOperands()[0];
+ auto numCols = createMaskOp.getOperands()[1];
+
+ auto predicateType =
+ VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
+ auto numColsOp =
+ rewriter.create<vector::CreateMaskOp>(loc, predicateType, numCols);
+
+ // Initialize tile with zero to satisfy padding. Inactive cols will be
+ // zeroed anyway since the loads use zeroing predication. For inactive rows
+ // however, no load will occur so these need to be zeroed.
+ auto tile = rewriter.create<arm_sme::ZeroOp>(loc, tileType);
+
+ // Create a loop to load the active tile slices from memory.
+ auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto upperBound = numRows;
+ auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+
+ rewriter.setInsertionPointToStart(forOp.getBody());
+
+ // Create 'arm_sme.load_tile_slice' to load tile slice from memory into
+ // tile.
+ SmallVector<Value> memrefIndices;
+ auto tileSliceIndex = forOp.getInductionVar();
+ getMemrefIndices(tileLoadOp.getIndices(),
+ tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
+ upperBound, memrefIndices, loc, rewriter);
+ rewriter.create<arm_sme::LoadTileSliceOp>(
+ loc, tileType, tileLoadOp.getBase(), numColsOp, tile, memrefIndices,
+ tileSliceIndex, tileLoadOp.getLayout());
+
+ rewriter.setInsertionPointAfter(forOp);
+
+ // Replace 'arm_sme.tile_load' with the tile.
+ rewriter.replaceOp(tileLoadOp, tile);
+
+ return success();
+ }
+};
+
+/// Lower `arm_sme.tile_load` with mask and non-zero pad.
+///
+/// BEFORE:
+/// ```mlir
+/// %pad = arith.constant 1 : i32
+/// %num_rows = arith.constant 2 : index
+/// %num_cols = arith.constant 4 : index
+/// %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
+/// %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
+/// memref<?x?xi32>, vector<[4]x[4]xi32>
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %pad_1d = arith.constant dense<1> : vector<[4]xi32>
+/// %num_rows = arith.constant 2 : index
+/// %num_cols = arith.constant 4 : index
+/// %tile_id = arm_sme.get_tile_id : i32
+/// %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
+/// %vscale = vector.vscale
+/// %c0 = arith.constant 0 : index
+/// %c1 = arith.constant 1 : index
+/// %min_svl_s = arith.constant 4 : index
+/// %svl_s = arith.muli %min_svl_s, %vscale : index
+/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
+/// %row_is_active = arith.cmpi ult %tile_slice_idx, %num_rows : index
+/// %slice = scf.if %row_is_active -> vector<[4]xi32> {
+/// %slice = vector.maskedload %base[%tile_slice_idx, %c0], %num_cols, %pad
+/// : memref<?x?xi32>, vector<[4]xi1>,
+/// vector<[4]xi32> into vector<[4]xi32>
+/// scf.yield %slice : vector<[4]xi32>
+/// } else {
+/// scf.yield %pad_1d : vector<[4]xi32>
+/// }
+/// // Insert slice into tile
+/// arm_sme.move_vector_to_tile_slice %slice, %tile, %tile_slice_idx
+/// : vector<[4]xi32> into vector<[4]x[4]xi32>
+/// }
+/// ```
+struct TileLoadOpWithMaskAndPadNonZeroConversion
+ : public OpRewritePattern<arm_sme::TileLoadOp> {
+ using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
+ PatternRewriter &rewriter) const override {
+ OpBuilder::InsertionGuard g(rewriter);
+ auto loc = tileLoadOp.getLoc();
+ auto tileType = tileLoadOp.getVectorType();
+ auto tileElementType = tileType.getElementType();
+ unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
+
+ auto maskOp = tileLoadOp.getMask();
+ if (!maskOp)
+ return rewriter.notifyMatchFailure(
+ tileLoadOp, "op has no mask, needs unmasked pattern");
+
+ auto padOp = tileLoadOp.getPadding();
+ assert(padOp && "expected padding when masking!");
+
+ auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
+ if (!createMaskOp)
+ return rewriter.notifyMatchFailure(
+ tileLoadOp, "unsupported mask op, only 'vector.create_mask' is "
+ "currently supported");
+
+ auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
+ if (constPadOp &&
+ constPadOp.getValue() == rewriter.getZeroAttr(tileElementType))
+ return rewriter.notifyMatchFailure(
+ tileLoadOp, "op has constant zero pad, needs zero pad pattern");
+
+ auto numRows = createMaskOp.getOperands()[0];
+ auto numCols = createMaskOp.getOperands()[1];
+
+ VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
+ auto predicateType =
+ VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
+ auto numColsOp =
+ rewriter.create<vector::CreateMaskOp>(loc, predicateType, numCols);
+
+ // Create 'arm_sme.get_tile' op.
+ auto tileId = rewriter.create<arm_sme::GetTileID>(
+ loc, rewriter.getIntegerType(tileElementWidth));
+
+ // Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type to
+ // use as input tile to 'arm_sme.load_tile_slice' ops.
+ auto tile =
+ rewriter.create<arm_sme::CastTileToVector>(loc, tileType, tileId);
+
+ // Create a loop that loads each ZA tile slice from memory.
+ auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
+ loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
+ auto vscale =
+ rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+ auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto numTileSlices =
+ rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+ auto forOp =
+ rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
+
+ rewriter.setInsertionPointToStart(forOp.getBody());
+
+ auto tileSliceIndex = forOp.getInductionVar();
+
+ auto rowIsActive = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
+
+ SmallVector<Value> memrefIndices;
+ getMemrefIndices(tileLoadOp.getIndices(),
+ tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
+ numTileSlices, memrefIndices, loc, rewriter);
+
+ // Splat pad into 1-D vector matching type of tile slice.
+ auto pad1DOp = rewriter.create<vector::SplatOp>(loc, tileSliceType, padOp);
+
+ Operation *slice = rewriter.create<scf::IfOp>(
+ loc, rowIsActive,
+ [&](OpBuilder &b, Location loc) {
+ // If the row is active, emit a masked load where the predicate is
+ // 'numCols'. Pad is used for inactive elements, taken from
+ // passthru.
+ auto loadSlice = rewriter.create<vector::MaskedLoadOp>(
+ loc, tileSliceType, tileLoadOp.getBase(), memrefIndices,
+ numColsOp, /*passthru=*/pad1DOp);
+ rewriter.create<scf::YieldOp>(loc, loadSlice->getResult(0));
+ },
+ [&](OpBuilder &b, Location loc) {
+ // Inactive rows are filled with pad.
+ rewriter.create<scf::YieldOp>(loc, pad1DOp.getResult());
+ });
+
+ // TODO: If the load is vertical the transpose can't be done in-flight with
+ // a regular (SVE) maskedload. Propagate layout to
+ // 'arm_sme.move_vector_to_tile_slice' below once it supports layout. This
+ // is currently broken.
+
+ // Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile.
+ rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
+ loc, tileType, slice->getResult(0), tile, tileSliceIndex,
+ tileLoadOp.getLayout());
+
+ rewriter.setInsertionPointAfter(forOp);
+
+ // Replace 'arm_sme.tile_load' with the tile.
+ rewriter.replaceOp(tileLoadOp, tile);
+
+ return success();
+ }
+};
+
/// Lower `arm_sme.tile_store` to a loop over the tile slices and store each
/// slice using `arm_sme.store_tile_slice`.
///
@@ -273,7 +520,8 @@ struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
} // namespace
void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns) {
- patterns.add<TileLoadOpConversion, TileStoreOpConversion,
+ patterns.add<TileLoadOpConversion, TileLoadOpWithMaskAndPadZeroConversion,
+ TileLoadOpWithMaskAndPadNonZeroConversion, TileStoreOpConversion,
TileVectorPrintOpConversion>(patterns.getContext());
}
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index d61f588941b408c..55ea56f42c96ed9 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -33,6 +33,62 @@ func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
return
}
+// -----
+
+// CHECK-LABEL: func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(
+// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi32>) {
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[NUM_ROWS:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1>
+// CHECK-DAG: %[[TILEZERO:.*]] = arm_sme.zero : vector<[4]x[4]xi32>
+// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_ROWS]] step %[[C1]] {
+// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
+// CHECK-NEXT: arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[NUM_COLS]], %[[TILEZERO]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
+func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>) {
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %pad = arith.constant 0 : i32
+ %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
+ %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(
+// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi32>,
+// CHECK-SAME: %[[PAD:.*]]: i32) {
+// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
+// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[NUM_ROWS:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1>
+// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+// CHECK-NEXT: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
+// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
+// CHECK-NEXT: %[[ROW_IS_ACTIVE:.*]] = arith.cmpi ult, %[[TILE_SLICE_INDEX]], %[[NUM_ROWS]] : index
+// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
+// CHECK: %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32>
+// CHECK: %[[SLICE:.*]] = scf.if %[[ROW_IS_ACTIVE]] -> (vector<[4]xi32>) {
+// CHECK: %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[NUM_COLS]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
+// CHECK: scf.yield %[[LOAD_SLICE]] : vector<[4]xi32>
+// CHECK: } else {
+// CHECK: scf.yield %[[PAD_1D]] : vector<[4]xi32>
+// CHECK: }
+// CHECK: arm_sme.move_vector_to_tile_slice %[[SLICE]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32>
+func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32>, %pad : i32) {
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
+ %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+ return
+}
+
//===----------------------------------------------------------------------===//
// arm_sme.tile_store
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
new file mode 100644
index 000000000000000..644f90d950645b8
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
@@ -0,0 +1,212 @@
+// DEFINE: %{entry_point} = entry
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE: %{run} = %mcr_aarch64_cmd \
+// DEFINE: -march=aarch64 -mattr=+sve,+sme \
+// DEFINE: -e %{entry_point} -entry-point-result=void \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile} | %{run} | FileCheck %s
+
+// Vector load.
+func.func @transfer_read_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+ %c4 = arith.constant 4 : index
+ %pad = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %A[%base1, %base2], %pad {in_bounds=[true, true]} :
+ memref<?x?xf32>, vector<[4]x[4]xf32>
+
+ vector.print str "TILE BEGIN:"
+ vector.print %0: vector<[4]x[4]xf32>
+
+ return
+}
+
+// Vector load + transpose.
+func.func @transfer_read_2d_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+ %pad = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %A[%base1, %base2], %pad
+ {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]}
+ : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+ vector.print str "TILE BEGIN:"
+ vector.print %0 : vector<[4]x[4]xf32>
+
+ return
+}
+
+// Vector load with mask and pad of zero.
+func.func @transfer_read_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %pad = arith.constant 0.0 : f32
+ %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+ %0 = vector.transfer_read %A[%base1, %base2], %pad, %mask
+ {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+ vector.print str "TILE BEGIN:"
+ vector.print %0: vector<[4]x[4]xf32>
+
+ return
+}
+
+// Vector load with mask and pad of zero + transpose.
+func.func @transfer_read_2d_mask_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %pad = arith.constant 0.0 : f32
+ %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+ %0 = vector.transfer_read %A[%base1, %base2], %pad, %mask
+ {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]}
+ : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+ vector.print str "TILE BEGIN:"
+ vector.print %0: vector<[4]x[4]xf32>
+
+ return
+}
+
+// Vector load with mask and non-zero pad.
+func.func @transfer_read_2d_mask_non_zero_pad(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %pad = arith.constant -42.0 : f32
+ %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+ %0 = vector.transfer_read %A[%base1, %base2], %pad, %mask
+ {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+ vector.print str "TILE BEGIN:"
+ vector.print %0: vector<[4]x[4]xf32>
+
+ return
+}
+
+// Vector load with mask and non-zero pad + transpose.
+func.func @transfer_read_2d_mask_non_zero_pad_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %pad = arith.constant -42.0 : f32
+ %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+ %0 = vector.transfer_read %A[%base1, %base2], %pad, %mask
+ {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]}
+ : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+ vector.print str "TILE BEGIN:"
+ vector.print %0: vector<[4]x[4]xf32>
+
+ return
+}
+
+// Allocate heap memory of size 'd0' x 'd1' and initialize.
+//
+// Example:
+//
+// initialize_memory(%c4, %c5)
+//
+// 0, 1, 2, 3, 4
+// 10, 11, 12, 13, 14
+// 20, 21, 22, 23, 24
+// 30, 31, 32, 33, 34
+//
+// Returns dynamic memref. It's the callers responsiblity to free the returned
+// memref.
+func.func @initialize_memor...
[truncated]
|
Could we avoid the if statement by anding the if condition with the mask that always execute the masked load? |
I think and'ing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, my comments/suggestions are nits, thanks!
mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
Outdated
Show resolved
Hide resolved
c805d05
to
676d299
Compare
This patch extends ArmSMEToSCF to support lowering of masked tile_load ops. Only masks created by 'vector.create_mask' are currently supported. There are two lowerings, one for pad of constant zero and another for non-zero pad. For the following example: %pad = arith.constant 0 : i32 %num_rows = arith.constant 2 : index %num_cols = arith.constant 4 : index %mask = vector.create_mask %num_rows, %num_cols : <[4]x[4]xi1> %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32> The former (constant non-zero pad) is lowered as follows: --------------------------------------------------------- %tile = arm_sme.zero : vector<[4]x[4]xi32> %num_cols = vector.create_mask %c4 : vector<[4]xi1> scf.for %slice_idx = %c0 to %num_rows step %c1 %tile_update = arm_sme.load_tile_slice %src[%slice_idx], %num_cols, %tile, %tile_slice_idx : memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32> The tile is zeroed the satisfy the padding and only active rows are loaded. The latter (non-zero pad) is lowered as follows: ------------------------------------------------ scf.for %slice_idx = %c0 to %num_tile_slices step %c1 { %row_is_active = arith.cmpi ult %slice_idx, %num_rows : index %slice = scf.if %row_is_active -> vector<[4]xf32> { %slice = vector.maskedload %src[%slice_idx, %c0], %num_cols, %pad_1d : memref<?x?xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32> scf.yield %slice : vector<[4]xf32> } else { scf.yield %pad_1d : vector<[4]xf32> } arm_sme.move_vector_to_tile_slice %slice, %tile, %slice_idx : vector<[4]xi32> into vector<[4]x[4]xi32> The scalar pad is broadcast to a 1-D vector and a regular 'vector.masked_load' (will be lowered to SVE, not SME) loads each slice for active rows, with padding specified as a passthru. For non-active rows the slice is the 1-D pad. The resulting slice is inserted into the tile with 'arm_sme.move_vector_to_tile_slice'.
accidentally pulled in a compiler-rt docs change I had on main when rebasing, will remove. |
676d299
to
9a5cdf0
Compare
This patch extends ArmSMEToSCF to support lowering of masked tile_load
ops. Only masks created by 'vector.create_mask' are currently supported.
There are two lowerings, one for pad of constant zero and another for
non-zero pad. For the following example:
The former (constant non-zero pad) is lowered as follows:
The tile is zeroed the satisfy the padding and only active rows are
loaded.
The latter (non-zero pad) is lowered as follows:
The scalar pad is broadcast to a 1-D vector and a regular
'vector.masked_load' (will be lowered to SVE, not SME) loads each slice
for active rows, with padding specified as a passthru. For non-active
rows the slice is the 1-D pad. The resulting slice is inserted into the
tile with 'arm_sme.move_vector_to_tile_slice'.