Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[lazyinit] combine lazy tensor with dtensor #3204

Merged
merged 22 commits into from
Mar 23, 2023

Conversation

ver217
Copy link
Member

@ver217 ver217 commented Mar 22, 2023

📌 Checklist before creating the PR

  • I have created an issue for this PR for traceability
  • The title follows the standard format: [doc/gemini/tensor/...]: A concise description
  • I have added relevant tags if possible for us to better distinguish different PRs

🚨 Issue number

Link this PR to your issue with words like fixed to automatically close the linked issue upon merge

e.g. fixed #1234, closed #1234, resolved #1234

Closes #3148 , closes #3149

📝 What does this PR do?

Summarize your work here.
if you have any plots/diagrams/screenshots/tables, please attach them here.

Combine lazy tensor with dtensor. Now it provide distribute() method, which will shard the tensor using target layout.

Usage:

ctx = LazyInitContext()
with ctx:
    deferred_model = model_fn()
layout_dict = generate_layout_dict(deferred_model, device_mesh)
ctx.distribute(deferred_model, layout_dict, verbose=True)

Other important changes:

  1. Make LazyTensor.materialize() in-place.
  2. Set LazyTensor's data directly won't trigger early initialization any longer. E.g. x.data = torch.empty(10)

Known issues: Many hf models' embedding cannot be lazy initialized.

Besides, we test it on our model zoo. Here is a report:

