-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][llvm] Fix import of branch weights with "expected" field #169776
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][llvm] Fix import of branch weights with "expected" field #169776
Conversation
This commit fixes the import of `branch_weights` metadata from LLVM IR to the LLVM dialect. Previously, `branch_weights` metadata containing the `!"expected"` field were rejected because the importer expected integer weights at operand 1, but found a string.
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Vadim Curcă (VadimCurca) ChangesThis commit fixes the import of Full diff: https://github.com/llvm/llvm-project/pull/169776.diff 2 Files Affected:
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
index 44732d5466f6d..7c19b3d57eb79 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
@@ -113,7 +113,7 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node,
return failure();
// Handle function entry count metadata.
- if (name->getString() == "function_entry_count") {
+ if (name->getString() == llvm::MDProfLabels::FunctionEntryCount) {
// TODO support function entry count metadata with GUID fields.
if (node->getNumOperands() != 2)
@@ -131,15 +131,28 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node,
<< "expected function_entry_count to be attached to a function";
}
- if (name->getString() != "branch_weights")
+ if (name->getString() != llvm::MDProfLabels::BranchWeights)
return failure();
+ ArrayRef<llvm::MDOperand> branchWeightOperands =
+ node->operands().drop_front();
+
+ // The branch_weights metadata must have at least 2 operands.
+ if (node->getNumOperands() < 2)
+ return failure();
+ if (auto *expected = dyn_cast<llvm::MDString>(node->getOperand(1))) {
+ if (expected->getString() != llvm::MDProfLabels::ExpectedBranchWeights)
+ return failure();
+ // The MLIR WeightedBranchOpInterface does not support the
+ // ExpectedBranchWeights field, so it is dropped.
+ branchWeightOperands = branchWeightOperands.drop_front();
+ }
// Handle branch weights metadata.
SmallVector<int32_t> branchWeights;
- branchWeights.reserve(node->getNumOperands() - 1);
- for (unsigned i = 1, e = node->getNumOperands(); i != e; ++i) {
+ branchWeights.reserve(branchWeightOperands.size());
+ for (const llvm::MDOperand &operand : branchWeightOperands) {
llvm::ConstantInt *branchWeight =
- llvm::mdconst::dyn_extract<llvm::ConstantInt>(node->getOperand(i));
+ llvm::mdconst::dyn_extract<llvm::ConstantInt>(operand);
if (!branchWeight)
return failure();
branchWeights.push_back(branchWeight->getZExtValue());
diff --git a/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll b/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll
index c623df0b605b2..328062545ed63 100644
--- a/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll
+++ b/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll
@@ -16,6 +16,22 @@ bb2:
; // -----
+; CHECK-LABEL: @cond_br_expected
+define i64 @cond_br_expected(i1 %arg1, i64 %arg2) {
+entry:
+ ; CHECK: llvm.cond_br
+ ; CHECK-SAME: weights([1, 2000])
+ br i1 %arg1, label %bb1, label %bb2, !prof !0
+bb1:
+ ret i64 %arg2
+bb2:
+ ret i64 %arg2
+}
+
+!0 = !{!"branch_weights", !"expected", i32 1, i32 2000}
+
+; // -----
+
; CHECK-LABEL: @simple_switch(
define i32 @simple_switch(i32 %arg1) {
; CHECK: llvm.switch
@@ -36,6 +52,26 @@ bbd:
; // -----
+; CHECK-LABEL: @simple_switch_expected(
+define i32 @simple_switch_expected(i32 %arg1) {
+ ; CHECK: llvm.switch
+ ; CHECK: {branch_weights = array<i32: 1, 1, 2000>}
+ switch i32 %arg1, label %bbd [
+ i32 0, label %bb1
+ i32 9, label %bb2
+ ], !prof !0
+bb1:
+ ret i32 %arg1
+bb2:
+ ret i32 %arg1
+bbd:
+ ret i32 %arg1
+}
+
+!0 = !{!"branch_weights", !"expected", i32 1, i32 1, i32 2000}
+
+; // -----
+
; Verify that a single weight attached to a call is not translated.
; The MLIR WeightedBranchOpInterface does not support this case.
|
293edaa to
3478db1
Compare
gysit
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the cleanup!
LGTM
Dinistro
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the fix.
This commit fixes the import of
branch_weightsmetadata from LLVM IR to the LLVM dialect. Previously,branch_weightsmetadata containing the!"expected"field were rejected because the importer expected integer weights at operand 1, but found a string.