Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][ArmSME] Add tile slice layout attr to vector <-> tile ops #69186

Merged
merged 1 commit into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
36 changes: 22 additions & 14 deletions mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -441,21 +441,24 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
of a 2-D scalable vector tile at the given index. The type of the 1-D
scalable vector to be moved must match the type of the tile slice. A tile
slice is a 1-D vector of horizontally or vertically contiguous elements
within a ZA tile. Horizontal tile slices are currently assumed when
lowering to intrinsics. The updated tile is returned as the result.
within a ZA tile. The updated tile is returned as the result.

Example 1: Move a vector<[16]xi8> into tile at given index.
An optional tile slice layout attribute specifies whether the tile slice is
horizontal (default) or vertical.

Example 1: Move a vector<[16]xi8> into tile horizontally (default) at given index.
```mlir
%tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[16]xi8> into vector<[16]x[16]xi8>
```

Example 2: Move a vector<[2]xf64> into tile at given index.
Example 2: Move a vector<[2]xf64> into tile vertically at given index.
```mlir
%tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[2]xf64> into vector<[2]x[2]xf64>
%tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[2]xf64> into vector<[2]x[2]xf64>
```
}];
let arguments = (ins
SVEVector:$vector, SMETile:$tile, Index:$tile_slice_index);
SVEVector:$vector, SMETile:$tile, Index:$tile_slice_index,
ArmSME_TileSliceLayoutAttr:$layout);
let results = (outs SMETile:$result);

let extraClassDeclaration = [{
Expand All @@ -465,7 +468,7 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
}];

let assemblyFormat = [{
$vector `,` $tile `,` $tile_slice_index
$vector `,` $tile `,` $tile_slice_index (`layout` `` $layout^)?
attr-dict `:` type($vector) `into` type($result)
}];
}
Expand All @@ -480,29 +483,34 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
let description = [{
The tile slice to vector operation extracts a 1-D scalable slice from a 2-D
scalable tile at the given index. A tile slice is a 1-D vector of
horizontally or vertically contiguous elements within a ZA tile. Horizontal
tile slices are currently assumed when lowering to intrinsics.
horizontally or vertically contiguous elements within a ZA tile.

An optional tile slice layout attribute specifies whether the tile slice is
horizontal (default) or vertical.

Example 1: Extract `vector<[16]xi8>` from tile at the given index.
Example 1: Extract `vector<[16]xi8>` from tile horizontally at the given index.
```mlir
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[16]xi8> from vector<[16]x[16]xi8>
```

Example 2: Extract `vector<[2]xf64>` from tile at the given index.
Example 2: Extract `vector<[2]xf64>` from tile vertically at the given index.
```mlir
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[2]xf64> from vector<[2]x[2]xf64>
```
}];

let arguments = (ins SMETile:$tile, Index:$tile_slice_index);
let arguments = (ins
SMETile:$tile, Index:$tile_slice_index,
ArmSME_TileSliceLayoutAttr:$layout
);
let results = (outs SVEVector:$result);

let extraClassDeclaration = [{
VectorType getSliceType() { return getResult().getType(); }
}];

let assemblyFormat = [{
$tile `[` $tile_slice_index `]` attr-dict
$tile `[` $tile_slice_index `]` (`layout` `` $layout^)? attr-dict
`:` type($result) `from` type($tile)
}];
}
Expand Down
43 changes: 30 additions & 13 deletions mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,7 @@ struct StoreTileSliceToArmSMELowering
}
};

/// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics. Only horizontal
/// tile slices are currently supported.
/// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics.
struct MoveVectorToTileSliceToArmSMELowering
: public ConvertOpToLLVMPattern<arm_sme::MoveVectorToTileSliceOp> {
using ConvertOpToLLVMPattern<
Expand Down Expand Up @@ -388,10 +387,19 @@ struct MoveVectorToTileSliceToArmSMELowering

auto tileI32 = castTileIDToI32(tile, loc, rewriter);

// Create 'arm_sme.intr.write.horiz' to write vector to tile slice.
rewriter.create<arm_sme::aarch64_sme_write_horiz>(
loc, tileI32, tileSliceI32, allActiveMask,
moveVectorToTileSliceOp.getVector());
// Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
switch (moveVectorToTileSliceOp.getLayout()) {
case arm_sme::TileSliceLayout::Horizontal:
rewriter.create<arm_sme::aarch64_sme_write_horiz>(
loc, tileI32, tileSliceI32, allActiveMask,
moveVectorToTileSliceOp.getVector());
break;
case arm_sme::TileSliceLayout::Vertical:
rewriter.create<arm_sme::aarch64_sme_write_vert>(
loc, tileI32, tileSliceI32, allActiveMask,
moveVectorToTileSliceOp.getVector());
break;
}

// Intrinsic has no result, replace 'arm_sme.move_vector_to_tile_slice' with
// 'arm_sme.cast_tile_to_vector' to preserve dataflow.
Expand All @@ -402,8 +410,7 @@ struct MoveVectorToTileSliceToArmSMELowering
}
};

/// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics. Only horizontal
/// tile slices are currently supported.
/// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics.
struct MoveTileSliceToVectorArmSMELowering
: public ConvertOpToLLVMPattern<arm_sme::MoveTileSliceToVectorOp> {
using ConvertOpToLLVMPattern<
Expand Down Expand Up @@ -435,10 +442,19 @@ struct MoveTileSliceToVectorArmSMELowering
auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getI32Type(), sliceIndex);

// Create 'arm_sme.intr.read.horiz' to extract the tile slice.
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
tileIdI32, sliceIndexI32);
// Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice.
switch (moveTileSliceToVector.getLayout()) {
case arm_sme::TileSliceLayout::Horizontal:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
tileIdI32, sliceIndexI32);
break;
case arm_sme::TileSliceLayout::Vertical:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_vert>(
moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
tileIdI32, sliceIndexI32);
break;
}

return success();
}
Expand Down Expand Up @@ -680,7 +696,8 @@ void mlir::configureArmSMELegalizeForExportTarget(
arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_mopa,
arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa,
arm_sme::aarch64_sme_za_enable, arm_sme::aarch64_sme_za_disable>();
target.addLegalOp<GetTileID>();
target.addIllegalOp<vector::OuterProductOp>();
Expand Down
32 changes: 32 additions & 0 deletions mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,29 @@ func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_s
return
}

//===----------------------------------------------------------------------===//
// arm_sme.move_vector_to_tile_slice
//===----------------------------------------------------------------------===//

// -----

// CHECK-LABEL: @arm_sme_move_vector_to_tile_slice_hor_i32
// CHECK: "arm_sme.intr.write.horiz"({{.*}}) : (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
func.func @arm_sme_move_vector_to_tile_slice_hor_i32(%vector : vector<[4]xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) -> () {
%c0 = arith.constant 0 : index
arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
return
}

// -----

// CHECK-LABEL: @arm_sme_move_vector_to_tile_slice_ver_bf16
// CHECK: "arm_sme.intr.write.vert"({{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
func.func @arm_sme_move_vector_to_tile_slice_ver_bf16(%vector : vector<[8]xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) -> () {
%c0 = arith.constant 0 : index
arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[8]xbf16> into vector<[8]x[8]xbf16>
return
}

//===----------------------------------------------------------------------===//
// arm_sme.move_tile_slice_to_vector
Expand Down Expand Up @@ -485,3 +508,12 @@ func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %t
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
return %slice : vector<[2]xf64>
}

// -----

// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_ver_i128
// CHECK: "arm_sme.intr.read.vert"({{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
func.func @arm_sme_move_tile_slice_to_vector_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index) -> vector<[1]xi128> {
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[1]xi128> from vector<[1]x[1]xi128>
return %slice : vector<[1]xi128>
}
16 changes: 16 additions & 0 deletions mlir/test/Dialect/ArmSME/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1059,6 +1059,14 @@ func.func @arm_sme_move_vector_to_tile_slice_f64(%vector : vector<[2]xf64>, %til
return
}

// -----

func.func @arm_sme_move_vector_to_tile_slice_ver_i8(%vector : vector<[16]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> () {
// CHECK: arm_sme.move_vector_to_tile_slice {{.*}} layout<vertical> : vector<[16]xi8> into vector<[16]x[16]xi8>
%c0 = arith.constant 0 : index
arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[16]xi8> into vector<[16]x[16]xi8>
return
}

//===----------------------------------------------------------------------===//
// arm_sme.move_tile_slice_to_vector
Expand Down Expand Up @@ -1135,3 +1143,11 @@ func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %t
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
return %slice : vector<[2]xf64>
}

// -----

func.func @arm_sme_move_tile_slice_to_vector_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index) -> vector<[2]xf64> {
// CHECK: arm_sme.move_tile_slice_to_vector {{.*}} layout<vertical> : vector<[2]xf64> from vector<[2]x[2]xf64>
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[2]xf64> from vector<[2]x[2]xf64>
return %slice : vector<[2]xf64>
}