From 3478db19c1ad88ed15809ee7978bf2faba85c341 Mon Sep 17 00:00:00 2001 From: VadimCurca Date: Thu, 27 Nov 2025 09:02:40 +0100 Subject: [PATCH] [mlir][llvm] Fix import of branch weights with "expected" field This commit fixes the import of `branch_weights` metadata from LLVM IR to the LLVM dialect. Previously, `branch_weights` metadata containing the `!"expected"` field were rejected because the importer expected integer weights at operand 1, but found a string. --- .../LLVMIR/LLVMIRToLLVMTranslation.cpp | 23 +++++++++--- .../LLVMIR/Import/metadata-profiling.ll | 36 +++++++++++++++++++ 2 files changed, 54 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp index 44732d5466f6d..81c9da1d98c40 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp @@ -113,7 +113,7 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node, return failure(); // Handle function entry count metadata. - if (name->getString() == "function_entry_count") { + if (name->getString() == llvm::MDProfLabels::FunctionEntryCount) { // TODO support function entry count metadata with GUID fields. if (node->getNumOperands() != 2) @@ -131,15 +131,28 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node, << "expected function_entry_count to be attached to a function"; } - if (name->getString() != "branch_weights") + if (name->getString() != llvm::MDProfLabels::BranchWeights) return failure(); + // The branch_weights metadata must have at least 2 operands. + if (node->getNumOperands() < 2) + return failure(); + + ArrayRef branchWeightOperands = + node->operands().drop_front(); + if (auto *mdString = dyn_cast(node->getOperand(1))) { + if (mdString->getString() != llvm::MDProfLabels::ExpectedBranchWeights) + return failure(); + // The MLIR WeightedBranchOpInterface does not support the + // ExpectedBranchWeights field, so it is dropped. + branchWeightOperands = branchWeightOperands.drop_front(); + } // Handle branch weights metadata. SmallVector branchWeights; - branchWeights.reserve(node->getNumOperands() - 1); - for (unsigned i = 1, e = node->getNumOperands(); i != e; ++i) { + branchWeights.reserve(branchWeightOperands.size()); + for (const llvm::MDOperand &operand : branchWeightOperands) { llvm::ConstantInt *branchWeight = - llvm::mdconst::dyn_extract(node->getOperand(i)); + llvm::mdconst::dyn_extract(operand); if (!branchWeight) return failure(); branchWeights.push_back(branchWeight->getZExtValue()); diff --git a/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll b/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll index c623df0b605b2..328062545ed63 100644 --- a/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll +++ b/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll @@ -16,6 +16,22 @@ bb2: ; // ----- +; CHECK-LABEL: @cond_br_expected +define i64 @cond_br_expected(i1 %arg1, i64 %arg2) { +entry: + ; CHECK: llvm.cond_br + ; CHECK-SAME: weights([1, 2000]) + br i1 %arg1, label %bb1, label %bb2, !prof !0 +bb1: + ret i64 %arg2 +bb2: + ret i64 %arg2 +} + +!0 = !{!"branch_weights", !"expected", i32 1, i32 2000} + +; // ----- + ; CHECK-LABEL: @simple_switch( define i32 @simple_switch(i32 %arg1) { ; CHECK: llvm.switch @@ -36,6 +52,26 @@ bbd: ; // ----- +; CHECK-LABEL: @simple_switch_expected( +define i32 @simple_switch_expected(i32 %arg1) { + ; CHECK: llvm.switch + ; CHECK: {branch_weights = array} + switch i32 %arg1, label %bbd [ + i32 0, label %bb1 + i32 9, label %bb2 + ], !prof !0 +bb1: + ret i32 %arg1 +bb2: + ret i32 %arg1 +bbd: + ret i32 %arg1 +} + +!0 = !{!"branch_weights", !"expected", i32 1, i32 1, i32 2000} + +; // ----- + ; Verify that a single weight attached to a call is not translated. ; The MLIR WeightedBranchOpInterface does not support this case.