Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix OV model for BLOOM architecture #340

Merged
merged 14 commits into from
Jun 9, 2023
8 changes: 8 additions & 0 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

from ..utils.import_utils import is_transformers_version
from .modeling import _TOKENIZER_FOR_DOC, INPUTS_DOCSTRING, MODEL_START_DOCSTRING, OVModel
from .modeling_utils import _prepare_attn_mask, _prepare_decoder_attention_mask
from .utils import ONNX_WEIGHTS_NAME


Expand Down Expand Up @@ -155,6 +156,13 @@ def _from_transformers(
onnx_config_constructor = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task)
onnx_config = onnx_config_constructor(model.config, use_past=use_cache)

# TODO : create ModelPatcher to patch each architecture
if model.config.model_type == "bloom":
model.transformer._prepare_attn_mask = _prepare_attn_mask

if model.config.model_type == "llama":
model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask

# Export the model to the ONNX format
export(model=model, config=onnx_config, output=save_dir_path / model_file_name)

Expand Down
91 changes: 91 additions & 0 deletions optimum/intel/openvino/modeling_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple

import torch


# Modified from transformers.models.bloom.modeling_bloom._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size,
device: torch.device,
past_key_values_length: int,
dtype: torch.dtype = torch.bool,
) -> torch.BoolTensor:
"""
Make causal mask used for bi-directional self-attention.
"""
batch_size, target_length = input_ids_shape
mask = torch.zeros((target_length, target_length + past_key_values_length), dtype=dtype, device=device)
seq_ids = torch.arange(target_length, device=device)

mask[:, past_key_values_length:] = (
(seq_ids[:, None] < seq_ids[None, :]) * torch.finfo(dtype).min
if torch.is_floating_point(mask)
else seq_ids[:, None] < seq_ids[None, :]
)

return mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)


# Modified from transformers.models..bloom.modeling_bloom._prepare_attn_mask
def _prepare_attn_mask(
attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
) -> torch.BoolTensor:
from transformers.models.bloom.modeling_bloom import _expand_mask

# create causal mask
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
combined_attention_mask = None
device = attention_mask.device
_, src_length = input_shape

combined_attention_mask = _make_causal_mask(
input_shape, device=device, past_key_values_length=past_key_values_length
)
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]_prepare_decoder_attention_mask
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
)

return combined_attention_mask


# Modified from transformers.models.llama.modeling_llama._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length):
from transformers.models.llama.modeling_llama import _expand_mask

# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None

combined_attention_mask = _make_causal_mask(
input_shape,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
dtype=inputs_embeds.dtype,
)

if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)

return combined_attention_mask
10 changes: 7 additions & 3 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,12 @@
"bloom": "hf-internal-testing/tiny-random-BloomModel",
"bigbird_pegasus": "hf-internal-testing/tiny-random-bigbird_pegasus",
"distilbert": "hf-internal-testing/tiny-random-distilbert",
# "gpt_bigcode": "bigcode/tiny_starcoder_py",
"gptj": "hf-internal-testing/tiny-random-GPTJModel",
"gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel",
"gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
"gpt2": "hf-internal-testing/tiny-random-gpt2",
"llama": "trl-internal-testing/tiny-random-LlamaForCausalLM",
"marian": "sshleifer/tiny-marian-en-de",
"mbart": "hf-internal-testing/tiny-random-mbart",
"m2m_100": "valhalla/m2m100_tiny_random",
Expand Down Expand Up @@ -417,6 +419,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
"gpt2",
"gpt_neo",
"gpt_neox",
# "llama",
)
GENERATION_LENGTH = 100
SPEEDUP_CACHE = 1.2
Expand All @@ -429,15 +432,16 @@ def test_compare_to_transformers(self, model_arch):
self.assertIsInstance(ov_model.config, PretrainedConfig)
transformers_model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokens = tokenizer("This is a sample", return_tensors="pt")
tokens = tokenizer(
"This is a sample", return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None
)
ov_outputs = ov_model(**tokens)
self.assertTrue("logits" in ov_outputs)
self.assertIsInstance(ov_outputs.logits, torch.Tensor)
with torch.no_grad():
transformers_outputs = transformers_model(**tokens)
# Compare tensor outputs
atol = 1e-1 if model_arch == "bloom" else 1e-4
self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=atol))
self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=1e-4))
gc.collect()

@parameterized.expand(SUPPORTED_ARCHITECTURES)
Expand Down