Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][ArmSME] Add support for lowering masked tile_load ops #70915

Merged
merged 4 commits into from
Nov 8, 2023

Conversation

c-rhodes
Copy link
Collaborator

@c-rhodes c-rhodes commented Nov 1, 2023

This patch extends ArmSMEToSCF to support lowering of masked tile_load
ops. Only masks created by 'vector.create_mask' are currently supported.

There are two lowerings, one for pad of constant zero and another for
non-zero pad. For the following example:

  %pad = arith.constant 0 : i32
  %num_rows = arith.constant 2 : index
  %num_cols = arith.constant 4 : index
  %mask = vector.create_mask %num_rows, %num_cols : <[4]x[4]xi1>
  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>,
                                                          vector<[4]x[4]xi32>

The former (constant non-zero pad) is lowered as follows:

  %tile = arm_sme.zero : vector<[4]x[4]xi32>
  %num_cols = vector.create_mask %c4 : vector<[4]xi1>
  scf.for %slice_idx = %c0 to %num_rows step %c1
    %tile_update = arm_sme.load_tile_slice
      %src[%slice_idx], %num_cols, %tile, %tile_slice_idx :
      memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>

The tile is zeroed the satisfy the padding and only active rows are
loaded.

The latter (non-zero pad) is lowered as follows:

  scf.for %slice_idx = %c0 to %num_tile_slices step %c1 {
    %row_is_active = arith.cmpi ult %slice_idx, %num_rows : index
    %mask = // sext row_is_active then and with num_cols
    %slice = vector.maskedload %src[%slice_idx, %c0], %mask, %pad_1d :
        memref<?x?xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
    arm_sme.move_vector_to_tile_slice %slice, %tile, %slice_idx
      : vector<[4]xi32> into vector<[4]x[4]xi32>

The scalar pad is broadcast to a 1-D vector and a regular
'vector.masked_load' (will be lowered to SVE, not SME) loads each slice
for active rows, with padding specified as a passthru. For non-active
rows the slice is the 1-D pad. The resulting slice is inserted into the
tile with 'arm_sme.move_vector_to_tile_slice'.

@c-rhodes
Copy link
Collaborator Author

c-rhodes commented Nov 1, 2023

this depends on #70814 and is the last patch from #69148 for masked loads.

Copy link

github-actions bot commented Nov 1, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@llvmbot
Copy link
Collaborator

llvmbot commented Nov 2, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sme

Author: Cullen Rhodes (c-rhodes)

Changes

This patch extends ArmSMEToSCF to support lowering of masked tile_load
ops. Only masks created by 'vector.create_mask' are currently supported.

There are two lowerings, one for pad of constant zero and another for
non-zero pad. For the following example:

  %pad = arith.constant 0 : i32
  %num_rows = arith.constant 2 : index
  %num_cols = arith.constant 4 : index
  %mask = vector.create_mask %num_rows, %num_cols : &lt;[4]x[4]xi1&gt;
  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref&lt;?x?xi32&gt;,
                                                          vector&lt;[4]x[4]xi32&gt;

The former (constant non-zero pad) is lowered as follows:

  %tile = arm_sme.zero : vector&lt;[4]x[4]xi32&gt;
  %num_cols = vector.create_mask %c4 : vector&lt;[4]xi1&gt;
  scf.for %slice_idx = %c0 to %num_rows step %c1
    %tile_update = arm_sme.load_tile_slice
      %src[%slice_idx], %num_cols, %tile, %tile_slice_idx :
      memref&lt;?x?xi32&gt;, vector&lt;[1]xi32&gt;, vector&lt;[4]x[4]xi32&gt;

The tile is zeroed the satisfy the padding and only active rows are
loaded.

The latter (non-zero pad) is lowered as follows:

  scf.for %slice_idx = %c0 to %num_tile_slices step %c1 {
    %row_is_active = arith.cmpi ult %slice_idx, %num_rows : index
    %slice = scf.if %row_is_active -&gt; vector&lt;[4]xf32&gt; {
      %slice = vector.maskedload %src[%slice_idx, %c0], %num_cols, %pad_1d :
        memref&lt;?x?xf32&gt;, vector&lt;[4]xi1&gt;, vector&lt;[4]xf32&gt; into vector&lt;[4]xf32&gt;
      scf.yield %slice : vector&lt;[4]xf32&gt;
    } else {
      scf.yield %pad_1d : vector&lt;[4]xf32&gt;
    }
    arm_sme.move_vector_to_tile_slice %slice, %tile, %slice_idx
      : vector&lt;[4]xi32&gt; into vector&lt;[4]x[4]xi32&gt;

The scalar pad is broadcast to a 1-D vector and a regular
'vector.masked_load' (will be lowered to SVE, not SME) loads each slice
for active rows, with padding specified as a passthru. For non-active
rows the slice is the 1-D pad. The resulting slice is inserted into the
tile with 'arm_sme.move_vector_to_tile_slice'.


Patch is 23.41 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/70915.diff

3 Files Affected:

  • (modified) mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp (+252-4)
  • (modified) mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir (+56)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir (+212)
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 80da6ffda1ed2ea..46e81bb935c406a 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -80,9 +80,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
   LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
                                 PatternRewriter &rewriter) const override {
     if (tileLoadOp.getMask())
-      // TODO: add masked patterns.
-      return rewriter.notifyMatchFailure(
-          tileLoadOp, "op has mask, needs masked pattern(s)");
+      return rewriter.notifyMatchFailure(tileLoadOp,
+                                         "op has mask, apply masked patterns");
 
     OpBuilder::InsertionGuard g(rewriter);
     auto loc = tileLoadOp.getLoc();
@@ -142,6 +141,254 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
   }
 };
 
+/// Lower `arm_sme.tile_load` with mask and pad of constant zero.
+///
+///  BEFORE:
+///  ```mlir
+///  %pad = arith.constant 0 : i32
+///  %num_rows = arith.constant 2 : index
+///  %num_cols = arith.constant 4 : index
+///  %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
+///  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
+///    memref<?x?xi32>, vector<[4]x[4]xi32>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %c0 = arith.constant 0 : index
+///  %c1 = arith.constant 1 : index
+///  %tile = arm_sme.zero : vector<[4]x[4]xi32>
+///  %num_cols = vector.create_mask %c4 : vector<[4]xi1>
+///  scf.for %tile_slice_idx = %c0 to %num_rows step %c1 {
+///    %tile_update = arm_sme.load_tile_slice
+///      %src[%tile_slice_idx], %num_cols, %tile, %tile_slice_idx :
+///      memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>
+///  }
+///  ```
+///
+/// NOTE: Only mask of 'vector.create_mask' op is currently supported.
+struct TileLoadOpWithMaskAndPadZeroConversion
+    : public OpRewritePattern<arm_sme::TileLoadOp> {
+  using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
+                                PatternRewriter &rewriter) const override {
+    OpBuilder::InsertionGuard g(rewriter);
+    auto loc = tileLoadOp.getLoc();
+    auto tileType = tileLoadOp.getVectorType();
+
+    auto maskOp = tileLoadOp.getMask();
+    if (!maskOp)
+      return rewriter.notifyMatchFailure(
+          tileLoadOp, "op has no mask, needs unmasked pattern");
+
+    auto padOp = tileLoadOp.getPadding();
+    assert(padOp && "expected padding when masking!");
+
+    auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
+    if (!createMaskOp)
+      return rewriter.notifyMatchFailure(
+          tileLoadOp, "unsupported mask op, only 'vector.create_mask' is "
+                      "currently supported");
+
+    auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
+    if (!constPadOp || constPadOp.getValue() !=
+                           rewriter.getZeroAttr(tileType.getElementType()))
+      return rewriter.notifyMatchFailure(
+          tileLoadOp, "op has non-zero pad, needs non-zero pad pattern");
+
+    auto numRows = createMaskOp.getOperands()[0];
+    auto numCols = createMaskOp.getOperands()[1];
+
+    auto predicateType =
+        VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
+    auto numColsOp =
+        rewriter.create<vector::CreateMaskOp>(loc, predicateType, numCols);
+
+    // Initialize tile with zero to satisfy padding. Inactive cols will be
+    // zeroed anyway since the loads use zeroing predication. For inactive rows
+    // however, no load will occur so these need to be zeroed.
+    auto tile = rewriter.create<arm_sme::ZeroOp>(loc, tileType);
+
+    // Create a loop to load the active tile slices from memory.
+    auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto upperBound = numRows;
+    auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+
+    rewriter.setInsertionPointToStart(forOp.getBody());
+
+    // Create 'arm_sme.load_tile_slice' to load tile slice from memory into
+    // tile.
+    SmallVector<Value> memrefIndices;
+    auto tileSliceIndex = forOp.getInductionVar();
+    getMemrefIndices(tileLoadOp.getIndices(),
+                     tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
+                     upperBound, memrefIndices, loc, rewriter);
+    rewriter.create<arm_sme::LoadTileSliceOp>(
+        loc, tileType, tileLoadOp.getBase(), numColsOp, tile, memrefIndices,
+        tileSliceIndex, tileLoadOp.getLayout());
+
+    rewriter.setInsertionPointAfter(forOp);
+
+    // Replace 'arm_sme.tile_load' with the tile.
+    rewriter.replaceOp(tileLoadOp, tile);
+
+    return success();
+  }
+};
+
+/// Lower `arm_sme.tile_load` with mask and non-zero pad.
+///
+///  BEFORE:
+///  ```mlir
+///  %pad = arith.constant 1 : i32
+///  %num_rows = arith.constant 2 : index
+///  %num_cols = arith.constant 4 : index
+///  %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
+///  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
+///    memref<?x?xi32>, vector<[4]x[4]xi32>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %pad_1d = arith.constant dense<1> : vector<[4]xi32>
+///  %num_rows = arith.constant 2 : index
+///  %num_cols = arith.constant 4 : index
+///  %tile_id = arm_sme.get_tile_id : i32
+///  %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
+///  %vscale = vector.vscale
+///  %c0 = arith.constant 0 : index
+///  %c1 = arith.constant 1 : index
+///  %min_svl_s = arith.constant 4 : index
+///  %svl_s = arith.muli %min_svl_s, %vscale : index
+///  scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
+///    %row_is_active = arith.cmpi ult %tile_slice_idx, %num_rows : index
+///    %slice = scf.if %row_is_active -> vector<[4]xi32> {
+///      %slice = vector.maskedload %base[%tile_slice_idx, %c0], %num_cols, %pad
+///        : memref<?x?xi32>, vector<[4]xi1>,
+///          vector<[4]xi32> into vector<[4]xi32>
+///      scf.yield %slice : vector<[4]xi32>
+///    } else {
+///      scf.yield %pad_1d : vector<[4]xi32>
+///    }
+///    // Insert slice into tile
+///    arm_sme.move_vector_to_tile_slice %slice, %tile, %tile_slice_idx
+///      : vector<[4]xi32> into vector<[4]x[4]xi32>
+///  }
+///  ```
+struct TileLoadOpWithMaskAndPadNonZeroConversion
+    : public OpRewritePattern<arm_sme::TileLoadOp> {
+  using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
+                                PatternRewriter &rewriter) const override {
+    OpBuilder::InsertionGuard g(rewriter);
+    auto loc = tileLoadOp.getLoc();
+    auto tileType = tileLoadOp.getVectorType();
+    auto tileElementType = tileType.getElementType();
+    unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
+
+    auto maskOp = tileLoadOp.getMask();
+    if (!maskOp)
+      return rewriter.notifyMatchFailure(
+          tileLoadOp, "op has no mask, needs unmasked pattern");
+
+    auto padOp = tileLoadOp.getPadding();
+    assert(padOp && "expected padding when masking!");
+
+    auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
+    if (!createMaskOp)
+      return rewriter.notifyMatchFailure(
+          tileLoadOp, "unsupported mask op, only 'vector.create_mask' is "
+                      "currently supported");
+
+    auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
+    if (constPadOp &&
+        constPadOp.getValue() == rewriter.getZeroAttr(tileElementType))
+      return rewriter.notifyMatchFailure(
+          tileLoadOp, "op has constant zero pad, needs zero pad pattern");
+
+    auto numRows = createMaskOp.getOperands()[0];
+    auto numCols = createMaskOp.getOperands()[1];
+
+    VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
+    auto predicateType =
+        VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
+    auto numColsOp =
+        rewriter.create<vector::CreateMaskOp>(loc, predicateType, numCols);
+
+    // Create 'arm_sme.get_tile' op.
+    auto tileId = rewriter.create<arm_sme::GetTileID>(
+        loc, rewriter.getIntegerType(tileElementWidth));
+
+    // Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type to
+    // use as input tile to 'arm_sme.load_tile_slice' ops.
+    auto tile =
+        rewriter.create<arm_sme::CastTileToVector>(loc, tileType, tileId);
+
+    // Create a loop that loads each ZA tile slice from memory.
+    auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
+        loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
+    auto vscale =
+        rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+    auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto numTileSlices =
+        rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+    auto forOp =
+        rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
+
+    rewriter.setInsertionPointToStart(forOp.getBody());
+
+    auto tileSliceIndex = forOp.getInductionVar();
+
+    auto rowIsActive = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
+
+    SmallVector<Value> memrefIndices;
+    getMemrefIndices(tileLoadOp.getIndices(),
+                     tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
+                     numTileSlices, memrefIndices, loc, rewriter);
+
+    // Splat pad into 1-D vector matching type of tile slice.
+    auto pad1DOp = rewriter.create<vector::SplatOp>(loc, tileSliceType, padOp);
+
+    Operation *slice = rewriter.create<scf::IfOp>(
+        loc, rowIsActive,
+        [&](OpBuilder &b, Location loc) {
+          // If the row is active, emit a masked load where the predicate is
+          // 'numCols'. Pad is used for inactive elements, taken from
+          // passthru.
+          auto loadSlice = rewriter.create<vector::MaskedLoadOp>(
+              loc, tileSliceType, tileLoadOp.getBase(), memrefIndices,
+              numColsOp, /*passthru=*/pad1DOp);
+          rewriter.create<scf::YieldOp>(loc, loadSlice->getResult(0));
+        },
+        [&](OpBuilder &b, Location loc) {
+          // Inactive rows are filled with pad.
+          rewriter.create<scf::YieldOp>(loc, pad1DOp.getResult());
+        });
+
+    // TODO: If the load is vertical the transpose can't be done in-flight with
+    // a regular (SVE) maskedload. Propagate layout to
+    // 'arm_sme.move_vector_to_tile_slice' below once it supports layout. This
+    // is currently broken.
+
+    // Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile.
+    rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
+        loc, tileType, slice->getResult(0), tile, tileSliceIndex,
+        tileLoadOp.getLayout());
+
+    rewriter.setInsertionPointAfter(forOp);
+
+    // Replace 'arm_sme.tile_load' with the tile.
+    rewriter.replaceOp(tileLoadOp, tile);
+
+    return success();
+  }
+};
+
 /// Lower `arm_sme.tile_store` to a loop over the tile slices and store each
 /// slice using `arm_sme.store_tile_slice`.
 ///
@@ -273,7 +520,8 @@ struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
 } // namespace
 
 void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns) {
-  patterns.add<TileLoadOpConversion, TileStoreOpConversion,
+  patterns.add<TileLoadOpConversion, TileLoadOpWithMaskAndPadZeroConversion,
+               TileLoadOpWithMaskAndPadNonZeroConversion, TileStoreOpConversion,
                TileVectorPrintOpConversion>(patterns.getContext());
 }
 
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index d61f588941b408c..55ea56f42c96ed9 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -33,6 +33,62 @@ func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
   return
 }
 
+// -----
+
+// CHECK-LABEL: func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(
+// CHECK-SAME:                                                          %[[SRC:.*]]: memref<?x?xi32>) {
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[NUM_ROWS:.*]] = arith.constant 3 : index
+// CHECK-DAG:     %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1>
+// CHECK-DAG:     %[[TILEZERO:.*]] = arm_sme.zero : vector<[4]x[4]xi32>
+// CHECK-NEXT:    scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_ROWS]] step %[[C1]] {
+// CHECK-NEXT:      %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
+// CHECK-NEXT:      arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[NUM_COLS]], %[[TILEZERO]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
+func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>) {
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %pad = arith.constant 0 : i32
+  %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
+  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(
+// CHECK-SAME:                                                             %[[SRC:.*]]: memref<?x?xi32>,
+// CHECK-SAME:                                                             %[[PAD:.*]]: i32) {
+// CHECK-DAG:     %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
+// CHECK-DAG:     %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG:     %[[NUM_ROWS:.*]] = arith.constant 3 : index
+// CHECK-DAG:     %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1>
+// CHECK-DAG:     %[[VSCALE:.*]] = vector.vscale
+// CHECK-NEXT:    %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
+// CHECK-NEXT:    scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
+// CHECK-NEXT:        %[[ROW_IS_ACTIVE:.*]] = arith.cmpi ult, %[[TILE_SLICE_INDEX]], %[[NUM_ROWS]] : index
+// CHECK-NEXT:        %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
+// CHECK:             %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32>
+// CHECK:             %[[SLICE:.*]] = scf.if %[[ROW_IS_ACTIVE]] -> (vector<[4]xi32>) {
+// CHECK:               %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[NUM_COLS]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
+// CHECK:               scf.yield %[[LOAD_SLICE]] : vector<[4]xi32>
+// CHECK:             } else {
+// CHECK:               scf.yield %[[PAD_1D]] : vector<[4]xi32>
+// CHECK:             }
+// CHECK:             arm_sme.move_vector_to_tile_slice %[[SLICE]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32>
+func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32>, %pad : i32) {
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
+  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // arm_sme.tile_store
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
new file mode 100644
index 000000000000000..644f90d950645b8
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
@@ -0,0 +1,212 @@
+// DEFINE: %{entry_point} = entry
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE: %{run} = %mcr_aarch64_cmd \
+// DEFINE:  -march=aarch64 -mattr=+sve,+sme \
+// DEFINE:  -e %{entry_point} -entry-point-result=void \
+// DEFINE:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile} | %{run} | FileCheck %s
+
+// Vector load.
+func.func @transfer_read_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %c4 = arith.constant 4 : index
+  %pad = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %A[%base1, %base2], %pad {in_bounds=[true, true]} :
+    memref<?x?xf32>, vector<[4]x[4]xf32>
+
+  vector.print str "TILE BEGIN:"
+  vector.print %0: vector<[4]x[4]xf32>
+
+  return
+}
+
+// Vector load + transpose.
+func.func @transfer_read_2d_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %pad = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %A[%base1, %base2], %pad
+    {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]}
+      : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+  vector.print str "TILE BEGIN:"
+  vector.print %0 : vector<[4]x[4]xf32>
+
+  return
+}
+
+// Vector load with mask and pad of zero.
+func.func @transfer_read_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %pad = arith.constant 0.0 : f32
+  %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+  %0 = vector.transfer_read %A[%base1, %base2], %pad, %mask
+    {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+  vector.print str "TILE BEGIN:"
+  vector.print %0: vector<[4]x[4]xf32>
+
+  return
+}
+
+// Vector load with mask and pad of zero + transpose.
+func.func @transfer_read_2d_mask_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %pad = arith.constant 0.0 : f32
+  %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+  %0 = vector.transfer_read %A[%base1, %base2], %pad, %mask
+    {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]}
+      : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+  vector.print str "TILE BEGIN:"
+  vector.print %0: vector<[4]x[4]xf32>
+
+  return
+}
+
+// Vector load with mask and non-zero pad.
+func.func @transfer_read_2d_mask_non_zero_pad(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %pad = arith.constant -42.0 : f32
+  %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+  %0 = vector.transfer_read %A[%base1, %base2], %pad, %mask
+    {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+  vector.print str "TILE BEGIN:"
+  vector.print %0: vector<[4]x[4]xf32>
+
+  return
+}
+
+// Vector load with mask and non-zero pad + transpose.
+func.func @transfer_read_2d_mask_non_zero_pad_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %pad = arith.constant -42.0 : f32
+  %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+  %0 = vector.transfer_read %A[%base1, %base2], %pad, %mask
+    {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]}
+      : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+  vector.print str "TILE BEGIN:"
+  vector.print %0: vector<[4]x[4]xf32>
+
+  return
+}
+
+// Allocate heap memory of size 'd0' x 'd1' and initialize.
+//
+// Example:
+//
+// initialize_memory(%c4, %c5)
+//
+//    0,  1,  2,  3,  4
+//   10, 11, 12, 13, 14
+//   20, 21, 22, 23, 24
+//   30, 31, 32, 33, 34
+//
+// Returns dynamic memref. It's the callers responsiblity to free the returned
+// memref.
+func.func @initialize_memor...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Nov 2, 2023

@llvm/pr-subscribers-mlir-vector

Author: Cullen Rhodes (c-rhodes)

Changes

This patch extends ArmSMEToSCF to support lowering of masked tile_load
ops. Only masks created by 'vector.create_mask' are currently supported.

There are two lowerings, one for pad of constant zero and another for
non-zero pad. For the following example:

  %pad = arith.constant 0 : i32
  %num_rows = arith.constant 2 : index
  %num_cols = arith.constant 4 : index
  %mask = vector.create_mask %num_rows, %num_cols : &lt;[4]x[4]xi1&gt;
  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref&lt;?x?xi32&gt;,
                                                          vector&lt;[4]x[4]xi32&gt;

The former (constant non-zero pad) is lowered as follows:

  %tile = arm_sme.zero : vector&lt;[4]x[4]xi32&gt;
  %num_cols = vector.create_mask %c4 : vector&lt;[4]xi1&gt;
  scf.for %slice_idx = %c0 to %num_rows step %c1
    %tile_update = arm_sme.load_tile_slice
      %src[%slice_idx], %num_cols, %tile, %tile_slice_idx :
      memref&lt;?x?xi32&gt;, vector&lt;[1]xi32&gt;, vector&lt;[4]x[4]xi32&gt;

The tile is zeroed the satisfy the padding and only active rows are
loaded.

The latter (non-zero pad) is lowered as follows:

  scf.for %slice_idx = %c0 to %num_tile_slices step %c1 {
    %row_is_active = arith.cmpi ult %slice_idx, %num_rows : index
    %slice = scf.if %row_is_active -&gt; vector&lt;[4]xf32&gt; {
      %slice = vector.maskedload %src[%slice_idx, %c0], %num_cols, %pad_1d :
        memref&lt;?x?xf32&gt;, vector&lt;[4]xi1&gt;, vector&lt;[4]xf32&gt; into vector&lt;[4]xf32&gt;
      scf.yield %slice : vector&lt;[4]xf32&gt;
    } else {
      scf.yield %pad_1d : vector&lt;[4]xf32&gt;
    }
    arm_sme.move_vector_to_tile_slice %slice, %tile, %slice_idx
      : vector&lt;[4]xi32&gt; into vector&lt;[4]x[4]xi32&gt;

The scalar pad is broadcast to a 1-D vector and a regular
'vector.masked_load' (will be lowered to SVE, not SME) loads each slice
for active rows, with padding specified as a passthru. For non-active
rows the slice is the 1-D pad. The resulting slice is inserted into the
tile with 'arm_sme.move_vector_to_tile_slice'.


Patch is 23.41 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/70915.diff

3 Files Affected:

  • (modified) mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp (+252-4)
  • (modified) mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir (+56)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir (+212)
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 80da6ffda1ed2ea..46e81bb935c406a 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -80,9 +80,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
   LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
                                 PatternRewriter &rewriter) const override {
     if (tileLoadOp.getMask())
-      // TODO: add masked patterns.
-      return rewriter.notifyMatchFailure(
-          tileLoadOp, "op has mask, needs masked pattern(s)");
+      return rewriter.notifyMatchFailure(tileLoadOp,
+                                         "op has mask, apply masked patterns");
 
     OpBuilder::InsertionGuard g(rewriter);
     auto loc = tileLoadOp.getLoc();
@@ -142,6 +141,254 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
   }
 };
 
