Skip to content

Conversation

fhahn
Copy link
Contributor

@fhahn fhahn commented Oct 13, 2025

Move URem matching to ScalarEvolutionPatternMatch.h so it can
be re-used together with other matchers.

Depends on #163169 (included in PR)

@fhahn fhahn requested review from artagnon and nikic October 13, 2025 10:09
@llvmbot llvmbot added llvm:analysis Includes value tracking, cost tables and constant folding llvm:transforms labels Oct 13, 2025
@llvmbot
Copy link
Member

llvmbot commented Oct 13, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Florian Hahn (fhahn)

Changes

Move URem matching to ScalarEvolutionPatternMatch.h so it can
be re-used together with other matchers.

Depends on #163169 (included in PR)


Full diff: https://github.com/llvm/llvm-project/pull/163170.diff

5 Files Affected:

  • (modified) llvm/include/llvm/Analysis/ScalarEvolution.h (-4)
  • (modified) llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h (+80)
  • (modified) llvm/lib/Analysis/ScalarEvolution.cpp (+28-100)
  • (modified) llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp (+1-1)
  • (modified) llvm/unittests/Analysis/ScalarEvolutionTest.cpp (+5-7)
diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 8876e4ed6ae4f..e5a6c8cc0a6aa 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -2316,10 +2316,6 @@ class ScalarEvolution {
   /// an add rec on said loop.
   void getUsedLoops(const SCEV *S, SmallPtrSetImpl<const Loop *> &LoopsUsed);
 
-  /// Try to match the pattern generated by getURemExpr(A, B). If successful,
-  /// Assign A and B to LHS and RHS, respectively.
-  LLVM_ABI bool matchURem(const SCEV *Expr, const SCEV *&LHS, const SCEV *&RHS);
-
   /// Look for a SCEV expression with type `SCEVType` and operands `Ops` in
   /// `UniqueSCEVs`.  Return if found, else nullptr.
   SCEV *findExistingSCEVInCache(SCEVTypes SCEVType, ArrayRef<const SCEV *> Ops);
diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
index 164b46b54890b..871028de3163c 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -182,6 +182,12 @@ m_scev_PtrToInt(const Op0_t &Op0) {
   return SCEVUnaryExpr_match<SCEVPtrToIntExpr, Op0_t>(Op0);
 }
 
+template <typename Op0_t>
+inline SCEVUnaryExpr_match<SCEVTruncateExpr, Op0_t>
+m_scev_Trunc(const Op0_t &Op0) {
+  return m_scev_Unary<SCEVTruncateExpr>(Op0);
+}
+
 /// Match a binary SCEV.
 template <typename SCEVTy, typename Op0_t, typename Op1_t,
           SCEV::NoWrapFlags WrapFlags = SCEV::FlagAnyWrap,
@@ -246,6 +252,80 @@ m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1) {
   return m_scev_Binary<SCEVUDivExpr>(Op0, Op1);
 }
 
+/// Match unsigned remainder pattern.
+/// Matches patterns generated by getURemExpr.
+template <typename Op0_t, typename Op1_t> struct SCEVURem_match {
+  Op0_t Op0;
+  Op1_t Op1;
+  ScalarEvolution &SE;
+
+  SCEVURem_match(Op0_t Op0, Op1_t Op1, ScalarEvolution &SE)
+      : Op0(Op0), Op1(Op1), SE(SE) {}
+
+  bool match(const SCEV *Expr) const {
+    if (Expr->getType()->isPointerTy())
+      return false;
+
+    // Try to match 'zext (trunc A to iB) to iY', which is used
+    // for URem with constant power-of-2 second operands. Make sure the size of
+    // the operand A matches the size of the whole expressions.
+    const SCEV *LHS;
+    if (SCEVPatternMatch::match(Expr, m_scev_ZExt(m_scev_Trunc(m_SCEV(LHS))))) {
+      Type *TruncTy = cast<SCEVZeroExtendExpr>(Expr)->getOperand()->getType();
+      // Bail out if the type of the LHS is larger than the type of the
+      // expression for now.
+      if (SE.getTypeSizeInBits(LHS->getType()) >
+          SE.getTypeSizeInBits(Expr->getType()))
+        return false;
+      if (LHS->getType() != Expr->getType())
+        LHS = SE.getZeroExtendExpr(LHS, Expr->getType());
+      const SCEV *RHS =
+          SE.getConstant(APInt(SE.getTypeSizeInBits(Expr->getType()), 1)
+                         << SE.getTypeSizeInBits(TruncTy));
+      return Op0.match(LHS) && Op1.match(RHS);
+    }
+    const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
+    if (Add == nullptr || Add->getNumOperands() != 2)
+      return false;
+
+    const SCEV *A = Add->getOperand(1);
+    const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
+
+    if (Mul == nullptr)
+      return false;
+
+    const auto MatchURemWithDivisor = [&](const SCEV *B) {
+      // (SomeExpr + (-(SomeExpr / B) * B)).
+      if (Expr == SE.getURemExpr(A, B))
+        return Op0.match(A) && Op1.match(B);
+      return false;
+    };
+
+    // (SomeExpr + (-1 * (SomeExpr / B) * B)).
+    if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
+      return MatchURemWithDivisor(Mul->getOperand(1)) ||
+             MatchURemWithDivisor(Mul->getOperand(2));
+
+    // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)).
+    if (Mul->getNumOperands() == 2)
+      return MatchURemWithDivisor(Mul->getOperand(1)) ||
+             MatchURemWithDivisor(Mul->getOperand(0)) ||
+             MatchURemWithDivisor(SE.getNegativeSCEV(Mul->getOperand(1))) ||
+             MatchURemWithDivisor(SE.getNegativeSCEV(Mul->getOperand(0)));
+    return false;
+  }
+};
+
+/// Match the mathematical pattern A - (A / B) * B, where A and B can be
+/// arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used
+/// for URem with constant power-of-2 second operands. It's not always easy, as
+/// A and B can be folded (imagine A is X / 2, and B is 4, A / B becomes X / 8).
+template <typename Op0_t, typename Op1_t>
+inline SCEVURem_match<Op0_t, Op1_t> m_scev_URem(Op0_t LHS, Op1_t RHS,
+                                                ScalarEvolution &SE) {
+  return SCEVURem_match<Op0_t, Op1_t>(LHS, RHS, SE);
+}
+
 inline class_match<const Loop> m_Loop() { return class_match<const Loop>(); }
 
 /// Match an affine SCEVAddRecExpr.
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index b5b4cd978d7a8..332c29ff368f7 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -1774,7 +1774,7 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty,
   {
     const SCEV *LHS;
     const SCEV *RHS;
-    if (matchURem(Op, LHS, RHS))
+    if (match(Op, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), *this)))
       return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
                          getZeroExtendExpr(RHS, Ty, Depth + 1));
   }
@@ -2699,17 +2699,13 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
   }
 
   // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
-  if (Ops.size() == 2) {
-    const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[0]);
-    if (Mul && Mul->getNumOperands() == 2 &&
-        Mul->getOperand(0)->isAllOnesValue()) {
-      const SCEV *X;
-      const SCEV *Y;
-      if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) {
-        return getMulExpr(Y, getUDivExpr(X, Y));
-      }
-    }
-  }
+  const SCEV *X;
+  const SCEV *Y;
+  if (Ops.size() == 2 &&
+      match(Ops[0], m_scev_Mul(m_scev_AllOnes(),
+                               m_scev_URem(m_SCEV(X), m_SCEV(Y), *this))) &&
+      X == Ops[1])
+    return getMulExpr(Y, getUDivExpr(X, Y));
 
   // Skip past any other cast SCEVs.
   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
@@ -5419,20 +5415,15 @@ static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
   if (SourceBits != NewBits)
     return nullptr;
 
-  const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(Op);
-  const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(Op);
-  if (!SExt && !ZExt)
-    return nullptr;
-  const SCEVTruncateExpr *Trunc =
-      SExt ? dyn_cast<SCEVTruncateExpr>(SExt->getOperand())
-           : dyn_cast<SCEVTruncateExpr>(ZExt->getOperand());
-  if (!Trunc)
-    return nullptr;
-  const SCEV *X = Trunc->getOperand();
-  if (X != SymbolicPHI)
-    return nullptr;
-  Signed = SExt != nullptr;
-  return Trunc->getType();
+  if (match(Op, m_scev_SExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
+    Signed = true;
+    return cast<SCEVCastExpr>(Op)->getOperand()->getType();
+  }
+  if (match(Op, m_scev_ZExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
+    Signed = false;
+    return cast<SCEVCastExpr>(Op)->getOperand()->getType();
+  }
+  return nullptr;
 }
 
 static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
@@ -15415,67 +15406,6 @@ void PredicatedScalarEvolution::print(raw_ostream &OS, unsigned Depth) const {
     }
 }
 
-// Match the mathematical pattern A - (A / B) * B, where A and B can be
-// arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used
-// for URem with constant power-of-2 second operands.
-// It's not always easy, as A and B can be folded (imagine A is X / 2, and B is
-// 4, A / B becomes X / 8).
-bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
-                                const SCEV *&RHS) {
-  if (Expr->getType()->isPointerTy())
-    return false;
-
-  // Try to match 'zext (trunc A to iB) to iY', which is used
-  // for URem with constant power-of-2 second operands. Make sure the size of
-  // the operand A matches the size of the whole expressions.
-  if (const auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(Expr))
-    if (const auto *Trunc = dyn_cast<SCEVTruncateExpr>(ZExt->getOperand(0))) {
-      LHS = Trunc->getOperand();
-      // Bail out if the type of the LHS is larger than the type of the
-      // expression for now.
-      if (getTypeSizeInBits(LHS->getType()) >
-          getTypeSizeInBits(Expr->getType()))
-        return false;
-      if (LHS->getType() != Expr->getType())
-        LHS = getZeroExtendExpr(LHS, Expr->getType());
-      RHS = getConstant(APInt(getTypeSizeInBits(Expr->getType()), 1)
-                        << getTypeSizeInBits(Trunc->getType()));
-      return true;
-    }
-  const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
-  if (Add == nullptr || Add->getNumOperands() != 2)
-    return false;
-
-  const SCEV *A = Add->getOperand(1);
-  const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
-
-  if (Mul == nullptr)
-    return false;
-
-  const auto MatchURemWithDivisor = [&](const SCEV *B) {
-    // (SomeExpr + (-(SomeExpr / B) * B)).
-    if (Expr == getURemExpr(A, B)) {
-      LHS = A;
-      RHS = B;
-      return true;
-    }
-    return false;
-  };
-
-  // (SomeExpr + (-1 * (SomeExpr / B) * B)).
-  if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
-    return MatchURemWithDivisor(Mul->getOperand(1)) ||
-           MatchURemWithDivisor(Mul->getOperand(2));
-
-  // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)).
-  if (Mul->getNumOperands() == 2)
-    return MatchURemWithDivisor(Mul->getOperand(1)) ||
-           MatchURemWithDivisor(Mul->getOperand(0)) ||
-           MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(1))) ||
-           MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0)));
-  return false;
-}
-
 ScalarEvolution::LoopGuards
 ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
   BasicBlock *Header = L->getHeader();
@@ -15696,20 +15626,18 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
     if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) {
       // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
       // explicitly express that.
-      const SCEV *URemLHS = nullptr;
+      const SCEVUnknown *URemLHS = nullptr;
       const SCEV *URemRHS = nullptr;
-      if (SE.matchURem(LHS, URemLHS, URemRHS)) {
-        if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
-          auto I = RewriteMap.find(LHSUnknown);
-          const SCEV *RewrittenLHS =
-              I != RewriteMap.end() ? I->second : LHSUnknown;
-          RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
-          const auto *Multiple =
-              SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
-          RewriteMap[LHSUnknown] = Multiple;
-          ExprsToRewrite.push_back(LHSUnknown);
-          return;
-        }
+      if (match(LHS,
+                m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE))) {
+        auto I = RewriteMap.find(URemLHS);
+        const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : URemLHS;
+        RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
+        const auto *Multiple =
+            SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
+        RewriteMap[URemLHS] = Multiple;
+        ExprsToRewrite.push_back(URemLHS);
+        return;
       }
     }
 
diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
index 45cee1e7da625..9035e58a707c4 100644
--- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
+++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
@@ -526,7 +526,7 @@ Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) {
   // Recognize the canonical representation of an unsimplifed urem.
   const SCEV *URemLHS = nullptr;
   const SCEV *URemRHS = nullptr;
-  if (SE.matchURem(S, URemLHS, URemRHS)) {
+  if (match(S, m_scev_URem(m_SCEV(URemLHS), m_SCEV(URemRHS), SE))) {
     Value *LHS = expand(URemLHS);
     Value *RHS = expand(URemRHS);
     return InsertBinop(Instruction::URem, LHS, RHS, SCEV::FlagAnyWrap,
diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
index 1a68823b4f254..5d7eded06a760 100644
--- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
+++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
@@ -11,6 +11,7 @@
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
 #include "llvm/Analysis/ScalarEvolutionNormalization.h"
+#include "llvm/Analysis/ScalarEvolutionPatternMatch.h"
 #include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/AsmParser/Parser.h"
 #include "llvm/IR/Constants.h"
@@ -26,6 +27,8 @@
 
 namespace llvm {
 
+using namespace SCEVPatternMatch;
+
 // We use this fixture to ensure that we clean up ScalarEvolution before
 // deleting the PassManager.
 class ScalarEvolutionsTest : public testing::Test {
@@ -64,11 +67,6 @@ static std::optional<APInt> computeConstantDifference(ScalarEvolution &SE,
   return SE.computeConstantDifference(LHS, RHS);
 }
 
-  static bool matchURem(ScalarEvolution &SE, const SCEV *Expr, const SCEV *&LHS,
-                        const SCEV *&RHS) {
-    return SE.matchURem(Expr, LHS, RHS);
-  }
-
   static bool isImpliedCond(
       ScalarEvolution &SE, ICmpInst::Predicate Pred, const SCEV *LHS,
       const SCEV *RHS, ICmpInst::Predicate FoundPred, const SCEV *FoundLHS,
@@ -1524,7 +1522,7 @@ TEST_F(ScalarEvolutionsTest, MatchURem) {
       auto *URemI = getInstructionByName(F, N);
       auto *S = SE.getSCEV(URemI);
       const SCEV *LHS, *RHS;
-      EXPECT_TRUE(matchURem(SE, S, LHS, RHS));
+      EXPECT_TRUE(match(S, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), SE)));
       EXPECT_EQ(LHS, SE.getSCEV(URemI->getOperand(0)));
       EXPECT_EQ(RHS, SE.getSCEV(URemI->getOperand(1)));
       EXPECT_EQ(LHS->getType(), S->getType());
@@ -1537,7 +1535,7 @@ TEST_F(ScalarEvolutionsTest, MatchURem) {
     auto *URem1 = getInstructionByName(F, "rem4");
     auto *S = SE.getSCEV(Ext);
     const SCEV *LHS, *RHS;
-    EXPECT_TRUE(matchURem(SE, S, LHS, RHS));
+    EXPECT_TRUE(match(S, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), SE)));
     EXPECT_NE(LHS, SE.getSCEV(URem1->getOperand(0)));
     // RHS and URem1->getOperand(1) have different widths, so compare the
     // integer values.

@llvmbot
Copy link
Member

llvmbot commented Oct 13, 2025

@llvm/pr-subscribers-llvm-analysis

Author: Florian Hahn (fhahn)

Changes

Move URem matching to ScalarEvolutionPatternMatch.h so it can
be re-used together with other matchers.

Depends on #163169 (included in PR)


Full diff: https://github.com/llvm/llvm-project/pull/163170.diff

5 Files Affected:

  • (modified) llvm/include/llvm/Analysis/ScalarEvolution.h (-4)
  • (modified) llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h (+80)
  • (modified) llvm/lib/Analysis/ScalarEvolution.cpp (+28-100)
  • (modified) llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp (+1-1)
  • (modified) llvm/unittests/Analysis/ScalarEvolutionTest.cpp (+5-7)
diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 8876e4ed6ae4f..e5a6c8cc0a6aa 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -2316,10 +2316,6 @@ class ScalarEvolution {
   /// an add rec on said loop.
   void getUsedLoops(const SCEV *S, SmallPtrSetImpl<const Loop *> &LoopsUsed);
 
-  /// Try to match the pattern generated by getURemExpr(A, B). If successful,
-  /// Assign A and B to LHS and RHS, respectively.
-  LLVM_ABI bool matchURem(const SCEV *Expr, const SCEV *&LHS, const SCEV *&RHS);
-
   /// Look for a SCEV expression with type `SCEVType` and operands `Ops` in
   /// `UniqueSCEVs`.  Return if found, else nullptr.
   SCEV *findExistingSCEVInCache(SCEVTypes SCEVType, ArrayRef<const SCEV *> Ops);
diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
index 164b46b54890b..871028de3163c 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -182,6 +182,12 @@ m_scev_PtrToInt(const Op0_t &Op0) {
   return SCEVUnaryExpr_match<SCEVPtrToIntExpr, Op0_t>(Op0);
 }
 
+template <typename Op0_t>
+inline SCEVUnaryExpr_match<SCEVTruncateExpr, Op0_t>
+m_scev_Trunc(const Op0_t &Op0) {
+  return m_scev_Unary<SCEVTruncateExpr>(Op0);
+}
+
 /// Match a binary SCEV.
 template <typename SCEVTy, typename Op0_t, typename Op1_t,
           SCEV::NoWrapFlags WrapFlags = SCEV::FlagAnyWrap,
@@ -246,6 +252,80 @@ m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1) {
   return m_scev_Binary<SCEVUDivExpr>(Op0, Op1);
 }
 
+/// Match unsigned remainder pattern.
+/// Matches patterns generated by getURemExpr.
+template <typename Op0_t, typename Op1_t> struct SCEVURem_match {
+  Op0_t Op0;
+  Op1_t Op1;
+  ScalarEvolution &SE;
+
+  SCEVURem_match(Op0_t Op0, Op1_t Op1, ScalarEvolution &SE)
+      : Op0(Op0), Op1(Op1), SE(SE) {}
+
+  bool match(const SCEV *Expr) const {
+    if (Expr->getType()->isPointerTy())
+      return false;
+
+    // Try to match 'zext (trunc A to iB) to iY', which is used
+    // for URem with constant power-of-2 second operands. Make sure the size of
+    // the operand A matches the size of the whole expressions.
+    const SCEV *LHS;
+    if (SCEVPatternMatch::match(Expr, m_scev_ZExt(m_scev_Trunc(m_SCEV(LHS))))) {
+      Type *TruncTy = cast<SCEVZeroExtendExpr>(Expr)->getOperand()->getType();
+      // Bail out if the type of the LHS is larger than the type of the
+      // expression for now.
+      if (SE.getTypeSizeInBits(LHS->getType()) >
+          SE.getTypeSizeInBits(Expr->getType()))
+        return false;
+      if (LHS->getType() != Expr->getType())
+        LHS = SE.getZeroExtendExpr(LHS, Expr->getType());
+      const SCEV *RHS =
+          SE.getConstant(APInt(SE.getTypeSizeInBits(Expr->getType()), 1)
+                         << SE.getTypeSizeInBits(TruncTy));
+      return Op0.match(LHS) && Op1.match(RHS);
+    }
+    const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
+    if (Add == nullptr || Add->getNumOperands() != 2)
+      return false;
+
+    const SCEV *A = Add->getOperand(1);
+    const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
+
+    if (Mul == nullptr)
+      return false;
+
+    const auto MatchURemWithDivisor = [&](const SCEV *B) {
+      // (SomeExpr + (-(SomeExpr / B) * B)).
+      if (Expr == SE.getURemExpr(A, B))
+        return Op0.match(A) && Op1.match(B);
+      return false;
+    };
+
+    // (SomeExpr + (-1 * (SomeExpr / B) * B)).
+    if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
+      return MatchURemWithDivisor(Mul->getOperand(1)) ||
+             MatchURemWithDivisor(Mul->getOperand(2));
+
+    // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)).
+    if (Mul->getNumOperands() == 2)
+      return MatchURemWithDivisor(Mul->getOperand(1)) ||
+             MatchURemWithDivisor(Mul->getOperand(0)) ||
+             MatchURemWithDivisor(SE.getNegativeSCEV(Mul->getOperand(1))) ||
+             MatchURemWithDivisor(SE.getNegativeSCEV(Mul->getOperand(0)));
+    return false;
+  }
+};
+
+/// Match the mathematical pattern A - (A / B) * B, where A and B can be
+/// arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used
+/// for URem with constant power-of-2 second operands. It's not always easy, as
+/// A and B can be folded (imagine A is X / 2, and B is 4, A / B becomes X / 8).
+template <typename Op0_t, typename Op1_t>
+inline SCEVURem_match<Op0_t, Op1_t> m_scev_URem(Op0_t LHS, Op1_t RHS,
+                                                ScalarEvolution &SE) {
+  return SCEVURem_match<Op0_t, Op1_t>(LHS, RHS, SE);
+}
+
 inline class_match<const Loop> m_Loop() { return class_match<const Loop>(); }
 
 /// Match an affine SCEVAddRecExpr.
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index b5b4cd978d7a8..332c29ff368f7 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -1774,7 +1774,7 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty,
   {
     const SCEV *LHS;
     const SCEV *RHS;
-    if (matchURem(Op, LHS, RHS))
+    if (match(Op, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), *this)))
       return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
                          getZeroExtendExpr(RHS, Ty, Depth + 1));
   }
@@ -2699,17 +2699,13 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
   }
 
   // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
-  if (Ops.size() == 2) {
-    const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[0]);
-    if (Mul && Mul->getNumOperands() == 2 &&
-        Mul->getOperand(0)->isAllOnesValue()) {
-      const SCEV *X;
-      const SCEV *Y;
-      if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) {
-        return getMulExpr(Y, getUDivExpr(X, Y));
-      }
-    }
-  }
+  const SCEV *X;
+  const SCEV *Y;
+  if (Ops.size() == 2 &&
+      match(Ops[0], m_scev_Mul(m_scev_AllOnes(),
+                               m_scev_URem(m_SCEV(X), m_SCEV(Y), *this))) &&
+      X == Ops[1])
+    return getMulExpr(Y, getUDivExpr(X, Y));
 
   // Skip past any other cast SCEVs.
   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
@@ -5419,20 +5415,15 @@ static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
   if (SourceBits != NewBits)
     return nullptr;
 
-  const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(Op);
-  const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(Op);
-  if (!SExt && !ZExt)
-    return nullptr;
-  const SCEVTruncateExpr *Trunc =
-      SExt ? dyn_cast<SCEVTruncateExpr>(SExt->getOperand())
-           : dyn_cast<SCEVTruncateExpr>(ZExt->getOperand());
-  if (!Trunc)
-    return nullptr;
-  const SCEV *X = Trunc->getOperand();
-  if (X != SymbolicPHI)
-    return nullptr;
-  Signed = SExt != nullptr;
-  return Trunc->getType();
+  if (match(Op, m_scev_SExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
+    Signed = true;
+    return cast<SCEVCastExpr>(Op)->getOperand()->getType();
+  }
+  if (match(Op, m_scev_ZExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
+    Signed = false;
+    return cast<SCEVCastExpr>(Op)->getOperand()->getType();
+  }
+  return nullptr;
 }
 
 static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
@@ -15415,67 +15406,6 @@ void PredicatedScalarEvolution::print(raw_ostream &OS, unsigned Depth) const {
     }
 }
 
-// Match the mathematical pattern A - (A / B) * B, where A and B can be
-// arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used
-// for URem with constant power-of-2 second operands.
-// It's not always easy, as A and B can be folded (imagine A is X / 2, and B is
-// 4, A / B becomes X / 8).
-bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
-                                const SCEV *&RHS) {
-  if (Expr->getType()->isPointerTy())
-    return false;
-
-  // Try to match 'zext (trunc A to iB) to iY', which is used
-  // for URem with constant power-of-2 second operands. Make sure the size of
-  // the operand A matches the size of the whole expressions.
-  if (const auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(Expr))
-    if (const auto *Trunc = dyn_cast<SCEVTruncateExpr>(ZExt->getOperand(0))) {
-      LHS = Trunc->getOperand();
-      // Bail out if the type of the LHS is larger than the type of the
-      // expression for now.
-      if (getTypeSizeInBits(LHS->getType()) >
-          getTypeSizeInBits(Expr->getType()))
-        return false;
-      if (LHS->getType() != Expr->getType())
-        LHS = getZeroExtendExpr(LHS, Expr->getType());
-      RHS = getConstant(APInt(getTypeSizeInBits(Expr->getType()), 1)
-                        << getTypeSizeInBits(Trunc->getType()));
-      return true;
-    }
-  const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
-  if (Add == nullptr || Add->getNumOperands() != 2)
-    return false;
-
-  const SCEV *A = Add->getOperand(1);
-  const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
-
-  if (Mul == nullptr)
-    return false;
-
-  const auto MatchURemWithDivisor = [&](const SCEV *B) {
-    // (SomeExpr + (-(SomeExpr / B) * B)).
-    if (Expr == getURemExpr(A, B)) {
-      LHS = A;
-      RHS = B;
-      return true;
-    }
-    return false;
-  };
-
-  // (SomeExpr + (-1 * (SomeExpr / B) * B)).
-  if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
-    return MatchURemWithDivisor(Mul->getOperand(1)) ||
-           MatchURemWithDivisor(Mul->getOperand(2));
-
-  // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)).
-  if (Mul->getNumOperands() == 2)
-    return MatchURemWithDivisor(Mul->getOperand(1)) ||
-           MatchURemWithDivisor(Mul->getOperand(0)) ||
-           MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(1))) ||
-           MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0)));
-  return false;
-}
-
 ScalarEvolution::LoopGuards
 ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
   BasicBlock *Header = L->getHeader();
@@ -15696,20 +15626,18 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
     if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) {
       // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
       // explicitly express that.
-      const SCEV *URemLHS = nullptr;
+      const SCEVUnknown *URemLHS = nullptr;
       const SCEV *URemRHS = nullptr;
-      if (SE.matchURem(LHS, URemLHS, URemRHS)) {
-        if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
-          auto I = RewriteMap.find(LHSUnknown);
-          const SCEV *RewrittenLHS =
-              I != RewriteMap.end() ? I->second : LHSUnknown;
-          RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
-          const auto *Multiple =
-              SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
-          RewriteMap[LHSUnknown] = Multiple;
-          ExprsToRewrite.push_back(LHSUnknown);
-          return;
-        }
+      if (match(LHS,
+                m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE))) {
+        auto I = RewriteMap.find(URemLHS);
+        const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : URemLHS;
+        RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
+        const auto *Multiple =
+            SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
+        RewriteMap[URemLHS] = Multiple;
+        ExprsToRewrite.push_back(URemLHS);
+        return;
       }
     }
 
diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
index 45cee1e7da625..9035e58a707c4 100644
--- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
+++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
@@ -526,7 +526,7 @@ Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) {
   // Recognize the canonical representation of an unsimplifed urem.
   const SCEV *URemLHS = nullptr;
   const SCEV *URemRHS = nullptr;
-  if (SE.matchURem(S, URemLHS, URemRHS)) {
+  if (match(S, m_scev_URem(m_SCEV(URemLHS), m_SCEV(URemRHS), SE))) {
     Value *LHS = expand(URemLHS);
     Value *RHS = expand(URemRHS);
     return InsertBinop(Instruction::URem, LHS, RHS, SCEV::FlagAnyWrap,
diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
index 1a68823b4f254..5d7eded06a760 100644
--- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
+++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
@@ -11,6 +11,7 @@
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
 #include "llvm/Analysis/ScalarEvolutionNormalization.h"
+#include "llvm/Analysis/ScalarEvolutionPatternMatch.h"
 #include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/AsmParser/Parser.h"
 #include "llvm/IR/Constants.h"
@@ -26,6 +27,8 @@
 
 namespace llvm {
 
+using namespace SCEVPatternMatch;
+
 // We use this fixture to ensure that we clean up ScalarEvolution before
 // deleting the PassManager.
 class ScalarEvolutionsTest : public testing::Test {
@@ -64,11 +67,6 @@ static std::optional<APInt> computeConstantDifference(ScalarEvolution &SE,
   return SE.computeConstantDifference(LHS, RHS);
 }
 
-  static bool matchURem(ScalarEvolution &SE, const SCEV *Expr, const SCEV *&LHS,
-                        const SCEV *&RHS) {
-    return SE.matchURem(Expr, LHS, RHS);
-  }
-
   static bool isImpliedCond(
       ScalarEvolution &SE, ICmpInst::Predicate Pred, const SCEV *LHS,
       const SCEV *RHS, ICmpInst::Predicate FoundPred, const SCEV *FoundLHS,
@@ -1524,7 +1522,7 @@ TEST_F(ScalarEvolutionsTest, MatchURem) {
       auto *URemI = getInstructionByName(F, N);
       auto *S = SE.getSCEV(URemI);
       const SCEV *LHS, *RHS;
-      EXPECT_TRUE(matchURem(SE, S, LHS, RHS));
+      EXPECT_TRUE(match(S, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), SE)));
       EXPECT_EQ(LHS, SE.getSCEV(URemI->getOperand(0)));
       EXPECT_EQ(RHS, SE.getSCEV(URemI->getOperand(1)));
       EXPECT_EQ(LHS->getType(), S->getType());
@@ -1537,7 +1535,7 @@ TEST_F(ScalarEvolutionsTest, MatchURem) {
     auto *URem1 = getInstructionByName(F, "rem4");
     auto *S = SE.getSCEV(Ext);
     const SCEV *LHS, *RHS;
-    EXPECT_TRUE(matchURem(SE, S, LHS, RHS));
+    EXPECT_TRUE(match(S, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), SE)));
     EXPECT_NE(LHS, SE.getSCEV(URem1->getOperand(0)));
     // RHS and URem1->getOperand(1) have different widths, so compare the
     // integer values.

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

return Op0.match(LHS) && Op1.match(RHS);
}
const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
if (Add == nullptr || Add->getNumOperands() != 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to use m_scev_Add here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, this is something I have a follow up in progress to combine the Mul/Add checks using pattern matching

if (Ops.size() == 2 &&
match(Ops[0], m_scev_Mul(m_scev_AllOnes(),
m_scev_URem(m_SCEV(X), m_SCEV(Y), *this))) &&
X == Ops[1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can use m_scev_Specific here instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks. It means we need to use Ops[1] below instead of X.

@fhahn fhahn force-pushed the scev-urem-patternmatch branch from c86bb59 to 3da6689 Compare October 13, 2025 18:46
@fhahn fhahn enabled auto-merge (squash) October 13, 2025 18:48
@fhahn fhahn merged commit 7f04ee1 into llvm:main Oct 13, 2025
9 of 10 checks passed
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Oct 13, 2025
… (#163170)

Move URem matching to ScalarEvolutionPatternMatch.h so it can
be re-used together with other matchers.

Depends on llvm/llvm-project#163169

PR: llvm/llvm-project#163170
akadutta pushed a commit to akadutta/llvm-project that referenced this pull request Oct 14, 2025
Move URem matching to ScalarEvolutionPatternMatch.h so it can 
be re-used together with other matchers.

Depends on llvm#163169

PR: llvm#163170
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

llvm:analysis Includes value tracking, cost tables and constant folding llvm:transforms

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants