Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][ArmSME] Add arm_sme.streaming_vl operation #77321

Merged
merged 3 commits into from Jan 10, 2024

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Jan 8, 2024

This operation provides a convenient way to query the streaming vector
length regardless of the streaming mode. This most useful for functions
that call/pass data to streaming functions, but are not streaming
themselves.

Example:

%svl_w = arm_sme.streaming_vl <word>

Created based on discussion here: #76086 (comment)

This operation provides a convenient way to query the streaming vector
length regardless of the streaming mode. This most useful for functions
that call/pass data to streaming functions, but are not streaming
themselves.

Example:
```mlir
%svl_w = arm_sme.streaming_vl <words>
```
@llvmbot
Copy link
Collaborator

llvmbot commented Jan 9, 2024

@llvm/pr-subscribers-mlir-sme

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

Changes

This operation provides a convenient way to query the streaming vector
length regardless of the streaming mode. This most useful for functions
that call/pass data to streaming functions, but are not streaming
themselves.

Example:

%svl_w = arm_sme.streaming_vl &lt;words&gt;

Created based on discussion here: #76086 (comment)


Full diff: https://github.com/llvm/llvm-project/pull/77321.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td (+43)
  • (modified) mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp (+44-3)
  • (modified) mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir (+42)
  • (modified) mlir/test/Dialect/ArmSME/roundtrip.mlir (+36)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index f7cc1d3fe7517f..4060407d81c0fa 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -223,6 +223,21 @@ def ArmSME_CombiningKindAttr : EnumAttr<ArmSME_Dialect, CombiningKind,
   let defaultValue = "CombiningKind::Add";
 }
 
+def TypeSize : I32EnumAttr<"TypeSize", "Size of vector type", [
+  I32EnumAttrCase<"Bytes"      , 0, "bytes">,
+  I32EnumAttrCase<"HalfWords"  , 1, "half_words">,
+  I32EnumAttrCase<"Words"      , 2, "words">,
+  I32EnumAttrCase<"DoubleWords", 3, "double_words">,
+]> {
+  let cppNamespace = "::mlir::arm_sme";
+  let genSpecializedAttr = 0;
+}
+
+def ArmSME_TypeSizeAttr : EnumAttr<ArmSME_Dialect, TypeSize,
+                                          "type_size"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
 //===----------------------------------------------------------------------===//
 // ArmSME op definitions
 //===----------------------------------------------------------------------===//
@@ -768,4 +783,32 @@ let arguments = (ins
   }];
 }
 
+def StreamingVLOp : ArmSME_Op<"streaming_vl", [Pure]>
+{
+  let summary = "Query the streaming vector length";
+
+  let description = [{
+    This operation returns the streaming vector length for a type size. Unlike
+    `vector.vscale` the value returned is invariant to the streaming mode.
+
+    Example:
+    ```mlir
+    // Streaming vector length in:
+    // - bytes (8-bit)
+    %svl_b = arm_sme.streaming_vl <bytes>
+    // - half words (16-bit)
+    %svl_h = arm_sme.streaming_vl <half_words>
+    // - words (32-bit)
+    %svl_w = arm_sme.streaming_vl <words>
+    // - double words (64-bit)
+    %svl_d = arm_sme.streaming_vl <double_words>
+    ```
+  }];
+
+  let arguments = (ins ArmSME_TypeSizeAttr: $type_size);
+  let results = (outs Index);
+
+  let assemblyFormat = "$type_size attr-dict";
+}
+
 #endif // ARMSME_OPS
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 0c6e2e80b88a3b..c5fdba6d00cc0f 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -518,6 +518,45 @@ struct OuterProductOpConversion
   }
 };
 
+/// Lower `arm_sme.streaming_vl` to SME CNTS intrinsics.
+///
+/// Example:
+///
+///   %0 = arm_sme.streaming_vl <half_words>
+///
+/// is converted to:
+///
+///   %cnt = "arm_sme.intr.cntsh"() : () -> i64
+///   %0 = arith.index_cast %cnt : i64 to index
+///
+struct StreamingVLOpConversion
+    : public ConvertOpToLLVMPattern<arm_sme::StreamingVLOp> {
+  using ConvertOpToLLVMPattern<arm_sme::StreamingVLOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(arm_sme::StreamingVLOp streamingVlOp,
+                  arm_sme::StreamingVLOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = streamingVlOp.getLoc();
+    auto i64Type = rewriter.getI64Type();
+    auto *intrOp = [&]() -> Operation * {
+      switch (streamingVlOp.getTypeSize()) {
+      case arm_sme::TypeSize::Bytes:
+        return rewriter.create<arm_sme::aarch64_sme_cntsb>(loc, i64Type);
+      case arm_sme::TypeSize::HalfWords:
+        return rewriter.create<arm_sme::aarch64_sme_cntsh>(loc, i64Type);
+      case arm_sme::TypeSize::Words:
+        return rewriter.create<arm_sme::aarch64_sme_cntsw>(loc, i64Type);
+      case arm_sme::TypeSize::DoubleWords:
+        return rewriter.create<arm_sme::aarch64_sme_cntsd>(loc, i64Type);
+      }
+    }();
+    rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
+        streamingVlOp, rewriter.getIndexType(), intrOp->getResult(0));
+    return success();
+  }
+};
+
 } // namespace
 
 namespace {
@@ -555,7 +594,9 @@ void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
       arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
       arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
       arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
-      arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>();
+      arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa,
+      arm_sme::aarch64_sme_cntsb, arm_sme::aarch64_sme_cntsh,
+      arm_sme::aarch64_sme_cntsw, arm_sme::aarch64_sme_cntsd>();
   target.addLegalDialect<arith::ArithDialect>();
   target.addLegalOp<UnrealizedConversionCastOp>();
 }
