diff --git a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp index 8b9076aff8fa9..375aa4e2cd440 100644 --- a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp +++ b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp @@ -746,6 +746,12 @@ void State::addInfoFor(BasicBlock &BB) { WorkList.emplace_back(DT.getNode(Br->getSuccessor(1)), CmpI, true); } +static Constant *getScalarConstOrSplat(ConstantInt *C, Type *Ty) { + if (auto *VTy = dyn_cast(Ty)) + return ConstantVector::getSplat(VTy->getElementCount(), C); + return C; +} + static bool checkAndReplaceCondition(CmpInst *Cmp, ConstraintInfo &Info) { LLVM_DEBUG(dbgs() << "Checking " << *Cmp << "\n"); CmpInst::Predicate Pred = Cmp->getPredicate(); @@ -780,7 +786,9 @@ static bool checkAndReplaceCondition(CmpInst *Cmp, ConstraintInfo &Info) { dbgs() << "Condition " << *Cmp << " implied by dominating constraints\n"; dumpWithNames(CSToUse, Info.getValue2Index(R.IsSigned)); }); - Cmp->replaceUsesWithIf(ConstantInt::getTrue(Ctx), [](Use &U) { + Constant *TrueC = + getScalarConstOrSplat(ConstantInt::getTrue(Ctx), Cmp->getType()); + Cmp->replaceUsesWithIf(TrueC, [](Use &U) { // Conditions in an assume trivially simplify to true. Skip uses // in assume calls to not destroy the available information. auto *II = dyn_cast(U.getUser()); @@ -797,7 +805,9 @@ static bool checkAndReplaceCondition(CmpInst *Cmp, ConstraintInfo &Info) { dbgs() << "Condition !" << *Cmp << " implied by dominating constraints\n"; dumpWithNames(CSToUse, Info.getValue2Index(R.IsSigned)); }); - Cmp->replaceAllUsesWith(ConstantInt::getFalse(Ctx)); + Constant *FalseC = + getScalarConstOrSplat(ConstantInt::getFalse(Ctx), Cmp->getType()); + Cmp->replaceAllUsesWith(FalseC); NumCondsRemoved++; Changed = true; } diff --git a/llvm/test/Transforms/ConstraintElimination/geps-ptrvector.ll b/llvm/test/Transforms/ConstraintElimination/geps-ptrvector.ll index e30830fff7c76..df915653e08e1 100644 --- a/llvm/test/Transforms/ConstraintElimination/geps-ptrvector.ll +++ b/llvm/test/Transforms/ConstraintElimination/geps-ptrvector.ll @@ -12,3 +12,25 @@ define <2 x i1> @test.vectorgep(<2 x ptr> %vec) { %cond = icmp ule <2 x ptr> %gep, zeroinitializer ret <2 x i1> %cond } + +define <2 x i1> @test.vectorgep.ult.true(<2 x ptr> %vec) { +; CHECK-LABEL: @test.vectorgep.ult.true( +; CHECK-NEXT: [[GEP_1:%.*]] = getelementptr inbounds i32, <2 x ptr> [[VEC:%.*]], i64 1 +; CHECK-NEXT: [[T_1:%.*]] = icmp ult <2 x ptr> [[VEC]], [[GEP_1]] +; CHECK-NEXT: ret <2 x i1> +; + %gep.1 = getelementptr inbounds i32, <2 x ptr> %vec, i64 1 + %t.1 = icmp ult <2 x ptr> %vec, %gep.1 + ret <2 x i1> %t.1 +} + +define <2 x i1> @test.vectorgep.ult.false(<2 x ptr> %vec) { +; CHECK-LABEL: @test.vectorgep.ult.false( +; CHECK-NEXT: [[GEP_1:%.*]] = getelementptr inbounds i32, <2 x ptr> [[VEC:%.*]], i64 1 +; CHECK-NEXT: [[T_1:%.*]] = icmp ult <2 x ptr> [[GEP_1]], [[VEC]] +; CHECK-NEXT: ret <2 x i1> zeroinitializer +; + %gep.1 = getelementptr inbounds i32, <2 x ptr> %vec, i64 1 + %t.1 = icmp ult <2 x ptr> %gep.1, %vec + ret <2 x i1> %t.1 +}