Skip to content
145 changes: 116 additions & 29 deletions onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// Licensed under the MIT License

#include "core/providers/openvino/ov_stateful_patch_utils.h"
#include "core/providers/shared_library/provider_api.h"
#include "core/common/common.h"

namespace onnxruntime {
namespace openvino_ep {
Expand Down Expand Up @@ -132,50 +134,135 @@
manager.run_passes(ov_model);
}

// Converted to C++ from below reference URL:
// https://github.com/huggingface/optimum-intel/blob/main/optimum/exporters/openvino/stateful.py#L281
void PatchStatefulDecoder(std::shared_ptr<ov::Model> model) {
// Helper function to extract KV patterns from output names dynamically
//
// Example: Given output names ["present_key_cross_0", "present_key_cross_1", "present_value_cross_0", "present_value_cross_1", "logits"]
// key_value_output_names = ["present_key_cross_0", "present_key_cross_1", "present_value_cross_0", "present_value_cross_1"]
// unique_patterns = {"key_cross", "value_cross"}
std::pair<std::vector<std::string>, std::unordered_set<std::string>> ExtractKVPatternsFromOutputs(const std::shared_ptr<ov::Model>& model) {
std::vector<std::string> key_value_output_names;
std::unordered_set<std::string> unique_patterns;

const std::string prefix = "present_";
const size_t prefix_len = prefix.length();
for (const ov::Output<ov::Node>& output : model->outputs()) {
const auto& names = output.get_names();
for (const auto& name : names) {
if (name.find(prefix) == 0 && name.length() > prefix_len) {
size_t last_underscore_pos = name.rfind('_');
// Extract pattern between "present_" and the last underscore
if (last_underscore_pos != std::string::npos && last_underscore_pos > prefix_len) {
std::string pattern = name.substr(prefix_len, last_underscore_pos - prefix_len);
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pattern extraction assumes the format 'present__' but doesn't validate that the suffix after the last underscore is actually numeric. This could lead to incorrect pattern extraction if output names have different formats (e.g., 'present_key_cross_layer_0'). Consider validating the numeric suffix before extracting the pattern.

Copilot uses AI. Check for mistakes.
if (!pattern.empty()) {
unique_patterns.insert(pattern);
key_value_output_names.push_back(name);
}
}
break;
}
}
}

if (unique_patterns.size() > 2) {
ORT_THROW("More than two unique KV patterns found in output names.");
}
Comment on lines +166 to +168
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hardcoded limit of 2 patterns contradicts the PR's goal of making the code more general. Consider either removing this restriction or making it configurable, as models with different architectures might legitimately have more than two KV pattern types.

Suggested change
if (unique_patterns.size() > 2) {
ORT_THROW("More than two unique KV patterns found in output names.");
}

Copilot uses AI. Check for mistakes.
return std::make_pair(key_value_output_names, unique_patterns);
}

// Main function to extract KV tensors using dynamic pattern matching
//
// Example: Given input names ["input_ids", "attention_mask", "past_key_cross_0", "past_key_cross_1", "past_value_cross_0", "past_value_cross_1"]
// kv_patterns = {"key_cross", "value_cross"}
//
// key_value_input_names = ["past_key_cross_0", "past_key_cross_1", "past_value_cross_0", "past_value_cross_1"]
// not_kv_inputs = ["input_ids", "attention_mask"]
std::pair<std::vector<std::string>, std::vector<std::string>> ExtractInputKVTensors(
const std::shared_ptr<ov::Model>& model, const std::unordered_set<std::string>& kv_patterns) {

Check warning on line 180 in onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_set> for unordered_set<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc:180: Add #include <unordered_set> for unordered_set<> [build/include_what_you_use] [4]

std::vector<std::string> key_value_input_names;
std::vector<std::string> not_kv_inputs;

if (kv_patterns.empty()) {
// Fallback: use original substring matching
for (const ov::Output<ov::Node>& input : model->inputs()) {
const auto& names = input.get_names();
const std::string input_name = input.get_any_name();

bool is_kv_input = false;
for (const auto& name : names) {
if (name.find("key_values") != std::string::npos ||
name.find("keys") != std::string::npos ||
name.find("values") != std::string::npos) {
key_value_input_names.push_back(name);
is_kv_input = true;
break;
}
}

if (!is_kv_input) {
not_kv_inputs.push_back(input_name);
}
}

return std::make_pair(key_value_input_names, not_kv_inputs);
}

// Inline helper function to check if name is matched with provided pattern followed by "_%d"
auto matches_pattern = [](const std::string& name, const std::string& pattern) -> bool {
size_t pos = name.find(pattern);
if (pos == std::string::npos) {
return false;
}

size_t after_pattern = pos + pattern.length();
if (after_pattern >= name.length() || name[after_pattern] != '_') {
return false;
}

std::string suffix = name.substr(after_pattern + 1);
return !suffix.empty() && std::all_of(suffix.begin(), suffix.end(), ::isdigit);
};

for (const ov::Output<ov::Node>& input : model->inputs()) {
auto& names = input.get_names();

bool found = false;
for (auto& name : names) {
if (name.find("key_values") != std::string::npos) {
key_value_input_names.push_back(name);
found = true;
break;
} else if (name.find("keys") != std::string::npos) {
key_value_input_names.push_back(name);
found = true;
break;
} else if (name.find("values") != std::string::npos) {
key_value_input_names.push_back(name);
found = true;
break;

// Check if any input name contains either key or value pattern
for (const auto& name : names) {
for (const auto& pattern : kv_patterns) {
if (matches_pattern(name, pattern)) {
key_value_input_names.push_back(name);
found = true;
break;
}
}
if (found) break;
}

if (!found) {
not_kv_inputs.push_back(input.get_any_name());
}
}

std::vector<std::string> key_value_output_names;
for (const ov::Output<ov::Node>& output : model->outputs()) {
auto& names = output.get_names();
for (auto& name : names) {
if (name.find("present") != std::string::npos) {
key_value_output_names.push_back(name);
break;
}
}
}
return std::make_pair(key_value_input_names, not_kv_inputs);
}

// Updated PatchStatefulDecoder function
void PatchStatefulDecoder(std::shared_ptr<ov::Model> model) {
// Use the dynamic pattern-based extraction logic
auto [key_value_output_names, extracted_patterns] = ExtractKVPatternsFromOutputs(model);
auto [key_value_input_names, not_kv_inputs] = ExtractInputKVTensors(model, extracted_patterns);

if (key_value_input_names.empty() || key_value_output_names.empty()) {
std::cout << "no key_value_input_names or key_value_output_names found" << std::endl;
return;
ORT_THROW("No key_value_input_names or key_value_output_names found");
}

if (key_value_input_names.size() != key_value_output_names.size()) {
ORT_THROW("Found different sizes between key_value_input_names (",
key_value_input_names.size(),
") and key_value_output_names (",
key_value_output_names.size(),
"). They couldn't be paired.");
}

// By default, batch is the 0 - th but chatglm uses 1 - st dimension as batch
Expand Down
Loading