Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply new Python API features from OpenVINO 2023.0 release #265

Merged
merged 13 commits into from
Jun 1, 2023
31 changes: 12 additions & 19 deletions optimum/intel/openvino/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,7 @@ def forward(
inputs["token_type_ids"] = token_type_ids

# Run inference
outputs = self.request.infer(inputs)
outputs = {key.get_any_name(): value for key, value in outputs.items()}
outputs = self.request(inputs)
logits = torch.from_numpy(outputs["logits"]).to(self.device) if not np_inputs else outputs["logits"]
return SequenceClassifierOutput(logits=logits)

Expand Down Expand Up @@ -241,8 +240,7 @@ def forward(
inputs["token_type_ids"] = token_type_ids

# Run inference
outputs = self.request.infer(inputs)
outputs = {key.get_any_name(): value for key, value in outputs.items()}
outputs = self.request(inputs)
start_logits = (
torch.from_numpy(outputs["start_logits"]).to(self.device) if not np_inputs else outputs["start_logits"]
)
Expand Down Expand Up @@ -312,8 +310,7 @@ def forward(
inputs["token_type_ids"] = token_type_ids

# Run inference
outputs = self.request.infer(inputs)
outputs = {key.get_any_name(): value for key, value in outputs.items()}
outputs = self.request(inputs)
logits = torch.from_numpy(outputs["logits"]).to(self.device) if not np_inputs else outputs["logits"]
return TokenClassifierOutput(logits=logits)

Expand Down Expand Up @@ -378,13 +375,12 @@ def forward(
inputs["token_type_ids"] = token_type_ids

# Run inference
outputs = self.request.infer(inputs)
outputs = {key.get_any_name(): value for key, value in outputs.items()}

last_hidden_state = outputs["last_hidden_state"]
if not np_inputs:
last_hidden_state = torch.from_numpy(last_hidden_state).to(self.device)

outputs = self.request(inputs)
last_hidden_state = (
torch.from_numpy(outputs["last_hidden_state"]).to(self.device)
if not np_inputs
else outputs["last_hidden_state"]
)
return BaseModelOutput(last_hidden_state=last_hidden_state)


Expand Down Expand Up @@ -449,8 +445,7 @@ def forward(
inputs["token_type_ids"] = token_type_ids

# Run inference
outputs = self.request.infer(inputs)
outputs = {key.get_any_name(): value for key, value in outputs.items()}
outputs = self.request(inputs)
logits = torch.from_numpy(outputs["logits"]).to(self.device) if not np_inputs else outputs["logits"]
return MaskedLMOutput(logits=logits)

Expand Down Expand Up @@ -508,8 +503,7 @@ def forward(
}

# Run inference
outputs = self.request.infer(inputs)
outputs = {key.get_any_name(): value for key, value in outputs.items()}
outputs = self.request(inputs)
logits = torch.from_numpy(outputs["logits"]).to(self.device) if not np_inputs else outputs["logits"]
return ImageClassifierOutput(logits=logits)

Expand Down Expand Up @@ -574,7 +568,6 @@ def forward(
inputs["attention_mask"] = attention_mask

# Run inference
outputs = self.request.infer(inputs)
outputs = {key.get_any_name(): value for key, value in outputs.items()}
outputs = self.request(inputs)
logits = torch.from_numpy(outputs["logits"]).to(self.device) if not np_inputs else outputs["logits"]
return SequenceClassifierOutput(logits=logits)
5 changes: 2 additions & 3 deletions optimum/intel/openvino/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,14 +306,13 @@ def _from_transformers(

def compile(self):
if self.request is None:
logger.info("Compiling the model and creating the inference request ...")
logger.info("Compiling the model...")
# Only enable CACHE_DIR for GPU because CACHE_DIR fails with some INT8 models on CPU with 2022.3
ov_config = self.ov_config.copy()
if self._device == "GPU":
cache_dir = Path(self.model_save_dir).joinpath("model_cache")
ov_config["CACHE_DIR"] = str(cache_dir)
compiled_model = core.compile_model(self.model, self._device, ov_config)
self.request = compiled_model.create_infer_request()
self.request = core.compile_model(self.model, self._device, ov_config)

def _reshape(
self,
Expand Down
11 changes: 5 additions & 6 deletions optimum/intel/openvino/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,9 +447,8 @@ def __init__(

def _create_inference_request(self):
if self.request is None:
logger.info("Compiling the encoder and creating the inference request ...")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great to rename the _create_inference_request method to _compile for stable diffusion models as well to keep uniformity

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

compiled_model = core.compile_model(self.model, self.device, self.ov_config)
self.request = compiled_model.create_infer_request()
logger.info("Compiling the encoder...")
self.request = core.compile_model(self.model, self.device, self.ov_config)

@property
def device(self):
Expand All @@ -463,7 +462,7 @@ def __call__(self, input_ids: np.ndarray):
inputs = {
"input_ids": input_ids,
}
outputs = self.request.infer(inputs)
outputs = self.request(inputs)
return list(outputs.values())


Expand All @@ -477,7 +476,7 @@ def __call__(self, sample: np.ndarray, timestep: np.ndarray, encoder_hidden_stat
"encoder_hidden_states": encoder_hidden_states,
}

outputs = self.request.infer(inputs)
outputs = self.request(inputs)
return list(outputs.values())


Expand All @@ -488,7 +487,7 @@ def __call__(self, latent_sample: np.ndarray):
inputs = {
"latent_sample": latent_sample,
}
outputs = self.request.infer(inputs)
outputs = self.request(inputs)
return list(outputs.values())


Expand Down
86 changes: 28 additions & 58 deletions optimum/intel/openvino/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@
from pathlib import Path
from typing import Dict, Optional, Tuple

import numpy as np
import openvino
import torch
import transformers
from openvino.runtime import Core, Tensor
from openvino.runtime import Core
from transformers import AutoConfig, AutoModelForSeq2SeqLM
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
Expand Down Expand Up @@ -126,10 +125,6 @@
"""


def _contiguous_helper(tensor: np.ndarray) -> np.ndarray:
return tensor if tensor.flags["C_CONTIGUOUS"] else np.ascontiguousarray(tensor)


@add_start_docstrings(
"""
Sequence-to-sequence model with a language modeling head for OpenVINO inference.
Expand Down Expand Up @@ -287,10 +282,10 @@ def clear_requests(self):
self.decoder_with_past.request = None

def compile(self):
self.encoder._create_inference_request()
self.decoder._create_inference_request()
self.encoder._compile()
self.decoder._compile()
if self.use_cache:
self.decoder_with_past._create_inference_request()
self.decoder_with_past._compile()


class OVEncoder:
Expand Down Expand Up @@ -318,36 +313,29 @@ def forward(
attention_mask: torch.LongTensor = None,
**kwargs,
) -> BaseModelOutput:
self._create_inference_request()
self._compile()

# Check if inputs are c-like, if not - convert them
input_ids = _contiguous_helper(np.array(input_ids))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks much clearer now ! Could this also be integrated for OVModelForCausalLM from modeling_decoder.py as well ? happy to do it it in a following PR if not

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I edited it as much as I could. Hopefully it's better now!:)

I am sorry for the late response here. However, now OV is bumped in setup.py and features can be used from the official release.


inputs = {
"input_ids": Tensor(input_ids, shared_memory=True),
}
# Model inputs
inputs = {"input_ids": input_ids}

# Add the attention_mask inputs when needed
if "attention_mask" in self.input_names:
attention_mask = _contiguous_helper(np.array(attention_mask))
inputs["attention_mask"] = Tensor(attention_mask, shared_memory=True)
inputs["attention_mask"] = attention_mask

# Run inference
self.request.start_async(inputs)
self.request.wait()

last_hidden_state = torch.from_numpy(self.request.get_tensor("last_hidden_state").data).to(self.device)
last_hidden_state = torch.from_numpy(self.request(inputs, shared_memory=True)["last_hidden_state"]).to(
self.device
)

return BaseModelOutput(last_hidden_state=last_hidden_state)

def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)

def _create_inference_request(self):
def _compile(self):
if self.request is None:
logger.info("Compiling the encoder and creating the inference request ...")
compiled_model = core.compile_model(self.model, self._device, self.ov_config)
self.request = compiled_model.create_infer_request()
logger.info("Compiling the encoder...")
self.request = core.compile_model(self.model, self._device, self.ov_config)


class OVDecoder:
Expand Down Expand Up @@ -387,56 +375,39 @@ def forward(
encoder_attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
) -> Seq2SeqLMOutput:
self._create_inference_request()

self._compile()
# Model inputs
inputs = {}

if past_key_values is not None:
# Flatten the past_key_values
past_key_values = tuple(
_contiguous_helper(np.array(past_key_value))
for pkv_per_layer in past_key_values
for past_key_value in pkv_per_layer
past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer
)

# Add the past_key_values to the decoder inputs
inputs = {
input_name: Tensor(past_key_value, shared_memory=True)
for input_name, past_key_value in zip(self.key_value_input_names, past_key_values)
}
inputs = dict(zip(self.key_value_input_names, past_key_values))

# Check if inputs are c-like, if not - convert them
input_ids = _contiguous_helper(np.array(input_ids))
inputs["input_ids"] = Tensor(input_ids, shared_memory=True)
inputs["input_ids"] = input_ids

# Add the encoder_attention_mask inputs when needed
if "encoder_attention_mask" in self.input_names and encoder_attention_mask is not None:
encoder_attention_mask = _contiguous_helper(np.array(encoder_attention_mask))
inputs["encoder_attention_mask"] = Tensor(encoder_attention_mask, shared_memory=True)
inputs["encoder_attention_mask"] = encoder_attention_mask

# Add the encoder_hidden_states inputs when needed
if "encoder_hidden_states" in self.input_names and encoder_hidden_states is not None:
encoder_hidden_states = _contiguous_helper(np.array(encoder_hidden_states))
inputs["encoder_hidden_states"] = Tensor(encoder_hidden_states, shared_memory=True)
inputs["encoder_hidden_states"] = encoder_hidden_states

# Run inference
self.request.start_async(inputs)
self.request.wait()

outputs = {}
for key, value in zip(self.request.model_outputs, self.request.outputs):
output_names = key.get_names()
output_name = "logits" if "logits" in output_names else next(iter(output_names))
outputs[output_name] = value.data

outputs = self.request(inputs, shared_memory=True)
logits = torch.from_numpy(outputs["logits"]).to(self.device)

# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the
# self-attention layer and 2 to the cross-attention layer)
out_past_key_values = tuple(
torch.from_numpy(outputs[key]).to(self.device)
for key in outputs
if ("key_values" in key or "present" in key)
torch.from_numpy(outputs[next(iter(key))]).to(self.device)
for key in outputs.names()
if ("key_values" in next(iter(key)) or "present" in next(iter(key)))
)

# Tuple of tuple of length `n_layers`, with each tuple of length equal to:
Expand All @@ -458,8 +429,7 @@ def forward(
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)

def _create_inference_request(self):
def _compile(self):
if self.request is None:
logger.info("Compiling the decoder and creating the inference request ...")
compiled_model = core.compile_model(self.model, self._device, self.ov_config)
self.request = compiled_model.create_infer_request()
logger.info("Compiling the decoder...")
self.request = core.compile_model(self.model, self._device, self.ov_config)