Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Verify memory usage is not prohibitively high in the ONNX export #1012

Closed
fxmarty opened this issue Apr 24, 2023 · 6 comments
Closed

Verify memory usage is not prohibitively high in the ONNX export #1012

fxmarty opened this issue Apr 24, 2023 · 6 comments
Labels
feature-request New feature or request onnx Related to the ONNX export

Comments

@fxmarty
Copy link
Contributor

fxmarty commented Apr 24, 2023

Feature request

I have not checked whether the ONNX export does not e.g. triple the memory usage for decoder models. I believe it should not be the case, but it could be worth making sure we don't overly use RAM vs the vanilla torch.onnx.export, which would make the export of large models difficult.

cc @xenova

Motivation

/

Your contribution

/

@xenova
Copy link
Contributor

xenova commented May 19, 2023

Encountered another OOM issue for EleutherAI/gpt-neo-1.3B, even with 25GB of RAM on google colab.

Here's the colab RAM graph (until it was killed):
image

@fxmarty
Copy link
Contributor Author

fxmarty commented May 31, 2023

@xenova At which point do you OOM? I realize validation actually takes a fair bit amount of memory as the model outputs (and especially pkv) are written in memory, and may be large. For the export itself, it does not appear we do worth than vanilla torch.onnx.export.

from memory_profiler import memory_usage

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

from optimum.exporters.onnx import main_export
##
def f():
    model = AutoModelForCausalLM.from_pretrained("gpt2-large")
    
    tokenizer = AutoTokenizer.from_pretrained("gpt2-large")
    tokenizer.add_special_tokens({'pad_token': 'a'})
    
    fake_inp = tokenizer(["This is me", "This is you and me"], padding=True, return_tensors="pt")
    
    fake_inp = {"input_ids": fake_inp["input_ids"], "attention_mask": fake_inp["attention_mask"]}

    torch.onnx.export(
        model,
        (fake_inp,),
        f="fake.onnx",
        input_names=["input_ids", "attention_mask"],
        output_names=["logits", "past_key_values"],
        dynamic_axes={
            "input_ids": {0: "batch_size"},
            "attention_mask": {0: "batch_size"}, 
            "logits": {0: "batch_size"}, 
            "past_key_values": {0: "batch_size"}, 
        }
    )
##
def optimum_export():
    main_export("gpt2-large", "gpt2_large_onnx", no_post_process=True, task="text-generation", do_validation=False)

##

mem_usage = memory_usage(optimum_export)
print('Memory usage (in chunks of .1 seconds): %s' % mem_usage)
print('Maximum memory usage: %s' % max(mem_usage))

@xenova
Copy link
Contributor

xenova commented May 31, 2023

Will test now 👍 It may be good enough to just skip validation.

@xenova
Copy link
Contributor

xenova commented May 31, 2023

Can confirm OOM occurs during validation. Skipping validation seems to be a suitable workaround for now.

With validation:
image

Without validation:
image


Let me convert some larger models I've had problems with before (like whisper-large-v2; xenova/transformers.js#102) and I'll get back to you.

@xenova
Copy link
Contributor

xenova commented May 31, 2023

@fxmarty
Copy link
Contributor Author

fxmarty commented Jun 15, 2023

Fixed in #1111. Kind of shameful we had this bug...

@fxmarty fxmarty closed this as completed Jun 15, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature-request New feature or request onnx Related to the ONNX export
Projects
None yet
Development

No branches or pull requests

2 participants