Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV] Introduce the RISCVLoopIdiomRecognizePass #92441

Closed

Conversation

mshockwave
Copy link
Member

RISCVLoopIdiomRecognize pattern matches a specific kind of byte compare loop that looks like:

while (i != max_len)
  if (a[i] != b[i])
      break;

... use index i ...

And effectively vectorizes it. This is similar to AArch64's AArch64LoopIdiomTransform Pass, except we're using VP intrinsic here.

This improves 557.xz_r in SPEC2017 by nearly 20% (on refrate workload) in terms of dynamic instruction counts.

Note that at this moment, we're not enabling this by default.

RISCVLoopIdiomRecognize pattern matches a specific kind of byte compare
loop that looks like:
```
while (i != max_len)
  if (a[i] != b[i])
      break;

... use index i ...
```
And effectively vectorizes it.

This is similar to AArch64's AArch64LoopIdiomTransform Pass, except
we're using VP intrinsic here.

This improves 557.xz_r in SPEC2017 by nearly 20% (on refrate workload)
in terms of dynamic instruction counts.
@llvmbot
Copy link
Collaborator

llvmbot commented May 16, 2024

@llvm/pr-subscribers-backend-risc-v

@llvm/pr-subscribers-llvm-transforms

Author: Min-Yih Hsu (mshockwave)

Changes

RISCVLoopIdiomRecognize pattern matches a specific kind of byte compare loop that looks like:

while (i != max_len)
  if (a[i] != b[i])
      break;

... use index i ...

And effectively vectorizes it. This is similar to AArch64's AArch64LoopIdiomTransform Pass, except we're using VP intrinsic here.

This improves 557.xz_r in SPEC2017 by nearly 20% (on refrate workload) in terms of dynamic instruction counts.

Note that at this moment, we're not enabling this by default.


Patch is 162.12 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/92441.diff

6 Files Affected:

  • (modified) llvm/lib/Target/RISCV/CMakeLists.txt (+1)
  • (added) llvm/lib/Target/RISCV/RISCVLoopIdiomRecognize.cpp (+752)
  • (added) llvm/lib/Target/RISCV/RISCVLoopIdiomRecognize.h (+25)
  • (modified) llvm/lib/Target/RISCV/RISCVTargetMachine.cpp (+20)
  • (modified) llvm/lib/Target/RISCV/RISCVTargetMachine.h (+3)
  • (added) llvm/test/Transforms/LoopIdiom/RISCV/byte-compare-index.ll (+1771)
diff --git a/llvm/lib/Target/RISCV/CMakeLists.txt b/llvm/lib/Target/RISCV/CMakeLists.txt
index 8715403f3839a..8d9cb65940097 100644
--- a/llvm/lib/Target/RISCV/CMakeLists.txt
+++ b/llvm/lib/Target/RISCV/CMakeLists.txt
@@ -43,6 +43,7 @@ add_llvm_target(RISCVCodeGen
   RISCVInstrInfo.cpp
   RISCVISelDAGToDAG.cpp
   RISCVISelLowering.cpp
+  RISCVLoopIdiomRecognize.cpp
   RISCVMachineFunctionInfo.cpp
   RISCVMergeBaseOffset.cpp
   RISCVOptWInstrs.cpp
diff --git a/llvm/lib/Target/RISCV/RISCVLoopIdiomRecognize.cpp b/llvm/lib/Target/RISCV/RISCVLoopIdiomRecognize.cpp
new file mode 100644
index 0000000000000..331e0a3ea6534
--- /dev/null
+++ b/llvm/lib/Target/RISCV/RISCVLoopIdiomRecognize.cpp
@@ -0,0 +1,752 @@
+//===-------- RISCVLoopIdiomRecognize.cpp - Loop idiom recognition --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "RISCVLoopIdiomRecognize.h"
+#include "llvm/ADT/ScopeExit.h"
+#include "llvm/Analysis/DomTreeUpdater.h"
+#include "llvm/Analysis/LoopPass.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/IntrinsicsRISCV.h"
+#include "llvm/IR/MDBuilder.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/TargetParser/RISCVTargetParser.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "riscv-loop-idiom"
+
+static cl::opt<bool>
+    DisableAll("riscv-disable-all-loop-idiom", cl::Hidden, cl::init(true),
+               cl::desc("Disable RISCV Loop Idiom Recognize Pass."));
+
+static cl::opt<bool> DisableByteCmp(
+    "disable-riscv-loop-idiom-bytecmp", cl::Hidden, cl::init(false),
+    cl::desc("Proceed with RISCV Loop Idiom Recognize Pass, but do "
+             "not convert byte-compare loop(s)."));
+
+// CustomLoopIdiomLMUL can be used to customize LMUL for vectorizing loops.
+// It uses the exponent value to represent LMUL i.e. 0 -> LMUL 1, 1 -> LMUL 2, 2
+// -> LMUL 4, 3 -> LMUL 8, etc.
+static cl::opt<unsigned>
+    CustomLoopIdiomLMUL("riscv-loop-idiom-lmul", cl::Hidden, cl::init(1),
+                        cl::Optional,
+                        cl::desc("Customize LMUL for vector loop."));
+
+namespace {
+
+class RISCVLoopIdiomRecognize {
+  Loop *CurLoop = nullptr;
+  DominatorTree &DT;
+  LoopInfo &LI;
+  TargetLibraryInfo &TLI;
+  const TargetTransformInfo &TTI;
+  const DataLayout &DL;
+
+public:
+  explicit RISCVLoopIdiomRecognize(DominatorTree &DT, LoopInfo &LI,
+                                   TargetLibraryInfo &TLI,
+                                   const TargetTransformInfo &TTI,
+                                   const DataLayout &DL)
+      : DT(DT), LI(LI), TLI(TLI), TTI(TTI), DL(DL) {}
+
+  bool run(Loop *L);
+
+private:
+  /// \name Countable Loop Idiom Handling
+  /// @{
+
+  bool runOnCountableLoop();
+  bool runOnLoopBlock(BasicBlock *BB, const SCEV *BECount,
+                      SmallVectorImpl<BasicBlock *> &ExitBlocks);
+
+  bool recognizeAndTransformByteCompare();
+  Value *expandFindMismatch(IRBuilder<> &Builder, GetElementPtrInst *GEPA,
+                            GetElementPtrInst *GEPB, Instruction *Index,
+                            Value *Start, Value *MaxLen);
+  void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
+                            PHINode *IndPhi, Value *MaxLen, Instruction *Index,
+                            Value *Start, bool IncIdx, BasicBlock *FoundBB,
+                            BasicBlock *EndBB);
+
+  /// @}
+};
+} // end anonymous namespace
+
+static VectorType *getBestVectorTypeForLoopIdiom(LLVMContext &Ctx) {
+  unsigned LMULExp = std::min(3U, CustomLoopIdiomLMUL.getValue());
+  unsigned VF = (RISCV::RVVBitsPerBlock / 8) << LMULExp;
+  ElementCount EC = ElementCount::getScalable(VF);
+  return VectorType::get(Type::getInt8Ty(Ctx), EC);
+}
+
+PreservedAnalyses
+RISCVLoopIdiomRecognizePass::run(Loop &L, LoopAnalysisManager &AM,
+                                 LoopStandardAnalysisResults &AR,
+                                 LPMUpdater &) {
+  if (DisableAll)
+    return PreservedAnalyses::all();
+
+  Function &F = *L.getHeader()->getParent();
+  if (F.hasFnAttribute(Attribute::NoImplicitFloat)) {
+    LLVM_DEBUG(dbgs() << DEBUG_TYPE << " is disabled on " << F.getName()
+                      << " due to its NoImplicitFloat attribute");
+    return PreservedAnalyses::all();
+  }
+
+  // Only enabled on RV64 for now.
+  if (L.getHeader()->getModule()->getDataLayout().getPointerSizeInBits() != 64)
+    return PreservedAnalyses::all();
+
+  // Only enabled when vector extension is present.
+  if (!AR.TTI.supportsScalableVectors())
+    return PreservedAnalyses::all();
+
+  const auto DL = L.getHeader()->getModule()->getDataLayout();
+
+  RISCVLoopIdiomRecognize LIR(AR.DT, AR.LI, AR.TLI, AR.TTI, DL);
+  if (!LIR.run(&L))
+    return PreservedAnalyses::all();
+
+  auto PA = PreservedAnalyses::none();
+  PA.preserve<DominatorTreeAnalysis>();
+  return PA;
+}
+
+//===----------------------------------------------------------------------===//
+//
+//          Implementation of RISCVLoopIdiomRecognize
+//
+//===----------------------------------------------------------------------===//
+
+bool RISCVLoopIdiomRecognize::run(Loop *L) {
+  CurLoop = L;
+
+  if (DisableAll)
+    return false;
+
+  // If the loop could not be converted to canonical form, it must have an
+  // indirectbr in it, just give up.
+  if (!L->getLoopPreheader())
+    return false;
+
+  LLVM_DEBUG(dbgs() << DEBUG_TYPE " Scanning: F["
+                    << CurLoop->getHeader()->getParent()->getName()
+                    << "] Loop %" << CurLoop->getHeader()->getName() << "\n");
+
+  return recognizeAndTransformByteCompare();
+}
+
+/// Match loop-invariant value.
+template <typename SubPattern_t> struct match_LoopInvariant {
+  SubPattern_t SubPattern;
+  const Loop *L;
+
+  match_LoopInvariant(const SubPattern_t &SP, const Loop *L)
+      : SubPattern(SP), L(L) {}
+
+  template <typename ITy> bool match(ITy *V) {
+    return L->isLoopInvariant(V) && SubPattern.match(V);
+  }
+};
+
+/// Matches if the value is loop-invariant.
+template <typename Ty>
+inline match_LoopInvariant<Ty> m_LoopInvariant(const Ty &M, const Loop *L) {
+  return match_LoopInvariant<Ty>(M, L);
+}
+
+bool RISCVLoopIdiomRecognize::recognizeAndTransformByteCompare() {
+  if (DisableByteCmp)
+    return false;
+
+  BasicBlock *PH = CurLoop->getLoopPreheader();
+
+  // The preheader should only contain an unconditional branch.
+  if (!PH || &PH->front() != PH->getTerminator())
+    return false;
+
+  using namespace PatternMatch;
+
+  BasicBlock *Header;
+  if (!match(PH->getTerminator(), m_UnconditionalBr(Header)))
+    return false;
+
+  if (Header != CurLoop->getHeader())
+    return false;
+
+  if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2)
+    return false;
+
+  auto *PN = dyn_cast<PHINode>(&Header->front());
+  if (!PN || PN->getNumIncomingValues() != 2)
+    return false;
+
+  auto LoopBlocks = CurLoop->getBlocks();
+  // The first block in the loop should contain only 4 instructions, e.g.
+  //
+  //  while.cond:
+  //   %res.phi = phi i32 [ %start, %ph ], [ %inc, %while.body ]
+  //   %inc = add i32 %res.phi, 1
+  //   %cmp.not = icmp eq i32 %inc, %n
+  //   br i1 %cmp.not, label %while.end, label %while.body
+  //
+  auto CondBBInsts = LoopBlocks[0]->instructionsWithoutDebug();
+  if (std::distance(CondBBInsts.begin(), CondBBInsts.end()) != 4)
+    return false;
+
+  // The second block should contain 7 instructions, e.g.
+  //
+  // while.body:
+  //   %idx = zext i32 %inc to i64
+  //   %idx.a = getelementptr inbounds i8, ptr %a, i64 %idx
+  //   %load.a = load i8, ptr %idx.a
+  //   %idx.b = getelementptr inbounds i8, ptr %b, i64 %idx
+  //   %load.b = load i8, ptr %idx.b
+  //   %cmp.not.ld = icmp eq i8 %load.a, %load.b
+  //   br i1 %cmp.not.ld, label %while.cond, label %while.end
+  //
+  auto LoopBBInsts = LoopBlocks[1]->instructionsWithoutDebug();
+  if (std::distance(LoopBBInsts.begin(), LoopBBInsts.end()) != 7)
+    return false;
+
+  // The incoming value to the PHI node from the loop should be an add of 1.
+  Instruction *Index = nullptr;
+  Value *StartIdx = nullptr;
+  for (BasicBlock *BB : PN->blocks()) {
+    if (!CurLoop->contains(BB)) {
+      StartIdx = PN->getIncomingValueForBlock(BB);
+      continue;
+    }
+    Index = dyn_cast<Instruction>(PN->getIncomingValueForBlock(BB));
+    // Limit to 32-bit types for now
+    if (!Index || !Index->getType()->isIntegerTy(32) ||
+        !match(Index, m_c_Add(m_Specific(PN), m_One())))
+      return false;
+  }
+
+  for (BasicBlock *BB : LoopBlocks)
+    for (Instruction &I : *BB)
+      if (&I != PN && &I != Index)
+        for (User *U : I.users()) {
+          auto UI = dyn_cast<Instruction>(U);
+          if (!CurLoop->contains(UI))
+            return false;
+        }
+
+  // Match the branch instruction for the header
+  ICmpInst::Predicate Pred;
+  Value *MaxLen;
+  BasicBlock *EndBB, *WhileBB;
+  if (!match(Header->getTerminator(),
+             m_Br(m_ICmp(Pred, m_Specific(Index), m_Value(MaxLen)),
+                  m_BasicBlock(EndBB), m_BasicBlock(WhileBB))))
+    return false;
+
+  // Make sure Pred is comparing for equal
+  if (Pred != ICmpInst::ICMP_EQ)
+    return false;
+
+  // Make sure EndBB is outside the loop and WhileBB is inside the loop.
+  if (CurLoop->contains(EndBB) || !CurLoop->contains(WhileBB))
+    return false;
+
+  // WhileBB should contain the pattern of load & compare instructions. Match
+  // the pattern and find the GEP instructions used by the loads.
+  ICmpInst::Predicate WhilePred;
+  BasicBlock *FoundBB;
+  BasicBlock *TrueBB;
+  Value *A, *B;
+  if (!match(WhileBB->getTerminator(),
+             m_Br(m_ICmp(WhilePred, m_Load(m_Value(A)), m_Load(m_Value(B))),
+                  m_BasicBlock(TrueBB), m_BasicBlock(FoundBB))))
+    return false;
+
+  // Make sure WhilePred is comparing for equal
+  if (WhilePred != ICmpInst::ICMP_EQ)
+    return false;
+
+  // Make sure TrueBB is the loop header and FoundBB is outside the loop.
+  if (CurLoop->getHeader() != TrueBB || CurLoop->contains(FoundBB))
+    return false;
+
+  auto *GEPA = dyn_cast<GetElementPtrInst>(A);
+  auto *GEPB = dyn_cast<GetElementPtrInst>(B);
+  if (!GEPA || !GEPB)
+    return false;
+
+  Value *PtrA = GEPA->getPointerOperand();
+  Value *PtrB = GEPB->getPointerOperand();
+
+  // Check PtrA and PtrB stride at i8.
+  if (!CurLoop->isLoopInvariant(PtrA) || !CurLoop->isLoopInvariant(PtrB) ||
+      !GEPA->getResultElementType()->isIntegerTy(8) ||
+      !GEPB->getResultElementType()->isIntegerTy(8) || PtrA == PtrB)
+    return false;
+
+  // Check loads from GEPA and GEPB are i8.
+  auto *LoadA = dyn_cast<LoadInst>(GEPA->getNextNode());
+  if (!LoadA || !LoadA->getType()->isIntegerTy(8))
+    return false;
+  auto *LoadB = dyn_cast<LoadInst>(GEPB->getNextNode());
+  if (!LoadB || !LoadB->getType()->isIntegerTy(8))
+    return false;
+
+  // Check that the index to the GEPs is the index we found earlier
+  if (GEPA->getNumIndices() > 1 || GEPB->getNumIndices() > 1)
+    return false;
+
+  Value *IdxA = GEPA->getOperand(GEPA->getNumIndices());
+  Value *IdxB = GEPB->getOperand(GEPB->getNumIndices());
+
+  if (IdxA != IdxB || !match(IdxA, m_ZExt(m_Specific(Index))))
+    return false;
+
+  // We only ever expect the pre-incremented index value to be used inside the
+  // loop.
+  if (!PN->hasOneUse())
+    return false;
+
+  // Ensure that when the Found and End blocks are identical the PHIs have the
+  // supported format. We don't currently allow cases like this:
+  // while.cond:
+  //   ...
+  //   br i1 %cmp.not, label %while.end, label %while.body
+  //
+  // while.body:
+  //   ...
+  //   br i1 %cmp.not2, label %while.cond, label %while.end
+  //
+  // while.end:
+  //   %final_ptr = phi ptr [ %c, %while.body ], [ %d, %while.cond ]
+  //
+  // Where the incoming values for %final_ptr are unique and from each of the
+  // loop blocks, but not actually defined in the loop. This requires extra
+  // work setting up the byte.compare block, i.e. by introducing a select to
+  // choose the correct value.
+  // TODO: We could add support for this in future.
+  if (FoundBB == EndBB) {
+    for (PHINode &EndPN : EndBB->phis()) {
+      Value *WhileCondVal = EndPN.getIncomingValueForBlock(Header);
+      Value *WhileBodyVal = EndPN.getIncomingValueForBlock(WhileBB);
+
+      // The value of the index when leaving the while.cond block is always the
+      // same as the end value (MaxLen) so we permit either. Otherwise for any
+      // other value defined outside the loop we only allow values that are the
+      // same as the exit value for while.body.
+      if (WhileCondVal != WhileBodyVal &&
+          ((WhileCondVal != Index && WhileCondVal != MaxLen) ||
+           (WhileBodyVal != Index && WhileBodyVal != MaxLen)))
+        return false;
+    }
+  }
+
+  LLVM_DEBUG(dbgs() << "FOUND IDIOM IN LOOP: \n"
+                    << *(EndBB->getParent()) << "\n\n");
+  transformByteCompare(GEPA, GEPB, PN, MaxLen, Index, StartIdx, true, FoundBB,
+                       EndBB);
+  LLVM_DEBUG(dbgs() << "AFTER IDIOM TRANSFORMATION: \n"
+                    << *(EndBB->getParent()) << "\n\n");
+  return true;
+}
+
+Value *RISCVLoopIdiomRecognize::expandFindMismatch(
+    IRBuilder<> &Builder, GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
+    Instruction *Index, Value *Start, Value *MaxLen) {
+  Value *PtrA = GEPA->getPointerOperand();
+  Value *PtrB = GEPB->getPointerOperand();
+
+  // Get the arguments and types for the intrinsic.
+  BasicBlock *Preheader = CurLoop->getLoopPreheader();
+  auto *PHBranch = cast<BranchInst>(Preheader->getTerminator());
+  LLVMContext &Ctx = PHBranch->getContext();
+  Type *LoadType = Type::getInt8Ty(Ctx);
+  Type *ResType = Builder.getInt32Ty();
+
+  // Split block at the original callsite, where the EndBlock continues from
+  // where the original call ended.
+  DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
+  BasicBlock *EndBlock =
+      SplitBlock(Preheader, PHBranch, &DT, &LI, nullptr, "mismatch_end");
+
+  // Safeguard to check if we build the correct DomTree with DTU.
+  auto CheckDTU = llvm::make_scope_exit([&]() {
+    assert(DTU.getDomTree().verify() && "Ill-formed DomTree built by DTU");
+  });
+
+  // Create the blocks that we're going to need:
+  //  1. A block for checking the zero-extended length exceeds 0
+  //  2. A block to check that the start and end addresses of a given array
+  //     lie on the same page.
+  //  3. The RVV loop preheader i.e. vector_loop_preheader
+  //  4. The first RVV loop block i.e. vector_loop
+  //  5. The RVV loop increment block i.e. vector_loop_inc
+  //  6. A block we can jump to from the RVV loop when a mismatch is found i.e.
+  //  vector_loop_exit
+  //  7. The first block of the scalar loop itself, containing PHIs , loads
+  //  and cmp.
+  //  8. A scalar loop increment block to increment the PHIs and go back
+  //  around the loop.
+
+  BasicBlock *MinItCheckBlock = BasicBlock::Create(
+      Ctx, "mismatch_min_it_check", EndBlock->getParent(), EndBlock);
+
+  // This DTU update is actually the only one we need to cover all control flow
+  // changes made in this function. Because the current DTU algorithm
+  // recaculates the whole sub-tree between a deleted edge. And the edge between
+  // Preheader and EndBlock happens to enclose all the blocks we inserted
+  // in this function.
+  DTU.applyUpdates({{DominatorTree::Insert, Preheader, MinItCheckBlock},
+                    {DominatorTree::Delete, Preheader, EndBlock}});
+
+  // Update the terminator added by SplitBlock to branch to the first block
+  Preheader->getTerminator()->setSuccessor(0, MinItCheckBlock);
+
+  BasicBlock *MemCheckBlock = BasicBlock::Create(
+      Ctx, "mismatch_mem_check", EndBlock->getParent(), EndBlock);
+
+  BasicBlock *RVVLoopPreheaderBlock = BasicBlock::Create(
+      Ctx, "mismatch_vector_loop_preheader", EndBlock->getParent(), EndBlock);
+
+  BasicBlock *RVVLoopStartBlock = BasicBlock::Create(
+      Ctx, "mismatch_vector_loop", EndBlock->getParent(), EndBlock);
+
+  BasicBlock *RVVLoopIncBlock = BasicBlock::Create(
+      Ctx, "mismatch_vector_loop_inc", EndBlock->getParent(), EndBlock);
+
+  BasicBlock *RVVLoopMismatchBlock = BasicBlock::Create(
+      Ctx, "mismatch_vector_loop_found", EndBlock->getParent(), EndBlock);
+
+  BasicBlock *LoopPreHeaderBlock = BasicBlock::Create(
+      Ctx, "mismatch_loop_pre", EndBlock->getParent(), EndBlock);
+
+  BasicBlock *LoopStartBlock =
+      BasicBlock::Create(Ctx, "mismatch_loop", EndBlock->getParent(), EndBlock);
+
+  BasicBlock *LoopIncBlock = BasicBlock::Create(
+      Ctx, "mismatch_loop_inc", EndBlock->getParent(), EndBlock);
+
+  // Update LoopInfo with the new RVV & scalar loops.
+  auto RVVLoop = LI.AllocateLoop();
+  auto ScalarLoop = LI.AllocateLoop();
+  if (CurLoop->getParentLoop()) {
+    CurLoop->getParentLoop()->addChildLoop(RVVLoop);
+    CurLoop->getParentLoop()->addChildLoop(ScalarLoop);
+
+    CurLoop->getParentLoop()->addBasicBlockToLoop(MinItCheckBlock, LI);
+    CurLoop->getParentLoop()->addBasicBlockToLoop(MemCheckBlock, LI);
+    CurLoop->getParentLoop()->addBasicBlockToLoop(RVVLoopPreheaderBlock, LI);
+    CurLoop->getParentLoop()->addBasicBlockToLoop(RVVLoopMismatchBlock, LI);
+    CurLoop->getParentLoop()->addBasicBlockToLoop(LoopPreHeaderBlock, LI);
+  } else {
+    LI.addTopLevelLoop(RVVLoop);
+    LI.addTopLevelLoop(ScalarLoop);
+  }
+
+  // Add the new basic blocks to their associated loops.
+  RVVLoop->addBasicBlockToLoop(RVVLoopStartBlock, LI);
+  RVVLoop->addBasicBlockToLoop(RVVLoopIncBlock, LI);
+
+  ScalarLoop->addBasicBlockToLoop(LoopStartBlock, LI);
+  ScalarLoop->addBasicBlockToLoop(LoopIncBlock, LI);
+
+  // Set up some types and constants that we intend to reuse.
+  Type *I64Type = Builder.getInt64Ty();
+  Type *I32Type = Builder.getInt32Ty();
+
+  // Check the zero-extended iteration count > 0
+  Builder.SetInsertPoint(MinItCheckBlock);
+  Value *ExtStart = Builder.CreateZExt(Start, I64Type);
+  Value *ExtEnd = Builder.CreateZExt(MaxLen, I64Type);
+  // This check doesn't really cost us very much.
+
+  Value *LimitCheck = Builder.CreateICmpULE(Start, MaxLen);
+  BranchInst *MinItCheckBr =
+      BranchInst::Create(MemCheckBlock, LoopPreHeaderBlock, LimitCheck);
+  MinItCheckBr->setMetadata(
+      LLVMContext::MD_prof,
+      MDBuilder(MinItCheckBr->getContext()).createBranchWeights(99, 1));
+  Builder.Insert(MinItCheckBr);
+
+  // For each of the arrays, check the start/end addresses are on the same
+  // page.
+  Builder.SetInsertPoint(MemCheckBlock);
+
+  // For each start address calculate the offset into the min architecturally
+  // allowed page size (4096). Then determine how many bytes there are left on
+  // the page and see if this is >= MaxLen.
+  Value *LhsStartPage = Builder.CreateLShr(
+      Builder.CreatePtrToInt(Builder.CreateGEP(LoadType, PtrA, ExtStart),
+                             I64Type),
+      uint64_t(12));
+  Value *LhsEndPage = Builder.CreateLShr(
+      Builder.CreatePtrToInt(Builder.CreateGEP(LoadType, PtrA, ExtEnd),
+                             I64Type),
+      uint64_t(12));
+  Value *RhsStartPage = Builder.CreateLShr(
+      Builder.CreatePtrToInt(Builder.CreateGEP(LoadType, PtrB, ExtStart),
+                             I64Type),
+      uint64_t(12));
+  Value *RhsEndPage = Builder.CreateLShr(
+      Builder.CreatePtrToInt(Builder.CreateGEP(LoadType, PtrB, ExtEnd),
+                             I64Type),
+      uint64_t(12));
+  Value *LhsPageCmp = Builder.CreateICmpNE(LhsStartPage, LhsEndPage);
+  Value *RhsPageCmp = Builder.CreateICmpNE(RhsStartPage, RhsEndPage);
+
+  BranchInst *CombinedPageCmpCmpBr =
+      BranchInst::Create(LoopPreHeaderBlock, RVVLoopPreheaderBlock,
+                         Builder.CreateOr(LhsPageCmp, RhsPageCmp));
+  CombinedPageCmpCmpBr->setMetadata(
+      LLVMContext::MD_prof, MDBuilder(CombinedPageCmpCmpBr->getContext())
+        ...
[truncated]

And some minor revisions.
// %cmp.not = icmp eq i32 %inc, %n
// br i1 %cmp.not, label %while.end, label %while.body
//
auto CondBBInsts = LoopBlocks[0]->instructionsWithoutDebug();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this equivalent to LoopBlocks[0]->sizeWithoutDebug(). I think this was lifted from AArch64 and could be simplified there too.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have addressed this in the new PR(s).

// %cmp.not.ld = icmp eq i8 %load.a, %load.b
// br i1 %cmp.not.ld, label %while.cond, label %while.end
//
auto LoopBBInsts = LoopBlocks[1]->instructionsWithoutDebug();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sizeWithoutDebug

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in the new PR(s).

for (Instruction &I : *BB)
if (&I != PN && &I != Index)
for (User *U : I.users()) {
auto *UI = dyn_cast<Instruction>(U);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be a cast instead of a dyn_cast? I don't think contains wants a nullptr. It immediate calls getParent() on it.

Since we're checking the user of an Instruction, I think the user itself can only be another Instruction.

Looks like AArch64 uses a cast here


// Ensure all Phis in the successors of CmpBB have an incoming value from it.
FixSuccessorPhis(EndBB);
FixSuccessorPhis(FoundBB);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AArch64 has an if before the second FixSuccessorPhis


// Update the dominator tree with the new block.
DT.addNewBlock(CmpBB, Preheader);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AArch64 verifies isRecursivelyLCSSAForm here

}

bool RISCVLoopIdiomRecognize::recognizeAndTransformByteCompare() {
if (DisableByteCmp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we early out if there is no V instructions?

@preames
Copy link
Collaborator

preames commented May 20, 2024

Glancing at this change and the corresponding AArch64 pass, I find myself questioning the high level strategy a bit. Unless I'm missing something major, there doesn't really look to be much that is AArch64 specific to the AArch64 version of this pass.

As a starting point, have you considered simply generalizing the aarch64 pass into a shared pass which can run on both targets? This would get you code generation via masked.loads for this case - but not the VP & VL predicated form.

Once this was done, you could parameterize the code generation (only, not the matching) on whether the target preferred VL style predication or not.

@ppenzin
Copy link

ppenzin commented May 22, 2024

As a starting point, have you considered simply generalizing the aarch64 pass into a shared pass which can run on both targets? This would get you code generation via masked.loads for this case - but not the VP & VL predicated form.

I think it makes sense to have a pass separate from AArch64, the semantics of SVE and RVV are different enough that these passes are bound to diverge at some point.

@topperc
Copy link
Collaborator

topperc commented May 22, 2024

As a starting point, have you considered simply generalizing the aarch64 pass into a shared pass which can run on both targets? This would get you code generation via masked.loads for this case - but not the VP & VL predicated form.

I think it makes sense to have a pass separate from AArch64, the semantics of SVE and RVV are different enough that these passes are bound to diverge at some point.

All of the idiom detection code is currently identical. There are some minor differences that I think came from the timing of when we copied the code and later changes made by AArch64. It was not intended for them to diverge. This is one problem having a single pass solves.

There are differences in how we expand the code. Maybe we can use the TTI hooks the middle loop vectorizer uses to have different strategies for AArch64 and RISC-V. If that gets too complicated we could have a common base class that AArch64 and RISC-V both inherit their own passes from so they can share the code that makes sense.

@ppenzin
Copy link

ppenzin commented May 23, 2024

There are differences in how we expand the code. Maybe we can use the TTI hooks the middle loop vectorizer uses to have different strategies for AArch64 and RISC-V. If that gets too complicated we could have a common base class that AArch64 and RISC-V both inherit their own passes from so they can share the code that makes sense.

Yeah, that makes sense.

@mshockwave
Copy link
Member Author

Closed in favor of #94081 and #94082

@mshockwave mshockwave closed this May 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants