Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[shardformer] update hybrid parallel plugin and fix bugs #4612

Merged
merged 35 commits into from
Sep 5, 2023

Conversation

ver217
Copy link
Member

@ver217 ver217 commented Sep 4, 2023

📌 Checklist before creating the PR

  • I have created an issue for this PR for traceability
  • The title follows the standard format: [doc/gemini/tensor/...]: A concise description
  • I have added relevant tags if possible for us to better distinguish different PRs

🚨 Issue number

Link this PR to your issue with words like fixed to automatically close the linked issue upon merge

e.g. fixed #1234, closed #1234, resolved #1234

Closes #4583 , closes #4596

📝 What does this PR do?

Summarize your work here.
if you have any plots/diagrams/screenshots/tables, please attach them here.

  1. Implement all basics features of hybrid parallel plugin, including TP, PP, SP and (Zero) DP.
  2. Fix bugs.
  3. Implement checkpoint io of hybrid parallelism.
  4. Update bert finetuning example.

💥 Checklist before requesting a review

  • I have linked my PR to an issue (instruction)
  • My issue clearly describes the problem/feature/proposal, with diagrams/charts/table/code if possible
  • I have performed a self-review of my code
  • I have added thorough tests.
  • I have added docstrings for all the functions/methods I implemented

⭐️ Do you enjoy contributing to Colossal-AI?

  • 🌝 Yes, I do.
  • 🌚 No, I don't.

Tell us more if you don't enjoy contributing to Colossal-AI.

