diff --git a/llvm/include/llvm/Analysis/InstructionSimplify.h b/llvm/include/llvm/Analysis/InstructionSimplify.h index e1e7da14376e6..75ce4e38df166 100644 --- a/llvm/include/llvm/Analysis/InstructionSimplify.h +++ b/llvm/include/llvm/Analysis/InstructionSimplify.h @@ -299,7 +299,7 @@ Value *SimplifyInstruction(Instruction *I, const SimplifyQuery &Q, /// return null. /// AllowRefinement specifies whether the simplification can be a refinement /// (e.g. 0 instead of poison), or whether it needs to be strictly identical. -Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, +Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, const SimplifyQuery &Q, bool AllowRefinement); /// Replace all uses of 'I' with 'SimpleV' and simplify the uses recursively. diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp index c72670b901fe5..e6baed1779cd1 100644 --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -3852,10 +3852,12 @@ Value *llvm::SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, return ::SimplifyFCmpInst(Predicate, LHS, RHS, FMF, Q, RecursionLimit); } -static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, +static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, const SimplifyQuery &Q, bool AllowRefinement, unsigned MaxRecurse) { + assert(!Op->getType()->isVectorTy() && "This is not safe for vectors"); + // Trivial replacement. if (V == Op) return RepOp; @@ -3965,10 +3967,10 @@ static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, return ConstantFoldInstOperands(I, ConstOps, Q.DL, Q.TLI); } -Value *llvm::SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, +Value *llvm::simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, const SimplifyQuery &Q, bool AllowRefinement) { - return ::SimplifyWithOpReplaced(V, Op, RepOp, Q, AllowRefinement, + return ::simplifyWithOpReplaced(V, Op, RepOp, Q, AllowRefinement, RecursionLimit); } @@ -4090,17 +4092,17 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal, // Note that the equivalence/replacement opportunity does not hold for vectors // because each element of a vector select is chosen independently. if (Pred == ICmpInst::ICMP_EQ && !CondVal->getType()->isVectorTy()) { - if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, + if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, /* AllowRefinement */ false, MaxRecurse) == TrueVal || - SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, + simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, /* AllowRefinement */ false, MaxRecurse) == TrueVal) return FalseVal; - if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, + if (simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, /* AllowRefinement */ true, MaxRecurse) == FalseVal || - SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, + simplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, /* AllowRefinement */ true, MaxRecurse) == FalseVal) return FalseVal; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 68124a188f716..f48cc901b7428 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -1142,7 +1142,7 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel, Value *CmpLHS = Cmp.getOperand(0), *CmpRHS = Cmp.getOperand(1); if (TrueVal != CmpLHS && isGuaranteedNotToBeUndefOrPoison(CmpRHS, SQ.AC, &Sel, &DT)) { - if (Value *V = SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, SQ, + if (Value *V = simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, SQ, /* AllowRefinement */ true)) return replaceOperand(Sel, Swapped ? 2 : 1, V); @@ -1164,7 +1164,7 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel, } if (TrueVal != CmpRHS && isGuaranteedNotToBeUndefOrPoison(CmpLHS, SQ.AC, &Sel, &DT)) - if (Value *V = SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, SQ, + if (Value *V = simplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, SQ, /* AllowRefinement */ true)) return replaceOperand(Sel, Swapped ? 2 : 1, V); @@ -1195,9 +1195,9 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel, // We have an 'EQ' comparison, so the select's false value will propagate. // Example: // (X == 42) ? 43 : (X + 1) --> (X == 42) ? (X + 1) : (X + 1) --> X + 1 - if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, SQ, + if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, SQ, /* AllowRefinement */ false) == TrueVal || - SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, SQ, + simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, SQ, /* AllowRefinement */ false) == TrueVal) { return replaceInstUsesWith(Sel, FalseVal); } @@ -2714,12 +2714,14 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { match(TrueVal, m_Specific(B)) && match(FalseVal, m_Zero())) return replaceOperand(SI, 0, A); - if (Value *S = SimplifyWithOpReplaced(TrueVal, CondVal, One, SQ, - /* AllowRefinement */ true)) - return replaceOperand(SI, 1, S); - if (Value *S = SimplifyWithOpReplaced(FalseVal, CondVal, Zero, SQ, - /* AllowRefinement */ true)) - return replaceOperand(SI, 2, S); + if (!SelType->isVectorTy()) { + if (Value *S = simplifyWithOpReplaced(TrueVal, CondVal, One, SQ, + /* AllowRefinement */ true)) + return replaceOperand(SI, 1, S); + if (Value *S = simplifyWithOpReplaced(FalseVal, CondVal, Zero, SQ, + /* AllowRefinement */ true)) + return replaceOperand(SI, 2, S); + } if (match(FalseVal, m_Zero()) || match(TrueVal, m_One())) { Use *Y = nullptr; diff --git a/llvm/test/Transforms/InstCombine/select-safe-bool-transforms.ll b/llvm/test/Transforms/InstCombine/select-safe-bool-transforms.ll index c15a64ee73152..fef4081c0bb6e 100644 --- a/llvm/test/Transforms/InstCombine/select-safe-bool-transforms.ll +++ b/llvm/test/Transforms/InstCombine/select-safe-bool-transforms.ll @@ -468,3 +468,28 @@ define i1 @bor_lor_right2(i1 %A, i1 %B) { ret i1 %res } +; Value equivalence substitution does not account for vector +; transforms, so it needs a scalar condition operand. +; For example, this would miscompile if %a = {1, 0}. + +define <2 x i1> @PR50500_trueval(<2 x i1> %a, <2 x i1> %b) { +; CHECK-LABEL: @PR50500_trueval( +; CHECK-NEXT: [[S:%.*]] = shufflevector <2 x i1> [[A:%.*]], <2 x i1> poison, <2 x i32> +; CHECK-NEXT: [[R:%.*]] = select <2 x i1> [[A]], <2 x i1> [[S]], <2 x i1> [[B:%.*]] +; CHECK-NEXT: ret <2 x i1> [[R]] +; + %s = shufflevector <2 x i1> %a, <2 x i1> poison, <2 x i32> + %r = select <2 x i1> %a, <2 x i1> %s, <2 x i1> %b + ret <2 x i1> %r +} + +define <2 x i1> @PR50500_falseval(<2 x i1> %a, <2 x i1> %b) { +; CHECK-LABEL: @PR50500_falseval( +; CHECK-NEXT: [[S:%.*]] = shufflevector <2 x i1> [[A:%.*]], <2 x i1> poison, <2 x i32> +; CHECK-NEXT: [[R:%.*]] = select <2 x i1> [[A]], <2 x i1> [[B:%.*]], <2 x i1> [[S]] +; CHECK-NEXT: ret <2 x i1> [[R]] +; + %s = shufflevector <2 x i1> %a, <2 x i1> poison, <2 x i32> + %r = select <2 x i1> %a, <2 x i1> %b, <2 x i1> %s + ret <2 x i1> %r +}