diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index deabacc592c7f..b409af483ad1e 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -810,11 +810,15 @@ Value *SimplifyCFGOpt::isValueEqualityComparison(Instruction *TI) { if (!SI->getParent()->hasNPredecessorsOrMore(128 / SI->getNumSuccessors())) CV = SI->getCondition(); } else if (BranchInst *BI = dyn_cast(TI)) - if (BI->isConditional() && BI->getCondition()->hasOneUse()) + if (BI->isConditional() && BI->getCondition()->hasOneUse()) { if (ICmpInst *ICI = dyn_cast(BI->getCondition())) { if (ICI->isEquality() && getConstantInt(ICI->getOperand(1), DL)) CV = ICI->getOperand(0); + } else if (auto *Trunc = dyn_cast(BI->getCondition())) { + if (Trunc->hasNoUnsignedWrap()) + CV = Trunc->getOperand(0); } + } // Unwrap any lossless ptrtoint cast. if (CV) { @@ -840,11 +844,20 @@ BasicBlock *SimplifyCFGOpt::getValueEqualityComparisonCases( } BranchInst *BI = cast(TI); - ICmpInst *ICI = cast(BI->getCondition()); - BasicBlock *Succ = BI->getSuccessor(ICI->getPredicate() == ICmpInst::ICMP_NE); - Cases.push_back(ValueEqualityComparisonCase( - getConstantInt(ICI->getOperand(1), DL), Succ)); - return BI->getSuccessor(ICI->getPredicate() == ICmpInst::ICMP_EQ); + Value *Cond = BI->getCondition(); + ICmpInst::Predicate Pred; + ConstantInt *C; + if (auto *ICI = dyn_cast(Cond)) { + Pred = ICI->getPredicate(); + C = getConstantInt(ICI->getOperand(1), DL); + } else { + Pred = ICmpInst::ICMP_NE; + auto *Trunc = cast(Cond); + C = ConstantInt::get(cast(Trunc->getOperand(0)->getType()), 0); + } + BasicBlock *Succ = BI->getSuccessor(Pred == ICmpInst::ICMP_NE); + Cases.push_back(ValueEqualityComparisonCase(C, Succ)); + return BI->getSuccessor(Pred == ICmpInst::ICMP_EQ); } /// Given a vector of bb/value pairs, remove any entries @@ -1106,7 +1119,10 @@ static void getBranchWeights(Instruction *TI, // default weight to be the first entry. if (BranchInst *BI = dyn_cast(TI)) { assert(Weights.size() == 2); - ICmpInst *ICI = cast(BI->getCondition()); + auto *ICI = dyn_cast(BI->getCondition()); + if (!ICI) + return; + if (ICI->getPredicate() == ICmpInst::ICMP_EQ) std::swap(Weights.front(), Weights.back()); } diff --git a/llvm/test/Transforms/SimplifyCFG/switch_create.ll b/llvm/test/Transforms/SimplifyCFG/switch_create.ll index f446d718f8206..a1533bdcffb43 100644 --- a/llvm/test/Transforms/SimplifyCFG/switch_create.ll +++ b/llvm/test/Transforms/SimplifyCFG/switch_create.ll @@ -1068,3 +1068,60 @@ if: else: ret void } + +define void @trunc_nuw_i1_condition(i32 %V) { +; CHECK-LABEL: @trunc_nuw_i1_condition( +; CHECK-NEXT: switch i32 [[V:%.*]], label [[F:%.*]] [ +; CHECK-NEXT: i32 2, label [[T:%.*]] +; CHECK-NEXT: i32 0, label [[T]] +; CHECK-NEXT: ] +; CHECK: common.ret: +; CHECK-NEXT: ret void +; CHECK: T: +; CHECK-NEXT: call void @foo1() +; CHECK-NEXT: br label [[COMMON_RET:%.*]] +; CHECK: F: +; CHECK-NEXT: call void @foo2() +; CHECK-NEXT: br label [[COMMON_RET]] +; + %C1 = icmp eq i32 %V, 2 + br i1 %C1, label %T, label %N +N: + %C2 = trunc nuw i32 %V to i1 + br i1 %C2, label %F, label %T +T: + call void @foo1( ) + ret void +F: + call void @foo2( ) + ret void +} + +define void @neg_trunc_i1_condition(i32 %V) { +; CHECK-LABEL: @neg_trunc_i1_condition( +; CHECK-NEXT: [[C1:%.*]] = icmp ne i32 [[V:%.*]], 2 +; CHECK-NEXT: [[C2:%.*]] = trunc i32 [[V]] to i1 +; CHECK-NEXT: [[OR_COND:%.*]] = and i1 [[C1]], [[C2]] +; CHECK-NEXT: br i1 [[OR_COND]], label [[F:%.*]], label [[T:%.*]] +; CHECK: common.ret: +; CHECK-NEXT: ret void +; CHECK: T: +; CHECK-NEXT: call void @foo1() +; CHECK-NEXT: br label [[COMMON_RET:%.*]] +; CHECK: F: +; CHECK-NEXT: call void @foo2() +; CHECK-NEXT: br label [[COMMON_RET]] +; + %C1 = icmp eq i32 %V, 2 + br i1 %C1, label %T, label %N +N: + %C2 = trunc i32 %V to i1 + br i1 %C2, label %F, label %T +T: + call void @foo1( ) + ret void +F: + call void @foo2( ) + ret void +} +