Skip to content

Commit

Permalink
[mlir][ArmSME] Add arm_sme.streaming_vl operation (#77321)
Browse files Browse the repository at this point in the history
This operation provides a convenient way to query the streaming vector
length regardless of the streaming mode. This most useful for functions
that call/pass data to streaming functions, but are not streaming
themselves.

Example:
```mlir
%svl_w = arm_sme.streaming_vl <word>
```

Created based on discussion here:
#76086 (comment)
  • Loading branch information
MacDue committed Jan 10, 2024
1 parent 38394a3 commit 53d4890
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 3 deletions.
44 changes: 44 additions & 0 deletions mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
Expand Up @@ -223,6 +223,21 @@ def ArmSME_CombiningKindAttr : EnumAttr<ArmSME_Dialect, CombiningKind,
let defaultValue = "CombiningKind::Add";
}

def TypeSize : I32EnumAttr<"TypeSize", "Size of a vector element type", [
I32EnumAttrCase<"Byte" , 0, "byte">,
I32EnumAttrCase<"Half" , 1, "half">,
I32EnumAttrCase<"Word" , 2, "word">,
I32EnumAttrCase<"Double", 3, "double">,
]> {
let cppNamespace = "::mlir::arm_sme";
let genSpecializedAttr = 0;
}

def ArmSME_TypeSizeAttr : EnumAttr<ArmSME_Dialect, TypeSize,
"type_size"> {
let assemblyFormat = "`<` $value `>`";
}

//===----------------------------------------------------------------------===//
// ArmSME op definitions
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -768,4 +783,33 @@ let arguments = (ins
}];
}

def StreamingVLOp : ArmSME_Op<"streaming_vl", [Pure]>
{
let summary = "Query the streaming vector length";

let description = [{
This operation returns the streaming vector length (SVL) for a given type
size. Unlike `vector.vscale` the value returned is invariant to the
streaming mode.

Example:
```mlir
// Streaming vector length in:
// - bytes (8-bit, SVL.B)
%svl_b = arm_sme.streaming_vl <byte>
// - half words (16-bit, SVL.H)
%svl_h = arm_sme.streaming_vl <half>
// - words (32-bit, SVL.W)
%svl_w = arm_sme.streaming_vl <word>
// - double words (64-bit, SVL.D)
%svl_d = arm_sme.streaming_vl <double>
```
}];

let arguments = (ins ArmSME_TypeSizeAttr: $type_size);
let results = (outs Index);

let assemblyFormat = "$type_size attr-dict";
}

#endif // ARMSME_OPS
47 changes: 44 additions & 3 deletions mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
Expand Up @@ -518,6 +518,45 @@ struct OuterProductOpConversion
}
};

/// Lower `arm_sme.streaming_vl` to SME CNTS intrinsics.
///
/// Example:
///
/// %0 = arm_sme.streaming_vl <half>
///
/// is converted to:
///
/// %cnt = "arm_sme.intr.cntsh"() : () -> i64
/// %0 = arith.index_cast %cnt : i64 to index
///
struct StreamingVLOpConversion
: public ConvertOpToLLVMPattern<arm_sme::StreamingVLOp> {
using ConvertOpToLLVMPattern<arm_sme::StreamingVLOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(arm_sme::StreamingVLOp streamingVlOp,
arm_sme::StreamingVLOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = streamingVlOp.getLoc();
auto i64Type = rewriter.getI64Type();
auto *intrOp = [&]() -> Operation * {
switch (streamingVlOp.getTypeSize()) {
case arm_sme::TypeSize::Byte:
return rewriter.create<arm_sme::aarch64_sme_cntsb>(loc, i64Type);
case arm_sme::TypeSize::Half:
return rewriter.create<arm_sme::aarch64_sme_cntsh>(loc, i64Type);
case arm_sme::TypeSize::Word:
return rewriter.create<arm_sme::aarch64_sme_cntsw>(loc, i64Type);
case arm_sme::TypeSize::Double:
return rewriter.create<arm_sme::aarch64_sme_cntsd>(loc, i64Type);
}
}();
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
streamingVlOp, rewriter.getIndexType(), intrOp->getResult(0));
return success();
}
};

} // namespace

namespace {
Expand Down Expand Up @@ -555,7 +594,9 @@ void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>();
arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa,
arm_sme::aarch64_sme_cntsb, arm_sme::aarch64_sme_cntsh,
arm_sme::aarch64_sme_cntsw, arm_sme::aarch64_sme_cntsd>();
target.addLegalDialect<arith::ArithDialect>();
target.addLegalOp<UnrealizedConversionCastOp>();
}
Expand All @@ -572,8 +613,8 @@ void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,

patterns.add<LoadTileSliceConversion, MoveTileSliceToVectorConversion,
MoveVectorToTileSliceConversion, StoreTileSliceConversion,
OuterProductOpConversion, ZeroOpConversion, GetTileConversion>(
converter);
OuterProductOpConversion, ZeroOpConversion, GetTileConversion,
StreamingVLOpConversion>(converter);
}

std::unique_ptr<Pass> mlir::createConvertArmSMEToLLVMPass() {
Expand Down
42 changes: 42 additions & 0 deletions mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
Expand Up @@ -559,3 +559,45 @@ func.func @arm_sme_move_tile_slice_to_vector_ver_i128(%tile_slice_index : index)
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[1]xi128> from vector<[1]x[1]xi128>
return %slice : vector<[1]xi128>
}

//===----------------------------------------------------------------------===//
// arm_sme.streaming_vl
//===----------------------------------------------------------------------===//

// -----

// CHECK-LABEL: @arm_sme_streaming_vl_bytes
// CHECK: %[[COUNT:.*]] = "arm_sme.intr.cntsb"() : () -> i64
// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[COUNT]] : i64 to index
// CHECK: return %[[INDEX_COUNT]] : index
func.func @arm_sme_streaming_vl_bytes() -> index {
%svl_b = arm_sme.streaming_vl <byte>
return %svl_b : index
}

// -----

// CHECK-LABEL: @arm_sme_streaming_vl_half_words
// CHECK: "arm_sme.intr.cntsh"() : () -> i64
func.func @arm_sme_streaming_vl_half_words() -> index {
%svl_h = arm_sme.streaming_vl <half>
return %svl_h : index
}

// -----

// CHECK-LABEL: @arm_sme_streaming_vl_words
// CHECK: "arm_sme.intr.cntsw"() : () -> i64
func.func @arm_sme_streaming_vl_words() -> index {
%svl_w = arm_sme.streaming_vl <word>
return %svl_w : index
}

// -----

// CHECK-LABEL: @arm_sme_streaming_vl_double_words
// CHECK: "arm_sme.intr.cntsd"() : () -> i64
func.func @arm_sme_streaming_vl_double_words() -> index {
%svl_d = arm_sme.streaming_vl <double>
return %svl_d : index
}
36 changes: 36 additions & 0 deletions mlir/test/Dialect/ArmSME/roundtrip.mlir
Expand Up @@ -1095,3 +1095,39 @@ func.func @arm_sme_outerproduct_with_everything(%vecA: vector<[16]xi8>, %vecB: v
%result = arm_sme.outerproduct %vecA, %vecB kind<sub> acc(%acc) masks(%maskA, %maskB) : vector<[16]xi8>, vector<[16]xi8>
return %result : vector<[16]x[16]xi8>
}

//===----------------------------------------------------------------------===//
// arm_sme.streaming_vl
//===----------------------------------------------------------------------===//

// -----

func.func @arm_sme_streaming_vl_bytes() -> index {
// CHECK: arm_sme.streaming_vl <byte>
%svl_b = arm_sme.streaming_vl <byte>
return %svl_b : index
}

// -----

func.func @arm_sme_streaming_vl_half_words() -> index {
// CHECK: arm_sme.streaming_vl <half>
%svl_h = arm_sme.streaming_vl <half>
return %svl_h : index
}

// -----

func.func @arm_sme_streaming_vl_words() -> index {
// CHECK: arm_sme.streaming_vl <word>
%svl_w = arm_sme.streaming_vl <word>
return %svl_w : index
}

// -----

func.func @arm_sme_streaming_vl_double_words() -> index {
// CHECK: arm_sme.streaming_vl <double>
%svl_d = arm_sme.streaming_vl <double>
return %svl_d : index
}

0 comments on commit 53d4890

Please sign in to comment.