/
standard_updater.py
271 lines (219 loc) · 10.3 KB
/
standard_updater.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
import warnings
import six
import chainer
from chainer.backends import cuda
from chainer.dataset import convert
from chainer.dataset import iterator as iterator_module
from chainer import device_resident
from chainer.training import _updater
from chainer.utils import argument
class StandardUpdater(_updater.Updater):
"""StandardUpdater(\
iterator, optimizer, converter=convert.concat_examples, device=None, \
loss_func=None, loss_scale=None, auto_new_epoch=True, *, input_device=None)
Standard implementation of Updater.
This is the standard implementation of :class:`~chainer.training.Updater`.
It accepts one or more training datasets and one or more optimizers.
The default update routine assumes that there is only one training dataset
and one optimizer. Users can override this update routine by inheriting
this class and overriding the :meth:`update_core` method. Each batch is
converted to input arrays by :func:`chainer.dataset.concat_examples` by
default, which can also be manually set by ``converter`` argument.
Args:
iterator: Dataset iterator for the training dataset. It can also be a
dictionary that maps strings to iterators.
If this is just an iterator, then the
iterator is registered by the name ``'main'``.
optimizer: Optimizer to update parameters. It can also be a dictionary
that maps strings to optimizers.
If this is just an optimizer, then the optimizer is
registered by the name ``'main'``.
converter: Converter function to build input arrays. Each batch
extracted by the main iterator and the ``device`` option are passed
to this function. :func:`chainer.dataset.concat_examples` is used
by default.
device(device specifier): Device to which the model is sent.
If ``None``, the device of the model will stay unchanged.
loss_func: Loss function. The target link of the main optimizer is used
by default.
loss_scale (float): Loss scaling factor. Loss scaling is a usefull
technique to mitigate vanishing gradient issue that tends to happen
when low precision data type like float16 is used during training.
If you set loss scaling factor, gradients of loss values are to be
multiplied by the factor before backprop starts. The factor is
propagated to whole gradients in a computational graph along the
backprop. The gradients of parameters are divided by the factor
just before the parameters are to be updated.
auto_new_epoch (bool): If ``True``,
:meth:`~chainer.Optimizer.new_epoch` of the main optimizer is
automatically called when the ``is_new_epoch`` attribute of the
main iterator is ``True``.
input_device (device specifier):
Device to which the training data is sent.
If ``input_device`` is omitted, it will match the ``device``
argument.
Attributes:
converter: Converter function.
loss_func: Loss function. If it is ``None``, the target link of the
main optimizer is used instead.
device: Device to which the model is sent.
input_device: Device to which the training data is sent.
iteration: Current number of completed updates.
auto_new_epoch: If ``True``, :meth:`~chainer.Optimizer.new_epoch` is
automatically called by :meth:`update_core`. In this case, the
:attr:`~chainer.Optimizer.use_auto_new_epoch` attribute of each
optimizer is also set to ``True``. If :meth:`update_core` is
overridden, the implementation should correctly call
:meth:`~chainer.Optimizer.new_epoch` of each optimizer.
"""
def __init__(self, iterator, optimizer, converter=convert.concat_examples,
device=None, loss_func=None, loss_scale=None,
auto_new_epoch=True, **kwargs):
input_device, = argument.parse_kwargs(
kwargs, ('input_device', None))
if device is not None:
device = chainer.get_device(device)
# input_device falls back to device
if input_device is None:
input_device = device
else:
input_device = chainer.get_device(input_device)
if isinstance(iterator, iterator_module.Iterator):
iterator = {'main': iterator}
self._iterators = iterator
if not isinstance(optimizer, dict):
optimizer = {'main': optimizer}
self._optimizers = optimizer
# Transfer the model
if device is not None:
for optimizer in six.itervalues(self._optimizers):
if device.xp is cuda.cupy:
# Do not transfer between different cupy devices.
# Detect GPU-to-GPU transfer and raise FutureWarning.
# TODO(niboshi): Eventually replace it with to_device.
thread_local = device_resident._thread_local
has_gpu_to_gpu = False
try:
# Turn on GPU-to-GPU detection
thread_local.flag_gpu_to_gpu = False
with warnings.catch_warnings():
warnings.filterwarnings(
'ignore',
message='to_gpu is deprecated.',
category=DeprecationWarning)
optimizer.target.to_gpu(device.device.id)
has_gpu_to_gpu = thread_local.flag_gpu_to_gpu
finally:
# Turn off GPU-to-GPU detection
thread_local.flag_gpu_to_gpu = None
if has_gpu_to_gpu:
warnings.warn(
'''\
Transfer between @cupy devices was detected and skipped. \
StandardUpdater normally transfers the model to the specified device, but \
except for between @cupy devices. \
That is, if a part of the model is on @cupy:n device and the specified \
device is @cupy:m device, that part of the model will be left in @cupy:n \
device. This behavior is planned to be changed in near future. \
After that, the model will be transferred to the specified device regardless \
of device combination. \
If you want to keep the model device but only want to transfer the input data \
to a given device, specify the 'input_device' argument instead and leave the \
'device' argument unspecified.
''',
FutureWarning)
else:
optimizer.target.to_device(device)
self.converter = converter
self.loss_func = loss_func
self.iteration = 0
self._device = device
self._input_device = input_device
self.loss_scale = loss_scale
if loss_scale is not None:
for optimizer in six.itervalues(self._optimizers):
optimizer.set_loss_scale(loss_scale)
self.auto_new_epoch = auto_new_epoch
if auto_new_epoch:
for o in six.itervalues(self._optimizers):
o.use_auto_new_epoch = True
@property
def device(self):
return self._device
@property
def input_device(self):
return self._input_device
@property
def epoch(self):
return self._iterators['main'].epoch
@property
def epoch_detail(self):
return self._iterators['main'].epoch_detail
@property
def previous_epoch_detail(self):
return self._iterators['main'].previous_epoch_detail
@property
def is_new_epoch(self):
return self._iterators['main'].is_new_epoch
def finalize(self):
"""Finalizes the updater object.
This method calls the `finalize` method of each iterator that
this updater has.
It is called at the end of training loops.
"""
for iterator in six.itervalues(self._iterators):
iterator.finalize()
def get_optimizer(self, name):
"""Gets the optimizer of given name.
Args:
name (str): Name of the optimizer.
Returns:
~chainer.Optimizer: Corresponding optimizer.
"""
return self._optimizers[name]
def get_all_optimizers(self):
"""Gets a dictionary of all optimizers for this updater.
Returns:
dict: Dictionary that maps names to optimizers.
"""
return dict(self._optimizers)
def get_iterator(self, name):
"""Gets the dataset iterator of given name.
Args:
name (str): Name of the dataset iterator.
Returns:
~chainer.dataset.Iterator: Corresponding dataset iterator.
"""
return self._iterators[name]
def update(self):
"""Updates the parameters of the target model.
This method implements an update formula for the training task,
including data loading, forward/backward computations, and actual
updates of parameters.
This method is called once at each iteration of the training loop.
"""
self.update_core()
self.iteration += 1
def update_core(self):
iterator = self._iterators['main']
batch = iterator.next()
in_arrays = convert._call_converter(
self.converter, batch, self.input_device)
optimizer = self._optimizers['main']
loss_func = self.loss_func or optimizer.target
if isinstance(in_arrays, tuple):
optimizer.update(loss_func, *in_arrays)
elif isinstance(in_arrays, dict):
optimizer.update(loss_func, **in_arrays)
else:
optimizer.update(loss_func, in_arrays)
if self.auto_new_epoch and iterator.is_new_epoch:
optimizer.new_epoch(auto=True)
def serialize(self, serializer):
"""Serializes the current state of the updater object."""
for name, iterator in six.iteritems(self._iterators):
iterator.serialize(serializer['iterator:' + name])
for name, optimizer in six.iteritems(self._optimizers):
optimizer.serialize(serializer['optimizer:' + name])
optimizer.target.serialize(serializer['model:' + name])
self.iteration = serializer('iteration', self.iteration)