Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved performance of decoders #354

Merged
merged 12 commits into from
Jun 21, 2023
38 changes: 18 additions & 20 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,11 @@ def reshape(self, batch_size: int, sequence_length: int):
logger.warning("Static shapes are not supported for causal language model.")
return self

def compile(self):
if self.request is None:
super().compile()
self.request = self.request.create_infer_request()


@add_start_docstrings(
"""
Expand Down Expand Up @@ -273,7 +278,7 @@ def forward(
if past_key_values is not None:
# Flatten the past_key_values
past_key_values = tuple(
np.array(past_key_value) for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer
past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer
)
# Add the past_key_values to the decoder inputs
inputs = dict(zip(self.key_value_input_names, past_key_values))
Expand Down Expand Up @@ -301,15 +306,14 @@ def forward(
inputs["attention_mask"] = np.array(attention_mask)

# Run inference
outputs = self.request(inputs, shared_memory=True)
self.request.start_async(inputs, shared_memory=True)
AlexKoff88 marked this conversation as resolved.
Show resolved Hide resolved
self.request.wait()

logits = torch.from_numpy(outputs["logits"]).to(self.device)
logits = torch.from_numpy(self.request.get_tensor("logits").data).to(self.device)

if self.use_cache:
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer)
past_key_values = tuple(
torch.from_numpy(outputs[key]).to(self.device) for key in self.key_value_output_names
)
past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names)
# Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention)
past_key_values = tuple(
past_key_values[i : i + self.num_pkv] for i in range(0, len(past_key_values), self.num_pkv)
Expand Down Expand Up @@ -345,13 +349,13 @@ def _reorder_cache(
[`~PreTrainedModel.beam_sample`] is called.
This is required to match `past_key_values` with the correct beam_idx at every generation step.
"""

if self.config.model_type == "bloom":
return self._reorder_cache_bloom(past_key_values, beam_idx)

# from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past_key_values
tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past) for layer_past in past_key_values
)

# Copied from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache
Expand All @@ -365,16 +369,10 @@ def _reorder_cache_bloom(
"""
standardized_past = self._convert_to_standard_cache(past_key_values, batch_size=len(beam_idx))

# Get a copy of `beam_idx` on all the devices where we need those indices.
device_to_beam_idx = {
past_state.device: beam_idx.to(past_state.device)
for layer_past in past_key_values
for past_state in layer_past
}
reordered_past = tuple(
(
layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
np.take(layer_past[0], beam_idx, 0),
np.take(layer_past[1], beam_idx, 0),
)
for layer_past in standardized_past
)
Expand All @@ -392,8 +390,8 @@ def _convert_to_bloom_cache(past_key_value: Tuple[Tuple[torch.Tensor]]) -> Tuple
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
return tuple(
(
layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
layer_past[0].reshape((batch_size_times_num_heads, head_dim, seq_length)),
layer_past[1].reshape((batch_size_times_num_heads, seq_length, head_dim)),
)
for layer_past in past_key_value
)
Expand All @@ -414,8 +412,8 @@ def _convert_to_standard_cache(
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
return tuple(
(
layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
layer_past[0].reshape((batch_size, num_heads, head_dim, seq_length)),
layer_past[1].reshape((batch_size, num_heads, seq_length, head_dim)),
)
for layer_past in past_key_value
)
Expand Down
18 changes: 9 additions & 9 deletions optimum/intel/openvino/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pathlib import Path
from typing import Dict, Optional, Tuple

import numpy as np
import openvino
import torch
import transformers
Expand Down Expand Up @@ -249,7 +250,7 @@ def _reorder_cache(past, beam_idx) -> Tuple[Tuple[torch.FloatTensor]]:
for layer_past in past:
# Cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past[:2]) + layer_past[2:],
)
return reordered_past

Expand Down Expand Up @@ -355,6 +356,8 @@ def __init__(self, model: openvino.runtime.Model, device: str, ov_config: Dict):
self.device = torch.device("cpu")
self.input_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.inputs)}
self.key_value_input_names = [key for key in self.input_names if "key_values" in key]
self.output_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.outputs)}
self.key_value_output_names = [key for key in self.output_names if "key_values" in key or "present" in key]
is_legacy = any("past_key_values" in key.get_any_name() for key in self.model.outputs)

if len(self.key_value_input_names) > 0 and not is_legacy:
Expand Down Expand Up @@ -399,16 +402,13 @@ def forward(
inputs["encoder_hidden_states"] = encoder_hidden_states

# Run inference
outputs = self.request(inputs, shared_memory=True)
logits = torch.from_numpy(outputs["logits"]).to(self.device)
self.request.start_async(inputs, shared_memory=True)
self.request.wait()
logits = torch.from_numpy(self.request.get_tensor("logits").data).to(self.device)

# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the
# self-attention layer and 2 to the cross-attention layer)
out_past_key_values = tuple(
torch.from_numpy(outputs[next(iter(key))]).to(self.device)
for key in outputs.names()
if ("key_values" in next(iter(key)) or "present" in next(iter(key)))
)
out_past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names)

# Tuple of tuple of length `n_layers`, with each tuple of length equal to:
# * 4 for the decoder without cache (k/v of self-attention + k/v of cross-attention)
Expand All @@ -432,4 +432,4 @@ def __call__(self, *args, **kwargs):
def _compile(self):
if self.request is None:
logger.info("Compiling the decoder...")
self.request = core.compile_model(self.model, self._device, self.ov_config)
self.request = core.compile_model(self.model, self._device, self.ov_config).create_infer_request()
27 changes: 23 additions & 4 deletions optimum/intel/openvino/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import logging
from itertools import chain
from pathlib import Path
from typing import Callable, Dict, Optional, Tuple, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union

import nncf
import openvino
Expand All @@ -31,7 +31,7 @@
from nncf.torch.initialization import PTInitializingDataLoader
from nncf.torch.nncf_network import NNCFNetwork
from openvino._offline_transformations import compress_quantize_weights_transformation
from openvino.runtime import Core
from openvino.runtime import Core, Tensor
from torch.onnx import export as onnx_export
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader, RandomSampler
Expand Down Expand Up @@ -237,14 +237,33 @@ def __call__(self, *args, **kwargs):
data_cache.append(*args)
return self.request(*args, *kwargs)

def infer(self, inputs: Any = None, shared_memory: bool = False):
data_cache.append(inputs)
return self.request.infer(inputs, shared_memory)

def start_async(
self,
inputs: Any = None,
userdata: Any = None,
shared_memory: bool = False,
):
data_cache.append(inputs)
self.request.infer(inputs, shared_memory)

def wait(self):
pass

def get_tensor(self, name: str):
return Tensor(self.request.results[name])

def __getattr__(self, attr):
if attr in self.__dict__:
return getattr(self, attr)
return getattr(self.request, attr)

self.model.request = InferRequestWrapper(self.model.request)
for i, data in enumerate(calibration_dataloader):
self.model.generate(**data, max_new_tokens=10)
for _, data in enumerate(calibration_dataloader):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to the PR, but what do you think about uniformizing how quantization is applied on causal langage models depending on whether the user gives a torch.nn.Module or a OVBaseDecoderModel (the number of generation steps is currently not the same). We could also instantiate an OVModel in the from_pretrained method when the given model is a PreTrainedModel

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is hard to accomplish this with the current NNCF PTQ API implementation we have for PyTorch. I think we should deprecate PTQ for PyTorch at some point because it also introduces ambiguity for the user about what workflow to use for quantization.

self.model.generate(**data, max_new_tokens=100)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this modification added to reduce accuracy degradation resulting from quantization? If yes, what did you observe when varying this parameter ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the process, and will update a bit later.

if len(data_cache) >= subset_size:
break
self.model.request = self.model.request.request
Expand Down