Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Save optimizer.state in cpu by default. #966

Merged
merged 5 commits into from
Apr 26, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/en/api/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,4 @@ Miscellaneous
requires_executable
requires_package
check_time
apply_to
1 change: 1 addition & 0 deletions docs/zh_cn/api/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,4 @@ Miscellaneous
requires_executable
requires_package
check_time
apply_to
20 changes: 13 additions & 7 deletions mmengine/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,13 @@
HOOKS, LOG_PROCESSORS, LOOPS, MODEL_WRAPPERS,
MODELS, OPTIM_WRAPPERS, PARAM_SCHEDULERS,
RUNNERS, VISUALIZERS, DefaultScope)
from mmengine.utils import digit_version, get_git_hash, is_seq_of
from mmengine.utils import apply_to, digit_version, get_git_hash, is_seq_of
from mmengine.utils.dl_utils import (TORCH_VERSION, collect_env,
set_multi_processing)
from mmengine.visualization import Visualizer
from .base_loop import BaseLoop
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
find_latest_checkpoint, get_state_dict,
save_checkpoint, weights_to_cpu)
find_latest_checkpoint, save_checkpoint)
from .log_processor import LogProcessor
from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop
from .priority import Priority, get_priority
Expand Down Expand Up @@ -2139,14 +2138,21 @@ def save_checkpoint(
model = self.model

checkpoint = {
'meta': meta,
'state_dict': weights_to_cpu(get_state_dict(model)),
'message_hub': self.message_hub.state_dict()
'meta':
meta,
'state_dict':
apply_to(model.state_dict(), lambda x: hasattr(x, 'cpu'),
lambda x: x.cpu()),
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
'message_hub':
apply_to(self.message_hub.state_dict(),
lambda x: hasattr(x, 'cpu'), lambda x: x.cpu()),
}
# save optimizer state dict to checkpoint
if save_optimizer:
if isinstance(self.optim_wrapper, OptimWrapper):
checkpoint['optimizer'] = self.optim_wrapper.state_dict()
checkpoint['optimizer'] = apply_to(
self.optim_wrapper.state_dict(),
lambda x: hasattr(x, 'cpu'), lambda x: x.cpu())
else:
raise TypeError(
'self.optim_wrapper should be an `OptimWrapper` '
Expand Down
7 changes: 4 additions & 3 deletions mmengine/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .manager import ManagerMeta, ManagerMixin
from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
deprecated_function, has_method,
from .misc import (apply_to, check_prerequisites, concat_list,
deprecated_api_warning, deprecated_function, has_method,
import_modules_from_strings, is_list_of,
is_method_overridden, is_seq_of, is_str, is_tuple_of,
iter_cast, list_cast, requires_executable, requires_package,
Expand All @@ -27,5 +27,6 @@
'is_abs', 'is_method_overridden', 'has_method', 'digit_version',
'get_git_hash', 'ManagerMeta', 'ManagerMixin', 'Timer', 'check_time',
'TimerError', 'ProgressBar', 'track_iter_progress',
'track_parallel_progress', 'track_progress', 'deprecated_function'
'track_parallel_progress', 'track_progress', 'deprecated_function',
'apply_to'
]
47 changes: 47 additions & 0 deletions mmengine/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,53 @@ def concat_list(in_list):
return list(itertools.chain(*in_list))


def apply_to(data: Any, expr: Callable, apply_func: Callable):
"""Apply function to each element in dict, list or tuple that matches with
the expression.

For examples, if you want to convert each element in a list of dict from
`np.ndarray` to `Tensor`. You can use the following code:

Examples:
>>> from mmengine.utils import apply_to
>>> import numpy as np
>>> import torch
>>> data = dict(array=[np.array(1)]) # {'array': [array(1)]}
>>> result = apply_to(data, lambda x: isinstance(x, np.ndarray), lambda x: torch.from_numpy(x))
>>> print(result) # {'array': [tensor(1)]}

Args:
data (Any): Data to be applied.
expr (Callable): Expression to tell which data should be applied with
the function. It should return a boolean.
apply_func (Callable): Function applied to data.

Returns:
Any: The data after applying.
""" # noqa: E501
if any(map(expr, (dict(), '', b'', tuple(), list()))):
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
'``expr`` should not match with Mapping, str, bytes, tuple or '
'list. This kind of `expr` could lead to an unexpected results')

def _apply(data, expr, apply_func):
if isinstance(data, dict):
return {key: _apply(data[key], expr, apply_func) for key in data}
elif isinstance(data, (str, bytes)) or data is None:
return data
elif isinstance(data, tuple) and hasattr(data, '_fields'):
# namedtuple
return type(data)(*(_apply(sample, expr, apply_func) for sample in data)) # type: ignore # noqa: E501 # yapf:disable
elif isinstance(data, (tuple, list)):
return type(data)(_apply(sample, expr, apply_func) for sample in data) # type: ignore # noqa: E501 # yapf:disable
elif expr(data):
return apply_func(data)
else:
return data

return _apply(data, expr, apply_func)


def check_prerequisites(
prerequisites,
checker,
Expand Down
17 changes: 16 additions & 1 deletion tests/test_utils/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from mmengine import MMLogger
# yapf: disable
from mmengine.utils.misc import (concat_list, deprecated_api_warning,
from mmengine.utils.misc import (apply_to, concat_list, deprecated_api_warning,
deprecated_function, has_method,
import_modules_from_strings, is_list_of,
is_method_overridden, is_seq_of, is_tuple_of,
Expand Down Expand Up @@ -283,3 +283,18 @@ def deprecated_demo1():

Short summary.""" # noqa: E122
assert expected_docstring.strip(' ') == deprecated_demo1.__doc__


def test_apply_to():
# Test only apply `+1` to int object.
data = dict(a=1, b=2.0)
result = apply_to(data, lambda x: isinstance(x, int), lambda x: x + 1)
assert result == dict(a=2, b=2.0)

# Test with nested data
data = dict(a=[dict(c=1)], b=2.0)
result = apply_to(data, lambda x: isinstance(x, int), lambda x: x + 1)
assert dict(a=[dict(c=2)], b=2.0)

with pytest.raises(ValueError):
apply_to(data, lambda x: isinstance(x, list), lambda x: x)