Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue to use GPT2 ONNX export with past key values #552

Closed
2 of 4 tasks
jplu opened this issue Dec 6, 2022 · 13 comments · Fixed by #553 or #554
Closed
2 of 4 tasks

Issue to use GPT2 ONNX export with past key values #552

jplu opened this issue Dec 6, 2022 · 13 comments · Fixed by #553 or #554
Labels
bug Something isn't working

Comments

@jplu
Copy link
Contributor

jplu commented Dec 6, 2022

System Info

python: 3.10.6
platform: Ubuntu 22.10
optimum version: 1.5.1
onnxruntime: 1.13.1

Who can help?

@JingyaHuang @ec

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Command line to export a GPT2 model:

python -m optimum.exporters.onnx --model gpt2 --task causal-lm-with-past output/

Gives the following output logs:

Framework not specified. Using pt to export to ONNX.
Using framework PyTorch: 1.13.0+cu117
Overriding 2 configuration item(s)
	- use_cache -> True
	- pad_token_id -> 0
/home/jplu/anaconda3/envs/transformers/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py:796: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if batch_size <= 0:
/home/jplu/anaconda3/envs/transformers/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py:185: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  attn_weights = attn_weights / torch.tensor(
/home/jplu/anaconda3/envs/transformers/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py:185: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  attn_weights = attn_weights / torch.tensor(
/home/jplu/anaconda3/envs/transformers/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py:200: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
Validating ONNX model...
	-[✓] ONNX model output names match reference model (present.1.value, present.0.key, present.6.key, present.6.value, present.5.value, present.8.key, present.0.value, present.2.key, present.5.key, present.10.key, present.9.value, present.10.value, logits, present.4.value, present.7.key, present.11.value, present.3.value, present.3.key, present.4.key, present.2.value, present.1.key, present.9.key, present.11.key, present.8.value, present.7.value)
	- Validating ONNX Model output "logits":
		-[✓] (2, 16, 50257) matches (2, 16, 50257)
		-[x] values not close enough, max diff: 0.0013427734375 (atol: 1e-05)
	- Validating ONNX Model output "present.0.key":
		-[✓] (2, 12, 32, 64) matches (2, 12, 32, 64)
		-[✓] all values close (atol: 1e-05)
	- Validating ONNX Model output "present.0.value":
		-[✓] (2, 12, 32, 64) matches (2, 12, 32, 64)
		-[✓] all values close (atol: 1e-05)
	- Validating ONNX Model output "present.1.key":
		-[✓] (2, 12, 32, 64) matches (2, 12, 32, 64)
		-[✓] all values close (atol: 1e-05)
	- Validating ONNX Model output "present.1.value":
		-[✓] (2, 12, 32, 64) matches (2, 12, 32, 64)
		-[✓] all values close (atol: 1e-05)
	- Validating ONNX Model output "present.2.key":
		-[✓] (2, 12, 32, 64) matches (2, 12, 32, 64)
		-[✓] all values close (atol: 1e-05)
	- Validating ONNX Model output "present.2.value":
		-[✓] (2, 12, 32, 64) matches (2, 12, 32, 64)
		-[✓] all values close (atol: 1e-05)
	- Validating ONNX Model output "present.3.key":
		-[✓] (2, 12, 32, 64) matches (2, 12, 32, 64)
		-[✓] all values close (atol: 1e-05)
	- Validating ONNX Model output "present.3.value":
		-[✓] (2, 12, 32, 64) matches (2, 12, 32, 64)
		-[✓] all values close (atol: 1e-05)
	- Validating ONNX Model output "present.4.key":
		-[✓] (2, 12, 32, 64) matches (2, 12, 32, 64)
		-[✓] all values close (atol: 1e-05)
	- Validating ONNX Model output "present.4.value":
		-[✓] (2, 12, 32, 64) matches (2, 12, 32, 64)
		-[✓] all values close (atol: 1e-05)
	- Validating ONNX Model output "present.5.key":
		-[✓] (2, 12, 32, 64) matches (2, 12, 32, 64)
		-[✓] all values close (atol: 1e-05)
	- Validating ONNX Model output "present.5.value":
		-[✓] (2, 12, 32, 64) matches (2, 12, 32, 64)
		-[✓] all values close (atol: 1e-05)
	- Validating ONNX Model output "present.6.key":
		-[✓] (2, 12, 32, 64) matches (2, 12, 32, 64)
		-[✓] all values close (atol: 1e-05)
	- Validating ONNX Model output "present.6.value":
		-[✓] (2, 12, 32, 64) matches (2, 12, 32, 64)
		-[✓] all values close (atol: 1e-05)
	- Validating ONNX Model output "present.7.key":
		-[✓] (2, 12, 32, 64) matches (2, 12, 32, 64)
		-[✓] all values close (atol: 1e-05)
	- Validating ONNX Model output "present.7.value":
		-[✓] (2, 12, 32, 64) matches (2, 12, 32, 64)
		-[✓] all values close (atol: 1e-05)
	- Validating ONNX Model output "present.8.key":
		-[✓] (2, 12, 32, 64) matches (2, 12, 32, 64)
		-[✓] all values close (atol: 1e-05)
	- Validating ONNX Model output "present.8.value":
		-[✓] (2, 12, 32, 64) matches (2, 12, 32, 64)
		-[✓] all values close (atol: 1e-05)
	- Validating ONNX Model output "present.9.key":
		-[✓] (2, 12, 32, 64) matches (2, 12, 32, 64)
		-[✓] all values close (atol: 1e-05)
	- Validating ONNX Model output "present.9.value":
		-[✓] (2, 12, 32, 64) matches (2, 12, 32, 64)
		-[✓] all values close (atol: 1e-05)
	- Validating ONNX Model output "present.10.key":
		-[✓] (2, 12, 32, 64) matches (2, 12, 32, 64)
		-[✓] all values close (atol: 1e-05)
	- Validating ONNX Model output "present.10.value":
		-[✓] (2, 12, 32, 64) matches (2, 12, 32, 64)
		-[✓] all values close (atol: 1e-05)
	- Validating ONNX Model output "present.11.key":
		-[✓] (2, 12, 32, 64) matches (2, 12, 32, 64)
		-[✓] all values close (atol: 1e-05)
	- Validating ONNX Model output "present.11.value":
		-[✓] (2, 12, 32, 64) matches (2, 12, 32, 64)
		-[✓] all values close (atol: 1e-05)
An error occured, but the model was saved at: model_repository/gpt2/1/model.onnx

Eventhough there is an error in the close values validation, that's ok. Now I would like to run the model with the following Python:

from optimum.onnxruntime import ORTModelForCausalLM
from transformers import GPT2Tokenizer

model = ORTModelForCausalLM.from_pretrained("output/", from_transformers=False, use_cache=True)
tokenizer = GPT2Tokenizer.from_pretrained("output/")
tokens = tokenizer("My name is Julien and I like", return_tensors="pt")
outputs_model = model.generate(**tokens)

And I get the following error:

/home/jplu/anaconda3/envs/transformers/lib/python3.10/site-packages/transformers/generation_utils.py:1359: UserWarning: Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to 20 (`self.config.max_length`). Controlling `max_length` via the config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we recommend using `max_new_tokens` to control the maximum length of the generation.
  warnings.warn(
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/jplu/anaconda3/envs/transformers/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/jplu/anaconda3/envs/transformers/lib/python3.10/site-packages/transformers/generation_utils.py", line 1490, in generate
    return self.greedy_search(
  File "/home/jplu/anaconda3/envs/transformers/lib/python3.10/site-packages/transformers/generation_utils.py", line 2233, in greedy_search
    outputs = self(
  File "/home/jplu/anaconda3/envs/transformers/lib/python3.10/site-packages/optimum/modeling_base.py", line 60, in __call__
    return self.forward(*args, **kwargs)
  File "/home/jplu/anaconda3/envs/transformers/lib/python3.10/site-packages/optimum/onnxruntime/modeling_ort.py", line 1454, in forward
    outputs = self.model.run(None, onnx_inputs)
  File "/home/jplu/anaconda3/envs/transformers/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 196, in run
    raise ValueError("Model requires {} inputs. Input Feed contains {}".format(num_required_inputs, num_inputs))
ValueError: Model requires 26 inputs. Input Feed contains 2

Do I have to randomly feed myself the past_key_values.X.value and past_key_values.X.keys?

When I try to do this directly with onnxruntime, I also get an error. Here what I do:

import onnxruntime as ort
from transformers import GPT2Tokenizer
import numpy as np

sess = ort.InferenceSession('output/model.onnx', providers=["CPUExecutionProvider"])
tokenizer = GPT2Tokenizer.from_pretrained("output/")
tokens = dict(tokenizer("My name is Julien and I like", return_tensors="np"))
shape = (1, 12, len(tokens["input_ids"][0]), 64)

for i in range(12):
    tokens[f"past_key_values.{i}.key"] = np.random.uniform(0, 1, shape).astype(np.float32)
    tokens[f"past_key_values.{i}.value"] = np.random.uniform(0, 1, shape).astype(np.float32)

sess.run(None, tokens)

And I get the following error:

2022-12-06 16:42:17.603173515 [E:onnxruntime:, sequential_executor.cc:369 Execute] Non-zero status code returned while running Add node. Name:'/transformer/h.0/attn/Add' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_wise_ops.h:503 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 8 by 16

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/jplu/anaconda3/envs/transformers/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 200, in run
    return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Add node. Name:'/transformer/h.0/attn/Add' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_wise_ops.h:503 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 8 by 16

Expected behavior

I expect to have a proper generation and usage with onnxruntime. The final goal is to use it through a Triton server.

I certainly miss something, but the documentation is not clear on how to properly use seq2seq and causal-lm with past-key-values either directly with onnxruntime or with optimum.

Thanks a lot in advance for all the advices you could provide :)

@jplu jplu added the bug Something isn't working label Dec 6, 2022
@fxmarty
Copy link
Collaborator

fxmarty commented Dec 6, 2022

@jplu Are you using the Optimum release version (1.5.1)? If so, directly feeding an encoder-decoder exported model from optimum.exporters.onnx to ORTModelForCausalLM will not work.

Can you try instead to do directly: model = ORTModelForCausalLM.from_pretrained("pytorch-model-repo-or-folder", from_transformers=True, use_cache=True). It is the from_transformers=True that matter here, and will intenally handle the export from PyTorch to ONNX (encoder and decoder separately).

If you would like to go first through the path of optimum.exporters.onnx, and then to load with from_transformers=False (i.e. directly from ONNX files), I suggest to wait for the next release of Optimum, where #497 will be included and allow to use the argument --for-ort to export Seq2seq models to ONNX and be smoothly consumed in the ORTModel integration.

So using your command, that would look like:

python -m optimum.exporters.onnx --model gpt2 --for-ort --task causal-lm-with-past output/

@jplu
Copy link
Contributor Author

jplu commented Dec 6, 2022

Oh I see! This is crystal clear, thanks a lot for your light. I will wait the next release then. Any ETA?

Last question, will it be possible to use the exported ONNX file generated by the last command you give directly through ONNXRuntime? As I guess now I get raised an error because of the same problem right?

@fxmarty
Copy link
Collaborator

fxmarty commented Dec 6, 2022

I think this week or next week is a good bet!

Yes, you'll be able to use the exported ONNX file directly through ONNX Runtime. What the ORTModel integration helps with is the integration of the generation, that can be a bit involved if you want to reimplement it, but you definitely can! A bit more details here: https://discuss.huggingface.co/t/export-m2m100-model-to-onnx/17694/11?u=fxmarty

Longer term, we're thinking it could be useful to have an export of ONNX model that can handle the generation end-to-end: #526

@jplu
Copy link
Contributor Author

jplu commented Dec 6, 2022

Perfect! Waiting a single week is perfectly OK 👌 By curiosity I will test with the main branch if I succeed to get it work, and will let you know in this thread if I encounter any issue.

Indeed, the generation is the hardest part to handle, on my side basically I host all my ONNX models into a Triton server, and I have TritonModelForXXXXs like your ORTModelForXXXX that handle gRPC calls and can be used with pipelines. It does the work but the counterpart is that it generates a lot of network calls. That's why I want to investigate to use their Triton Python backend with optimum to see if it works better.

The ideal world, the dream, would be indeed a true end-to-end model that handles tokenization+inference for simple encoders and in case of decoders and encoders-decoders models tok+inf+generation.

@jplu
Copy link
Contributor Author

jplu commented Dec 7, 2022

I will wait the official release, it seems to be a bit unstable for now:

python -m optimum.exporters.onnx --model gpt2 --for-ort --task causal-lm-with-past output/

gives:

Traceback (most recent call last):
  File "/home/jplu/anaconda3/envs/transformers/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/jplu/anaconda3/envs/transformers/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/jplu/anaconda3/envs/transformers/lib/python3.10/site-packages/optimum/exporters/onnx/__main__.py", line 23, in <module>
    from ...utils.save_utils import maybe_save_tokenizer_or_processor_or_feature_extractor
ImportError: cannot import name 'maybe_save_tokenizer_or_processor_or_feature_extractor' from 'optimum.utils.save_utils' (/home/jplu/anaconda3/envs/transformers/lib/python3.10/site-packages/optimum/utils/save_utils.py)

The way to install was:

pip install -U git+https://github.com/huggingface/optimum.git@main

@fxmarty
Copy link
Collaborator

fxmarty commented Dec 7, 2022

cc @michaelbenayoun we should add tests for the CLI

@jplu
Copy link
Contributor Author

jplu commented Dec 7, 2022

Thanks @fxmarty for the fix! Nevertheless, these two piece of code:

from optimum.onnxruntime import ORTModelForCausalLM
from transformers import GPT2Tokenizer

model = ORTModelForCausalLM.from_pretrained("output/", from_transformers=False, use_cache=True)
tokenizer = GPT2Tokenizer.from_pretrained("output/")
tokens = tokenizer("My name is Julien and I like", return_tensors="pt")
outputs_model = model.generate(**tokens)

And

import onnxruntime as ort
from transformers import GPT2Tokenizer
import numpy as np

sess = ort.InferenceSession('output/model.onnx', providers=["CPUExecutionProvider"])
tokenizer = GPT2Tokenizer.from_pretrained("output/")
tokens = dict(tokenizer("My name is Julien and I like", return_tensors="np"))
shape = (1, 12, len(tokens["input_ids"][0]), 64)

for i in range(12):
    tokens[f"past_key_values.{i}.key"] = np.random.uniform(0, 1, shape).astype(np.float32)
    tokens[f"past_key_values.{i}.value"] = np.random.uniform(0, 1, shape).astype(np.float32)

sess.run(None, tokens)

Still raises the errors:

Traceback (most recent call last):
  File "/home/jplu/dev/buster/cluster-inference/test_optimum.py", line 4, in <module>
    model = ORTModelForCausalLM.from_pretrained("output/", from_transformers=False, use_cache=True)
  File "/home/jplu/anaconda3/envs/transformers/lib/python3.10/site-packages/optimum/onnxruntime/modeling_ort.py", line 487, in from_pretrained
    return super().from_pretrained(
  File "/home/jplu/anaconda3/envs/transformers/lib/python3.10/site-packages/optimum/modeling_base.py", line 325, in from_pretrained
    return from_pretrained_method(
  File "/home/jplu/anaconda3/envs/transformers/lib/python3.10/site-packages/optimum/onnxruntime/modeling_decoder.py", line 545, in _from_pretrained
    decoder_file_name = infer_filename(r"(.*)?decoder((?!with_past).)*?\.onnx", "decoder_file_name")
  File "/home/jplu/anaconda3/envs/transformers/lib/python3.10/site-packages/optimum/onnxruntime/modeling_decoder.py", line 534, in infer_filename
    raise FileNotFoundError(f"Could not find any ONNX model file in {path}")
FileNotFoundError: Could not find any ONNX model file in output

And:

2022-12-07 11:41:07.446372853 [E:onnxruntime:, sequential_executor.cc:369 Execute] Non-zero status code returned while running Add node. Name:'/transformer/h.0/attn/Add' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_wise_ops.h:503 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 8 by 16

Traceback (most recent call last):
  File "/home/jplu/dev/buster/cluster-inference/test_ort.py", line 14, in <module>
    sess.run(None, tokens)
  File "/home/jplu/anaconda3/envs/transformers/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 200, in run
    return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Add node. Name:'/transformer/h.0/attn/Add' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_wise_ops.h:503 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 8 by 16

With the ONNX model generated by:

python -m optimum.exporters.onnx --model gpt2 --for-ort --task causal-lm-with-past output/

@fxmarty
Copy link
Collaborator

fxmarty commented Dec 7, 2022

Yes apologizes, merging the previous PR closed this automatically!

Basically gpt2 is decoder-only, and the support was not yet implemented: #554

However, if you try for example

python -m optimum.exporters.onnx --model valhalla/m2m100_tiny_random --for-ort m2m100_tiny_onnx_ort

or with a larger model:

python -m optimum.exporters.onnx --model facebook/m2m100_418M --task seq2seq-lm-with-past --for-ort m2m100_onnx_ort

you will see the different files for encoder / decoder / decoder with past. Those can be fed directly into an ORTModel:

from transformers import AutoTokenizer, pipeline
from optimum.onnxruntime import ORTModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("/path/to/m2m100_onnx_ort")

model = ORTModelForSeq2SeqLM.from_pretrained("/path/to/m2m100_onnx_ort", from_transformers=False, use_cache=True)
tokens = tokenizer("My name is Felix and I like you", return_tensors="pt")

outputs_model = model.generate(**tokens, forced_bos_token_id=tokenizer.get_lang_id("fr"))

print(tokenizer.decode(outputs_model[0]))

@jplu
Copy link
Contributor Author

jplu commented Dec 7, 2022

It is ok, no worries!

I tried with the model you suggested, and indeed, I get all the three files. And each works like a charm. Even in pure ORT with:

import onnxruntime as ort
from transformers import AutoTokenizer
import numpy as np

sess_encoder = ort.InferenceSession('output/encoder_model.onnx', providers=["CPUExecutionProvider"])
sess_decoder = ort.InferenceSession('output/decoder_model.onnx', providers=["CPUExecutionProvider"])
sess_decoder_pkv = ort.InferenceSession('output/decoder_with_past_model.onnx', providers=["CPUExecutionProvider"])
tokenizer = AutoTokenizer.from_pretrained("output/")
inputs_encoder = dict(tokenizer("My name is Julien and I like", return_tensors="np"))
outputs_encoder = sess_encoder.run(None, inputs_encoder)
inputs_decoder = {
    "encoder_hidden_states": outputs_encoder[0],
    "encoder_attention_mask": inputs_encoder["attention_mask"],
    "input_ids": inputs_encoder["input_ids"]
}

sess_decoder.run(None, inputs_decoder)

inputs_decoder_pkv = inputs_decoder

shape = (1, 16, len(inputs_encoder["input_ids"][0]), 64)

for i in range(12):
    inputs_decoder_pkv[f"past_key_values.{i}.encoder.key"] = np.random.uniform(0, 1, shape).astype(np.float32)
    inputs_decoder_pkv[f"past_key_values.{i}.encoder.value"] = np.random.uniform(0, 1, shape).astype(np.float32)
    inputs_decoder_pkv[f"past_key_values.{i}.decoder.key"] = np.random.uniform(0, 1, shape).astype(np.float32)
    inputs_decoder_pkv[f"past_key_values.{i}.decoder.value"] = np.random.uniform(0, 1, shape).astype(np.float32)

sess_decoder_pkv.run(None, inputs_decoder_pkv)

I keep this issue open as it was mostly about decoder only, but I'm sure it will be ok once your PR merged!

@fxmarty
Copy link
Collaborator

fxmarty commented Dec 8, 2022

Hi @jplu , #554 is merged and hopefully optimum.exporters.onnx and ORTModel should now be gracefully interfaced for decoder-only models as well.

We'll do a release next week!

@jplu
Copy link
Contributor Author

jplu commented Dec 8, 2022

Hi @fxmarty!! Thanks a lot for the addition, I have updated the package. This piece of code:

from optimum.onnxruntime import ORTModelForCausalLM
from transformers import GPT2Tokenizer

model = ORTModelForCausalLM.from_pretrained("output/", from_transformers=False, use_cache=True)
tokenizer = GPT2Tokenizer.from_pretrained("output/")
tokens = tokenizer("My name is Julien and I like", return_tensors="pt")
outputs_model = model.generate(**tokens)

Now perfectly works!! But, unfortunately, this one:

import onnxruntime as ort
from transformers import GPT2Tokenizer
import numpy as np

sess = ort.InferenceSession('output/decoder_with_past_model.onnx, providers=["CPUExecutionProvider"])
tokenizer = GPT2Tokenizer.from_pretrained("output/")
tokens = dict(tokenizer("My name is Julien and I like", return_tensors="np"))
shape = (1, 12, len(tokens["input_ids"][0]), 64)

for i in range(12):
    tokens[f"past_key_values.{i}.key"] = np.random.uniform(0, 1, shape).astype(np.float32)
    tokens[f"past_key_values.{i}.value"] = np.random.uniform(0, 1, shape).astype(np.float32)

sess.run(None, tokens)

Still raises the exact same error for me:

2022-12-07 11:41:07.446372853 [E:onnxruntime:, sequential_executor.cc:369 Execute] Non-zero status code returned while running Add node. Name:'/transformer/h.0/attn/Add' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_wise_ops.h:503 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 8 by 16

Traceback (most recent call last):
  File "/home/jplu/dev/buster/cluster-inference/test_ort.py", line 14, in <module>
    sess.run(None, tokens)
  File "/home/jplu/anaconda3/envs/transformers/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 200, in run
    return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Add node. Name:'/transformer/h.0/attn/Add' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_wise_ops.h:503 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 8 by 16

The models are still generated with:

python -m optimum.exporters.onnx --model gpt2 --for-ort --task causal-lm-with-past output

Anything I'm doing wrong?

@fxmarty
Copy link
Collaborator

fxmarty commented Dec 9, 2022

Yes, I think this is expected.

Looking at the shapes in generate()'s beam search or greedy search: https://github.com/huggingface/transformers/blob/9a6c6ef97fa5df4b1fb8dbc9e8c10ee3a9ed7e2a/src/transformers/generation/utils.py#L2285 (with vanilla transformers model)

dict_keys(['input_ids', 'past_key_values', 'use_cache', 'position_ids', 'attention_mask', 'token_type_ids'])
input_ids torch.Size([1, 8])
past_key_values None
use_cache None
position_ids torch.Size([1, 8])
attention_mask torch.Size([1, 8])
token_type_ids None
dict_keys(['input_ids', 'past_key_values', 'use_cache', 'position_ids', 'attention_mask', 'token_type_ids'])
input_ids torch.Size([1, 1])
         past_key_values torch.Size([1, 12, 8, 64])
         past_key_values torch.Size([1, 12, 8, 64])
         past_key_values torch.Size([1, 12, 8, 64])
         past_key_values torch.Size([1, 12, 8, 64])
         past_key_values torch.Size([1, 12, 8, 64])
         past_key_values torch.Size([1, 12, 8, 64])
         past_key_values torch.Size([1, 12, 8, 64])
         past_key_values torch.Size([1, 12, 8, 64])
         past_key_values torch.Size([1, 12, 8, 64])
         past_key_values torch.Size([1, 12, 8, 64])
         past_key_values torch.Size([1, 12, 8, 64])
         past_key_values torch.Size([1, 12, 8, 64])
         past_key_values torch.Size([1, 12, 8, 64])
         past_key_values torch.Size([1, 12, 8, 64])
         past_key_values torch.Size([1, 12, 8, 64])
         past_key_values torch.Size([1, 12, 8, 64])
         past_key_values torch.Size([1, 12, 8, 64])
         past_key_values torch.Size([1, 12, 8, 64])
         past_key_values torch.Size([1, 12, 8, 64])
         past_key_values torch.Size([1, 12, 8, 64])
         past_key_values torch.Size([1, 12, 8, 64])
         past_key_values torch.Size([1, 12, 8, 64])
         past_key_values torch.Size([1, 12, 8, 64])
         past_key_values torch.Size([1, 12, 8, 64])
use_cache None
position_ids torch.Size([1, 1])
attention_mask torch.Size([1, 9])
token_type_ids None
dict_keys(['input_ids', 'past_key_values', 'use_cache', 'position_ids', 'attention_mask', 'token_type_ids'])
input_ids torch.Size([1, 1])
         past_key_values torch.Size([1, 12, 9, 64])
         past_key_values torch.Size([1, 12, 9, 64])
         past_key_values torch.Size([1, 12, 9, 64])
         past_key_values torch.Size([1, 12, 9, 64])
         past_key_values torch.Size([1, 12, 9, 64])
         past_key_values torch.Size([1, 12, 9, 64])
         past_key_values torch.Size([1, 12, 9, 64])
         past_key_values torch.Size([1, 12, 9, 64])
         past_key_values torch.Size([1, 12, 9, 64])
         past_key_values torch.Size([1, 12, 9, 64])
         past_key_values torch.Size([1, 12, 9, 64])
         past_key_values torch.Size([1, 12, 9, 64])
         past_key_values torch.Size([1, 12, 9, 64])
         past_key_values torch.Size([1, 12, 9, 64])
         past_key_values torch.Size([1, 12, 9, 64])
         past_key_values torch.Size([1, 12, 9, 64])
         past_key_values torch.Size([1, 12, 9, 64])
         past_key_values torch.Size([1, 12, 9, 64])
         past_key_values torch.Size([1, 12, 9, 64])
         past_key_values torch.Size([1, 12, 9, 64])
         past_key_values torch.Size([1, 12, 9, 64])
         past_key_values torch.Size([1, 12, 9, 64])
         past_key_values torch.Size([1, 12, 9, 64])
         past_key_values torch.Size([1, 12, 9, 64])
use_cache None
position_ids torch.Size([1, 1])
attention_mask torch.Size([1, 10])
token_type_ids None

So this code works:

import onnxruntime as ort
from transformers import GPT2Tokenizer
import numpy as np

sess = ort.InferenceSession("/home/fxmarty/hf_internship/optimum/gpt2_onnx/decoder_with_past_model.onnx", providers=["CPUExecutionProvider"])
tokenizer = GPT2Tokenizer.from_pretrained("/home/fxmarty/hf_internship/optimum/gpt2_onnx")
tokens = dict(tokenizer("My name is Julien and I like", return_tensors="np"))
shape = (1, 12, len(tokens["input_ids"][0]) - 1, 64)

tokens["input_ids"] = np.array([[4]], dtype=np.int64)

for i in range(12):
    tokens[f"past_key_values.{i}.key"] = np.random.uniform(0, 1, shape).astype(np.float32)
    tokens[f"past_key_values.{i}.value"] = np.random.uniform(0, 1, shape).astype(np.float32)

sess.run(None, tokens)

@jplu
Copy link
Contributor Author

jplu commented Dec 9, 2022

Oh I missed that part! Thanks a lot for correcting me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
2 participants