From c3b963207b1af767b7441f7eeabbeb634f039e75 Mon Sep 17 00:00:00 2001 From: Sander de Smalen Date: Fri, 19 Jan 2024 11:12:40 +0000 Subject: [PATCH] [AArch64] NFC: Simplify discombobulating 'requiresSMChange' interface Having it return a `std::optional` is unnecessarily confusing. This patch changes it to a simple 'bool'. This patch also removes the 'BodyOverridesInterface' operand because there is only a single use for this which is easily rewritten. --- .../Target/AArch64/AArch64ISelLowering.cpp | 15 +++-- .../AArch64/AArch64TargetTransformInfo.cpp | 5 +- .../AArch64/Utils/AArch64SMEAttributes.cpp | 20 ++----- .../AArch64/Utils/AArch64SMEAttributes.h | 9 +-- .../Target/AArch64/SMEAttributesTest.cpp | 59 ++++--------------- 5 files changed, 28 insertions(+), 80 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 8a6f1dc7487ba..b0548cc64fcf3 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -7650,8 +7650,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, } SDValue PStateSM; - std::optional RequiresSMChange = - CallerAttrs.requiresSMChange(CalleeAttrs); + bool RequiresSMChange = CallerAttrs.requiresSMChange(CalleeAttrs); if (RequiresSMChange) { if (CallerAttrs.hasStreamingInterfaceOrBody()) PStateSM = DAG.getConstant(1, DL, MVT::i64); @@ -7925,8 +7924,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, SDValue InGlue; if (RequiresSMChange) { - SDValue NewChain = changeStreamingMode(DAG, DL, *RequiresSMChange, Chain, - InGlue, PStateSM, true); + SDValue NewChain = + changeStreamingMode(DAG, DL, CalleeAttrs.hasStreamingInterface(), Chain, + InGlue, PStateSM, true); Chain = NewChain.getValue(0); InGlue = NewChain.getValue(1); } @@ -8076,8 +8076,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, if (RequiresSMChange) { assert(PStateSM && "Expected a PStateSM to be set"); - Result = changeStreamingMode(DAG, DL, !*RequiresSMChange, Result, InGlue, - PStateSM, false); + Result = changeStreamingMode(DAG, DL, !CalleeAttrs.hasStreamingInterface(), + Result, InGlue, PStateSM, false); } if (RequiresLazySave) { @@ -25479,8 +25479,7 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const { if (auto *Base = dyn_cast(&Inst)) { auto CallerAttrs = SMEAttrs(*Inst.getFunction()); auto CalleeAttrs = SMEAttrs(*Base); - if (CallerAttrs.requiresSMChange(CalleeAttrs, - /*BodyOverridesInterface=*/false) || + if (CallerAttrs.requiresSMChange(CalleeAttrs) || CallerAttrs.requiresLazySave(CalleeAttrs)) return true; } diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index d358a5c8bd949..08ae536fe9bbf 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -236,8 +236,9 @@ bool AArch64TTIImpl::areInlineCompatible(const Function *Caller, return false; if (CallerAttrs.requiresLazySave(CalleeAttrs) || - CallerAttrs.requiresSMChange(CalleeAttrs, - /*BodyOverridesInterface=*/true)) { + (CallerAttrs.requiresSMChange(CalleeAttrs) && + (!CallerAttrs.hasStreamingInterfaceOrBody() || + !CalleeAttrs.hasStreamingBody()))) { if (hasPossibleIncompatibleOps(Callee)) return false; } diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp index ccdec78d78086..9693b6a664be2 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp @@ -82,27 +82,17 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) { Bitmask |= encodeZT0State(StateValue::New); } -std::optional -SMEAttrs::requiresSMChange(const SMEAttrs &Callee, - bool BodyOverridesInterface) const { - // If the transition is not through a call (e.g. when considering inlining) - // and Callee has a streaming body, then we can ignore the interface of - // Callee. - if (BodyOverridesInterface && Callee.hasStreamingBody()) { - return hasStreamingInterfaceOrBody() ? std::nullopt - : std::optional(true); - } - +bool SMEAttrs::requiresSMChange(const SMEAttrs &Callee) const { if (Callee.hasStreamingCompatibleInterface()) - return std::nullopt; + return false; // Both non-streaming if (hasNonStreamingInterfaceAndBody() && Callee.hasNonStreamingInterface()) - return std::nullopt; + return false; // Both streaming if (hasStreamingInterfaceOrBody() && Callee.hasStreamingInterface()) - return std::nullopt; + return false; - return Callee.hasStreamingInterface(); + return true; } diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h index af2854856fb97..6f622f1996a3a 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h @@ -75,14 +75,7 @@ class SMEAttrs { /// \return true if a call from Caller -> Callee requires a change in /// streaming mode. - /// If \p BodyOverridesInterface is true and Callee has a streaming body, - /// then requiresSMChange considers a call to Callee as having a Streaming - /// interface. This can be useful when considering e.g. inlining, where we - /// explicitly want the body to overrule the interface (because after inlining - /// the interface is no longer relevant). - std::optional - requiresSMChange(const SMEAttrs &Callee, - bool BodyOverridesInterface = false) const; + bool requiresSMChange(const SMEAttrs &Callee) const; // Interfaces to query PSTATE.ZA bool hasNewZABody() const { return Bitmask & ZA_New; } diff --git a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp index 2f7201464ba2f..294e557181424 100644 --- a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp +++ b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp @@ -193,86 +193,51 @@ TEST(SMEAttributes, Transitions) { ASSERT_FALSE(SA(SA::Normal).requiresSMChange(SA(SA::Normal))); // Normal -> Normal + LocallyStreaming ASSERT_FALSE(SA(SA::Normal).requiresSMChange(SA(SA::Normal | SA::SM_Body))); - ASSERT_EQ(*SA(SA::Normal) - .requiresSMChange(SA(SA::Normal | SA::SM_Body), - /*BodyOverridesInterface=*/true), - true); // Normal -> Streaming - ASSERT_EQ(*SA(SA::Normal).requiresSMChange(SA(SA::SM_Enabled)), true); + ASSERT_TRUE(SA(SA::Normal).requiresSMChange(SA(SA::SM_Enabled))); // Normal -> Streaming + LocallyStreaming - ASSERT_EQ(*SA(SA::Normal).requiresSMChange(SA(SA::SM_Enabled | SA::SM_Body)), - true); - ASSERT_EQ(*SA(SA::Normal) - .requiresSMChange(SA(SA::SM_Enabled | SA::SM_Body), - /*BodyOverridesInterface=*/true), - true); + ASSERT_TRUE( + SA(SA::Normal).requiresSMChange(SA(SA::SM_Enabled | SA::SM_Body))); // Normal -> Streaming-compatible ASSERT_FALSE(SA(SA::Normal).requiresSMChange(SA(SA::SM_Compatible))); // Normal -> Streaming-compatible + LocallyStreaming ASSERT_FALSE( SA(SA::Normal).requiresSMChange(SA(SA::SM_Compatible | SA::SM_Body))); - ASSERT_EQ(*SA(SA::Normal) - .requiresSMChange(SA(SA::SM_Compatible | SA::SM_Body), - /*BodyOverridesInterface=*/true), - true); // Streaming -> Normal - ASSERT_EQ(*SA(SA::SM_Enabled).requiresSMChange(SA(SA::Normal)), false); + ASSERT_TRUE(SA(SA::SM_Enabled).requiresSMChange(SA(SA::Normal))); // Streaming -> Normal + LocallyStreaming - ASSERT_EQ(*SA(SA::SM_Enabled).requiresSMChange(SA(SA::Normal | SA::SM_Body)), - false); - ASSERT_FALSE(SA(SA::SM_Enabled) - .requiresSMChange(SA(SA::Normal | SA::SM_Body), - /*BodyOverridesInterface=*/true)); + ASSERT_TRUE( + SA(SA::SM_Enabled).requiresSMChange(SA(SA::Normal | SA::SM_Body))); // Streaming -> Streaming ASSERT_FALSE(SA(SA::SM_Enabled).requiresSMChange(SA(SA::SM_Enabled))); // Streaming -> Streaming + LocallyStreaming ASSERT_FALSE( SA(SA::SM_Enabled).requiresSMChange(SA(SA::SM_Enabled | SA::SM_Body))); - ASSERT_FALSE(SA(SA::SM_Enabled) - .requiresSMChange(SA(SA::SM_Enabled | SA::SM_Body), - /*BodyOverridesInterface=*/true)); // Streaming -> Streaming-compatible ASSERT_FALSE(SA(SA::SM_Enabled).requiresSMChange(SA(SA::SM_Compatible))); // Streaming -> Streaming-compatible + LocallyStreaming ASSERT_FALSE( SA(SA::SM_Enabled).requiresSMChange(SA(SA::SM_Compatible | SA::SM_Body))); - ASSERT_FALSE(SA(SA::SM_Enabled) - .requiresSMChange(SA(SA::SM_Compatible | SA::SM_Body), - /*BodyOverridesInterface=*/true)); // Streaming-compatible -> Normal - ASSERT_EQ(*SA(SA::SM_Compatible).requiresSMChange(SA(SA::Normal)), false); - ASSERT_EQ( - *SA(SA::SM_Compatible).requiresSMChange(SA(SA::Normal | SA::SM_Body)), - false); - ASSERT_EQ(*SA(SA::SM_Compatible) - .requiresSMChange(SA(SA::Normal | SA::SM_Body), - /*BodyOverridesInterface=*/true), - true); + ASSERT_TRUE(SA(SA::SM_Compatible).requiresSMChange(SA(SA::Normal))); + ASSERT_TRUE( + SA(SA::SM_Compatible).requiresSMChange(SA(SA::Normal | SA::SM_Body))); // Streaming-compatible -> Streaming - ASSERT_EQ(*SA(SA::SM_Compatible).requiresSMChange(SA(SA::SM_Enabled)), true); + ASSERT_TRUE(SA(SA::SM_Compatible).requiresSMChange(SA(SA::SM_Enabled))); // Streaming-compatible -> Streaming + LocallyStreaming - ASSERT_EQ( - *SA(SA::SM_Compatible).requiresSMChange(SA(SA::SM_Enabled | SA::SM_Body)), - true); - ASSERT_EQ(*SA(SA::SM_Compatible) - .requiresSMChange(SA(SA::SM_Enabled | SA::SM_Body), - /*BodyOverridesInterface=*/true), - true); + ASSERT_TRUE( + SA(SA::SM_Compatible).requiresSMChange(SA(SA::SM_Enabled | SA::SM_Body))); // Streaming-compatible -> Streaming-compatible ASSERT_FALSE(SA(SA::SM_Compatible).requiresSMChange(SA(SA::SM_Compatible))); // Streaming-compatible -> Streaming-compatible + LocallyStreaming ASSERT_FALSE(SA(SA::SM_Compatible) .requiresSMChange(SA(SA::SM_Compatible | SA::SM_Body))); - ASSERT_EQ(*SA(SA::SM_Compatible) - .requiresSMChange(SA(SA::SM_Compatible | SA::SM_Body), - /*BodyOverridesInterface=*/true), - true); }