Skip to content

Commit

Permalink
[mlir][ArmSME] Add tile slice layout attr to vector <-> tile ops (#69186
Browse files Browse the repository at this point in the history
)

This is used in #69148 when lowering masked tile_store with non-zero
pad, see #69148

This updates:
 * `arm_sme.move_vector_to_tile_slice`
 * `arm_sme.move_tile_slice_to_vector`
  • Loading branch information
c-rhodes committed Oct 25, 2023
1 parent d9cfb82 commit 2f055dd
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 27 deletions.
36 changes: 22 additions & 14 deletions mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -441,21 +441,24 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
of a 2-D scalable vector tile at the given index. The type of the 1-D
scalable vector to be moved must match the type of the tile slice. A tile
slice is a 1-D vector of horizontally or vertically contiguous elements
within a ZA tile. Horizontal tile slices are currently assumed when
lowering to intrinsics. The updated tile is returned as the result.
within a ZA tile. The updated tile is returned as the result.

Example 1: Move a vector<[16]xi8> into tile at given index.
An optional tile slice layout attribute specifies whether the tile slice is
horizontal (default) or vertical.

Example 1: Move a vector<[16]xi8> into tile horizontally (default) at given index.
```mlir
%tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[16]xi8> into vector<[16]x[16]xi8>
```

Example 2: Move a vector<[2]xf64> into tile at given index.
Example 2: Move a vector<[2]xf64> into tile vertically at given index.
```mlir
%tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[2]xf64> into vector<[2]x[2]xf64>
%tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[2]xf64> into vector<[2]x[2]xf64>
```
}];
let arguments = (ins
SVEVector:$vector, SMETile:$tile, Index:$tile_slice_index);
SVEVector:$vector, SMETile:$tile, Index:$tile_slice_index,
ArmSME_TileSliceLayoutAttr:$layout);
let results = (outs SMETile:$result);

let extraClassDeclaration = [{
Expand All @@ -465,7 +468,7 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
}];

let assemblyFormat = [{
$vector `,` $tile `,` $tile_slice_index
$vector `,` $tile `,` $tile_slice_index (`layout` `` $layout^)?
attr-dict `:` type($vector) `into` type($result)
}];
}
Expand All @@ -480,29 +483,34 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
let description = [{
The tile slice to vector operation extracts a 1-D scalable slice from a 2-D
scalable tile at the given index. A tile slice is a 1-D vector of
horizontally or vertically contiguous elements within a ZA tile. Horizontal
tile slices are currently assumed when lowering to intrinsics.
horizontally or vertically contiguous elements within a ZA tile.

An optional tile slice layout attribute specifies whether the tile slice is
horizontal (default) or vertical.

Example 1: Extract `vector<[16]xi8>` from tile at the given index.
Example 1: Extract `vector<[16]xi8>` from tile horizontally at the given index.
```mlir
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[16]xi8> from vector<[16]x[16]xi8>
```

Example 2: Extract `vector<[2]xf64>` from tile at the given index.
Example 2: Extract `vector<[2]xf64>` from tile vertically at the given index.
```mlir
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[2]xf64> from vector<[2]x[2]xf64>
```
}];

let arguments = (ins SMETile:$tile, Index:$tile_slice_index);
let arguments = (ins
SMETile:$tile, Index:$tile_slice_index,
ArmSME_TileSliceLayoutAttr:$layout
);
let results = (outs SVEVector:$result);

let extraClassDeclaration = [{
VectorType getSliceType() { return getResult().getType(); }
}];

let assemblyFormat = [{
$tile `[` $tile_slice_index `]` attr-dict
$tile `[` $tile_slice_index `]` (`layout` `` $layout^)? attr-dict
`:` type($result) `from` type($tile)
}];
}
Expand Down
43 changes: 30 additions & 13 deletions mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,7 @@ struct StoreTileSliceToArmSMELowering
}
};

/// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics. Only horizontal
/// tile slices are currently supported.
/// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics.
struct MoveVectorToTileSliceToArmSMELowering
: public ConvertOpToLLVMPattern<arm_sme::MoveVectorToTileSliceOp> {
using ConvertOpToLLVMPattern<
Expand Down Expand Up @@ -388,10 +387,19 @@ struct MoveVectorToTileSliceToArmSMELowering

auto tileI32 = castTileIDToI32(tile, loc, rewriter);

// Create 'arm_sme.intr.write.horiz' to write vector to tile slice.
rewriter.create<arm_sme::aarch64_sme_write_horiz>(
loc, tileI32, tileSliceI32, allActiveMask,
moveVectorToTileSliceOp.getVector());
// Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
switch (moveVectorToTileSliceOp.getLayout()) {
case arm_sme::TileSliceLayout::Horizontal:
rewriter.create<arm_sme::aarch64_sme_write_horiz>(
loc, tileI32, tileSliceI32, allActiveMask,
moveVectorToTileSliceOp.getVector());
break;
case arm_sme::TileSliceLayout::Vertical:
rewriter.create<arm_sme::aarch64_sme_write_vert>(
loc, tileI32, tileSliceI32, allActiveMask,
moveVectorToTileSliceOp.getVector());
break;
}

// Intrinsic has no result, replace 'arm_sme.move_vector_to_tile_slice' with
// 'arm_sme.cast_tile_to_vector' to preserve dataflow.
Expand All @@ -402,8 +410,7 @@ struct MoveVectorToTileSliceToArmSMELowering
}
};

/// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics. Only horizontal
/// tile slices are currently supported.
/// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics.
struct MoveTileSliceToVectorArmSMELowering
: public ConvertOpToLLVMPattern<arm_sme::MoveTileSliceToVectorOp> {
using ConvertOpToLLVMPattern<
Expand Down Expand Up @@ -435,10 +442,19 @@ struct MoveTileSliceToVectorArmSMELowering
auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getI32Type(), sliceIndex);

// Create 'arm_sme.intr.read.horiz' to extract the tile slice.
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
tileIdI32, sliceIndexI32);
// Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice.
switch (moveTileSliceToVector.getLayout()) {
case arm_sme::TileSliceLayout::Horizontal:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
tileIdI32, sliceIndexI32);
break;
case arm_sme::TileSliceLayout::Vertical:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_vert>(
moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
tileIdI32, sliceIndexI32);
break;
}

return success();
}
Expand Down Expand Up @@ -680,7 +696,8 @@ void mlir::configureArmSMELegalizeForExportTarget(
arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_mopa,
arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa,
arm_sme::aarch64_sme_za_enable, arm_sme::aarch64_sme_za_disable>();
target.addLegalOp<GetTileID>();
target.addIllegalOp<vector::OuterProductOp>();
Expand Down
32 changes: 32 additions & 0 deletions mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,29 @@ func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_s
return
}

//===----------------------------------------------------------------------===//
// arm_sme.move_vector_to_tile_slice
//===----------------------------------------------------------------------===//

// -----

// CHECK-LABEL: @arm_sme_move_vector_to_tile_slice_hor_i32
// CHECK: "arm_sme.intr.write.horiz"({{.*}}) : (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
func.func @arm_sme_move_vector_to_tile_slice_hor_i32(%vector : vector<[4]xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) -> () {
%c0 = arith.constant 0 : index
arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
return
}

// -----

// CHECK-LABEL: @arm_sme_move_vector_to_tile_slice_ver_bf16
// CHECK: "arm_sme.intr.write.vert"({{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
func.func @arm_sme_move_vector_to_tile_slice_ver_bf16(%vector : vector<[8]xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) -> () {
%c0 = arith.constant 0 : index
arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[8]xbf16> into vector<[8]x[8]xbf16>
return
}

//===----------------------------------------------------------------------===//
// arm_sme.move_tile_slice_to_vector
Expand Down Expand Up @@ -485,3 +508,12 @@ func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %t
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
return %slice : vector<[2]xf64>
}

// -----

// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_ver_i128
// CHECK: "arm_sme.intr.read.vert"({{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
func.func @arm_sme_move_tile_slice_to_vector_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index) -> vector<[1]xi128> {
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[1]xi128> from vector<[1]x[1]xi128>
return %slice : vector<[1]xi128>
}
16 changes: 16 additions & 0 deletions mlir/test/Dialect/ArmSME/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1059,6 +1059,14 @@ func.func @arm_sme_move_vector_to_tile_slice_f64(%vector : vector<[2]xf64>, %til
return
}

// -----

func.func @arm_sme_move_vector_to_tile_slice_ver_i8(%vector : vector<[16]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> () {
// CHECK: arm_sme.move_vector_to_tile_slice {{.*}} layout<vertical> : vector<[16]xi8> into vector<[16]x[16]xi8>
%c0 = arith.constant 0 : index
arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[16]xi8> into vector<[16]x[16]xi8>
return
}

//===----------------------------------------------------------------------===//
// arm_sme.move_tile_slice_to_vector
Expand Down Expand Up @@ -1135,3 +1143,11 @@ func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %t
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
return %slice : vector<[2]xf64>
}

// -----

func.func @arm_sme_move_tile_slice_to_vector_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index) -> vector<[2]xf64> {
// CHECK: arm_sme.move_tile_slice_to_vector {{.*}} layout<vertical> : vector<[2]xf64> from vector<[2]x[2]xf64>
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[2]xf64> from vector<[2]x[2]xf64>
return %slice : vector<[2]xf64>
}

0 comments on commit 2f055dd

Please sign in to comment.