Skip to content

Improve init_empty_weights to override tensor constructor#699

Merged
thomasw21 merged 6 commits intomainfrom
thomas/override_tensor_constructor
Sep 14, 2022
Merged

Improve init_empty_weights to override tensor constructor#699
thomasw21 merged 6 commits intomainfrom
thomas/override_tensor_constructor

Conversation

@thomasw21
Copy link
Contributor

Summary

init_empty_weights actually construct tensors in cpu and then moves them to meta. Instead we propose to construct tensors in meta device directly by overriding default constructors. This is inspired from https://github.com/microsoft/DeepSpeed/blob/c199edac8210e730acfd004c6e2bc3a98c0db903/deepspeed/utils/init_on_device.py This results in a faster loading mechanism with using the init_empty_weights context manager.

Additionally we override loading mechanism to return empty dictionary as there's no reason to read the checkpoint since everything is in meta (This is a hack as map_location="meta" doesn't work yet). Not sure if that's considered too hacky to be integrated inside accelerate

Running the following gets accelerated:

from timeit import timeit

from accelerate import init_empty_weights
from transformers import AutoModel

def init_empty_model():
    with init_empty_weights():
        AutoModel.from_pretrained("gpt2")

def main():
    print(timeit(init_empty_model, number=10))
    pass

if __name__ == "__main__":
    main()
without this PR: 16.627875665999998
with this PR: 6.471049583
with this PR (+ torch.load hack): 4.145832625

@thomasw21 thomasw21 requested a review from sgugger September 14, 2022 13:08
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 14, 2022

The documentation is not available anymore as the PR was closed or merged.

@sgugger
Copy link
Collaborator

sgugger commented Sep 14, 2022

I don't really understand why this is needed: while load a pretrained model inside the context manager and complain it takes time?

from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModel

config = AutoConfig.from_pretrained("gpt2")

with init_empty_weights():
    model = AutoModel.from_config(config)

is way faster than 6s

@thomasw21
Copy link
Contributor Author

Hum it's doing it 10 times, so 0.6 sec per load. Benchmarking your solution displays the same order of magnitude: 14.567796499999996 (I think the difference is that the use of config doesn't require to read from disk the checkpoint anymore)

Though your workaround removed the need to override torch.load and especially the hack that I introduced.

Using config + this PR: 4.1660165419999995

@thomasw21
Copy link
Contributor Author

Also I'm not sure why this wasn't detected, but the test tests/test_big_modeling.py::BigModelingTester::test_init_empty_weights pass on my MAC contrary to CI ... Moving back to draft as I need to figure this one out.

@thomasw21 thomasw21 marked this pull request as draft September 14, 2022 13:34
@thomasw21 thomasw21 marked this pull request as ready for review September 14, 2022 14:02
@thomasw21
Copy link
Contributor Author

Actually if we activate this feature only when include_buffers=True (I'm guessing the assumption is that all pytorch tensors are expected to be meta) then that should be fine.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants