Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix FSDP bug #553

Closed
wants to merge 17 commits into from
Closed

[Fix] Fix FSDP bug #553

wants to merge 17 commits into from

Conversation

C1rN09
Copy link
Collaborator

@C1rN09 C1rN09 commented Sep 27, 2022

Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.

Motivation

Fix bugs in MMEngine's FSDP wrapper

Modification

  1. Change execution order in Runner

Before change: wrap_model --> build_optimizer --> init_weight
After change: init_weight --> wrap_model --> build_optimizer

  1. Modify train_step, val_step and test_step in MMFullyShardedDataParallel to be consistent with BaseModel
  2. Fix AssertionError caused by pytorch issue when some stages are frozen.

This fix is kind of tricky. We manually set all parameters to requires_grad=True, but the frozen layers are passed to FSDP in ignored_modules arguments, so that they are not in FSDP's parameters and not optimized by optimizer.
This may not be a perfect solution, and may need further improvement.

  1. Fix checkpoint saving issues.

BC-breaking (Optional)

The execution order of Runner has been changed, especially the hook point of before_run, as follows.

Before: wrap_model --> build_optimizer --> hooks.before_run --> init_weight
After: init_weight --> wrap_model --> build_optimizer --> hooks.before_run

This modification is potentially BC-breaking.

Use cases (Optional)

Checklist

  1. Pre-commit or other linting tools are used to fix the potential lint issues.
  2. The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness.
  3. If the modification has potential influence on downstream projects, this PR should be tested with downstream projects, like MMDet or MMCls.
  4. The documentation has been modified accordingly, like docstring or example tutorials.

@C1rN09 C1rN09 changed the title [WIP] Fix FSDP bug [Fix] Fix FSDP bug Sep 28, 2022
Copy link
Contributor

@yhna940 yhna940 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I leave a comment because I think fsdp will need extra logic to save the optimizer state and use it to resume training. I'm really sorry that there is a difference between the place where the comment was posted and the part where the new proposal is located. It worked fine in the environment we used, but if there are any problems, please leave a new comment. Thank you.

mmengine/runner/runner.py Show resolved Hide resolved
mmengine/runner/runner.py Show resolved Hide resolved
Copy link
Contributor

@yhna940 yhna940 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm really sorry for leaving a comment that doesn't fit this PR, but can I ask you to add the sharding_strategy option? (The sharding strategy option seems to have been added in version 1.12 of Torch.) If this request is accepted, it is likely that memory and time tradeoffs can be considered. Currently, shard_grad_op and full_shard are supported. Thank you.

@@ -16,7 +19,7 @@


@MODEL_WRAPPERS.register_module()
class MMFullyShardedDataParallel(FullyShardedDataParallel):
class MMFullyShardedDataParallel(FSDP):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

            if sharding_strategy is not None:
                if isinstance(sharding_strategy, str):
                    assert sharding_strategy in ['shard_grad_op', 'full_shard']
                    sharding_strategy = (
                        ShardingStrategy.SHARD_GRAD_OP if sharding_strategy
                        == 'shard_grad_op' else ShardingStrategy.FULL_SHARD)
                elif not isinstance(sharding_strategy, ShardingStrategy):
                    raise TypeError('`sharding_strategy` '
                                    'should be `None`, `str` '
                                    'or `ShardingStrategy`, but has type '
                                    f'{type(sharding_strategy)}')

@C1rN09
Copy link
Collaborator Author

C1rN09 commented Oct 6, 2022

@yhna940 Great suggestions! First I apologize for my late response because I was on vacation ^_^ Your suggestions are correct and I am glad to see community's interests in requesting this FSDP feature in MMEngine. But this PR is originally aiming at resolving bugs for basic use cases without breaking BC, not supporting full FSDP

I have actually planned for some new modifications:

  1. Support saving full optimizer states to enable resume training (exactly as you mentioned)
  2. Add config options for newly added FSDP arguments (e.g. sharding_strategy you have mentioned)
  3. Make FSDP arguments configurable in our config file

These modifications are of relatively low priority because I am focusing on some other things (e.g. better Chinese/English documentations), and probably they will be carried out next month. But if you are interested in either of them, you are welcome to make contributions by making issues/PR(s) ^_^

HAOCHENYE
HAOCHENYE previously approved these changes Oct 11, 2022
@C1rN09 C1rN09 added the ready ready to merge label Oct 13, 2022
@C1rN09 C1rN09 requested a review from zhouzaida October 17, 2022 06:45
mmengine/runner/runner.py Outdated Show resolved Hide resolved
RangiLyu
RangiLyu previously approved these changes Nov 2, 2022
yhna940
yhna940 previously approved these changes Nov 3, 2022
@zhouzaida zhouzaida modified the milestones: 0.3.0, 0.3.1 Nov 8, 2022

# initialize the model weights
self._init_model_weights()
# make sure checkpoint-related hooks are triggered after `before_run`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The description "after before_run" is not consistent with the code.

@C1rN09 C1rN09 dismissed stale reviews from yhna940 and RangiLyu via 58a5263 November 8, 2022 07:17
@codecov
Copy link

codecov bot commented Nov 8, 2022

Codecov Report

Base: 78.40% // Head: 78.10% // Decreases project coverage by -0.30% ⚠️

Coverage data is based on head (e039fcc) compared to base (618a063).
Patch coverage: 19.67% of modified lines in pull request are covered.

❗ Current head e039fcc differs from pull request most recent head edc936b. Consider uploading reports for the commit edc936b to get more accurate results

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #553      +/-   ##
==========================================
- Coverage   78.40%   78.10%   -0.31%     
==========================================
  Files         127      127              
  Lines        9175     9212      +37     
  Branches     1826     1838      +12     
==========================================
+ Hits         7194     7195       +1     
- Misses       1670     1702      +32     
- Partials      311      315       +4     
Flag Coverage Δ
unittests 78.10% <19.67%> (-0.31%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
mmengine/config/utils.py 58.49% <ø> (ø)
mmengine/hooks/checkpoint_hook.py 89.28% <0.00%> (ø)
mmengine/model/__init__.py 75.00% <0.00%> (ø)
mmengine/model/wrappers/__init__.py 70.00% <0.00%> (ø)
...engine/model/wrappers/fully_sharded_distributed.py 0.00% <0.00%> (ø)
mmengine/runner/runner.py 83.33% <47.61%> (-1.36%) ⬇️
mmengine/model/weight_init.py 35.50% <100.00%> (ø)
mmengine/version.py 58.33% <100.00%> (ø)

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

@SCZwangxiao
Copy link
Contributor

Waiting for this feature...

# the former depends on the latter in FSDP
self.optim_wrapper = self.build_optim_wrapper(self.optim_wrapper)
# Automatically scaling lr by linear scaling rule
self.scale_lr(self.optim_wrapper, self.auto_scale_lr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of the bugs

Hi @C1rN09 . I am deeply appreciative of your commitment! FSDP is critical for training large models.

However, I leveraged this commit to train my own model, and found two bugs when resuming checkpoints. The bugs are caused due to changes in the execution order of some functions. The bugs are fixed in pull request C1rN09#1. Below are the details.

Bug 1: RuntimeError in self.scale_lr()

Description. Currently, the scale_lr() method must be called before building the ParamScheduler, or it will raise RuntimeError. However, if there is a checkpoint, the ParamScheduler will be built in self.load_or_resume(), before we run self.scale_lr(self.optim_wrapper, self.auto_scale_lr).

Fix: I think scale_lr() method should be modified to fix the bug, because the execution order of other functions cannot be changed. Concretely, since wrap_model() should be called after load_or_resume() to be compatible with FSDP, there must be:
load_or_resume()->wrap_model()->build_optim_wrapper()->scale_lr().

Bug 2: state dict of optim_wrapper will be loaded into CPU instead of cuda

Description. When we call self.load_or_resume(), the state dict of optim_wrapper will be loaded into CPU, because the self.model is in CPU. The original code does not have this problem, because self.load_or_resume() is called after self.wrap_model() where the model is sent to cuda.

Fix: There could be many solutions, as long as model = model.to(get_device()) is inserted before self.load_or_resume(). However, I cannot find an insertion place that makes the code as elegant as before. Therefore, I inserted it right before self.load_or_resume().

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your bug report & solutions! Honestly speaking, this PR does not solve all issues in FSDP integration, especially the issue with checkpoint loading. The PR stops because we find that FSDP and some other frameworks (DeepSpeed, ColossalAI, etc.) are changing the execution order of model setup, initialization, checkpoint saving/loading, which spoil the whole thing. Sometimes they may even conflict with each other. Therefore, we are working on a more elegant way to solve this issue by refactoring Runner, as discussed in Discussion topic.

This refactor is working in progress, and we'll take your bug reports into consideration. Hopefully it will come out soon, and you will be able to use FSDP in MMEngine then!

@zhouzaida
Copy link
Member

supported by #1213

@zhouzaida zhouzaida closed this Jul 26, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Bug:P1 ready ready to merge
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants