/
snapshot_writers.py
331 lines (241 loc) · 9.91 KB
/
snapshot_writers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
import multiprocessing
import os
import shutil
from six.moves import queue
import threading
from chainer.serializers import npz
from chainer import utils
class Writer(object):
"""Base class of snapshot writers.
:class:`~chainer.training.extensions.Snapshot` invokes ``__call__`` of this
class everytime when taking a snapshot.
This class determines how the actual saving function will be invoked.
.. seealso::
- :meth:`chainer.training.extensions.snapshot`
"""
def __init__(self):
self._post_save_hooks = []
def __call__(self, filename, outdir, target):
"""Invokes the actual snapshot function.
This method is invoked by a
:class:`~chainer.training.extensions.Snapshot` object every time it
takes a snapshot.
Args:
filename (str): Name of the file into which the serialized target
is saved. It is a concrete file name, i.e. not a pre-formatted
template string.
outdir (str): Output directory. Corresponds to
:py:attr:`Trainer.out <chainer.training.Trainer.out>`.
target (dict): Serialized object which will be saved.
"""
raise NotImplementedError
def __del__(self):
self.finalize()
def finalize(self):
"""Finalizes the wirter.
Like extensions in :class:`~chainer.training.Trainer`, this method
is invoked at the end of the training.
"""
pass
def save(self, filename, outdir, target, savefun, **kwds):
prefix = 'tmp' + filename
with utils.tempdir(prefix=prefix, dir=outdir) as tmpdir:
tmppath = os.path.join(tmpdir, filename)
savefun(tmppath, target)
shutil.move(tmppath, os.path.join(outdir, filename))
self._post_save()
def _add_cleanup_hook(self, hook_fun):
"""Adds cleanup hook function.
Technically, arbitrary user-defined hook can be called, but
this is intended for cleaning up stale snapshots.
Args:
hook_fun (callable): callable funtion to be called
right after save is done. It takes no arguments.
"""
self._post_save_hooks.append(hook_fun)
def _post_save(self):
for hook in self._post_save_hooks:
hook()
class SimpleWriter(Writer):
"""The most simple snapshot writer.
This class just passes the arguments to the actual saving function.
Args:
savefun: Callable object. It takes three arguments: the output file
path, the serialized dictionary object, and the optional keyword
arguments.
kwds: Keyword arguments for the ``savefun``.
.. seealso::
- :meth:`chainer.training.extensions.snapshot`
"""
def __init__(self, savefun=npz.save_npz, **kwds):
super(SimpleWriter, self).__init__()
self._savefun = savefun
self._kwds = kwds
def __call__(self, filename, outdir, target):
self.save(filename, outdir, target, self._savefun, **self._kwds)
class StandardWriter(Writer):
"""Base class of snapshot writers which use thread or process.
This class creates a new thread or a process every time when ``__call__``
is invoked.
Args:
savefun: Callable object. It takes three arguments: the output file
path, the serialized dictionary object, and the optional keyword
arguments.
kwds: Keyword arguments for the ``savefun``.
.. seealso::
- :meth:`chainer.training.extensions.snapshot`
"""
_started = False
_finalized = False
_worker = None
def __init__(self, savefun=npz.save_npz, **kwds):
super(StandardWriter, self).__init__()
self._savefun = savefun
self._kwds = kwds
self._started = False
self._finalized = False
def __call__(self, filename, outdir, target):
if self._started:
self._worker.join()
self._started = False
self._filename = filename
self._worker = self.create_worker(filename, outdir, target,
**self._kwds)
self._worker.start()
self._started = True
def create_worker(self, filename, outdir, target, **kwds):
"""Creates a worker for the snapshot.
This method creates a thread or a process to take a snapshot. The
created worker must have :meth:`start` and :meth:`join` methods.
Args:
filename (str): Name of the file into which the serialized target
is saved. It is already formated string.
outdir (str): Output directory. Passed by `trainer.out`.
target (dict): Serialized object which will be saved.
kwds: Keyword arguments for the ``savefun``.
"""
raise NotImplementedError
def finalize(self):
if self._started:
if not self._finalized:
self._worker.join()
self._started = False
self._finalized = True
class ThreadWriter(StandardWriter):
"""Snapshot writer that uses a separate thread.
This class creates a new thread that invokes the actual saving function.
.. seealso::
- :meth:`chainer.training.extensions.snapshot`
"""
def __init__(self, savefun=npz.save_npz, **kwds):
super(ThreadWriter, self).__init__(savefun=savefun, **kwds)
def create_worker(self, filename, outdir, target, **kwds):
return threading.Thread(
target=self.save,
args=(filename, outdir, target, self._savefun),
kwargs=self._kwds)
class ProcessWriter(StandardWriter):
"""Snapshot writer that uses a separate process.
This class creates a new process that invokes the actual saving function.
.. note::
Forking a new process from a MPI process might be danger. Consider
using :class:`ThreadWriter` instead of ``ProcessWriter`` if you are
using MPI.
.. seealso::
- :meth:`chainer.training.extensions.snapshot`
"""
def __init__(self, savefun=npz.save_npz, **kwds):
super(ProcessWriter, self).__init__(savefun=savefun, **kwds)
def create_worker(self, filename, outdir, target, **kwds):
return multiprocessing.Process(
target=self.save,
args=(filename, outdir, target, self._savefun),
kwargs=self._kwds)
class QueueWriter(Writer):
"""Base class of queue snapshot writers.
This class is a base class of snapshot writers that use a queue.
A Queue is created when this class is constructed, and every time when
``__call__`` is invoked, a snapshot task is put into the queue.
Args:
savefun: Callable object which is passed to the :meth:`create_task`
if the task is ``None``. It takes three arguments: the output file
path, the serialized dictionary object, and the optional keyword
arguments.
task: Callable object. Its ``__call__`` must have a same interface to
``Writer.__call__``. This object is directly put into the queue.
.. seealso::
- :meth:`chainer.training.extensions.snapshot`
"""
_started = False
_finalized = False
_queue = None
_consumer = None
def __init__(self, savefun=npz.save_npz, task=None):
super(QueueWriter, self).__init__()
if task is None:
self._task = self.create_task(savefun)
else:
self._task = task
self._queue = self.create_queue()
self._consumer = self.create_consumer(self._queue)
self._consumer.start()
self._started = True
self._finalized = False
def __call__(self, filename, outdir, target):
self._queue.put([self._task, filename, outdir, target])
def create_task(self, savefun):
return SimpleWriter(savefun=savefun)
def create_queue(self):
raise NotImplementedError
def create_consumer(self, q):
raise NotImplementedError
def consume(self, q):
while True:
task = q.get()
if task is None:
q.task_done()
return
else:
task[0](task[1], task[2], task[3])
q.task_done()
def finalize(self):
if self._started:
if not self._finalized:
self._queue.put(None)
self._queue.join()
self._consumer.join()
self._started = False
self._finalized = True
class ThreadQueueWriter(QueueWriter):
"""Snapshot writer that uses a thread queue.
This class creates a thread and a queue by :mod:`threading` and
:mod:`queue` modules
respectively. The thread will be a consumer of the queue, and the main
thread will be a producer of the queue.
.. seealso::
- :meth:`chainer.training.extensions.snapshot`
"""
def __init__(self, savefun=npz.save_npz, task=None):
super(ThreadQueueWriter, self).__init__(savefun=savefun, task=task)
def create_queue(self):
return queue.Queue()
def create_consumer(self, q):
return threading.Thread(target=self.consume, args=(q,))
class ProcessQueueWriter(QueueWriter):
"""Snapshot writer that uses process queue.
This class creates a process and a queue by :mod:`multiprocessing` module.
The process will be a consumer of this queue, and the main process will be
a producer of this queue.
.. note::
Forking a new process from MPI process might be danger. Consider using
:class:`ThreadQueueWriter` instead of ``ProcessQueueWriter`` if you are
using MPI.
.. seealso::
- :meth:`chainer.training.extensions.snapshot`
"""
def __init__(self, savefun=npz.save_npz, task=None):
super(ProcessQueueWriter, self).__init__(savefun=savefun, task=task)
def create_queue(self):
return multiprocessing.JoinableQueue()
def create_consumer(self, q):
return multiprocessing.Process(target=self.consume, args=(q,))