Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix inference for causal lm models #351

Merged
merged 14 commits into from
Jun 15, 2023
29 changes: 26 additions & 3 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,21 @@
```
"""

_SUPPORTED_ARCHITECTURES = (
"bart",
"blenderbot",
"blenderbot-small",
"bloom",
# "codegen",
"gpt2",
"gpt_neo",
"gpt_neox",
"llama",
"marian",
"opt",
"pegasus",
)


@add_start_docstrings(
"""
Expand Down Expand Up @@ -153,15 +168,23 @@ def _from_transformers(
"trust_remote_code": trust_remote_code,
}
model = TasksManager.get_model_from_task(task, model_id, **model_kwargs)
config.is_decoder = True
config.is_encoder_decoder = False
if config.model_type not in _SUPPORTED_ARCHITECTURES:
raise ValueError(
f"Unrecognized architecture : {config.model_type}, only :{', '.join(_SUPPORTED_ARCHITECTURES)} architectures are supported."
)

onnx_config_constructor = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task)
onnx_config = onnx_config_constructor(model.config, use_past=use_cache)

# TODO : create ModelPatcher to patch each architecture
if model.config.model_type == "bloom":
if config.model_type == "bloom":
model.transformer._prepare_attn_mask = _prepare_attn_mask

if model.config.model_type == "llama":
elif config.model_type == "llama":
model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
elif config.model_type in {"blenderbot-small", "blenderbot", "opt", "pegasus", "bart"}:
model.model.decoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask

# Export the model to the ONNX format
export(model=model, config=onnx_config, output=save_dir_path / model_file_name)
Expand Down
12 changes: 11 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,17 @@
"accelerate", # transformers 4.29 require accelerate for PyTorch
]

TESTS_REQUIRE = ["pytest", "parameterized", "Pillow", "evaluate", "diffusers", "py-cpuinfo"]
TESTS_REQUIRE = [
"pytest",
"parameterized",
"Pillow",
"evaluate",
"diffusers",
"py-cpuinfo",
"sacremoses",
"torchaudio",
"rjieba",
]

QUALITY_REQUIRE = ["black~=23.1", "ruff>=0.0.241"]

Expand Down
151 changes: 137 additions & 14 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoTokenizer,
GenerationConfig,
PretrainedConfig,
pipeline,
set_seed,
Expand Down Expand Up @@ -76,27 +77,71 @@


MODEL_NAMES = {
"bart": "hf-internal-testing/tiny-random-bart",
"albert": "hf-internal-testing/tiny-random-albert",
"audio_spectrogram_transformer": "Ericwang/tiny-random-ast",
"beit": "hf-internal-testing/tiny-random-BeitForImageClassification",
"bert": "hf-internal-testing/tiny-random-bert",
"bloom": "hf-internal-testing/tiny-random-BloomModel",
"bart": "hf-internal-testing/tiny-random-bart",
"bigbird_pegasus": "hf-internal-testing/tiny-random-bigbird_pegasus",
"blenderbot-small": "hf-internal-testing/tiny-random-BlenderbotModel",
"blenderbot": "hf-internal-testing/tiny-random-blenderbot",
"bloom": "hf-internal-testing/tiny-random-BloomModel",
"camembert": "hf-internal-testing/tiny-random-camembert",
"convbert": "hf-internal-testing/tiny-random-ConvBertForSequenceClassification",
"codegen": "hf-internal-testing/tiny-random-CodeGenModel",
"data2vec_text": "hf-internal-testing/tiny-random-Data2VecTextModel",
"data2vec_vision": "hf-internal-testing/tiny-random-Data2VecVisionModel",
"data2vec_audio": "hf-internal-testing/tiny-random-Data2VecAudioModel",
"deberta": "hf-internal-testing/tiny-random-deberta",
"deberta_v2": "hf-internal-testing/tiny-random-DebertaV2Model",
"deit": "hf-internal-testing/tiny-random-deit",
"convnext": "hf-internal-testing/tiny-random-convnext",
"distilbert": "hf-internal-testing/tiny-random-distilbert",
"electra": "hf-internal-testing/tiny-random-electra",
"flaubert": "hf-internal-testing/tiny-random-flaubert",
# "gpt_bigcode": "bigcode/tiny_starcoder_py",
"gptj": "hf-internal-testing/tiny-random-GPTJModel",
"gpt2": "hf-internal-testing/tiny-random-gpt2",
"gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel",
"gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
"gpt2": "hf-internal-testing/tiny-random-gpt2",
"llama": "trl-internal-testing/tiny-random-LlamaForCausalLM",
"marian": "sshleifer/tiny-marian-en-de",
"gptj": "hf-internal-testing/tiny-random-GPTJModel",
"hubert": "hf-internal-testing/tiny-random-HubertModel",
"ibert": "hf-internal-testing/tiny-random-ibert",
"levit": "hf-internal-testing/tiny-random-LevitModel",
"longt5": "hf-internal-testing/tiny-random-LongT5Model",
"llama": "fxmarty/tiny-llama-fast-tokenizer",
"m2m_100": "hf-internal-testing/tiny-random-m2m_100",
"opt": "hf-internal-testing/tiny-random-OPTModel",
"marian": "sshleifer/tiny-marian-en-de", # hf-internal-testing ones are broken
"mbart": "hf-internal-testing/tiny-random-mbart",
"m2m_100": "valhalla/m2m100_tiny_random",
"mobilebert": "hf-internal-testing/tiny-random-MobileBertModel",
"mobilenet_v1": "google/mobilenet_v1_0.75_192",
"mobilenet_v2": "hf-internal-testing/tiny-random-MobileNetV2Model",
"mobilevit": "hf-internal-testing/tiny-random-mobilevit",
"mt5": "hf-internal-testing/tiny-random-mt5",
"nystromformer": "hf-internal-testing/tiny-random-NystromformerModel",
"pegasus": "hf-internal-testing/tiny-random-pegasus",
"poolformer": "hf-internal-testing/tiny-random-PoolFormerModel",
"resnet": "hf-internal-testing/tiny-random-resnet",
"roberta": "hf-internal-testing/tiny-random-roberta",
"roformer": "hf-internal-testing/tiny-random-roformer",
"segformer": "hf-internal-testing/tiny-random-SegformerModel",
"squeezebert": "hf-internal-testing/tiny-random-squeezebert",
"stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch",
"sew": "hf-internal-testing/tiny-random-SEWModel",
"sew_d": "hf-internal-testing/tiny-random-SEWDModel",
"swin": "hf-internal-testing/tiny-random-SwinModel",
"t5": "hf-internal-testing/tiny-random-t5",
"unispeech": "hf-internal-testing/tiny-random-unispeech",
"unispeech_sat": "hf-internal-testing/tiny-random-UnispeechSatModel",
"vit": "hf-internal-testing/tiny-random-vit",
"wavlm": "hf-internal-testing/tiny-random-WavlmModel",
"wav2vec2": "anton-l/wav2vec2-random-tiny-classifier",
"wav2vec2-conformer": "hf-internal-testing/tiny-random-wav2vec2-conformer",
"xlm": "hf-internal-testing/tiny-random-xlm",
"xlm_roberta": "hf-internal-testing/tiny-xlm-roberta",
}


TENSOR_ALIAS_TO_TYPE = {
"pt": torch.Tensor,
"np": np.ndarray,
Expand Down Expand Up @@ -207,9 +252,23 @@ def test_load_from_hub_and_save_stable_diffusion_model(self):

class OVModelForSequenceClassificationIntegrationTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = (
"albert",
"bert",
# "camembert",
"convbert",
# "data2vec_text",
# "deberta_v2",
"distilbert",
"electra",
"flaubert",
"ibert",
# "mobilebert",
# "nystromformer",
"roberta",
"roformer",
"squeezebert",
"xlm",
# "xlm_roberta",
)

@parameterized.expand(SUPPORTED_ARCHITECTURES)
Expand Down Expand Up @@ -415,11 +474,19 @@ def test_pipeline(self, model_arch):

class OVModelForCausalLMIntegrationTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = (
"bart",
"blenderbot",
"blenderbot-small",
"bloom",
# "codegen",
# "data2vec-text", # TODO : enable when enabled in exporters
"gpt2",
"gpt_neo",
"gpt_neox",
# "llama",
"marian",
"opt",
"pegasus",
)
GENERATION_LENGTH = 100
SPEEDUP_CACHE = 1.2
Expand Down Expand Up @@ -449,6 +516,7 @@ def test_pipeline(self, model_arch):
model_id = MODEL_NAMES[model_arch]
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = OVModelForCausalLM.from_pretrained(model_id, from_transformers=True, use_cache=False, compile=False)
model.config.encoder_no_repeat_ngram_size = 0
model.to("cpu")
model.half()
model.compile()
Expand All @@ -467,7 +535,8 @@ def test_multiple_inputs(self, model_arch):
tokenizer.pad_token = tokenizer.eos_token
texts = ["this is a simple input", "this is a second simple input", "this is a third simple input"]
tokens = tokenizer(texts, padding=True, return_tensors="pt")
outputs = model.generate(**tokens, max_new_tokens=20, num_beams=2)
generation_config = GenerationConfig(encoder_no_repeat_ngram_size=0, max_new_tokens=20, num_beams=2)
outputs = model.generate(**tokens, generation_config=generation_config)
self.assertIsInstance(outputs, torch.Tensor)
self.assertEqual(outputs.shape[0], 3)

Expand Down Expand Up @@ -511,9 +580,23 @@ def test_compare_with_and_without_past_key_values(self):

class OVModelForMaskedLMIntegrationTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = (
# "albert",
"bert",
# "camembert",
# "convbert",
# "data2vec_text",
# "deberta",
# "deberta_v2",
"distilbert",
# "electra",
# "flaubert",
# "ibert",
# "mobilebert",
"roberta",
# "roformer",
# "squeezebert",
# "xlm",
# "xlm_roberta",
)

@parameterized.expand(SUPPORTED_ARCHITECTURES)
Expand Down Expand Up @@ -550,7 +633,21 @@ def test_pipeline(self, model_arch):


class OVModelForImageClassificationIntegrationTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = ("vit",)
SUPPORTED_ARCHITECTURES = (
"beit",
"convnext",
# "data2vec_vision",
# "deit",
"levit",
"mobilenet_v1",
"mobilenet_v2",
"mobilevit",
# "poolformer",
"resnet",
# "segformer",
# "swin",
"vit",
)

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_compare_to_transformers(self, model_arch):
Expand Down Expand Up @@ -590,9 +687,15 @@ def test_pipeline(self, model_arch):
class OVModelForSeq2SeqLMIntegrationTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = (
"bart",
# "bigbird_pegasus",
"blenderbot",
"blenderbot-small",
# "longt5",
"m2m_100",
"marian",
"mbart",
"m2m_100",
# "mt5",
"pegasus",
"t5",
)

Expand Down Expand Up @@ -710,7 +813,26 @@ def test_compare_with_and_without_past_key_values(self):


class OVModelForAudioClassificationIntegrationTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = ("wav2vec2",)
SUPPORTED_ARCHITECTURES = (
# "audio_spectrogram_transformer",
# "data2vec_audio",
# "hubert",
# "sew",
# "sew_d",
# "wav2vec2-conformer",
"unispeech",
# "unispeech_sat",
# "wavlm",
"wav2vec2",
# "wav2vec2-conformer",
)

def _generate_random_audio_data(self):
np.random.seed(10)
t = np.linspace(0, 5.0, int(5.0 * 22050), endpoint=False)
# generate pure sine wave at 220 Hz
audio_data = 0.5 * np.sin(2 * np.pi * 220 * t)
return audio_data

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_compare_to_transformers(self, model_arch):
Expand All @@ -720,12 +842,13 @@ def test_compare_to_transformers(self, model_arch):
self.assertIsInstance(ov_model.config, PretrainedConfig)
transformers_model = AutoModelForAudioClassification.from_pretrained(model_id)
preprocessor = AutoFeatureExtractor.from_pretrained(model_id)
wavs = [np.random.random(16000)]
inputs = preprocessor(wavs, sampling_rate=preprocessor.sampling_rate, return_tensors="pt")
inputs = preprocessor(self._generate_random_audio_data(), return_tensors="pt")

with torch.no_grad():
transformers_outputs = transformers_model(**inputs)

for input_type in ["pt", "np"]:
inputs = preprocessor(wavs, sampling_rate=preprocessor.sampling_rate, return_tensors=input_type)
inputs = preprocessor(self._generate_random_audio_data(), return_tensors=input_type)
ov_outputs = ov_model(**inputs)
self.assertIn("logits", ov_outputs)
self.assertIsInstance(ov_outputs.logits, TENSOR_ALIAS_TO_TYPE[input_type])
Expand Down
Loading