diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index b03fb6213d61c..511c8d3c0b69e 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -80,6 +80,7 @@ #include #include #include +#include #include #include #include @@ -7632,7 +7633,33 @@ static bool simplifySwitchOfPowersOfTwo(SwitchInst *SI, IRBuilder<> &Builder, auto *DefaultCaseBB = SI->getDefaultDest(); BasicBlock *SplitBB = SplitBlock(OrigBB, SI, DTU); auto It = OrigBB->getTerminator()->getIterator(); + SmallVector Weights; + auto HasWeights = + !ProfcheckDisableMetadataFixes && extractBranchWeights(*SI, Weights); auto *BI = BranchInst::Create(SplitBB, DefaultCaseBB, IsPow2, It); + if (HasWeights && any_of(Weights, [](const auto &V) { return V != 0; })) { + // IsPow2 covers a subset of the cases in which we'd go to the default + // label. The other is those powers of 2 that don't appear in the case + // statement. We don't know the distribution of the values coming in, so + // the safest is to split 50-50 the original probability to `default`. + uint64_t OrigDenominator = sum_of(map_range( + Weights, [](const auto &V) { return static_cast(V); })); + SmallVector NewWeights(2); + NewWeights[1] = Weights[0] / 2; + NewWeights[0] = OrigDenominator - NewWeights[1]; + setFittedBranchWeights(*BI, NewWeights, /*IsExpected=*/false); + + // For the original switch, we reduce the weight of the default by the + // amount by which the previous branch contributes to getting to default, + // and then make sure the remaining weights have the same relative ratio + // wrt eachother. + uint64_t CasesDenominator = OrigDenominator - Weights[0]; + Weights[0] /= 2; + for (auto &W : drop_begin(Weights)) + W = NewWeights[0] * static_cast(W) / CasesDenominator; + + setBranchWeights(*SI, Weights, /*IsExpected=*/false); + } // BI is handling the default case for SI, and so should share its DebugLoc. BI->setDebugLoc(SI->getDebugLoc()); It->eraseFromParent(); diff --git a/llvm/test/Transforms/SimplifyCFG/X86/switch-of-powers-of-two.ll b/llvm/test/Transforms/SimplifyCFG/X86/switch-of-powers-of-two.ll index aa95b3fd235e5..d818335f075e5 100644 --- a/llvm/test/Transforms/SimplifyCFG/X86/switch-of-powers-of-two.ll +++ b/llvm/test/Transforms/SimplifyCFG/X86/switch-of-powers-of-two.ll @@ -1,8 +1,13 @@ -; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals all --version 5 ; RUN: opt -passes='simplifycfg' -simplifycfg-require-and-preserve-domtree=1 -S < %s | FileCheck %s target triple = "x86_64-unknown-linux-gnu" +;. +; CHECK: @switch.table.switch_of_powers_two = private unnamed_addr constant [7 x i32] [i32 3, i32 poison, i32 poison, i32 2, i32 1, i32 0, i32 42], align 4 +; CHECK: @switch.table.switch_of_powers_two_default_reachable = private unnamed_addr constant [7 x i32] [i32 3, i32 5, i32 5, i32 2, i32 1, i32 0, i32 42], align 4 +; CHECK: @switch.table.switch_of_powers_two_default_reachable_multipreds = private unnamed_addr constant [7 x i32] [i32 3, i32 poison, i32 poison, i32 2, i32 1, i32 0, i32 42], align 4 +;. define i32 @switch_of_powers_two(i32 %arg) { ; CHECK-LABEL: define i32 @switch_of_powers_two( ; CHECK-SAME: i32 [[ARG:%.*]]) { @@ -35,17 +40,17 @@ return: ret i32 %phi } -define i32 @switch_of_powers_two_default_reachable(i32 %arg) { +define i32 @switch_of_powers_two_default_reachable(i32 %arg) !prof !0 { ; CHECK-LABEL: define i32 @switch_of_powers_two_default_reachable( -; CHECK-SAME: i32 [[ARG:%.*]]) { +; CHECK-SAME: i32 [[ARG:%.*]]) !prof [[PROF0:![0-9]+]] { ; CHECK-NEXT: [[ENTRY:.*]]: ; CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.ctpop.i32(i32 [[ARG]]) ; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i32 [[TMP0]], 1 -; CHECK-NEXT: br i1 [[TMP1]], label %[[ENTRY_SPLIT:.*]], label %[[RETURN:.*]] +; CHECK-NEXT: br i1 [[TMP1]], label %[[ENTRY_SPLIT:.*]], label %[[RETURN:.*]], !prof [[PROF1:![0-9]+]] ; CHECK: [[ENTRY_SPLIT]]: ; CHECK-NEXT: [[TMP2:%.*]] = call i32 @llvm.cttz.i32(i32 [[ARG]], i1 true) ; CHECK-NEXT: [[TMP3:%.*]] = icmp ult i32 [[TMP2]], 7 -; CHECK-NEXT: br i1 [[TMP3]], label %[[SWITCH_LOOKUP:.*]], label %[[RETURN]] +; CHECK-NEXT: br i1 [[TMP3]], label %[[SWITCH_LOOKUP:.*]], label %[[RETURN]], !prof [[PROF2:![0-9]+]] ; CHECK: [[SWITCH_LOOKUP]]: ; CHECK-NEXT: [[TMP4:%.*]] = zext nneg i32 [[TMP2]] to i64 ; CHECK-NEXT: [[SWITCH_GEP:%.*]] = getelementptr inbounds [7 x i32], ptr @switch.table.switch_of_powers_two_default_reachable, i64 0, i64 [[TMP4]] @@ -62,7 +67,7 @@ entry: i32 16, label %bb3 i32 32, label %bb4 i32 64, label %bb5 - ] + ], !prof !1 default_case: br label %return bb1: br label %return @@ -128,3 +133,13 @@ return: %phi = phi i32 [ 3, %bb1 ], [ 2, %bb2 ], [ 1, %bb3 ], [ 0, %bb4 ], [ 42, %bb5 ], [ %pn, %default_case ] ret i32 %phi } + +!0 = !{!"function_entry_count", i32 10} +!1 = !{!"branch_weights", i32 10, i32 5, i32 7, i32 11, i32 13, i32 17} +;. +; CHECK: attributes #[[ATTR0:[0-9]+]] = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } +;. +; CHECK: [[PROF0]] = !{!"function_entry_count", i32 10} +; CHECK: [[PROF1]] = !{!"branch_weights", i32 58, i32 5} +; CHECK: [[PROF2]] = !{!"branch_weights", i32 56, i32 5} +;. diff --git a/llvm/utils/profcheck-xfail.txt b/llvm/utils/profcheck-xfail.txt index aef7c0987fda7..83bffc70574a8 100644 --- a/llvm/utils/profcheck-xfail.txt +++ b/llvm/utils/profcheck-xfail.txt @@ -1317,8 +1317,6 @@ Transforms/SimpleLoopUnswitch/pr60736.ll Transforms/SimpleLoopUnswitch/trivial-unswitch-freeze-individual-conditions.ll Transforms/SimpleLoopUnswitch/trivial-unswitch.ll Transforms/SimpleLoopUnswitch/trivial-unswitch-logical-and-or.ll -Transforms/SimplifyCFG/RISCV/switch-of-powers-of-two.ll -Transforms/SimplifyCFG/X86/switch-of-powers-of-two.ll Transforms/StackProtector/cross-dso-cfi-stack-chk-fail.ll Transforms/StructurizeCFG/AMDGPU/uniform-regions.ll Transforms/StructurizeCFG/hoist-zerocost.ll