model class param lazy rate buffer lazy rate non-lazy numel non-lazy numel percent
torchvision_alexnet 16/16 0/0 0.000 M 0.00%
torchvision_densenet 364/364 363/363 0.000 M 0.00%
torchvision_efficientnet_b 213/213 147/147 0.000 M 0.00%
torchvision_googlenet 187/187 177/177 0.000 M 0.00%
torchvision_inception_v 292/292 288/288 0.000 M 0.00%
torchvision_mobilenet_v 158/158 156/156 0.000 M 0.00%
torchvision_mobilenet_v 142/142 102/102 0.000 M 0.00%
torchvision_mnasnet 158/158 156/156 0.000 M 0.00%
torchvision_resnet 62/62 60/60 0.000 M 0.00%
torchvision_regnet_x_ 215/215 213/213 0.000 M 0.00%
torchvision_resnext 161/161 159/159 0.000 M 0.00%
torchvision_shufflenet_v 170/170 168/168 0.000 M 0.00%
torchvision_squeezenet 52/52 0/0 0.000 M 0.00%
torchvision_vgg 22/22 0/0 0.000 M 0.00%
torchvision_wide_resnet 161/161 159/159 0.000 M 0.00%
torchvision_vit_b_ 152/152 0/0 0.000 M 0.00%
torchvision_convnext_base 344/344 0/0 0.000 M 0.00%
torchvision_swin_s 173/173 0/12 0.027 M 0.10%
torchvision_efficientnet_v 452/452 330/330 0.000 M 0.00%
diffusers_auto_encoder_kl 92/92 0/0 0.000 M 0.00%
diffusers_vq_model 93/93 0/0 0.000 M 0.00%
diffusers_clip_model 398/398 2/2 0.000 M 0.00%
diffusers_clip_text_model 196/196 1/1 0.000 M 0.00%
diffusers_clip_vision_model 199/199 1/1 0.000 M 0.00%
diffusers_unet 432/432 0/0 0.000 M 0.00%
timm_resnet 263/263 213/213 0.000 M 0.00%
timm_beit 199/199 24/24 0.000 M 0.00%
timm_cait 476/476 0/0 0.000 M 0.00%
timm_convmixer 262/262 195/195 0.000 M 0.00%
timm_efficientnetv 649/649 471/471 0.000 M 0.00%
timm_resmlp 150/150 0/0 0.000 M 0.00%
timm_vision_transformer 152/152 0/0 0.000 M 0.00%
timm_deit 155/155 0/0 0.000 M 0.00%
timm_beitv 199/199 24/24 0.000 M 0.00%
timm_coat 152/152 0/0 0.000 M 0.00%
timm_deit 176/176 0/0 0.000 M 0.00%
timm_eca_nfnet 128/185 0/0 20.765 M 90.18%
timm_efficientformer 181/181 99/100 0.002 M 0.02%
timm_ese_vovnet 93/93 69/69 0.000 M 0.00%
timm_gmixer_ 102/150 0/0 7.633 M 63.02%
timm_gmlp_b 306/306 0/0 0.000 M 0.00%
timm_hardcorenas_a 138/138 102/102 0.000 M 0.00%
timm_hrnet_w 279/279 273/273 0.000 M 0.00%
timm_inception_v 284/284 282/282 0.000 M 0.00%
timm_mixer_b 150/150 0/0 0.000 M 0.00%
timm_nf_ecaresnet 243/347 0/0 40.431 M 95.16%
timm_nf_regnet_b 174/228 0/0 3.946 M 47.21%
timm_regnetv_ 293/293 198/198 0.000 M 0.00%
timm_skresnet 118/118 108/108 0.000 M 0.00%
timm_tnt_b_patch 351/351 0/0 0.000 M 0.00%
timm_wide_resnet 161/161 159/159 0.000 M 0.00%
timm_convit 180/180 0/0 0.000 M 0.00%
timm_dm_nfnet 176/233 0/0 44.327 M 65.02%
timm_convnext 344/344 0/0 0.000 M 0.00%
timm_vgg 22/22 0/0 0.000 M 0.00%
timm_dpn 217/217 216/216 0.000 M 0.00%
timm_densenet 364/364 363/363 0.000 M 0.00%
timm_rexnet 227/227 186/186 0.000 M 0.00%
timm_swin_transformer 329/329 11/35 0.055 M 0.07%
transformers_albert 24/25 2/2 3.662 M 94.29%
transformers_albert_for_pretraining 31/32 2/2 3.662 M 93.21%
transformers_albert_for_masked_lm 27/28 2/2 3.662 M 93.59%
transformers_albert_for_sequence_classification 26/27 2/2 3.662 M 94.28%
transformers_albert_for_token_classification 24/25 2/2 3.662 M 94.67%
transformers_albert_for_question_answering 24/25 2/2 3.662 M 94.67%
transformers_albert_for_multiple_choice 26/27 2/2 3.662 M 94.29%
transformers_bert 38/39 2/2 3.726 M 91.81%
transformers_bert_for_pretraining 45/46 2/2 3.726 M 90.79%
transformers_bert_lm_head_model 41/42 2/2 3.726 M 91.15%
transformers_bert_for_masked_lm 41/42 2/2 3.726 M 91.15%
transformers_bert_for_sequence_classification 40/41 2/2 3.726 M 91.80%
transformers_bert_for_token_classification 38/39 2/2 3.726 M 92.16%
transformers_bert_for_next_sentence 40/41 2/2 3.726 M 91.80%
transformers_bert_for_mcq 40/41 2/2 3.726 M 91.81%
transformers_gpt 28/28 4/4 0.000 M 0.00%
transformers_gpt_lm 28/28 4/4 0.000 M 0.00%
transformers_gpt_double_heads 30/30 4/4 0.000 M 0.00%
transformers_gpt_for_token_classification 30/30 4/4 0.000 M 0.00%
transformers_gpt_for_sequence_classification 29/29 4/4 0.000 M 0.00%
transformers_opt 35/36 0/0 6.137 M 76.52%
transformers_opt_for_causal_lm 35/36 0/0 6.137 M 76.52%
transformers_t 47/47 0/0 0.000 M 0.00%
transformers_t 47/47 0/0 0.000 M 0.00%
transformers_t 19/19 0/0 0.000 M 0.00%
torchaudio_conformer 120/120 12/12 0.000 M 0.00%
torchaudio_convtasnet 343/343 0/0 0.000 M 0.00%
torchaudio_deepspeech 18/18 0/0 0.000 M 0.00%
torchaudio_emformer 64/64 0/0 0.000 M 0.00%
torchaudio_wav 24/24 0/0 0.000 M 0.00%
torchaudio_wav 22/22 0/0 0.000 M 0.00%
torchaudio_wavernn 36/36 15/15 0.000 M 0.00%
torchaudio_tacotron 60/60 24/24 0.000 M 0.00%
deepfm_densearch 4/4 0/0 0.000 M 0.00%
deepfm_interactionarch 2/2 0/0 0.000 M 0.00%
deepfm_overarch 2/2 0/0 0.000 M 0.00%
deepfm_simpledeepfmnn 10/10 0/0 0.000 M 0.00%
deepfm_sparsearch 2/2 0/0 0.000 M 0.00%
dlrm 10/10 0/0 0.000 M 0.00%
dlrm_densearch 4/4 0/0 0.000 M 0.00%
dlrm_interactionarch 0/0 0/0 0.000 M 0.00%
dlrm_overarch 4/4 0/0 0.000 M 0.00%
dlrm_sparsearch 2/2 0/0 0.000 M 0.00%

Unit tests are skipped until we upgrade torch to 1.12. We run the tests on local:

image

💥 Checklist before requesting a review

  • I have linked my PR to an issue (instruction)
  • My issue clearly describes the problem/feature/proposal, with diagrams/charts/table/code if possible
  • I have performed a self-review of my code
  • I have added thorough tests.
  • I have added docstrings for all the functions/methods I implemented

⭐️ Do you enjoy contributing to Colossal-AI?

  • 🌝 Yes, I do.
  • 🌚 No, I don't.

Tell us more if you don't enjoy contributing to Colossal-AI.

@ver217 ver217 added enhancement New feature or request lazyinit Lazy initialization labels Mar 22, 2023
@FrankLeeeee FrankLeeeee merged commit f8289d4 into hpcaitech:main Mar 23, 2023
@ver217 ver217 deleted the feature/lazyinit-dist branch March 23, 2023 02:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request lazyinit Lazy initialization
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[lazyinit] add verification for distributed cases [lazyinit] combine lazy tensor with dtensor
3 participants