Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handling ONNX models with external data #586

Merged
merged 59 commits into from
Dec 22, 2022
Merged
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
031c70e
attempt at fixing saving onnx model with external data
NouamaneTazi Jul 1, 2022
6fec76f
styling
NouamaneTazi Jul 1, 2022
7c59c2e
fix: `cache_dir` wasn't used when loading from transformers
NouamaneTazi Jul 2, 2022
387fe6c
separate onnx_cache_dir argument from model's cache_dir
NouamaneTazi Jul 2, 2022
cce8e90
we can now load large ONNX models by specifying external's data direc…
NouamaneTazi Jul 2, 2022
5902565
Merge branch 'main' of https://github.com/huggingface/optimum into pr…
NouamaneTazi Dec 7, 2022
845f788
Fix saving external data for large models (seq2seq)
NouamaneTazi Dec 7, 2022
1f2687f
fix saving external data for all ORT models
NouamaneTazi Dec 7, 2022
cd2babd
make style
NouamaneTazi Dec 7, 2022
16d2e24
typing
NouamaneTazi Dec 7, 2022
0a5a52c
apply suggestions
NouamaneTazi Dec 7, 2022
b3f7bbb
Merge branch 'main' of https://github.com/huggingface/optimum into pr…
NouamaneTazi Dec 7, 2022
2ec8009
make style
NouamaneTazi Dec 7, 2022
2c1e907
Merge branch 'main' of https://github.com/huggingface/optimum into pr…
NouamaneTazi Dec 12, 2022
af664f8
apply suggestion
NouamaneTazi Dec 12, 2022
a961d8b
export onnx models to separate subfolders when multiple models
NouamaneTazi Dec 13, 2022
ca7a38b
this should get us correct file_names but not correct subfolder!!
NouamaneTazi Dec 13, 2022
a7cb11e
export_models is only used for multiple submodels
NouamaneTazi Dec 13, 2022
9479c97
we can now load seq2seq model from local dir (multiple submodels)
NouamaneTazi Dec 13, 2022
c9f1eb5
Merge branch 'save-large-models' of https://github.com/nouamanetazi/o…
NouamaneTazi Dec 13, 2022
840bc44
fix save_pretrained
NouamaneTazi Dec 13, 2022
690ac69
* `infer_onnx_filename` now returns a path
NouamaneTazi Dec 15, 2022
0fdfd87
try saving to single file
NouamaneTazi Dec 15, 2022
f708b58
didn't work, reverting "try saving to single file"
NouamaneTazi Dec 15, 2022
3aa28c0
add test for seq2seq model with external data
NouamaneTazi Dec 15, 2022
89e64ed
quick fix
NouamaneTazi Dec 15, 2022
d56f8c3
try saving to a single file again
NouamaneTazi Dec 16, 2022
2070d3c
Revert "quick fix"
NouamaneTazi Dec 16, 2022
2ad9e83
quick fix test
NouamaneTazi Dec 16, 2022
990c373
save external data in a single file
NouamaneTazi Dec 16, 2022
b315f85
save_pretrained now moves model instead of copying from temp directory
NouamaneTazi Dec 16, 2022
f4f2997
Revert "save_pretrained now moves model instead of copying from temp …
NouamaneTazi Dec 16, 2022
9f15395
add push to hub test
NouamaneTazi Dec 16, 2022
ee04cd0
add FORCE_ONNX_EXTERNAL_DATA env and faster test to push to hub
NouamaneTazi Dec 16, 2022
da65c03
quick fix
NouamaneTazi Dec 16, 2022
99bef3c
we can now save and load large seq2seq models to hub + added test
NouamaneTazi Dec 16, 2022
4e70fd7
we no longer save to subfolders, as we use a singla file for external…
NouamaneTazi Dec 16, 2022
40e6abe
make style
NouamaneTazi Dec 18, 2022
c332b73
apply same fixes to `modeling_decoder.py`
NouamaneTazi Dec 18, 2022
16bd118
apply same fixes to `modeling_ort.py`
NouamaneTazi Dec 18, 2022
cf1e8ed
add tests
NouamaneTazi Dec 18, 2022
2511eaa
fix auth token in tests
NouamaneTazi Dec 18, 2022
8f51d89
Merge branch 'main' of https://github.com/huggingface/optimum into ex…
NouamaneTazi Dec 18, 2022
4c6bc60
add **kwargs to all `_save_pretrained`
NouamaneTazi Dec 18, 2022
2e61b70
quick fix
NouamaneTazi Dec 18, 2022
2013d57
make style
NouamaneTazi Dec 18, 2022
9648d85
try reducing memory footprint when exporting onnx
NouamaneTazi Dec 18, 2022
1d41c20
replace large seq2seq model with small on to make tests pass
NouamaneTazi Dec 21, 2022
cd384f1
Merge branch 'main' of https://github.com/huggingface/optimum into ex…
NouamaneTazi Dec 21, 2022
36af9a8
fix merge
NouamaneTazi Dec 21, 2022
222a8a7
Merge branch 'main' of https://github.com/huggingface/optimum into ex…
NouamaneTazi Dec 21, 2022
27bc039
we no longer export models to subfolders. instead we regroup external…
NouamaneTazi Dec 21, 2022
a08f697
util from last commit
NouamaneTazi Dec 21, 2022
78703f2
empty commit
fxmarty Dec 21, 2022
5fc415d
fix import
fxmarty Dec 22, 2022
b796f9d
add onnx utils
fxmarty Dec 22, 2022
f1ef9e9
fix import2
fxmarty Dec 22, 2022
b262c46
better tests
fxmarty Dec 22, 2022
e550834
parameterized and skip order
fxmarty Dec 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
39 changes: 34 additions & 5 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
"""ONNX model check and export functions."""

import os
from inspect import signature
from itertools import chain
from pathlib import Path
Expand All @@ -22,6 +23,9 @@
import numpy as np
from transformers.utils import is_tf_available, is_torch_available

import onnx

from ...onnxruntime.utils import check_model_uses_external_data
from ...utils import logging
from .base import OnnxConfig
from .utils import (
Expand Down Expand Up @@ -278,6 +282,7 @@ def export_pytorch(
from torch.utils._pytree import tree_map

logger.info(f"Using framework PyTorch: {torch.__version__}")
FORCE_ONNX_EXTERNAL_DATA = os.getenv("FORCE_ONNX_EXTERNAL_DATA", "0") == "1"

with torch.no_grad():
model.config.return_dict = True
Expand Down Expand Up @@ -323,6 +328,29 @@ def export_pytorch(
opset_version=opset,
)

# check if external data was exported
onnx_model = onnx.load(str(output), load_external_data=False)
model_uses_external_data = check_model_uses_external_data(onnx_model)

if model_uses_external_data or FORCE_ONNX_EXTERNAL_DATA:
logger.info("Saving external data to one file...")
onnx_model = onnx.load(
str(output), load_external_data=True
) # TODO: this will probably be too memory heavy, shall we free `model` memory?
onnx.save(
onnx_model,
str(output),
save_as_external_data=True,
all_tensors_to_one_file=True,
location=output.name + "_data",
size_threshold=1024 if not FORCE_ONNX_EXTERNAL_DATA else 0,
NouamaneTazi marked this conversation as resolved.
Show resolved Hide resolved
)

# delete previous external data (all files besides model.onnx and model.onnx_data)
for file in os.listdir(output.parent):
if file != output.name and file != output.name + "_data":
os.remove(os.path.join(output.parent, file))

config.restore_ops()

return input_names, output_names
Expand Down Expand Up @@ -443,11 +471,12 @@ def export_models(

for i, model_name in enumerate(models_for_export.keys()):
submodel, sub_onnx_config = models_for_export[model_name]
output_path = (
output_dir.joinpath(output_names[i])
if output_names is not None
else output_dir.joinpath(model_name + ".onnx")
)
output_name = output_names[i] if output_names is not None else Path(model_name + ".onnx")

# when the model uses several ONNX files, save each in subfolders to avoid conflicting external files
output_path = output_dir / model_name / output_name
output_path.parent.mkdir(parents=True, exist_ok=True)

outputs.append(
export(
submodel,
Expand Down
2 changes: 1 addition & 1 deletion optimum/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def save_pretrained(
self.config.save_pretrained(save_directory)
for preprocessor in self.preprocessors:
preprocessor.save_pretrained(save_directory)
self._save_pretrained(save_directory, **kwargs)
self._save_pretrained(save_directory)

if push_to_hub:
return self.push_to_hub(save_directory, **kwargs)
Expand Down
56 changes: 44 additions & 12 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import onnxruntime
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError

from ..exporters import TasksManager
from ..exporters.onnx import export_models, get_decoder_models_for_export
Expand All @@ -35,7 +36,13 @@
from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
from .io_binding import TypeHelper
from .modeling_ort import ORTModel
from .utils import ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, get_provider_for_device, parse_device
from .utils import (
ONNX_DECODER_NAME,
ONNX_DECODER_WITH_PAST_NAME,
_get_external_data_paths,
get_provider_for_device,
parse_device,
)


if TYPE_CHECKING:
Expand Down Expand Up @@ -475,12 +482,16 @@ def _save_pretrained(
"""
src_paths = [self.decoder_model_path]
dst_file_names = [decoder_file_name]

if self.use_cache:
src_paths.append(self.decoder_with_past_model_path)
dst_file_names.append(decoder_with_past_file_name)

# add external data paths in case of large models
src_paths, dst_file_names = _get_external_data_paths(src_paths, dst_file_names)

for src_path, dst_file_name in zip(src_paths, dst_file_names):
dst_path = Path(save_directory).joinpath(dst_file_name)
dst_path = Path(save_directory) / dst_file_name
shutil.copyfile(src_path, dst_path)

@classmethod
Expand All @@ -506,25 +517,27 @@ def _from_pretrained(
model_path = Path(model_id)

if not validate_file_exists(model_id, decoder_file_name, subfolder=subfolder, revision=revision):
decoder_file_name = ORTModelDecoder.infer_onnx_filename(
decoder_path = ORTModelDecoder.infer_onnx_filename(
model_id,
DECODER_ONNX_FILE_PATTERN,
"decoder_file_name",
subfolder=subfolder,
use_auth_token=use_auth_token,
revision=revision,
)
else:
decoder_path = model_path / subfolder / decoder_file_name
decoder_regular_onnx_filenames = ORTModelDecoder._generate_regular_names_for_filename(ONNX_DECODER_NAME)
if decoder_file_name not in decoder_regular_onnx_filenames:
if decoder_path.name not in decoder_regular_onnx_filenames:
logger.warning(
f"The ONNX file {decoder_file_name} is not a regular name used in optimum.onnxruntime that are {decoder_regular_onnx_filenames}, the "
f"The ONNX file {decoder_path.name} is not a regular name used in optimum.onnxruntime that are {decoder_regular_onnx_filenames}, the "
f"{cls.__name__} might not behave as expected."
)

decoder_with_past_path = None
if use_cache is True:
if not validate_file_exists(model_id, decoder_with_past_file_name, subfolder=subfolder, revision=revision):
decoder_with_past_file_name = ORTModelDecoder.infer_onnx_filename(
decoder_with_past_path = ORTModelDecoder.infer_onnx_filename(
model_id,
DECODER_WITH_PAST_ONNX_FILE_PATTERN,
"decoder_with_past_file_name",
Expand All @@ -533,23 +546,25 @@ def _from_pretrained(
revision=revision,
fail_if_not_found=use_cache,
)
else:
decoder_with_past_path = model_path / subfolder / decoder_with_past_file_name

decoder_with_past_regular_onnx_filenames = ORTModelDecoder._generate_regular_names_for_filename(
ONNX_DECODER_WITH_PAST_NAME
)

if decoder_with_past_file_name not in decoder_with_past_regular_onnx_filenames:
if decoder_with_past_path.name not in decoder_with_past_regular_onnx_filenames:
logger.warning(
f"The ONNX file {decoder_with_past_file_name} is not a regular name used in optimum.onnxruntime that are {decoder_with_past_regular_onnx_filenames}, "
f"The ONNX file {decoder_with_past_path.name} is not a regular name used in optimum.onnxruntime that are {decoder_with_past_regular_onnx_filenames}, "
f"the {cls.__name__} might not behave as expected."
)

decoder_with_past_path = model_path / decoder_with_past_file_name if use_cache else None
decoder_with_past_path = decoder_with_past_path if use_cache else None

preprocessors = None
if model_path.is_dir():
model = cls.load_model(
decoder_path=model_path / decoder_file_name,
decoder_path=decoder_path,
decoder_with_past_path=decoder_with_past_path,
provider=provider,
session_options=session_options,
Expand All @@ -559,8 +574,8 @@ def _from_pretrained(
preprocessors = maybe_load_preprocessors(model_id)
else:
attribute_name_to_filename = {
"last_decoder_model_name": decoder_file_name,
"last_decoder_with_past_model_name": decoder_with_past_file_name if use_cache else None,
"last_decoder_model_name": decoder_path.name,
"last_decoder_with_past_model_name": decoder_with_past_path.name if use_cache else None,
}
paths = {}
for attr_name, filename in attribute_name_to_filename.items():
Expand All @@ -576,6 +591,23 @@ def _from_pretrained(
force_download=force_download,
local_files_only=local_files_only,
)

# try download external data
try:
model_data_cache_path = hf_hub_download(
repo_id=model_id,
subfolder=subfolder,
filename=filename + "_data",
use_auth_token=use_auth_token,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
)
except EntryNotFoundError:
# model doesn't use external data
pass

paths[attr_name] = Path(model_cache_path).name
new_model_save_dir = Path(model_cache_path).parent
preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder)
Expand Down
33 changes: 29 additions & 4 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@

import onnxruntime as ort
from huggingface_hub import HfApi, HfFolder, hf_hub_download
from huggingface_hub.utils import EntryNotFoundError

from ..exporters import TasksManager
from ..exporters.onnx import export
Expand All @@ -54,6 +55,7 @@
from .io_binding import IOBindingHelper, TypeHelper
from .utils import (
ONNX_WEIGHTS_NAME,
_get_external_data_paths,
get_device_for_provider,
get_provider_for_device,
parse_device,
Expand Down Expand Up @@ -303,9 +305,15 @@ def _save_pretrained(self, save_directory: Union[str, Path], file_name: str = ON
file_name (`str`, *optional*, defaults to the value of `optimum.onnxruntime.utils.ONNX_WEIGHTS_NAME`):
The filename to use when saving the model.
"""
# TODO: support models with external data
dst_path = Path(save_directory).joinpath(file_name)
shutil.copyfile(self.model_path, dst_path)
src_paths = [self.model_path]
dst_file_names = [file_name]

# add external data paths in case of large models
src_paths, dst_file_names = _get_external_data_paths(src_paths, dst_file_names)

for src_path, dst_file_name in zip(src_paths, dst_file_names):
dst_path = Path(save_directory) / dst_file_name
shutil.copyfile(src_path, dst_path)

@staticmethod
def _generate_regular_names_for_filename(filename: str):
Expand Down Expand Up @@ -345,7 +353,7 @@ def infer_onnx_filename(
f"Too many ONNX model files were found in {path}, specify which one to load by using the "
f"{argument_name} argument."
)
return onnx_files[0].name
return onnx_files[0]

@classmethod
def _from_pretrained(
Expand Down Expand Up @@ -417,6 +425,23 @@ def _from_pretrained(
force_download=force_download,
local_files_only=local_files_only,
)

# try download external data
try:
model_data_cache_path = hf_hub_download(
repo_id=model_id,
subfolder=subfolder,
filename=file_name + "_data",
use_auth_token=use_auth_token,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
)
except EntryNotFoundError:
# model doesn't use external data
pass

model = ORTModel.load_model(
model_cache_path, provider=provider, session_options=session_options, provider_options=provider_options
)
Expand Down