Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable inference with a merged decoder in ORTModelForCausalLM #647

Merged
merged 78 commits into from
Feb 15, 2023
Merged
Show file tree
Hide file tree
Changes from 64 commits
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
ec049f4
Add save option
JingyaHuang Dec 27, 2022
d899e34
Add test for saving
JingyaHuang Dec 27, 2022
eb0d2ef
Fix test path
JingyaHuang Dec 27, 2022
a8b98b5
Allow str path for merging
JingyaHuang Dec 30, 2022
d3a9a1d
Add Path and remove merged names
JingyaHuang Jan 2, 2023
e399e8e
Merge branch 'main' into enable-merged-modeling
JingyaHuang Jan 2, 2023
b76d3a1
Finish adapting ORTModelDecoder
JingyaHuang Jan 2, 2023
04ff464
Prepare extra inputs
JingyaHuang Jan 3, 2023
af3461b
do not store merged in place
JingyaHuang Jan 3, 2023
a1d422c
Support I/O binding for merged
JingyaHuang Jan 4, 2023
0a5dd30
Extend to multiple patterns
JingyaHuang Jan 4, 2023
a7ec6ef
Add test for inference
JingyaHuang Jan 4, 2023
85603ee
Fix test
JingyaHuang Jan 4, 2023
2a8f3ca
update test
JingyaHuang Jan 4, 2023
167ae30
Merge branch 'main' into enable-merged-modeling
JingyaHuang Jan 4, 2023
b5fe0a3
Remove prints
JingyaHuang Jan 4, 2023
68e0025
Merge branch 'master' into enable-merged-modeling
fxmarty Feb 6, 2023
babed4b
add back missing method
fxmarty Feb 6, 2023
86cfc0a
fix mess
fxmarty Feb 6, 2023
f5051fb
fix post merge
fxmarty Feb 6, 2023
fcec713
none defaut
fxmarty Feb 7, 2023
cd2deb2
fix
fxmarty Feb 8, 2023
102d0c8
Merge branch 'master' into enable-merged-modeling
fxmarty Feb 8, 2023
496395c
fix errors
fxmarty Feb 8, 2023
b0d9c9a
remove nonsense tests
fxmarty Feb 8, 2023
995a976
fix doc
fxmarty Feb 8, 2023
6bd97b6
ongoing
fxmarty Feb 8, 2023
5c3b11b
debug
fxmarty Feb 9, 2023
0bcd528
Merge branch 'master' into enable-merged-modeling
fxmarty Feb 10, 2023
44f7600
fix style
fxmarty Feb 10, 2023
698bd70
fix post merge
fxmarty Feb 10, 2023
5ada3ec
hopefully working!
fxmarty Feb 10, 2023
fb3feae
add tests
fxmarty Feb 10, 2023
37acd6b
rename
fxmarty Feb 10, 2023
0eeb428
add constants
fxmarty Feb 10, 2023
ea07a0e
fix test
fxmarty Feb 10, 2023
14b616d
fix names
fxmarty Feb 10, 2023
5cebc22
ort support
fxmarty Feb 10, 2023
194108c
wip
fxmarty Feb 10, 2023
9c92ecc
fix
fxmarty Feb 10, 2023
04db6c4
fix
fxmarty Feb 10, 2023
164af2a
tests for merged
fxmarty Feb 10, 2023
5edb255
Merge branch 'master' into enable-merged-modeling
fxmarty Feb 10, 2023
4673963
stype
fxmarty Feb 10, 2023
d11d1ec
fix
fxmarty Feb 10, 2023
df6ef1d
fix merge errors
fxmarty Feb 13, 2023
67874f3
fix tests
fxmarty Feb 13, 2023
61878ce
fix test
fxmarty Feb 13, 2023
ead4702
remove irrelevant test
fxmarty Feb 13, 2023
79aacee
Update optimum/exporters/onnx/__main__.py
fxmarty Feb 13, 2023
badee2b
Update optimum/exporters/onnx/base.py
fxmarty Feb 13, 2023
739d549
Update optimum/exporters/onnx/config.py
fxmarty Feb 13, 2023
c176a8c
Update optimum/exporters/onnx/config.py
fxmarty Feb 13, 2023
2b4fda5
Update optimum/exporters/onnx/config.py
fxmarty Feb 13, 2023
d3cba91
Update optimum/exporters/onnx/config.py
fxmarty Feb 13, 2023
edd8aab
Update optimum/exporters/onnx/config.py
fxmarty Feb 13, 2023
0481ab2
Update optimum/onnxruntime/modeling_decoder.py
fxmarty Feb 13, 2023
8f8873b
Update tests/exporters/onnx/test_exporters_onnx_cli.py
fxmarty Feb 13, 2023
4d0ef00
fix on suggestions
fxmarty Feb 13, 2023
d3f68eb
fix import of dummyinputgenerators
fxmarty Feb 13, 2023
2237300
skip unwanted tests
fxmarty Feb 13, 2023
4b15d40
fix diffusion model
fxmarty Feb 13, 2023
3734803
Merge branch 'master' into enable-merged-modeling
fxmarty Feb 13, 2023
d00c117
fix tests
fxmarty Feb 14, 2023
44df1c7
Update optimum/commands/export/onnx.py
fxmarty Feb 14, 2023
72beefc
Update optimum/onnxruntime/modeling_decoder.py
fxmarty Feb 14, 2023
f9a4c46
fix last tests
fxmarty Feb 14, 2023
2b44e6d
Update optimum/onnx/graph_transformations.py
fxmarty Feb 14, 2023
8619f7c
Update optimum/onnx/graph_transformations.py
fxmarty Feb 14, 2023
2fc5d46
Update optimum/onnx/graph_transformations.py
fxmarty Feb 14, 2023
7a1b5b0
Update optimum/onnx/graph_transformations.py
fxmarty Feb 14, 2023
aaa9501
Update optimum/onnxruntime/modeling_decoder.py
fxmarty Feb 14, 2023
adf349a
fix signature and docstrings
fxmarty Feb 15, 2023
81cfd98
add error message if post process fails
fxmarty Feb 15, 2023
0490518
Merge branch 'master' into enable-merged-modeling
fxmarty Feb 15, 2023
3452fa5
tiny fix
fxmarty Feb 15, 2023
e41a7c2
last fixes
fxmarty Feb 15, 2023
4f8d7d5
typo
fxmarty Feb 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 9 additions & 1 deletion optimum/commands/export/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,15 @@ def parse_args_onnx(parser):
optional_group.add_argument(
"--trust-remote-code",
action="store_true",
help="Allow to use custom code for the modeling hosted in the model repository. This option should only be set for repositories you trust and in which you have read the code, as it will execute on your local machine arbitrary code present in the model repository.",
help="Allows to use custom code for the modeling hosted in the model repository. This option should only be set for repositories you trust and in which you have read the code, as it will execute on your local machine arbitrary code present in the model repository.",
)
optional_group.add_argument(
"--no-post-process",
action="store_true",
help=(
"Allows to disable any post-processing done by default on the exported ONNX models. This is for example the merging of decoder"
" and decoder-with-past into a single ONNX with If node."
fxmarty marked this conversation as resolved.
Show resolved Hide resolved
fxmarty marked this conversation as resolved.
Show resolved Hide resolved
),
)

input_group = parser.add_argument_group(
Expand Down
141 changes: 64 additions & 77 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ..error_utils import AtolError, OutputMatchError, ShapeError
from ..tasks import TasksManager
from .base import OnnxConfigWithPast
from .convert import export, export_models, validate_model_outputs, validate_models_outputs
from .convert import export_models, validate_models_outputs
from .utils import (
get_decoder_models_for_export,
get_encoder_decoder_models_for_export,
Expand All @@ -43,10 +43,9 @@ def main():

# Retrieve CLI arguments
args = parser.parse_args()
args.output = args.output.joinpath("model.onnx")

if not args.output.parent.exists():
args.output.parent.mkdir(parents=True)
if not args.output.exists():
args.output.mkdir(parents=True)

if args.for_ort:
logger.warning(
Expand Down Expand Up @@ -88,10 +87,10 @@ def main():
else:
logger.info(
f"The task `{task}` was manually specified, and past key values will not be reused in the decoding."
f"Please pass `--task {task}-with-past` to export using the past key values."
f" if needed, please pass `--task {task}-with-past` to export using the past key values."
)

if task == "auto":
if args.task == "auto":
logger.info(f"Automatic task detection to {task}.")

if task != "stable-diffusion":
Expand Down Expand Up @@ -130,94 +129,82 @@ def main():
args.atol = args.atol[task.replace("-with-past", "")]

# Saving the model config and preprocessor as this is needed sometimes.
model.config.save_pretrained(args.output.parent)
maybe_save_preprocessors(args.model, args.output.parent)
model.config.save_pretrained(args.output)
maybe_save_preprocessors(args.model, args.output)

if task == "stable-diffusion":
onnx_files_subpaths = [
"text_encoder/model.onnx",
"unet/model.onnx",
"vae_encoder/model.onnx",
"vae_decoder/model.onnx",
]
models_and_onnx_configs = get_stable_diffusion_models_for_export(model)
# Saving the additional components needed to perform inference.
model.tokenizer.save_pretrained(args.output.joinpath("tokenizer"))
model.scheduler.save_pretrained(args.output.joinpath("scheduler"))
model.feature_extractor.save_pretrained(args.output.joinpath("feature_extractor"))
model.save_config(args.output)
else:
if model.config.is_encoder_decoder and task.startswith("causal-lm"):
raise ValueError(
f"model.config.is_encoder_decoder is True and task is `{task}`, which are incompatible. If the task was auto-inferred, please fill a bug report"
f"at https://github.com/huggingface/optimum, if --task was explicitely passed, make sure you selected the right task for the model,"
f" referring to `optimum.exporters.tasks.TaskManager`'s `_TASKS_TO_AUTOMODELS`."
)

if task == "stable-diffusion" or (
task.startswith(("causal-lm", "seq2seq-lm", "speech2seq-lm", "vision2seq-lm", "default-with-past"))
and not args.monolith
):
if task == "stable-diffusion":
output_names = [
"text_encoder/model.onnx",
"unet/model.onnx",
"vae_encoder/model.onnx",
"vae_decoder/model.onnx",
]
models_and_onnx_configs = get_stable_diffusion_models_for_export(model)
# Saving the additional components needed to perform inference.
model.tokenizer.save_pretrained(args.output.parent.joinpath("tokenizer"))
model.scheduler.save_pretrained(args.output.parent.joinpath("scheduler"))
model.feature_extractor.save_pretrained(args.output.parent.joinpath("feature_extractor"))
model.save_config(args.output.parent)
onnx_files_subpaths = None
if (
model.config.is_encoder_decoder
and task.startswith(("seq2seq-lm", "speech2seq-lm", "vision2seq-lm", "default-with-past"))
and not args.monolith
):
models_and_onnx_configs = get_encoder_decoder_models_for_export(model, onnx_config)
elif task.startswith("causal-lm") and not args.monolith:
models_and_onnx_configs = get_decoder_models_for_export(model, onnx_config)
else:
if model.config.is_encoder_decoder and task.startswith("causal-lm"):
raise ValueError(
f"model.config.is_encoder_decoder is True and task is `{task}`, which are incompatible. If the task was auto-inferred, please fill a bug report"
f"at https://github.com/huggingface/optimum, if --task was explicitely passed, make sure you selected the right task for the model,"
f" referring to `optimum.exporters.tasks.TaskManager`'s `_TASKS_TO_AUTOMODELS`."
)
if model.config.is_encoder_decoder:
models_and_onnx_configs = get_encoder_decoder_models_for_export(model, onnx_config)
else:
models_and_onnx_configs = get_decoder_models_for_export(model, onnx_config)
output_names = None
models_and_onnx_configs = {"model": (model, onnx_config)}

_, onnx_outputs = export_models(
models_and_onnx_configs=models_and_onnx_configs,
opset=args.opset,
output_dir=args.output,
output_names=onnx_files_subpaths,
input_shapes=input_shapes,
device=args.device,
)

onnx_inputs, onnx_outputs = export_models(
models_and_onnx_configs=models_and_onnx_configs,
opset=args.opset,
output_dir=args.output.parent,
output_names=output_names,
input_shapes=input_shapes,
device=args.device,
# Optionally post process the obtained ONNX file(s), for example to merge the decoder / decoder with past if any
# TODO: treating stable diffusion separately is quite ugly
if not args.no_post_process and task != "stable-diffusion":
models_and_onnx_configs, onnx_files_subpaths = onnx_config.post_process_exported_models(
args.output, models_and_onnx_configs, onnx_files_subpaths
)
else:
onnx_inputs, onnx_outputs = export(
model=model,
config=onnx_config,
output=args.output,
opset=args.opset,

try:
validate_models_outputs(
models_and_onnx_configs=models_and_onnx_configs,
onnx_named_outputs=onnx_outputs,
atol=args.atol,
output_dir=args.output,
onnx_files_subpaths=onnx_files_subpaths,
input_shapes=input_shapes,
device=args.device,
)

try:
if task == "stable-diffusion" or (
task.startswith(("causal-lm", "seq2seq-lm", "speech2seq-lm", "vision2seq-lm", "default-with-past"))
and not args.monolith
):
validate_models_outputs(
models_and_onnx_configs=models_and_onnx_configs,
onnx_named_outputs=onnx_outputs,
atol=args.atol,
output_dir=args.output.parent,
output_names=output_names,
device=args.device,
)
else:
validate_model_outputs(
config=onnx_config,
reference_model=model,
onnx_model=args.output,
onnx_named_outputs=onnx_outputs,
atol=args.atol,
device=args.device,
)

logger.info(f"The ONNX export succeeded and the exported model was saved at: {args.output.parent.as_posix()}")
logger.info(f"The ONNX export succeeded and the exported model was saved at: {args.output.as_posix()}")
except ShapeError as e:
raise e
except AtolError as e:
logger.warning(
f"The ONNX export succeeded with the warning: {e}.\n The exported model was saved at: {args.output.parent.as_posix()}"
f"The ONNX export succeeded with the warning: {e}.\n The exported model was saved at: {args.output.as_posix()}"
)
except OutputMatchError as e:
logger.warning(
f"The ONNX export succeeded with the warning: {e}.\n The exported model was saved at: {args.output.parent.as_posix()}"
f"The ONNX export succeeded with the warning: {e}.\n The exported model was saved at: {args.output.as_posix()}"
)
except Exception as e:
logger.error(
f"An error occured with the error message: {e}.\n The exported model was saved at: {args.output.parent.as_posix()}"
f"An error occured with the error message: {e}.\n The exported model was saved at: {args.output.as_posix()}"
)


Expand Down
36 changes: 29 additions & 7 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import re
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import onnx
from onnxruntime import InferenceSession
Expand All @@ -46,10 +46,6 @@
logger = logging.get_logger(__name__)


# 2 Gb
EXTERNAL_DATA_FORMAT_SIZE_LIMIT = 2 * 1024 * 1024 * 1024


@dataclasses.dataclass
class PatchingSpec:
"""
Expand Down Expand Up @@ -462,14 +458,28 @@ def generate_dummy_inputs_for_validation(self, reference_model_inputs: Dict[str,
"""
Generates inputs for ONNX Runtime using the reference model inputs. Override this to run inference with seq2seq
models which have the encoder and decoder exported as separate ONNX files.

Args:
reference_model_inputs ([`Dict[str, Tensor]`):
Reference inputs for the model.

Returns:
`Dict[str, Tensor]`: The mapping holding the kwargs to provide to the model's forward function
"""
return reference_model_inputs

def post_process_exported_models(
self, path: "Path", models_and_onnx_configs: Tuple, onnx_files_subpaths: List[str]
):
"""
Performs any model-specific post-processing on the ONNX.

Args:
path (`Path`):
Path to the directory of the stored ONNX model.
"""
fxmarty marked this conversation as resolved.
Show resolved Hide resolved
return models_and_onnx_configs, onnx_files_subpaths


class OnnxConfigWithPast(OnnxConfig, ABC):
"""
Expand Down Expand Up @@ -507,6 +517,8 @@ def __init__(
f"use_past = {use_past} is different than use_present_in_outputs = {use_present_in_outputs}, the value "
"of use_present_in_outputs value will be used for the outputs."
)
self.is_merged = False
self.use_cache_branch = None
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference between use_cache_branch and use_past and use_past_in_inputs ? I mean that use_cache_branch must for the case of merged decoder, but why do we need to distinguish them?

And does use_cache_branch urges use_past=True?

Copy link
Collaborator

@fxmarty fxmarty Feb 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does use_cache_branch urges use_past=True?

Yes, in other cases use_cache_branch does not make sense.

About the difference on use_past and use_past_in_inputs, it seems like code legacy that could be simplified. Or I miss something @michaelbenayoun ?

use_cache_branch is a flag indicating that for the merged decoder case, we use the cache branch of the controlflow. This flag is used in several places:
image
image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_past is the legacy here.
Basically you have two "use past":

  1. use_past_in_inputs: inputs will have past key values
  2. use_present_in_outputs: outputs will have past key values

If you set only use_past, it sets both.

super().__init__(config, task=task)

@classmethod
Expand Down Expand Up @@ -552,7 +564,11 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
if dummy_input_gen.supports_input(input_name):
# models from TextSeq2SeqOnnxConfig use decoder_input_ids as input name
# while models from TextDecoderOnnxConfig use input_ids, hence the check for both
if self.use_past is True and input_name in ["decoder_input_ids", "input_ids"]:
if (
self.use_past is True
and self.use_cache_branch is not False
and input_name in ["decoder_input_ids", "input_ids"]
):
sequence_length = dummy_input_gen.sequence_length
if "sequence_length" in kwargs and kwargs["sequence_length"] != 1:
logger.info(
Expand All @@ -572,7 +588,12 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
)

# refer to https://github.com/huggingface/optimum/pull/764
if self.use_past_in_inputs and "attention_mask" in dummy_inputs and self.PAD_ATTENTION_MASK_TO_PAST:
if (
self.use_past_in_inputs
and self.PAD_ATTENTION_MASK_TO_PAST
and self.use_cache_branch is not False
JingyaHuang marked this conversation as resolved.
Show resolved Hide resolved
and "attention_mask" in dummy_inputs
):
past_length = dummy_inputs["past_key_values"][0][0].shape[2]
dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim(
dummy_inputs["attention_mask"],
Expand Down Expand Up @@ -800,6 +821,7 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
input_name, _ = next(iter(self._onnx_config.inputs.items()))
batch_size = dummy_inputs[input_name].shape[0]

# TODO: doesn't this break attention_mask generation?
if isinstance(self._onnx_config, OnnxSeq2SeqConfigWithPast) and self._onnx_config.use_past_in_inputs is True:
kwargs["sequence_length"] = 1

Expand Down
62 changes: 59 additions & 3 deletions optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@
# limitations under the License.
"""Common ONNX configuration classes that handle most of the features for building model specific configurations."""

from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional
import os
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional

from ...onnx import merge_decoders
from ...utils import (
DummyAudioInputGenerator,
DummyBboxInputGenerator,
DummyInputGenerator,
DummyPastKeyValuesGenerator,
DummySeq2SeqDecoderTextInputGenerator,
DummySeq2SeqPastKeyValuesGenerator,
Expand All @@ -27,13 +31,12 @@
logging,
)
from .base import ConfigBehavior, OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME


if TYPE_CHECKING:
from transformers import PretrainedConfig

from ...utils import DummyInputGenerator

logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -66,6 +69,59 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
}
return common_inputs

def post_process_exported_models(self, path: Path, models_and_onnx_configs, onnx_files_subpaths):
# Attempt to merge only if the decoder-only was exported separately without/with past
if self.use_past is True and len(models_and_onnx_configs) == 2:
if onnx_files_subpaths is not None:
decoder_path = Path(path, onnx_files_subpaths[0])
decoder_with_past_path = Path(path, onnx_files_subpaths[1])
fxmarty marked this conversation as resolved.
Show resolved Hide resolved
else:
decoder_path = Path(path, ONNX_DECODER_NAME + ".onnx")
decoder_with_past_path = Path(path, ONNX_DECODER_WITH_PAST_NAME + ".onnx")
decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx")
try:
merge_decoders(
decoder=decoder_path,
decoder_with_past=decoder_with_past_path,
save_path=decoder_merged_path,
)
except Exception as e:
raise Exception(f"Unable to merge decoders. Detailed error: {e}")
os.remove(decoder_path)
os.remove(decoder_with_past_path)

# In order to do the validation of the two branches on the same file
onnx_files_subpaths = [decoder_merged_path.name, decoder_merged_path.name]

# We validate the two branches of the decoder model then
models_and_onnx_configs[ONNX_DECODER_NAME][1].is_merged = True
models_and_onnx_configs[ONNX_DECODER_NAME][1].use_cache_branch = False

# Past key values won't be generated by default, but added in the input
models_and_onnx_configs[ONNX_DECODER_NAME][1].use_past = False
models_and_onnx_configs[ONNX_DECODER_NAME][1].use_past_in_inputs = True

models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].use_cache_branch = True
models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].is_merged = True

return models_and_onnx_configs, onnx_files_subpaths

def generate_dummy_inputs_for_validation(self, reference_model_inputs: Mapping[str, Any]) -> Mapping[str, Any]:
if self.is_merged is True and self.use_cache_branch is True:
reference_model_inputs["use_cache_branch"] = DummyInputGenerator.constant_tensor(shape=[1], value=True)
elif self.is_merged is True and self.use_cache_branch is False:
reference_model_inputs["use_cache_branch"] = DummyInputGenerator.constant_tensor(shape=[1], value=False)

# We don't support optional inputs for now, so even though the non-cache branch is used,
JingyaHuang marked this conversation as resolved.
Show resolved Hide resolved
# dummy past key values are necessary
batch_size = reference_model_inputs["input_ids"].shape[0]
pkv_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[1](
task=self.task, normalized_config=self._normalized_config, sequence_length=1, batch_size=batch_size
)
reference_model_inputs["past_key_values"] = pkv_generator.generate("past_key_values", framework="pt")

return reference_model_inputs


class TextSeq2SeqOnnxConfig(OnnxSeq2SeqConfigWithPast):
"""
Expand Down