-
Notifications
You must be signed in to change notification settings - Fork 301
Enable WhisperStatefulImpl for NPU, fix Whisper pipelines for transformers 4.53.3 & 4.55 #2126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a4f09c7
1ba5560
7ac58ec
2c31ff4
9faf09a
935e63d
9556acd
14db9c7
5536a6e
68e16f6
bd178b4
6b728f5
7e29a70
6614e94
04c89a2
a303604
816938a
86719ab
979470a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| // Copyright (C) 2024 Intel Corporation | ||
| // Copyright (C) 2024-2025 Intel Corporation | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| #pragma once | ||
|
|
@@ -12,7 +12,8 @@ class WhisperStatefullDecoder : public WhisperDecoder { | |
| public: | ||
| WhisperStatefullDecoder(const std::filesystem::path& models_path, | ||
| const std::string& device, | ||
| const ov::AnyMap& properties); | ||
| const ov::AnyMap& properties, | ||
| const ov::PartialShape& lhs_shape); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we obtain the required dimensions from decoder model? From its parameters or results? As I understand, you need for updating KV cache. Also, you can use variables of decoder for this, can't you? |
||
|
|
||
| void start_async(const Tensor& encoder_hidden_state, const Tensor& input_ids, const Tensor& beam_idx) override; | ||
|
|
||
|
|
@@ -27,5 +28,6 @@ class WhisperStatefullDecoder : public WhisperDecoder { | |
|
|
||
| private: | ||
| ov::InferRequest m_request; | ||
| bool m_has_cache_position = true; | ||
| }; | ||
| } // namespace ov::genai | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -42,6 +42,25 @@ ov::InferRequest init_model(ov::CompiledModel& compiled) { | |
| } | ||
| } | ||
|
|
||
| void reshape_to_static_encoder(std::shared_ptr<ov::Model> model, | ||
| const size_t batch_size, | ||
| const size_t feature_size) { | ||
|
Comment on lines
+45
to
+47
|
||
| std::map<std::string, ov::PartialShape> new_shapes; | ||
| for (auto input : model->inputs()) { | ||
| const auto& input_name = input.get_any_name(); | ||
| ov::PartialShape new_shape; | ||
| if (input_name.find("input_features") != std::string::npos) { | ||
| const auto& partial_shape = input.get_partial_shape(); | ||
| OPENVINO_ASSERT(partial_shape.size() >= 3); | ||
| new_shape = partial_shape; | ||
| new_shape[0] = batch_size; // batch_dim | ||
| new_shape[1] = feature_size; | ||
| new_shapes.emplace(input_name, new_shape); | ||
| } | ||
| } | ||
| model->reshape(new_shapes); | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
||
| namespace ov { | ||
|
|
@@ -55,13 +74,20 @@ class WhisperPipeline::WhisperPipelineStatefulImpl : public WhisperPipeline::Whi | |
| : WhisperPipelineImplBase{models_path}, | ||
| m_sampler(m_tokenizer) { | ||
| ov::Core core = utils::singleton_core(); | ||
| ov::CompiledModel compiled_model; | ||
| if (device == "NPU") { | ||
| auto encoder_model = core.read_model(models_path / "openvino_encoder_model.xml", {}, properties); | ||
| // NB: only batch_size == 1 is supported now for NPU | ||
| reshape_to_static_encoder(encoder_model, 1, m_feature_extractor.feature_size); | ||
| compiled_model = core.compile_model(encoder_model, "NPU", properties); | ||
| } else { | ||
| compiled_model = core.compile_model(models_path / "openvino_encoder_model.xml", device, properties); | ||
| } | ||
|
|
||
| ov::CompiledModel compiled_model = | ||
| core.compile_model(models_path / "openvino_encoder_model.xml", device, properties); | ||
| ov::genai::utils::print_compiled_model_properties(compiled_model, "whisper encoder model"); | ||
| m_encoder = init_model(compiled_model); | ||
|
|
||
| m_decoder = WhisperDecoder::from_path(models_path, device, properties); | ||
| m_decoder = WhisperDecoder::from_path(models_path, device, properties, m_encoder.get_compiled_model().output("last_hidden_state").get_partial_shape()); | ||
|
|
||
| // If eos_token_id was not provided, take value | ||
| if (m_generation_config.eos_token_id == -1) { | ||
|
|
@@ -155,7 +181,13 @@ ov::genai::WhisperPipeline::WhisperPipeline(const std::filesystem::path& models_ | |
| const ov::AnyMap& properties) { | ||
| auto start_time = std::chrono::steady_clock::now(); | ||
| if (device == "NPU") { | ||
| m_impl = std::make_unique<StaticWhisperPipeline>(models_path, properties); | ||
| auto properties_copy = properties; | ||
| const bool use_static_pipeline = utils::pop_or_default(properties_copy, "STATIC_PIPELINE", true); | ||
| if (!use_static_pipeline) { | ||
| m_impl = std::make_unique<WhisperPipelineStatefulImpl>(models_path, device, properties_copy); | ||
| } else { | ||
| m_impl = std::make_unique<StaticWhisperPipeline>(models_path, properties_copy); | ||
| } | ||
| } else { | ||
| m_impl = std::make_unique<WhisperPipelineStatefulImpl>(models_path, device, properties); | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To enable weight sharing for FP16 models:
On top of that, for INT8-SYM:
For asym, we need to consider a third option (CWAI?)