Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 40 additions & 5 deletions src/cpp/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ std::optional<uint32_t> pop_int_and_cast(ov::AnyMap& config, const std::string&
}

void update_npu_config(ov::AnyMap& config,
const std::shared_ptr<ov::Model>& model,
const ov::genai::utils::KVAxesPosition& kv_pos,
const ov::genai::utils::KVDesc& kv_desc) {
update_config(config, {"NPU_USE_NPUW", "YES"});
Expand All @@ -97,6 +96,26 @@ void update_npu_config(ov::AnyMap& config,
rename_key(config, "++SHARED_HEAD_CONFIG", "++NPUW_LLM_SHARED_HEAD_CONFIG");
}

void update_npu_config_whisper(ov::AnyMap& config,
const ov::genai::utils::KVAxesPosition& kv_pos,
const ov::genai::utils::KVDesc& kv_desc) {
update_config(config, {"NPU_USE_NPUW", "YES"});
update_config(config, {"NPUW_ONLINE_PIPELINE", "NONE"});
update_config(config, {"NPUW_FUNCALL_FOR_ALL", "NO"});
update_config(config, {"NPUW_FOLD", "NO"});
Comment on lines +104 to +105
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To enable weight sharing for FP16 models:

Suggested change
update_config(config, {"NPUW_FUNCALL_FOR_ALL", "NO"});
update_config(config, {"NPUW_FOLD", "NO"});
update_config(config, {"NPUW_FUNCALL_FOR_ALL", "YES"});
update_config(config, {"NPUW_FOLD", "YES"});
update_config(config, {"NPUW_WEIGHTS_BANK", "whisper-shared"});

On top of that, for INT8-SYM:

    update_config(config, {"NPUW_DQ", "YES"});
    update_config(config, {"NPU_COMILER_DYNAMIC_QUANTIZATION", "YES"});

For asym, we need to consider a third option (CWAI?)

update_config(config, {"NPUW_LLM", "YES"});
update_config(config, {"NPUW_WHISPER", "YES"});

update_config(config, {"NPUW_LLM_BATCH_DIM", kv_pos.batch});
update_config(config, {"NPUW_LLM_SEQ_LEN_DIM", kv_pos.seq_len});

update_config(config, {"NPUW_LLM_MAX_PROMPT_LEN", kv_desc.max_prompt_len});
update_config(config, {"NPUW_LLM_MIN_RESPONSE_LEN", kv_desc.min_response_len});

// To disable chunking
update_config(config, {"NPUW_LLM_PREFILL_HINT", "STATIC"});
}

inline bool is_paged_attention_available() {
#if defined(OPENVINO_ARCH_X86_64) || defined(OPENVINO_ARCH_ARM64)
return true;
Expand Down Expand Up @@ -554,7 +573,8 @@ void print_scheduler_config_info(const SchedulerConfig &scheduler_config) {
std::pair<ov::CompiledModel, KVDesc>
compile_decoder_for_npu(const std::shared_ptr<ov::Model>& model,
const ov::AnyMap& config,
const KVAxesPosition& kv_pos) {
const KVAxesPosition& kv_pos,
const bool is_whisper) {
ov::CompiledModel compiled;
ov::AnyMap properties = config;
KVDesc kv_desc;
Expand All @@ -575,9 +595,16 @@ compile_decoder_for_npu(const std::shared_ptr<ov::Model>& model,
kv_desc.max_prompt_len = compiled.get_property("NPUW_LLM_MAX_PROMPT_LEN").as<uint32_t>();
kv_desc.min_response_len = compiled.get_property("NPUW_LLM_MIN_RESPONSE_LEN").as<uint32_t>();
} else {
kv_desc.max_prompt_len = pop_int_and_cast(properties, "MAX_PROMPT_LEN").value_or(1024u);
kv_desc.min_response_len = pop_int_and_cast(properties, "MIN_RESPONSE_LEN").value_or(128u);
update_npu_config(properties, model, kv_pos, kv_desc);
if (is_whisper) {
kv_desc.max_prompt_len = pop_int_and_cast(properties, "MAX_PROMPT_LEN").value_or(4u);
// kvcache size for Whisper = 448u (MAX_PROMPT_LEN + MIN_RESPONSE_LEN)
kv_desc.min_response_len = pop_int_and_cast(properties, "MIN_RESPONSE_LEN").value_or(444u);
update_npu_config_whisper(properties, kv_pos, kv_desc);
} else {
kv_desc.max_prompt_len = pop_int_and_cast(properties, "MAX_PROMPT_LEN").value_or(1024u);
kv_desc.min_response_len = pop_int_and_cast(properties, "MIN_RESPONSE_LEN").value_or(128u);
update_npu_config(properties, kv_pos, kv_desc);
}
compiled = ov::genai::utils::singleton_core().compile_model(model, "NPU", properties);
// Also export compiled model if required
if (export_blob) {
Expand Down Expand Up @@ -813,6 +840,14 @@ void export_model(ov::CompiledModel& compiled_model, const std::filesystem::path
out.close();
}

bool has_input(const std::shared_ptr<ov::Model>& model, const std::string& name) {
auto inputs = model->inputs();
auto it = std::find_if(inputs.begin(), inputs.end(), [&](const auto& port) {
return port.get_names().count(name) != 0;
});
return it != inputs.end();
}

} // namespace utils
} // namespace genai
} // namespace ov
8 changes: 7 additions & 1 deletion src/cpp/src/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ struct KVDesc {

std::pair<ov::CompiledModel, KVDesc> compile_decoder_for_npu(const std::shared_ptr<ov::Model>& model,
const ov::AnyMap& config,
const KVAxesPosition& kv_pos);
const KVAxesPosition& kv_pos,
const bool is_whisper = false);

/// @brief SharedOptional is a wrapper around a reference to an existing object and an optional shared alternative value.
/// The difference from std::optional is that the default state is not empty and contains a reference to an existing object outside the class.
Expand Down Expand Up @@ -308,6 +309,11 @@ ov::CompiledModel import_model(const std::filesystem::path& blob_path,
*/
void export_model(ov::CompiledModel& compiled_model, const std::filesystem::path& blob_path);

/**
* @brief Checks if the model has an input with the specified name.
*/
bool has_input(const std::shared_ptr<Model>& model, const std::string& name);

} // namespace utils
} // namespace genai
} // namespace ov
9 changes: 7 additions & 2 deletions src/cpp/src/whisper/models/decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,19 @@
namespace ov::genai {
std::shared_ptr<WhisperDecoder> WhisperDecoder::from_path(const std::filesystem::path& models_path,
const std::string& device,
const ov::AnyMap& properties) {
const ov::AnyMap& properties,
const ov::PartialShape& lhs_shape) {
bool has_decoder_with_past = std::filesystem::exists(models_path / "openvino_decoder_with_past_model.xml");

if (has_decoder_with_past) {
if (device == "NPU") {
OPENVINO_THROW("For NPU, 3-model whisper pipeline works only with STATIC_PIPELINE : YES configuration "
"(which is default for NPU).");
}
return std::make_shared<WhisperWithPastDecoder>(models_path, device, properties);
}

return std::make_shared<WhisperStatefullDecoder>(models_path, device, properties);
return std::make_shared<WhisperStatefullDecoder>(models_path, device, properties, lhs_shape);
}

std::pair<int64_t, float> WhisperDecoder::detect_language(const ov::Tensor& encoder_hidden_state,
Expand Down
3 changes: 2 additions & 1 deletion src/cpp/src/whisper/models/decoder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ class WhisperDecoder {
public:
static std::shared_ptr<WhisperDecoder> from_path(const std::filesystem::path& models_path,
const std::string& device,
const ov::AnyMap& properties);
const ov::AnyMap& properties,
const ov::PartialShape& lhs_shape);

std::pair<int64_t, float> detect_language(const Tensor& encoder_hidden_state, const int64_t decoder_start_token_id);

Expand Down
40 changes: 34 additions & 6 deletions src/cpp/src/whisper/models/statefull_decoder.cpp
Original file line number Diff line number Diff line change
@@ -1,21 +1,45 @@
// Copyright (C) 2024 Intel Corporation
// Copyright (C) 2024-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include "statefull_decoder.hpp"

#include "utils.hpp"

namespace {
void reshape_hidden_states_to_static(std::shared_ptr<ov::Model> model, const ov::PartialShape& lhstates_shape) {
ov::PartialShape new_shape = model->input("encoder_hidden_states").get_partial_shape();
OPENVINO_ASSERT(new_shape.size() > 1 && lhstates_shape.size() > 1);
new_shape[1] = lhstates_shape[1];
std::map<std::string, ov::PartialShape> name_to_shape{{"encoder_hidden_states", new_shape}};
model->reshape(name_to_shape);
}

} // anonymous

namespace ov::genai {
WhisperStatefullDecoder::WhisperStatefullDecoder(const std::filesystem::path& models_path,
const std::string& device,
const ov::AnyMap& properties) {
const ov::AnyMap& properties,
const ov::PartialShape& lhs_shape) {
ov::Core core = utils::singleton_core();

auto model = core.read_model(models_path / "openvino_decoder_model.xml", {}, properties);

utils::apply_slice_before_matmul_transformation(model);
m_has_cache_position = utils::has_input(model, "cache_position");

ov::CompiledModel compiled_model;
if (device == "NPU") {
auto kv_pos = ov::genai::utils::get_kv_axes_pos(model);

reshape_hidden_states_to_static(model, lhs_shape);

auto compiled_model = core.compile_model(model, device, properties);
utils::KVDesc kv_desc;
std::tie(compiled_model, kv_desc) = utils::compile_decoder_for_npu(model, properties, kv_pos, true);
} else {
utils::apply_slice_before_matmul_transformation(model);

compiled_model = core.compile_model(model, device, properties);
}

utils::print_compiled_model_properties(compiled_model, "whisper decoder model");
m_request = compiled_model.create_infer_request();
Expand All @@ -29,7 +53,9 @@ void WhisperStatefullDecoder::start_async(const Tensor& encoder_hidden_state,

_set_encoder_hidden_states_tensor(encoder_hidden_state, batch_size, m_request);

_set_cache_position_tensor(seq_len);
if (m_has_cache_position) {
_set_cache_position_tensor(seq_len);
}
m_request.set_tensor("input_ids", input_ids);
m_request.set_tensor("beam_idx", beam_idx);

Expand Down Expand Up @@ -58,7 +84,9 @@ Tensor WhisperStatefullDecoder::wait() {

void WhisperStatefullDecoder::reset_state() {
m_request.reset_state();
m_request.set_tensor("cache_position", create_host_tensor(ov::element::i64, {0}));
if (m_has_cache_position) {
m_request.set_tensor("cache_position", create_host_tensor(ov::element::i64, {0}));
}

Shape encoder_hidden_states_shape{m_request.get_tensor("encoder_hidden_states").get_shape()};
encoder_hidden_states_shape[0] = 0;
Expand Down
6 changes: 4 additions & 2 deletions src/cpp/src/whisper/models/statefull_decoder.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2024 Intel Corporation
// Copyright (C) 2024-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#pragma once
Expand All @@ -12,7 +12,8 @@ class WhisperStatefullDecoder : public WhisperDecoder {
public:
WhisperStatefullDecoder(const std::filesystem::path& models_path,
const std::string& device,
const ov::AnyMap& properties);
const ov::AnyMap& properties,
const ov::PartialShape& lhs_shape);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we obtain the required dimensions from decoder model? From its parameters or results? As I understand, you need for updating KV cache. Also, you can use variables of decoder for this, can't you?
My point is to avoid passing last hidden state shape from encoder. I think we can deduce it from decoder itself.


void start_async(const Tensor& encoder_hidden_state, const Tensor& input_ids, const Tensor& beam_idx) override;

Expand All @@ -27,5 +28,6 @@ class WhisperStatefullDecoder : public WhisperDecoder {

private:
ov::InferRequest m_request;
bool m_has_cache_position = true;
};
} // namespace ov::genai
5 changes: 4 additions & 1 deletion src/cpp/src/whisper/models/with_past_decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,15 @@ WhisperWithPastDecoder::WhisperWithPastDecoder(const std::filesystem::path& mode
"To obtain stateful decoder model use latest `optimum-intel` package:\n"
"pip install optimum-intel@git+https://github.com/huggingface/optimum-intel.git@main\n"
"optimum-cli export openvino --trust-remote-code --model openai/whisper-tiny whisper-tiny");

ov::Core core = utils::singleton_core();

auto compiled_model = core.compile_model(models_path / "openvino_decoder_model.xml", device, properties);
utils::print_compiled_model_properties(compiled_model, "whisper decoder model");
m_request_decoder = compiled_model.create_infer_request();

m_past_decoder_has_cache_position =
utils::has_input(core.read_model(models_path / "openvino_decoder_with_past_model.xml"), "cache_position");
compiled_model = core.compile_model(models_path / "openvino_decoder_with_past_model.xml", device, properties);
utils::print_compiled_model_properties(compiled_model, "whisper decoder with past model");
m_request_decoder_with_past = compiled_model.create_infer_request();
Expand All @@ -109,7 +112,7 @@ void WhisperWithPastDecoder::start_async(const Tensor& encoder_hidden_state,
_set_encoder_hidden_states_tensor(encoder_hidden_state, batch_size, request);
request.set_tensor("input_ids", input_ids);

if (!is_initial_step) {
if (!is_initial_step && m_past_decoder_has_cache_position) {
ov::Tensor cache_position_tensor = request.get_tensor("cache_position");
cache_position_tensor.set_shape({1});
cache_position_tensor.data<int64_t>()[0] = m_cache_position;
Expand Down
1 change: 1 addition & 0 deletions src/cpp/src/whisper/models/with_past_decoder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class WhisperWithPastDecoder : public WhisperDecoder {
size_t m_cache_position = 0;
bool m_initial_past_key_value_set = false;
bool m_past_key_value_linked = false;
bool m_past_decoder_has_cache_position = true;

void _set_past_key_value(const Tensor& beam_idx);
};
Expand Down
40 changes: 36 additions & 4 deletions src/cpp/src/whisper/pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,25 @@ ov::InferRequest init_model(ov::CompiledModel& compiled) {
}
}

void reshape_to_static_encoder(std::shared_ptr<ov::Model> model,
const size_t batch_size,
const size_t feature_size) {
Comment on lines +45 to +47
Copy link

Copilot AI Oct 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function duplicates existing functionality from reshape_to_static_encoder in pipeline_static.cpp, which already reshapes the encoder model. The only difference is the batch_size parameter. Consider refactoring to reuse the existing implementation or consolidate into a single utility function.

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, propose to do this as part of refactoring task

std::map<std::string, ov::PartialShape> new_shapes;
for (auto input : model->inputs()) {
const auto& input_name = input.get_any_name();
ov::PartialShape new_shape;
if (input_name.find("input_features") != std::string::npos) {
const auto& partial_shape = input.get_partial_shape();
OPENVINO_ASSERT(partial_shape.size() >= 3);
new_shape = partial_shape;
new_shape[0] = batch_size; // batch_dim
new_shape[1] = feature_size;
new_shapes.emplace(input_name, new_shape);
}
}
model->reshape(new_shapes);
}

} // namespace

namespace ov {
Expand All @@ -55,13 +74,20 @@ class WhisperPipeline::WhisperPipelineStatefulImpl : public WhisperPipeline::Whi
: WhisperPipelineImplBase{models_path},
m_sampler(m_tokenizer) {
ov::Core core = utils::singleton_core();
ov::CompiledModel compiled_model;
if (device == "NPU") {
auto encoder_model = core.read_model(models_path / "openvino_encoder_model.xml", {}, properties);
// NB: only batch_size == 1 is supported now for NPU
reshape_to_static_encoder(encoder_model, 1, m_feature_extractor.feature_size);
compiled_model = core.compile_model(encoder_model, "NPU", properties);
} else {
compiled_model = core.compile_model(models_path / "openvino_encoder_model.xml", device, properties);
}

ov::CompiledModel compiled_model =
core.compile_model(models_path / "openvino_encoder_model.xml", device, properties);
ov::genai::utils::print_compiled_model_properties(compiled_model, "whisper encoder model");
m_encoder = init_model(compiled_model);

m_decoder = WhisperDecoder::from_path(models_path, device, properties);
m_decoder = WhisperDecoder::from_path(models_path, device, properties, m_encoder.get_compiled_model().output("last_hidden_state").get_partial_shape());

// If eos_token_id was not provided, take value
if (m_generation_config.eos_token_id == -1) {
Expand Down Expand Up @@ -155,7 +181,13 @@ ov::genai::WhisperPipeline::WhisperPipeline(const std::filesystem::path& models_
const ov::AnyMap& properties) {
auto start_time = std::chrono::steady_clock::now();
if (device == "NPU") {
m_impl = std::make_unique<StaticWhisperPipeline>(models_path, properties);
auto properties_copy = properties;
const bool use_static_pipeline = utils::pop_or_default(properties_copy, "STATIC_PIPELINE", true);
if (!use_static_pipeline) {
m_impl = std::make_unique<WhisperPipelineStatefulImpl>(models_path, device, properties_copy);
} else {
m_impl = std::make_unique<StaticWhisperPipeline>(models_path, properties_copy);
}
} else {
m_impl = std::make_unique<WhisperPipelineStatefulImpl>(models_path, device, properties);
}
Expand Down
Loading
Loading