-
Notifications
You must be signed in to change notification settings - Fork 323
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Fix] Fix FSDP bug #553
[Fix] Fix FSDP bug #553
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I leave a comment because I think fsdp will need extra logic to save the optimizer state and use it to resume training. I'm really sorry that there is a difference between the place where the comment was posted and the part where the new proposal is located. It worked fine in the environment we used, but if there are any problems, please leave a new comment. Thank you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm really sorry for leaving a comment that doesn't fit this PR, but can I ask you to add the sharding_strategy
option? (The sharding strategy option seems to have been added in version 1.12 of Torch.) If this request is accepted, it is likely that memory and time tradeoffs can be considered. Currently, shard_grad_op
and full_shard
are supported. Thank you.
@@ -16,7 +19,7 @@ | |||
|
|||
|
|||
@MODEL_WRAPPERS.register_module() | |||
class MMFullyShardedDataParallel(FullyShardedDataParallel): | |||
class MMFullyShardedDataParallel(FSDP): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if sharding_strategy is not None:
if isinstance(sharding_strategy, str):
assert sharding_strategy in ['shard_grad_op', 'full_shard']
sharding_strategy = (
ShardingStrategy.SHARD_GRAD_OP if sharding_strategy
== 'shard_grad_op' else ShardingStrategy.FULL_SHARD)
elif not isinstance(sharding_strategy, ShardingStrategy):
raise TypeError('`sharding_strategy` '
'should be `None`, `str` '
'or `ShardingStrategy`, but has type '
f'{type(sharding_strategy)}')
@yhna940 Great suggestions! First I apologize for my late response because I was on vacation ^_^ Your suggestions are correct and I am glad to see community's interests in requesting this FSDP feature in MMEngine. But this PR is originally aiming at resolving bugs for basic use cases without breaking BC, not supporting full FSDP I have actually planned for some new modifications:
These modifications are of relatively low priority because I am focusing on some other things (e.g. better Chinese/English documentations), and probably they will be carried out next month. But if you are interested in either of them, you are welcome to make contributions by making issues/PR(s) ^_^ |
mmengine/runner/runner.py
Outdated
|
||
# initialize the model weights | ||
self._init_model_weights() | ||
# make sure checkpoint-related hooks are triggered after `before_run` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The description "after before_run"
is not consistent with the code.
…fter before_run()
Codecov ReportBase: 78.40% // Head: 78.10% // Decreases project coverage by
Additional details and impacted files@@ Coverage Diff @@
## main #553 +/- ##
==========================================
- Coverage 78.40% 78.10% -0.31%
==========================================
Files 127 127
Lines 9175 9212 +37
Branches 1826 1838 +12
==========================================
+ Hits 7194 7195 +1
- Misses 1670 1702 +32
- Partials 311 315 +4
Flags with carried forward coverage won't be shown. Click here to find out more.
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
Waiting for this feature... |
# the former depends on the latter in FSDP | ||
self.optim_wrapper = self.build_optim_wrapper(self.optim_wrapper) | ||
# Automatically scaling lr by linear scaling rule | ||
self.scale_lr(self.optim_wrapper, self.auto_scale_lr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of the bugs
Hi @C1rN09 . I am deeply appreciative of your commitment! FSDP is critical for training large models.
However, I leveraged this commit to train my own model, and found two bugs when resuming checkpoints. The bugs are caused due to changes in the execution order of some functions. The bugs are fixed in pull request C1rN09#1. Below are the details.
Bug 1: RuntimeError
in self.scale_lr()
Description. Currently, the scale_lr()
method must be called before building the ParamScheduler
, or it will raise RuntimeError
. However, if there is a checkpoint, the ParamScheduler
will be built in self.load_or_resume()
, before we run self.scale_lr(self.optim_wrapper, self.auto_scale_lr)
.
Fix: I think scale_lr()
method should be modified to fix the bug, because the execution order of other functions cannot be changed. Concretely, since wrap_model()
should be called after load_or_resume()
to be compatible with FSDP, there must be:
load_or_resume()
->wrap_model()
->build_optim_wrapper()
->scale_lr()
.
Bug 2: state dict of optim_wrapper
will be loaded into CPU instead of cuda
Description. When we call self.load_or_resume()
, the state dict of optim_wrapper will be loaded into CPU, because the self.model
is in CPU. The original code does not have this problem, because self.load_or_resume()
is called after self.wrap_model()
where the model is sent to cuda
.
Fix: There could be many solutions, as long as model = model.to(get_device())
is inserted before self.load_or_resume()
. However, I cannot find an insertion place that makes the code as elegant as before. Therefore, I inserted it right before self.load_or_resume()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your bug report & solutions! Honestly speaking, this PR does not solve all issues in FSDP integration, especially the issue with checkpoint loading. The PR stops because we find that FSDP and some other frameworks (DeepSpeed, ColossalAI, etc.) are changing the execution order of model setup, initialization, checkpoint saving/loading, which spoil the whole thing. Sometimes they may even conflict with each other. Therefore, we are working on a more elegant way to solve this issue by refactoring Runner
, as discussed in Discussion topic.
This refactor is working in progress, and we'll take your bug reports into consideration. Hopefully it will come out soon, and you will be able to use FSDP in MMEngine then!
supported by #1213 |
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.
Motivation
Fix bugs in MMEngine's FSDP wrapper
Modification
Runner
train_step
,val_step
andtest_step
inMMFullyShardedDataParallel
to be consistent withBaseModel
AssertionError
caused by pytorch issue when some stages are frozen.BC-breaking (Optional)
The execution order of
Runner
has been changed, especially the hook point ofbefore_run
, as follows.This modification is potentially BC-breaking.
Use cases (Optional)
Checklist