diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 87000a1c36eef..640a01fb7275e 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -17,6 +17,7 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CmpInstAnalysis.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/Loads.h" #include "llvm/Analysis/OverflowInstAnalysis.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" @@ -42,6 +43,7 @@ #include "llvm/Support/KnownBits.h" #include "llvm/Transforms/InstCombine/InstCombiner.h" #include +#include #include #define DEBUG_TYPE "instcombine" @@ -1448,10 +1450,16 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel, return nullptr; }; - if (Instruction *R = ReplaceOldOpWithNewOp(CmpLHS, CmpRHS)) - return R; - if (Instruction *R = ReplaceOldOpWithNewOp(CmpRHS, CmpLHS)) - return R; + bool CanReplaceCmpLHSWithRHS = canReplacePointersIfEqual(CmpLHS, CmpRHS, DL); + if (CanReplaceCmpLHSWithRHS) { + if (Instruction *R = ReplaceOldOpWithNewOp(CmpLHS, CmpRHS)) + return R; + } + bool CanReplaceCmpRHSWithLHS = canReplacePointersIfEqual(CmpRHS, CmpLHS, DL); + if (CanReplaceCmpRHSWithLHS) { + if (Instruction *R = ReplaceOldOpWithNewOp(CmpRHS, CmpLHS)) + return R; + } auto *FalseInst = dyn_cast(FalseVal); if (!FalseInst) @@ -1466,12 +1474,14 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel, // Example: // (X == 42) ? 43 : (X + 1) --> (X == 42) ? (X + 1) : (X + 1) --> X + 1 SmallVector DropFlags; - if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, SQ, - /* AllowRefinement */ false, - &DropFlags) == TrueVal || - simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, SQ, - /* AllowRefinement */ false, - &DropFlags) == TrueVal) { + if ((CanReplaceCmpLHSWithRHS && + simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, SQ, + /* AllowRefinement */ false, + &DropFlags) == TrueVal) || + (CanReplaceCmpRHSWithLHS && + simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, SQ, + /* AllowRefinement */ false, + &DropFlags) == TrueVal)) { for (Instruction *I : DropFlags) { I->dropPoisonGeneratingAnnotations(); Worklist.add(I); diff --git a/llvm/test/Transforms/InstCombine/select-gep.ll b/llvm/test/Transforms/InstCombine/select-gep.ll index dd8dffba11b05..718133699a8a7 100644 --- a/llvm/test/Transforms/InstCombine/select-gep.ll +++ b/llvm/test/Transforms/InstCombine/select-gep.ll @@ -286,3 +286,35 @@ define <2 x ptr> @test7(<2 x ptr> %p1, i64 %idx, <2 x i1> %cc) { %select = select <2 x i1> %cc, <2 x ptr> %p1, <2 x ptr> %gep ret <2 x ptr> %select } + +define ptr @ptr_eq_replace_freeze1(ptr %p, ptr %q) { +; CHECK-LABEL: @ptr_eq_replace_freeze1( +; CHECK-NEXT: [[Q_FR:%.*]] = freeze ptr [[Q:%.*]] +; CHECK-NEXT: [[Q_FR1:%.*]] = freeze ptr [[Q1:%.*]] +; CHECK-NEXT: [[CMP:%.*]] = icmp eq ptr [[Q_FR]], [[Q_FR1]] +; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[CMP]], ptr [[Q_FR]], ptr [[Q_FR1]] +; CHECK-NEXT: ret ptr [[SELECT]] +; + %p.fr = freeze ptr %p + %q.fr = freeze ptr %q + %cmp = icmp eq ptr %p.fr, %q.fr + %select = select i1 %cmp, ptr %p.fr, ptr %q.fr + ret ptr %select +} + +define ptr @ptr_eq_replace_freeze2(ptr %p, ptr %q) { +; CHECK-LABEL: @ptr_eq_replace_freeze2( +; CHECK-NEXT: [[P_FR:%.*]] = freeze ptr [[P:%.*]] +; CHECK-NEXT: [[P_FR1:%.*]] = freeze ptr [[P1:%.*]] +; CHECK-NEXT: [[CMP:%.*]] = icmp eq ptr [[P_FR1]], [[P_FR]] +; CHECK-NEXT: [[SELECT_V:%.*]] = select i1 [[CMP]], ptr [[P_FR1]], ptr [[P_FR]] +; CHECK-NEXT: [[SELECT:%.*]] = getelementptr i8, ptr [[SELECT_V]], i64 16 +; CHECK-NEXT: ret ptr [[SELECT]] +; + %gep1 = getelementptr i32, ptr %p, i64 4 + %gep2 = getelementptr i32, ptr %q, i64 4 + %cmp = icmp eq ptr %p, %q + %cmp.fr = freeze i1 %cmp + %select = select i1 %cmp.fr, ptr %gep1, ptr %gep2 + ret ptr %select +}