Skip to content

Commit

Permalink
[PGO]Implement metadata combine for 'branch_weights' of direct
Browse files Browse the repository at this point in the history
callsites when none of the instructions folds the rest away.

- Merge cases are added for simplify-cfg {sink,hoist}, based on https://gcc.godbolt.org/z/avGvc38W7 and https://gcc.godbolt.org/z/dbWbjGhaE
- When one instruction folds the others in, do not update branch_weights
  with sum (see test/Transforms/GVN/calls-readonly.ll)

Differential Revision: https://reviews.llvm.org/D148877
  • Loading branch information
minglotus-6 committed Apr 27, 2023
1 parent f478721 commit b3cb950
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 3 deletions.
10 changes: 10 additions & 0 deletions llvm/include/llvm/IR/Metadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -1274,6 +1274,11 @@ class MDNode : public Metadata {
template <class NodeTy>
static void dispatchResetHash(NodeTy *, std::false_type) {}

/// Merge branch weights from two direct callsites.
static MDNode *mergeDirectCallProfMetadata(MDNode *A, MDNode *B,
const Instruction *AInstr,
const Instruction *BInstr);

public:
using op_iterator = const MDOperand *;
using op_range = iterator_range<op_iterator>;
Expand Down Expand Up @@ -1319,6 +1324,11 @@ class MDNode : public Metadata {
static MDNode *getMostGenericRange(MDNode *A, MDNode *B);
static MDNode *getMostGenericAliasScope(MDNode *A, MDNode *B);
static MDNode *getMostGenericAlignmentOrDereferenceable(MDNode *A, MDNode *B);
/// Merge !prof metadata from two instructions.
/// Currently only implemented with direct callsites with branch weights.
static MDNode *getMergedProfMetadata(MDNode *A, MDNode *B,
const Instruction *AInstr,
const Instruction *BInstr);
};

/// Tuple of metadata.
Expand Down
64 changes: 64 additions & 0 deletions llvm/lib/IR/Metadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1072,6 +1072,70 @@ MDNode *MDNode::getMostGenericFPMath(MDNode *A, MDNode *B) {
return B;
}

// Call instructions with branch weights are only used in SamplePGO as
// documented in
/// https://llvm.org/docs/BranchWeightMetadata.html#callinst).
MDNode *MDNode::mergeDirectCallProfMetadata(MDNode *A, MDNode *B,
const Instruction *AInstr,
const Instruction *BInstr) {
assert(A && B && AInstr && BInstr && "Caller should guarantee");
auto &Ctx = AInstr->getContext();
MDBuilder MDHelper(Ctx);

// LLVM IR verifier verifies !prof metadata has at least 2 operands.
assert(A->getNumOperands() >= 2 && B->getNumOperands() >= 2 &&
"!prof annotations should have no less than 2 operands");
MDString *AMDS = dyn_cast<MDString>(A->getOperand(0));
MDString *BMDS = dyn_cast<MDString>(B->getOperand(0));
// LLVM IR verfier verifies first operand is MDString.
assert(AMDS != nullptr && BMDS != nullptr &&
"first operand should be a non-null MDString");
StringRef AProfName = AMDS->getString();
StringRef BProfName = BMDS->getString();
if (AProfName.equals("branch_weights") &&
BProfName.equals("branch_weights")) {
ConstantInt *AInstrWeight =
mdconst::dyn_extract<ConstantInt>(A->getOperand(1));
ConstantInt *BInstrWeight =
mdconst::dyn_extract<ConstantInt>(B->getOperand(1));
assert(AInstrWeight && BInstrWeight && "verified by LLVM verifier");
return MDNode::get(Ctx,
{MDHelper.createString("branch_weights"),
MDHelper.createConstant(ConstantInt::get(
Type::getInt64Ty(Ctx),
SaturatingAdd(AInstrWeight->getZExtValue(),
BInstrWeight->getZExtValue())))});
}
return nullptr;
}

// Pass in both instructions and nodes. Instruction information (e.g.,
// instruction type) helps interpret profiles and make implementation clearer.
MDNode *MDNode::getMergedProfMetadata(MDNode *A, MDNode *B,
const Instruction *AInstr,
const Instruction *BInstr) {
if (!(A && B)) {
return A ? A : B;
}

assert(AInstr->getMetadata(LLVMContext::MD_prof) == A &&
"Caller should guarantee");
assert(BInstr->getMetadata(LLVMContext::MD_prof) == B &&
"Caller should guarantee");

const CallInst *ACall = dyn_cast<CallInst>(AInstr);
const CallInst *BCall = dyn_cast<CallInst>(BInstr);

// Both ACall and BCall are direct callsites.
if (ACall && BCall && ACall->getCalledFunction() &&
BCall->getCalledFunction())
return mergeDirectCallProfMetadata(A, B, AInstr, BInstr);

// The rest of the cases are not implemented but could be added
// when there are use cases.
return nullptr;
}

static bool isContiguous(const ConstantRange &A, const ConstantRange &B) {
return A.getUpper() == B.getLower() || A.getLower() == B.getUpper();
}
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/Transforms/Utils/Local.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2709,6 +2709,10 @@ void llvm::combineMetadata(Instruction *K, const Instruction *J,
// Preserve !nontemporal if it is present on both instructions.
K->setMetadata(Kind, JMD);
break;
case LLVMContext::MD_prof:
if (DoesKMove)
K->setMetadata(Kind, MDNode::getMergedProfMetadata(KMD, JMD, K, J));
break;
}
}
// Set !invariant.group from J if J has it. If both instructions have it
Expand Down Expand Up @@ -2737,6 +2741,7 @@ void llvm::combineMetadataForCSE(Instruction *K, const Instruction *J,
LLVMContext::MD_dereferenceable_or_null,
LLVMContext::MD_access_group,
LLVMContext::MD_preserve_access_index,
LLVMContext::MD_prof,
LLVMContext::MD_nontemporal,
LLVMContext::MD_noundef};
combineMetadata(K, J, KnownIDs, KDominatesJ);
Expand Down
9 changes: 6 additions & 3 deletions llvm/test/Transforms/GVN/calls-readonly.ll
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ target triple = "i386-apple-darwin7"

define ptr @test(ptr %P, ptr %Q, i32 %x, i32 %y) nounwind readonly {
entry:
%0 = tail call i32 @strlen(ptr %P) ; <i32> [#uses=2]
%0 = tail call i32 @strlen(ptr %P), !prof !0 ; <i32> [#uses=2]
%1 = icmp eq i32 %0, 0 ; <i1> [#uses=1]
br i1 %1, label %bb, label %bb1

Expand All @@ -17,7 +17,7 @@ bb: ; preds = %entry
bb1: ; preds = %bb, %entry
%x_addr.0 = phi i32 [ %2, %bb ], [ %x, %entry ] ; <i32> [#uses=1]
%3 = tail call ptr @strchr(ptr %Q, i32 97) ; <ptr> [#uses=1]
%4 = tail call i32 @strlen(ptr %P) ; <i32> [#uses=1]
%4 = tail call i32 @strlen(ptr %P) , !prof !1 ; <i32> [#uses=1]
%5 = add i32 %x_addr.0, %0 ; <i32> [#uses=1]
%.sum = sub i32 %5, %4 ; <i32> [#uses=1]
%6 = getelementptr i8, ptr %3, i32 %.sum ; <ptr> [#uses=1]
Expand All @@ -26,7 +26,7 @@ bb1: ; preds = %bb, %entry

; CHECK: define ptr @test(ptr %P, ptr %Q, i32 %x, i32 %y) #0 {
; CHECK: entry:
; CHECK-NEXT: %0 = tail call i32 @strlen(ptr %P)
; CHECK-NEXT: %0 = tail call i32 @strlen(ptr %P), !prof !0
; CHECK-NEXT: %1 = icmp eq i32 %0, 0
; CHECK-NEXT: br i1 %1, label %bb, label %bb1
; CHECK: bb:
Expand All @@ -43,3 +43,6 @@ bb1: ; preds = %bb, %entry
declare i32 @strlen(ptr) nounwind readonly

declare ptr @strchr(ptr, i32) nounwind readonly

!0 = !{!"branch_weights", i32 95}
!1 = !{!"branch_weights", i32 95}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals --version 2
; RUN: opt < %s -passes='simplifycfg<no-sink-common-insts;hoist-common-insts>' -simplifycfg-require-and-preserve-domtree=1 -S | FileCheck %s --check-prefix=HOIST

; Test case based on C++ code with manualy annotated !prof metadata.
; This is to test that when calls to 'func1' from 'if.then' block
; and 'if.else' block are hoisted, the branch_weights are merged and
; attached to merged call rather than dropped.
;
; int func1(int a, int b) ;
; int func2(int a, int b) ;

; int func(int a, int b, bool c) {
; int sum= 0;
; if(c) {
; sum += func1(a, b);
; } else {
; sum += func1(a, b);
; sum -= func2(a, b);
; }
; return sum;
; }
define i32 @_Z4funciib(i32 %a, i32 %b, i1 %c) {
; HOIST-LABEL: define i32 @_Z4funciib
; HOIST-SAME: (i32 [[A:%.*]], i32 [[B:%.*]], i1 [[C:%.*]]) {
; HOIST-NEXT: entry:
; HOIST-NEXT: [[CALL:%.*]] = tail call i32 @_Z5func1ii(i32 [[A]], i32 [[B]]), !prof [[PROF0:![0-9]+]]
; HOIST-NEXT: br i1 [[C]], label [[IF_END:%.*]], label [[IF_ELSE:%.*]]
; HOIST: if.else:
; HOIST-NEXT: [[CALL3:%.*]] = tail call i32 @_Z5func2ii(i32 [[A]], i32 [[B]])
; HOIST-NEXT: [[SUB:%.*]] = sub i32 [[CALL]], [[CALL3]]
; HOIST-NEXT: br label [[IF_END]]
; HOIST: if.end:
; HOIST-NEXT: [[SUM_0:%.*]] = phi i32 [ [[SUB]], [[IF_ELSE]] ], [ [[CALL]], [[ENTRY:%.*]] ]
; HOIST-NEXT: ret i32 [[SUM_0]]
;
entry:
br i1 %c, label %if.then, label %if.else

if.then: ; preds = %entry
%call = tail call i32 @_Z5func1ii(i32 %a, i32 %b), !prof !0
br label %if.end

if.else: ; preds = %entry
%call1 = tail call i32 @_Z5func1ii(i32 %a, i32 %b), !prof !1
%call3 = tail call i32 @_Z5func2ii(i32 %a, i32 %b)
%sub = sub i32 %call1, %call3
br label %if.end

if.end: ; preds = %if.else, %if.then
%sum.0 = phi i32 [ %call, %if.then ], [ %sub, %if.else ]
ret i32 %sum.0
}

declare i32 @_Z5func1ii(i32, i32)

declare i32 @_Z5func2ii(i32, i32)

!0 = !{!"branch_weights", i32 10}
!1 = !{!"branch_weights", i32 90}
;.
; HOIST: [[PROF0]] = !{!"branch_weights", i64 100}
;.
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals --version 2
; RUN: opt < %s -passes='simplifycfg<sink-common-insts;no-hoist-common-insts>' -simplifycfg-require-and-preserve-domtree=1 -S | FileCheck %s --check-prefix=SINK


; Test case based on the following C++ code with manualy annotated !prof metadata.
; This is to test that when calls to 'func1' from 'if.then' and 'if.else' are
; sinked, the branch weights are merged and attached to sinked call.
;
; int func1(int a, int b) ;
; int func2(int a, int b) ;

; int func(int a, int b, bool c) {
; int sum = 0;
; if (c) {
; sum += func1(a,b);
; } else {
; b -= func2(a,b);
; sum += func1(a,b);
; }
; return sum;
; }

define i32 @_Z4funciib(i32 %a, i32 %b, i1 %c) {
; SINK-LABEL: define i32 @_Z4funciib
; SINK-SAME: (i32 [[A:%.*]], i32 [[B:%.*]], i1 [[C:%.*]]) {
; SINK-NEXT: entry:
; SINK-NEXT: br i1 [[C]], label [[IF_END:%.*]], label [[IF_ELSE:%.*]]
; SINK: if.else:
; SINK-NEXT: [[CALL1:%.*]] = tail call i32 @_Z5func2ii(i32 [[A]], i32 [[B]])
; SINK-NEXT: [[SUB:%.*]] = sub i32 [[B]], [[CALL1]]
; SINK-NEXT: br label [[IF_END]]
; SINK: if.end:
; SINK-NEXT: [[SUB_SINK:%.*]] = phi i32 [ [[SUB]], [[IF_ELSE]] ], [ [[B]], [[ENTRY:%.*]] ]
; SINK-NEXT: [[CALL2:%.*]] = tail call i32 @_Z5func1ii(i32 [[A]], i32 [[SUB_SINK]]), !prof [[PROF0:![0-9]+]]
; SINK-NEXT: ret i32 [[CALL2]]
;
entry:
br i1 %c, label %if.then, label %if.else

if.then: ; preds = %entry
%call = tail call i32 @_Z5func1ii(i32 %a, i32 %b), !prof !0
br label %if.end

if.else: ; preds = %entry
%call1 = tail call i32 @_Z5func2ii(i32 %a, i32 %b)
%sub = sub i32 %b, %call1
%call2 = tail call i32 @_Z5func1ii(i32 %a, i32 %sub), !prof !1
br label %if.end

if.end: ; preds = %if.else, %if.then
%sum.0 = phi i32 [ %call, %if.then ], [ %call2, %if.else ]
ret i32 %sum.0
}

declare i32 @_Z5func1ii(i32, i32)

declare i32 @_Z5func2ii(i32, i32)

!0 = !{!"branch_weights", i32 10}
!1 = !{!"branch_weights", i32 90}
;.
; SINK: [[PROF0]] = !{!"branch_weights", i64 100}
;.

0 comments on commit b3cb950

Please sign in to comment.