Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SimpleLoopUnswitch] Remove callbacks #73300

Merged
merged 3 commits into from
Nov 24, 2023

Conversation

boomanaiden154
Copy link
Contributor

After the removal of the legacyPM version of simple loop unswitch, there is no longer a need for the callback mechanism to handle PM specific tasks. This patch removes the callbacks to help simplify the code now that they're no longer needed.

After the removal of the legacyPM version of simple loop unswitch, there
is no longer a need for the callback mechanism to handle PM specific
tasks. This patch removes the callbacks to help simplify the code now
that they're no longer needed.
@llvmbot
Copy link
Collaborator

llvmbot commented Nov 24, 2023

@llvm/pr-subscribers-llvm-transforms

Author: Aiden Grossman (boomanaiden154)

Changes

After the removal of the legacyPM version of simple loop unswitch, there is no longer a need for the callback mechanism to handle PM specific tasks. This patch removes the callbacks to help simplify the code now that they're no longer needed.


Full diff: https://github.com/llvm/llvm-project/pull/73300.diff

1 Files Affected:

  • (modified) llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp (+74-88)
diff --git a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
index 55606473765239b..cd00ef1fbbfc86b 100644
--- a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
+++ b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
@@ -1682,13 +1682,12 @@ deleteDeadClonedBlocks(Loop &L, ArrayRef<BasicBlock *> ExitBlocks,
     BB->eraseFromParent();
 }
 
-static void
-deleteDeadBlocksFromLoop(Loop &L,
-                         SmallVectorImpl<BasicBlock *> &ExitBlocks,
-                         DominatorTree &DT, LoopInfo &LI,
-                         MemorySSAUpdater *MSSAU,
-                         ScalarEvolution *SE,
-                         function_ref<void(Loop &, StringRef)> DestroyLoopCB) {
+static void deleteDeadBlocksFromLoop(Loop &L,
+                                     SmallVectorImpl<BasicBlock *> &ExitBlocks,
+                                     DominatorTree &DT, LoopInfo &LI,
+                                     MemorySSAUpdater *MSSAU,
+                                     ScalarEvolution *SE,
+                                     LPMUpdater &LoopUpdater) {
   // Find all the dead blocks tied to this loop, and remove them from their
   // successors.
   SmallSetVector<BasicBlock *, 8> DeadBlockSet;
@@ -1738,7 +1737,7 @@ deleteDeadBlocksFromLoop(Loop &L,
                         }) &&
            "If the child loop header is dead all blocks in the child loop must "
            "be dead as well!");
-    DestroyLoopCB(*ChildL, ChildL->getName());
+    LoopUpdater.markLoopAsDeleted(*ChildL, ChildL->getName());
     if (SE)
       SE->forgetBlockAndLoopDispositions();
     LI.destroy(ChildL);
@@ -2082,8 +2081,8 @@ static bool rebuildLoopAfterUnswitch(Loop &L, ArrayRef<BasicBlock *> ExitBlocks,
       ParentL->removeChildLoop(llvm::find(*ParentL, &L));
     else
       LI.removeLoop(llvm::find(LI, &L));
-    // markLoopAsDeleted for L should be triggered by the caller (it is typically
-    // done by using the UnswitchCB callback).
+    // markLoopAsDeleted for L should be triggered by the caller (it is
+    // typically done within postUnswitch).
     if (SE)
       SE->forgetBlockAndLoopDispositions();
     LI.destroy(&L);
@@ -2120,18 +2119,56 @@ void visitDomSubTree(DominatorTree &DT, BasicBlock *BB, CallableT Callable) {
   } while (!DomWorklist.empty());
 }
 
+void postUnswitch(Loop &L, LPMUpdater &U, StringRef LoopName,
+                  bool CurrentLoopValid, bool PartiallyInvariant,
+                  bool InjectedCondition, ArrayRef<Loop *> NewLoops) {
+  // If we did a non-trivial unswitch, we have added new (cloned) loops.
+  if (!NewLoops.empty())
+    U.addSiblingLoops(NewLoops);
+
+  // If the current loop remains valid, we should revisit it to catch any
+  // other unswitch opportunities. Otherwise, we need to mark it as deleted.
+  if (CurrentLoopValid) {
+    if (PartiallyInvariant) {
+      // Mark the new loop as partially unswitched, to avoid unswitching on
+      // the same condition again.
+      auto &Context = L.getHeader()->getContext();
+      MDNode *DisableUnswitchMD = MDNode::get(
+          Context,
+          MDString::get(Context, "llvm.loop.unswitch.partial.disable"));
+      MDNode *NewLoopID = makePostTransformationMetadata(
+          Context, L.getLoopID(), {"llvm.loop.unswitch.partial"},
+          {DisableUnswitchMD});
+      L.setLoopID(NewLoopID);
+    } else if (InjectedCondition) {
+      // Do the same for injection of invariant conditions.
+      auto &Context = L.getHeader()->getContext();
+      MDNode *DisableUnswitchMD = MDNode::get(
+          Context,
+          MDString::get(Context, "llvm.loop.unswitch.injection.disable"));
+      MDNode *NewLoopID = makePostTransformationMetadata(
+          Context, L.getLoopID(), {"llvm.loop.unswitch.injection"},
+          {DisableUnswitchMD});
+      L.setLoopID(NewLoopID);
+    } else
+      U.revisitCurrentLoop();
+  } else
+    U.markLoopAsDeleted(L, LoopName);
+}
+
 static void unswitchNontrivialInvariants(
     Loop &L, Instruction &TI, ArrayRef<Value *> Invariants,
     IVConditionInfo &PartialIVInfo, DominatorTree &DT, LoopInfo &LI,
-    AssumptionCache &AC,
-    function_ref<void(bool, bool, bool, ArrayRef<Loop *>)> UnswitchCB,
-    ScalarEvolution *SE, MemorySSAUpdater *MSSAU,
-    function_ref<void(Loop &, StringRef)> DestroyLoopCB, bool InsertFreeze,
-    bool InjectedCondition) {
+    AssumptionCache &AC, ScalarEvolution *SE, MemorySSAUpdater *MSSAU,
+    LPMUpdater &LoopUpdater, bool InsertFreeze, bool InjectedCondition) {
   auto *ParentBB = TI.getParent();
   BranchInst *BI = dyn_cast<BranchInst>(&TI);
   SwitchInst *SI = BI ? nullptr : cast<SwitchInst>(&TI);
 
+  // Save the current loop name in a variable so that we can report it even
+  // after it has been deleted.
+  std::string LoopName(L.getName());
+
   // We can only unswitch switches, conditional branches with an invariant
   // condition, or combining invariant conditions with an instruction or
   // partially invariant instructions.
@@ -2444,7 +2481,7 @@ static void unswitchNontrivialInvariants(
   // Now that our cloned loops have been built, we can update the original loop.
   // First we delete the dead blocks from it and then we rebuild the loop
   // structure taking these deletions into account.
-  deleteDeadBlocksFromLoop(L, ExitBlocks, DT, LI, MSSAU, SE,DestroyLoopCB);
+  deleteDeadBlocksFromLoop(L, ExitBlocks, DT, LI, MSSAU, SE, LoopUpdater);
 
   if (MSSAU && VerifyMemorySSA)
     MSSAU->getMemorySSA()->verifyMemorySSA();
@@ -2580,7 +2617,8 @@ static void unswitchNontrivialInvariants(
   for (Loop *UpdatedL : llvm::concat<Loop *>(NonChildClonedLoops, HoistedLoops))
     if (UpdatedL->getParentLoop() == ParentL)
       SibLoops.push_back(UpdatedL);
-  UnswitchCB(IsStillLoop, PartiallyInvariant, InjectedCondition, SibLoops);
+  postUnswitch(L, LoopUpdater, LoopName, IsStillLoop, PartiallyInvariant,
+               InjectedCondition, SibLoops);
 
   if (MSSAU && VerifyMemorySSA)
     MSSAU->getMemorySSA()->verifyMemorySSA();
@@ -3427,12 +3465,11 @@ static bool shouldInsertFreeze(Loop &L, Instruction &TI, DominatorTree &DT,
       Cond, &AC, L.getLoopPreheader()->getTerminator(), &DT);
 }
 
-static bool unswitchBestCondition(
-    Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
-    AAResults &AA, TargetTransformInfo &TTI,
-    function_ref<void(bool, bool, bool, ArrayRef<Loop *>)> UnswitchCB,
-    ScalarEvolution *SE, MemorySSAUpdater *MSSAU,
-    function_ref<void(Loop &, StringRef)> DestroyLoopCB) {
+static bool unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI,
+                                  AssumptionCache &AC, AAResults &AA,
+                                  TargetTransformInfo &TTI, ScalarEvolution *SE,
+                                  MemorySSAUpdater *MSSAU,
+                                  LPMUpdater &LoopUpdater) {
   // Collect all invariant conditions within this loop (as opposed to an inner
   // loop which would be handled when visiting that inner loop).
   SmallVector<NonTrivialUnswitchCandidate, 4> UnswitchCandidates;
@@ -3495,8 +3532,8 @@ static bool unswitchBestCondition(
   LLVM_DEBUG(dbgs() << "  Unswitching non-trivial (cost = " << Best.Cost
                     << ") terminator: " << *Best.TI << "\n");
   unswitchNontrivialInvariants(L, *Best.TI, Best.Invariants, PartialIVInfo, DT,
-                               LI, AC, UnswitchCB, SE, MSSAU, DestroyLoopCB,
-                               InsertFreeze, InjectedCondition);
+                               LI, AC, SE, MSSAU, LoopUpdater, InsertFreeze,
+                               InjectedCondition);
   return true;
 }
 
@@ -3515,20 +3552,18 @@ static bool unswitchBestCondition(
 /// true, we will attempt to do non-trivial unswitching as well as trivial
 /// unswitching.
 ///
-/// The `UnswitchCB` callback provided will be run after unswitching is
-/// complete, with the first parameter set to `true` if the provided loop
-/// remains a loop, and a list of new sibling loops created.
+/// The `postUnswitch` function will be run after unswitching is complete
+/// with information on whether or not the provided loop remains a loop and
+/// a list of new sibling loops created.
 ///
 /// If `SE` is non-null, we will update that analysis based on the unswitching
 /// done.
-static bool
-unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
-             AAResults &AA, TargetTransformInfo &TTI, bool Trivial,
-             bool NonTrivial,
-             function_ref<void(bool, bool, bool, ArrayRef<Loop *>)> UnswitchCB,
-             ScalarEvolution *SE, MemorySSAUpdater *MSSAU,
-             ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI,
-             function_ref<void(Loop &, StringRef)> DestroyLoopCB) {
+static bool unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI,
+                         AssumptionCache &AC, AAResults &AA,
+                         TargetTransformInfo &TTI, bool Trivial,
+                         bool NonTrivial, ScalarEvolution *SE,
+                         MemorySSAUpdater *MSSAU, ProfileSummaryInfo *PSI,
+                         BlockFrequencyInfo *BFI, LPMUpdater &LoopUpdater) {
   assert(L.isRecursivelyLCSSAForm(DT, LI) &&
          "Loops must be in LCSSA form before unswitching.");
 
@@ -3540,8 +3575,7 @@ unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
   if (Trivial && unswitchAllTrivialConditions(L, DT, LI, SE, MSSAU)) {
     // If we unswitched successfully we will want to clean up the loop before
     // processing it further so just mark it as unswitched and return.
-    UnswitchCB(/*CurrentLoopValid*/ true, /*PartiallyInvariant*/ false,
-               /*InjectedCondition*/ false, {});
+    postUnswitch(L, LoopUpdater, L.getName(), true, false, false, {});
     return true;
   }
 
@@ -3610,8 +3644,7 @@ unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
 
   // Try to unswitch the best invariant condition. We prefer this full unswitch to
   // a partial unswitch when possible below the threshold.
-  if (unswitchBestCondition(L, DT, LI, AC, AA, TTI, UnswitchCB, SE, MSSAU,
-                            DestroyLoopCB))
+  if (unswitchBestCondition(L, DT, LI, AC, AA, TTI, SE, MSSAU, LoopUpdater))
     return true;
 
   // No other opportunities to unswitch.
@@ -3631,52 +3664,6 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM,
   LLVM_DEBUG(dbgs() << "Unswitching loop in " << F.getName() << ": " << L
                     << "\n");
 
-  // Save the current loop name in a variable so that we can report it even
-  // after it has been deleted.
-  std::string LoopName = std::string(L.getName());
-
-  auto UnswitchCB = [&L, &U, &LoopName](bool CurrentLoopValid,
-                                        bool PartiallyInvariant,
-                                        bool InjectedCondition,
-                                        ArrayRef<Loop *> NewLoops) {
-    // If we did a non-trivial unswitch, we have added new (cloned) loops.
-    if (!NewLoops.empty())
-      U.addSiblingLoops(NewLoops);
-
-    // If the current loop remains valid, we should revisit it to catch any
-    // other unswitch opportunities. Otherwise, we need to mark it as deleted.
-    if (CurrentLoopValid) {
-      if (PartiallyInvariant) {
-        // Mark the new loop as partially unswitched, to avoid unswitching on
-        // the same condition again.
-        auto &Context = L.getHeader()->getContext();
-        MDNode *DisableUnswitchMD = MDNode::get(
-            Context,
-            MDString::get(Context, "llvm.loop.unswitch.partial.disable"));
-        MDNode *NewLoopID = makePostTransformationMetadata(
-            Context, L.getLoopID(), {"llvm.loop.unswitch.partial"},
-            {DisableUnswitchMD});
-        L.setLoopID(NewLoopID);
-      } else if (InjectedCondition) {
-        // Do the same for injection of invariant conditions.
-        auto &Context = L.getHeader()->getContext();
-        MDNode *DisableUnswitchMD = MDNode::get(
-            Context,
-            MDString::get(Context, "llvm.loop.unswitch.injection.disable"));
-        MDNode *NewLoopID = makePostTransformationMetadata(
-            Context, L.getLoopID(), {"llvm.loop.unswitch.injection"},
-            {DisableUnswitchMD});
-        L.setLoopID(NewLoopID);
-      } else
-        U.revisitCurrentLoop();
-    } else
-      U.markLoopAsDeleted(L, LoopName);
-  };
-
-  auto DestroyLoopCB = [&U](Loop &L, StringRef Name) {
-    U.markLoopAsDeleted(L, Name);
-  };
-
   std::optional<MemorySSAUpdater> MSSAU;
   if (AR.MSSA) {
     MSSAU = MemorySSAUpdater(AR.MSSA);
@@ -3684,8 +3671,7 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM,
       AR.MSSA->verifyMemorySSA();
   }
   if (!unswitchLoop(L, AR.DT, AR.LI, AR.AC, AR.AA, AR.TTI, Trivial, NonTrivial,
-                    UnswitchCB, &AR.SE, MSSAU ? &*MSSAU : nullptr, PSI, AR.BFI,
-                    DestroyLoopCB))
+                    &AR.SE, MSSAU ? &*MSSAU : nullptr, PSI, AR.BFI, U))
     return PreservedAnalyses::all();
 
   if (AR.MSSA && VerifyMemorySSA)

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp Outdated Show resolved Hide resolved
Copy link
Contributor

@fhahn fhahn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks for the cleanup!

@boomanaiden154 boomanaiden154 merged commit 4233176 into llvm:main Nov 24, 2023
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants