Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add colossalai strategy #1299

Merged
merged 14 commits into from
Aug 18, 2023
1 change: 0 additions & 1 deletion docs/en/api/optim.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ Optimizer
OptimWrapperDict
DefaultOptimWrapperConstructor
ZeroRedundancyOptimizer
DeepSpeedOptimWrapper

.. autosummary::
:toctree: generated
Expand Down
23 changes: 23 additions & 0 deletions docs/en/api/strategy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,26 @@ mmengine._strategy
DDPStrategy
DeepSpeedStrategy
FSDPStrategy
ColossalAIStrategy


.. currentmodule:: mmengine._strategy.deepspeed

.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst

MMDeepSpeedEngineWrapper
DeepSpeedOptimWrapper


.. currentmodule:: mmengine._strategy.colossalai

.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst

CollosalAIModelWrapper
ColossalAIOpitmWrapper
2 changes: 1 addition & 1 deletion docs/en/common_usage/large_model_training.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pip install deepspeed
After installing DeepSpeed, you need to configure the `strategy` and `optim_wrapper` parameters of FlexibleRunner as follows:

- strategy: Set `type='DeepSpeedStrategy'` and configure other parameters. See [DeepSpeedStrategy](mmengine._strategy.DeepSpeedStrategy) for more details.
- optim_wrapper: Set `type='DeepSpeedOptimWrapper'` and configure other parameters. See [DeepSpeedOptimWrapper](mmengine.optim.DeepSpeedOptimWrapper) for more details.
- optim_wrapper: Set `type='DeepSpeedOptimWrapper'` and configure other parameters. See [DeepSpeedOptimWrapper](mmengine._strategy.deepspeed.DeepSpeedOptimWrapper) for more details.

Here is an example configuration related to DeepSpeed:

Expand Down
1 change: 0 additions & 1 deletion docs/zh_cn/api/optim.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ Optimizer
OptimWrapperDict
DefaultOptimWrapperConstructor
ZeroRedundancyOptimizer
DeepSpeedOptimWrapper

.. autosummary::
:toctree: generated
Expand Down
23 changes: 23 additions & 0 deletions docs/zh_cn/api/strategy.rst
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,26 @@ mmengine._strategy
DDPStrategy
DeepSpeedStrategy
FSDPStrategy
ColossalAIStrategy


.. currentmodule:: mmengine._strategy.deepspeed

.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst

MMDeepSpeedEngineWrapper
DeepSpeedOptimWrapper


.. currentmodule:: mmengine._strategy.colossalai

.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst

CollosalAIModelWrapper
ColossalAIOpitmWrapper
2 changes: 1 addition & 1 deletion docs/zh_cn/common_usage/large_model_training.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pip install deepspeed
安装好 deepspeed 后,需配置 FlexibleRunner 的 strategy 和 optim_wrapper 参数:

- strategy:指定 `type='DeepSpeedStrategy'` 并配置参数。参数的详细介绍可阅读 [DeepSpeedStrategy](mmengine._strategy.DeepSpeedStrategy)。
- optim_wrapper:指定 `type='DeepSpeedOptimWrapper'` 并配置参数。参数的详细介绍可阅读 [DeepSpeedOptimWrapper](mmengine.optim.DeepSpeedOptimWrapper)。
- optim_wrapper:指定 `type='DeepSpeedOptimWrapper'` 并配置参数。参数的详细介绍可阅读 [DeepSpeedOptimWrapper](mmengine._strategy.deepspeed.DeepSpeedOptimWrapper)。

下面是 DeepSpeed 相关的配置:

Expand Down
4 changes: 3 additions & 1 deletion mmengine/_strategy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
from .base import BaseStrategy
from .colossalai import ColossalAIStrategy
from .deepspeed import DeepSpeedStrategy
from .distributed import DDPStrategy
from .single_device import SingleDeviceStrategy

__all__ = [
'BaseStrategy', 'DDPStrategy', 'SingleDeviceStrategy', 'DeepSpeedStrategy'
'BaseStrategy', 'DDPStrategy', 'SingleDeviceStrategy', 'DeepSpeedStrategy',
'ColossalAIStrategy'
]

if digit_version(TORCH_VERSION) >= digit_version('2.0.0'):
Expand Down