@@ -572,8 +613,8 @@ void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
 
   patterns.add<LoadTileSliceConversion, MoveTileSliceToVectorConversion,
                MoveVectorToTileSliceConversion, StoreTileSliceConversion,
-               OuterProductOpConversion, ZeroOpConversion, GetTileConversion>(
-      converter);
+               OuterProductOpConversion, ZeroOpConversion, GetTileConversion,
+               StreamingVLOpConversion>(converter);
 }
 
 std::unique_ptr<Pass> mlir::createConvertArmSMEToLLVMPass() {
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
index bd88da37bdf966..bab4fd60518c54 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
@@ -559,3 +559,45 @@ func.func @arm_sme_move_tile_slice_to_vector_ver_i128(%tile_slice_index : index)
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[1]xi128> from vector<[1]x[1]xi128>
   return %slice : vector<[1]xi128>
 }
+
+//===----------------------------------------------------------------------===//
+// arm_sme.streaming_vl
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: @arm_sme_streaming_vl_bytes
+// CHECK: %[[COUNT:.*]] = "arm_sme.intr.cntsb"() : () -> i64
+// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[COUNT]] : i64 to index
+// CHECK: return %[[INDEX_COUNT]] : index
+func.func @arm_sme_streaming_vl_bytes() -> index {
+  %svl_b = arm_sme.streaming_vl <bytes>
+  return %svl_b : index
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_streaming_vl_half_words
+// CHECK: "arm_sme.intr.cntsh"() : () -> i64
+func.func @arm_sme_streaming_vl_half_words() -> index {
+  %svl_h = arm_sme.streaming_vl <half_words>
+  return %svl_h : index
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_streaming_vl_words
+// CHECK: "arm_sme.intr.cntsw"() : () -> i64
+func.func @arm_sme_streaming_vl_words() -> index {
+  %svl_w = arm_sme.streaming_vl <words>
+  return %svl_w : index
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_streaming_vl_double_words
+// CHECK: "arm_sme.intr.cntsd"() : () -> i64
+func.func @arm_sme_streaming_vl_double_words() -> index {
+  %svl_d = arm_sme.streaming_vl <double_words>
+  return %svl_d : index
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 58ff7ef4d8340e..eb4c7149b61f1d 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1095,3 +1095,39 @@ func.func @arm_sme_outerproduct_with_everything(%vecA: vector<[16]xi8>, %vecB: v
   %result = arm_sme.outerproduct %vecA, %vecB kind<sub> acc(%acc) masks(%maskA, %maskB) : vector<[16]xi8>, vector<[16]xi8>
   return %result : vector<[16]x[16]xi8>
 }
+
+//===----------------------------------------------------------------------===//
+// arm_sme.streaming_vl
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_streaming_vl_bytes() -> index {
+  // CHECK: arm_sme.streaming_vl <bytes>
+  %svl_b = arm_sme.streaming_vl <bytes>
+  return %svl_b : index
+}
+
+// -----
+
+func.func @arm_sme_streaming_vl_half_words() -> index {
+  // CHECK: arm_sme.streaming_vl <half_words>
+  %svl_h = arm_sme.streaming_vl <half_words>
+  return %svl_h : index
+}
+
+// -----
+
+func.func @arm_sme_streaming_vl_words() -> index {
+  // CHECK: arm_sme.streaming_vl <words>
+  %svl_w = arm_sme.streaming_vl <words>
+  return %svl_w : index
+}
+
+// -----
+
+func.func @arm_sme_streaming_vl_double_words() -> index {
+  // CHECK: arm_sme.streaming_vl <double_words>
+  %svl_d = arm_sme.streaming_vl <double_words>
+  return %svl_d : index
+}

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks Ben, couple of nits but LGTM regardless, cheers

mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td Outdated Show resolved Hide resolved
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just a small suggestion.

mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td Outdated Show resolved Hide resolved
@MacDue MacDue merged commit 53d4890 into llvm:main Jan 10, 2024
3 of 4 checks passed
@MacDue MacDue deleted the streaming_vl_op branch January 10, 2024 10:11
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
This operation provides a convenient way to query the streaming vector
length regardless of the streaming mode. This most useful for functions
that call/pass data to streaming functions, but are not streaming
themselves.

Example:
```mlir
%svl_w = arm_sme.streaming_vl <word>
```

Created based on discussion here:
llvm#76086 (comment)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants