diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index e5a934191e042..b0d5f0ccfbbfb 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -5967,14 +5967,18 @@ const SCEV *ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond( if (isa(C) && cast(C)->getAPInt().ule(1)) return getAddExpr(getUMaxExpr(X, C), Y); } - // x == 0 ? 0 : umin(..., x, ...) -> umin_seq(x, umin(...)) + // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...)) + // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...)) if (getTypeSizeInBits(LHS->getType()) == getTypeSizeInBits(I->getType()) && isa(RHS) && cast(RHS)->isZero() && isa(TrueVal) && cast(TrueVal)->isZero()) { const SCEV *X = getSCEV(LHS); - auto *UMin = dyn_cast(getSCEV(FalseVal)); - if (UMin && is_contained(UMin->operands(), X)) - return getUMinExpr(X, UMin, /*Sequential=*/true); + auto *FalseValExpr = dyn_cast(getSCEV(FalseVal)); + if (FalseValExpr && + (FalseValExpr->getSCEVType() == scUMinExpr || + FalseValExpr->getSCEVType() == scSequentialUMinExpr) && + is_contained(FalseValExpr->operands(), X)) + return getUMinExpr(X, FalseValExpr, /*Sequential=*/true); } break; default: diff --git a/llvm/test/Analysis/ScalarEvolution/logical-operations.ll b/llvm/test/Analysis/ScalarEvolution/logical-operations.ll index 1dde49a1a6c41..4456d2eb6a4b7 100644 --- a/llvm/test/Analysis/ScalarEvolution/logical-operations.ll +++ b/llvm/test/Analysis/ScalarEvolution/logical-operations.ll @@ -587,7 +587,7 @@ define i32 @umin_seq_x_x_y_z(i32 %x, i32 %y, i32 %z) { ; CHECK-NEXT: %r0 = select i1 %x.is.zero, i32 0, i32 %umin ; CHECK-NEXT: --> (%x umin_seq (%y umin %z)) U: full-set S: full-set ; CHECK-NEXT: %r = select i1 %x.is.zero, i32 0, i32 %r0 -; CHECK-NEXT: --> %r U: full-set S: full-set +; CHECK-NEXT: --> (%x umin_seq (%y umin %z)) U: full-set S: full-set ; CHECK-NEXT: Determining loop execution counts for: @umin_seq_x_x_y_z ; %umin0 = call i32 @llvm.umin(i32 %z, i32 %x)