-
Notifications
You must be signed in to change notification settings - Fork 388
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Handling ONNX models with external data (#586)
* attempt at fixing saving onnx model with external data * styling * fix: `cache_dir` wasn't used when loading from transformers * separate onnx_cache_dir argument from model's cache_dir * we can now load large ONNX models by specifying external's data directory * Fix saving external data for large models (seq2seq) * fix saving external data for all ORT models * make style * typing Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com> * apply suggestions * make style * apply suggestion Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com> * export onnx models to separate subfolders when multiple models * this should get us correct file_names but not correct subfolder!! * export_models is only used for multiple submodels * we can now load seq2seq model from local dir (multiple submodels) * fix save_pretrained * * `infer_onnx_filename` now returns a path * refactor `_from_pretrained` * try saving to single file * didn't work, reverting "try saving to single file" This reverts commit 0fdfd87. * add test for seq2seq model with external data * quick fix * try saving to a single file again * Revert "quick fix" This reverts commit 89e64ed. * quick fix test * save external data in a single file * save_pretrained now moves model instead of copying from temp directory * Revert "save_pretrained now moves model instead of copying from temp directory" This reverts commit b315f85. * add push to hub test * add FORCE_ONNX_EXTERNAL_DATA env and faster test to push to hub * quick fix * we can now save and load large seq2seq models to hub + added test * we no longer save to subfolders, as we use a singla file for external data * make style * apply same fixes to `modeling_decoder.py` * apply same fixes to `modeling_ort.py` * add tests * fix auth token in tests * add **kwargs to all `_save_pretrained` * quick fix * make style * try reducing memory footprint when exporting onnx * replace large seq2seq model with small on to make tests pass * fix merge * we no longer export models to subfolders. instead we regroup external data in a single data file * util from last commit * empty commit * fix import * add onnx utils * fix import2 * better tests * parameterized and skip order Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com> Co-authored-by: Felix Marty <9808326+fxmarty@users.noreply.github.com>
- Loading branch information
1 parent
0a02ea6
commit 6da9e1a
Showing
8 changed files
with
426 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Copyright 2022 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from pathlib import Path | ||
from typing import List, Tuple | ||
|
||
import onnx | ||
from onnx.external_data_helper import ExternalDataInfo, _get_initializer_tensors | ||
|
||
|
||
def _get_onnx_external_data_tensors(model: onnx.ModelProto) -> List[str]: | ||
""" | ||
Get the paths of the external data tensors in the model. | ||
Note: make sure you load the model with load_external_data=False. | ||
""" | ||
model_tensors = _get_initializer_tensors(model) | ||
model_tensors_ext = [ | ||
ExternalDataInfo(tensor).location | ||
for tensor in model_tensors | ||
if tensor.HasField("data_location") and tensor.data_location == onnx.TensorProto.EXTERNAL | ||
] | ||
return model_tensors_ext | ||
|
||
|
||
def _get_external_data_paths(src_paths: List[Path], dst_file_names: List[str]) -> Tuple[List[Path], List[str]]: | ||
""" | ||
Get external data paths from the model and add them to the list of files to copy. | ||
""" | ||
model_paths = src_paths.copy() | ||
for model_path in model_paths: | ||
model = onnx.load(str(model_path), load_external_data=False) | ||
model_tensors = _get_initializer_tensors(model) | ||
# filter out tensors that are not external data | ||
model_tensors_ext = [ | ||
ExternalDataInfo(tensor).location | ||
for tensor in model_tensors | ||
if tensor.HasField("data_location") and tensor.data_location == onnx.TensorProto.EXTERNAL | ||
] | ||
if len(set(model_tensors_ext)) == 1: | ||
# if external data was saved in a single file | ||
src_paths.append(model_path.parent / model_tensors_ext[0]) | ||
dst_file_names.append(model_tensors_ext[0]) | ||
else: | ||
# if external data doesnt exist or was saved in multiple files | ||
src_paths.extend([model_path.parent / tensor_name for tensor_name in model_tensors_ext]) | ||
dst_file_names.extend(model_path.parent.name + "/" + tensor_name for tensor_name in model_tensors_ext) | ||
return src_paths, dst_file_names | ||
|
||
|
||
def check_model_uses_external_data(model: onnx.ModelProto) -> bool: | ||
""" | ||
Check if the model uses external data. | ||
""" | ||
model_tensors = _get_initializer_tensors(model) | ||
return any( | ||
tensor.HasField("data_location") and tensor.data_location == onnx.TensorProto.EXTERNAL | ||
for tensor in model_tensors | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.