diff --git a/llvm/include/llvm/IR/Metadata.h b/llvm/include/llvm/IR/Metadata.h index 262a148039a94..1a584f2ba599b 100644 --- a/llvm/include/llvm/IR/Metadata.h +++ b/llvm/include/llvm/IR/Metadata.h @@ -1274,6 +1274,11 @@ class MDNode : public Metadata { template static void dispatchResetHash(NodeTy *, std::false_type) {} + /// Merge branch weights from two direct callsites. + static MDNode *mergeDirectCallProfMetadata(MDNode *A, MDNode *B, + const Instruction *AInstr, + const Instruction *BInstr); + public: using op_iterator = const MDOperand *; using op_range = iterator_range; @@ -1319,6 +1324,11 @@ class MDNode : public Metadata { static MDNode *getMostGenericRange(MDNode *A, MDNode *B); static MDNode *getMostGenericAliasScope(MDNode *A, MDNode *B); static MDNode *getMostGenericAlignmentOrDereferenceable(MDNode *A, MDNode *B); + /// Merge !prof metadata from two instructions. + /// Currently only implemented with direct callsites with branch weights. + static MDNode *getMergedProfMetadata(MDNode *A, MDNode *B, + const Instruction *AInstr, + const Instruction *BInstr); }; /// Tuple of metadata. diff --git a/llvm/lib/IR/Metadata.cpp b/llvm/lib/IR/Metadata.cpp index cfcfcd762fdc3..6ffeec1f21d33 100644 --- a/llvm/lib/IR/Metadata.cpp +++ b/llvm/lib/IR/Metadata.cpp @@ -1072,6 +1072,70 @@ MDNode *MDNode::getMostGenericFPMath(MDNode *A, MDNode *B) { return B; } +// Call instructions with branch weights are only used in SamplePGO as +// documented in +/// https://llvm.org/docs/BranchWeightMetadata.html#callinst). +MDNode *MDNode::mergeDirectCallProfMetadata(MDNode *A, MDNode *B, + const Instruction *AInstr, + const Instruction *BInstr) { + assert(A && B && AInstr && BInstr && "Caller should guarantee"); + auto &Ctx = AInstr->getContext(); + MDBuilder MDHelper(Ctx); + + // LLVM IR verifier verifies !prof metadata has at least 2 operands. + assert(A->getNumOperands() >= 2 && B->getNumOperands() >= 2 && + "!prof annotations should have no less than 2 operands"); + MDString *AMDS = dyn_cast(A->getOperand(0)); + MDString *BMDS = dyn_cast(B->getOperand(0)); + // LLVM IR verfier verifies first operand is MDString. + assert(AMDS != nullptr && BMDS != nullptr && + "first operand should be a non-null MDString"); + StringRef AProfName = AMDS->getString(); + StringRef BProfName = BMDS->getString(); + if (AProfName.equals("branch_weights") && + BProfName.equals("branch_weights")) { + ConstantInt *AInstrWeight = + mdconst::dyn_extract(A->getOperand(1)); + ConstantInt *BInstrWeight = + mdconst::dyn_extract(B->getOperand(1)); + assert(AInstrWeight && BInstrWeight && "verified by LLVM verifier"); + return MDNode::get(Ctx, + {MDHelper.createString("branch_weights"), + MDHelper.createConstant(ConstantInt::get( + Type::getInt64Ty(Ctx), + SaturatingAdd(AInstrWeight->getZExtValue(), + BInstrWeight->getZExtValue())))}); + } + return nullptr; +} + +// Pass in both instructions and nodes. Instruction information (e.g., +// instruction type) helps interpret profiles and make implementation clearer. +MDNode *MDNode::getMergedProfMetadata(MDNode *A, MDNode *B, + const Instruction *AInstr, + const Instruction *BInstr) { + if (!(A && B)) { + return A ? A : B; + } + + assert(AInstr->getMetadata(LLVMContext::MD_prof) == A && + "Caller should guarantee"); + assert(BInstr->getMetadata(LLVMContext::MD_prof) == B && + "Caller should guarantee"); + + const CallInst *ACall = dyn_cast(AInstr); + const CallInst *BCall = dyn_cast(BInstr); + + // Both ACall and BCall are direct callsites. + if (ACall && BCall && ACall->getCalledFunction() && + BCall->getCalledFunction()) + return mergeDirectCallProfMetadata(A, B, AInstr, BInstr); + + // The rest of the cases are not implemented but could be added + // when there are use cases. + return nullptr; +} + static bool isContiguous(const ConstantRange &A, const ConstantRange &B) { return A.getUpper() == B.getLower() || A.getLower() == B.getUpper(); } diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp index 53d1f8b62d1b8..0f5d2ce841f19 100644 --- a/llvm/lib/Transforms/Utils/Local.cpp +++ b/llvm/lib/Transforms/Utils/Local.cpp @@ -2709,6 +2709,10 @@ void llvm::combineMetadata(Instruction *K, const Instruction *J, // Preserve !nontemporal if it is present on both instructions. K->setMetadata(Kind, JMD); break; + case LLVMContext::MD_prof: + if (DoesKMove) + K->setMetadata(Kind, MDNode::getMergedProfMetadata(KMD, JMD, K, J)); + break; } } // Set !invariant.group from J if J has it. If both instructions have it @@ -2737,6 +2741,7 @@ void llvm::combineMetadataForCSE(Instruction *K, const Instruction *J, LLVMContext::MD_dereferenceable_or_null, LLVMContext::MD_access_group, LLVMContext::MD_preserve_access_index, + LLVMContext::MD_prof, LLVMContext::MD_nontemporal, LLVMContext::MD_noundef}; combineMetadata(K, J, KnownIDs, KDominatesJ); diff --git a/llvm/test/Transforms/GVN/calls-readonly.ll b/llvm/test/Transforms/GVN/calls-readonly.ll index 5c24740c881b4..b4855e41a64f5 100644 --- a/llvm/test/Transforms/GVN/calls-readonly.ll +++ b/llvm/test/Transforms/GVN/calls-readonly.ll @@ -6,7 +6,7 @@ target triple = "i386-apple-darwin7" define ptr @test(ptr %P, ptr %Q, i32 %x, i32 %y) nounwind readonly { entry: - %0 = tail call i32 @strlen(ptr %P) ; [#uses=2] + %0 = tail call i32 @strlen(ptr %P), !prof !0 ; [#uses=2] %1 = icmp eq i32 %0, 0 ; [#uses=1] br i1 %1, label %bb, label %bb1 @@ -17,7 +17,7 @@ bb: ; preds = %entry bb1: ; preds = %bb, %entry %x_addr.0 = phi i32 [ %2, %bb ], [ %x, %entry ] ; [#uses=1] %3 = tail call ptr @strchr(ptr %Q, i32 97) ; [#uses=1] - %4 = tail call i32 @strlen(ptr %P) ; [#uses=1] + %4 = tail call i32 @strlen(ptr %P) , !prof !1 ; [#uses=1] %5 = add i32 %x_addr.0, %0 ; [#uses=1] %.sum = sub i32 %5, %4 ; [#uses=1] %6 = getelementptr i8, ptr %3, i32 %.sum ; [#uses=1] @@ -26,7 +26,7 @@ bb1: ; preds = %bb, %entry ; CHECK: define ptr @test(ptr %P, ptr %Q, i32 %x, i32 %y) #0 { ; CHECK: entry: -; CHECK-NEXT: %0 = tail call i32 @strlen(ptr %P) +; CHECK-NEXT: %0 = tail call i32 @strlen(ptr %P), !prof !0 ; CHECK-NEXT: %1 = icmp eq i32 %0, 0 ; CHECK-NEXT: br i1 %1, label %bb, label %bb1 ; CHECK: bb: @@ -43,3 +43,6 @@ bb1: ; preds = %bb, %entry declare i32 @strlen(ptr) nounwind readonly declare ptr @strchr(ptr, i32) nounwind readonly + +!0 = !{!"branch_weights", i32 95} +!1 = !{!"branch_weights", i32 95} diff --git a/llvm/test/Transforms/SimplifyCFG/merge-direct-call-branch-weights-in-hoist.ll b/llvm/test/Transforms/SimplifyCFG/merge-direct-call-branch-weights-in-hoist.ll new file mode 100644 index 0000000000000..e57033b345384 --- /dev/null +++ b/llvm/test/Transforms/SimplifyCFG/merge-direct-call-branch-weights-in-hoist.ll @@ -0,0 +1,62 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals --version 2 +; RUN: opt < %s -passes='simplifycfg' -simplifycfg-require-and-preserve-domtree=1 -S | FileCheck %s --check-prefix=HOIST + +; Test case based on C++ code with manualy annotated !prof metadata. +; This is to test that when calls to 'func1' from 'if.then' block +; and 'if.else' block are hoisted, the branch_weights are merged and +; attached to merged call rather than dropped. +; +; int func1(int a, int b) ; +; int func2(int a, int b) ; + +; int func(int a, int b, bool c) { +; int sum= 0; +; if(c) { +; sum += func1(a, b); +; } else { +; sum += func1(a, b); +; sum -= func2(a, b); +; } +; return sum; +; } +define i32 @_Z4funciib(i32 %a, i32 %b, i1 %c) { +; HOIST-LABEL: define i32 @_Z4funciib +; HOIST-SAME: (i32 [[A:%.*]], i32 [[B:%.*]], i1 [[C:%.*]]) { +; HOIST-NEXT: entry: +; HOIST-NEXT: [[CALL:%.*]] = tail call i32 @_Z5func1ii(i32 [[A]], i32 [[B]]), !prof [[PROF0:![0-9]+]] +; HOIST-NEXT: br i1 [[C]], label [[IF_END:%.*]], label [[IF_ELSE:%.*]] +; HOIST: if.else: +; HOIST-NEXT: [[CALL3:%.*]] = tail call i32 @_Z5func2ii(i32 [[A]], i32 [[B]]) +; HOIST-NEXT: [[SUB:%.*]] = sub i32 [[CALL]], [[CALL3]] +; HOIST-NEXT: br label [[IF_END]] +; HOIST: if.end: +; HOIST-NEXT: [[SUM_0:%.*]] = phi i32 [ [[SUB]], [[IF_ELSE]] ], [ [[CALL]], [[ENTRY:%.*]] ] +; HOIST-NEXT: ret i32 [[SUM_0]] +; +entry: + br i1 %c, label %if.then, label %if.else + +if.then: ; preds = %entry + %call = tail call i32 @_Z5func1ii(i32 %a, i32 %b), !prof !0 + br label %if.end + +if.else: ; preds = %entry + %call1 = tail call i32 @_Z5func1ii(i32 %a, i32 %b), !prof !1 + %call3 = tail call i32 @_Z5func2ii(i32 %a, i32 %b) + %sub = sub i32 %call1, %call3 + br label %if.end + +if.end: ; preds = %if.else, %if.then + %sum.0 = phi i32 [ %call, %if.then ], [ %sub, %if.else ] + ret i32 %sum.0 +} + +declare i32 @_Z5func1ii(i32, i32) + +declare i32 @_Z5func2ii(i32, i32) + +!0 = !{!"branch_weights", i32 10} +!1 = !{!"branch_weights", i32 90} +;. +; HOIST: [[PROF0]] = !{!"branch_weights", i64 100} +;. diff --git a/llvm/test/Transforms/SimplifyCFG/merge-direct-call-branch-weights-in-sink.ll b/llvm/test/Transforms/SimplifyCFG/merge-direct-call-branch-weights-in-sink.ll new file mode 100644 index 0000000000000..3206746b13a33 --- /dev/null +++ b/llvm/test/Transforms/SimplifyCFG/merge-direct-call-branch-weights-in-sink.ll @@ -0,0 +1,63 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals --version 2 +; RUN: opt < %s -passes='simplifycfg' -simplifycfg-require-and-preserve-domtree=1 -S | FileCheck %s --check-prefix=SINK + + +; Test case based on the following C++ code with manualy annotated !prof metadata. +; This is to test that when calls to 'func1' from 'if.then' and 'if.else' are +; sinked, the branch weights are merged and attached to sinked call. +; +; int func1(int a, int b) ; +; int func2(int a, int b) ; + +; int func(int a, int b, bool c) { +; int sum = 0; +; if (c) { +; sum += func1(a,b); +; } else { +; b -= func2(a,b); +; sum += func1(a,b); +; } +; return sum; +; } + +define i32 @_Z4funciib(i32 %a, i32 %b, i1 %c) { +; SINK-LABEL: define i32 @_Z4funciib +; SINK-SAME: (i32 [[A:%.*]], i32 [[B:%.*]], i1 [[C:%.*]]) { +; SINK-NEXT: entry: +; SINK-NEXT: br i1 [[C]], label [[IF_END:%.*]], label [[IF_ELSE:%.*]] +; SINK: if.else: +; SINK-NEXT: [[CALL1:%.*]] = tail call i32 @_Z5func2ii(i32 [[A]], i32 [[B]]) +; SINK-NEXT: [[SUB:%.*]] = sub i32 [[B]], [[CALL1]] +; SINK-NEXT: br label [[IF_END]] +; SINK: if.end: +; SINK-NEXT: [[SUB_SINK:%.*]] = phi i32 [ [[SUB]], [[IF_ELSE]] ], [ [[B]], [[ENTRY:%.*]] ] +; SINK-NEXT: [[CALL2:%.*]] = tail call i32 @_Z5func1ii(i32 [[A]], i32 [[SUB_SINK]]), !prof [[PROF0:![0-9]+]] +; SINK-NEXT: ret i32 [[CALL2]] +; +entry: + br i1 %c, label %if.then, label %if.else + +if.then: ; preds = %entry + %call = tail call i32 @_Z5func1ii(i32 %a, i32 %b), !prof !0 + br label %if.end + +if.else: ; preds = %entry + %call1 = tail call i32 @_Z5func2ii(i32 %a, i32 %b) + %sub = sub i32 %b, %call1 + %call2 = tail call i32 @_Z5func1ii(i32 %a, i32 %sub), !prof !1 + br label %if.end + +if.end: ; preds = %if.else, %if.then + %sum.0 = phi i32 [ %call, %if.then ], [ %call2, %if.else ] + ret i32 %sum.0 +} + +declare i32 @_Z5func1ii(i32, i32) + +declare i32 @_Z5func2ii(i32, i32) + +!0 = !{!"branch_weights", i32 10} +!1 = !{!"branch_weights", i32 90} +;. +; SINK: [[PROF0]] = !{!"branch_weights", i64 100} +;.