Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NFC][AArch64] Refactor AArch64LoopIdiomTransform in preparation for more idioms #78471

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

david-arm
Copy link
Contributor

@david-arm david-arm commented Jan 17, 2024

In a future patch I intend to add support for other types of C loops that do memory comparisons with early exits, e.g.

  for (unsigned long i = start; i < end; i++) {
    if (p1[i] != p2[i])
      break;
  }

where we compare the loaded values prior to incrementing the induction variable i. This requires first refactoring the pass ready to support the new loop structures. I've created a new MemCompareIdiom class with support routines that can be reused by new code required to recognise these new loops.

I also modified MemCompareIdiom::generateMemCompare to ensure it is ready to support arbitrary induction variable types as well as supporting new comparison predicates when comparing the memory values.

…more idioms

In a future patch I intend to add support for other types of C
loops that do memory comparisons with early exits, e.g.

  for (unsigned long i = start; i < end; i++) {
    if (p1[i] != p2[i])
      break;
  }

where we compare the loaded values prior to incrementing the
induction variable `i`. This requires first refactoring the pass
ready to support the new loop structures. I've created a new
MemCompareIdiom class with support routines that can be reused
by new code required to recognise these new loops.

I also modified MemCompareIdiom::generateMemCompare to ensure
it is ready to support arbitrary induction variable types as
well as supporting new comparison predicates when comparing
the memory values.
@llvmbot
Copy link
Collaborator

llvmbot commented Jan 17, 2024

@llvm/pr-subscribers-backend-aarch64

Author: David Sherwood (david-arm)

Changes

In a future patch I intend to add support for other types of C loops that do memory comparisons with early exits, e.g.

for (unsigned long i = start; i < end; i++) {
if (p1[i] != p2[i])
break;
}

where we compare the loaded values prior to incrementing the induction variable i. This requires first refactoring the pass ready to support the new loop structures. I've created a new MemCompareIdiom class with support routines that can be reused by new code required to recognise these new loops.

I also modified MemCompareIdiom::generateMemCompare to ensure it is ready to support arbitrary induction variable types as well as supporting new comparison predicates when comparing the memory values.


Patch is 20.01 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/78471.diff

1 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64LoopIdiomTransform.cpp (+214-148)
diff --git a/llvm/lib/Target/AArch64/AArch64LoopIdiomTransform.cpp b/llvm/lib/Target/AArch64/AArch64LoopIdiomTransform.cpp
index 6dfb2b9df7135d..5b48d5be159f75 100644
--- a/llvm/lib/Target/AArch64/AArch64LoopIdiomTransform.cpp
+++ b/llvm/lib/Target/AArch64/AArch64LoopIdiomTransform.cpp
@@ -1,3 +1,4 @@
+
 //===- AArch64LoopIdiomTransform.cpp - Loop idiom recognition -------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
@@ -97,14 +98,6 @@ class AArch64LoopIdiomTransform {
   bool runOnLoopBlock(BasicBlock *BB, const SCEV *BECount,
                       SmallVectorImpl<BasicBlock *> &ExitBlocks);
 
-  bool recognizeByteCompare();
-  Value *expandFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU,
-                            GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
-                            Instruction *Index, Value *Start, Value *MaxLen);
-  void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
-                            PHINode *IndPhi, Value *MaxLen, Instruction *Index,
-                            Value *Start, bool IncIdx, BasicBlock *FoundBB,
-                            BasicBlock *EndBB);
   /// @}
 };
 
@@ -187,6 +180,38 @@ AArch64LoopIdiomTransformPass::run(Loop &L, LoopAnalysisManager &AM,
 //
 //===----------------------------------------------------------------------===//
 
+struct MemCompareIdiom {
+private:
+  const TargetTransformInfo *TTI;
+  DominatorTree *DT;
+  LoopInfo *LI;
+  Loop *CurLoop;
+  GetElementPtrInst *GEPA, *GEPB;
+  PHINode *IndPhi;
+  Instruction *Index;
+  bool IncIdx;
+  Value *StartIdx, *MaxLen;
+  BasicBlock *FoundBB, *EndBB;
+  ICmpInst::Predicate CmpPred;
+
+public:
+  MemCompareIdiom(const TargetTransformInfo *TTI, DominatorTree *DT,
+                  LoopInfo *LI, Loop *L)
+      : TTI(TTI), DT(DT), LI(LI), CurLoop(L) {}
+
+  bool recognize();
+  void transform();
+
+private:
+  bool checkIterationCondition(BasicBlock *CheckItBB, BasicBlock *&LoopBB);
+  bool checkLoadCompareCondition(BasicBlock *LoadCmpBB, Value *&LoadA,
+                                 Value *&LoadB);
+  bool areValidLoads(Value *BaseIdx, Value *LoadA, Value *LoadB);
+  bool checkEndAndFoundBlockPhis(Value *IndVal);
+  bool recognizePostIncMemCompare();
+  Value *generateMemCompare(IRBuilder<> &Builder, DomTreeUpdater &DTU);
+};
+
 bool AArch64LoopIdiomTransform::run(Loop *L) {
   CurLoop = L;
 
@@ -202,32 +227,135 @@ bool AArch64LoopIdiomTransform::run(Loop *L) {
                     << CurLoop->getHeader()->getParent()->getName()
                     << "] Loop %" << CurLoop->getHeader()->getName() << "\n");
 
-  return recognizeByteCompare();
+  MemCompareIdiom BCI(TTI, DT, LI, L);
+  if (BCI.recognize()) {
+    BCI.transform();
+    return true;
+  }
+
+  return false;
 }
 
-bool AArch64LoopIdiomTransform::recognizeByteCompare() {
-  // Currently the transformation only works on scalable vector types, although
-  // there is no fundamental reason why it cannot be made to work for fixed
-  // width too.
+bool MemCompareIdiom::areValidLoads(Value *BaseIdx, Value *LoadA,
+                                    Value *LoadB) {
+  Value *A, *B;
+  if (!match(LoadA, m_Load(m_Value(A))) || !match(LoadB, m_Load(m_Value(B))))
+    return false;
 
-  // We also need to know the minimum page size for the target in order to
-  // generate runtime memory checks to ensure the vector version won't fault.
-  if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() ||
-      DisableByteCmp)
+  if (!cast<LoadInst>(LoadA)->isSimple() || !cast<LoadInst>(LoadB)->isSimple())
     return false;
 
-  BasicBlock *Header = CurLoop->getHeader();
+  GEPA = dyn_cast<GetElementPtrInst>(A);
+  GEPB = dyn_cast<GetElementPtrInst>(B);
 
-  // In AArch64LoopIdiomTransform::run we have already checked that the loop
-  // has a preheader so we can assume it's in a canonical form.
-  if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2)
+  if (!GEPA || !GEPB)
     return false;
 
-  PHINode *PN = dyn_cast<PHINode>(&Header->front());
-  if (!PN || PN->getNumIncomingValues() != 2)
+  Value *PtrA = GEPA->getPointerOperand();
+  Value *PtrB = GEPB->getPointerOperand();
+
+  // Check we are loading i8 values from two loop invariant pointers
+  Type *MemType = GEPA->getResultElementType();
+  if (!CurLoop->isLoopInvariant(PtrA) || !CurLoop->isLoopInvariant(PtrB) ||
+      !MemType->isIntegerTy(8) || GEPB->getResultElementType() != MemType ||
+      cast<LoadInst>(LoadA)->getType() != MemType ||
+      cast<LoadInst>(LoadB)->getType() != MemType || PtrA == PtrB)
+    return false;
+
+  // Check that the index to the GEPs is the index we found earlier
+  if (GEPA->getNumIndices() != 1 || GEPB->getNumIndices() != 1)
+    return false;
+
+  Value *IdxA = GEPA->getOperand(1);
+  Value *IdxB = GEPB->getOperand(1);
+
+  if (IdxA != IdxB || !match(IdxA, m_ZExt(m_Specific(BaseIdx))))
     return false;
 
+  return true;
+}
+
+static bool checkBlockSizes(Loop *CurLoop, unsigned Block1Size,
+                            unsigned Block2Size) {
   auto LoopBlocks = CurLoop->getBlocks();
+
+  auto BB1 = LoopBlocks[0]->instructionsWithoutDebug();
+  if (std::distance(BB1.begin(), BB1.end()) > Block1Size)
+    return false;
+
+  auto BB2 = LoopBlocks[1]->instructionsWithoutDebug();
+  if (std::distance(BB2.begin(), BB2.end()) > Block2Size)
+    return false;
+
+  return true;
+}
+
+bool MemCompareIdiom::checkIterationCondition(BasicBlock *CheckItBB,
+                                              BasicBlock *&LoopBB) {
+  ICmpInst::Predicate Pred;
+  if (!match(CheckItBB->getTerminator(),
+             m_Br(m_ICmp(Pred, m_Specific(Index), m_Value(MaxLen)),
+                  m_BasicBlock(EndBB), m_BasicBlock(LoopBB))) ||
+      Pred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(LoopBB))
+    return false;
+
+  return true;
+}
+
+bool MemCompareIdiom::checkLoadCompareCondition(BasicBlock *LoadCmpBB,
+                                                Value *&LoadA, Value *&LoadB) {
+  BasicBlock *TrueBB;
+  ICmpInst::Predicate Pred;
+  // TODO: Support other predicates.
+  CmpPred = ICmpInst::Predicate::ICMP_EQ;
+  if (!match(LoadCmpBB->getTerminator(),
+             m_Br(m_ICmp(Pred, m_Value(LoadA), m_Value(LoadB)),
+                  m_BasicBlock(TrueBB), m_BasicBlock(FoundBB))) ||
+      Pred != CmpPred || TrueBB != CurLoop->getHeader())
+    return false;
+
+  return true;
+}
+
+bool MemCompareIdiom::checkEndAndFoundBlockPhis(Value *IndVal) {
+  // Ensure that when the Found and End blocks are identical the PHIs have the
+  // supported format. We don't currently allow cases like this:
+  // while.cond:
+  //   ...
+  //   br i1 %cmp.not, label %while.end, label %while.body
+  //
+  // while.body:
+  //   ...
+  //   br i1 %cmp.not2, label %while.cond, label %while.end
+  //
+  // while.end:
+  //   %final_ptr = phi ptr [ %c, %while.body ], [ %d, %while.cond ]
+  //
+  // Where the incoming values for %final_ptr are unique and from each of the
+  // loop blocks, but not actually defined in the loop. This requires extra
+  // work setting up the byte.compare block, i.e. by introducing a select to
+  // choose the correct value.
+  // TODO: We could add support for this in future.
+  if (FoundBB == EndBB) {
+    auto LoopBlocks = CurLoop->getBlocks();
+    for (PHINode &EndPN : EndBB->phis()) {
+      Value *V1 = EndPN.getIncomingValueForBlock(LoopBlocks[0]);
+      Value *V2 = EndPN.getIncomingValueForBlock(LoopBlocks[1]);
+
+      // The value of the index when leaving the while.cond block is always the
+      // same as the end value (MaxLen) so we permit either. Otherwise for any
+      // other value defined outside the loop we only allow values that are the
+      // same as the exit value for while.body.
+      if (V1 != V2 &&
+          ((V1 != IndVal && V1 != MaxLen) || (V2 != IndVal)))
+        return false;
+    }
+  }
+
+  return true;
+}
+
+bool MemCompareIdiom::recognizePostIncMemCompare() {
   // The first block in the loop should contain only 4 instructions, e.g.
   //
   //  while.cond:
@@ -236,10 +364,6 @@ bool AArch64LoopIdiomTransform::recognizeByteCompare() {
   //   %cmp.not = icmp eq i32 %inc, %n
   //   br i1 %cmp.not, label %while.end, label %while.body
   //
-  auto CondBBInsts = LoopBlocks[0]->instructionsWithoutDebug();
-  if (std::distance(CondBBInsts.begin(), CondBBInsts.end()) > 4)
-    return false;
-
   // The second block should contain 7 instructions, e.g.
   //
   // while.body:
@@ -251,145 +375,90 @@ bool AArch64LoopIdiomTransform::recognizeByteCompare() {
   //   %cmp.not.ld = icmp eq i8 %load.a, %load.b
   //   br i1 %cmp.not.ld, label %while.cond, label %while.end
   //
-  auto LoopBBInsts = LoopBlocks[1]->instructionsWithoutDebug();
-  if (std::distance(LoopBBInsts.begin(), LoopBBInsts.end()) > 7)
+  if (!checkBlockSizes(CurLoop, 4, 7))
     return false;
 
-  // The incoming value to the PHI node from the loop should be an add of 1.
-  Value *StartIdx = nullptr;
-  Instruction *Index = nullptr;
-  if (!CurLoop->contains(PN->getIncomingBlock(0))) {
-    StartIdx = PN->getIncomingValue(0);
-    Index = dyn_cast<Instruction>(PN->getIncomingValue(1));
-  } else {
-    StartIdx = PN->getIncomingValue(1);
-    Index = dyn_cast<Instruction>(PN->getIncomingValue(0));
-  }
-
-  // Limit to 32-bit types for now
-  if (!Index || !Index->getType()->isIntegerTy(32) ||
-      !match(Index, m_c_Add(m_Specific(PN), m_One())))
-    return false;
-
-  // If we match the pattern, PN and Index will be replaced with the result of
-  // the cttz.elts intrinsic. If any other instructions are used outside of
-  // the loop, we cannot replace it.
-  for (BasicBlock *BB : LoopBlocks)
-    for (Instruction &I : *BB)
-      if (&I != PN && &I != Index)
-        for (User *U : I.users())
-          if (!CurLoop->contains(cast<Instruction>(U)))
-            return false;
-
   // Match the branch instruction for the header
-  ICmpInst::Predicate Pred;
-  Value *MaxLen;
-  BasicBlock *EndBB, *WhileBB;
-  if (!match(Header->getTerminator(),
-             m_Br(m_ICmp(Pred, m_Specific(Index), m_Value(MaxLen)),
-                  m_BasicBlock(EndBB), m_BasicBlock(WhileBB))) ||
-      Pred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(WhileBB))
+  BasicBlock *Header = CurLoop->getHeader();
+  BasicBlock *WhileBodyBB;
+  if (!checkIterationCondition(Header, WhileBodyBB))
     return false;
 
   // WhileBB should contain the pattern of load & compare instructions. Match
   // the pattern and find the GEP instructions used by the loads.
-  ICmpInst::Predicate WhilePred;
-  BasicBlock *FoundBB;
-  BasicBlock *TrueBB;
   Value *LoadA, *LoadB;
-  if (!match(WhileBB->getTerminator(),
-             m_Br(m_ICmp(WhilePred, m_Value(LoadA), m_Value(LoadB)),
-                  m_BasicBlock(TrueBB), m_BasicBlock(FoundBB))) ||
-      WhilePred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(TrueBB))
+  if (!checkLoadCompareCondition(WhileBodyBB, LoadA, LoadB))
     return false;
 
-  Value *A, *B;
-  if (!match(LoadA, m_Load(m_Value(A))) || !match(LoadB, m_Load(m_Value(B))))
+  if (!areValidLoads(Index, LoadA, LoadB))
     return false;
 
-  LoadInst *LoadAI = cast<LoadInst>(LoadA);
-  LoadInst *LoadBI = cast<LoadInst>(LoadB);
-  if (!LoadAI->isSimple() || !LoadBI->isSimple())
+  if (!checkEndAndFoundBlockPhis(Index))
     return false;
 
-  GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(A);
-  GetElementPtrInst *GEPB = dyn_cast<GetElementPtrInst>(B);
+  return true;
+}
 
-  if (!GEPA || !GEPB)
+bool MemCompareIdiom::recognize() {
+  // Currently the transformation only works on scalable vector types, although
+  // there is no fundamental reason why it cannot be made to work for fixed
+  // width too.
+
+  // We also need to know the minimum page size for the target in order to
+  // generate runtime memory checks to ensure the vector version won't fault.
+  if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() ||
+      DisableByteCmp)
     return false;
 
-  Value *PtrA = GEPA->getPointerOperand();
-  Value *PtrB = GEPB->getPointerOperand();
+  BasicBlock *Header = CurLoop->getHeader();
 
-  // Check we are loading i8 values from two loop invariant pointers
-  if (!CurLoop->isLoopInvariant(PtrA) || !CurLoop->isLoopInvariant(PtrB) ||
-      !GEPA->getResultElementType()->isIntegerTy(8) ||
-      !GEPB->getResultElementType()->isIntegerTy(8) ||
-      !LoadAI->getType()->isIntegerTy(8) ||
-      !LoadBI->getType()->isIntegerTy(8) || PtrA == PtrB)
+  // In AArch64LoopIdiomTransform::run we have already checked that the loop
+  // has a preheader so we can assume it's in a canonical form.
+  if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2)
     return false;
 
-  // Check that the index to the GEPs is the index we found earlier
-  if (GEPA->getNumIndices() > 1 || GEPB->getNumIndices() > 1)
+  // We only expect one PHI node for the index.
+  IndPhi = dyn_cast<PHINode>(&Header->front());
+  if (!IndPhi || IndPhi->getNumIncomingValues() != 2)
     return false;
 
-  Value *IdxA = GEPA->getOperand(GEPA->getNumIndices());
-  Value *IdxB = GEPB->getOperand(GEPB->getNumIndices());
-  if (IdxA != IdxB || !match(IdxA, m_ZExt(m_Specific(Index))))
-    return false;
+  // The incoming value to the PHI node from the loop should be an add of 1.
+  StartIdx = nullptr;
+  Index = nullptr;
+  if (!CurLoop->contains(IndPhi->getIncomingBlock(0))) {
+    StartIdx = IndPhi->getIncomingValue(0);
+    Index = dyn_cast<Instruction>(IndPhi->getIncomingValue(1));
+  } else {
+    StartIdx = IndPhi->getIncomingValue(1);
+    Index = dyn_cast<Instruction>(IndPhi->getIncomingValue(0));
+  }
 
-  // We only ever expect the pre-incremented index value to be used inside the
-  // loop.
-  if (!PN->hasOneUse())
+  if (!Index || !Index->getType()->isIntegerTy(32) ||
+      !match(Index, m_c_Add(m_Specific(IndPhi), m_One())))
     return false;
 
-  // Ensure that when the Found and End blocks are identical the PHIs have the
-  // supported format. We don't currently allow cases like this:
-  // while.cond:
-  //   ...
-  //   br i1 %cmp.not, label %while.end, label %while.body
-  //
-  // while.body:
-  //   ...
-  //   br i1 %cmp.not2, label %while.cond, label %while.end
-  //
-  // while.end:
-  //   %final_ptr = phi ptr [ %c, %while.body ], [ %d, %while.cond ]
-  //
-  // Where the incoming values for %final_ptr are unique and from each of the
-  // loop blocks, but not actually defined in the loop. This requires extra
-  // work setting up the byte.compare block, i.e. by introducing a select to
-  // choose the correct value.
-  // TODO: We could add support for this in future.
-  if (FoundBB == EndBB) {
-    for (PHINode &EndPN : EndBB->phis()) {
-      Value *WhileCondVal = EndPN.getIncomingValueForBlock(Header);
-      Value *WhileBodyVal = EndPN.getIncomingValueForBlock(WhileBB);
+  // If we match the pattern, IndPhi and Index will be replaced with the result
+  // of the mismatch. If any other instructions are used outside of the loop, we
+  // cannot replace it.
+  for (BasicBlock *BB : CurLoop->getBlocks())
+    for (Instruction &I : *BB)
+      if (&I != IndPhi && &I != Index)
+        for (User *U : I.users())
+          if (!CurLoop->contains(cast<Instruction>(U)))
+            return false;
 
-      // The value of the index when leaving the while.cond block is always the
-      // same as the end value (MaxLen) so we permit either. The value when
-      // leaving the while.body block should only be the index. Otherwise for
-      // any other values we only allow ones that are same for both blocks.
-      if (WhileCondVal != WhileBodyVal &&
-          ((WhileCondVal != Index && WhileCondVal != MaxLen) ||
-           (WhileBodyVal != Index)))
-        return false;
-    }
-  }
+  if (recognizePostIncMemCompare())
+    IncIdx = true;
+  else
+    return false;
 
   LLVM_DEBUG(dbgs() << "FOUND IDIOM IN LOOP: \n"
                     << *(EndBB->getParent()) << "\n\n");
-
-  // The index is incremented before the GEP/Load pair so we need to
-  // add 1 to the start value.
-  transformByteCompare(GEPA, GEPB, PN, MaxLen, Index, StartIdx, /*IncIdx=*/true,
-                       FoundBB, EndBB);
   return true;
 }
 
-Value *AArch64LoopIdiomTransform::expandFindMismatch(
-    IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
-    GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) {
+Value *MemCompareIdiom::generateMemCompare(IRBuilder<> &Builder,
+                                           DomTreeUpdater &DTU) {
   Value *PtrA = GEPA->getPointerOperand();
   Value *PtrB = GEPB->getPointerOperand();
 
@@ -398,7 +467,7 @@ Value *AArch64LoopIdiomTransform::expandFindMismatch(
   BranchInst *PHBranch = cast<BranchInst>(Preheader->getTerminator());
   LLVMContext &Ctx = PHBranch->getContext();
   Type *LoadType = Type::getInt8Ty(Ctx);
-  Type *ResType = Builder.getInt32Ty();
+  Type *ResType = StartIdx->getType();
 
   // Split block in the original loop preheader.
   BasicBlock *EndBlock =
@@ -479,11 +548,11 @@ Value *AArch64LoopIdiomTransform::expandFindMismatch(
 
   // Check the zero-extended iteration count > 0
   Builder.SetInsertPoint(MinItCheckBlock);
-  Value *ExtStart = Builder.CreateZExt(Start, I64Type);
+  Value *ExtStart = Builder.CreateZExt(StartIdx, I64Type);
   Value *ExtEnd = Builder.CreateZExt(MaxLen, I64Type);
   // This check doesn't really cost us very much.
 
-  Value *LimitCheck = Builder.CreateICmpULE(Start, MaxLen);
+  Value *LimitCheck = Builder.CreateICmpULE(StartIdx, MaxLen);
   BranchInst *MinItCheckBr =
       BranchInst::Create(MemCheckBlock, LoopPreHeaderBlock, LimitCheck);
   MinItCheckBr->setMetadata(
@@ -592,7 +661,8 @@ Value *AArch64LoopIdiomTransform::expandFindMismatch(
   Value *SVERhsLoad = Builder.CreateMaskedLoad(SVELoadType, SVERhsGep, Align(1),
                                                LoopPred, Passthru);
 
-  Value *SVEMatchCmp = Builder.CreateICmpNE(SVELhsLoad, SVERhsLoad);
+  Value *SVEMatchCmp = Builder.CreateICmp(
+      ICmpInst::getInversePredicate(CmpPred), SVELhsLoad, SVERhsLoad);
   SVEMatchCmp = Builder.CreateSelect(LoopPred, SVEMatchCmp, PFalse);
   Value *SVEMatchHasActiveLanes = Builder.CreateOrReduce(SVEMatchCmp);
   BranchInst *SVEEarlyExit = BranchInst::Create(
@@ -658,7 +728,7 @@ Value *AArch64LoopIdiomTransform::expandFindMismatch(
 
   Builder.SetInsertPoint(LoopStartBlock);
   PHINode *IndexPhi = Builder.CreatePHI(ResType, 2, "mismatch_index");
-  IndexPhi->addIncoming(Start, LoopPreHeaderBlock);
+  IndexPhi->addIncoming(StartIdx, LoopPreHeaderBlock);
 
   // Otherwise compare the values
   // Load bytes from each array and compare them.
@@ -674,7 +744,7 @@ Value *AArch64LoopIdiomTransform::expandFindMismatch(
     cast<GetElementPtrInst>(RhsGep)->setIsInBounds(true);
   Value *RhsLoad = Builder.CreateLoad(LoadType, RhsGep);
 
-  Value *MatchCmp = Builder.CreateICmpEQ(LhsLoad, RhsLoad);
+  Value *MatchCmp = Builder.CreateICmp(CmpPred, LhsLoad, RhsLoad);
   // If we have a mismatch then exit the loop ...
   BranchInst *MatchCmpBr = BranchInst::Create(LoopIncBlock, EndBlock, MatchCmp);
   Builder.Insert(MatchCmpBr);
@@ -723,11 +793,7 @@ Value *AArch64LoopIdiomTransform::expandFindMismatch(
   return FinalRes;
 }
 
-void AArch64LoopIdiomTransform::transformByteCompare(
-    GetElementPtrInst *GEPA, GetElementPtrInst *GEPB, PHINode *IndPhi,
-    Value *MaxLen, Instruction *Index, Value *Start, bool IncIdx,
-    BasicBlock *FoundBB, BasicBlock *EndBB) {
-
+void MemCompareIdiom::transform() {
   // Insert the byte compare code at the end of the preheader block
   BasicBlock *Preheader = CurLoop->getLoopPreheader();
   BasicBlock *Header = CurLoop->getHeader();
@@ -738,10 +804,10 @@ void AArch64LoopIdiomTransform::transformByteCompare(
 
   // Increment the pointer if this was done before the loads in the loop.
   if (IncIdx)
-    Start = Builder.CreateAdd(Start, ConstantInt::get(Start->getType(), 1));
+    StartIdx =
+        Builder.CreateAdd(StartIdx, ConstantInt::get(StartIdx->getType(), 1));
 
-  Value *ByteCmpRes =
-      expandFindMismatch(Builder, DTU, GEPA, GEPB, Index, Start, MaxLen);
+  Value *ByteCmpRes = generateMemCompare(Builder, DTU);
 
   // Replaces uses of index & induction Phi with intrinsic (we already
   // checked that the the first instruction of Header is t...
[truncated]

Copy link

github-actions bot commented Jan 17, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

bool MemCompareIdiom::checkIterationCondition(BasicBlock *CheckItBB,
BasicBlock *&LoopBB) {
ICmpInst::Predicate Pred;
if (!match(CheckItBB->getTerminator(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of pattern matching the icmp, I was wondering whether we can use this Loop API:

ICmpInst * getLatchCmpInst () const
Get the latch condition instruction.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I might have looked at this before and it requires loops to have a single latch. The icmp in this case occurs in the first block, not the last block. I'll double check, but I think there was a reason I couldn't use that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this function does return something, but if you don't mind I've not used it here because I'd still have to do the work to extract the blocks in the branch, check the icmp predicate, Index, MaxLen, etc. I think the code is more compact right now, whereas using the latch function would require something like this:

  ICmpInst IC = CurLoop->getLatchCmpInst();
  if (!IC ||
      !match(CheckItBB->getTerminator(), m_Br(m_Specific(IC), m_BasicBlock(EndBB), m_BasicBlock(LoopBB))) ||
      IC->getOperand(0) != Index ||
      IC->getPredicate() != ICmpInst::Predicate::ICMP_EQ ||
      !CurLoop->contains(LoopBB))
     return false;

  MaxLen = IC->getOperand(1);
  return false;

return false;

// Check that the index to the GEPs is the index we found earlier
if (GEPA->getNumIndices() > 1 || GEPB->getNumIndices() > 1)
// We only expect one PHI node for the index.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't remember if asked this on the original merge-request, but was wondering if we can use a Loop API here, e.g.:

PHINode * getInductionVariable (ScalarEvolution &SE) const
Return the loop induction variable if found, else return nullptr.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly? Although getInductionVariable calls getLatchCmpInst so I think we might have the same issue as mentioned above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I think this is possible, but again if you don't mind I've kept the original code. The reason is that getInductionVariable requires initialising the ScalarEvolution class, doing work to create SCEVAddRecAddr and InductionDescriptor objects for nodes that we're going to throw away and ignore. It's also far more powerful than we need, since it supports more induction variables than the narrow set we care about, i.e. it supports FP inductions, increments != 1, non-linear inductions, etc. It feels a shame to do all that work, and then afterwards restrict it to the very narrow set we actually care about for these tight loops.

llvm/lib/Target/AArch64/AArch64LoopIdiomTransform.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/AArch64/AArch64LoopIdiomTransform.cpp Outdated Show resolved Hide resolved
* Change struct -> class.
* Update comments about outside users of loop instructions.
* In same loop as ^^ remove the check for IndPhi as this was
unnecessary.
* Fix formatting issue.
@JeffHDF
Copy link

JeffHDF commented Jan 24, 2024

In a future patch I intend to add support for other types of C loops that do memory comparisons with early exits, e.g.

  for (unsigned long i = start; i < end; i++) {
    if (p1[i] != p2[i])
      break;
  }

where we compare the loaded values prior to incrementing the induction variable i. This requires first refactoring the pass ready to support the new loop structures. I've created a new MemCompareIdiom class with support routines that can be reused by new code required to recognise these new loops.

I also modified MemCompareIdiom::generateMemCompare to ensure it is ready to support arbitrary induction variable types as well as supporting new comparison predicates when comparing the memory values.

Hi, Where these more scenarios come from? How do you consider this question? As i known, there are some scenarios in sub-item 557.xz of SPEC2017 int. Will you consider to support these scenarios?

@david-arm
Copy link
Contributor Author

Hi, Where these more scenarios come from? How do you consider this question? As i known, there are some scenarios in sub-item 557.xz of SPEC2017 int. Will you consider to support these scenarios?

Yes, some of these new scenarios are in the xz benchmark in SPEC2017. There are loops recognised in blender too, although these are not in hot code. Certainly there are lots of cases in the LLVM test suite where we will transform loops too. To be honest, one of the main benefits of going wide and supporting more loops is to get better test coverage, although hopefully we will see increased performance in more applications as well.

@shaojingzhi
Copy link
Contributor

Hi, I'm interested in this pattern too, and I wonder that when would you plan to finish this pattern match? Would you just recognize this pattern and transform it to the original while loop IR and do the following transform reusing the original code, or would you like to write something new for this pattern?

@david-arm
Copy link
Contributor Author

Hi, I'm interested in this pattern too, and I wonder that when would you plan to finish this pattern match? Would you just recognize this pattern and transform it to the original while loop IR and do the following transform reusing the original code, or would you like to write something new for this pattern?

I already have a second patch waiting to upstream once this patch lands. This patch just makes it a lot easier to recognise different forms of the same idiom and transform the loop using the same generation code. There are a surprisingly large number of variations of the same idiom out there in SPEC and other benchmarks!

@@ -1,3 +1,4 @@

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this extra newline?

@@ -592,7 +660,8 @@ Value *AArch64LoopIdiomTransform::expandFindMismatch(
Value *SVERhsLoad = Builder.CreateMaskedLoad(SVELoadType, SVERhsGep, Align(1),
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you try to use normal load? And generate an epilogue loop after the SVE loop, I think there may be more performance impovement.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants