Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][ArmSME] Add arm_sme.streaming_vl operation #77321

Merged
merged 3 commits into from Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
44 changes: 44 additions & 0 deletions mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
Expand Up @@ -223,6 +223,21 @@ def ArmSME_CombiningKindAttr : EnumAttr<ArmSME_Dialect, CombiningKind,
let defaultValue = "CombiningKind::Add";
}

def TypeSize : I32EnumAttr<"TypeSize", "Size of a vector element type", [
I32EnumAttrCase<"Byte" , 0, "byte">,
I32EnumAttrCase<"Half" , 1, "half">,
I32EnumAttrCase<"Word" , 2, "word">,
I32EnumAttrCase<"Double", 3, "double">,
]> {
let cppNamespace = "::mlir::arm_sme";
let genSpecializedAttr = 0;
}

def ArmSME_TypeSizeAttr : EnumAttr<ArmSME_Dialect, TypeSize,
"type_size"> {
let assemblyFormat = "`<` $value `>`";
}

//===----------------------------------------------------------------------===//
// ArmSME op definitions
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -768,4 +783,33 @@ let arguments = (ins
}];
}

def StreamingVLOp : ArmSME_Op<"streaming_vl", [Pure]>
{
let summary = "Query the streaming vector length";

let description = [{
This operation returns the streaming vector length (SVL) for a given type
size. Unlike `vector.vscale` the value returned is invariant to the
streaming mode.

Example:
```mlir
// Streaming vector length in:
// - bytes (8-bit, SVL.B)
%svl_b = arm_sme.streaming_vl <byte>
// - half words (16-bit, SVL.H)
%svl_h = arm_sme.streaming_vl <half>
// - words (32-bit, SVL.W)
%svl_w = arm_sme.streaming_vl <word>
// - double words (64-bit, SVL.D)
%svl_d = arm_sme.streaming_vl <double>
```
}];

let arguments = (ins ArmSME_TypeSizeAttr: $type_size);
let results = (outs Index);

let assemblyFormat = "$type_size attr-dict";
}

#endif // ARMSME_OPS
47 changes: 44 additions & 3 deletions mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
Expand Up @@ -518,6 +518,45 @@ struct OuterProductOpConversion
}
};

/// Lower `arm_sme.streaming_vl` to SME CNTS intrinsics.
///
/// Example:
///
/// %0 = arm_sme.streaming_vl <half>
///
/// is converted to:
///
/// %cnt = "arm_sme.intr.cntsh"() : () -> i64
/// %0 = arith.index_cast %cnt : i64 to index
///
struct StreamingVLOpConversion
: public ConvertOpToLLVMPattern<arm_sme::StreamingVLOp> {
using ConvertOpToLLVMPattern<arm_sme::StreamingVLOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(arm_sme::StreamingVLOp streamingVlOp,
arm_sme::StreamingVLOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = streamingVlOp.getLoc();
auto i64Type = rewriter.getI64Type();
auto *intrOp = [&]() -> Operation * {
switch (streamingVlOp.getTypeSize()) {
case arm_sme::TypeSize::Byte:
return rewriter.create<arm_sme::aarch64_sme_cntsb>(loc, i64Type);
case arm_sme::TypeSize::Half:
return rewriter.create<arm_sme::aarch64_sme_cntsh>(loc, i64Type);
case arm_sme::TypeSize::Word:
return rewriter.create<arm_sme::aarch64_sme_cntsw>(loc, i64Type);
case arm_sme::TypeSize::Double:
return rewriter.create<arm_sme::aarch64_sme_cntsd>(loc, i64Type);
}
}();
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
streamingVlOp, rewriter.getIndexType(), intrOp->getResult(0));
return success();
}
};

} // namespace

namespace {
Expand Down Expand Up @@ -555,7 +594,9 @@ void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>();
arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa,
arm_sme::aarch64_sme_cntsb, arm_sme::aarch64_sme_cntsh,
arm_sme::aarch64_sme_cntsw, arm_sme::aarch64_sme_cntsd>();
target.addLegalDialect<arith::ArithDialect>();
target.addLegalOp<UnrealizedConversionCastOp>();
}
Expand All @@ -572,8 +613,8 @@ void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,

patterns.add<LoadTileSliceConversion, MoveTileSliceToVectorConversion,
MoveVectorToTileSliceConversion, StoreTileSliceConversion,
OuterProductOpConversion, ZeroOpConversion, GetTileConversion>(
converter);
OuterProductOpConversion, ZeroOpConversion, GetTileConversion,
StreamingVLOpConversion>(converter);
}

std::unique_ptr<Pass> mlir::createConvertArmSMEToLLVMPass() {
Expand Down
42 changes: 42 additions & 0 deletions mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
Expand Up @@ -559,3 +559,45 @@ func.func @arm_sme_move_tile_slice_to_vector_ver_i128(%tile_slice_index : index)
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[1]xi128> from vector<[1]x[1]xi128>
return %slice : vector<[1]xi128>
}

//===----------------------------------------------------------------------===//
// arm_sme.streaming_vl
//===----------------------------------------------------------------------===//

// -----

// CHECK-LABEL: @arm_sme_streaming_vl_bytes
// CHECK: %[[COUNT:.*]] = "arm_sme.intr.cntsb"() : () -> i64
// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[COUNT]] : i64 to index
// CHECK: return %[[INDEX_COUNT]] : index
func.func @arm_sme_streaming_vl_bytes() -> index {
%svl_b = arm_sme.streaming_vl <byte>
return %svl_b : index
}

// -----

// CHECK-LABEL: @arm_sme_streaming_vl_half_words
// CHECK: "arm_sme.intr.cntsh"() : () -> i64
func.func @arm_sme_streaming_vl_half_words() -> index {
%svl_h = arm_sme.streaming_vl <half>
return %svl_h : index
}

// -----

// CHECK-LABEL: @arm_sme_streaming_vl_words
// CHECK: "arm_sme.intr.cntsw"() : () -> i64
func.func @arm_sme_streaming_vl_words() -> index {
%svl_w = arm_sme.streaming_vl <word>
return %svl_w : index
}

// -----

// CHECK-LABEL: @arm_sme_streaming_vl_double_words
// CHECK: "arm_sme.intr.cntsd"() : () -> i64
func.func @arm_sme_streaming_vl_double_words() -> index {
%svl_d = arm_sme.streaming_vl <double>
return %svl_d : index
}
36 changes: 36 additions & 0 deletions mlir/test/Dialect/ArmSME/roundtrip.mlir
Expand Up @@ -1095,3 +1095,39 @@ func.func @arm_sme_outerproduct_with_everything(%vecA: vector<[16]xi8>, %vecB: v
%result = arm_sme.outerproduct %vecA, %vecB kind<sub> acc(%acc) masks(%maskA, %maskB) : vector<[16]xi8>, vector<[16]xi8>
return %result : vector<[16]x[16]xi8>
}

//===----------------------------------------------------------------------===//
// arm_sme.streaming_vl
//===----------------------------------------------------------------------===//

// -----

func.func @arm_sme_streaming_vl_bytes() -> index {
// CHECK: arm_sme.streaming_vl <byte>
%svl_b = arm_sme.streaming_vl <byte>
return %svl_b : index
}

// -----

func.func @arm_sme_streaming_vl_half_words() -> index {
// CHECK: arm_sme.streaming_vl <half>
%svl_h = arm_sme.streaming_vl <half>
return %svl_h : index
}

// -----

func.func @arm_sme_streaming_vl_words() -> index {
// CHECK: arm_sme.streaming_vl <word>
%svl_w = arm_sme.streaming_vl <word>
return %svl_w : index
}

// -----

func.func @arm_sme_streaming_vl_double_words() -> index {
// CHECK: arm_sme.streaming_vl <double>
%svl_d = arm_sme.streaming_vl <double>
return %svl_d : index
}