-
Notifications
You must be signed in to change notification settings - Fork 99
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Apply new Python API features from OpenVINO 2023.0 release #265
Changes from 7 commits
ceacad8
70c7b0f
482c658
014a7ec
3ace1f7
74b42c7
ff95a0e
c9c8a6c
9fa6405
36398fb
8c2bb77
9a2e5b0
afe1354
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,11 +16,10 @@ | |
from pathlib import Path | ||
from typing import Dict, Optional, Tuple | ||
|
||
import numpy as np | ||
import openvino | ||
import torch | ||
import transformers | ||
from openvino.runtime import Core, Tensor | ||
from openvino.runtime import Core | ||
from transformers import AutoConfig, AutoModelForSeq2SeqLM | ||
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward | ||
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput | ||
|
@@ -126,10 +125,6 @@ | |
""" | ||
|
||
|
||
def _contiguous_helper(tensor: np.ndarray) -> np.ndarray: | ||
return tensor if tensor.flags["C_CONTIGUOUS"] else np.ascontiguousarray(tensor) | ||
|
||
|
||
@add_start_docstrings( | ||
""" | ||
Sequence-to-sequence model with a language modeling head for OpenVINO inference. | ||
|
@@ -287,10 +282,10 @@ def clear_requests(self): | |
self.decoder_with_past.request = None | ||
|
||
def compile(self): | ||
self.encoder._create_inference_request() | ||
self.decoder._create_inference_request() | ||
self.encoder._compile() | ||
self.decoder._compile() | ||
if self.use_cache: | ||
self.decoder_with_past._create_inference_request() | ||
self.decoder_with_past._compile() | ||
|
||
|
||
class OVEncoder: | ||
|
@@ -318,36 +313,29 @@ def forward( | |
attention_mask: torch.LongTensor = None, | ||
**kwargs, | ||
) -> BaseModelOutput: | ||
self._create_inference_request() | ||
self._compile() | ||
|
||
# Check if inputs are c-like, if not - convert them | ||
input_ids = _contiguous_helper(np.array(input_ids)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks much clearer now ! Could this also be integrated for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I edited it as much as I could. Hopefully it's better now!:) I am sorry for the late response here. However, now OV is bumped in setup.py and features can be used from the official release. |
||
|
||
inputs = { | ||
"input_ids": Tensor(input_ids, shared_memory=True), | ||
} | ||
# Model inputs | ||
inputs = {"input_ids": input_ids} | ||
|
||
# Add the attention_mask inputs when needed | ||
if "attention_mask" in self.input_names: | ||
attention_mask = _contiguous_helper(np.array(attention_mask)) | ||
inputs["attention_mask"] = Tensor(attention_mask, shared_memory=True) | ||
inputs["attention_mask"] = attention_mask | ||
|
||
# Run inference | ||
self.request.start_async(inputs) | ||
self.request.wait() | ||
|
||
last_hidden_state = torch.from_numpy(self.request.get_tensor("last_hidden_state").data).to(self.device) | ||
last_hidden_state = torch.from_numpy(self.request(inputs, shared_memory=True)["last_hidden_state"]).to( | ||
self.device | ||
) | ||
|
||
return BaseModelOutput(last_hidden_state=last_hidden_state) | ||
|
||
def __call__(self, *args, **kwargs): | ||
return self.forward(*args, **kwargs) | ||
|
||
def _create_inference_request(self): | ||
def _compile(self): | ||
if self.request is None: | ||
logger.info("Compiling the encoder and creating the inference request ...") | ||
compiled_model = core.compile_model(self.model, self._device, self.ov_config) | ||
self.request = compiled_model.create_infer_request() | ||
logger.info("Compiling the encoder...") | ||
self.request = core.compile_model(self.model, self._device, self.ov_config) | ||
|
||
|
||
class OVDecoder: | ||
|
@@ -387,56 +375,39 @@ def forward( | |
encoder_attention_mask: Optional[torch.LongTensor] = None, | ||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, | ||
) -> Seq2SeqLMOutput: | ||
self._create_inference_request() | ||
|
||
self._compile() | ||
# Model inputs | ||
inputs = {} | ||
|
||
if past_key_values is not None: | ||
# Flatten the past_key_values | ||
past_key_values = tuple( | ||
_contiguous_helper(np.array(past_key_value)) | ||
for pkv_per_layer in past_key_values | ||
for past_key_value in pkv_per_layer | ||
past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer | ||
) | ||
|
||
# Add the past_key_values to the decoder inputs | ||
inputs = { | ||
input_name: Tensor(past_key_value, shared_memory=True) | ||
for input_name, past_key_value in zip(self.key_value_input_names, past_key_values) | ||
} | ||
inputs = dict(zip(self.key_value_input_names, past_key_values)) | ||
|
||
# Check if inputs are c-like, if not - convert them | ||
input_ids = _contiguous_helper(np.array(input_ids)) | ||
inputs["input_ids"] = Tensor(input_ids, shared_memory=True) | ||
inputs["input_ids"] = input_ids | ||
|
||
# Add the encoder_attention_mask inputs when needed | ||
if "encoder_attention_mask" in self.input_names and encoder_attention_mask is not None: | ||
encoder_attention_mask = _contiguous_helper(np.array(encoder_attention_mask)) | ||
inputs["encoder_attention_mask"] = Tensor(encoder_attention_mask, shared_memory=True) | ||
inputs["encoder_attention_mask"] = encoder_attention_mask | ||
|
||
# Add the encoder_hidden_states inputs when needed | ||
if "encoder_hidden_states" in self.input_names and encoder_hidden_states is not None: | ||
encoder_hidden_states = _contiguous_helper(np.array(encoder_hidden_states)) | ||
inputs["encoder_hidden_states"] = Tensor(encoder_hidden_states, shared_memory=True) | ||
inputs["encoder_hidden_states"] = encoder_hidden_states | ||
|
||
# Run inference | ||
self.request.start_async(inputs) | ||
self.request.wait() | ||
|
||
outputs = {} | ||
for key, value in zip(self.request.model_outputs, self.request.outputs): | ||
output_names = key.get_names() | ||
output_name = "logits" if "logits" in output_names else next(iter(output_names)) | ||
outputs[output_name] = value.data | ||
|
||
outputs = self.request(inputs, shared_memory=True) | ||
logits = torch.from_numpy(outputs["logits"]).to(self.device) | ||
|
||
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the | ||
# self-attention layer and 2 to the cross-attention layer) | ||
out_past_key_values = tuple( | ||
torch.from_numpy(outputs[key]).to(self.device) | ||
for key in outputs | ||
if ("key_values" in key or "present" in key) | ||
torch.from_numpy(outputs[next(iter(key))]).to(self.device) | ||
for key in outputs.names() | ||
if ("key_values" in next(iter(key)) or "present" in next(iter(key))) | ||
) | ||
|
||
# Tuple of tuple of length `n_layers`, with each tuple of length equal to: | ||
|
@@ -458,8 +429,7 @@ def forward( | |
def __call__(self, *args, **kwargs): | ||
return self.forward(*args, **kwargs) | ||
|
||
def _create_inference_request(self): | ||
def _compile(self): | ||
if self.request is None: | ||
logger.info("Compiling the decoder and creating the inference request ...") | ||
compiled_model = core.compile_model(self.model, self._device, self.ov_config) | ||
self.request = compiled_model.create_infer_request() | ||
logger.info("Compiling the decoder...") | ||
self.request = core.compile_model(self.model, self._device, self.ov_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be great to rename the
_create_inference_request
method to_compile
for stable diffusion models as well to keep uniformityThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed