Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce PrecomputeLoopExpressionsPass. #90263

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

huihzhang
Copy link
Contributor

This patch is not for review, but to share the optimization to community.

This patch is not for review, but to share this optimization to community.
@llvmbot
Copy link
Collaborator

llvmbot commented Apr 26, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Huihui Zhang (huihzhang)

Changes

This patch is not for review, but to share the optimization to community.


Patch is 35.37 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/90263.diff

4 Files Affected:

  • (added) llvm/include/llvm/Transforms/Scalar/PrecomputeLoopExpressions.h (+27)
  • (modified) llvm/lib/Passes/PassBuilderPipelines.cpp (+7)
  • (modified) llvm/lib/Transforms/Scalar/CMakeLists.txt (+1)
  • (added) llvm/lib/Transforms/Scalar/PrecomputeLoop.cpp (+1107)
diff --git a/llvm/include/llvm/Transforms/Scalar/PrecomputeLoopExpressions.h b/llvm/include/llvm/Transforms/Scalar/PrecomputeLoopExpressions.h
new file mode 100644
index 00000000000000..42eea0a8d53e59
--- /dev/null
+++ b/llvm/include/llvm/Transforms/Scalar/PrecomputeLoopExpressions.h
@@ -0,0 +1,27 @@
+//===-------------------- PrecomputeLoopExpressions.h ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_TRANSFORMS_SCALAR_PRECOMPUTE_LOOP_EXPRESSIONS_H
+#define LLVM_TRANSFORMS_SCALAR_PRECOMPUTE_LOOP_EXPRESSIONS_H
+
+#include "llvm/IR/PassManager.h"
+
+namespace llvm {
+class PrecomputeLoopExpressionsPass
+    : public PassInfoMixin<PrecomputeLoopExpressionsPass> {
+  unsigned TotalInitSize;
+
+public:
+  PrecomputeLoopExpressionsPass(unsigned TotalInitSize = 0)
+      : TotalInitSize(TotalInitSize) {}
+  PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
+};
+
+} // end namespace llvm
+
+#endif // LLVM_TRANSFORMS_SCALAR_PRECOMPUTE_LOOP_EXPRESSIONS_H
diff --git a/llvm/lib/Passes/PassBuilderPipelines.cpp b/llvm/lib/Passes/PassBuilderPipelines.cpp
index 90ba3b541553e2..36351e0f510220 100644
--- a/llvm/lib/Passes/PassBuilderPipelines.cpp
+++ b/llvm/lib/Passes/PassBuilderPipelines.cpp
@@ -115,6 +115,7 @@
 #include "llvm/Transforms/Scalar/MemCpyOptimizer.h"
 #include "llvm/Transforms/Scalar/MergedLoadStoreMotion.h"
 #include "llvm/Transforms/Scalar/NewGVN.h"
+#include "llvm/Transforms/Scalar/PrecomputeLoopExpressions.h"
 #include "llvm/Transforms/Scalar/Reassociate.h"
 #include "llvm/Transforms/Scalar/SCCP.h"
 #include "llvm/Transforms/Scalar/SROA.h"
@@ -302,6 +303,8 @@ namespace llvm {
 extern cl::opt<bool> EnableMemProfContextDisambiguation;
 
 extern cl::opt<bool> EnableInferAlignmentPass;
+
+extern cl::opt<bool> DisablePCLE;
 } // namespace llvm
 
 PipelineTuningOptions::PipelineTuningOptions() {
@@ -681,6 +684,10 @@ PassBuilder::buildFunctionSimplificationPipeline(OptimizationLevel Level,
   FPM.addPass(createFunctionToLoopPassAdaptor(std::move(LPM1),
                                               /*UseMemorySSA=*/true,
                                               /*UseBlockFrequencyInfo=*/true));
+
+  if (!DisablePCLE)
+    FPM.addPass(PrecomputeLoopExpressionsPass());
+
   FPM.addPass(
       SimplifyCFGPass(SimplifyCFGOptions().convertSwitchRangeToICmp(true)));
   FPM.addPass(InstCombinePass());
diff --git a/llvm/lib/Transforms/Scalar/CMakeLists.txt b/llvm/lib/Transforms/Scalar/CMakeLists.txt
index ba09ebf8b04c4c..8b5bd39d1242d0 100644
--- a/llvm/lib/Transforms/Scalar/CMakeLists.txt
+++ b/llvm/lib/Transforms/Scalar/CMakeLists.txt
@@ -61,6 +61,7 @@ add_llvm_component_library(LLVMScalarOpts
   NewGVN.cpp
   PartiallyInlineLibCalls.cpp
   PlaceSafepoints.cpp
+  PrecomputeLoop.cpp
   Reassociate.cpp
   Reg2Mem.cpp
   RewriteStatepointsForGC.cpp
diff --git a/llvm/lib/Transforms/Scalar/PrecomputeLoop.cpp b/llvm/lib/Transforms/Scalar/PrecomputeLoop.cpp
new file mode 100644
index 00000000000000..3af149db11b28c
--- /dev/null
+++ b/llvm/lib/Transforms/Scalar/PrecomputeLoop.cpp
@@ -0,0 +1,1107 @@
+//===-------- PrecomputeLoop.cpp - Precompute expressions in a loop -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass detects and evaluates expressions based on loop induction
+// variables. Loops and induction variables will need to have compile-time known
+// trip count and increments. Then this pass will determine if the detected
+// expressions can benefit from being replaced with loads to precomputed table.
+// The precomputed table is initialized with values computed based on
+// expressions and iteration space.
+//
+// For example:
+//  int N = 36, sum;
+//  for (int p=0; p<N; p++){
+//    sum = 0;
+//    for (int m=0; m < N/2; m++)
+//      sum += cos_table[((2*p+1+N/2)*(2*m+1))%144];
+//    out[p] = sum;
+//  }
+//
+// Expression "(...*(2*m+1))%144" is detected and to be replaced with a single
+// load to precomputed table.
+// The precomputed table is created as a ConstantArray of size
+// [36 x [18 x i32]], and use the expression and iteration space to initialize.
+// E.g., array element at (p=0,m=2) is 95.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/APInt.h"
+#include "llvm/Analysis/GlobalsModRef.h"
+#include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/Scalar/PrecomputeLoopExpressions.h"
+
+#include <atomic>
+#include <deque>
+#include <map>
+#include <set>
+#include <vector>
+
+#define DEBUG_TYPE "pcle"
+
+using namespace llvm;
+
+static cl::opt<unsigned> MinCostThreshold("pcle-min-cost", cl::Hidden,
+                                          cl::init(8));
+
+static const int kByte = 1024;
+static const int MByte = 1024 * 1024;
+
+static cl::opt<unsigned> MaxSizeThreshold("pcle-max-size", cl::Hidden,
+                                          cl::init(512 * kByte));
+
+static cl::opt<unsigned> MaxTotalSizeThreshold("pcle-max-total-size",
+                                               cl::Hidden, cl::init(2 * MByte));
+
+namespace llvm {
+cl::opt<bool> DisablePCLE("disable-pcle", cl::Hidden, cl::init(false),
+                          cl::desc("Disable Precomputing Loop Expressions"));
+}
+
+namespace {
+typedef int32_t Integer;
+#define BitSize(T) (8 * sizeof(T))
+
+struct IVInfo {
+  IVInfo() : L(0) {}
+  IVInfo(Loop *Lp) : L(Lp) {}
+  Integer Start, End, Bump;
+  Loop *L;
+
+  bool EqualIterSpace(const IVInfo &I) const {
+    return Start == I.Start && End == I.End && Bump == I.Bump;
+  }
+};
+#ifndef NDEBUG
+raw_ostream &operator<<(raw_ostream &OS, const IVInfo &II) {
+  if (II.L)
+    OS << "Loop header: " << II.L->getHeader()->getName();
+  else
+    OS << "No loop";
+  OS << "   Start:" << II.Start << "  End:" << II.End << "  Bump:" << II.Bump;
+  return OS;
+}
+#endif
+
+typedef std::vector<Value *> ValueVect;
+typedef std::map<Value *, IVInfo> IVInfoMap;
+
+#ifndef NDEBUG
+raw_ostream &operator<<(raw_ostream &OS, const IVInfoMap &M) {
+  for (auto &I : M)
+    OS << I.first->getName() << " -> " << I.second << '\n';
+  return OS;
+}
+#endif
+
+class InitDescKey {
+public:
+  ArrayType *ATy;
+  ValueVect IVs;
+
+  InitDescKey() : ATy(0), IVs(), IVInfos(0) {}
+  InitDescKey(ArrayType *T, ValueVect &Vs, IVInfoMap &IVM)
+      : ATy(T), IVs(Vs), IVInfos(&IVM) {}
+
+  bool operator==(const InitDescKey &K) const {
+    if (ATy != K.ATy)
+      return false;
+
+    unsigned Dims = IVs.size();
+    if (Dims != K.IVs.size())
+      return false;
+
+    for (unsigned i = 0; i < Dims; ++i) {
+      IVInfo &I = (*IVInfos)[IVs[i]];
+      IVInfo &KI = (*K.IVInfos)[K.IVs[i]];
+      if (!I.EqualIterSpace(KI))
+        return false;
+    }
+    return true;
+  }
+  bool operator<(const InitDescKey &K) const {
+    unsigned Dims = IVs.size();
+    if (Dims != K.IVs.size())
+      return Dims < K.IVs.size();
+    // Dims are equal here.
+    if (ATy != K.ATy)
+      return uintptr_t(ATy) < uintptr_t(K.ATy);
+    // Types are the same.
+    for (unsigned i = 0; i < Dims; ++i) {
+      IVInfo &I = (*IVInfos)[IVs[i]];
+      IVInfo &KI = (*K.IVInfos)[K.IVs[i]];
+      if (I.Start != KI.Start)
+        return I.Start < KI.Start;
+      if (I.End != KI.End)
+        return I.End < KI.End;
+      if (I.Bump != KI.Bump)
+        return I.Bump < KI.Bump;
+    }
+    return false;
+  }
+
+private:
+  IVInfoMap *IVInfos;
+};
+
+class InitDescVal {
+public:
+  Value *Ex;
+  GlobalVariable *GV;
+  Constant *Init;
+  unsigned Seq;
+
+  InitDescVal() : Ex(0), GV(0), Init(0), Seq(0) {}
+  InitDescVal(Value *E, GlobalVariable *G, Constant *I)
+      : Ex(E), GV(G), Init(I), Seq(std::atomic_fetch_add(&SeqCounter, 1U)) {}
+
+  static std::atomic<unsigned> SeqCounter;
+};
+
+typedef std::vector<Integer> IntVect;
+typedef std::deque<Value *> ValueQueue;
+typedef std::set<Value *> ValueSet;
+typedef std::pair<GlobalVariable *, Value *> AdjustedInit;
+typedef std::multimap<InitDescKey, InitDescVal> InitializerCache;
+
+struct OrderMap {
+  OrderMap() {}
+  typedef std::map<Instruction *, unsigned> MapType;
+  MapType::mapped_type operator[](Instruction *In) {
+    if (Map.find(In) == Map.end())
+      recalculate(*In->getParent()->getParent());
+    assert(Map.find(In) != Map.end());
+    return Map[In];
+  }
+
+  void recalculate(Function &F);
+  MapType Map;
+};
+
+void OrderMap::recalculate(Function &F) {
+  Map.clear();
+  unsigned Ord = 0;
+  for (auto &B : F)
+    for (auto &I : B)
+      Map.insert(std::make_pair(&I, ++Ord));
+}
+
+#ifndef NDEBUG
+raw_ostream &operator<<(raw_ostream &OS, const ValueSet &S) {
+  OS << '{';
+  for (auto &I : S)
+    OS << ' ' << *I;
+  OS << " }";
+  return OS;
+}
+#endif
+
+class PrecomputeLoopExpressions {
+public:
+  PrecomputeLoopExpressions(DominatorTree *DT, LoopInfo *LI,
+                            ScalarEvolution *SE, TargetLibraryInfo *TLI,
+                            unsigned TotalInitSize)
+      : DT(DT), LI(LI), SE(SE), TLI(TLI), TotalInitSize(TotalInitSize) {};
+
+  bool run(Function &Fn);
+
+private:
+  bool isLoopValid(Loop *L);
+  bool processLatchForIV(Instruction *TrIn, Value *&IV, IVInfo &IVI);
+  bool processPHIForIV(Instruction *PIn, Value *IV, IVInfo &IVI);
+  void collectInductionVariables();
+
+  bool isAllowedOpcode(unsigned Opc);
+  bool verifyExpressionNode(Value *Ex, ValueSet &Valid);
+  bool verifyExpression(Value *Ex, ValueSet &Valid);
+  void extendExpression(Value *Ex, ValueSet &Valid, ValueSet &New);
+  unsigned computeInitializerSize(Value *V);
+  unsigned computeExpressionCost(Value *V, ValueSet &Vs);
+  void collectCandidateExpressions();
+
+  void extractInductionVariables(Value *Ex, ValueVect &IVs);
+  ArrayType *createTypeForArray(Type *ETy, ValueVect &IVs);
+  Integer evaluateExpression(Value *Ex, ValueVect &IVs, IntVect &C);
+  Constant *createInitializerForSlice(Value *Ex, unsigned Dim, ArrayType *ATy,
+                                      ValueVect &IVs, bool Zero, IntVect &C,
+                                      IntVect &Starts, IntVect &Ends,
+                                      IntVect &Bumps);
+  Constant *createInitializerForArray(Value *Ex, ArrayType *ATy,
+                                      ValueVect &IVs);
+  AdjustedInit getInitializerForArray(Value *Ex, ArrayType *ATy,
+                                      ValueVect &IVs);
+  Value *computeDifference(Value *A, Value *B);
+  bool rewriteExpression(Value *Ex, Value *Adj, ArrayType *ATy, ValueVect &IVs,
+                         GlobalVariable *GV);
+  bool processCandidateExpressions();
+
+  Function *F;
+  DominatorTree *DT;
+  LoopInfo *LI;
+  ScalarEvolution *SE;
+  TargetLibraryInfo *TLI;
+  OrderMap Order;
+
+  IVInfoMap IVInfos;
+  ValueSet IVEs;
+  InitializerCache InitCache;
+  unsigned TotalInitSize;
+
+  static std::atomic<unsigned> Counter;
+};
+} // namespace
+
+std::atomic<unsigned> InitDescVal::SeqCounter(0);
+std::atomic<unsigned> PrecomputeLoopExpressions::Counter(0);
+
+static unsigned Log2p(unsigned A) {
+  if (A == 0)
+    return 1;
+
+  unsigned L = 1;
+  while (A >>= 1)
+    L++;
+
+  return L;
+}
+
+bool PrecomputeLoopExpressions::isLoopValid(Loop *L) {
+  BasicBlock *H = L->getHeader();
+  if (!H)
+    return false;
+  BasicBlock *PH = L->getLoopPreheader();
+  if (!PH)
+    return false;
+  BasicBlock *EB = L->getExitingBlock();
+  if (!EB)
+    return false;
+
+  if (std::distance(pred_begin(H), pred_end(H)) != 2)
+    return false;
+
+  unsigned TC = SE->getSmallConstantTripCount(L, EB);
+  if (TC == 0)
+    return false;
+
+  return true;
+}
+
+bool PrecomputeLoopExpressions::processLatchForIV(Instruction *TrIn, Value *&IV,
+                                                  IVInfo &IVI) {
+  // Need a conditional branch.
+  BranchInst *Br = dyn_cast<BranchInst>(TrIn);
+  if (!Br || !Br->isConditional())
+    return false;
+
+  // The branch condition needs to be an integer compare.
+  Value *CV = Br->getCondition();
+  Instruction *CIn = dyn_cast<Instruction>(CV);
+  if (!CIn || CIn->getOpcode() != Instruction::ICmp)
+    return false;
+
+  // The comparison has to be less-than.
+  ICmpInst *ICIn = cast<ICmpInst>(CIn);
+  CmpInst::Predicate P = ICIn->getPredicate();
+  if (P != CmpInst::ICMP_ULT && P != CmpInst::ICMP_SLT)
+    return false;
+
+  // Less-than a constant int to be exact.
+  Value *CR = ICIn->getOperand(1);
+  if (!isa<ConstantInt>(CR))
+    return false;
+
+  // The int has to fit in 32 bits.
+  const APInt &U = cast<ConstantInt>(CR)->getValue();
+  if (!U.isSignedIntN(BitSize(Integer)))
+    return false;
+
+  // The value that is less-than the int needs to be an add.
+  Value *VC = ICIn->getOperand(0);
+  Instruction *VCIn = dyn_cast<Instruction>(VC);
+  if (!VCIn || VCIn->getOpcode() != Instruction::Add)
+    return false;
+
+  // An add of a constant int.
+  Value *ValA, *ValI;
+  if (isa<ConstantInt>(VCIn->getOperand(1))) {
+    ValA = VCIn->getOperand(0);
+    ValI = VCIn->getOperand(1);
+  } else {
+    ValA = VCIn->getOperand(1);
+    ValI = VCIn->getOperand(0);
+  }
+  if (!isa<ConstantInt>(ValI))
+    return false;
+
+  // The added int has to fit in 32 bits.
+  const APInt &B = cast<ConstantInt>(ValI)->getValue();
+  if (!B.isSignedIntN(BitSize(Integer)))
+    return false;
+
+  // Done...
+  IV = ValA;
+  IVI.End = (Integer)U.getSExtValue();
+  IVI.Bump = (Integer)B.getSExtValue();
+  return true;
+}
+
+bool PrecomputeLoopExpressions::processPHIForIV(Instruction *PIn, Value *IV,
+                                                IVInfo &IVI) {
+  if (IV != PIn)
+    return false;
+
+  // The PHI must only have two incoming blocks.
+  PHINode *P = cast<PHINode>(PIn);
+  if (P->getNumIncomingValues() != 2)
+    return false;
+
+  // The blocks have to be preheader and loop latch.
+  BasicBlock *PH = IVI.L->getLoopPreheader();
+  BasicBlock *LT = IVI.L->getLoopLatch();
+
+  if (P->getIncomingBlock(0) == PH) {
+    if (P->getIncomingBlock(1) != LT)
+      return false;
+  } else if (P->getIncomingBlock(1) == PH) {
+    if (P->getIncomingBlock(0) != LT)
+      return false;
+  } else {
+    return false;
+  }
+
+  // The value coming from the preheader needs to be a constant int.
+  Value *VPH = P->getIncomingValueForBlock(PH);
+  if (!isa<ConstantInt>(VPH))
+    return false;
+
+  // That int has to fit in 32 bits.
+  const APInt &S = cast<ConstantInt>(VPH)->getValue();
+  if (!S.isSignedIntN(BitSize(Integer)))
+    return false;
+
+  // All checks passed.
+  IVI.Start = static_cast<Integer>(S.getSExtValue());
+  return true;
+}
+
+void PrecomputeLoopExpressions::collectInductionVariables() {
+  IVInfos.clear();
+
+  typedef std::deque<Loop *> LoopQueue;
+  LoopQueue Work;
+
+  for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I) {
+    Work.push_back(*I);
+  }
+
+  while (!Work.empty()) {
+    Loop *L = Work.front();
+    Work.pop_front();
+
+    for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I) {
+      Work.push_back(*I);
+    }
+    if (!isLoopValid(L))
+      continue;
+
+    Value *IV;
+    IVInfo IVI(L);
+    Instruction *TrIn = L->getLoopLatch()->getTerminator();
+
+    bool LatchOk = processLatchForIV(TrIn, IV, IVI);
+    if (!LatchOk)
+      continue;
+
+    BasicBlock *H = L->getHeader();
+    for (BasicBlock::iterator PI = H->begin(); isa<PHINode>(PI); ++PI) {
+      Instruction *I = &*PI;
+      if (I == IV) {
+        bool PHIOk = processPHIForIV(I, IV, IVI);
+        if (PHIOk) {
+          IVInfos.insert(std::make_pair(IV, IVI));
+        }
+        break;
+      }
+    }
+  }
+}
+
+bool PrecomputeLoopExpressions::isAllowedOpcode(unsigned Opc) {
+  switch (Opc) {
+  case Instruction::Add:
+  case Instruction::Sub:
+  case Instruction::Mul:
+  case Instruction::UDiv:
+  case Instruction::SDiv:
+  case Instruction::URem:
+  case Instruction::SRem:
+  case Instruction::Shl:
+  case Instruction::AShr:
+  case Instruction::And:
+  case Instruction::Or:
+  case Instruction::Xor:
+    return true;
+  }
+  return false;
+}
+
+bool PrecomputeLoopExpressions::verifyExpressionNode(Value *Ex,
+                                                     ValueSet &Valid) {
+  Type *T = Ex->getType();
+  if (!T->isIntegerTy())
+    return false;
+  if (cast<IntegerType>(T)->getBitWidth() > BitSize(Integer))
+    return false;
+
+  Instruction *In = dyn_cast<Instruction>(Ex);
+  if (!In)
+    return false;
+  if (!isAllowedOpcode(In->getOpcode()))
+    return false;
+
+  return true;
+}
+
+bool PrecomputeLoopExpressions::verifyExpression(Value *Ex, ValueSet &Valid) {
+  if (Valid.count(Ex))
+    return true;
+  if (isa<ConstantInt>(Ex))
+    return true;
+
+  if (!verifyExpressionNode(Ex, Valid))
+    return false;
+
+  assert(isa<Instruction>(Ex) && "Should have checked for instruction");
+  Instruction *In = cast<Instruction>(Ex);
+  for (unsigned i = 0, n = In->getNumOperands(); i < n; ++i) {
+    bool ValidOp = verifyExpression(In->getOperand(i), Valid);
+    if (!ValidOp)
+      return false;
+  }
+  return true;
+}
+
+void PrecomputeLoopExpressions::extendExpression(Value *Ex, ValueSet &Valid,
+                                                 ValueSet &New) {
+  for (Value::user_iterator I = Ex->user_begin(), E = Ex->user_end(); I != E;
+       ++I) {
+    Value *U = *I;
+    if (Valid.count(U))
+      continue;
+    if (U->getType()->isVoidTy())
+      continue;
+
+    bool BadUser = false;
+
+    if (Instruction *In = dyn_cast<Instruction>(U)) {
+      if (In->getOpcode() == Instruction::PHI)
+        continue;
+      if (!verifyExpressionNode(U, Valid))
+        continue;
+
+      for (unsigned i = 0, n = In->getNumOperands(); i < n; ++i) {
+        Value *Op = In->getOperand(i);
+        if (Op != Ex && !verifyExpression(Op, Valid)) {
+          BadUser = true;
+          break;
+        }
+      }
+    } else {
+      BadUser = true;
+    }
+    if (BadUser)
+      continue;
+
+    New.insert(U);
+  }
+}
+
+unsigned PrecomputeLoopExpressions::computeExpressionCost(Value *V,
+                                                          ValueSet &Vs) {
+  if (Vs.count(V))
+    return 0;
+  Vs.insert(V);
+
+  unsigned C = 0;
+  if (Instruction *In = dyn_cast<Instruction>(V)) {
+    switch (In->getOpcode()) {
+    case Instruction::Add:
+    case Instruction::Sub:
+    case Instruction::Shl:
+    case Instruction::AShr:
+    case Instruction::And:
+    case Instruction::Or:
+    case Instruction::Xor:
+      C = 1;
+      break;
+    case Instruction::Mul:
+    case Instruction::UDiv:
+    case Instruction::SDiv:
+    case Instruction::URem:
+    case Instruction::SRem:
+      C = 3;
+      break;
+    case Instruction::PHI:
+      return 0;
+    }
+
+    for (unsigned i = 0, n = In->getNumOperands(); i < n; ++i) {
+      C += computeExpressionCost(In->getOperand(i), Vs);
+    }
+  }
+
+  return C;
+}
+
+unsigned PrecomputeLoopExpressions::computeInitializerSize(Value *V) {
+  ValueVect IVs;
+
+  extractInductionVariables(V, IVs);
+
+  Type *T = V->getType();
+  assert(T->isIntegerTy());
+  unsigned Total = (cast<IntegerType>(T)->getBitWidth()) / 8;
+
+  for (unsigned i = 0, Dims = IVs.size(); i < Dims; ++i) {
+    IVInfo &IVI = IVInfos[IVs[i]];
+    unsigned D = std::abs(IVI.End - IVI.Start);
+    if (Log2p(D) + Log2p(Total) > 8 * sizeof(Integer))
+      return UINT_MAX;
+    Total *= D;
+  }
+
+  return Total;
+}
+
+void PrecomputeLoopExpressions::collectCandidateExpressions() {
+  ValueQueue Work;
+
+  IVEs.clear();
+
+  for (auto &KV : IVInfos) {
+    IVE...
[truncated]

@huihzhang huihzhang changed the title Introduce PrecomputLoopExpressionsPass. Introduce PrecomputeLoopExpressionsPass. Apr 26, 2024
Copy link

github-actions bot commented Apr 26, 2024

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff 300340f656d762afa8bde5fc398757d2951560bf 3ef11e605edb2bfd11f3a384579fa1f7f6c6bc9a -- llvm/include/llvm/Transforms/Scalar/PrecomputeLoopExpressions.h llvm/lib/Transforms/Scalar/PrecomputeLoop.cpp llvm/lib/Passes/PassBuilderPipelines.cpp
View the diff from clang-format here.
diff --git a/llvm/lib/Transforms/Scalar/PrecomputeLoop.cpp b/llvm/lib/Transforms/Scalar/PrecomputeLoop.cpp
index b6d61232d0..9a526a9dfd 100644
--- a/llvm/lib/Transforms/Scalar/PrecomputeLoop.cpp
+++ b/llvm/lib/Transforms/Scalar/PrecomputeLoop.cpp
@@ -219,7 +219,7 @@ public:
   PrecomputeLoopExpressions(DominatorTree *DT, LoopInfo *LI,
                             ScalarEvolution *SE, TargetLibraryInfo *TLI,
                             unsigned TotalInitSize)
-      : DT(DT), LI(LI), SE(SE), TLI(TLI), TotalInitSize(TotalInitSize) {};
+      : DT(DT), LI(LI), SE(SE), TLI(TLI), TotalInitSize(TotalInitSize){};
 
   bool run(Function &Fn);
 

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants