Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CodeCamp2023-470] Runner supports setting the number of iterations for each epoch #1292

Merged
merged 56 commits into from
Oct 8, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
b89259c
本次修改增加了num_batch_per_epoch功能
ShuRaymond Aug 4, 2023
be7230b
[Feature]Add num_batch_per_epoch
ShuRaymond Aug 4, 2023
75d8dff
[Feature] Add num_batch_per_epoch
ShuRaymond Aug 4, 2023
d7df66a
[Feature] Add num_batch_per_epoch
ShuRaymond Aug 4, 2023
b3f0032
[Feature] Add num_batch_per_epoch
ShuRaymond Aug 4, 2023
40c5252
add tests and edit md
ShuRaymond Aug 7, 2023
e025c53
fix bugs
ShuRaymond Aug 7, 2023
ea3db4f
fix bugs
ShuRaymond Aug 7, 2023
3a06ed8
fix bugs
ShuRaymond Aug 7, 2023
5c444c7
fix bugs
ShuRaymond Aug 7, 2023
2f40dc0
fix bugs
ShuRaymond Aug 7, 2023
8ef9a26
fix bugs
ShuRaymond Aug 7, 2023
a446035
add tests
ShuRaymond Aug 8, 2023
f9b27e4
add tests
ShuRaymond Aug 8, 2023
d598fe1
fix bugs
ShuRaymond Aug 8, 2023
7ee46ed
fix bugs
ShuRaymond Aug 8, 2023
f88b104
fix bugs
ShuRaymond Aug 8, 2023
8bb15b6
modify metrics
ShuRaymond Aug 9, 2023
e0aa78a
modify docstring
ShuRaymond Aug 9, 2023
143ab59
modify unit tests
ShuRaymond Aug 9, 2023
4ed72c9
modify unit tests
ShuRaymond Aug 9, 2023
bd035d7
modify unit tests
ShuRaymond Aug 9, 2023
05087cd
modify unit tests
ShuRaymond Aug 9, 2023
d6e220b
modify unit tests
ShuRaymond Aug 9, 2023
3311ebe
modify unit tests
ShuRaymond Aug 9, 2023
b4d45a3
rerun ci
ShuRaymond Aug 13, 2023
9ea37fc
rerun ci
ShuRaymond Aug 13, 2023
90a23bc
change method to support num_batch_per_epoch
ShuRaymond Aug 18, 2023
22345a9
delete invaild tests
ShuRaymond Aug 18, 2023
f54f1c1
delete invaild tests
ShuRaymond Aug 18, 2023
5b4b2b4
delete invaild tests
ShuRaymond Aug 18, 2023
4d3e2f0
delete invaild tests
ShuRaymond Aug 18, 2023
8ba88da
update the documentation
ShuRaymond Aug 18, 2023
96ec4d6
update the documentation
ShuRaymond Aug 19, 2023
5c57054
fix
ShuRaymond Aug 19, 2023
0050ff7
Modify the variable name
ShuRaymond Aug 23, 2023
77bcad1
solve the conflicts
ShuRaymond Aug 27, 2023
d1a5456
Merge branch 'main' into dev
ShuRaymond Aug 27, 2023
36519f6
Update debug_tricks.md
ShuRaymond Aug 27, 2023
3fa23ec
Update debug_tricks.md
ShuRaymond Aug 27, 2023
deef4ad
modify the doc and runner.py
ShuRaymond Aug 30, 2023
7e31da4
modify the doc and runner.py
ShuRaymond Aug 30, 2023
988d790
Merge remote-tracking branch 'origin/dev' into dev
ShuRaymond Aug 30, 2023
adb92ef
Merge branch 'open-mmlab:main' into dev
ShuRaymond Aug 30, 2023
f01570b
modify the doc and runner.py
ShuRaymond Aug 30, 2023
de6fd78
modify the doc and runner.py
ShuRaymond Aug 30, 2023
53bc8e0
modify the doc and runner.py
ShuRaymond Aug 30, 2023
ac6e046
Merge remote-tracking branch 'origin/dev' into dev
ShuRaymond Aug 30, 2023
2881405
modify the doc and runner.py
ShuRaymond Aug 30, 2023
4eb168f
Update debug_tricks.md
zhouzaida Sep 1, 2023
9540f21
Update distributed_training.py
zhouzaida Sep 1, 2023
15f9d85
Update debug_tricks.md
zhouzaida Sep 1, 2023
379600b
Update tests/test_runner/test_runner.py
zhouzaida Sep 1, 2023
36ee77d
Minor refine
HAOCHENYE Oct 7, 2023
9ccbb3f
Merge remote-tracking branch 'origin/main' into dev
HAOCHENYE Oct 7, 2023
686ebfb
Merge remote-tracking branch 'origin/main' into dev
HAOCHENYE Oct 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
101 changes: 101 additions & 0 deletions docs/en/common_usage/debug_tricks.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,104 @@ As we can see, the number of iterations has changed to `313`. Compared to before
02/20 14:44:59 - mmengine - INFO - Epoch(train) [1][200/313] lr: 1.0000e-01 eta: 0:23:18 time: 0.0143 data_time: 0.0002 memory: 214 loss: 2.0424
02/20 14:45:01 - mmengine - INFO - Epoch(train) [1][300/313] lr: 1.0000e-01 eta: 0:20:39 time: 0.0143 data_time: 0.0003 memory: 214 loss: 1.814
```

## Add cfg parameter
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
## Add cfg parameter
## Training for a fixed number of iterations (epoch-based training)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I have done these changes.


During the process of debugging code, sometimes it is necessary to train for several epochs, such as debugging the validation process or checking whether the checkpoint saving meets expectations. However, if the dataset is too large, it may take a long time to complete one epoch, in which case the cfg parameter can be added.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
During the process of debugging code, sometimes it is necessary to train for several epochs, such as debugging the validation process or checking whether the checkpoint saving meets expectations. However, if the dataset is too large, it may take a long time to complete one epoch, in which case the cfg parameter can be added.
During the process of debugging code, sometimes it is necessary to train for several epochs, such as debugging the validation process or checking whether the checkpoint saving meets expectations. However, if the dataset is too large, it may take a long time to complete one epoch, in which case the `num_batch_per_epoch` could be configured:

Take `MMEngine` as an example(Refer to the [documentation](https://mmengine.readthedocs.io/zh_CN/latest/get_started/installation.html)for installing MMEngine)。

Example of a training script
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Take `MMEngine` as an example(Refer to the [documentation](https://mmengine.readthedocs.io/zh_CN/latest/get_started/installation.html)for installing MMEngine)。
Example of a training script


```python
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.optim import SGD
from torch.utils.data import DataLoader

from mmengine.evaluator import BaseMetric
from mmengine.model import BaseModel
from mmengine.runner import FlexibleRunner


class MMResNet50(BaseModel):
def __init__(self):
super().__init__()
self.resnet = torchvision.models.resnet50()

def forward(self, imgs, labels, mode):
x = self.resnet(imgs)
if mode == 'loss':
return {'loss': F.cross_entropy(x, labels)}
elif mode == 'predict':
return x, labels


class Accuracy(BaseMetric):
def process(self, data_batch, data_samples):
score, gt = data_samples
self.results.append({
'batch_size': len(gt),
'correct': (score.argmax(dim=1) == gt).sum().cpu(),
})

def compute_metrics(self, results):
total_correct = sum(item['correct'] for item in results)
total_size = sum(item['batch_size'] for item in results)
return dict(accuracy=100 * total_correct / total_size)


norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_dataloader = DataLoader(batch_size=32,
shuffle=True,
dataset=torchvision.datasets.CIFAR10(
'data/cifar10',
train=True,
download=True,
transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)
])))

val_dataloader = DataLoader(batch_size=32,
shuffle=False,
dataset=torchvision.datasets.CIFAR10(
'data/cifar10',
train=False,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)
])))

runner = FlexibleRunner(
model=MMResNet50(),
work_dir='./work_dir',
train_dataloader=train_dataloader,
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1, num_batch_per_epoch=10),
val_dataloader=val_dataloader,
val_cfg=dict(num_batch_per_epoch=10),
val_evaluator=dict(type=Accuracy)
)
runner.train()
```

Fast debugging is achieved by adding the `num_batch_per_epoch` parameter to `train_cfg` and `val_cfg`.

Run the training script. You can see that after running each epoch run 10 batch is over. Compared to the original, debugging is faster and more flexible.

```
08/07 16:35:45 - mmengine - INFO - Epoch(train) [1][ 10/1563] lr: 1.0000e-03 eta: 1:15:03 time: 0.5770 data_time: 0.0075 memory: 477 loss: 5.0847
08/07 16:35:45 - mmengine - INFO - Saving checkpoint at 1 epochs
08/07 16:35:46 - mmengine - INFO - Epoch(val) [1][ 10/313] eta: 0:00:03 time: 0.0131 data_time: 0.0037 memory: 477
08/07 16:35:46 - mmengine - INFO - Epoch(val) [1][313/313] accuracy: 13.0682 data_time: 0.0038 time: 0.0130
08/07 16:35:46 - mmengine - INFO - Epoch(train) [2][ 10/1563] lr: 1.0000e-03 eta: 0:38:13 time: 0.0360 data_time: 0.0066 memory: 477 loss: 2.7406
08/07 16:35:46 - mmengine - INFO - Saving checkpoint at 2 epochs
08/07 16:35:47 - mmengine - INFO - Epoch(val) [2][ 10/313] eta: 0:00:03 time: 0.0104 data_time: 0.0036 memory: 477
08/07 16:35:47 - mmengine - INFO - Epoch(val) [2][313/313] accuracy: 12.5000 data_time: 0.0036 time: 0.0117
```
101 changes: 101 additions & 0 deletions docs/zh_cn/common_usage/debug_tricks.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,104 @@ python tools/train.py configs/resnet/resnet18_8xb16_cifar10.py
02/20 14:44:59 - mmengine - INFO - Epoch(train) [1][200/313] lr: 1.0000e-01 eta: 0:23:18 time: 0.0143 data_time: 0.0002 memory: 214 loss: 2.0424
02/20 14:45:01 - mmengine - INFO - Epoch(train) [1][300/313] lr: 1.0000e-01 eta: 0:20:39 time: 0.0143 data_time: 0.0003 memory: 214 loss: 1.814
```

## 增加cfg参数

在调试代码的过程中,有时需要训练几个 epoch,例如调试验证过程或者权重的保存是否符合期望。然而如果数据集太大,需要花费较长时间才能训完一个 epoch,这种情况下可以增加cfg参数。
以 `MMEngine` 为例(参考[文档](https://mmengine.readthedocs.io/zh_CN/latest/get_started/installation.html)安装 MMEngine)。

训练脚本示例

```python
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.optim import SGD
from torch.utils.data import DataLoader

from mmengine.evaluator import BaseMetric
from mmengine.model import BaseModel
from mmengine.runner import FlexibleRunner


class MMResNet50(BaseModel):
def __init__(self):
super().__init__()
self.resnet = torchvision.models.resnet50()

def forward(self, imgs, labels, mode):
x = self.resnet(imgs)
if mode == 'loss':
return {'loss': F.cross_entropy(x, labels)}
elif mode == 'predict':
return x, labels


class Accuracy(BaseMetric):
def process(self, data_batch, data_samples):
score, gt = data_samples
self.results.append({
'batch_size': len(gt),
'correct': (score.argmax(dim=1) == gt).sum().cpu(),
})

def compute_metrics(self, results):
total_correct = sum(item['correct'] for item in results)
total_size = sum(item['batch_size'] for item in results)
return dict(accuracy=100 * total_correct / total_size)


norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_dataloader = DataLoader(batch_size=32,
shuffle=True,
dataset=torchvision.datasets.CIFAR10(
'data/cifar10',
train=True,
download=True,
transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)
])))

val_dataloader = DataLoader(batch_size=32,
shuffle=False,
dataset=torchvision.datasets.CIFAR10(
'data/cifar10',
train=False,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)
])))

runner = FlexibleRunner(
model=MMResNet50(),
work_dir='./work_dir',
train_dataloader=train_dataloader,
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1, num_batch_per_epoch=10),
val_dataloader=val_dataloader,
val_cfg=dict(num_batch_per_epoch=10),
val_evaluator=dict(type=Accuracy)
)
runner.train()
```

通过在`train_cfg`和`val_cfg`新增`num_batch_per_epoch`参数实现快速调试。

运行训练脚本。可以看到跑完每个epoch跑完10个batch就结束了。相比原来,调试更加快速和灵活。

```
08/07 16:35:45 - mmengine - INFO - Epoch(train) [1][ 10/1563] lr: 1.0000e-03 eta: 1:15:03 time: 0.5770 data_time: 0.0075 memory: 477 loss: 5.0847
08/07 16:35:45 - mmengine - INFO - Saving checkpoint at 1 epochs
08/07 16:35:46 - mmengine - INFO - Epoch(val) [1][ 10/313] eta: 0:00:03 time: 0.0131 data_time: 0.0037 memory: 477
08/07 16:35:46 - mmengine - INFO - Epoch(val) [1][313/313] accuracy: 13.0682 data_time: 0.0038 time: 0.0130
08/07 16:35:46 - mmengine - INFO - Epoch(train) [2][ 10/1563] lr: 1.0000e-03 eta: 0:38:13 time: 0.0360 data_time: 0.0066 memory: 477 loss: 2.7406
08/07 16:35:46 - mmengine - INFO - Saving checkpoint at 2 epochs
08/07 16:35:47 - mmengine - INFO - Epoch(val) [2][ 10/313] eta: 0:00:03 time: 0.0104 data_time: 0.0036 memory: 477
08/07 16:35:47 - mmengine - INFO - Epoch(val) [2][313/313] accuracy: 12.5000 data_time: 0.0036 time: 0.0117
```
1 change: 1 addition & 0 deletions mmengine/runner/_flexible_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def __init__(
f'train_dataloader={train_dataloader}, '
f'train_cfg={train_cfg}, '
f'optim_wrapper={optim_wrapper}.')

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

self._train_dataloader = train_dataloader
self._train_loop = train_cfg

Expand Down
15 changes: 15 additions & 0 deletions mmengine/runner/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
max_epochs: int,
val_begin: int = 1,
val_interval: int = 1,
num_batch_per_epoch: Optional[int] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding a new parameter in the middle position may cause a bc issue. Suggest moving it to the end.

dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None:
super().__init__(runner, dataloader)
self._max_epochs = int(max_epochs)
Expand All @@ -50,6 +51,7 @@ def __init__(
self._iter = 0
self.val_begin = val_begin
self.val_interval = val_interval
self._num_batch_per_epoch = num_batch_per_epoch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self._num_batch_per_epoch = num_batch_per_epoch
self.num_batch_per_epoch = num_batch_per_epoch

# This attribute will be updated by `EarlyStoppingHook`
# when it is enabled.
self.stop_training = False
Expand Down Expand Up @@ -109,6 +111,9 @@ def run_epoch(self) -> None:
self.runner.call_hook('before_train_epoch')
self.runner.model.train()
for idx, data_batch in enumerate(self.dataloader):
if self._num_batch_per_epoch is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self._num_batch_per_epoch is not None:
if self.num_batch_per_epoch is not None:

if idx > self._num_batch_per_epoch:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if idx > self._num_batch_per_epoch:
if idx > self.num_batch_per_epoch:

break
self.run_iter(idx, data_batch)

self.runner.call_hook('after_train_epoch')
Expand Down Expand Up @@ -331,6 +336,7 @@ def __init__(self,
runner,
dataloader: Union[DataLoader, Dict],
evaluator: Union[Evaluator, Dict, List],
num_batch_per_epoch: Optional[int] = None,
fp16: bool = False) -> None:
super().__init__(runner, dataloader)

Expand All @@ -352,6 +358,7 @@ def __init__(self,
'visualizer will be None.',
logger='current',
level=logging.WARNING)
self._num_batch_per_epoch = num_batch_per_epoch
self.fp16 = fp16

def run(self) -> dict:
Expand All @@ -360,6 +367,9 @@ def run(self) -> dict:
self.runner.call_hook('before_val_epoch')
self.runner.model.eval()
for idx, data_batch in enumerate(self.dataloader):
if self._num_batch_per_epoch is not None:
if idx > self._num_batch_per_epoch:
break
self.run_iter(idx, data_batch)

# compute metrics
Expand Down Expand Up @@ -406,6 +416,7 @@ def __init__(self,
runner,
dataloader: Union[DataLoader, Dict],
evaluator: Union[Evaluator, Dict, List],
num_batch_per_epoch: Optional[int] = None,
fp16: bool = False):
super().__init__(runner, dataloader)

Expand All @@ -424,6 +435,7 @@ def __init__(self,
'visualizer will be None.',
logger='current',
level=logging.WARNING)
self._num_batch_per_epoch = num_batch_per_epoch
self.fp16 = fp16

def run(self) -> dict:
Expand All @@ -432,6 +444,9 @@ def run(self) -> dict:
self.runner.call_hook('before_test_epoch')
self.runner.model.eval()
for idx, data_batch in enumerate(self.dataloader):
if self._num_batch_per_epoch is not None:
if idx > self._num_batch_per_epoch:
break
self.run_iter(idx, data_batch)

# compute metrics
Expand Down
31 changes: 31 additions & 0 deletions tests/test_runner/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1347,6 +1347,17 @@ def test_build_train_loop(self):
loop = runner.build_train_loop(cfg)
self.assertIsInstance(loop, CustomTrainLoop)

# input is a dict and contains num_batch_per_epoch
cfg = dict(
type='EpochBasedTrainLoop', max_epochs=3, num_batch_per_epoch=5)
loop = runner.build_train_loop(cfg)
self.assertIsInstance(loop, EpochBasedTrainLoop)

# input is a dict and does not contain type key
cfg = dict(by_epoch=True, max_epochs=3, num_batch_per_epoch=5)
loop = runner.build_train_loop(cfg)
self.assertIsInstance(loop, EpochBasedTrainLoop)

def test_build_val_loop(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_val_loop'
Expand Down Expand Up @@ -1374,6 +1385,16 @@ def test_build_val_loop(self):
loop = runner.build_val_loop(cfg)
self.assertIsInstance(loop, CustomValLoop)

# input is a dict and contains type key and num_batch_per_epoch
cfg = dict(type='ValLoop', num_batch_per_epoch=5)
loop = runner.build_val_loop(cfg)
self.assertIsInstance(loop, ValLoop)

# input is a dict and contains num_batch_per_epoch
cfg = dict(num_batch_per_epoch=5)
loop = runner.build_val_loop(cfg)
self.assertIsInstance(loop, ValLoop)

def test_build_test_loop(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_test_loop'
Expand Down Expand Up @@ -1401,6 +1422,16 @@ def test_build_test_loop(self):
loop = runner.build_val_loop(cfg)
self.assertIsInstance(loop, CustomTestLoop)

# input is a dict and contains type key and num_batch_per_epoch
cfg = dict(type='TestLoop', num_batch_per_epoch=5)
loop = runner.build_test_loop(cfg)
self.assertIsInstance(loop, TestLoop)

# input is a dict and contains num_batch_per_epoch
cfg = dict(num_batch_per_epoch=5)
loop = runner.build_test_loop(cfg)
self.assertIsInstance(loop, TestLoop)

def test_build_log_processor(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_log_processor'
Expand Down