Skip to content

Commit

Permalink
[Feature] Resume from the latest checkpoint automatically. (#61)
Browse files Browse the repository at this point in the history
* support auto-resume

* support auto-resume

* support auto-resume

* support auto-resume

Co-authored-by: pppppM <67539920+pppppM@users.noreply.github.com>
  • Loading branch information
HIT-cwh and pppppM committed Mar 8, 2022
1 parent 366fd0f commit 81e0e34
Show file tree
Hide file tree
Showing 9 changed files with 119 additions and 1 deletion.
7 changes: 7 additions & 0 deletions mmrazor/apis/mmcls/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from mmrazor.core.hooks import DistSamplerSeedHook
from mmrazor.core.optimizer import build_optimizers
from mmrazor.datasets.utils import split_dataset
from mmrazor.utils import find_latest_checkpoint


def set_random_seed(seed, deterministic=False):
Expand Down Expand Up @@ -190,6 +191,12 @@ def train_model(model,
runner.register_hook(
eval_hook(val_dataloader, **eval_cfg), priority='LOW')

resume_from = None
if cfg.resume_from is None and cfg.get('auto_resume'):
resume_from = find_latest_checkpoint(cfg.work_dir)
if resume_from is not None:
cfg.resume_from = resume_from

if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
Expand Down
7 changes: 7 additions & 0 deletions mmrazor/apis/mmdet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from mmrazor.core.distributed_wrapper import DistributedDataParallelWrapper
from mmrazor.core.hooks import DistSamplerSeedHook
from mmrazor.core.optimizer import build_optimizers
from mmrazor.utils import find_latest_checkpoint


def set_random_seed(seed, deterministic=False):
Expand Down Expand Up @@ -181,6 +182,12 @@ def train_detector(model,
runner.register_hook(
eval_hook(val_dataloader, **eval_cfg), priority='LOW')

resume_from = None
if cfg.resume_from is None and cfg.get('auto_resume'):
resume_from = find_latest_checkpoint(cfg.work_dir)
if resume_from is not None:
cfg.resume_from = resume_from

if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
Expand Down
7 changes: 7 additions & 0 deletions mmrazor/apis/mmseg/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from mmrazor.core.distributed_wrapper import DistributedDataParallelWrapper
from mmrazor.core.optimizer import build_optimizers
from mmrazor.utils import find_latest_checkpoint


def set_random_seed(seed, deterministic=False):
Expand Down Expand Up @@ -137,6 +138,12 @@ def train_segmentor(model,
runner.register_hook(
eval_hook(val_dataloader, **eval_cfg), priority='LOW')

resume_from = None
if cfg.resume_from is None and cfg.get('auto_resume'):
resume_from = find_latest_checkpoint(cfg.work_dir)
if resume_from is not None:
cfg.resume_from = resume_from

if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
Expand Down
3 changes: 2 additions & 1 deletion mmrazor/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .misc import find_latest_checkpoint
from .setup_env import setup_multi_processes

__all__ = ['setup_multi_processes']
__all__ = ['find_latest_checkpoint', 'setup_multi_processes']
38 changes: 38 additions & 0 deletions mmrazor/utils/misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) OpenMMLab. All rights reserved.
import glob
import os.path as osp
import warnings


def find_latest_checkpoint(path, suffix='pth'):
"""Find the latest checkpoint from the working directory.
Args:
path(str): The path to find checkpoints.
suffix(str): File extension. Defaults to pth.
Returns:
latest_path(str | None): File path of the latest checkpoint.
References:
.. [1] https://github.com/microsoft/SoftTeacher
/blob/main/ssod/utils/patch.py
"""
if not osp.exists(path):
warnings.warn('The path of checkpoints does not exist.')
return None
if osp.exists(osp.join(path, f'latest.{suffix}')):
return osp.join(path, f'latest.{suffix}')

checkpoints = glob.glob(osp.join(path, f'*.{suffix}'))
if len(checkpoints) == 0:
warnings.warn('There are no checkpoints in the path.')
return None
latest = -1
latest_path = None
for checkpoint in checkpoints:
count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0])
if count > latest:
latest = count
latest_path = checkpoint
return latest_path
43 changes: 43 additions & 0 deletions tests/test_utils/test_misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile

from mmrazor.utils import find_latest_checkpoint


def test_find_latest_checkpoint():
with tempfile.TemporaryDirectory() as tmpdir:
path = tmpdir
latest = find_latest_checkpoint(path)
# There are no checkpoints in the path.
assert latest is None

path = tmpdir + '/none'
latest = find_latest_checkpoint(path)
# The path does not exist.
assert latest is None

with tempfile.TemporaryDirectory() as tmpdir:
with open(tmpdir + '/latest.pth', 'w') as f:
f.write('latest')
path = tmpdir
latest = find_latest_checkpoint(path)
assert latest == osp.join(tmpdir, 'latest.pth')

with tempfile.TemporaryDirectory() as tmpdir:
with open(tmpdir + '/iter_4000.pth', 'w') as f:
f.write('iter_4000')
with open(tmpdir + '/iter_8000.pth', 'w') as f:
f.write('iter_8000')
path = tmpdir
latest = find_latest_checkpoint(path)
assert latest == osp.join(tmpdir, 'iter_8000.pth')

with tempfile.TemporaryDirectory() as tmpdir:
with open(tmpdir + '/epoch_1.pth', 'w') as f:
f.write('epoch_1')
with open(tmpdir + '/epoch_2.pth', 'w') as f:
f.write('epoch_2')
path = tmpdir
latest = find_latest_checkpoint(path)
assert latest == osp.join(tmpdir, 'epoch_2.pth')
5 changes: 5 additions & 0 deletions tools/mmcls/train_mmcls.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ def parse_args():
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--resume-from', help='the checkpoint file to resume from')
parser.add_argument(
'--auto-resume',
action='store_true',
help='resume from the latest checkpoint automatically')
parser.add_argument(
'--no-validate',
action='store_true',
Expand Down Expand Up @@ -101,6 +105,7 @@ def main():
osp.splitext(osp.basename(args.config))[0])
if args.resume_from is not None:
cfg.resume_from = args.resume_from
cfg.auto_resume = args.auto_resume
if args.gpus is not None:
cfg.gpu_ids = range(1)
warnings.warn('`--gpus` is deprecated because we only support '
Expand Down
5 changes: 5 additions & 0 deletions tools/mmdet/train_mmdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def parse_args():
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--resume-from', help='the checkpoint file to resume from')
parser.add_argument(
'--auto-resume',
action='store_true',
help='resume from the latest checkpoint automatically')
parser.add_argument(
'--no-validate',
action='store_true',
Expand Down Expand Up @@ -112,6 +116,7 @@ def main():
osp.splitext(osp.basename(args.config))[0])
if args.resume_from is not None:
cfg.resume_from = args.resume_from
cfg.auto_resume = args.auto_resume
if args.gpus is not None:
cfg.gpu_ids = range(1)
warnings.warn('`--gpus` is deprecated because we only support '
Expand Down
5 changes: 5 additions & 0 deletions tools/mmseg/train_mmseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ def parse_args():
'--load-from', help='the checkpoint file to load weights from')
parser.add_argument(
'--resume-from', help='the checkpoint file to resume from')
parser.add_argument(
'--auto-resume',
action='store_true',
help='resume from the latest checkpoint automatically')
parser.add_argument(
'--no-validate',
action='store_true',
Expand Down Expand Up @@ -114,6 +118,7 @@ def main():
cfg.load_from = args.load_from
if args.resume_from is not None:
cfg.resume_from = args.resume_from
cfg.auto_resume = args.auto_resume
if args.gpus is not None:
cfg.gpu_ids = range(1)
warnings.warn('`--gpus` is deprecated because we only support '
Expand Down

0 comments on commit 81e0e34

Please sign in to comment.