diff --git a/clang/include/clang/Basic/arm_sme.td b/clang/include/clang/Basic/arm_sme.td index a4eb92e76968c..5f6a6eaab80a3 100644 --- a/clang/include/clang/Basic/arm_sme.td +++ b/clang/include/clang/Basic/arm_sme.td @@ -156,16 +156,10 @@ let SMETargetGuard = "sme2p1" in { //////////////////////////////////////////////////////////////////////////////// // SME - Counting elements in a streaming vector -multiclass ZACount { - def NAME : SInst<"sv" # n_suffix, "nv", "", MergeNone, - "aarch64_sme_" # n_suffix, - [IsOverloadNone, IsStreamingCompatible]>; -} - -defm SVCNTSB : ZACount<"cntsb">; -defm SVCNTSH : ZACount<"cntsh">; -defm SVCNTSW : ZACount<"cntsw">; -defm SVCNTSD : ZACount<"cntsd">; +def SVCNTSB : SInst<"svcntsb", "nv", "", MergeNone, "", [IsOverloadNone, IsStreamingCompatible]>; +def SVCNTSH : SInst<"svcntsh", "nv", "", MergeNone, "", [IsOverloadNone, IsStreamingCompatible]>; +def SVCNTSW : SInst<"svcntsw", "nv", "", MergeNone, "", [IsOverloadNone, IsStreamingCompatible]>; +def SVCNTSD : SInst<"svcntsd", "nv", "", MergeNone, "aarch64_sme_cntsd", [IsOverloadNone, IsStreamingCompatible]>; //////////////////////////////////////////////////////////////////////////////// // SME - ADDHA/ADDVA diff --git a/clang/lib/CodeGen/TargetBuiltins/ARM.cpp b/clang/lib/CodeGen/TargetBuiltins/ARM.cpp index 60413e7b18e85..734d925c0bb7c 100644 --- a/clang/lib/CodeGen/TargetBuiltins/ARM.cpp +++ b/clang/lib/CodeGen/TargetBuiltins/ARM.cpp @@ -4304,9 +4304,11 @@ Value *CodeGenFunction::EmitSMELd1St1(const SVETypeFlags &TypeFlags, // size in bytes. if (Ops.size() == 5) { Function *StreamingVectorLength = - CGM.getIntrinsic(Intrinsic::aarch64_sme_cntsb); + CGM.getIntrinsic(Intrinsic::aarch64_sme_cntsd); llvm::Value *StreamingVectorLengthCall = - Builder.CreateCall(StreamingVectorLength); + Builder.CreateMul(Builder.CreateCall(StreamingVectorLength), + llvm::ConstantInt::get(Int64Ty, 8), "svl", + /* HasNUW */ true, /* HasNSW */ true); llvm::Value *Mulvl = Builder.CreateMul(StreamingVectorLengthCall, Ops[4], "mulvl"); // The type of the ptr parameter is void *, so use Int8Ty here. @@ -4918,6 +4920,26 @@ Value *CodeGenFunction::EmitAArch64SMEBuiltinExpr(unsigned BuiltinID, // Handle builtins which require their multi-vector operands to be swapped swapCommutativeSMEOperands(BuiltinID, Ops); + auto isCntsBuiltin = [&]() { + switch (BuiltinID) { + default: + return 0; + case SME::BI__builtin_sme_svcntsb: + return 8; + case SME::BI__builtin_sme_svcntsh: + return 4; + case SME::BI__builtin_sme_svcntsw: + return 2; + } + }; + + if (auto Mul = isCntsBuiltin()) { + llvm::Value *Cntd = + Builder.CreateCall(CGM.getIntrinsic(Intrinsic::aarch64_sme_cntsd)); + return Builder.CreateMul(Cntd, llvm::ConstantInt::get(Int64Ty, Mul), + "mulsvl", /* HasNUW */ true, /* HasNSW */ true); + } + // Should not happen! if (Builtin->LLVMIntrinsic == 0) return nullptr; diff --git a/clang/test/CodeGen/AArch64/sme-intrinsics/acle_sme_cnt.c b/clang/test/CodeGen/AArch64/sme-intrinsics/acle_sme_cnt.c index c0b3e1a06b0ff..049c1742e5a9d 100644 --- a/clang/test/CodeGen/AArch64/sme-intrinsics/acle_sme_cnt.c +++ b/clang/test/CodeGen/AArch64/sme-intrinsics/acle_sme_cnt.c @@ -6,49 +6,55 @@ #include -// CHECK-C-LABEL: define dso_local i64 @test_svcntsb( +// CHECK-C-LABEL: define dso_local range(i64 0, -9223372036854775808) i64 @test_svcntsb( // CHECK-C-SAME: ) local_unnamed_addr #[[ATTR0:[0-9]+]] { // CHECK-C-NEXT: entry: -// CHECK-C-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsb() -// CHECK-C-NEXT: ret i64 [[TMP0]] +// CHECK-C-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsd() +// CHECK-C-NEXT: [[MULSVL:%.*]] = shl nuw nsw i64 [[TMP0]], 3 +// CHECK-C-NEXT: ret i64 [[MULSVL]] // -// CHECK-CXX-LABEL: define dso_local noundef i64 @_Z12test_svcntsbv( +// CHECK-CXX-LABEL: define dso_local noundef range(i64 0, -9223372036854775808) i64 @_Z12test_svcntsbv( // CHECK-CXX-SAME: ) local_unnamed_addr #[[ATTR0:[0-9]+]] { // CHECK-CXX-NEXT: entry: -// CHECK-CXX-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsb() -// CHECK-CXX-NEXT: ret i64 [[TMP0]] +// CHECK-CXX-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsd() +// CHECK-CXX-NEXT: [[MULSVL:%.*]] = shl nuw nsw i64 [[TMP0]], 3 +// CHECK-CXX-NEXT: ret i64 [[MULSVL]] // uint64_t test_svcntsb() { return svcntsb(); } -// CHECK-C-LABEL: define dso_local i64 @test_svcntsh( +// CHECK-C-LABEL: define dso_local range(i64 0, -9223372036854775808) i64 @test_svcntsh( // CHECK-C-SAME: ) local_unnamed_addr #[[ATTR0]] { // CHECK-C-NEXT: entry: -// CHECK-C-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsh() -// CHECK-C-NEXT: ret i64 [[TMP0]] +// CHECK-C-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsd() +// CHECK-C-NEXT: [[MULSVL:%.*]] = shl nuw nsw i64 [[TMP0]], 2 +// CHECK-C-NEXT: ret i64 [[MULSVL]] // -// CHECK-CXX-LABEL: define dso_local noundef i64 @_Z12test_svcntshv( +// CHECK-CXX-LABEL: define dso_local noundef range(i64 0, -9223372036854775808) i64 @_Z12test_svcntshv( // CHECK-CXX-SAME: ) local_unnamed_addr #[[ATTR0]] { // CHECK-CXX-NEXT: entry: -// CHECK-CXX-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsh() -// CHECK-CXX-NEXT: ret i64 [[TMP0]] +// CHECK-CXX-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsd() +// CHECK-CXX-NEXT: [[MULSVL:%.*]] = shl nuw nsw i64 [[TMP0]], 2 +// CHECK-CXX-NEXT: ret i64 [[MULSVL]] // uint64_t test_svcntsh() { return svcntsh(); } -// CHECK-C-LABEL: define dso_local i64 @test_svcntsw( +// CHECK-C-LABEL: define dso_local range(i64 0, -9223372036854775808) i64 @test_svcntsw( // CHECK-C-SAME: ) local_unnamed_addr #[[ATTR0]] { // CHECK-C-NEXT: entry: -// CHECK-C-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsw() -// CHECK-C-NEXT: ret i64 [[TMP0]] +// CHECK-C-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsd() +// CHECK-C-NEXT: [[MULSVL:%.*]] = shl nuw nsw i64 [[TMP0]], 1 +// CHECK-C-NEXT: ret i64 [[MULSVL]] // -// CHECK-CXX-LABEL: define dso_local noundef i64 @_Z12test_svcntswv( +// CHECK-CXX-LABEL: define dso_local noundef range(i64 0, -9223372036854775808) i64 @_Z12test_svcntswv( // CHECK-CXX-SAME: ) local_unnamed_addr #[[ATTR0]] { // CHECK-CXX-NEXT: entry: -// CHECK-CXX-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsw() -// CHECK-CXX-NEXT: ret i64 [[TMP0]] +// CHECK-CXX-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsd() +// CHECK-CXX-NEXT: [[MULSVL:%.*]] = shl nuw nsw i64 [[TMP0]], 1 +// CHECK-CXX-NEXT: ret i64 [[MULSVL]] // uint64_t test_svcntsw() { return svcntsw(); diff --git a/llvm/include/llvm/IR/IntrinsicsAArch64.td b/llvm/include/llvm/IR/IntrinsicsAArch64.td index 6d53bf8b172d8..7c9aef52b3acf 100644 --- a/llvm/include/llvm/IR/IntrinsicsAArch64.td +++ b/llvm/include/llvm/IR/IntrinsicsAArch64.td @@ -3147,13 +3147,8 @@ let TargetPrefix = "aarch64" in { // Counting elements // - class AdvSIMD_SME_CNTSB_Intrinsic - : DefaultAttrsIntrinsic<[llvm_i64_ty], [], [IntrNoMem]>; - - def int_aarch64_sme_cntsb : AdvSIMD_SME_CNTSB_Intrinsic; - def int_aarch64_sme_cntsh : AdvSIMD_SME_CNTSB_Intrinsic; - def int_aarch64_sme_cntsw : AdvSIMD_SME_CNTSB_Intrinsic; - def int_aarch64_sme_cntsd : AdvSIMD_SME_CNTSB_Intrinsic; + def int_aarch64_sme_cntsd + : DefaultAttrsIntrinsic<[llvm_i64_ty], [], [IntrNoMem]>; // // PSTATE Functions diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp index bc786f415b554..1c20a8240d688 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp @@ -71,6 +71,9 @@ class AArch64DAGToDAGISel : public SelectionDAGISel { template bool SelectRDVLImm(SDValue N, SDValue &Imm); + template + bool SelectRDSVLShiftImm(SDValue N, SDValue &Imm); + bool SelectArithExtendedRegister(SDValue N, SDValue &Reg, SDValue &Shift); bool SelectArithUXTXRegister(SDValue N, SDValue &Reg, SDValue &Shift); bool SelectArithImmed(SDValue N, SDValue &Val, SDValue &Shift); @@ -937,6 +940,21 @@ bool AArch64DAGToDAGISel::SelectRDVLImm(SDValue N, SDValue &Imm) { return false; } +// Returns a suitable RDSVL multiplier from a left shift. +template +bool AArch64DAGToDAGISel::SelectRDSVLShiftImm(SDValue N, SDValue &Imm) { + if (!isa(N)) + return false; + + int64_t MulImm = 1 << cast(N)->getSExtValue(); + if (MulImm >= Low && MulImm <= High) { + Imm = CurDAG->getSignedTargetConstant(MulImm, SDLoc(N), MVT::i32); + return true; + } + + return false; +} + /// SelectArithExtendedRegister - Select a "extended register" operand. This /// operand folds in an extend followed by an optional left shift. bool AArch64DAGToDAGISel::SelectArithExtendedRegister(SDValue N, SDValue &Reg, diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 23328ed57fb36..b1b40c6b570d6 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -6266,25 +6266,11 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, case Intrinsic::aarch64_sve_clz: return DAG.getNode(AArch64ISD::CTLZ_MERGE_PASSTHRU, DL, Op.getValueType(), Op.getOperand(2), Op.getOperand(3), Op.getOperand(1)); - case Intrinsic::aarch64_sme_cntsb: - return DAG.getNode(AArch64ISD::RDSVL, DL, Op.getValueType(), - DAG.getConstant(1, DL, MVT::i32)); - case Intrinsic::aarch64_sme_cntsh: { - SDValue One = DAG.getConstant(1, DL, MVT::i32); - SDValue Bytes = DAG.getNode(AArch64ISD::RDSVL, DL, Op.getValueType(), One); - return DAG.getNode(ISD::SRL, DL, Op.getValueType(), Bytes, One); - } - case Intrinsic::aarch64_sme_cntsw: { - SDValue Bytes = DAG.getNode(AArch64ISD::RDSVL, DL, Op.getValueType(), - DAG.getConstant(1, DL, MVT::i32)); - return DAG.getNode(ISD::SRL, DL, Op.getValueType(), Bytes, - DAG.getConstant(2, DL, MVT::i32)); - } case Intrinsic::aarch64_sme_cntsd: { SDValue Bytes = DAG.getNode(AArch64ISD::RDSVL, DL, Op.getValueType(), DAG.getConstant(1, DL, MVT::i32)); return DAG.getNode(ISD::SRL, DL, Op.getValueType(), Bytes, - DAG.getConstant(3, DL, MVT::i32)); + DAG.getConstant(3, DL, MVT::i32), SDNodeFlags::Exact); } case Intrinsic::aarch64_sve_cnt: { SDValue Data = Op.getOperand(3); diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td index 0d8cb3a76d0be..6313aba9a435e 100644 --- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td @@ -127,11 +127,16 @@ def : Pat<(AArch64_requires_za_save), (RequiresZASavePseudo)>; def SDT_AArch64RDSVL : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisInt<1>]>; def AArch64rdsvl : SDNode<"AArch64ISD::RDSVL", SDT_AArch64RDSVL>; +def sme_rdsvl_shl_imm : ComplexPattern">; + let Predicates = [HasSMEandIsNonStreamingSafe] in { def RDSVLI_XI : sve_int_read_vl_a<0b0, 0b11111, "rdsvl", /*streaming_sve=*/0b1>; def ADDSPL_XXI : sve_int_arith_vl<0b1, "addspl", /*streaming_sve=*/0b1>; def ADDSVL_XXI : sve_int_arith_vl<0b0, "addsvl", /*streaming_sve=*/0b1>; +def : Pat<(i64 (shl (AArch64rdsvl (i32 1)), (sme_rdsvl_shl_imm i64:$imm))), + (RDSVLI_XI (!cast("trunc_imm") $imm))>; + def : Pat<(AArch64rdsvl (i32 simm6_32b:$imm)), (RDSVLI_XI simm6_32b:$imm)>; } diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index 490f6391c15a0..d4c7cb11a70a3 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -2102,15 +2102,15 @@ instCombineSVECntElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts) { } static std::optional -instCombineSMECntsElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts, - const AArch64Subtarget *ST) { +instCombineSMECntsd(InstCombiner &IC, IntrinsicInst &II, + const AArch64Subtarget *ST) { if (!ST->isStreaming()) return std::nullopt; - // In streaming-mode, aarch64_sme_cnts is equivalent to aarch64_sve_cnt + // In streaming-mode, aarch64_sme_cntds is equivalent to aarch64_sve_cntd // with SVEPredPattern::all - Value *Cnt = IC.Builder.CreateElementCount( - II.getType(), ElementCount::getScalable(NumElts)); + Value *Cnt = + IC.Builder.CreateElementCount(II.getType(), ElementCount::getScalable(2)); Cnt->takeName(&II); return IC.replaceInstUsesWith(II, Cnt); } @@ -2825,13 +2825,7 @@ AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC, case Intrinsic::aarch64_sve_cntb: return instCombineSVECntElts(IC, II, 16); case Intrinsic::aarch64_sme_cntsd: - return instCombineSMECntsElts(IC, II, 2, ST); - case Intrinsic::aarch64_sme_cntsw: - return instCombineSMECntsElts(IC, II, 4, ST); - case Intrinsic::aarch64_sme_cntsh: - return instCombineSMECntsElts(IC, II, 8, ST); - case Intrinsic::aarch64_sme_cntsb: - return instCombineSMECntsElts(IC, II, 16, ST); + return instCombineSMECntsd(IC, II, ST); case Intrinsic::aarch64_sve_ptest_any: case Intrinsic::aarch64_sve_ptest_first: case Intrinsic::aarch64_sve_ptest_last: diff --git a/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll b/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll index 5d10d7e13da14..06c53d8070781 100644 --- a/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll +++ b/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll @@ -1,46 +1,89 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py ; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme -verify-machineinstrs < %s | FileCheck %s -define i64 @sme_cntsb() { -; CHECK-LABEL: sme_cntsb: +define i64 @cntsb() { +; CHECK-LABEL: cntsb: ; CHECK: // %bb.0: ; CHECK-NEXT: rdsvl x0, #1 ; CHECK-NEXT: ret - %v = call i64 @llvm.aarch64.sme.cntsb() - ret i64 %v + %1 = call i64 @llvm.aarch64.sme.cntsd() + %res = shl nuw nsw i64 %1, 3 + ret i64 %res } -define i64 @sme_cntsh() { -; CHECK-LABEL: sme_cntsh: +define i64 @cntsh() { +; CHECK-LABEL: cntsh: ; CHECK: // %bb.0: ; CHECK-NEXT: rdsvl x8, #1 ; CHECK-NEXT: lsr x0, x8, #1 ; CHECK-NEXT: ret - %v = call i64 @llvm.aarch64.sme.cntsh() - ret i64 %v + %1 = call i64 @llvm.aarch64.sme.cntsd() + %res = shl nuw nsw i64 %1, 2 + ret i64 %res } -define i64 @sme_cntsw() { -; CHECK-LABEL: sme_cntsw: +define i64 @cntsw() { +; CHECK-LABEL: cntsw: ; CHECK: // %bb.0: ; CHECK-NEXT: rdsvl x8, #1 ; CHECK-NEXT: lsr x0, x8, #2 ; CHECK-NEXT: ret - %v = call i64 @llvm.aarch64.sme.cntsw() - ret i64 %v + %1 = call i64 @llvm.aarch64.sme.cntsd() + %res = shl nuw nsw i64 %1, 1 + ret i64 %res } -define i64 @sme_cntsd() { -; CHECK-LABEL: sme_cntsd: +define i64 @cntsd() { +; CHECK-LABEL: cntsd: ; CHECK: // %bb.0: ; CHECK-NEXT: rdsvl x8, #1 ; CHECK-NEXT: lsr x0, x8, #3 +; CHECK-NEXT: ret + %res = call i64 @llvm.aarch64.sme.cntsd() + ret i64 %res +} + +define i64 @sme_cntsb_mul() { +; CHECK-LABEL: sme_cntsb_mul: +; CHECK: // %bb.0: +; CHECK-NEXT: rdsvl x0, #4 +; CHECK-NEXT: ret + %v = call i64 @llvm.aarch64.sme.cntsd() + %shl = shl nuw nsw i64 %v, 3 + %res = mul nuw nsw i64 %shl, 4 + ret i64 %res +} + +define i64 @sme_cntsh_mul() { +; CHECK-LABEL: sme_cntsh_mul: +; CHECK: // %bb.0: +; CHECK-NEXT: rdsvl x0, #4 +; CHECK-NEXT: ret + %v = call i64 @llvm.aarch64.sme.cntsd() + %shl = shl nuw nsw i64 %v, 2 + %res = mul nuw nsw i64 %shl, 8 + ret i64 %res +} + +define i64 @sme_cntsw_mul() { +; CHECK-LABEL: sme_cntsw_mul: +; CHECK: // %bb.0: +; CHECK-NEXT: rdsvl x0, #4 +; CHECK-NEXT: ret + %v = call i64 @llvm.aarch64.sme.cntsd() + %shl = shl nuw nsw i64 %v, 1 + %res = mul nuw nsw i64 %shl, 16 + ret i64 %res +} + +define i64 @sme_cntsd_mul() { +; CHECK-LABEL: sme_cntsd_mul: +; CHECK: // %bb.0: +; CHECK-NEXT: rdsvl x0, #4 ; CHECK-NEXT: ret %v = call i64 @llvm.aarch64.sme.cntsd() - ret i64 %v + %res = mul nuw nsw i64 %v, 32 + ret i64 %res } -declare i64 @llvm.aarch64.sme.cntsb() -declare i64 @llvm.aarch64.sme.cntsh() -declare i64 @llvm.aarch64.sme.cntsw() declare i64 @llvm.aarch64.sme.cntsd() diff --git a/llvm/test/CodeGen/AArch64/sme-streaming-interface-remarks.ll b/llvm/test/CodeGen/AArch64/sme-streaming-interface-remarks.ll index e1a474d898233..2806f864c7b25 100644 --- a/llvm/test/CodeGen/AArch64/sme-streaming-interface-remarks.ll +++ b/llvm/test/CodeGen/AArch64/sme-streaming-interface-remarks.ll @@ -76,14 +76,14 @@ entry: %Data1 = alloca , align 16 %Data2 = alloca , align 16 %Data3 = alloca , align 16 - %0 = tail call i64 @llvm.aarch64.sme.cntsb() + %0 = tail call i64 @llvm.aarch64.sme.cntsd() call void @foo(ptr noundef nonnull %Data1, ptr noundef nonnull %Data2, ptr noundef nonnull %Data3, i64 noundef %0) %1 = load , ptr %Data1, align 16 %vecext = extractelement %1, i64 0 ret i8 %vecext } -declare i64 @llvm.aarch64.sme.cntsb() +declare i64 @llvm.aarch64.sme.cntsd() declare void @foo(ptr noundef, ptr noundef, ptr noundef, i64 noundef) diff --git a/llvm/test/CodeGen/AArch64/sme-streaming-interface.ll b/llvm/test/CodeGen/AArch64/sme-streaming-interface.ll index 8c4d57e244e03..505a40c16653b 100644 --- a/llvm/test/CodeGen/AArch64/sme-streaming-interface.ll +++ b/llvm/test/CodeGen/AArch64/sme-streaming-interface.ll @@ -366,9 +366,10 @@ define i8 @call_to_non_streaming_pass_sve_objects(ptr nocapture noundef readnone ; CHECK-NEXT: stp d9, d8, [sp, #48] // 16-byte Folded Spill ; CHECK-NEXT: stp x29, x30, [sp, #64] // 16-byte Folded Spill ; CHECK-NEXT: addvl sp, sp, #-3 -; CHECK-NEXT: rdsvl x3, #1 +; CHECK-NEXT: rdsvl x8, #1 ; CHECK-NEXT: addvl x0, sp, #2 ; CHECK-NEXT: addvl x1, sp, #1 +; CHECK-NEXT: lsr x3, x8, #3 ; CHECK-NEXT: mov x2, sp ; CHECK-NEXT: smstop sm ; CHECK-NEXT: bl foo @@ -386,7 +387,7 @@ entry: %Data1 = alloca , align 16 %Data2 = alloca , align 16 %Data3 = alloca , align 16 - %0 = tail call i64 @llvm.aarch64.sme.cntsb() + %0 = tail call i64 @llvm.aarch64.sme.cntsd() call void @foo(ptr noundef nonnull %Data1, ptr noundef nonnull %Data2, ptr noundef nonnull %Data3, i64 noundef %0) %1 = load , ptr %Data1, align 16 %vecext = extractelement %1, i64 0 @@ -421,7 +422,7 @@ entry: ret void } -declare i64 @llvm.aarch64.sme.cntsb() +declare i64 @llvm.aarch64.sme.cntsd() declare void @foo(ptr noundef, ptr noundef, ptr noundef, i64 noundef) declare void @bar(ptr noundef, i64 noundef, i64 noundef, i32 noundef, i32 noundef, float noundef, float noundef, double noundef, double noundef) diff --git a/llvm/test/Transforms/InstCombine/AArch64/sme-intrinsic-opts-counting-elems.ll b/llvm/test/Transforms/InstCombine/AArch64/sme-intrinsic-opts-counting-elems.ll index f213c0b53f6ef..c1d12b825b72c 100644 --- a/llvm/test/Transforms/InstCombine/AArch64/sme-intrinsic-opts-counting-elems.ll +++ b/llvm/test/Transforms/InstCombine/AArch64/sme-intrinsic-opts-counting-elems.ll @@ -5,48 +5,6 @@ target triple = "aarch64-unknown-linux-gnu" -define i64 @cntsb() { -; CHECK-LABEL: @cntsb( -; CHECK-NEXT: [[OUT:%.*]] = call i64 @llvm.aarch64.sme.cntsb() -; CHECK-NEXT: ret i64 [[OUT]] -; -; CHECK-STREAMING-LABEL: @cntsb( -; CHECK-STREAMING-NEXT: [[TMP1:%.*]] = call i64 @llvm.vscale.i64() -; CHECK-STREAMING-NEXT: [[OUT:%.*]] = shl nuw i64 [[TMP1]], 4 -; CHECK-STREAMING-NEXT: ret i64 [[OUT]] -; - %out = call i64 @llvm.aarch64.sme.cntsb() - ret i64 %out -} - -define i64 @cntsh() { -; CHECK-LABEL: @cntsh( -; CHECK-NEXT: [[OUT:%.*]] = call i64 @llvm.aarch64.sme.cntsh() -; CHECK-NEXT: ret i64 [[OUT]] -; -; CHECK-STREAMING-LABEL: @cntsh( -; CHECK-STREAMING-NEXT: [[TMP1:%.*]] = call i64 @llvm.vscale.i64() -; CHECK-STREAMING-NEXT: [[OUT:%.*]] = shl nuw i64 [[TMP1]], 3 -; CHECK-STREAMING-NEXT: ret i64 [[OUT]] -; - %out = call i64 @llvm.aarch64.sme.cntsh() - ret i64 %out -} - -define i64 @cntsw() { -; CHECK-LABEL: @cntsw( -; CHECK-NEXT: [[OUT:%.*]] = call i64 @llvm.aarch64.sme.cntsw() -; CHECK-NEXT: ret i64 [[OUT]] -; -; CHECK-STREAMING-LABEL: @cntsw( -; CHECK-STREAMING-NEXT: [[TMP1:%.*]] = call i64 @llvm.vscale.i64() -; CHECK-STREAMING-NEXT: [[OUT:%.*]] = shl nuw i64 [[TMP1]], 2 -; CHECK-STREAMING-NEXT: ret i64 [[OUT]] -; - %out = call i64 @llvm.aarch64.sme.cntsw() - ret i64 %out -} - define i64 @cntsd() { ; CHECK-LABEL: @cntsd( ; CHECK-NEXT: [[OUT:%.*]] = call i64 @llvm.aarch64.sme.cntsd() @@ -61,8 +19,5 @@ define i64 @cntsd() { ret i64 %out } -declare i64 @llvm.aarch64.sve.cntsb() -declare i64 @llvm.aarch64.sve.cntsh() -declare i64 @llvm.aarch64.sve.cntsw() declare i64 @llvm.aarch64.sve.cntsd() diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td index 06fb8511774e8..4d19fa5415ef0 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td @@ -201,9 +201,6 @@ class ArmSME_IntrCountOp /*traits*/[PredOpTrait<"`res` is i64", TypeIsPred<"res", I64>>], /*numResults=*/1, /*overloadedResults=*/[]>; -def LLVM_aarch64_sme_cntsb : ArmSME_IntrCountOp<"cntsb">; -def LLVM_aarch64_sme_cntsh : ArmSME_IntrCountOp<"cntsh">; -def LLVM_aarch64_sme_cntsw : ArmSME_IntrCountOp<"cntsw">; def LLVM_aarch64_sme_cntsd : ArmSME_IntrCountOp<"cntsd">; #endif // ARMSME_INTRINSIC_OPS diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h index 1f40eb6fc693c..b57b27de4e1de 100644 --- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h @@ -32,6 +32,9 @@ namespace mlir::arm_sme { constexpr unsigned MinStreamingVectorLengthInBits = 128; +/// Return the size represented by arm_sme::TypeSize in bytes. +unsigned getSizeInBytes(TypeSize type); + /// Return minimum number of elements for the given element `type` in /// a vector of SVL bits. unsigned getSMETileSliceMinNumElts(Type type); diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp index 8a2e3b639aaa7..033e9ae1f4d4c 100644 --- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp +++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp @@ -822,7 +822,7 @@ struct OuterProductWideningOpConversion } }; -/// Lower `arm_sme.streaming_vl` to SME CNTS intrinsics. +/// Lower `arm_sme.streaming_vl` to SME CNTSD intrinsic. /// /// Example: /// @@ -830,8 +830,10 @@ struct OuterProductWideningOpConversion /// /// is converted to: /// -/// %cnt = "arm_sme.intr.cntsh"() : () -> i64 -/// %0 = arith.index_cast %cnt : i64 to index +/// %cnt = "arm_sme.intr.cntsd"() : () -> i64 +/// %scale = arith.constant 4 : index +/// %cntIndex = arith.index_cast %cnt : i64 to index +/// %0 = arith.muli %cntIndex, %scale : index /// struct StreamingVLOpConversion : public ConvertArmSMEOpToLLVMPattern Operation * { - switch (streamingVlOp.getTypeSize()) { - case arm_sme::TypeSize::Byte: - return arm_sme::aarch64_sme_cntsb::create(rewriter, loc, i64Type); - case arm_sme::TypeSize::Half: - return arm_sme::aarch64_sme_cntsh::create(rewriter, loc, i64Type); - case arm_sme::TypeSize::Word: - return arm_sme::aarch64_sme_cntsw::create(rewriter, loc, i64Type); - case arm_sme::TypeSize::Double: - return arm_sme::aarch64_sme_cntsd::create(rewriter, loc, i64Type); - } - llvm_unreachable("unknown type size in StreamingVLOpConversion"); - }(); - rewriter.replaceOpWithNewOp( - streamingVlOp, rewriter.getIndexType(), intrOp->getResult(0)); + auto cntsd = arm_sme::aarch64_sme_cntsd::create(rewriter, loc, i64Type); + auto cntsdIdx = arith::IndexCastOp::create(rewriter, loc, + rewriter.getIndexType(), cntsd); + auto scale = arith::ConstantIndexOp::create( + rewriter, loc, + 8 / arm_sme::getSizeInBytes(streamingVlOp.getTypeSize())); + rewriter.replaceOpWithNewOp(streamingVlOp, cntsdIdx, scale); return success(); } }; @@ -964,9 +958,7 @@ void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) { arm_sme::aarch64_sme_smops_za32, arm_sme::aarch64_sme_umopa_za32, arm_sme::aarch64_sme_umops_za32, arm_sme::aarch64_sme_sumopa_wide, arm_sme::aarch64_sme_sumops_wide, arm_sme::aarch64_sme_usmopa_wide, - arm_sme::aarch64_sme_usmops_wide, arm_sme::aarch64_sme_cntsb, - arm_sme::aarch64_sme_cntsh, arm_sme::aarch64_sme_cntsw, - arm_sme::aarch64_sme_cntsd>(); + arm_sme::aarch64_sme_usmops_wide, arm_sme::aarch64_sme_cntsd>(); target.addLegalDialect vec // ----- // CHECK-LABEL: @arm_sme_streaming_vl_bytes -// CHECK: %[[COUNT:.*]] = "arm_sme.intr.cntsb"() : () -> i64 -// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[COUNT]] : i64 to index -// CHECK: return %[[INDEX_COUNT]] : index +// CHECK: %[[CONST:.*]] = arith.constant 8 : index +// CHECK: %[[CNTSD:.*]] = "arm_sme.intr.cntsd"() : () -> i64 +// CHECK: %[[CNTSD_IDX:.*]] = arith.index_cast %[[CNTSD]] : i64 to index +// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD_IDX]], %[[CONST]] : index func.func @arm_sme_streaming_vl_bytes() -> index { %svl_b = arm_sme.streaming_vl return %svl_b : index @@ -597,7 +598,10 @@ func.func @arm_sme_streaming_vl_bytes() -> index { // ----- // CHECK-LABEL: @arm_sme_streaming_vl_half_words -// CHECK: "arm_sme.intr.cntsh"() : () -> i64 +// CHECK: %[[CONST:.*]] = arith.constant 4 : index +// CHECK: %[[CNTSD:.*]] = "arm_sme.intr.cntsd"() : () -> i64 +// CHECK: %[[CNTSD_IDX:.*]] = arith.index_cast %[[CNTSD]] : i64 to index +// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD_IDX]], %[[CONST]] : index func.func @arm_sme_streaming_vl_half_words() -> index { %svl_h = arm_sme.streaming_vl return %svl_h : index @@ -606,7 +610,10 @@ func.func @arm_sme_streaming_vl_half_words() -> index { // ----- // CHECK-LABEL: @arm_sme_streaming_vl_words -// CHECK: "arm_sme.intr.cntsw"() : () -> i64 +// CHECK: %[[CONST:.*]] = arith.constant 2 : index +// CHECK: %[[CNTSD:.*]] = "arm_sme.intr.cntsd"() : () -> i64 +// CHECK: %[[CNTSD_IDX:.*]] = arith.index_cast %[[CNTSD]] : i64 to index +// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD_IDX]], %[[CONST]] : index func.func @arm_sme_streaming_vl_words() -> index { %svl_w = arm_sme.streaming_vl return %svl_w : index diff --git a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir index 14821da838726..6f5b1d8c5d93d 100644 --- a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir +++ b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir @@ -36,6 +36,6 @@ llvm.func @arm_sme_tile_slice_to_vector_invalid_element_types( llvm.func @arm_sme_streaming_vl_invalid_return_type() -> i32 { // expected-error @+1 {{failed to verify that `res` is i64}} - %res = "arm_sme.intr.cntsb"() : () -> i32 + %res = "arm_sme.intr.cntsd"() : () -> i32 llvm.return %res : i32 } diff --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir index aedb6730b06bb..0a13a75618a23 100644 --- a/mlir/test/Target/LLVMIR/arm-sme.mlir +++ b/mlir/test/Target/LLVMIR/arm-sme.mlir @@ -419,12 +419,6 @@ llvm.func @arm_sme_tile_slice_to_vector_vert(%tileslice : i32, // ----- llvm.func @arm_sme_streaming_vl() { - // CHECK: call i64 @llvm.aarch64.sme.cntsb() - %svl_b = "arm_sme.intr.cntsb"() : () -> i64 - // CHECK: call i64 @llvm.aarch64.sme.cntsh() - %svl_h = "arm_sme.intr.cntsh"() : () -> i64 - // CHECK: call i64 @llvm.aarch64.sme.cntsw() - %svl_w = "arm_sme.intr.cntsw"() : () -> i64 // CHECK: call i64 @llvm.aarch64.sme.cntsd() %svl_d = "arm_sme.intr.cntsd"() : () -> i64 llvm.return