Skip to content

Commit

Permalink
[Experimental] Add FlexibleRunner and Strategies (#1183)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhouzaida committed Jun 27, 2023
1 parent 04b0ffe commit 1c3f9f7
Show file tree
Hide file tree
Showing 26 changed files with 3,906 additions and 145 deletions.
17 changes: 17 additions & 0 deletions docs/en/api/strategy.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
.. role:: hidden
:class: hidden-section

mmengine._strategy
===================================

.. currentmodule:: mmengine._strategy

.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst

BaseStrategy
SingleDeviceStrategy
DDPStrategy
DeepSpeedStrategy
1 change: 1 addition & 0 deletions docs/en/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ You can switch between Chinese and English documents in the lower-left corner of
mmengine.registry <api/registry>
mmengine.config <api/config>
mmengine.runner <api/runner>
mmengine._strategy <api/strategy>
mmengine.hooks <api/hooks>
mmengine.model <api/model>
mmengine.optim <api/optim>
Expand Down
17 changes: 17 additions & 0 deletions docs/zh_cn/api/strategy.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
.. role:: hidden
:class: hidden-section

mmengine._strategy
===================================

.. currentmodule:: mmengine._strategy

.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst

BaseStrategy
SingleDeviceStrategy
DDPStrategy
DeepSpeedStrategy
1 change: 1 addition & 0 deletions docs/zh_cn/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
mmengine.registry <api/registry>
mmengine.config <api/config>
mmengine.runner <api/runner>
mmengine._strategy <api/strategy>
mmengine.hooks <api/hooks>
mmengine.model <api/model>
mmengine.optim <api/optim>
Expand Down
127 changes: 127 additions & 0 deletions examples/distributed_training_with_flexible_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse

import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.optim import SGD

from mmengine.evaluator import BaseMetric
from mmengine.model import BaseModel
from mmengine.runner._flexible_runner import FlexibleRunner


class MMResNet50(BaseModel):

def __init__(self):
super().__init__()
self.resnet = torchvision.models.resnet50()

def forward(self, imgs, labels, mode):
x = self.resnet(imgs)
if mode == 'loss':
return {'loss': F.cross_entropy(x, labels)}
elif mode == 'predict':
return x, labels


class Accuracy(BaseMetric):

def process(self, data_batch, data_samples):
score, gt = data_samples
self.results.append({
'batch_size': len(gt),
'correct': (score.argmax(dim=1) == gt).sum().cpu(),
})

def compute_metrics(self, results):
total_correct = sum(item['correct'] for item in results)
total_size = sum(item['batch_size'] for item in results)
return dict(accuracy=100 * total_correct / total_size)


def parse_args():
parser = argparse.ArgumentParser(description='Distributed Training')
parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
parser.add_argument('--use-deepspeed', action='store_true')

args = parser.parse_args()
return args


def main():
args = parse_args()
norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_set = torchvision.datasets.CIFAR10(
'data/cifar10',
train=True,
download=True,
transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)
]))
valid_set = torchvision.datasets.CIFAR10(
'data/cifar10',
train=False,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize(**norm_cfg)]))
train_dataloader = dict(
batch_size=128,
dataset=train_set,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'))
val_dataloader = dict(
batch_size=128,
dataset=valid_set,
sampler=dict(type='DefaultSampler', shuffle=False),
collate_fn=dict(type='default_collate'))

if args.use_deepspeed:
strategy = dict(
type='DeepSpeedStrategy',
fp16=dict(
enabled=True,
fp16_master_weights_and_grads=False,
loss_scale=0,
loss_scale_window=500,
hysteresis=2,
min_loss_scale=1,
initial_scale_power=15,
),
inputs_to_half=[0],
zero_optimization=dict(
stage=0,
allgather_partitions=True,
reduce_scatter=True,
allgather_bucket_size=50000000,
reduce_bucket_size=50000000,
overlap_comm=True,
contiguous_gradients=True,
cpu_offload=False))
optim_wrapper = dict(
type='DeepSpeedOptimWrapper',
optimizer=dict(type=SGD, lr=0.001, momentum=0.9))
else:
strategy = None
optim_wrapper = dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9))

runner = FlexibleRunner(
model=MMResNet50(),
work_dir='./work_dirs',
strategy=strategy,
train_dataloader=train_dataloader,
optim_wrapper=optim_wrapper,
param_scheduler=dict(type='LinearLR'),
train_cfg=dict(by_epoch=True, max_epochs=10, val_interval=1),
val_dataloader=val_dataloader,
val_cfg=dict(),
val_evaluator=dict(type=Accuracy))
runner.train()


if __name__ == '__main__':
main()
11 changes: 11 additions & 0 deletions mmengine/_strategy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.utils import is_installed
from .base import BaseStrategy
from .distributed import DDPStrategy
from .single_device import SingleDeviceStrategy

__all__ = ['BaseStrategy', 'DDPStrategy', 'SingleDeviceStrategy']

if is_installed('deepspeed'):
from .deepspeed import DeepSpeedStrategy # noqa: F401
__all__.append('DeepSpeedStrategy')

0 comments on commit 1c3f9f7

Please sign in to comment.