+/// Lower `arm_sme.tile_load` with mask and pad of constant zero.
+///
+///  BEFORE:
+///  ```mlir
+///  %pad = arith.constant 0 : i32
+///  %num_rows = arith.constant 2 : index
+///  %num_cols = arith.constant 4 : index
+///  %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
+///  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
+///    memref<?x?xi32>, vector<[4]x[4]xi32>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %c0 = arith.constant 0 : index
+///  %c1 = arith.constant 1 : index
+///  %tile = arm_sme.zero : vector<[4]x[4]xi32>
+///  %num_cols = vector.create_mask %c4 : vector<[4]xi1>
+///  scf.for %tile_slice_idx = %c0 to %num_rows step %c1 {
+///    %tile_update = arm_sme.load_tile_slice
+///      %src[%tile_slice_idx], %num_cols, %tile, %tile_slice_idx :
+///      memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>
+///  }
+///  ```
+///
+/// NOTE: Only mask of 'vector.create_mask' op is currently supported.
+struct TileLoadOpWithMaskAndPadZeroConversion
+    : public OpRewritePattern<arm_sme::TileLoadOp> {
+  using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
+                                PatternRewriter &rewriter) const override {
+    OpBuilder::InsertionGuard g(rewriter);
+    auto loc = tileLoadOp.getLoc();
+    auto tileType = tileLoadOp.getVectorType();
+
+    auto maskOp = tileLoadOp.getMask();
+    if (!maskOp)
+      return rewriter.notifyMatchFailure(
+          tileLoadOp, "op has no mask, needs unmasked pattern");
+
+    auto padOp = tileLoadOp.getPadding();
+    assert(padOp && "expected padding when masking!");
+
+    auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
+    if (!createMaskOp)
+      return rewriter.notifyMatchFailure(
+          tileLoadOp, "unsupported mask op, only 'vector.create_mask' is "
+                      "currently supported");
+
+    auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
+    if (!constPadOp || constPadOp.getValue() !=
+                           rewriter.getZeroAttr(tileType.getElementType()))
+      return rewriter.notifyMatchFailure(
+          tileLoadOp, "op has non-zero pad, needs non-zero pad pattern");
+
+    auto numRows = createMaskOp.getOperands()[0];
+    auto numCols = createMaskOp.getOperands()[1];
+
+    auto predicateType =
+        VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
+    auto numColsOp =
+        rewriter.create<vector::CreateMaskOp>(loc, predicateType, numCols);
+
+    // Initialize tile with zero to satisfy padding. Inactive cols will be
+    // zeroed anyway since the loads use zeroing predication. For inactive rows
+    // however, no load will occur so these need to be zeroed.
+    auto tile = rewriter.create<arm_sme::ZeroOp>(loc, tileType);
+
+    // Create a loop to load the active tile slices from memory.
+    auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto upperBound = numRows;
+    auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+
+    rewriter.setInsertionPointToStart(forOp.getBody());
+
+    // Create 'arm_sme.load_tile_slice' to load tile slice from memory into
+    // tile.
+    SmallVector<Value> memrefIndices;
+    auto tileSliceIndex = forOp.getInductionVar();
+    getMemrefIndices(tileLoadOp.getIndices(),
+                     tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
+                     upperBound, memrefIndices, loc, rewriter);
+    rewriter.create<arm_sme::LoadTileSliceOp>(
+        loc, tileType, tileLoadOp.getBase(), numColsOp, tile, memrefIndices,
+        tileSliceIndex, tileLoadOp.getLayout());
+
+    rewriter.setInsertionPointAfter(forOp);
+
+    // Replace 'arm_sme.tile_load' with the tile.
+    rewriter.replaceOp(tileLoadOp, tile);
+
+    return success();
+  }
+};
+
+/// Lower `arm_sme.tile_load` with mask and non-zero pad.
+///
+///  BEFORE:
+///  ```mlir
+///  %pad = arith.constant 1 : i32
+///  %num_rows = arith.constant 2 : index
+///  %num_cols = arith.constant 4 : index
+///  %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
+///  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
+///    memref<?x?xi32>, vector<[4]x[4]xi32>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %pad_1d = arith.constant dense<1> : vector<[4]xi32>
+///  %num_rows = arith.constant 2 : index
+///  %num_cols = arith.constant 4 : index
+///  %tile_id = arm_sme.get_tile_id : i32
+///  %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
+///  %vscale = vector.vscale
+///  %c0 = arith.constant 0 : index
+///  %c1 = arith.constant 1 : index
+///  %min_svl_s = arith.constant 4 : index
+///  %svl_s = arith.muli %min_svl_s, %vscale : index
+///  scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
+///    %row_is_active = arith.cmpi ult %tile_slice_idx, %num_rows : index
+///    %slice = scf.if %row_is_active -> vector<[4]xi32> {
+///      %slice = vector.maskedload %base[%tile_slice_idx, %c0], %num_cols, %pad
+///        : memref<?x?xi32>, vector<[4]xi1>,
+///          vector<[4]xi32> into vector<[4]xi32>
+///      scf.yield %slice : vector<[4]xi32>
+///    } else {
+///      scf.yield %pad_1d : vector<[4]xi32>
+///    }
+///    // Insert slice into tile
+///    arm_sme.move_vector_to_tile_slice %slice, %tile, %tile_slice_idx
+///      : vector<[4]xi32> into vector<[4]x[4]xi32>
+///  }
+///  ```
+struct TileLoadOpWithMaskAndPadNonZeroConversion
+    : public OpRewritePattern<arm_sme::TileLoadOp> {
+  using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
+                                PatternRewriter &rewriter) const override {
+    OpBuilder::InsertionGuard g(rewriter);
+    auto loc = tileLoadOp.getLoc();
+    auto tileType = tileLoadOp.getVectorType();
+    auto tileElementType = tileType.getElementType();
+    unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
+
+    auto maskOp = tileLoadOp.getMask();
+    if (!maskOp)
+      return rewriter.notifyMatchFailure(
+          tileLoadOp, "op has no mask, needs unmasked pattern");
+
+    auto padOp = tileLoadOp.getPadding();
+    assert(padOp && "expected padding when masking!");
+
+    auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
+    if (!createMaskOp)
+      return rewriter.notifyMatchFailure(
+          tileLoadOp, "unsupported mask op, only 'vector.create_mask' is "
+                      "currently supported");
+
+    auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
+    if (constPadOp &&
+        constPadOp.getValue() == rewriter.getZeroAttr(tileElementType))
+      return rewriter.notifyMatchFailure(
+          tileLoadOp, "op has constant zero pad, needs zero pad pattern");
+
+    auto numRows = createMaskOp.getOperands()[0];
+    auto numCols = createMaskOp.getOperands()[1];
+
+    VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
+    auto predicateType =
+        VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
+    auto numColsOp =
+        rewriter.create<vector::CreateMaskOp>(loc, predicateType, numCols);
+
+    // Create 'arm_sme.get_tile' op.
+    auto tileId = rewriter.create<arm_sme::GetTileID>(
+        loc, rewriter.getIntegerType(tileElementWidth));
+
+    // Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type to
+    // use as input tile to 'arm_sme.load_tile_slice' ops.
+    auto tile =
+        rewriter.create<arm_sme::CastTileToVector>(loc, tileType, tileId);
+
+    // Create a loop that loads each ZA tile slice from memory.
+    auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
+        loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
+    auto vscale =
+        rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+    auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto numTileSlices =
+        rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+    auto forOp =
+        rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
+
+    rewriter.setInsertionPointToStart(forOp.getBody());
+
+    auto tileSliceIndex = forOp.getInductionVar();
+
+    auto rowIsActive = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
+
+    SmallVector<Value> memrefIndices;
+    getMemrefIndices(tileLoadOp.getIndices(),
+                     tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
+                     numTileSlices, memrefIndices, loc, rewriter);
+
+    // Splat pad into 1-D vector matching type of tile slice.
+    auto pad1DOp = rewriter.create<vector::SplatOp>(loc, tileSliceType, padOp);
+
+    Operation *slice = rewriter.create<scf::IfOp>(
+        loc, rowIsActive,
+        [&](OpBuilder &b, Location loc) {
+          // If the row is active, emit a masked load where the predicate is
+          // 'numCols'. Pad is used for inactive elements, taken from
+          // passthru.
+          auto loadSlice = rewriter.create<vector::MaskedLoadOp>(
+              loc, tileSliceType, tileLoadOp.getBase(), memrefIndices,
+              numColsOp, /*passthru=*/pad1DOp);
+          rewriter.create<scf::YieldOp>(loc, loadSlice->getResult(0));
+        },
+        [&](OpBuilder &b, Location loc) {
+          // Inactive rows are filled with pad.
+          rewriter.create<scf::YieldOp>(loc, pad1DOp.getResult());
+        });
+
+    // TODO: If the load is vertical the transpose can't be done in-flight with
+    // a regular (SVE) maskedload. Propagate layout to
+    // 'arm_sme.move_vector_to_tile_slice' below once it supports layout. This
+    // is currently broken.
+
+    // Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile.
+    rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
+        loc, tileType, slice->getResult(0), tile, tileSliceIndex,
+        tileLoadOp.getLayout());
+
+    rewriter.setInsertionPointAfter(forOp);
+
+    // Replace 'arm_sme.tile_load' with the tile.
+    rewriter.replaceOp(tileLoadOp, tile);
+
+    return success();
+  }
+};
+
 /// Lower `arm_sme.tile_store` to a loop over the tile slices and store each
 /// slice using `arm_sme.store_tile_slice`.
 ///
@@ -273,7 +520,8 @@ struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
 } // namespace
 
 void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns) {
-  patterns.add<TileLoadOpConversion, TileStoreOpConversion,
+  patterns.add<TileLoadOpConversion, TileLoadOpWithMaskAndPadZeroConversion,
+               TileLoadOpWithMaskAndPadNonZeroConversion, TileStoreOpConversion,
                TileVectorPrintOpConversion>(patterns.getContext());
 }
 
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index d61f588941b408c..55ea56f42c96ed9 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -33,6 +33,62 @@ func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
   return
 }
 
+// -----
+
+// CHECK-LABEL: func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(
+// CHECK-SAME:                                                          %[[SRC:.*]]: memref<?x?xi32>) {
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[NUM_ROWS:.*]] = arith.constant 3 : index
+// CHECK-DAG:     %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1>
+// CHECK-DAG:     %[[TILEZERO:.*]] = arm_sme.zero : vector<[4]x[4]xi32>
+// CHECK-NEXT:    scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_ROWS]] step %[[C1]] {
+// CHECK-NEXT:      %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
+// CHECK-NEXT:      arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[NUM_COLS]], %[[TILEZERO]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
+func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>) {
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %pad = arith.constant 0 : i32
+  %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
+  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(
+// CHECK-SAME:                                                             %[[SRC:.*]]: memref<?x?xi32>,
+// CHECK-SAME:                                                             %[[PAD:.*]]: i32) {
+// CHECK-DAG:     %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
+// CHECK-DAG:     %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG:     %[[NUM_ROWS:.*]] = arith.constant 3 : index
+// CHECK-DAG:     %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1>
+// CHECK-DAG:     %[[VSCALE:.*]] = vector.vscale
+// CHECK-NEXT:    %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
+// CHECK-NEXT:    scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
+// CHECK-NEXT:        %[[ROW_IS_ACTIVE:.*]] = arith.cmpi ult, %[[TILE_SLICE_INDEX]], %[[NUM_ROWS]] : index
+// CHECK-NEXT:        %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
+// CHECK:             %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32>
+// CHECK:             %[[SLICE:.*]] = scf.if %[[ROW_IS_ACTIVE]] -> (vector<[4]xi32>) {
+// CHECK:               %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[NUM_COLS]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
+// CHECK:               scf.yield %[[LOAD_SLICE]] : vector<[4]xi32>
+// CHECK:             } else {
+// CHECK:               scf.yield %[[PAD_1D]] : vector<[4]xi32>
+// CHECK:             }
+// CHECK:             arm_sme.move_vector_to_tile_slice %[[SLICE]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32>
+func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32>, %pad : i32) {
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
+  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // arm_sme.tile_store
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
new file mode 100644
index 000000000000000..644f90d950645b8
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
@@ -0,0 +1,212 @@
+// DEFINE: %{entry_point} = entry
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE: %{run} = %mcr_aarch64_cmd \
+// DEFINE:  -march=aarch64 -mattr=+sve,+sme \
+// DEFINE:  -e %{entry_point} -entry-point-result=void \
+// DEFINE:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile} | %{run} | FileCheck %s
+
+// Vector load.
+func.func @transfer_read_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %c4 = arith.constant 4 : index
+  %pad = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %A[%base1, %base2], %pad {in_bounds=[true, true]} :
+    memref<?x?xf32>, vector<[4]x[4]xf32>
+
+  vector.print str "TILE BEGIN:"
+  vector.print %0: vector<[4]x[4]xf32>
+
+  return
+}
+
+// Vector load + transpose.
+func.func @transfer_read_2d_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %pad = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %A[%base1, %base2], %pad
+    {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]}
+      : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+  vector.print str "TILE BEGIN:"
+  vector.print %0 : vector<[4]x[4]xf32>
+
+  return
+}
+
+// Vector load with mask and pad of zero.
+func.func @transfer_read_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %pad = arith.constant 0.0 : f32
+  %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+  %0 = vector.transfer_read %A[%base1, %base2], %pad, %mask
+    {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+  vector.print str "TILE BEGIN:"
+  vector.print %0: vector<[4]x[4]xf32>
+
+  return
+}
+
+// Vector load with mask and pad of zero + transpose.
+func.func @transfer_read_2d_mask_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %pad = arith.constant 0.0 : f32
+  %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+  %0 = vector.transfer_read %A[%base1, %base2], %pad, %mask
+    {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]}
+      : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+  vector.print str "TILE BEGIN:"
+  vector.print %0: vector<[4]x[4]xf32>
+
+  return
+}
+
+// Vector load with mask and non-zero pad.
+func.func @transfer_read_2d_mask_non_zero_pad(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %pad = arith.constant -42.0 : f32
+  %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+  %0 = vector.transfer_read %A[%base1, %base2], %pad, %mask
+    {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+  vector.print str "TILE BEGIN:"
+  vector.print %0: vector<[4]x[4]xf32>
+
+  return
+}
+
+// Vector load with mask and non-zero pad + transpose.
+func.func @transfer_read_2d_mask_non_zero_pad_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %pad = arith.constant -42.0 : f32
+  %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+  %0 = vector.transfer_read %A[%base1, %base2], %pad, %mask
+    {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]}
+      : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+  vector.print str "TILE BEGIN:"
+  vector.print %0: vector<[4]x[4]xf32>
+
+  return
+}
+
+// Allocate heap memory of size 'd0' x 'd1' and initialize.
+//
+// Example:
+//
+// initialize_memory(%c4, %c5)
+//
+//    0,  1,  2,  3,  4
+//   10, 11, 12, 13, 14
+//   20, 21, 22, 23, 24
+//   30, 31, 32, 33, 34
+//
+// Returns dynamic memref. It's the callers responsiblity to free the returned
+// memref.
+func.func @initialize_memor...
[truncated]

@dcaballe
Copy link
Contributor

dcaballe commented Nov 2, 2023

     %row_is_active = arith.cmpi ult %slice_idx, %num_rows : index
    %slice = scf.if %row_is_active -> vector<[4]xf32> {
      %slice = vector.maskedload %src[%slice_idx, %c0], %num_cols, %pad_1d :
        memref<?x?xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
      scf.yield %slice : vector<[4]xf32>
    } else {
      scf.yield %pad_1d : vector<[4]xf32>
    }

Could we avoid the if statement by anding the if condition with the mask that always execute the masked load?

@c-rhodes
Copy link
Collaborator Author

c-rhodes commented Nov 2, 2023

     %row_is_active = arith.cmpi ult %slice_idx, %num_rows : index
    %slice = scf.if %row_is_active -> vector<[4]xf32> {
      %slice = vector.maskedload %src[%slice_idx, %c0], %num_cols, %pad_1d :
        memref<?x?xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
      scf.yield %slice : vector<[4]xf32>
    } else {
      scf.yield %pad_1d : vector<[4]xf32>
    }

Could we avoid the if statement by anding the if condition with the mask that always execute the masked load?

I think and'ing %num_cols with sign-extension of %row_is_active should work, good spot will update this, cheers.

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, my comments/suggestions are nits, thanks!

@c-rhodes c-rhodes force-pushed the mlir-arm-sme-masked-tile-load-lowering branch from c805d05 to 676d299 Compare November 7, 2023 15:45
This patch extends ArmSMEToSCF to support lowering of masked tile_load
ops. Only masks created by 'vector.create_mask' are currently supported.

There are two lowerings, one for pad of constant zero and another for
non-zero pad. For the following example:

  %pad = arith.constant 0 : i32
  %num_rows = arith.constant 2 : index
  %num_cols = arith.constant 4 : index
  %mask = vector.create_mask %num_rows, %num_cols : <[4]x[4]xi1>
  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>,
                                                          vector<[4]x[4]xi32>

The former (constant non-zero pad) is lowered as follows:
---------------------------------------------------------

  %tile = arm_sme.zero : vector<[4]x[4]xi32>
  %num_cols = vector.create_mask %c4 : vector<[4]xi1>
  scf.for %slice_idx = %c0 to %num_rows step %c1
    %tile_update = arm_sme.load_tile_slice
      %src[%slice_idx], %num_cols, %tile, %tile_slice_idx :
      memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>

The tile is zeroed the satisfy the padding and only active rows are
loaded.

The latter (non-zero pad) is lowered as follows:
------------------------------------------------

  scf.for %slice_idx = %c0 to %num_tile_slices step %c1 {
    %row_is_active = arith.cmpi ult %slice_idx, %num_rows : index
    %slice = scf.if %row_is_active -> vector<[4]xf32> {
      %slice = vector.maskedload %src[%slice_idx, %c0], %num_cols, %pad_1d :
        memref<?x?xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
      scf.yield %slice : vector<[4]xf32>
    } else {
      scf.yield %pad_1d : vector<[4]xf32>
    }
    arm_sme.move_vector_to_tile_slice %slice, %tile, %slice_idx
      : vector<[4]xi32> into vector<[4]x[4]xi32>

The scalar pad is broadcast to a 1-D vector and a regular
'vector.masked_load' (will be lowered to SVE, not SME) loads each slice
for active rows, with padding specified as a passthru. For non-active
rows the slice is the 1-D pad. The resulting slice is inserted into the
tile with 'arm_sme.move_vector_to_tile_slice'.
@c-rhodes
Copy link
Collaborator Author

c-rhodes commented Nov 7, 2023

accidentally pulled in a compiler-rt docs change I had on main when rebasing, will remove.

@c-rhodes c-rhodes force-pushed the mlir-arm-sme-masked-tile-load-lowering branch from 676d299 to 9a5cdf0 Compare November 7, 2023 15:47
@c-rhodes c-rhodes merged commit 9783cf4 into llvm:main Nov 8, 2023
3 checks passed
@c-rhodes c-rhodes deleted the mlir-arm-sme-masked-tile-load-lowering branch November 8, 2023 09:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants