Skip to content

Commit

Permalink
Handling ONNX models with external data (#586)
Browse files Browse the repository at this point in the history
* attempt at fixing saving onnx model with external data

* styling

* fix: `cache_dir` wasn't used when loading from transformers

* separate onnx_cache_dir argument from model's cache_dir

* we can now load large ONNX models by specifying external's data directory

* Fix saving external data for large models (seq2seq)

* fix saving external data for all ORT models

* make style

* typing

Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>

* apply suggestions

* make style

* apply suggestion

Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>

* export onnx models to separate subfolders when multiple models

* this should get us correct file_names but not correct subfolder!!

* export_models is only used for multiple submodels

* we can now load seq2seq model from local dir (multiple submodels)

* fix save_pretrained

* * `infer_onnx_filename` now returns a path
* refactor `_from_pretrained`

* try saving to single file

* didn't work, reverting "try saving to single file"

This reverts commit 0fdfd87.

* add test for seq2seq model with external data

* quick fix

* try saving to a single file again

* Revert "quick fix"

This reverts commit 89e64ed.

* quick fix test

* save external data in a single file

* save_pretrained now moves model instead of copying from temp directory

* Revert "save_pretrained now moves model instead of copying from temp directory"

This reverts commit b315f85.

* add push to hub test

* add FORCE_ONNX_EXTERNAL_DATA env and faster test to push to hub

* quick fix

* we can now save and load large seq2seq models to hub + added test

* we no longer save to subfolders, as we use a singla file for external data

* make style

* apply same fixes to `modeling_decoder.py`

* apply same fixes to `modeling_ort.py`

* add tests

* fix auth token in tests

* add **kwargs to all `_save_pretrained`

* quick fix

* make style

* try reducing memory footprint when exporting onnx

* replace large seq2seq model with small on to make tests pass

* fix merge

* we no longer export models to subfolders. instead we regroup external data in a single data file

* util from last commit

* empty commit

* fix import

* add onnx utils

* fix import2

* better tests

* parameterized and skip order

Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
Co-authored-by: Felix Marty <9808326+fxmarty@users.noreply.github.com>
  • Loading branch information
3 people committed Dec 22, 2022
1 parent 0a02ea6 commit 6da9e1a
Show file tree
Hide file tree
Showing 8 changed files with 426 additions and 43 deletions.
43 changes: 38 additions & 5 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
"""ONNX model check and export functions."""

import os
from inspect import signature
from itertools import chain
from pathlib import Path
Expand All @@ -22,6 +23,9 @@
import numpy as np
from transformers.utils import is_tf_available, is_torch_available

import onnx

from ...onnx.utils import _get_onnx_external_data_tensors, check_model_uses_external_data
from ...utils import TORCH_MINIMUM_VERSION, is_diffusers_available, is_torch_onnx_support_available, logging
from .base import OnnxConfig

Expand Down Expand Up @@ -307,6 +311,7 @@ def export_pytorch(
from torch.utils._pytree import tree_map

logger.info(f"Using framework PyTorch: {torch.__version__}")
FORCE_ONNX_EXTERNAL_DATA = os.getenv("FORCE_ONNX_EXTERNAL_DATA", "0") == "1"

with torch.no_grad():
model.config.return_dict = True
Expand Down Expand Up @@ -355,6 +360,34 @@ def export_pytorch(
opset_version=opset,
)

# check if external data was exported
onnx_model = onnx.load(str(output), load_external_data=False)
model_uses_external_data = check_model_uses_external_data(onnx_model)

if model_uses_external_data or FORCE_ONNX_EXTERNAL_DATA:
tensors_paths = _get_onnx_external_data_tensors(onnx_model)
logger.info("Saving external data to one file...")

# try free model memory
del model
del onnx_model

onnx_model = onnx.load(
str(output), load_external_data=True
) # this will probably be too memory heavy for large models
onnx.save(
onnx_model,
str(output),
save_as_external_data=True,
all_tensors_to_one_file=True,
location=output.name + "_data",
size_threshold=1024 if not FORCE_ONNX_EXTERNAL_DATA else 0,
)

# delete previous external data
for tensor in tensors_paths:
os.remove(output.parent / tensor)

config.restore_ops()

return input_names, output_names
Expand Down Expand Up @@ -476,11 +509,11 @@ def export_models(

for i, model_name in enumerate(models_and_onnx_configs.keys()):
submodel, sub_onnx_config = models_and_onnx_configs[model_name]
output_path = (
output_dir.joinpath(output_names[i])
if output_names is not None
else output_dir.joinpath(model_name + ".onnx")
)
output_name = output_names[i] if output_names is not None else Path(model_name + ".onnx")

output_path = output_dir / output_name
output_path.parent.mkdir(parents=True, exist_ok=True)

outputs.append(
export(
model=submodel,
Expand Down
69 changes: 69 additions & 0 deletions optimum/onnx/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path
from typing import List, Tuple

import onnx
from onnx.external_data_helper import ExternalDataInfo, _get_initializer_tensors


def _get_onnx_external_data_tensors(model: onnx.ModelProto) -> List[str]:
"""
Get the paths of the external data tensors in the model.
Note: make sure you load the model with load_external_data=False.
"""
model_tensors = _get_initializer_tensors(model)
model_tensors_ext = [
ExternalDataInfo(tensor).location
for tensor in model_tensors
if tensor.HasField("data_location") and tensor.data_location == onnx.TensorProto.EXTERNAL
]
return model_tensors_ext


def _get_external_data_paths(src_paths: List[Path], dst_file_names: List[str]) -> Tuple[List[Path], List[str]]:
"""
Get external data paths from the model and add them to the list of files to copy.
"""
model_paths = src_paths.copy()
for model_path in model_paths:
model = onnx.load(str(model_path), load_external_data=False)
model_tensors = _get_initializer_tensors(model)
# filter out tensors that are not external data
model_tensors_ext = [
ExternalDataInfo(tensor).location
for tensor in model_tensors
if tensor.HasField("data_location") and tensor.data_location == onnx.TensorProto.EXTERNAL
]
if len(set(model_tensors_ext)) == 1:
# if external data was saved in a single file
src_paths.append(model_path.parent / model_tensors_ext[0])
dst_file_names.append(model_tensors_ext[0])
else:
# if external data doesnt exist or was saved in multiple files
src_paths.extend([model_path.parent / tensor_name for tensor_name in model_tensors_ext])
dst_file_names.extend(model_path.parent.name + "/" + tensor_name for tensor_name in model_tensors_ext)
return src_paths, dst_file_names


def check_model_uses_external_data(model: onnx.ModelProto) -> bool:
"""
Check if the model uses external data.
"""
model_tensors = _get_initializer_tensors(model)
return any(
tensor.HasField("data_location") and tensor.data_location == onnx.TensorProto.EXTERNAL
for tensor in model_tensors
)
53 changes: 42 additions & 11 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@

import onnxruntime
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError

from ..exporters import TasksManager
from ..exporters.onnx import export_models, get_decoder_models_for_export
from ..onnx.utils import _get_external_data_paths
from ..utils import NormalizedConfigManager, check_if_transformers_greater
from ..utils.file_utils import validate_file_exists
from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
Expand Down Expand Up @@ -457,6 +459,7 @@ def _save_pretrained(
save_directory: Union[str, Path],
decoder_file_name: str = ONNX_DECODER_NAME,
decoder_with_past_file_name: str = ONNX_DECODER_WITH_PAST_NAME,
**kwargs,
):
"""
Saves the model decoder and decoder with past key values as well as its configuration file to a
Expand All @@ -475,12 +478,16 @@ def _save_pretrained(
"""
src_paths = [self.decoder_model_path]
dst_file_names = [decoder_file_name]

if self.use_cache:
src_paths.append(self.decoder_with_past_model_path)
dst_file_names.append(decoder_with_past_file_name)

# add external data paths in case of large models
src_paths, dst_file_names = _get_external_data_paths(src_paths, dst_file_names)

for src_path, dst_file_name in zip(src_paths, dst_file_names):
dst_path = Path(save_directory).joinpath(dst_file_name)
dst_path = Path(save_directory) / dst_file_name
shutil.copyfile(src_path, dst_path)

@classmethod
Expand All @@ -506,25 +513,27 @@ def _from_pretrained(
model_path = Path(model_id)

if not validate_file_exists(model_id, decoder_file_name, subfolder=subfolder, revision=revision):
decoder_file_name = ORTModelDecoder.infer_onnx_filename(
decoder_path = ORTModelDecoder.infer_onnx_filename(
model_id,
DECODER_ONNX_FILE_PATTERN,
"decoder_file_name",
subfolder=subfolder,
use_auth_token=use_auth_token,
revision=revision,
)
else:
decoder_path = model_path / subfolder / decoder_file_name
decoder_regular_onnx_filenames = ORTModelDecoder._generate_regular_names_for_filename(ONNX_DECODER_NAME)
if decoder_file_name not in decoder_regular_onnx_filenames:
if decoder_path.name not in decoder_regular_onnx_filenames:
logger.warning(
f"The ONNX file {decoder_file_name} is not a regular name used in optimum.onnxruntime that are {decoder_regular_onnx_filenames}, the "
f"The ONNX file {decoder_path.name} is not a regular name used in optimum.onnxruntime that are {decoder_regular_onnx_filenames}, the "
f"{cls.__name__} might not behave as expected."
)

decoder_with_past_path = None
if use_cache is True:
if not validate_file_exists(model_id, decoder_with_past_file_name, subfolder=subfolder, revision=revision):
decoder_with_past_file_name = ORTModelDecoder.infer_onnx_filename(
decoder_with_past_path = ORTModelDecoder.infer_onnx_filename(
model_id,
DECODER_WITH_PAST_ONNX_FILE_PATTERN,
"decoder_with_past_file_name",
Expand All @@ -533,23 +542,28 @@ def _from_pretrained(
revision=revision,
fail_if_not_found=use_cache,
)
else:
decoder_with_past_path = model_path / subfolder / decoder_with_past_file_name

decoder_with_past_regular_onnx_filenames = ORTModelDecoder._generate_regular_names_for_filename(
ONNX_DECODER_WITH_PAST_NAME
)

if decoder_with_past_file_name not in decoder_with_past_regular_onnx_filenames:
if (
decoder_with_past_path is not None
and decoder_with_past_path.name not in decoder_with_past_regular_onnx_filenames
):
logger.warning(
f"The ONNX file {decoder_with_past_file_name} is not a regular name used in optimum.onnxruntime that are {decoder_with_past_regular_onnx_filenames}, "
f"The ONNX file {decoder_with_past_path.name} is not a regular name used in optimum.onnxruntime that are {decoder_with_past_regular_onnx_filenames}, "
f"the {cls.__name__} might not behave as expected."
)

decoder_with_past_path = model_path / decoder_with_past_file_name if use_cache else None
decoder_with_past_path = decoder_with_past_path if use_cache else None

preprocessors = None
if model_path.is_dir():
model = cls.load_model(
decoder_path=model_path / decoder_file_name,
decoder_path=decoder_path,
decoder_with_past_path=decoder_with_past_path,
provider=provider,
session_options=session_options,
Expand All @@ -559,8 +573,8 @@ def _from_pretrained(
preprocessors = maybe_load_preprocessors(model_id)
else:
attribute_name_to_filename = {
"last_decoder_model_name": decoder_file_name,
"last_decoder_with_past_model_name": decoder_with_past_file_name if use_cache else None,
"last_decoder_model_name": decoder_path.name,
"last_decoder_with_past_model_name": decoder_with_past_path.name if use_cache else None,
}
paths = {}
for attr_name, filename in attribute_name_to_filename.items():
Expand All @@ -576,6 +590,23 @@ def _from_pretrained(
force_download=force_download,
local_files_only=local_files_only,
)

# try download external data
try:
model_data_cache_path = hf_hub_download(
repo_id=model_id,
subfolder=subfolder,
filename=filename + "_data",
use_auth_token=use_auth_token,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
)
except EntryNotFoundError:
# model doesn't use external data
pass

paths[attr_name] = Path(model_cache_path).name
new_model_save_dir = Path(model_cache_path).parent
preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder)
Expand Down
35 changes: 30 additions & 5 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,12 @@

import onnxruntime as ort
from huggingface_hub import HfApi, HfFolder, hf_hub_download
from huggingface_hub.utils import EntryNotFoundError

from ..exporters import TasksManager
from ..exporters.onnx import export
from ..modeling_base import FROM_PRETRAINED_START_DOCSTRING, OptimizedModel
from ..onnx.utils import _get_external_data_paths
from ..utils.file_utils import find_files_matching_pattern
from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
from .io_binding import IOBindingHelper, TypeHelper
Expand Down Expand Up @@ -294,7 +296,7 @@ def load_model(
provider_options=None if provider_options is None else [provider_options],
)

def _save_pretrained(self, save_directory: Union[str, Path], file_name: str = ONNX_WEIGHTS_NAME):
def _save_pretrained(self, save_directory: Union[str, Path], file_name: str = ONNX_WEIGHTS_NAME, **kwargs):
"""
Saves a model and its configuration file to a directory, so that it can be re-loaded using the
[`~optimum.onnxruntime.modeling_ort.ORTModel.from_pretrained`] class method. It will always save the
Expand All @@ -306,9 +308,15 @@ def _save_pretrained(self, save_directory: Union[str, Path], file_name: str = ON
file_name (`str`, *optional*, defaults to the value of `optimum.onnxruntime.utils.ONNX_WEIGHTS_NAME`):
The filename to use when saving the model.
"""
# TODO: support models with external data
dst_path = Path(save_directory).joinpath(file_name)
shutil.copyfile(self.model_path, dst_path)
src_paths = [self.model_path]
dst_file_names = [file_name]

# add external data paths in case of large models
src_paths, dst_file_names = _get_external_data_paths(src_paths, dst_file_names)

for src_path, dst_file_name in zip(src_paths, dst_file_names):
dst_path = Path(save_directory) / dst_file_name
shutil.copyfile(src_path, dst_path)

@staticmethod
def _generate_regular_names_for_filename(filename: str):
Expand Down Expand Up @@ -348,7 +356,7 @@ def infer_onnx_filename(
f"Too many ONNX model files were found in {path}, specify which one to load by using the "
f"{argument_name} argument."
)
return onnx_files[0].name
return onnx_files[0]

@classmethod
def _from_pretrained(
Expand Down Expand Up @@ -420,6 +428,23 @@ def _from_pretrained(
force_download=force_download,
local_files_only=local_files_only,
)

# try download external data
try:
model_data_cache_path = hf_hub_download(
repo_id=model_id,
subfolder=subfolder,
filename=file_name + "_data",
use_auth_token=use_auth_token,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
)
except EntryNotFoundError:
# model doesn't use external data
pass

model = ORTModel.load_model(
model_cache_path, provider=provider, session_options=session_options, provider_options=provider_options
)
Expand Down

0 comments on commit 6da9e1a

Please sign in to comment.