Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add FlexibleRunner and Strategies #1183

Merged
merged 55 commits into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
d06b9db
Add FlexibleRunner and Strategy
zhouzaida Apr 18, 2023
f8881a1
refine Strategy and add DeepSpeedStrategy
zhouzaida May 5, 2023
5bb8ae4
refine docstring
zhouzaida May 6, 2023
3a0d04e
load_optim_state_dict only accpets the its state
zhouzaida May 6, 2023
59d8ad1
BaseStrategy supports resuming ckpt
zhouzaida May 8, 2023
835f05a
fix error
zhouzaida May 8, 2023
ef6e12e
add inputs_to_half method
zhouzaida May 12, 2023
7f8b9bf
delete the master_only decorator of save_checkpoints
zhouzaida May 23, 2023
f40377c
ds supports loading and saving checkpoints
zhouzaida May 30, 2023
ee2a84d
fix path
zhouzaida May 30, 2023
165c3e7
handle deepspeed version
zhouzaida May 30, 2023
d323d2f
ds does not save latest
zhouzaida May 30, 2023
431b929
move the calling of setup_up to __init__
zhouzaida May 31, 2023
43bfca2
remove ckpt
zhouzaida Jun 1, 2023
64c085e
rename filenames
zhouzaida Jun 2, 2023
dd6e689
refine prepare interface
zhouzaida Jun 2, 2023
1f8de6c
merge main branch
zhouzaida Jun 2, 2023
5a7822c
merge main branch
zhouzaida Jun 2, 2023
04b6709
rename filenames
zhouzaida Jun 2, 2023
f26aeff
update
zhouzaida Jun 2, 2023
5803608
import is_installed
zhouzaida Jun 2, 2023
275be9e
refactor the deepspeed wrapper
zhouzaida Jun 6, 2023
35f35fc
refactor the inheritance
zhouzaida Jun 6, 2023
31820d6
fix lint
zhouzaida Jun 6, 2023
f4855b4
support gradient_clipping for deepspeed
zhouzaida Jun 9, 2023
af70855
add deepspeed optimizers
zhouzaida Jun 12, 2023
61faa87
remove wrap_model from BaseStrategy
zhouzaida Jun 12, 2023
320d295
format docstring
zhouzaida Jun 15, 2023
118cc25
fix lint
zhouzaida Jun 15, 2023
7bd1a07
fix lint
zhouzaida Jun 15, 2023
22f608d
also infer launcher in Strategy
zhouzaida Jun 15, 2023
b104033
resume can be a string
zhouzaida Jun 15, 2023
c57b4ce
fix batch_size error
zhouzaida Jun 16, 2023
9e4d698
move the logic of loading ckpt behind prepare
zhouzaida Jun 17, 2023
5afd217
rename wrap_model to _wrap_model
zhouzaida Jun 17, 2023
d11256b
Merge branch 'main' of github.com:open-mmlab/mmengine into flexible-r…
zhouzaida Jun 17, 2023
2c2b287
add BaseOptimWrapper
zhouzaida Jun 18, 2023
1719810
refine
zhouzaida Jun 18, 2023
b5c2e89
refine
zhouzaida Jun 19, 2023
82b72e6
update docstring
zhouzaida Jun 19, 2023
bd1dd98
make the calling consistent
zhouzaida Jun 19, 2023
35f8a5a
update deepspeed docstring
zhouzaida Jun 19, 2023
4301de6
refine
zhouzaida Jun 20, 2023
9088abc
resolve comments
zhouzaida Jun 20, 2023
8436d00
rename DSOptimWrapper to DeepSpeedOptimWrapper
zhouzaida Jun 20, 2023
9d38c31
register FlexibleRunner to RUNNERS
zhouzaida Jun 21, 2023
d79636c
refactor the BaseOptimWrapper
zhouzaida Jun 25, 2023
10a6320
refactor the BaseOptimWrapper
zhouzaida Jun 25, 2023
1cedbe1
minor refine
zhouzaida Jun 25, 2023
9f54663
load_checkpoint also can initialize optim_wrapper status
zhouzaida Jun 25, 2023
f7c21a6
move _scale_lr to single device strategy
zhouzaida Jun 26, 2023
1d9ac21
rename setup_env to _setup_env
zhouzaida Jun 26, 2023
ac68242
deepspeed supports slurm
zhouzaida Jun 26, 2023
e613794
move convert_model to single device strategy
zhouzaida Jun 26, 2023
08cb129
move scale_lr to base strategy
zhouzaida Jun 27, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 17 additions & 0 deletions docs/en/api/strategy.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
.. role:: hidden
:class: hidden-section

mmengine._strategy
===================================

.. currentmodule:: mmengine._strategy

.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst

BaseStrategy
SingleDeviceStrategy
DDPStrategy
DeepSpeedStrategy
1 change: 1 addition & 0 deletions docs/en/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ You can switch between Chinese and English documents in the lower-left corner of
mmengine.registry <api/registry>
mmengine.config <api/config>
mmengine.runner <api/runner>
mmengine._strategy <api/strategy>
mmengine.hooks <api/hooks>
mmengine.model <api/model>
mmengine.optim <api/optim>
Expand Down
17 changes: 17 additions & 0 deletions docs/zh_cn/api/strategy.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
.. role:: hidden
:class: hidden-section

mmengine._strategy
===================================

.. currentmodule:: mmengine._strategy

.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst

BaseStrategy
SingleDeviceStrategy
DDPStrategy
DeepSpeedStrategy
1 change: 1 addition & 0 deletions docs/zh_cn/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
mmengine.registry <api/registry>
mmengine.config <api/config>
mmengine.runner <api/runner>
mmengine._strategy <api/strategy>
mmengine.hooks <api/hooks>
mmengine.model <api/model>
mmengine.optim <api/optim>
Expand Down
122 changes: 122 additions & 0 deletions examples/distributed_training_with_deepspeed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse

import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

from mmengine.evaluator import BaseMetric
from mmengine.model import BaseModel
from mmengine.runner._flexible_runner import FlexibleRunner


class MMResNet50(BaseModel):

def __init__(self):
super().__init__()
self.resnet = torchvision.models.resnet50()

def forward(self, imgs, labels, mode):
x = self.resnet(imgs)
if mode == 'loss':
return {'loss': F.cross_entropy(x, labels.long())}
elif mode == 'predict':
return x, labels


class Accuracy(BaseMetric):

def process(self, data_batch, data_samples):
score, gt = data_samples
self.results.append({
'batch_size': len(gt),
'correct': (score.argmax(dim=1) == gt).sum().cpu(),
})

def compute_metrics(self, results):
total_correct = sum(item['correct'] for item in results)
total_size = sum(item['batch_size'] for item in results)
return dict(accuracy=100 * total_correct / total_size)


def parse_args():
parser = argparse.ArgumentParser(description='Distributed Training')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default=None,
help='job launcher')
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
parser.add_argument('--local_rank', '--local-rank', type=int, default=0)

args = parser.parse_args()
return args


def main():
args = parse_args()
norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_set = torchvision.datasets.CIFAR10(
'/nvme/data/zhouzaida/codebases/data/cifar10',
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
train=True,
download=True,
transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)
]))
valid_set = torchvision.datasets.CIFAR10(
'/nvme/data/zhouzaida/codebases/data/cifar10',
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
train=False,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize(**norm_cfg)]))
train_dataloader = dict(
batch_size=128,
dataset=train_set,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'))
val_dataloader = dict(
batch_size=128,
dataset=valid_set,
sampler=dict(type='DefaultSampler', shuffle=False),
collate_fn=dict(type='default_collate'))
strategy = dict(
type='DeepSpeedStrategy',
fp16=dict(
enabled=True,
fp16_master_weights_and_grads=False,
loss_scale=0,
loss_scale_window=500,
hysteresis=2,
min_loss_scale=1,
initial_scale_power=15,
),
inputs_to_half=[0],
zero_optimization=dict(
stage=0,
allgather_partitions=True,
reduce_scatter=True,
allgather_bucket_size=50000000,
reduce_bucket_size=50000000,
overlap_comm=True,
contiguous_gradients=True,
cpu_offload=False))
runner = FlexibleRunner(
model=MMResNet50(),
work_dir='./work_dir',
strategy=strategy,
train_dataloader=train_dataloader,
optim_wrapper=dict(type='DSOptimWrapper', optimizer=dict(type='Adam')),
param_scheduler=dict(type='LinearLR'),
train_cfg=dict(by_epoch=True, max_epochs=20, val_interval=1),
val_dataloader=val_dataloader,
val_cfg=dict(),
val_evaluator=dict(type=Accuracy),
launcher=args.launcher)
runner.train()


if __name__ == '__main__':
main()
102 changes: 102 additions & 0 deletions examples/distributed_training_with_flexible_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse

import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.optim import SGD

from mmengine.evaluator import BaseMetric
from mmengine.model import BaseModel
from mmengine.runner._flexible_runner import FlexibleRunner


class MMResNet50(BaseModel):

def __init__(self):
super().__init__()
self.resnet = torchvision.models.resnet50()

def forward(self, imgs, labels, mode):
x = self.resnet(imgs)
if mode == 'loss':
return {'loss': F.cross_entropy(x, labels)}
elif mode == 'predict':
return x, labels


class Accuracy(BaseMetric):

def process(self, data_batch, data_samples):
score, gt = data_samples
self.results.append({
'batch_size': len(gt),
'correct': (score.argmax(dim=1) == gt).sum().cpu(),
})

def compute_metrics(self, results):
total_correct = sum(item['correct'] for item in results)
total_size = sum(item['batch_size'] for item in results)
return dict(accuracy=100 * total_correct / total_size)


def parse_args():
parser = argparse.ArgumentParser(description='Distributed Training')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
help='job launcher')
parser.add_argument('--local_rank', '--local-rank', type=int, default=0)

args = parser.parse_args()
return args


def main():
args = parse_args()
norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_set = torchvision.datasets.CIFAR10(
'/nvme/data/zhouzaida/codebases/data/cifar10',
train=True,
download=True,
transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)
]))
valid_set = torchvision.datasets.CIFAR10(
'/nvme/data/zhouzaida/codebases/data/cifar10',
train=False,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize(**norm_cfg)]))
train_dataloader = dict(
batch_size=128,
dataset=train_set,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'))
val_dataloader = dict(
batch_size=128,
dataset=valid_set,
sampler=dict(type='DefaultSampler', shuffle=False),
collate_fn=dict(type='default_collate'))
runner = FlexibleRunner(
model=MMResNet50(),
work_dir='./work_dir',
train_dataloader=train_dataloader,
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
param_scheduler=dict(type='LinearLR'),
train_cfg=dict(by_epoch=True, max_epochs=10, val_interval=1),
val_dataloader=val_dataloader,
val_cfg=dict(),
val_evaluator=dict(type=Accuracy),
launcher=args.launcher,
resume=True,
load_from='./work_dir/epoch_3.pth')
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
runner.train()


if __name__ == '__main__':
main()
11 changes: 11 additions & 0 deletions mmengine/_strategy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.utils import is_installed
from .base import BaseStrategy
from .distributed import DDPStrategy
from .single_device import SingleDeviceStrategy

__all__ = ['BaseStrategy', 'DDPStrategy', 'SingleDeviceStrategy']

if is_installed('deepspeed'):
from .deepspeed import DeepSpeedStrategy # noqa: F401
__all__.append('DeepSpeedStrategy')