Skip to content

Commit

Permalink
[Fix] Save optimizer.state_dict() in cpu by default (#966)
Browse files Browse the repository at this point in the history
  • Loading branch information
HAOCHENYE committed Apr 26, 2023
1 parent 9868131 commit 6ba667c
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 17 deletions.
1 change: 1 addition & 0 deletions docs/en/api/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,4 @@ Miscellaneous
requires_executable
requires_package
check_time
apply_to
1 change: 1 addition & 0 deletions docs/zh_cn/api/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,4 @@ Miscellaneous
requires_executable
requires_package
check_time
apply_to
12 changes: 6 additions & 6 deletions mmengine/runner/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from mmengine.fileio import load as load_file
from mmengine.logging import print_log
from mmengine.model import BaseTTAModel, is_model_wrapper
from mmengine.utils import deprecated_function, digit_version, mkdir_or_exist
from mmengine.utils import (apply_to, deprecated_function, digit_version,
mkdir_or_exist)
from mmengine.utils.dl_utils import load_url

# `MMENGINE_HOME` is the highest priority directory to save checkpoints
Expand Down Expand Up @@ -622,12 +623,11 @@ def weights_to_cpu(state_dict):
Returns:
OrderedDict: Model weights on GPU.
"""
state_dict_cpu = OrderedDict()
for key, val in state_dict.items():
state_dict_cpu[key] = val.cpu()
state_dict = apply_to(state_dict, lambda x: hasattr(x, 'cpu'),
lambda x: x.cpu())
# Keep metadata in state_dict
state_dict_cpu._metadata = getattr(state_dict, '_metadata', OrderedDict())
return state_dict_cpu
state_dict._metadata = getattr(state_dict, '_metadata', OrderedDict())
return state_dict


@deprecated_function(
Expand Down
20 changes: 13 additions & 7 deletions mmengine/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@
HOOKS, LOG_PROCESSORS, LOOPS, MODEL_WRAPPERS,
MODELS, OPTIM_WRAPPERS, PARAM_SCHEDULERS,
RUNNERS, VISUALIZERS, DefaultScope)
from mmengine.utils import digit_version, get_git_hash, is_seq_of
from mmengine.utils import apply_to, digit_version, get_git_hash, is_seq_of
from mmengine.utils.dl_utils import (TORCH_VERSION, collect_env,
set_multi_processing)
from mmengine.visualization import Visualizer
from .base_loop import BaseLoop
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
find_latest_checkpoint, get_state_dict,
save_checkpoint, weights_to_cpu)
find_latest_checkpoint, save_checkpoint,
weights_to_cpu)
from .log_processor import LogProcessor
from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop
from .priority import Priority, get_priority
Expand Down Expand Up @@ -2164,14 +2164,20 @@ def save_checkpoint(
model = self.model

checkpoint = {
'meta': meta,
'state_dict': weights_to_cpu(get_state_dict(model)),
'message_hub': self.message_hub.state_dict()
'meta':
meta,
'state_dict':
weights_to_cpu(model.state_dict()),
'message_hub':
apply_to(self.message_hub.state_dict(),
lambda x: hasattr(x, 'cpu'), lambda x: x.cpu()),
}
# save optimizer state dict to checkpoint
if save_optimizer:
if isinstance(self.optim_wrapper, OptimWrapper):
checkpoint['optimizer'] = self.optim_wrapper.state_dict()
checkpoint['optimizer'] = apply_to(
self.optim_wrapper.state_dict(),
lambda x: hasattr(x, 'cpu'), lambda x: x.cpu())
else:
raise TypeError(
'self.optim_wrapper should be an `OptimWrapper` '
Expand Down
7 changes: 4 additions & 3 deletions mmengine/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .manager import ManagerMeta, ManagerMixin
from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
deprecated_function, has_method,
from .misc import (apply_to, check_prerequisites, concat_list,
deprecated_api_warning, deprecated_function, has_method,
import_modules_from_strings, is_list_of,
is_method_overridden, is_seq_of, is_str, is_tuple_of,
iter_cast, list_cast, requires_executable, requires_package,
Expand All @@ -27,5 +27,6 @@
'is_abs', 'is_method_overridden', 'has_method', 'digit_version',
'get_git_hash', 'ManagerMeta', 'ManagerMixin', 'Timer', 'check_time',
'TimerError', 'ProgressBar', 'track_iter_progress',
'track_parallel_progress', 'track_progress', 'deprecated_function'
'track_parallel_progress', 'track_progress', 'deprecated_function',
'apply_to'
]
41 changes: 41 additions & 0 deletions mmengine/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,47 @@ def concat_list(in_list):
return list(itertools.chain(*in_list))


def apply_to(data: Any, expr: Callable, apply_func: Callable):
"""Apply function to each element in dict, list or tuple that matches with
the expression.
For examples, if you want to convert each element in a list of dict from
`np.ndarray` to `Tensor`. You can use the following code:
Examples:
>>> from mmengine.utils import apply_to
>>> import numpy as np
>>> import torch
>>> data = dict(array=[np.array(1)]) # {'array': [array(1)]}
>>> result = apply_to(data, lambda x: isinstance(x, np.ndarray), lambda x: torch.from_numpy(x))
>>> print(result) # {'array': [tensor(1)]}
Args:
data (Any): Data to be applied.
expr (Callable): Expression to tell which data should be applied with
the function. It should return a boolean.
apply_func (Callable): Function applied to data.
Returns:
Any: The data after applying.
""" # noqa: E501
if isinstance(data, dict):
# Keep the original dict type
res = type(data)()
for key, value in data.items():
res[key] = apply_to(value, expr, apply_func)
return res
elif isinstance(data, tuple) and hasattr(data, '_fields'):
# namedtuple
return type(data)(*(apply_to(sample, expr, apply_func) for sample in data)) # type: ignore # noqa: E501 # yapf:disable
elif isinstance(data, (tuple, list)):
return type(data)(apply_to(sample, expr, apply_func) for sample in data) # type: ignore # noqa: E501 # yapf:disable
elif expr(data):
return apply_func(data)
else:
return data


def check_prerequisites(
prerequisites,
checker,
Expand Down
46 changes: 45 additions & 1 deletion tests/test_utils/test_misc.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
from collections import namedtuple

import numpy as np
import pytest
import torch

from mmengine import MMLogger
# yapf: disable
from mmengine.utils.misc import (concat_list, deprecated_api_warning,
from mmengine.utils.misc import (apply_to, concat_list, deprecated_api_warning,
deprecated_function, has_method,
import_modules_from_strings, is_list_of,
is_method_overridden, is_seq_of, is_tuple_of,
Expand Down Expand Up @@ -283,3 +287,43 @@ def deprecated_demo1():
Short summary.""" # noqa: E122
assert expected_docstring.strip(' ') == deprecated_demo1.__doc__


def test_apply_to():
# Test only apply `+1` to int object.
data = dict(a=1, b=2.0)
result = apply_to(data, lambda x: isinstance(x, int), lambda x: x + 1)
assert result == dict(a=2, b=2.0)

# Test with nested data
data = dict(a=[dict(c=1)], b=2.0)
result = apply_to(data, lambda x: isinstance(x, int), lambda x: x + 1)
assert result == dict(a=[dict(c=2)], b=2.0)

# Tensor to numpy
data = dict(a=[dict(c=torch.tensor(1))], b=torch.tensor(2))
result = apply_to(data, lambda x: isinstance(x, torch.Tensor),
lambda x: x.numpy())
assert isinstance(result['b'], np.ndarray)
assert isinstance(result['a'][0]['c'], np.ndarray)

# Tuple and convert string
data = (1, dict(a=[dict(b=2.0)]), 'test')
result = apply_to(
data, lambda x: isinstance(x, int) or x == 'test',
lambda x: torch.Tensor(x) if isinstance(x, int) else 'train')
assert isinstance(result, tuple)
assert isinstance(result[0], torch.Tensor)
assert isinstance(result[1]['a'][0]['b'], float)
assert result[2] == 'train'

# Named Tuple
dataclass = namedtuple('Data', ['a', 'b'])
data = dataclass('test', dict(a=[dict(c=1)], b=2.0))
result = apply_to(
data, lambda x: isinstance(x, int) or x == 'test',
lambda x: torch.Tensor(x) if isinstance(x, int) else 'train')
assert isinstance(result, dataclass)
assert result[0] == 'train'
assert isinstance(result.b['a'][0]['c'], torch.Tensor)
assert isinstance(result.b['b'], float)

0 comments on commit 6ba667c

Please sign in to comment.