FoolPlayer and others added 30 commits August 16, 2023 15:41
* [shardformer/sequence parallel] Support sequence parallel for gpt2 (#4384)

* [sequence parallel] add sequence parallel linear col/row support (#4336)

* add sequence parallel linear col/row support

* add annotation

* add annotation

* add support for gpt2 fused qkv linear layer

* support sequence parallel in GPT2

* add docstring and note

* add requirments

* remove unused flash-attb

* modify flash attn test

* modify flash attn setting

* modify flash attn code

* add assert before divide, rename forward function

* [shardformer/test] fix gpt2 test with seq-parallel

* [shardformer/sequence parallel] Overlap input gather and grad computation during col backward (#4401)

* overlap gather input / grad computing during col backward

* modify test for overlap

* simplify code

* fix code and modify cuda stream synchronize

* [shardformer/sequence parallel] polish code
* support DDP for HybridPlugin/add tp+dp tests

* add docstring for HybridParallelPlugin
* [test] remove cpu marker

* [test] remove gpu marker

* [test] update pytest markers

* [ci] update unit test ci
* support interleaved pipeline

* fix unit test

* remove virtual stage test in stage mgr

* add droped type hint and updated bwd
…tp (#4460)

* support gpt2 seq parallel with pp/dp/tp

* fix a bug when waiting for stream done

* delete unused gpt2_seq file
[shardformer] bloom support sequence parallel
* [shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

* [shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

* [shardformer] bert support sequence parallel
* add some base tests and policies

* finish whisper base model

* add conditional generation

* finish basic tests

* whisper

* finish whisper

* finish whisper

* del useless  whisper test

* fix

* add argmin to replace

* finish revision
* support tp+zero/input type cast for hybridplugin

* add tp+zero tests

* fix bucket arguments
* [shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

* fix

fix

fix

fix
…ome fix. (#4498)

* [shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

* fix

fix

fix

fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* activate checks
* [shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

* fix

fix

fix

fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* activate checks

* [Test] test ci

* test ci

* test ci

* test ci

* test ci

* test ci

* test ci

* fix
…lelPlugin (#4506)

* add APIs

* implement save_sharded_model

* add test for hybrid checkpointio

* implement naive loading for sharded model

* implement efficient sharded model loading

* open a new file for hybrid checkpoint_io

* small fix

* fix circular importing

* fix docstring

* arrange arguments and apis

* small fix
* pause

* finish pp+zero1

* Update test_shard_vit.py
…on in shardco… (#4516)

* fix overlap bug and support bert, add overlap as an option in shardconfig

* support overlap for chatglm and bloom
* add overlap support for gpt2

* remove unused code

* remove unused code
* [shardformer] fix opt test hanging

* fix

* test

* test

* test

* fix test

* fix test

* remove print

* add fix
* [shardformer] fix opt test hanging

* fix

* test

* test

* test

* fix test

* fix test

* remove print

* add fix

* [shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

* [shardformer] pp+tp+zero1

* [shardformer] pp+tp+zero1

* [shardformer] pp+tp+zero1

* [shardformer] pp+tp+zero1
…lPlugin (#4540)

* implement sharded optimizer saving

* add more param info

* finish implementation of sharded optimizer saving

* fix bugs in optimizer sharded saving

* add pp+zero test

* param group loading

* greedy loading of optimizer

* fix bug when loading

* implement optimizer sharded saving

* add optimizer test & arrange checkpointIO utils

* fix gemini sharding state_dict

* add verbose option

* add loading of master params

* fix typehint

* fix master/working mapping in fp16 amp
…arallelPlugin (#4575)

* hybrid plugin support huggingface from_pretrained

* add huggingface compatibility tests

* add folder cleaning

* fix bugs
* pytree test

* test bert

* test bert

* test bert

* revise

* add register

* add register
…4584)

* [shardformer] fix opt test hanging

* fix

* test

* test

* test

* fix test

* fix test

* remove print

* add fix

* [shardformer] add bert finetune example

* [shardformer] add bert finetune example

* [shardformer] add bert finetune example

* [shardformer] add bert finetune example

* [shardformer] add bert finetune example

* [shardformer] add bert finetune example

* [shardformer] fix epoch change

* [shardformer] broadcast add pp group

* [shardformer] fix opt test hanging

* fix

* test

* test

* [shardformer] zero1+pp and the corresponding tests (#4517)

* pause

* finish pp+zero1

* Update test_shard_vit.py

* [shardformer/fix overlap bug] fix overlap bug, add overlap as an option in shardco… (#4516)

* fix overlap bug and support bert, add overlap as an option in shardconfig

* support overlap for chatglm and bloom

* [shardformer] fix emerged bugs after updating transformers (#4526)

* test

* fix test

* fix test

* remove print

* add fix

* [shardformer] add bert finetune example

* [shardformer] add bert finetune example

* [shardformer] Add overlap support for gpt2 (#4535)

* add overlap support for gpt2

* remove unused code

* remove unused code

* [shardformer] support pp+tp+zero1 tests (#4531)

* [shardformer] fix opt test hanging

* fix

* test

* test

* test

* fix test

* fix test

* remove print

* add fix

* [shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

* [shardformer] pp+tp+zero1

* [shardformer] pp+tp+zero1

* [shardformer] pp+tp+zero1

* [shardformer] pp+tp+zero1

* [shardformer] fix submodule replacement bug when enabling pp (#4544)

* [shardformer] support sharded optimizer checkpointIO of HybridParallelPlugin (#4540)

* implement sharded optimizer saving

* add more param info

* finish implementation of sharded optimizer saving

* fix bugs in optimizer sharded saving

* add pp+zero test

* param group loading

* greedy loading of optimizer

* fix bug when loading

* implement optimizer sharded saving

* add optimizer test & arrange checkpointIO utils

* fix gemini sharding state_dict

* add verbose option

* add loading of master params

* fix typehint

* fix master/working mapping in fp16 amp

* [shardformer] add bert finetune example

* [shardformer] add bert finetune example

* [shardformer] add bert finetune example

* [shardformer] add bert finetune example

* [shardformer] fix epoch change

* [shardformer] broadcast add pp group

* rebase feature/shardformer

* update pipeline

* [shardformer] fix

* [shardformer] fix

* [shardformer] bert finetune fix

* [shardformer] add all_reduce operation to loss

add all_reduce operation to loss

* [shardformer] make compatible with pytree.

make compatible with pytree.

* [shardformer] disable tp

disable tp

* [shardformer] add 3d plugin to ci test

* [shardformer] update num_microbatches to None

* [shardformer] update microbatchsize

* [shardformer] update assert

* update scheduler

* update scheduler

---------

Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com>
Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: Baizhou Zhang <eddiezhang@pku.edu.cn>
@ver217 ver217 added bug Something isn't working enhancement New feature or request example example-related issuer or pull request shardformer labels Sep 4, 2023
@github-actions
Copy link
Contributor

github-actions bot commented Sep 5, 2023

The code coverage for the changed files is 83%.

Click me to view the complete report
Name                                                                     Stmts   Miss  Cover
--------------------------------------------------------------------------------------------
colossalai/booster/plugin/gemini_plugin.py                                 123     12    90%
colossalai/booster/plugin/hybrid_parallel_plugin.py                        210     14    93%
colossalai/checkpoint_io/__init__.py                                         5      0   100%
colossalai/checkpoint_io/general_checkpoint_io.py                           91      8    91%
colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py                  336     32    90%
colossalai/checkpoint_io/utils.py                                          318     44    86%
colossalai/cluster/process_group_mesh.py                                    73      1    99%
colossalai/pipeline/p2p.py                                                  96      7    93%
colossalai/pipeline/schedule/_utils.py                                      75      7    91%
colossalai/pipeline/schedule/interleaved_pp.py                             170     10    94%
colossalai/pipeline/schedule/one_f_one_b.py                                139      6    96%
colossalai/pipeline/stage_manager.py                                        50      0   100%
colossalai/shardformer/layer/_operation.py                                 298    142    52%
colossalai/shardformer/layer/linear.py                                     190     53    72%
colossalai/shardformer/layer/parallel_module.py                             72     20    72%
colossalai/shardformer/layer/qkv_fused_linear.py                           300     75    75%
colossalai/shardformer/modeling/bert.py                                    486    147    70%
colossalai/shardformer/modeling/bloom.py                                   458    141    69%
colossalai/shardformer/modeling/chatglm2.py                                183     41    78%
colossalai/shardformer/modeling/gpt2.py                                    397    307    23%
colossalai/shardformer/modeling/whisper.py                                 319     69    78%
colossalai/shardformer/policies/auto_policy.py                              27      2    93%
colossalai/shardformer/policies/base_policy.py                              85      7    92%
colossalai/shardformer/policies/bert.py                                    261      0   100%
colossalai/shardformer/policies/blip2.py                                    49      2    96%
colossalai/shardformer/policies/bloom.py                                   155      2    99%
colossalai/shardformer/policies/chatglm2.py                                107      6    94%
colossalai/shardformer/policies/gpt2.py                                    185     42    77%
colossalai/shardformer/policies/llama.py                                   118      3    97%
colossalai/shardformer/policies/opt.py                                     144      2    99%
colossalai/shardformer/policies/sam.py                                      32      0   100%
colossalai/shardformer/policies/t5.py                                      181      5    97%
colossalai/shardformer/policies/vit.py                                     112      1    99%
colossalai/shardformer/policies/whisper.py                                 200      9    96%
colossalai/shardformer/shard/shard_config.py                                36      2    94%
colossalai/shardformer/shard/sharder.py                                     96      3    97%
colossalai/zero/gemini/gemini_ddp.py                                       400     77    81%
colossalai/zero/gemini/gemini_optimizer.py                                 392     39    90%
colossalai/zero/low_level/low_level_optim.py                               340     24    93%
tests/kit/model_zoo/transformers/__init__.py                                12      0   100%
tests/kit/model_zoo/transformers/chatglm2.py                                20      0   100%
tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py       87      0   100%
tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py          56      1    98%
tests/test_config/test_load_config.py                                        9      0   100%
tests/test_context/test_hybrid_parallel.py                                 105     25    76%
tests/test_data/test_cifar10_dataset.py                                     14      1    93%
tests/test_data/test_data_parallel_sampler.py                               35      1    97%
tests/test_data/test_deterministic_dataloader.py                            34      1    97%
tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py           17      0   100%
tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py      20      2    90%
tests/test_pipeline/test_schedule/test_interleaved.py                       99      1    99%
tests/test_pipeline/test_schedule/test_oneF_oneB.py                         80      2    98%
tests/test_pipeline/test_stage_manager.py                                   45      1    98%
tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py          94      1    99%
tests/test_shardformer/test_layer/test_linear_1d.py                        116      1    99%
tests/test_shardformer/test_model/_utils.py                                182     30    84%
tests/test_shardformer/test_model/test_shard_bert.py                        81     12    85%
tests/test_shardformer/test_model/test_shard_bloom.py                       80     12    85%
tests/test_shardformer/test_model/test_shard_chatglm2.py                    80     11    86%
tests/test_shardformer/test_model/test_shard_gpt2.py                        83     54    35%
tests/test_shardformer/test_model/test_shard_llama.py                       83     12    86%
tests/test_shardformer/test_model/test_shard_opt.py                         82     11    87%
tests/test_shardformer/test_model/test_shard_t5.py                          80     11    86%
tests/test_shardformer/test_model/test_shard_vit.py                         79     11    86%
tests/test_shardformer/test_model/test_shard_whisper.py                     88     14    84%
tests/test_utils/test_activation_checkpointing.py                           81      1    99%
--------------------------------------------------------------------------------------------
TOTAL                                                                     9151   1578    83%

@github-actions
Copy link
Contributor

github-actions bot commented Sep 5, 2023

The code coverage for the changed files is 82%.

Click me to view the complete report
Name                                                                    Stmts   Miss  Cover
-------------------------------------------------------------------------------------------
colossalai/booster/plugin/gemini_plugin.py                                123     12    90%
colossalai/booster/plugin/hybrid_parallel_plugin.py                       210     14    93%
colossalai/checkpoint_io/__init__.py                                        5      0   100%
colossalai/checkpoint_io/general_checkpoint_io.py                          91      9    90%
colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py                 336     32    90%
colossalai/checkpoint_io/utils.py                                         318     59    81%
colossalai/cluster/process_group_mesh.py                                   73      1    99%
colossalai/pipeline/p2p.py                                                 96      7    93%
colossalai/pipeline/schedule/_utils.py                                     75      7    91%
colossalai/pipeline/schedule/interleaved_pp.py                            170     10    94%
colossalai/pipeline/schedule/one_f_one_b.py                               139      6    96%
colossalai/pipeline/stage_manager.py                                       50      0   100%
colossalai/shardformer/layer/_operation.py                                298    142    52%
colossalai/shardformer/layer/linear.py                                    190     53    72%
colossalai/shardformer/layer/parallel_module.py                            72     20    72%
colossalai/shardformer/layer/qkv_fused_linear.py                          300     75    75%
colossalai/shardformer/modeling/bert.py                                   486    147    70%
colossalai/shardformer/modeling/bloom.py                                  458    141    69%
colossalai/shardformer/modeling/chatglm2.py                               183     41    78%
colossalai/shardformer/modeling/gpt2.py                                   397    307    23%
colossalai/shardformer/modeling/whisper.py                                319     69    78%
colossalai/shardformer/policies/auto_policy.py                             27      2    93%
colossalai/shardformer/policies/base_policy.py                             85     10    88%
colossalai/shardformer/policies/bert.py                                   261      0   100%
colossalai/shardformer/policies/blip2.py                                   49      2    96%
colossalai/shardformer/policies/bloom.py                                  155      2    99%
colossalai/shardformer/policies/chatglm2.py                               107      6    94%
colossalai/shardformer/policies/gpt2.py                                   185     42    77%
colossalai/shardformer/policies/llama.py                                  118      3    97%
colossalai/shardformer/policies/opt.py                                    144      2    99%
colossalai/shardformer/policies/sam.py                                     32      0   100%
colossalai/shardformer/policies/t5.py                                     181      5    97%
colossalai/shardformer/policies/vit.py                                    112      1    99%
colossalai/shardformer/policies/whisper.py                                200      9    96%
colossalai/shardformer/shard/shard_config.py                               36      2    94%
colossalai/shardformer/shard/sharder.py                                    96      3    97%
colossalai/zero/gemini/gemini_ddp.py                                      400     77    81%
colossalai/zero/gemini/gemini_optimizer.py                                392     39    90%
colossalai/zero/low_level/low_level_optim.py                              340     24    93%
tests/kit/model_zoo/transformers/__init__.py                               12      0   100%
tests/kit/model_zoo/transformers/chatglm2.py                               20      0   100%
tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py      87      0   100%
tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py         56      1    98%
tests/test_context/test_hybrid_parallel.py                                105     25    76%
tests/test_data/test_data_parallel_sampler.py                              35      1    97%
tests/test_data/test_deterministic_dataloader.py                           34      1    97%
tests/test_pipeline/test_schedule/test_interleaved.py                      99      1    99%
tests/test_pipeline/test_schedule/test_oneF_oneB.py                        80      2    98%
tests/test_pipeline/test_stage_manager.py                                  45      1    98%
tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py         94      1    99%
tests/test_shardformer/test_layer/test_linear_1d.py                       116      1    99%
tests/test_shardformer/test_model/_utils.py                               182     30    84%
tests/test_shardformer/test_model/test_shard_bert.py                       81     12    85%
tests/test_shardformer/test_model/test_shard_bloom.py                      80     12    85%
tests/test_shardformer/test_model/test_shard_chatglm2.py                   80     11    86%
tests/test_shardformer/test_model/test_shard_gpt2.py                       83     54    35%
tests/test_shardformer/test_model/test_shard_llama.py                      83     12    86%
tests/test_shardformer/test_model/test_shard_opt.py                        82     11    87%
tests/test_shardformer/test_model/test_shard_t5.py                         80     11    86%
tests/test_shardformer/test_model/test_shard_vit.py                        79     11    86%
tests/test_shardformer/test_model/test_shard_whisper.py                    88     14    84%
-------------------------------------------------------------------------------------------
TOTAL                                                                    9010   1593    82%

@TongLi3701 TongLi3701 self-requested a review September 5, 2023 15:15
@ver217 ver217 merged commit efba0f4 into main Sep 5, 2023
9 of 10 checks passed
@ver217 ver217 deleted the feature/shardformer branch September 5, 2023 15:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement New feature or request example example-related issuer or pull request shardformer
Projects
None yet
8 participants