From 7de88ddf518d05010cc24e0c691c80a9b7c36b5d Mon Sep 17 00:00:00 2001 From: Florian Hahn Date: Mon, 13 Oct 2025 10:49:27 +0100 Subject: [PATCH 1/2] [SCEV] Move URem matching to ScalarEvolutionPatternMatch.h --- llvm/include/llvm/Analysis/ScalarEvolution.h | 4 - .../Analysis/ScalarEvolutionPatternMatch.h | 74 +++++++++++++ llvm/lib/Analysis/ScalarEvolution.cpp | 103 ++++-------------- .../Utils/ScalarEvolutionExpander.cpp | 2 +- .../Analysis/ScalarEvolutionTest.cpp | 12 +- 5 files changed, 99 insertions(+), 96 deletions(-) diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h index 8876e4ed6ae4f..e5a6c8cc0a6aa 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -2316,10 +2316,6 @@ class ScalarEvolution { /// an add rec on said loop. void getUsedLoops(const SCEV *S, SmallPtrSetImpl &LoopsUsed); - /// Try to match the pattern generated by getURemExpr(A, B). If successful, - /// Assign A and B to LHS and RHS, respectively. - LLVM_ABI bool matchURem(const SCEV *Expr, const SCEV *&LHS, const SCEV *&RHS); - /// Look for a SCEV expression with type `SCEVType` and operands `Ops` in /// `UniqueSCEVs`. Return if found, else nullptr. SCEV *findExistingSCEVInCache(SCEVTypes SCEVType, ArrayRef Ops); diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h index 07a482d4f166a..871028de3163c 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h +++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h @@ -252,6 +252,80 @@ m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1) { return m_scev_Binary(Op0, Op1); } +/// Match unsigned remainder pattern. +/// Matches patterns generated by getURemExpr. +template struct SCEVURem_match { + Op0_t Op0; + Op1_t Op1; + ScalarEvolution &SE; + + SCEVURem_match(Op0_t Op0, Op1_t Op1, ScalarEvolution &SE) + : Op0(Op0), Op1(Op1), SE(SE) {} + + bool match(const SCEV *Expr) const { + if (Expr->getType()->isPointerTy()) + return false; + + // Try to match 'zext (trunc A to iB) to iY', which is used + // for URem with constant power-of-2 second operands. Make sure the size of + // the operand A matches the size of the whole expressions. + const SCEV *LHS; + if (SCEVPatternMatch::match(Expr, m_scev_ZExt(m_scev_Trunc(m_SCEV(LHS))))) { + Type *TruncTy = cast(Expr)->getOperand()->getType(); + // Bail out if the type of the LHS is larger than the type of the + // expression for now. + if (SE.getTypeSizeInBits(LHS->getType()) > + SE.getTypeSizeInBits(Expr->getType())) + return false; + if (LHS->getType() != Expr->getType()) + LHS = SE.getZeroExtendExpr(LHS, Expr->getType()); + const SCEV *RHS = + SE.getConstant(APInt(SE.getTypeSizeInBits(Expr->getType()), 1) + << SE.getTypeSizeInBits(TruncTy)); + return Op0.match(LHS) && Op1.match(RHS); + } + const auto *Add = dyn_cast(Expr); + if (Add == nullptr || Add->getNumOperands() != 2) + return false; + + const SCEV *A = Add->getOperand(1); + const auto *Mul = dyn_cast(Add->getOperand(0)); + + if (Mul == nullptr) + return false; + + const auto MatchURemWithDivisor = [&](const SCEV *B) { + // (SomeExpr + (-(SomeExpr / B) * B)). + if (Expr == SE.getURemExpr(A, B)) + return Op0.match(A) && Op1.match(B); + return false; + }; + + // (SomeExpr + (-1 * (SomeExpr / B) * B)). + if (Mul->getNumOperands() == 3 && isa(Mul->getOperand(0))) + return MatchURemWithDivisor(Mul->getOperand(1)) || + MatchURemWithDivisor(Mul->getOperand(2)); + + // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)). + if (Mul->getNumOperands() == 2) + return MatchURemWithDivisor(Mul->getOperand(1)) || + MatchURemWithDivisor(Mul->getOperand(0)) || + MatchURemWithDivisor(SE.getNegativeSCEV(Mul->getOperand(1))) || + MatchURemWithDivisor(SE.getNegativeSCEV(Mul->getOperand(0))); + return false; + } +}; + +/// Match the mathematical pattern A - (A / B) * B, where A and B can be +/// arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used +/// for URem with constant power-of-2 second operands. It's not always easy, as +/// A and B can be folded (imagine A is X / 2, and B is 4, A / B becomes X / 8). +template +inline SCEVURem_match m_scev_URem(Op0_t LHS, Op1_t RHS, + ScalarEvolution &SE) { + return SCEVURem_match(LHS, RHS, SE); +} + inline class_match m_Loop() { return class_match(); } /// Match an affine SCEVAddRecExpr. diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 00c3dbbf3e800..332c29ff368f7 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -1774,7 +1774,7 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, { const SCEV *LHS; const SCEV *RHS; - if (matchURem(Op, LHS, RHS)) + if (match(Op, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), *this))) return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1), getZeroExtendExpr(RHS, Ty, Depth + 1)); } @@ -2699,17 +2699,13 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, } // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y) - if (Ops.size() == 2) { - const SCEVMulExpr *Mul = dyn_cast(Ops[0]); - if (Mul && Mul->getNumOperands() == 2 && - Mul->getOperand(0)->isAllOnesValue()) { - const SCEV *X; - const SCEV *Y; - if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) { - return getMulExpr(Y, getUDivExpr(X, Y)); - } - } - } + const SCEV *X; + const SCEV *Y; + if (Ops.size() == 2 && + match(Ops[0], m_scev_Mul(m_scev_AllOnes(), + m_scev_URem(m_SCEV(X), m_SCEV(Y), *this))) && + X == Ops[1]) + return getMulExpr(Y, getUDivExpr(X, Y)); // Skip past any other cast SCEVs. while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr) @@ -15410,65 +15406,6 @@ void PredicatedScalarEvolution::print(raw_ostream &OS, unsigned Depth) const { } } -// Match the mathematical pattern A - (A / B) * B, where A and B can be -// arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used -// for URem with constant power-of-2 second operands. -// It's not always easy, as A and B can be folded (imagine A is X / 2, and B is -// 4, A / B becomes X / 8). -bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS, - const SCEV *&RHS) { - if (Expr->getType()->isPointerTy()) - return false; - - // Try to match 'zext (trunc A to iB) to iY', which is used - // for URem with constant power-of-2 second operands. Make sure the size of - // the operand A matches the size of the whole expressions. - if (match(Expr, m_scev_ZExt(m_scev_Trunc(m_SCEV(LHS))))) { - Type *TruncTy = cast(Expr)->getOperand()->getType(); - // Bail out if the type of the LHS is larger than the type of the - // expression for now. - if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(Expr->getType())) - return false; - if (LHS->getType() != Expr->getType()) - LHS = getZeroExtendExpr(LHS, Expr->getType()); - RHS = getConstant(APInt(getTypeSizeInBits(Expr->getType()), 1) - << getTypeSizeInBits(TruncTy)); - return true; - } - const auto *Add = dyn_cast(Expr); - if (Add == nullptr || Add->getNumOperands() != 2) - return false; - - const SCEV *A = Add->getOperand(1); - const auto *Mul = dyn_cast(Add->getOperand(0)); - - if (Mul == nullptr) - return false; - - const auto MatchURemWithDivisor = [&](const SCEV *B) { - // (SomeExpr + (-(SomeExpr / B) * B)). - if (Expr == getURemExpr(A, B)) { - LHS = A; - RHS = B; - return true; - } - return false; - }; - - // (SomeExpr + (-1 * (SomeExpr / B) * B)). - if (Mul->getNumOperands() == 3 && isa(Mul->getOperand(0))) - return MatchURemWithDivisor(Mul->getOperand(1)) || - MatchURemWithDivisor(Mul->getOperand(2)); - - // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)). - if (Mul->getNumOperands() == 2) - return MatchURemWithDivisor(Mul->getOperand(1)) || - MatchURemWithDivisor(Mul->getOperand(0)) || - MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(1))) || - MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0))); - return false; -} - ScalarEvolution::LoopGuards ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) { BasicBlock *Header = L->getHeader(); @@ -15689,20 +15626,18 @@ void ScalarEvolution::LoopGuards::collectFromBlock( if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) { // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to // explicitly express that. - const SCEV *URemLHS = nullptr; + const SCEVUnknown *URemLHS = nullptr; const SCEV *URemRHS = nullptr; - if (SE.matchURem(LHS, URemLHS, URemRHS)) { - if (const SCEVUnknown *LHSUnknown = dyn_cast(URemLHS)) { - auto I = RewriteMap.find(LHSUnknown); - const SCEV *RewrittenLHS = - I != RewriteMap.end() ? I->second : LHSUnknown; - RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS); - const auto *Multiple = - SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS); - RewriteMap[LHSUnknown] = Multiple; - ExprsToRewrite.push_back(LHSUnknown); - return; - } + if (match(LHS, + m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE))) { + auto I = RewriteMap.find(URemLHS); + const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : URemLHS; + RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS); + const auto *Multiple = + SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS); + RewriteMap[URemLHS] = Multiple; + ExprsToRewrite.push_back(URemLHS); + return; } } diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp index 45cee1e7da625..9035e58a707c4 100644 --- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp +++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp @@ -526,7 +526,7 @@ Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) { // Recognize the canonical representation of an unsimplifed urem. const SCEV *URemLHS = nullptr; const SCEV *URemRHS = nullptr; - if (SE.matchURem(S, URemLHS, URemRHS)) { + if (match(S, m_scev_URem(m_SCEV(URemLHS), m_SCEV(URemRHS), SE))) { Value *LHS = expand(URemLHS); Value *RHS = expand(URemRHS); return InsertBinop(Instruction::URem, LHS, RHS, SCEV::FlagAnyWrap, diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp index 1a68823b4f254..5d7eded06a760 100644 --- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp +++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp @@ -11,6 +11,7 @@ #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/ScalarEvolutionNormalization.h" +#include "llvm/Analysis/ScalarEvolutionPatternMatch.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/AsmParser/Parser.h" #include "llvm/IR/Constants.h" @@ -26,6 +27,8 @@ namespace llvm { +using namespace SCEVPatternMatch; + // We use this fixture to ensure that we clean up ScalarEvolution before // deleting the PassManager. class ScalarEvolutionsTest : public testing::Test { @@ -64,11 +67,6 @@ static std::optional computeConstantDifference(ScalarEvolution &SE, return SE.computeConstantDifference(LHS, RHS); } - static bool matchURem(ScalarEvolution &SE, const SCEV *Expr, const SCEV *&LHS, - const SCEV *&RHS) { - return SE.matchURem(Expr, LHS, RHS); - } - static bool isImpliedCond( ScalarEvolution &SE, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, ICmpInst::Predicate FoundPred, const SCEV *FoundLHS, @@ -1524,7 +1522,7 @@ TEST_F(ScalarEvolutionsTest, MatchURem) { auto *URemI = getInstructionByName(F, N); auto *S = SE.getSCEV(URemI); const SCEV *LHS, *RHS; - EXPECT_TRUE(matchURem(SE, S, LHS, RHS)); + EXPECT_TRUE(match(S, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), SE))); EXPECT_EQ(LHS, SE.getSCEV(URemI->getOperand(0))); EXPECT_EQ(RHS, SE.getSCEV(URemI->getOperand(1))); EXPECT_EQ(LHS->getType(), S->getType()); @@ -1537,7 +1535,7 @@ TEST_F(ScalarEvolutionsTest, MatchURem) { auto *URem1 = getInstructionByName(F, "rem4"); auto *S = SE.getSCEV(Ext); const SCEV *LHS, *RHS; - EXPECT_TRUE(matchURem(SE, S, LHS, RHS)); + EXPECT_TRUE(match(S, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), SE))); EXPECT_NE(LHS, SE.getSCEV(URem1->getOperand(0))); // RHS and URem1->getOperand(1) have different widths, so compare the // integer values. From 3da6689f9ccea1e7eb342b6f9de64d88b72cf0a2 Mon Sep 17 00:00:00 2001 From: Florian Hahn Date: Mon, 13 Oct 2025 19:44:56 +0100 Subject: [PATCH 2/2] !fixup use m_scev_Specific. --- llvm/lib/Analysis/ScalarEvolution.cpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 332c29ff368f7..3fab6b0572cb7 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -2699,13 +2699,12 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, } // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y) - const SCEV *X; const SCEV *Y; if (Ops.size() == 2 && - match(Ops[0], m_scev_Mul(m_scev_AllOnes(), - m_scev_URem(m_SCEV(X), m_SCEV(Y), *this))) && - X == Ops[1]) - return getMulExpr(Y, getUDivExpr(X, Y)); + match(Ops[0], + m_scev_Mul(m_scev_AllOnes(), + m_scev_URem(m_scev_Specific(Ops[1]), m_SCEV(Y), *this)))) + return getMulExpr(Y, getUDivExpr(Ops[1], Y)); // Skip past any other cast SCEVs. while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)