Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create infer request per inference to enable concurrency #494

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 56 additions & 25 deletions optimum/intel/openvino/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os
from pathlib import Path
from typing import Optional, Union
import queue

import numpy as np
import openvino
Expand Down Expand Up @@ -130,7 +131,8 @@ def to(self, device: str):
be in upper or lower case. To speed up first inference, call `.compile()` after `.to()`.
"""
self._device = device.upper()
self.request = None
self.compiled_model = None

return self

def forward(self, *args, **kwargs):
Expand Down Expand Up @@ -197,8 +199,11 @@ def forward(
inputs["token_type_ids"] = token_type_ids

# Run inference
outputs = self.request(inputs)
logits = torch.from_numpy(outputs["logits"]).to(self.device) if not np_inputs else outputs["logits"]
infer_request = self.compiled_model.create_infer_request()
infer_request.start_async(inputs, shared_memory=True)
infer_request.wait()
logits = torch.from_numpy(infer_request.outputs["logits"]).to(self.device) if not np_inputs else infer_request.outputs["logits"]

return SequenceClassifierOutput(logits=logits)


Expand Down Expand Up @@ -263,13 +268,16 @@ def forward(
inputs["token_type_ids"] = token_type_ids

# Run inference
outputs = self.request(inputs)
infer_request = self.compiled_model.create_infer_request()
infer_request.start_async(inputs, shared_memory=True)
infer_request.wait()
start_logits = (
torch.from_numpy(outputs["start_logits"]).to(self.device) if not np_inputs else outputs["start_logits"]
torch.from_numpy(infer_request.outputs["start_logits"]).to(self.device) if not np_inputs else infer_request.outputs["start_logits"]
)
end_logits = (
torch.from_numpy(outputs["end_logits"]).to(self.device) if not np_inputs else outputs["end_logits"]
torch.from_numpy(infer_request.outputs["end_logits"]).to(self.device) if not np_inputs else infer_request.outputs["end_logits"]
)

return QuestionAnsweringModelOutput(start_logits=start_logits, end_logits=end_logits)


Expand Down Expand Up @@ -333,8 +341,11 @@ def forward(
inputs["token_type_ids"] = token_type_ids

# Run inference
outputs = self.request(inputs)
logits = torch.from_numpy(outputs["logits"]).to(self.device) if not np_inputs else outputs["logits"]
infer_request = self.compiled_model.create_infer_request()
infer_request.start_async(inputs, shared_memory=True)
infer_request.wait()
logits = torch.from_numpy(infer_request.outputs["logits"]).to(self.device) if not np_inputs else infer_request.outputs["logits"]

return TokenClassifierOutput(logits=logits)


Expand Down Expand Up @@ -398,12 +409,15 @@ def forward(
inputs["token_type_ids"] = token_type_ids

# Run inference
outputs = self.request(inputs)
infer_request = self.compiled_model.create_infer_request()
infer_request.start_async(inputs, shared_memory=True)
infer_request.wait()
last_hidden_state = (
torch.from_numpy(outputs["last_hidden_state"]).to(self.device)
torch.from_numpy(infer_request.outputs["last_hidden_state"]).to(self.device)
if not np_inputs
else outputs["last_hidden_state"]
else infer_request.outputs["last_hidden_state"]
)

return BaseModelOutput(last_hidden_state=last_hidden_state)


Expand Down Expand Up @@ -468,8 +482,11 @@ def forward(
inputs["token_type_ids"] = token_type_ids

# Run inference
outputs = self.request(inputs)
logits = torch.from_numpy(outputs["logits"]).to(self.device) if not np_inputs else outputs["logits"]
infer_request = self.compiled_model.create_infer_request()
infer_request.start_async(inputs, shared_memory=True)
infer_request.wait()
logits = torch.from_numpy(infer_request.outputs["logits"]).to(self.device) if not np_inputs else infer_request.outputs["logits"]

return MaskedLMOutput(logits=logits)


Expand Down Expand Up @@ -595,8 +612,11 @@ def forward(
}

# Run inference
outputs = self.request(inputs)
logits = torch.from_numpy(outputs["logits"]).to(self.device) if not np_inputs else outputs["logits"]
infer_request = self.compiled_model.create_infer_request()
infer_request.start_async(inputs, shared_memory=True)
infer_request.wait()
logits = torch.from_numpy(infer_request.outputs["logits"]).to(self.device) if not np_inputs else infer_request.outputs["logits"]

return ImageClassifierOutput(logits=logits)


Expand Down Expand Up @@ -660,8 +680,11 @@ def forward(
inputs["attention_mask"] = attention_mask

# Run inference
outputs = self.request(inputs)
logits = torch.from_numpy(outputs["logits"]).to(self.device) if not np_inputs else outputs["logits"]
infer_request = self.compiled_model.create_infer_request()
infer_request.start_async(inputs, shared_memory=True)
infer_request.wait()
logits = torch.from_numpy(infer_request.outputs["logits"]).to(self.device) if not np_inputs else infer_request.outputs["logits"]

return SequenceClassifierOutput(logits=logits)


Expand Down Expand Up @@ -732,8 +755,11 @@ def forward(
inputs["attention_mask"] = attention_mask

# Run inference
outputs = self.request(inputs)
logits = torch.from_numpy(outputs["logits"]).to(self.device) if not np_inputs else outputs["logits"]
infer_request = self.compiled_model.create_infer_request()
infer_request.start_async(inputs, shared_memory=True)
infer_request.wait()
logits = torch.from_numpy(infer_request.outputs["logits"]).to(self.device) if not np_inputs else infer_request.outputs["logits"]

return CausalLMOutput(logits=logits)


Expand Down Expand Up @@ -813,11 +839,14 @@ def forward(
inputs["attention_mask"] = attention_mask

# Run inference
outputs = self.request(inputs)
logits = torch.from_numpy(outputs["logits"]).to(self.device) if not np_inputs else outputs["logits"]
infer_request = self.compiled_model.create_infer_request()
infer_request.start_async(inputs, shared_memory=True)
infer_request.wait()
logits = torch.from_numpy(infer_request.outputs["logits"]).to(self.device) if not np_inputs else infer_request.outputs["logits"]
embeddings = (
torch.from_numpy(outputs["embeddings"]).to(self.device) if not np_inputs else outputs["embeddings"]
torch.from_numpy(infer_request.outputs["embeddings"]).to(self.device) if not np_inputs else infer_request.outputs["embeddings"]
)


return XVectorOutput(logits=logits, embeddings=embeddings)

Expand Down Expand Up @@ -890,7 +919,9 @@ def forward(
inputs["attention_mask"] = attention_mask

# Run inference
outputs = self.request(inputs)
logits = torch.from_numpy(outputs["logits"]).to(self.device) if not np_inputs else outputs["logits"]

infer_request = self.compiled_model.create_infer_request()
infer_request.start_async(inputs, shared_memory=True)
infer_request.wait()
logits = torch.from_numpy(infer_request.outputs["logits"]).to(self.device) if not np_inputs else infer_request.outputs["logits"]

return TokenClassifierOutput(logits=logits)
14 changes: 9 additions & 5 deletions optimum/intel/openvino/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pathlib import Path
from tempfile import TemporaryDirectory, gettempdir
from typing import Dict, Optional, Union
import queue

import openvino
from huggingface_hub import hf_hub_download
Expand Down Expand Up @@ -88,7 +89,8 @@ def __init__(
self.output_names = output_names

self.model = model
self.request = None
self.compiled_model = None

if enable_compilation:
self.compile()

Expand Down Expand Up @@ -337,15 +339,15 @@ def _to_load(
)

def compile(self):
if self.request is None:
if self.compiled_model is None:
logger.info(f"Compiling the model to {self._device} ...")
ov_config = {**self.ov_config}
if "CACHE_DIR" not in self.ov_config.keys() and not str(self.model_save_dir).startswith(gettempdir()):
# Set default CACHE_DIR only if it is not set, and if the model is not in a temporary directory
cache_dir = Path(self.model_save_dir).joinpath("model_cache")
ov_config["CACHE_DIR"] = str(cache_dir)
logger.info(f"Setting OpenVINO CACHE_DIR to {str(cache_dir)}")
self.request = core.compile_model(self.model, self._device, ov_config)
self.compiled_model = core.compile_model(self.model, self._device, ov_config)

def _reshape(
self,
Expand Down Expand Up @@ -383,7 +385,8 @@ def reshape(self, batch_size: int, sequence_length: int, height: int = None, wid
"""
self.is_dynamic = True if batch_size == -1 and sequence_length == -1 else False
self.model = self._reshape(self.model, batch_size, sequence_length, height, width)
self.request = None
self.compiled_model = None

return self

def half(self):
Expand All @@ -392,7 +395,8 @@ def half(self):
"""
apply_moc_transformations(self.model, cf=False)
compress_model_transformation(self.model)
self.request = None
self.compiled_model = None

return self

def forward(self, *args, **kwargs):
Expand Down
17 changes: 10 additions & 7 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Dict, Optional, Tuple, Union
import queue

import numpy as np
import openvino
Expand Down Expand Up @@ -182,7 +183,8 @@ def update_pkv_precision(self, force_fp32=False):
self.model = self._original_model.clone()
if self.is_dynamic:
self.model = self._reshape(self.model, -1, -1)
self.request = None
self.compiled_model = None


def _save_pretrained(self, save_directory: Union[str, Path]):
"""
Expand Down Expand Up @@ -283,9 +285,8 @@ def reshape(self, batch_size: int, sequence_length: int):
return self

def compile(self):
if self.request is None:
if self.compiled_model is None:
super().compile()
self.request = self.request.create_infer_request()


@add_start_docstrings(
Expand Down Expand Up @@ -385,13 +386,15 @@ def forward(
inputs["position_ids"] = position_ids

# Run inference
self.request.start_async(inputs, shared_memory=True)
self.request.wait()
logits = torch.from_numpy(self.request.get_tensor("logits").data).to(self.device)
infer_request = self.compiled_model.create_infer_request()
infer_request.start_async(inputs, shared_memory=True)
infer_request.wait()
logits = torch.from_numpy(infer_request.get_tensor("logits").data).to(self.device)

if self.use_cache:
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer)
past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names)
past_key_values = tuple(infer_request.get_tensor(key).data for key in self.key_value_output_names)

if self.config.model_type not in MULTI_QUERY_ATTN_MODELS:
# Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention)
past_key_values = tuple(
Expand Down
40 changes: 29 additions & 11 deletions optimum/intel/openvino/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pathlib import Path
from tempfile import TemporaryDirectory, gettempdir
from typing import Any, Dict, List, Optional, Union
import copy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused?


import numpy as np
import openvino
Expand Down Expand Up @@ -528,7 +529,8 @@ def __init__(
for inputs in self.model.inputs
}
self.ov_config = ov_config or {**self.parent_model.ov_config}
self.request = None
self.compiled_model = None

self._model_name = model_name
self._model_dir = Path(model_dir or parent_model._model_save_dir)
config_path = self._model_dir / model_name / self.CONFIG_NAME
Expand All @@ -537,9 +539,9 @@ def __init__(
self.ov_config["CACHE_DIR"] = os.path.join(self._model_dir, self._model_name, "model_cache")

def _compile(self):
if self.request is None:
if self.compiled_model is None:
logger.info(f"Compiling the {self._model_name} to {self.device} ...")
self.request = core.compile_model(self.model, self.device, self.ov_config)
self.compiled_model = core.compile_model(self.model, self.device, self.ov_config)

@property
def device(self):
Expand All @@ -562,8 +564,12 @@ def __call__(self, input_ids: np.ndarray):
inputs = {
"input_ids": input_ids,
}
outputs = self.request(inputs, shared_memory=True)
return list(outputs.values())

infer_request = self.compiled_model.create_infer_request()
infer_request.start_async(inputs, shared_memory=True)
infer_request.wait()
outputs = [output.data for output in infer_request.outputs]
return outputs


class OVModelUnet(OVModelPart):
Expand Down Expand Up @@ -596,8 +602,12 @@ def __call__(
if timestep_cond is not None:
inputs["timestep_cond"] = timestep_cond

outputs = self.request(inputs, shared_memory=True)
return list(outputs.values())
infer_request = self.compiled_model.create_infer_request()
infer_request.start_async(inputs, shared_memory=True)
infer_request.wait()
outputs = [output.data for output in infer_request.outputs]

return outputs


class OVModelVaeDecoder(OVModelPart):
Expand All @@ -612,8 +622,12 @@ def __call__(self, latent_sample: np.ndarray):
inputs = {
"latent_sample": latent_sample,
}
outputs = self.request(inputs, shared_memory=True)
return list(outputs.values())
infer_request = self.compiled_model.create_infer_request()
infer_request.start_async(inputs, shared_memory=True)
infer_request.wait()
outputs = [output.data for output in infer_request.outputs]

return outputs

def _compile(self):
if "GPU" in self.device:
Expand All @@ -633,8 +647,12 @@ def __call__(self, sample: np.ndarray):
inputs = {
"sample": sample,
}
outputs = self.request(inputs, shared_memory=True)
return list(outputs.values())
infer_request = self.compiled_model.create_infer_request()
infer_request.start_async(inputs, shared_memory=True)
infer_request.wait()
outputs = [output.data for output in infer_request.outputs]

return outputs

def _compile(self):
if "GPU" in self.device:
Expand Down
Loading