Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support torch ZeroRedundancyOptimizer #551

Merged
merged 25 commits into from Oct 27, 2022

Conversation

nijkah
Copy link
Contributor

@nijkah nijkah commented Sep 27, 2022

  • Device: V100 32GB (32510MiB) x4
  • Input Size: 512
  • Total Batch Size: 16

upernet_r50_4xb4-80k_ade20k-512x512

Method Time Memory
w/o zero 0.611 7053
w/ zero 0.552 6868

Time may be dependent to the environment.

Co-authored-by: Junhwa Song ethan9867@gmail.com @KKIEEK
Signed-off-by: Junhwa Song ethan9867@gmail.com
Signed-off-by: Hakjin Lee nijkah@gmail.com

  • Check reduced memory
  • Reproduce Performance
  • Check saving optimizer's state
  • Check loading optimizer's state

Co-authored-by: Junhwa Song <ethan9867@gmail.com>
Signed-off-by: Junhwa Song <ethan9867@gmail.com>
Signed-off-by: Hakjin Lee <nijkah@gmail.com>
@codecov
Copy link

codecov bot commented Sep 27, 2022

Codecov Report

Base: 78.07% // Head: 78.04% // Decreases project coverage by -0.03% ⚠️

Coverage data is based on head (7e91b88) compared to base (36af1f0).
Patch coverage: 61.11% of modified lines in pull request are covered.

❗ Current head 7e91b88 differs from pull request most recent head 00ed3b3. Consider uploading reports for the commit 00ed3b3 to get more accurate results

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #551      +/-   ##
==========================================
- Coverage   78.07%   78.04%   -0.04%     
==========================================
  Files         125      126       +1     
  Lines        8991     9009      +18     
  Branches     1845     1846       +1     
==========================================
+ Hits         7020     7031      +11     
- Misses       1659     1666       +7     
  Partials      312      312              
Flag Coverage Δ
unittests 78.04% <61.11%> (-0.04%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
mmengine/optim/optimizer/zero_optimizer.py 58.82% <58.82%> (ø)
mmengine/optim/optimizer/__init__.py 100.00% <100.00%> (ø)

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

@nijkah nijkah marked this pull request as draft September 27, 2022 04:29
@nijkah
Copy link
Contributor Author

nijkah commented Sep 27, 2022

  File "/workspace/mmengine/mmengine/runner/runner.py", line 1718, in call_hook                                                                                                                                          
    getattr(hook, fn_name)(self, **kwargs)                                                                                                                                                                                        
  File "/workspace/mmengine/mmengine/hooks/checkpoint_hook.py", line 488, in after_train_iter
    self._save_checkpoint(runner)                                                                                                                                                                                                 
  File "/workspace/mmengine/mmengine/dist/utils.py", line 346, in wrapper                                                                                                                                                
    return func(*args, **kwargs)                                                                                                                                                                                                  
  File "/workspace/mmengine/mmengine/hooks/checkpoint_hook.py", line 294, in _save_checkpoint                                                                                                                            
    **self.args)                                                                                                                                                                                                                  
  File "/workspace/mmengine/mmengine/runner/runner.py", line 2109, in save_checkpoint                                                                                                                                    
    ) else self.optim_wrapper.state_dict()                                                                                                                                                                                        
  File "/workspace/mmengine/mmengine/optim/optimizer/optimizer_wrapper.py", line 199, in state_dict                                                                                                                      
    return self.optimizer.state_dict()                                                                                                                                                                                            
  File "/opt/conda/lib/python3.7/site-packages/torch/distributed/optim/zero_redundancy_optimizer.py", line 1152, in state_dict                                                                                                    
    "Optimizer state has not been consolidated on this rank. "                                                                                                                                                                    
RuntimeError: Optimizer state has not been consolidated on this rank. Please call `consolidate_state_dict(to=0)` on all ranks beforehand if you meant to save the global state. 

I found that it has a problem when saving the optimizer's state_dict.

@nijkah nijkah marked this pull request as ready for review September 27, 2022 04:57
@nijkah
Copy link
Contributor Author

nijkah commented Sep 27, 2022

Solved the problem to call self.state_dict.

@HAOCHENYE
Copy link
Collaborator

Thank you for your contributions! It could be better to add a unit test to use the ZeroRedundancyOptimizer 😊.

@nijkah nijkah changed the title [Feature] Support torch ZeRORedundancyOptimizer [Feature] Support torch ZeroRedundancyOptimizer Sep 27, 2022
Comment on lines 733 to 734
if ZeroRedundancyOptimizer is None:
self.skipTest('ZeroRedundancyOptimizer is not available.')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if ZeroRedundancyOptimizer is None:
self.skipTest('ZeroRedundancyOptimizer is not available.')
if ZeroRedundancyOptimizer is None:
self.skipTest('ZeroRedundancyOptimizer is not available.')

Is this line duplicated with

@unittest.skipIf(
    digit_version(TORCH_VERSION) < digit_version('1.8.0'),
    reason='ZeRO needs Pytorch 1.8 or higher')

Copy link
Contributor Author

@nijkah nijkah Sep 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/open-mmlab/mmengine/actions/runs/3134972777/jobs/5090129146#step:8:132
I found that importing ZeroRedundancyOptimizer failed in the Windows CPU CI with & torch1.8.1.
(The importing failure made _ZeroRedundancyOptimizer as object.)
So I added duplicated skip code.

I'll check it again.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found that it has another condition.
torch.distributed.rpc should be available. I removed duplicated lines, and clarified this condition.

@nijkah
Copy link
Contributor Author

nijkah commented Sep 30, 2022

I found another bug.

Currently, saving the ZeroRedundancyOptimizer's state_dict gives the error.
Each ZeroRedundancyOptimizer should call consolidate_state_dict for all processes, and only the main process's optimizer should return the state_dict.

This constraint makes refactor the CheckpointHook.
I'll have to think about it a little more.

@nijkah
Copy link
Contributor Author

nijkah commented Sep 30, 2022

Hi, @HAOCHENYE @C1rN09

I changed the CheckpointHook's logic to remove master_only decorators when saving the checkpoint.
I think it doesn't affect the logic since save_checkpoint function in mmengine/runner/checkpoint.py is wrapped by master_only already.

Please notify me if the potential problem is expected.

@HAOCHENYE
Copy link
Collaborator

Hi, @HAOCHENYE @C1rN09

I changed the CheckpointHook's logic to remove master_only decorators when saving the checkpoint. I think it doesn't affect the logic since save_checkpoint function in mmengine/runner/checkpoint.py is wrapped by master_only already.

Please notify me if the potential problem is expected.

#553 also remove the master_only decorator to collect all weights from different rank. Does @C1rN09 have any idea about this?

mmengine/optim/optimizer/zero_optimizer.py Outdated Show resolved Hide resolved
mmengine/optim/optimizer/zero_optimizer.py Outdated Show resolved Hide resolved
mmengine/optim/optimizer/zero_optimizer.py Outdated Show resolved Hide resolved
'`torch.distributed.optim.ZeroReundancyOptimizer` is only '
'available when pytorch version >= 1.8.')
assert is_available(), 'torch.distributed.rpc is not available.'
optimizer_class = getattr(torch.optim, optimizer_type)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it support custom Optimizer classes?

Copy link
Contributor Author

@nijkah nijkah Oct 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still figuring it out now. Until now, it does not seem to have a specific dependency on torch's optimizers. It may be possible to custom Optimizer classes.

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
@zhouzaida
Copy link
Member

Hi @nijkah , the lint failed.

@nijkah
Copy link
Contributor Author

nijkah commented Oct 24, 2022

Hi @C1rN09 @zhouzaida , I'm concerning about the option overlap_with_ddp.

The public API docs demonstrate:

overlap_with_ddp (bool, optional) – if True, step() is overlapped with DistributedDataParallel ‘s gradient synchronization; this requires (1) either a functional optimizer for the optimizer_class argument or one with a functional equivalent and (2) registering a DDP communication hook constructed from one of the functions in ddp_zero_hook.py; parameters are packed into buckets matching those in DistributedDataParallel, meaning that the parameters_as_bucket_view argument is ignored. If False, step() runs disjointly after the backward pass (per normal). (default: False)

So if someone wants to use overlap_with_ddp option, there should be some modifications like registering a DDP communication hook, removing consolidate_state_dict, and so on.

Will it be better to fix it as overlap_with_ddp=False?

def state_dict(self):
"""Consolidate `state_dict`s from ranks to save the `state_dict`."""
self.consolidate_state_dict()
state_dict = super().state_dict() if is_main_process() else dict()
Copy link
Contributor Author

@nijkah nijkah Oct 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

state_dict['loss_scaler'] = self.loss_scaler.state_dict()

Due to this line, using ZeroRedundancyOptimizer with AmpOptimWrapper gave the error like

TypeError: 'NoneType' object does not support item assignment in <mmengine.hooks.checkpoint_hook.CheckpointHook object at XXXXXX>

So I modified it to return dict() instead of None when it is not the main process.

@C1rN09
Copy link
Collaborator

C1rN09 commented Oct 25, 2022

Hi @C1rN09 @zhouzaida , I'm concerning about the option overlap_with_ddp.

The public API docs demonstrate:

overlap_with_ddp (bool, optional) – if True, step() is overlapped with DistributedDataParallel ‘s gradient synchronization; this requires (1) either a functional optimizer for the optimizer_class argument or one with a functional equivalent and (2) registering a DDP communication hook constructed from one of the functions in ddp_zero_hook.py; parameters are packed into buckets matching those in DistributedDataParallel, meaning that the parameters_as_bucket_view argument is ignored. If False, step() runs disjointly after the backward pass (per normal). (default: False)

So if someone wants to use overlap_with_ddp option, there should be some modifications like registering a DDP communication hook, removing consolidate_state_dict, and so on.

Will it be better to fix it as overlap_with_ddp=False?

Hi! If there is no easy solution, I think it's acceptable to fix overlap_with_ddp=False in this PR. You may add a few lines of comments/TODO to explain why. Probably we could introduce this feature in the future.

@nijkah
Copy link
Contributor Author

nijkah commented Oct 25, 2022

@C1rN09 In some versions (e.g. 1.8.0), ZeroRedundancyOptimizer does not have a keyword argument for overlap_with_ddp.
So I just added a TODO comment.

@HAOCHENYE HAOCHENYE added the enhancement New feature or request label Oct 27, 2022
@C1rN09
Copy link
Collaborator

C1rN09 commented Oct 27, 2022

Hi, @nijkah I tested your branch on my cluster and got some different results, show as below:

Method Time Memory
w/o zero ~0.24 7053
w/ zero ~0.24 6868

What I'm confused about is the memory consumption. Since this model is only ~250MB, it seems that the maximum memory reduction that ZeroOptimizer can do with momentum SGD is 250MB * (1 + 1) * 0.75 ~ 375MB, which is close to my results. Could you provide your detailed modifications on models, configs, etc (if any), to help me find out what I've missed?

My cluster & configurations in case I miss something important:

  • Device: 4 x A100
  • Configs: same as upernet_r50_4xb4-80k_ade20k-512x512.py, except that I changed the dataset inheritence to cityscapes because my cluster has no ade20k dataset
  • Zero optimizer configs:
optimizer = dict(type='ZeroRedundancyOptimizer', optimizer_type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)

@nijkah
Copy link
Contributor Author

nijkah commented Oct 27, 2022

Hi, @C1rN09. I apologize for confusing this. 😞
You are totally correct. I found that I was confused about this today.

I was confused because I just tried to compare my result with the provided log in github.

After running the experiments again, I also could get the same result as you. I'll fix the description.

@zhouzaida zhouzaida merged commit 0857f9f into open-mmlab:main Oct 27, 2022
@nijkah
Copy link
Contributor Author

nijkah commented Oct 28, 2022

nit: Hi @C1rN09. Since the SGD optimizer with momentum only stores the model params from the previous step,
should the expected memory reduction be 250MB * (1) * 0.75 ~ 175MB, which is more close to the result?

@C1rN09
Copy link
Collaborator

C1rN09 commented Oct 28, 2022

nit: Hi @C1rN09. Since the SGD optimizer with momentum only stores the model params from the previous step, should the expected memory reduction be 250MB * (1) * 0.75 ~ 175MB, which is more close to the result?

Yes, I think you are right! From the experiment result I guess it only shards optimizer states, instead of optimizer states + grads, which I used to think it might do.

@nijkah nijkah deleted the zero_1_optimizer branch November 3, 2022 04:04
ly015 pushed a commit to ly015/mmengine that referenced this pull request Nov 9, 2022
* [Feature] Support torch ZeRORedundancyOptimizer

Co-authored-by: Junhwa Song <ethan9867@gmail.com>
Signed-off-by: Junhwa Song <ethan9867@gmail.com>
Signed-off-by: Hakjin Lee <nijkah@gmail.com>

* lint

* Fix saving optimizer state_dict

* Fix handling import error

* Add test case

* fix UT

* Revert "fix UT"

This reverts commit dd64538.

* fix handling import in UT

* Fix saving zero checkpoint and delete redundant master_only

* lint

* test unittest

* Fix handling impor error

* Fix UT condition

* Edit docstrings

* Fix typo

* Skip redundant procudure in checkpoint hook

* fix typo again

* Update mmengine/optim/optimizer/zero_optimizer.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Add api info

* lint

* Fix lint

* Handling AmpOptimWrapper case

* handling overlap_with_ddp

* Fix error

Signed-off-by: Junhwa Song <ethan9867@gmail.com>
Signed-off-by: Hakjin Lee <nijkah@gmail.com>
Co-authored-by: Junhwa Song <ethan9867@gmail.com>
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
@twmht
Copy link
Contributor

twmht commented Mar 28, 2023

@nijkah

how can i use ZeroRedundancyOptimizer with AmpOptimizer in the training config?

@nijkah
Copy link
Contributor Author

nijkah commented Mar 28, 2023

@nijkah

how can i use ZeroRedundancyOptimizer with AmpOptimizer in the training config?

I didn't test it yet. Is there any specific reason to use AmpOptimizer for 'mixed precision training'?
Unless it is necessary, I recommend you use AmpOptimWrapper. You can find a related configuration file and just change its optimizer part.

@twmht
Copy link
Contributor

twmht commented Mar 28, 2023

@nijkah

This is what i want

optimizer = dict(type='ZeroRedundancyOptimizer', optimizer_type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optim_wrapper = dict(type='AmpOptimWrapper', optimizer=optimizer, clip_grad=None)

Did you test that?

@nijkah
Copy link
Contributor Author

nijkah commented Mar 28, 2023

@twmht Yes, it should work! 😄
I tested it and this is a related line I edited. #551 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants