-
Notifications
You must be signed in to change notification settings - Fork 323
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CodeCamp2023-470] Runner supports setting the number of iterations for each epoch #1292
Changes from 12 commits
b89259c
be7230b
75d8dff
d7df66a
b3f0032
40c5252
e025c53
ea3db4f
3a06ed8
5c444c7
2f40dc0
8ef9a26
a446035
f9b27e4
d598fe1
7ee46ed
f88b104
8bb15b6
e0aa78a
143ab59
4ed72c9
bd035d7
05087cd
d6e220b
3311ebe
b4d45a3
9ea37fc
90a23bc
22345a9
f54f1c1
5b4b2b4
4d3e2f0
8ba88da
96ec4d6
5c57054
0050ff7
77bcad1
d1a5456
36519f6
3fa23ec
deef4ad
7e31da4
988d790
adb92ef
f01570b
de6fd78
53bc8e0
ac6e046
2881405
4eb168f
9540f21
15f9d85
379600b
36ee77d
9ccbb3f
686ebfb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -49,3 +49,104 @@ As we can see, the number of iterations has changed to `313`. Compared to before | |||||||
02/20 14:44:59 - mmengine - INFO - Epoch(train) [1][200/313] lr: 1.0000e-01 eta: 0:23:18 time: 0.0143 data_time: 0.0002 memory: 214 loss: 2.0424 | ||||||||
02/20 14:45:01 - mmengine - INFO - Epoch(train) [1][300/313] lr: 1.0000e-01 eta: 0:20:39 time: 0.0143 data_time: 0.0003 memory: 214 loss: 1.814 | ||||||||
``` | ||||||||
|
||||||||
## Add cfg parameter | ||||||||
|
||||||||
During the process of debugging code, sometimes it is necessary to train for several epochs, such as debugging the validation process or checking whether the checkpoint saving meets expectations. However, if the dataset is too large, it may take a long time to complete one epoch, in which case the cfg parameter can be added. | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
Take `MMEngine` as an example(Refer to the [documentation](https://mmengine.readthedocs.io/zh_CN/latest/get_started/installation.html)for installing MMEngine)。 | ||||||||
|
||||||||
Example of a training script | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
|
||||||||
```python | ||||||||
import os | ||||||||
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" | ||||||||
import torch.nn.functional as F | ||||||||
import torchvision | ||||||||
import torchvision.transforms as transforms | ||||||||
from torch.optim import SGD | ||||||||
from torch.utils.data import DataLoader | ||||||||
|
||||||||
from mmengine.evaluator import BaseMetric | ||||||||
from mmengine.model import BaseModel | ||||||||
from mmengine.runner import FlexibleRunner | ||||||||
|
||||||||
|
||||||||
class MMResNet50(BaseModel): | ||||||||
def __init__(self): | ||||||||
super().__init__() | ||||||||
self.resnet = torchvision.models.resnet50() | ||||||||
|
||||||||
def forward(self, imgs, labels, mode): | ||||||||
x = self.resnet(imgs) | ||||||||
if mode == 'loss': | ||||||||
return {'loss': F.cross_entropy(x, labels)} | ||||||||
elif mode == 'predict': | ||||||||
return x, labels | ||||||||
|
||||||||
|
||||||||
class Accuracy(BaseMetric): | ||||||||
def process(self, data_batch, data_samples): | ||||||||
score, gt = data_samples | ||||||||
self.results.append({ | ||||||||
'batch_size': len(gt), | ||||||||
'correct': (score.argmax(dim=1) == gt).sum().cpu(), | ||||||||
}) | ||||||||
|
||||||||
def compute_metrics(self, results): | ||||||||
total_correct = sum(item['correct'] for item in results) | ||||||||
total_size = sum(item['batch_size'] for item in results) | ||||||||
return dict(accuracy=100 * total_correct / total_size) | ||||||||
|
||||||||
|
||||||||
norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201]) | ||||||||
train_dataloader = DataLoader(batch_size=32, | ||||||||
shuffle=True, | ||||||||
dataset=torchvision.datasets.CIFAR10( | ||||||||
'data/cifar10', | ||||||||
train=True, | ||||||||
download=True, | ||||||||
transform=transforms.Compose([ | ||||||||
transforms.RandomCrop(32, padding=4), | ||||||||
transforms.RandomHorizontalFlip(), | ||||||||
transforms.ToTensor(), | ||||||||
transforms.Normalize(**norm_cfg) | ||||||||
]))) | ||||||||
|
||||||||
val_dataloader = DataLoader(batch_size=32, | ||||||||
shuffle=False, | ||||||||
dataset=torchvision.datasets.CIFAR10( | ||||||||
'data/cifar10', | ||||||||
train=False, | ||||||||
download=True, | ||||||||
transform=transforms.Compose([ | ||||||||
transforms.ToTensor(), | ||||||||
transforms.Normalize(**norm_cfg) | ||||||||
]))) | ||||||||
|
||||||||
runner = FlexibleRunner( | ||||||||
model=MMResNet50(), | ||||||||
work_dir='./work_dir', | ||||||||
train_dataloader=train_dataloader, | ||||||||
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)), | ||||||||
train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1, num_batch_per_epoch=10), | ||||||||
val_dataloader=val_dataloader, | ||||||||
val_cfg=dict(num_batch_per_epoch=10), | ||||||||
val_evaluator=dict(type=Accuracy) | ||||||||
) | ||||||||
runner.train() | ||||||||
``` | ||||||||
|
||||||||
Fast debugging is achieved by adding the `num_batch_per_epoch` parameter to `train_cfg` and `val_cfg`. | ||||||||
|
||||||||
Run the training script. You can see that after running each epoch run 10 batch is over. Compared to the original, debugging is faster and more flexible. | ||||||||
|
||||||||
``` | ||||||||
08/07 16:35:45 - mmengine - INFO - Epoch(train) [1][ 10/1563] lr: 1.0000e-03 eta: 1:15:03 time: 0.5770 data_time: 0.0075 memory: 477 loss: 5.0847 | ||||||||
08/07 16:35:45 - mmengine - INFO - Saving checkpoint at 1 epochs | ||||||||
08/07 16:35:46 - mmengine - INFO - Epoch(val) [1][ 10/313] eta: 0:00:03 time: 0.0131 data_time: 0.0037 memory: 477 | ||||||||
08/07 16:35:46 - mmengine - INFO - Epoch(val) [1][313/313] accuracy: 13.0682 data_time: 0.0038 time: 0.0130 | ||||||||
08/07 16:35:46 - mmengine - INFO - Epoch(train) [2][ 10/1563] lr: 1.0000e-03 eta: 0:38:13 time: 0.0360 data_time: 0.0066 memory: 477 loss: 2.7406 | ||||||||
08/07 16:35:46 - mmengine - INFO - Saving checkpoint at 2 epochs | ||||||||
08/07 16:35:47 - mmengine - INFO - Epoch(val) [2][ 10/313] eta: 0:00:03 time: 0.0104 data_time: 0.0036 memory: 477 | ||||||||
08/07 16:35:47 - mmengine - INFO - Epoch(val) [2][313/313] accuracy: 12.5000 data_time: 0.0036 time: 0.0117 | ||||||||
``` |
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -298,6 +298,7 @@ def __init__( | |||
f'train_dataloader={train_dataloader}, ' | ||||
f'train_cfg={train_cfg}, ' | ||||
f'optim_wrapper={optim_wrapper}.') | ||||
|
||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||
self._train_dataloader = train_dataloader | ||||
self._train_loop = train_cfg | ||||
|
||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -40,6 +40,7 @@ def __init__( | |||||
max_epochs: int, | ||||||
val_begin: int = 1, | ||||||
val_interval: int = 1, | ||||||
num_batch_per_epoch: Optional[int] = None, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adding a new parameter in the middle position may cause a bc issue. Suggest moving it to the end. |
||||||
dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None: | ||||||
super().__init__(runner, dataloader) | ||||||
self._max_epochs = int(max_epochs) | ||||||
|
@@ -50,6 +51,7 @@ def __init__( | |||||
self._iter = 0 | ||||||
self.val_begin = val_begin | ||||||
self.val_interval = val_interval | ||||||
self._num_batch_per_epoch = num_batch_per_epoch | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
# This attribute will be updated by `EarlyStoppingHook` | ||||||
# when it is enabled. | ||||||
self.stop_training = False | ||||||
|
@@ -109,6 +111,9 @@ def run_epoch(self) -> None: | |||||
self.runner.call_hook('before_train_epoch') | ||||||
self.runner.model.train() | ||||||
for idx, data_batch in enumerate(self.dataloader): | ||||||
if self._num_batch_per_epoch is not None: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
if idx > self._num_batch_per_epoch: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
break | ||||||
self.run_iter(idx, data_batch) | ||||||
|
||||||
self.runner.call_hook('after_train_epoch') | ||||||
|
@@ -331,6 +336,7 @@ def __init__(self, | |||||
runner, | ||||||
dataloader: Union[DataLoader, Dict], | ||||||
evaluator: Union[Evaluator, Dict, List], | ||||||
num_batch_per_epoch: Optional[int] = None, | ||||||
fp16: bool = False) -> None: | ||||||
super().__init__(runner, dataloader) | ||||||
|
||||||
|
@@ -352,6 +358,7 @@ def __init__(self, | |||||
'visualizer will be None.', | ||||||
logger='current', | ||||||
level=logging.WARNING) | ||||||
self._num_batch_per_epoch = num_batch_per_epoch | ||||||
self.fp16 = fp16 | ||||||
|
||||||
def run(self) -> dict: | ||||||
|
@@ -360,6 +367,9 @@ def run(self) -> dict: | |||||
self.runner.call_hook('before_val_epoch') | ||||||
self.runner.model.eval() | ||||||
for idx, data_batch in enumerate(self.dataloader): | ||||||
if self._num_batch_per_epoch is not None: | ||||||
if idx > self._num_batch_per_epoch: | ||||||
break | ||||||
self.run_iter(idx, data_batch) | ||||||
|
||||||
# compute metrics | ||||||
|
@@ -406,6 +416,7 @@ def __init__(self, | |||||
runner, | ||||||
dataloader: Union[DataLoader, Dict], | ||||||
evaluator: Union[Evaluator, Dict, List], | ||||||
num_batch_per_epoch: Optional[int] = None, | ||||||
fp16: bool = False): | ||||||
super().__init__(runner, dataloader) | ||||||
|
||||||
|
@@ -424,6 +435,7 @@ def __init__(self, | |||||
'visualizer will be None.', | ||||||
logger='current', | ||||||
level=logging.WARNING) | ||||||
self._num_batch_per_epoch = num_batch_per_epoch | ||||||
self.fp16 = fp16 | ||||||
|
||||||
def run(self) -> dict: | ||||||
|
@@ -432,6 +444,9 @@ def run(self) -> dict: | |||||
self.runner.call_hook('before_test_epoch') | ||||||
self.runner.model.eval() | ||||||
for idx, data_batch in enumerate(self.dataloader): | ||||||
if self._num_batch_per_epoch is not None: | ||||||
if idx > self._num_batch_per_epoch: | ||||||
break | ||||||
self.run_iter(idx, data_batch) | ||||||
|
||||||
# compute metrics | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I have done these changes.