Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add FlexibleRunner and Strategies #1183

Merged
merged 55 commits into from Jun 27, 2023

Conversation

zhouzaida
Copy link
Member

@zhouzaida zhouzaida commented Jun 2, 2023

Background

In the process of supporting FSDP, DeepSpeed, and ColossalAI, Runner's scalability has encountered challenges, mainly manifested in the following three aspects:

  1. Incompatibility with existing fixed training processes and new training methods like ZeRO.

    MMEngine achieves the unification of single-gpu training and DDP (Distributed Data Parallel) training processes, and this unified and fixed training process is written in code within the Runner. However, when supporting ZeRO series methods (FSDP, ColossalAI ZeroDDP, DeepSpeed ZeRO), this fixed training process becomes incompatible and requires adjustments in the order.

    For example, after model = FSDP(model) is applied, the parameters and buffers in the model are split across different GPUs, resulting in incompleteness. Any direct modification operations on the model (such as init_weights in MMEngine) will result in errors.

    Furthermore, even within the ZeRO series methods, there are differences in the implementation across different frameworks, requiring flexibility to adjust the order and dispatch based on different frameworks. For example, load_checkpoint in FSDP must be called before model = FSDP(model); while DeepSpeed and ColossalAI require it to be called after model = initialize(model).

  2. Coupling between training components in other frameworks (DeepSpeed, ColossalAI).

    In MMEngine, the Model Wrapper and Optim Wrapper are independent, while in DeepSpeed and ColossalAI, there is a coupling relationship between the model and optimizer, requiring mutual access to accomplish certain tasks. This can be observed in the colossalai.initialize and deepspeed.initialize interfaces.

  3. The unified save/load checkpoint function cannot meet the requirements of different models and frameworks.

    The current save_checkpoint and load_checkpoint are independent functions in Runner, with no association with the model, framework, etc., which is counterintuitive. For example, the FSDP training method requires collecting model parameters and optimizer states to GPU 0 before saving the model, while ColossalAI and DeepSpeed have their own complex logic for weight saving and loading.

Design

To avoid impacting the existing Runner, we will re-implement a FlexibleRunner and introduce a new abstract Strategy.

The Strategy is primarily responsible for:

  • Constructing and initializing training components such as the model, optimizer, parameter scheduler, etc.
  • Initializing the distributed training environment.
  • Saving and loading the model, optimizer state, etc.

This PR will support three types of strategies:

  • SingleDeviceStrategy
  • DDPStrategy
  • DeepSpeedStrategy

Note: This is an experimental feature, and the interface is subject to change.

Environment

  • PyTorch: 2.0.0
  • deepspeed: 0.9.3+d755b9d6
  • CUDA: 11.7
  • GPU: 8 * A100, 80G

Validation

  • resume
  • load_from

Experiment

MMPreTrain

  • vit-huge-p14_8xb128-coslr-50e_in1k.py

    DDP: Out of memory
    strategy = dict(
        type='DDPStrategy',
    )
    DDP + fp16: 58G per GPU
    optim_wrapper = dict(
        type='AmpOptimWrapper',
        optimizer=dict(
            type='AdamW',
            lr=0.004,
            weight_decay=0.05,
            eps=1e-08,
            betas=(0.9, 0.999)),
        paramwise_cfg=dict(
            norm_decay_mult=0.0,
            bias_decay_mult=0.0,
            flat_decay_mult=0.0,
            custom_keys=dict({
                '.absolute_pos_embed': dict(decay_mult=0.0),
                '.relative_position_bias_table': dict(decay_mult=0.0),
                '.ln': dict(decay_mult=0.0),
                '.bias': dict(decay_mult=0.0),
                '.cls_token': dict(decay_mult=0.0),
                '.pos_embed': dict(decay_mult=0.0)
            }),
            layer_decay_rate=0.75),
        constructor='LearningRateDecayOptimWrapperConstructor')
    strategy = dict(
        type='DDPStrategy',
    )
    DeepSpeed ZeRO1 + fp16: 44G per GPU
    strategy = dict(
        type='DeepSpeedStrategy',
        fp16=dict(
            enabled=True,
            fp16_master_weights_and_grads=False,
            loss_scale=0,
            loss_scale_window=500,
            hysteresis=2,
            min_loss_scale=1,
            initial_scale_power=15,
        ),
        inputs_to_half=['inputs'],
        zero_optimization=dict(
            stage=3,
            allgather_partitions=True,
            reduce_scatter=True,
            allgather_bucket_size=50000000,
            reduce_bucket_size=50000000,
            overlap_comm=True,
            contiguous_gradients=True,
            cpu_offload=False,
        )
    )
    DeepSpeed ZeRO3 + fp16:
    strategy = dict(
        type='DeepSpeedStrategy',
        fp16=dict(
            enabled=True,
            fp16_master_weights_and_grads=False,
            loss_scale=0,
            loss_scale_window=500,
            hysteresis=2,
            min_loss_scale=1,
            initial_scale_power=15,
        ),
        inputs_to_half=['inputs'],
        zero_optimization=dict(
            stage=3,
            allgather_partitions=True,
            reduce_scatter=True,
            allgather_bucket_size=50000000,
            reduce_bucket_size=50000000,
            overlap_comm=True,
            contiguous_gradients=True,
            cpu_offload=False,
        )
    )
  • vit-large-p16_8xb128-coslr-50e_in1k.py

    Deespeed ZeRO1+fp16: 21G per GPU, accuracy: 85.6040
    strategy = dict(
        type='DeepSpeedStrategy',
        fp16=dict(
            enabled=True,
            fp16_master_weights_and_grads=False,
            loss_scale=0,
            loss_scale_window=500,
            hysteresis=2,
            min_loss_scale=1,
            initial_scale_power=15,
        ),
        inputs_to_half=['inputs'],
        zero_optimization=dict(
            stage=3,
            allgather_partitions=True,
            reduce_scatter=True,
            allgather_bucket_size=50000000,
            reduce_bucket_size=50000000,
            overlap_comm=True,
            contiguous_gradients=True,
            cpu_offload=False,
        )
    )
    Deespeed ZeRO3
    strategy = dict(
        type='DeepSpeedStrategy',
        fp16=dict(
            enabled=True,
            fp16_master_weights_and_grads=False,
            loss_scale=0,
            loss_scale_window=500,
            hysteresis=2,
            min_loss_scale=1,
            initial_scale_power=15,
        ),
        inputs_to_half=['inputs'],
        zero_optimization=dict(
            stage=3,
            allgather_partitions=True,
            reduce_scatter=True,
            allgather_bucket_size=50000000,
            reduce_bucket_size=50000000,
            overlap_comm=True,
            contiguous_gradients=True,
            cpu_offload=False,
        )
    )

MMDet

config mAP time
FlexibleRunner + DDPStrategy Runner + DDP FlexibleRunner + DDPStrategy Runner + DDP
atss_r50_fpn_1x_coco.py 39.3 39.4
faster-rcnn_r50_fpn_1x_coco.py 37.4 37.4

@zhouzaida zhouzaida marked this pull request as draft June 2, 2023 06:24
@zhouzaida zhouzaida changed the base branch from flexible-runner to main June 2, 2023 08:20
@zhouzaida zhouzaida changed the base branch from main to flexible-runner June 2, 2023 08:21
examples/distributed_training_with_deepspeed.py Outdated Show resolved Hide resolved
examples/distributed_training_with_deepspeed.py Outdated Show resolved Hide resolved
examples/distributed_training_with_flexible_runner.py Outdated Show resolved Hide resolved
mmengine/_strategy/base.py Outdated Show resolved Hide resolved
mmengine/_strategy/base.py Outdated Show resolved Hide resolved
mmengine/_strategy/base.py Show resolved Hide resolved
mmengine/_strategy/deepspeed.py Show resolved Hide resolved
mmengine/_strategy/deepspeed.py Outdated Show resolved Hide resolved
mmengine/_strategy/deepspeed.py Outdated Show resolved Hide resolved
mmengine/_strategy/deepspeed.py Outdated Show resolved Hide resolved
mmengine/model/wrappers/_deepspeed.py Outdated Show resolved Hide resolved
mmengine/_strategy/single_device.py Outdated Show resolved Hide resolved
mmengine/_strategy/single_device.py Show resolved Hide resolved
mmengine/optim/optimizer/_deepspeed.py Outdated Show resolved Hide resolved
mmengine/_strategy/base.py Show resolved Hide resolved
mmengine/_strategy/base.py Show resolved Hide resolved
.gitignore Outdated Show resolved Hide resolved
mmengine/runner/_flexible_runner.py Show resolved Hide resolved
mmengine/runner/_flexible_runner.py Outdated Show resolved Hide resolved
mmengine/runner/_flexible_runner.py Outdated Show resolved Hide resolved
mmengine/_strategy/distributed.py Show resolved Hide resolved
mmengine/runner/_flexible_runner.py Outdated Show resolved Hide resolved
map_location: Union[str, Callable] = 'cpu',
strict: bool = False,
revise_keys: list = [(r'^module.', '')],
callback: Optional[Callable] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

callback is not used in DeepSpeedStrategy, is that expected?

@zhouzaida zhouzaida merged commit 1c3f9f7 into open-mmlab:flexible-runner Jun 27, 2023
16 of 19 checks passed
@zhouzaida zhouzaida deleted the flexible-runner branch July 3, 2023 02:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants