Skip to content

Commit

Permalink
Merge 912ffbb into ff27b72
Browse files Browse the repository at this point in the history
  • Loading branch information
KerwinKai committed Mar 17, 2023
2 parents ff27b72 + 912ffbb commit 533ac42
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 2 deletions.
9 changes: 9 additions & 0 deletions docs/en/tutorials/hook.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ runner.train()
- Save the most recent checkpoints
- Save the best checkpoints
- Specify the path to save the checkpoints
- Automatically publish the best and the last checkpoints

For more features, please read the [CheckpointHook API documentation](mmengine.hooks.CheckpointHook).

Expand Down Expand Up @@ -120,6 +121,14 @@ The four features mentioned above are described below.
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, out_dir='/path/of/directory'))
```

- Automatically publish the best and the last checkpoints

If you want to publish the best and the last checkpoints after training, you can set the `published_keys` parameter. You can select any keys in checkpoint to be published.

```python
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=1, save_best='accuracy', rule='less', published_keys=['meta', 'state_dict']))
```

[LoggerHook](mmengine.hooks.LoggerHook) collects logs from different components of `Runner` and write them to terminal, JSON file, tensorboard and wandb .etc.

If we want to output (or save) the logs every 20 iterations, we can set the `interval` parameter and configure it as follows.
Expand Down
9 changes: 9 additions & 0 deletions docs/zh_cn/tutorials/hook.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ runner.train()
- 保存最新的多个权重
- 保存最优权重
- 指定保存权重的路径
- 自动发布最好的和最后的权重

如需了解其他功能,请阅读 [CheckpointHook API 文档](mmengine.hooks.CheckpointHook)

Expand Down Expand Up @@ -121,6 +122,14 @@ runner.train()
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, out_dir='/path/of/directory'))
```

- 自动发布最好的和最后的权重

如果你想在训练后发布最好的和最后的权重,你可以设置`published_keys`参数。 通过这个参数,您可以选择要发布权重中所要保存的键。

```python
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=1, save_best='accuracy', rule='less', published_keys=['meta', 'state_dict']))
```

### LoggerHook

[LoggerHook](mmengine.hooks.LoggerHook) 负责收集日志并把日志输出到终端或者输出到文件、TensorBoard 等后端。
Expand Down
76 changes: 75 additions & 1 deletion mmengine/hooks/checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,10 @@ class CheckpointHook(Hook):
backend_args (dict, optional): Arguments to instantiate the
prefix of uri corresponding backend. Defaults to None.
New in v0.2.0.
published_keys (str, List[str], optional): If ``save_last`` is ``True``
or ``save_best`` is not ``None``, it will automatically
publish model with keys in the list after training.
Defaults to None.
Examples:
>>> # Save best based on single metric
>>> CheckpointHook(interval=2, by_epoch=True, save_best='acc',
Expand All @@ -95,6 +98,9 @@ class CheckpointHook(Hook):
>>> # Save best based on multi metrics with different comparison rule
>>> CheckpointHook(interval=2, by_epoch=True,
>>> save_best=['FID', 'IS'], rule=['less', 'greater'])
>>> # Save best based on single metric and publish model after training
>>> CheckpointHook(interval=2, by_epoch=True, save_best='acc',
>>> rule='less', published_keys=['meta', 'state_dict'])
"""
out_dir: str

Expand Down Expand Up @@ -128,6 +134,7 @@ def __init__(self,
file_client_args: Optional[dict] = None,
filename_tmpl: Optional[str] = None,
backend_args: Optional[dict] = None,
published_keys: Union[str, List[str], None] = None,
**kwargs) -> None:
self.interval = interval
self.by_epoch = by_epoch
Expand Down Expand Up @@ -218,6 +225,20 @@ def __init__(self,
else:
self.best_ckpt_path_dict: Dict = dict()

# published keys
if not (isinstance(published_keys, str)
or is_list_of(published_keys, str) or published_keys is None):
raise TypeError(
'"published_keys" should be a str or list of str or None, '
f'but got {type(published_keys)}')

if isinstance(published_keys, str):
published_keys = [published_keys]
elif isinstance(published_keys, list):
assert len(published_keys) == len(set(published_keys)), (
'Find duplicate elements in "published_keys".')
self.published_keys = published_keys

def before_train(self, runner) -> None:
"""Finish all operations, related to checkpoint.
Expand Down Expand Up @@ -304,6 +325,59 @@ def after_val_epoch(self, runner, metrics):

self._save_best_checkpoint(runner, metrics)

def after_train(self, runner) -> None:
"""Publish the checkpoint after training.
Args:
runner (Runner): The runner of the training process.
"""
if self.published_keys is None:
return

if self.save_last and 'last_ckpt' in runner.message_hub.runtime_info:
last_ckpt = runner.message_hub.get_info('last_ckpt')
self._publish_model(runner, last_ckpt)

if getattr(self, 'best_ckpt_path', None) is not None:
self._publish_model(runner, str(self.best_ckpt_path))
if getattr(self, 'best_ckpt_path_dict', None) is not None:
for key, best_ckpt in self.best_ckpt_path_dict.items():
self._publish_model(runner, best_ckpt)

def _publish_model(self, runner, ckpt_path: str) -> None:
"""Remove unnecessary keys from ckpt_path and save the new checkpoint.
Args:
runner (Runner): The runner of the training process.
ckpt_path (str): The checkpoint path that ought to be published.
"""
from mmengine.runner import save_checkpoint
from mmengine.runner.checkpoint import _load_checkpoint
checkpoint = _load_checkpoint(ckpt_path)
published_keys: Optional[Union[List[str], str]] = self.published_keys
removed_keys = []
for key in list(checkpoint.keys()):
if published_keys is not None and\
key not in published_keys:
removed_keys.append(key)
checkpoint.pop(key)
if removed_keys:
print_log(
f'Key {removed_keys} will be removed because they are not '
'found in published_keys. If you want to keep them, '
f'please set `{removed_keys}` in published_keys',
logger='current')
final_path = osp.splitext(
ckpt_path)[0] + f'-published-{runner.timestamp}.pth'
save_checkpoint(checkpoint, final_path)
print_log(
f'The published model is saved at {final_path}.', logger='current')
if 'publish_ckpt_names' not in runner.message_hub.runtime_info:
runner.message_hub.update_info('publish_ckpt_names', [final_path])
else:
ckpt_path_list = runner.message_hub.get_info('publish_ckpt_names')
ckpt_path_list.append(final_path)

def _save_checkpoint(self, runner) -> None:
"""Save the current checkpoint and delete outdated checkpoint.
Expand Down
11 changes: 10 additions & 1 deletion tests/test_hooks/test_checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from mmengine.model import BaseModel
from mmengine.optim import OptimWrapper
from mmengine.runner import Runner
from mmengine.runner.checkpoint import _load_checkpoint


class ToyModel(BaseModel):
Expand Down Expand Up @@ -465,7 +466,10 @@ def test_with_runner(self, tmp_path):
type='CheckpointHook',
interval=save_interval,
filename_tmpl=tmpl,
by_epoch=True)
by_epoch=True,
save_best='test/acc',
rule='less',
published_keys=['meta', 'state_dict'])
runner = Runner(
model=ToyModel(),
work_dir=work_dir,
Expand All @@ -492,3 +496,8 @@ def test_with_runner(self, tmp_path):
continue
path = osp.join(work_dir, tmpl.format(epoch))
assert osp.isfile(path=path)
for path in runner.message_hub.get_info('publish_ckpt_names'):
checkpoint = _load_checkpoint(path)
assert osp.isfile(path=path)
for key in list(checkpoint.keys()):
assert key in list(checkpoint_cfg['published_keys'])

0 comments on commit 533ac42

Please sign in to comment.