Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make _fast_init fast again (by surely skipping model weights init)! #26258

Closed
poedator opened this issue Sep 19, 2023 · 12 comments · Fixed by #27709
Closed

Make _fast_init fast again (by surely skipping model weights init)! #26258

poedator opened this issue Sep 19, 2023 · 12 comments · Fixed by #27709
Assignees

Comments

@poedator
Copy link
Contributor

poedator commented Sep 19, 2023

I observed that loading pre-trained model takes rather long, even when loading cached models from fast SSD. It is especially noticeable when dealing with LLMs with billions of weights.
Apparently, majority of the time is lost in this section of the code:

# Instantiate model.
    init_contexts = [no_init_weights(_enable=_fast_init)]
# (...)
with ContextManagers(init_contexts):
	model = cls(config, *model_args, **model_kwargs)

Time is spent on weights initialization (by torch.nn.init.kaiming_uniform_() and similar) is wasted, because the newly initialized weights will be then replaced by loaded ones. The no_init_weights context manager sets _init_weights global variable, but it gets ignored by model's code (tested on Llama_2_7B).

I recently discussed a similar issue with PEFT team, but there it was easier to solve, because in PEFT the init code was dealing with specific torch.nn layer. see huggingface/peft#871 and linked PRs by @BenjaminBossan. Here we need a model-scale solution.

One (not perfectly elegant one) - is to temporarily override methods like torch.nn.init.kaiming_uniform_(). It is used in our SpQR repo:

@contextmanager
def suspend_nn_inits():
    skip = lambda *args, **kwargs: None
    saved_inits = torch.nn.init.kaiming_uniform_, torch.nn.init.uniform_, torch.nn.init.normal_  # saving
    torch.nn.init.kaiming_uniform_ = torch.nn.init.uniform_ = torch.nn.init.normal_ = skip  # replacing
    try:
        yield
    finally:
        torch.nn.init.kaiming_uniform_, torch.nn.init.uniform_, torch.nn.init.normal_ = saved_inits  # restoring

but there may be better ones, using some native torch tools?

I'd be glad to contribute a PR with the maintainers blessing. Summoning @younesbelkada

System Info

A100-80G + SSD + mucho RAM and kernels.

Who can help?

@younesbelkada

Reproduction

load model, measure timing for this line

Expected behavior

faster loading

@BenjaminBossan
Copy link
Member

The no_init_weights context manager sets _init_weights global variable, but it gets ignored by model's code (tested on Llama_2_7B).

Interesting, could you please describe how you tested this? This sounds like a bug.

@poedator
Copy link
Contributor Author

Interesting, could you please describe how you tested this? This sounds like a bug.

Hi, @BenjaminBossan ,

This is how to test this slow loading issue:
select a model, large enough for the effect to be noticeable. I tested with meta-llama/Llama-2-7b-hf; load it as AutoModel.from_pretrained(), then delete - this fills the models cache.

Then try some or all of these:

  • load it again and notice time passed before Loading checkpoint shards: progress bar appears. Normally it should be few seconds or less.
  • compare overal command run time with Loading checkpoint shards: time. In my case it is 41s vs 2s. What takes the other 39s, if the model is cached on SSD already?
  • run AutoModel.from_pretrained() with profiler and see that uniform (i.e. weight init) process takes most of the time, though it is not needed for from_pretrained().
  • try loading model with disabled weight inits (using context manager, see notebook). In my case it reduced Llama2-7B loading time 10X (from 41s to 4s)

See my testing notebook as gist here: https://gist.github.com/poedator/792d6c7528a1bc5a84acb550268777ed

@BenjaminBossan
Copy link
Member

Thanks for providing the context and notebook. I could replicate your results and also confirmed that the model produces the same output in both cases. This smells like a bug to me, maybe @ArthurZucker can take a look.

@ArthurZucker
Copy link
Collaborator

definitely interesting, I'll have a look!

@younesbelkada
Copy link
Contributor

@poedator thanks a lot for the deep investigation - do you observe the same behaviour with low_cpu_mem_usage=True ? Looking at the gist it seems you are calling from_pretrained without any additional arguments - we should maybe start thinking of using that argument as default
I also went through SpQr repository you have shared, I have seen some community interest to support it natively on the HF ecosystem, I did not had a deep look into the repository, I wanted to ask if you think that it is possible, design-wise to integrate that into transformers ? cc @SunMarc FYI

@poedator
Copy link
Contributor Author

poedator commented Sep 25, 2023

@younesbelkada,
Whatever is behind low_cpu_mem_usage=True may be a good basis for the solution. I knew about it but hesitated to use because it does more magic than that (at least this was my impression from reading the doc). Please see, how much of low_cpu_mem_usage=True functionality can be included into default options. Hopefully it is a small fix.

Thank you for your interest in supporting SpQR in the HF ecosystem. Let me discuss with my teammates the best way to do this, and then I will get back to you.

@jph00
Copy link

jph00 commented Oct 10, 2023

One possible solution is mentioned here: #18505

@yuanenming
Copy link

I met the same issue. And I have another specific scenario, where I want to randomly initialize a large model for debug. So I just want a very fast initialization.

I tried:

config = AutoConfig.from_pretrained(model_name)
model = AutoModelForCausalLM.from_config(config)

I found it is even slower than just loading the weights:

model = AutoModelForCausalLM.from_pretrained(model_name, _fast_init=True, low_cpu_mem_usage=True)

So I wonder if there is a way to fast initialize a very large model (without any initialization algorithm) using from_config?

Thank you very much!

@huggingface huggingface deleted a comment from github-actions bot Nov 8, 2023
@ArthurZucker
Copy link
Collaborator

Ouch sorry about that! Was off for a bit, and it's planned! Will try to open a draft PR asap

@ArthurZucker
Copy link
Collaborator

Update 🤗
I'll tackle this as I can indeed reproduce and though we have the low_cpu_mem_usage flag that requires accelerate, this seems like a somewhat low-hanging fruit. We gotta make sure the weights that are missing from the state-dict are initialized ( non-persistant buffers etc).

@pacman100
Copy link
Contributor

On main branch of Transformers, I observe the following:

  1. low_cpu_mem_usage should resolve the issue coupled with _fast_init which is True by default.
  2. low_cpu_mem_usage internally calls accelerate's init_empty_weights which sets the weights on meta device leading to reset_parameters() being a no-op. If include_buffers=True, it just directly uses with torch.device("meta") context manager as suggested by Horace in the other linked issue.

Screenshot 2023-11-28 at 11 54 48 AM

@ArthurZucker
Copy link
Collaborator

The goal is to still have fast init without accelerate

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants