Skip to content

Commit 9cb9b16

Browse files
authored
[mlir][llvm] Fix import of branch weights with "expected" field (#169776)
This commit fixes the import of `branch_weights` metadata from LLVM IR to the LLVM dialect. Previously, `branch_weights` metadata containing the `!"expected"` field were rejected because the importer expected integer weights at operand 1, but found a string.
1 parent 1c7ec06 commit 9cb9b16

File tree

2 files changed

+54
-5
lines changed

2 files changed

+54
-5
lines changed

mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node,
113113
return failure();
114114

115115
// Handle function entry count metadata.
116-
if (name->getString() == "function_entry_count") {
116+
if (name->getString() == llvm::MDProfLabels::FunctionEntryCount) {
117117

118118
// TODO support function entry count metadata with GUID fields.
119119
if (node->getNumOperands() != 2)
@@ -131,15 +131,28 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node,
131131
<< "expected function_entry_count to be attached to a function";
132132
}
133133

134-
if (name->getString() != "branch_weights")
134+
if (name->getString() != llvm::MDProfLabels::BranchWeights)
135135
return failure();
136+
// The branch_weights metadata must have at least 2 operands.
137+
if (node->getNumOperands() < 2)
138+
return failure();
139+
140+
ArrayRef<llvm::MDOperand> branchWeightOperands =
141+
node->operands().drop_front();
142+
if (auto *mdString = dyn_cast<llvm::MDString>(node->getOperand(1))) {
143+
if (mdString->getString() != llvm::MDProfLabels::ExpectedBranchWeights)
144+
return failure();
145+
// The MLIR WeightedBranchOpInterface does not support the
146+
// ExpectedBranchWeights field, so it is dropped.
147+
branchWeightOperands = branchWeightOperands.drop_front();
148+
}
136149

137150
// Handle branch weights metadata.
138151
SmallVector<int32_t> branchWeights;
139-
branchWeights.reserve(node->getNumOperands() - 1);
140-
for (unsigned i = 1, e = node->getNumOperands(); i != e; ++i) {
152+
branchWeights.reserve(branchWeightOperands.size());
153+
for (const llvm::MDOperand &operand : branchWeightOperands) {
141154
llvm::ConstantInt *branchWeight =
142-
llvm::mdconst::dyn_extract<llvm::ConstantInt>(node->getOperand(i));
155+
llvm::mdconst::dyn_extract<llvm::ConstantInt>(operand);
143156
if (!branchWeight)
144157
return failure();
145158
branchWeights.push_back(branchWeight->getZExtValue());

mlir/test/Target/LLVMIR/Import/metadata-profiling.ll

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,22 @@ bb2:
1616

1717
; // -----
1818

19+
; CHECK-LABEL: @cond_br_expected
20+
define i64 @cond_br_expected(i1 %arg1, i64 %arg2) {
21+
entry:
22+
; CHECK: llvm.cond_br
23+
; CHECK-SAME: weights([1, 2000])
24+
br i1 %arg1, label %bb1, label %bb2, !prof !0
25+
bb1:
26+
ret i64 %arg2
27+
bb2:
28+
ret i64 %arg2
29+
}
30+
31+
!0 = !{!"branch_weights", !"expected", i32 1, i32 2000}
32+
33+
; // -----
34+
1935
; CHECK-LABEL: @simple_switch(
2036
define i32 @simple_switch(i32 %arg1) {
2137
; CHECK: llvm.switch
@@ -36,6 +52,26 @@ bbd:
3652

3753
; // -----
3854

55+
; CHECK-LABEL: @simple_switch_expected(
56+
define i32 @simple_switch_expected(i32 %arg1) {
57+
; CHECK: llvm.switch
58+
; CHECK: {branch_weights = array<i32: 1, 1, 2000>}
59+
switch i32 %arg1, label %bbd [
60+
i32 0, label %bb1
61+
i32 9, label %bb2
62+
], !prof !0
63+
bb1:
64+
ret i32 %arg1
65+
bb2:
66+
ret i32 %arg1
67+
bbd:
68+
ret i32 %arg1
69+
}
70+
71+
!0 = !{!"branch_weights", !"expected", i32 1, i32 1, i32 2000}
72+
73+
; // -----
74+
3975
; Verify that a single weight attached to a call is not translated.
4076
; The MLIR WeightedBranchOpInterface does not support this case.
4177

0 commit comments

Comments
 (0)