Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handling ONNX models with external data #586

Merged
merged 59 commits into from Dec 22, 2022

Conversation

NouamaneTazi
Copy link
Member

@NouamaneTazi NouamaneTazi commented Dec 13, 2022

This PR aims to handle loading and exporting ONNX models with external data, locally and from the hub. We can also now use FORCE_ONNX_EXTERNAL_DATA=1 to force using external data format even for small models

  • Saving/loading a model with external data locally
  • Saving external data in a single file (ends with .onnx_data for easy loading from hub)
  • Saving/loading a model with external data from the hub
  • Writing tests
  • Apply the same changes for other models besides seq2seq

cc @fxmarty @mht-sharma @michaelbenayoun

Fixes #254 and #377

@NouamaneTazi
Copy link
Member Author

Saving is correctly done as we discussed @fxmarty, but loading deserves some more discussion.
I'm trying to load the different submodels from the different subfolders (for example here). I think an easy (but bloated) solution would be to have multiple subfolder arguments, like we have for file_name

        encoder_file_name: str = ONNX_ENCODER_NAME,
        decoder_file_name: str = ONNX_DECODER_NAME,
        decoder_with_past_file_name: str = ONNX_DECODER_WITH_PAST_NAME,
        subfolder: str = "",

Wdyt?

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Dec 13, 2022

The documentation is not available anymore as the PR was closed or merged.

@fxmarty
Copy link
Collaborator

fxmarty commented Dec 13, 2022

So we can maybe align on:

  • use ONNX from subfolders encoder/ decoder/ etc with hard-coded folder names if the subfolders are found
  • otherwise use top-level (backward compatible)

Possible arborescence:

t5_model/ (subfolder="onnx")
    onnx/
        encoder/
        decoder/
        decoder_with_past/
t5_model/ (subfolder="")
    encoder/
    decoder/
    decoder_with_past/
t5_model/ (subfolder="")
t5_model/ (subfolder="onnx")
    onnx/

@NouamaneTazi
Copy link
Member Author

NouamaneTazi commented Dec 13, 2022

This should work now

import shutil
from pathlib import Path

from transformers import (
    AutoConfig,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    BertForSequenceClassification,
    MBartForConditionalGeneration,
)
from transformers.modeling_utils import no_init_weights

from huggingface_hub import HfApi
from optimum.onnxruntime import ORTModelForCausalLM, ORTModelForSeq2SeqLM, ORTModelForSequenceClassification


# model_ckpt = "hf-internal-testing/tiny-bert"
# model_ckpt = "facebook/mbart-large-en-ro"
model_ckpt = "sshleifer/tiny-mbart"
save_path = Path(f"saved_model/{model_ckpt}")
save_path.mkdir(parents=True, exist_ok=True)

tokenizer = AutoTokenizer.from_pretrained(model_ckpt, use_auth_token=True)

config = AutoConfig.from_pretrained(model_ckpt, use_auth_token=True)
with no_init_weights():
    model = MBartForConditionalGeneration(config)

# save to local folder
model.save_pretrained(save_path)

model = ORTModelForSeq2SeqLM.from_pretrained(save_path, from_transformers=True)
# save onnx to local folder
model.save_pretrained(save_path / "onnx")

# gives:
# .
#  |-tiny-mbart
#  | |-special_tokens_map.json
#  | |-sentencepiece.bpe.model
#  | |-tokenizer.json
#  | |-tokenizer_config.json
#  | |-onnx
#  | | |-special_tokens_map.json
#  | | |-decoder_with_past_model
#  | | | |-decoder_with_past_model.onnx
#  | | |-sentencepiece.bpe.model
#  | | |-tokenizer.json
#  | | |-decoder_model
#  | | | |-decoder_model.onnx
#  | | |-tokenizer_config.json
#  | | |-encoder_model
#  | | | |-encoder_model.onnx
#  | | |-config.json
#  | |-pytorch_model.bin
#  | |-config.json

@fxmarty
Copy link
Collaborator

fxmarty commented Dec 14, 2022

Edit: nevermind this comment, this PR changes only the save_pretrained. How different is it from #255 ?

Great! I tried it, and in this case, I get several warnings:

The ONNX file encoder_model/encoder_model.onnx is not a regular name used in optimum.onnxruntime, the ORTModelForConditionalGeneration might not behave as expected.
The ONNX file decoder_model/decoder_model.onnx is not a regular name used in optimum.onnxruntime, the ORTModelForConditionalGeneration might not behave as expected.
The ONNX file decoder_with_past_model/decoder_with_past_model.onnx is not a regular name used in optimum.onnxruntime, the ORTModelForConditionalGeneration might not behave as expected.

image

@PoodleWang
Copy link

note: issue open here #605 @PoodleWang

let me create a new issue for it. They are different~

New issue here: #606

@NouamaneTazi
Copy link
Member Author

NouamaneTazi commented Dec 18, 2022

Added some tests for saving/loading from local folder and from hub
The tests which sae to hub such as test_push_ort_model_with_external_data_to_hub seem to fail when tested locally because it saves to my personal repo on the hub, then it tries to load it from hf-internal-testing. It's because of this part in push_to_hub. I'm wondering if it works fine on the CI and if we should do anything about it? 🤔

Otherwise the PR should be good to merge once all tests pass

To launch external data specific tests:

pytest ./tests/onnxruntime/test_modeling.py::ORTModelIntegrationTest -k "external"

Comment on lines 495 to 511
def test_save_seq2seq_model_with_external_data(self):
with tempfile.TemporaryDirectory() as tmpdirname:
# randomly intialize large model
config = AutoConfig.from_pretrained(self.LARGE_ONNX_SEQ2SEQ_MODEL_ID)
with no_init_weights():
model = MBartForConditionalGeneration(config)

# save transformers model to be able to load it with `ORTModel...`
model.save_pretrained(tmpdirname)

model = ORTModelForSeq2SeqLM.from_pretrained(tmpdirname, from_transformers=True)
model.save_pretrained(tmpdirname + "/onnx")

# Verify config and ONNX exported encoder, decoder and decoder with past are present each in their own folder
folder_contents = os.listdir(tmpdirname + "/onnx")
self.assertTrue(CONFIG_NAME in folder_contents)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test may slow down the CI a little. We could consider adding @slow decorators for such tests

cc @michaelbenayoun @fxmarty @mht-sharma

Copy link
Collaborator

@fxmarty fxmarty Dec 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, please do!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only test failing it seems. It probably fails because it tries to load the whole model and save convert it to external data. We're getting the error

worker 'gw0' crashed while running 'tests/onnxruntime/test_modeling.py::ORTModelIntegrationTest::test_save_seq2seq_model_with_external_data'

I'm wondering if we should just use a smaller model instead, and use FORCE_ONNX_EXTERNAL_DATA ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be good yes, that way we both test the feature and get a faster (and less prone to OOM) CI.

Comment on lines 332 to 351
onnx_model = onnx.load(str(output), load_external_data=False)
model_uses_external_data = check_model_uses_external_data(onnx_model)

if model_uses_external_data or FORCE_ONNX_EXTERNAL_DATA:
logger.info("Saving external data to one file...")

# try free model memory
del model
del onnx_model

onnx_model = onnx.load(
str(output), load_external_data=True
) # TODO: this will probably be too memory heavy, shall we free `model` memory?
onnx.save(
onnx_model,
str(output),
save_as_external_data=True,
all_tensors_to_one_file=True,
location=output.name + "_data",
size_threshold=1024 if not FORCE_ONNX_EXTERNAL_DATA else 0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we create a function for this? Probably would be cleaner.

Not sure how tf2onnx handles files >2GB. Could this be used in the export_tensorflow?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was testing mainly with pytorch. It'll be better if somebody else made another PR to apply the same modifications to export_tensorflow

@fxmarty
Copy link
Collaborator

fxmarty commented Dec 20, 2022

@NouamaneTazi Do you think it is likely to have it merged today/tomorrow? I think it should be in the release

@NouamaneTazi
Copy link
Member Author

Should be ready to merge once all tests pass @fxmarty 🙌

Copy link
Collaborator

@fxmarty fxmarty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just waiting for the tests

@NouamaneTazi NouamaneTazi merged commit 6da9e1a into huggingface:main Dec 22, 2022
@fxmarty fxmarty mentioned this pull request Dec 22, 2022
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Saving external data for > 2GB models
7 participants