diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index b0d5f0ccfbbfb2..b50e6f5697aa73 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -5887,6 +5887,44 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { return getUnknown(PN); } +bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, + SCEVTypes RootKind) { + struct FindClosure { + const SCEV *OperandToFind; + const SCEVTypes RootKind; // Must be a sequential min/max expression. + const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind. + + bool Found = false; + + bool canRecurseInto(SCEVTypes Kind) const { + // We can only recurse into the SCEV expression of the same effective type + // as the type of our root SCEV expression. + return RootKind == Kind || NonSequentialRootKind == Kind; + }; + + FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind) + : OperandToFind(OperandToFind), RootKind(RootKind), + NonSequentialRootKind( + SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType( + RootKind)) {} + + bool follow(const SCEV *S) { + if (isDone()) + return false; + + Found = S == OperandToFind; + + return !isDone() && canRecurseInto(S->getSCEVType()); + } + + bool isDone() const { return Found; } + }; + + FindClosure FC(OperandToFind, RootKind); + visitAll(Root, FC); + return FC.Found; +} + const SCEV *ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond( Instruction *I, ICmpInst *Cond, Value *TrueVal, Value *FalseVal) { // Try to match some simple smax or umax patterns. @@ -5969,15 +6007,14 @@ const SCEV *ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond( } // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...)) // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...)) + // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...) + // -> umin_seq(x, umin (..., umin_seq(...), ...)) if (getTypeSizeInBits(LHS->getType()) == getTypeSizeInBits(I->getType()) && isa(RHS) && cast(RHS)->isZero() && isa(TrueVal) && cast(TrueVal)->isZero()) { const SCEV *X = getSCEV(LHS); - auto *FalseValExpr = dyn_cast(getSCEV(FalseVal)); - if (FalseValExpr && - (FalseValExpr->getSCEVType() == scUMinExpr || - FalseValExpr->getSCEVType() == scSequentialUMinExpr) && - is_contained(FalseValExpr->operands(), X)) + const SCEV *FalseValExpr = getSCEV(FalseVal); + if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr)) return getUMinExpr(X, FalseValExpr, /*Sequential=*/true); } break; diff --git a/llvm/test/Analysis/ScalarEvolution/logical-operations.ll b/llvm/test/Analysis/ScalarEvolution/logical-operations.ll index 4456d2eb6a4b7f..682d99c2e349de 100644 --- a/llvm/test/Analysis/ScalarEvolution/logical-operations.ll +++ b/llvm/test/Analysis/ScalarEvolution/logical-operations.ll @@ -608,7 +608,7 @@ define i32 @umin_seq_x_y_z(i32 %x, i32 %y, i32 %z) { ; CHECK-NEXT: %r0 = select i1 %y.is.zero, i32 0, i32 %umin ; CHECK-NEXT: --> (%y umin_seq (%x umin %z)) U: full-set S: full-set ; CHECK-NEXT: %r = select i1 %x.is.zero, i32 0, i32 %r0 -; CHECK-NEXT: --> %r U: full-set S: full-set +; CHECK-NEXT: --> (%x umin_seq %y umin_seq %z) U: full-set S: full-set ; CHECK-NEXT: Determining loop execution counts for: @umin_seq_x_y_z ; %umin0 = call i32 @llvm.umin(i32 %z, i32 %x) @@ -632,7 +632,7 @@ define i32 @umin_seq_a_b_c_d(i32 %a, i32 %b, i32 %c, i32 %d) { ; CHECK-NEXT: %umin = call i32 @llvm.umin.i32(i32 %umin0, i32 %r1) ; CHECK-NEXT: --> ((%c umin_seq %d) umin %a umin %b) U: full-set S: full-set ; CHECK-NEXT: %r = select i1 %d.is.zero, i32 0, i32 %umin -; CHECK-NEXT: --> %r U: full-set S: full-set +; CHECK-NEXT: --> (%d umin_seq (%a umin %b umin %c)) U: full-set S: full-set ; CHECK-NEXT: Determining loop execution counts for: @umin_seq_a_b_c_d ; %umin1 = call i32 @llvm.umin(i32 %c, i32 %d)