Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handling ONNX models with external data #586

Merged
merged 59 commits into from
Dec 22, 2022
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
031c70e
attempt at fixing saving onnx model with external data
NouamaneTazi Jul 1, 2022
6fec76f
styling
NouamaneTazi Jul 1, 2022
7c59c2e
fix: `cache_dir` wasn't used when loading from transformers
NouamaneTazi Jul 2, 2022
387fe6c
separate onnx_cache_dir argument from model's cache_dir
NouamaneTazi Jul 2, 2022
cce8e90
we can now load large ONNX models by specifying external's data direc…
NouamaneTazi Jul 2, 2022
5902565
Merge branch 'main' of https://github.com/huggingface/optimum into pr…
NouamaneTazi Dec 7, 2022
845f788
Fix saving external data for large models (seq2seq)
NouamaneTazi Dec 7, 2022
1f2687f
fix saving external data for all ORT models
NouamaneTazi Dec 7, 2022
cd2babd
make style
NouamaneTazi Dec 7, 2022
16d2e24
typing
NouamaneTazi Dec 7, 2022
0a5a52c
apply suggestions
NouamaneTazi Dec 7, 2022
b3f7bbb
Merge branch 'main' of https://github.com/huggingface/optimum into pr…
NouamaneTazi Dec 7, 2022
2ec8009
make style
NouamaneTazi Dec 7, 2022
2c1e907
Merge branch 'main' of https://github.com/huggingface/optimum into pr…
NouamaneTazi Dec 12, 2022
af664f8
apply suggestion
NouamaneTazi Dec 12, 2022
a961d8b
export onnx models to separate subfolders when multiple models
NouamaneTazi Dec 13, 2022
ca7a38b
this should get us correct file_names but not correct subfolder!!
NouamaneTazi Dec 13, 2022
a7cb11e
export_models is only used for multiple submodels
NouamaneTazi Dec 13, 2022
9479c97
we can now load seq2seq model from local dir (multiple submodels)
NouamaneTazi Dec 13, 2022
c9f1eb5
Merge branch 'save-large-models' of https://github.com/nouamanetazi/o…
NouamaneTazi Dec 13, 2022
840bc44
fix save_pretrained
NouamaneTazi Dec 13, 2022
690ac69
* `infer_onnx_filename` now returns a path
NouamaneTazi Dec 15, 2022
0fdfd87
try saving to single file
NouamaneTazi Dec 15, 2022
f708b58
didn't work, reverting "try saving to single file"
NouamaneTazi Dec 15, 2022
3aa28c0
add test for seq2seq model with external data
NouamaneTazi Dec 15, 2022
89e64ed
quick fix
NouamaneTazi Dec 15, 2022
d56f8c3
try saving to a single file again
NouamaneTazi Dec 16, 2022
2070d3c
Revert "quick fix"
NouamaneTazi Dec 16, 2022
2ad9e83
quick fix test
NouamaneTazi Dec 16, 2022
990c373
save external data in a single file
NouamaneTazi Dec 16, 2022
b315f85
save_pretrained now moves model instead of copying from temp directory
NouamaneTazi Dec 16, 2022
f4f2997
Revert "save_pretrained now moves model instead of copying from temp …
NouamaneTazi Dec 16, 2022
9f15395
add push to hub test
NouamaneTazi Dec 16, 2022
ee04cd0
add FORCE_ONNX_EXTERNAL_DATA env and faster test to push to hub
NouamaneTazi Dec 16, 2022
da65c03
quick fix
NouamaneTazi Dec 16, 2022
99bef3c
we can now save and load large seq2seq models to hub + added test
NouamaneTazi Dec 16, 2022
4e70fd7
we no longer save to subfolders, as we use a singla file for external…
NouamaneTazi Dec 16, 2022
40e6abe
make style
NouamaneTazi Dec 18, 2022
c332b73
apply same fixes to `modeling_decoder.py`
NouamaneTazi Dec 18, 2022
16bd118
apply same fixes to `modeling_ort.py`
NouamaneTazi Dec 18, 2022
cf1e8ed
add tests
NouamaneTazi Dec 18, 2022
2511eaa
fix auth token in tests
NouamaneTazi Dec 18, 2022
8f51d89
Merge branch 'main' of https://github.com/huggingface/optimum into ex…
NouamaneTazi Dec 18, 2022
4c6bc60
add **kwargs to all `_save_pretrained`
NouamaneTazi Dec 18, 2022
2e61b70
quick fix
NouamaneTazi Dec 18, 2022
2013d57
make style
NouamaneTazi Dec 18, 2022
9648d85
try reducing memory footprint when exporting onnx
NouamaneTazi Dec 18, 2022
1d41c20
replace large seq2seq model with small on to make tests pass
NouamaneTazi Dec 21, 2022
cd384f1
Merge branch 'main' of https://github.com/huggingface/optimum into ex…
NouamaneTazi Dec 21, 2022
36af9a8
fix merge
NouamaneTazi Dec 21, 2022
222a8a7
Merge branch 'main' of https://github.com/huggingface/optimum into ex…
NouamaneTazi Dec 21, 2022
27bc039
we no longer export models to subfolders. instead we regroup external…
NouamaneTazi Dec 21, 2022
a08f697
util from last commit
NouamaneTazi Dec 21, 2022
78703f2
empty commit
fxmarty Dec 21, 2022
5fc415d
fix import
fxmarty Dec 22, 2022
b796f9d
add onnx utils
fxmarty Dec 22, 2022
f1ef9e9
fix import2
fxmarty Dec 22, 2022
b262c46
better tests
fxmarty Dec 22, 2022
e550834
parameterized and skip order
fxmarty Dec 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 6 additions & 5 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,11 +444,12 @@ def export_models(

for i, model_name in enumerate(models_for_export.keys()):
submodel, sub_onnx_config = models_for_export[model_name]
output_path = (
output_dir.joinpath(output_names[i])
if output_names is not None
else output_dir.joinpath(model_name + ".onnx")
)
output_name = output_names[i] if output_names is not None else Path(model_name + ".onnx")

# when the model uses several ONNX files, save each in subfolders to avoid conflicting external files
output_path = output_dir / model_name / output_name
output_path.parent.mkdir(parents=True, exist_ok=True)

outputs.append(
export(
submodel,
Expand Down
14 changes: 12 additions & 2 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@
from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
from .io_binding import TypeHelper
from .modeling_ort import ORTModel
from .utils import ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, get_provider_for_device, parse_device
from .utils import (
ONNX_DECODER_NAME,
ONNX_DECODER_WITH_PAST_NAME,
_get_external_data_paths,
get_provider_for_device,
parse_device,
)


if TYPE_CHECKING:
Expand Down Expand Up @@ -475,12 +481,16 @@ def _save_pretrained(
"""
src_paths = [self.decoder_model_path]
dst_file_names = [decoder_file_name]

if self.use_cache:
src_paths.append(self.decoder_with_past_model_path)
dst_file_names.append(decoder_with_past_file_name)

# add external data paths in case of large models
src_paths, dst_file_names = _get_external_data_paths(src_paths, dst_file_names)

for src_path, dst_file_name in zip(src_paths, dst_file_names):
dst_path = Path(save_directory).joinpath(dst_file_name)
dst_path = Path(save_directory) / dst_file_name
shutil.copyfile(src_path, dst_path)

@classmethod
Expand Down
13 changes: 10 additions & 3 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from .io_binding import IOBindingHelper, TypeHelper
from .utils import (
ONNX_WEIGHTS_NAME,
_get_external_data_paths,
get_device_for_provider,
get_provider_for_device,
parse_device,
Expand Down Expand Up @@ -301,9 +302,15 @@ def _save_pretrained(self, save_directory: Union[str, Path], file_name: str = ON
file_name (`str`, *optional*, defaults to the value of `optimum.onnxruntime.utils.ONNX_WEIGHTS_NAME`):
The filename to use when saving the model.
"""
# TODO: support models with external data
dst_path = Path(save_directory).joinpath(file_name)
shutil.copyfile(self.model_path, dst_path)
src_paths = [self.model_path]
dst_file_names = [file_name]

# add external data paths in case of large models
src_paths, dst_file_names = _get_external_data_paths(src_paths, dst_file_names)

for src_path, dst_file_name in zip(src_paths, dst_file_names):
dst_path = Path(save_directory) / dst_file_name
shutil.copyfile(src_path, dst_path)

@staticmethod
def _generate_regular_names_for_filename(filename: str):
Expand Down
53 changes: 47 additions & 6 deletions optimum/onnxruntime/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from ..exporters.onnx import export_models, get_encoder_decoder_models_for_export
from ..exporters.tasks import TasksManager
from ..utils import NormalizedConfigManager, check_if_transformers_greater
from ..utils.file_utils import validate_file_exists
from ..utils.file_utils import validate_file_exists, find_files_matching_pattern
from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
from .io_binding import TypeHelper
from .modeling_decoder import ORTDecoder
Expand All @@ -45,6 +45,7 @@
ONNX_DECODER_NAME,
ONNX_DECODER_WITH_PAST_NAME,
ONNX_ENCODER_NAME,
_get_external_data_paths,
get_provider_for_device,
parse_device,
validate_provider_availability,
Expand Down Expand Up @@ -900,14 +901,18 @@ def _save_pretrained(
The decoder with past key values model file name overwriting the default file name, allowing to save
the decoder model with a different name.
"""
src_file_names = [self.encoder_model_path, self.decoder_model_path]
dst_file_names = [encoder_file_name, decoder_file_name]
src_paths = [self.encoder_model_path, self.decoder_model_path]
dst_file_names = ["encoder_model/" + encoder_file_name, "decoder_model/" + decoder_file_name]
if self.use_cache:
src_file_names.append(self.decoder_with_past_model_path)
dst_file_names.append(decoder_with_past_file_name)
src_paths.append(self.decoder_with_past_model_path)
dst_file_names.append("decoder_with_past_model/" + decoder_with_past_file_name)
NouamaneTazi marked this conversation as resolved.
Show resolved Hide resolved

for src_path, dst_file_name in zip(src_file_names, dst_file_names):
# add external data paths in case of large models
src_paths, dst_file_names = _get_external_data_paths(src_paths, dst_file_names)

for src_path, dst_file_name in zip(src_paths, dst_file_names):
dst_path = Path(save_directory) / dst_file_name
dst_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copyfile(src_path, dst_path)

@classmethod
Expand Down Expand Up @@ -1152,6 +1157,42 @@ def to(self, device: Union[torch.device, str, int]):
return self


@staticmethod
def infer_onnx_filename(
NouamaneTazi marked this conversation as resolved.
Show resolved Hide resolved
model_name_or_path: Union[str, Path],
pattern: str,
argument_name: str,
subfolder: str = "",
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
fail_if_not_found: bool = True,
) -> str:
onnx_files = find_files_matching_pattern(
model_name_or_path,
pattern,
glob_pattern="**/*.onnx",
subfolder=subfolder,
use_auth_token=use_auth_token,
revision=revision,
)

path = model_name_or_path
if subfolder != "":
path = f"{path}/{subfolder}"

if len(onnx_files) == 0:
if fail_if_not_found:
raise FileNotFoundError(f"Could not find any ONNX model file in {path}")
return None
elif len(onnx_files) > 1:
if argument_name is not None:
raise RuntimeError(
f"Too many ONNX model files were found in {path}, specify which one to load by using the "
f"{argument_name} argument."
)
return onnx_files[0].parent.name + "/" + onnx_files[0].name
NouamaneTazi marked this conversation as resolved.
Show resolved Hide resolved


class ORTModelForSeq2SeqLM(ORTModelForConditionalGeneration, GenerationMixin):
"""
Sequence-to-sequence model with a language modeling head for ONNX Runtime inference.
Expand Down
23 changes: 22 additions & 1 deletion optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
import importlib.util
import os
from enum import Enum
from typing import Dict, Tuple, Union
from pathlib import Path
from typing import Dict, List, Tuple, Union

import torch
from transformers.onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
Expand All @@ -25,6 +26,7 @@
import onnx
import onnxruntime as ort
import pkg_resources
from onnx.external_data_helper import ExternalDataInfo, _get_initializer_tensors

from ..onnx import OnnxConfigWithLoss, OnnxConfigWithPastAndLoss, OnnxSeq2SeqConfigWithPastAndLoss

Expand Down Expand Up @@ -270,3 +272,22 @@ class ORTQuantizableOperator(Enum):
Resize = "Resize"
AveragePool = "AveragePool"
Concat = "Concat"


def _get_external_data_paths(src_paths: List[Path], dst_file_names: List[str]) -> Tuple[List[Path], List[str]]:
NouamaneTazi marked this conversation as resolved.
Show resolved Hide resolved
"""
Get external data paths from the model and add them to the list of files to copy.
"""
model_paths = src_paths.copy()
for model_path in model_paths:
model = onnx.load(str(model_path), load_external_data=False)
# filter out tensors that are not external data
model_tensors = _get_initializer_tensors(model)
model_tensors_ext = [
ExternalDataInfo(tensor).location
for tensor in model_tensors
if tensor.HasField("data_location") and tensor.data_location == onnx.TensorProto.EXTERNAL
]
src_paths.extend([model_path.parent / tensor_name for tensor_name in model_tensors_ext])
dst_file_names.extend(str(model_path.parent.name / tensor_name) for tensor_name in model_tensors_ext)
return src_paths, dst_file_names