diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index 2d84b4ae1ba5c..216bdf4eb9efb 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -84,7 +84,6 @@ #include #include #include -#include #include #include #include @@ -6356,25 +6355,25 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector, if (DefaultResult) { Value *ValueCompare = Builder.CreateICmpEQ(Condition, SecondCase, "switch.selectcmp"); - SelectInst *SelectValueInst = cast(Builder.CreateSelect( - ValueCompare, ResultVector[1].first, DefaultResult, "switch.select")); - SelectValue = SelectValueInst; - if (HasBranchWeights) { + SelectValue = Builder.CreateSelect(ValueCompare, ResultVector[1].first, + DefaultResult, "switch.select"); + if (auto *SI = dyn_cast(SelectValue); + SI && HasBranchWeights) { // We start with 3 probabilities, where the numerator is the // corresponding BranchWeights[i], and the denominator is the sum over // BranchWeights. We want the probability and negative probability of // Condition == SecondCase. assert(BranchWeights.size() == 3); - setBranchWeights(SelectValueInst, BranchWeights[2], + setBranchWeights(SI, BranchWeights[2], BranchWeights[0] + BranchWeights[1], /*IsExpected=*/false); } } Value *ValueCompare = Builder.CreateICmpEQ(Condition, FirstCase, "switch.selectcmp"); - SelectInst *Ret = cast(Builder.CreateSelect( - ValueCompare, ResultVector[0].first, SelectValue, "switch.select")); - if (HasBranchWeights) { + Value *Ret = Builder.CreateSelect(ValueCompare, ResultVector[0].first, + SelectValue, "switch.select"); + if (auto *SI = dyn_cast(Ret); SI && HasBranchWeights) { // We may have had a DefaultResult. Base the position of the first and // second's branch weights accordingly. Also the proability that Condition // != FirstCase needs to take that into account. @@ -6382,7 +6381,7 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector, size_t FirstCasePos = (Condition != nullptr); size_t SecondCasePos = FirstCasePos + 1; uint32_t DefaultCase = (Condition != nullptr) ? BranchWeights[0] : 0; - setBranchWeights(Ret, BranchWeights[FirstCasePos], + setBranchWeights(SI, BranchWeights[FirstCasePos], DefaultCase + BranchWeights[SecondCasePos], /*IsExpected=*/false); } @@ -6422,13 +6421,13 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector, Value *And = Builder.CreateAnd(Condition, AndMask); Value *Cmp = Builder.CreateICmpEQ( And, Constant::getIntegerValue(And->getType(), AndMask)); - SelectInst *Ret = cast( - Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult)); - if (HasBranchWeights) { + Value *Ret = + Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult); + if (auto *SI = dyn_cast(Ret); SI && HasBranchWeights) { // We know there's a Default case. We base the resulting branch // weights off its probability. assert(BranchWeights.size() >= 2); - setBranchWeights(Ret, accumulate(drop_begin(BranchWeights), 0), + setBranchWeights(SI, accumulate(drop_begin(BranchWeights), 0), BranchWeights[0], /*IsExpected=*/false); } return Ret; @@ -6448,11 +6447,11 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector, Value *And = Builder.CreateAnd(Condition, ~BitMask, "switch.and"); Value *Cmp = Builder.CreateICmpEQ( And, Constant::getNullValue(And->getType()), "switch.selectcmp"); - SelectInst *Ret = cast( - Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult)); - if (HasBranchWeights) { + Value *Ret = + Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult); + if (auto *SI = dyn_cast(Ret); SI && HasBranchWeights) { assert(BranchWeights.size() >= 2); - setBranchWeights(Ret, accumulate(drop_begin(BranchWeights), 0), + setBranchWeights(SI, accumulate(drop_begin(BranchWeights), 0), BranchWeights[0], /*IsExpected=*/false); } return Ret; @@ -6466,11 +6465,11 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector, Value *Cmp2 = Builder.CreateICmpEQ(Condition, CaseValues[1], "switch.selectcmp.case2"); Value *Cmp = Builder.CreateOr(Cmp1, Cmp2, "switch.selectcmp"); - SelectInst *Ret = cast( - Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult)); - if (HasBranchWeights) { + Value *Ret = + Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult); + if (auto *SI = dyn_cast(Ret); SI && HasBranchWeights) { assert(BranchWeights.size() >= 2); - setBranchWeights(Ret, accumulate(drop_begin(BranchWeights), 0), + setBranchWeights(SI, accumulate(drop_begin(BranchWeights), 0), BranchWeights[0], /*IsExpected=*/false); } return Ret; diff --git a/llvm/test/Transforms/SimplifyCFG/switch-to-select-two-case.ll b/llvm/test/Transforms/SimplifyCFG/switch-to-select-two-case.ll index 39703e9b53b6b..9d78b97c204a8 100644 --- a/llvm/test/Transforms/SimplifyCFG/switch-to-select-two-case.ll +++ b/llvm/test/Transforms/SimplifyCFG/switch-to-select-two-case.ll @@ -755,6 +755,25 @@ bb3: ret i1 %phi } +define i32 @negative_constfold_select() { +; CHECK-LABEL: @negative_constfold_select( +; CHECK-NEXT: entry: +; CHECK-NEXT: ret i32 poison +; +entry: + switch i32 poison, label %default [ + i32 0, label %bb + i32 2, label %bb + ] + +bb: + br label %default + +default: + %ret = phi i32 [ poison, %entry ], [ poison, %bb ] + ret i32 %ret +} + !0 = !{!"function_entry_count", i64 1000} !1 = !{!"branch_weights", i32 3, i32 5, i32 7} !2 = !{!"branch_weights", i32 3, i32 5, i32 7, i32 11, i32 13}