Skip to content

Commit

Permalink
Merge 9c9d9c3 into 7aca074
Browse files Browse the repository at this point in the history
  • Loading branch information
niboshi committed Mar 26, 2018
2 parents 7aca074 + 9c9d9c3 commit 7c5e299
Show file tree
Hide file tree
Showing 9 changed files with 647 additions and 44 deletions.
24 changes: 22 additions & 2 deletions chainer/serializers/hdf5.py
@@ -1,4 +1,7 @@
import sys

import numpy
import six

from chainer.backends import cuda
from chainer import serializer
Expand Down Expand Up @@ -72,6 +75,7 @@ def save_hdf5(filename, obj, compression=4):
Args:
filename (str): Target file name.
obj: Object to be serialized. It must support serialization protocol.
If it is a dictionary object, the serialization will be skipped.
compression (int): Gzip compression level.
.. note::
Expand All @@ -86,8 +90,24 @@ def save_hdf5(filename, obj, compression=4):
"""
_check_available()
with h5py.File(filename, 'w') as f:
s = HDF5Serializer(f, compression=compression)
s.save(obj)
if isinstance(obj, dict):
for key, value in obj.items():
key = '/' + key.lstrip('/')
arr = numpy.asarray(value)
compression = None if arr.size <= 1 else compression
try:
f.create_dataset(key, data=arr, compression=compression)
except TypeError:
sys.stderr.write(
'A key named "{}" is unable to save in HDF5 format.\n')
# In Chainer, LogReport extension and PlotReport extension
# are # unable to save in HDF5 format. These extensions
# have a data type `numpy.dtype('O')` which is not
# supported by h5py.
six.reraise(*sys.exec_info())
else:
s = HDF5Serializer(f, compression=compression)
s.save(obj)


class HDF5Deserializer(serializer.Deserializer):
Expand Down
28 changes: 24 additions & 4 deletions chainer/serializers/npz.py
Expand Up @@ -53,6 +53,20 @@ def __call__(self, key, value):
return ret


def serialize(obj):
"""Serializes an object to a dictionary object.
Args:
obj: Object to be serialized. It must support serialization protocol.
Returns:
dict: Serialized object.
"""
s = DictionarySerializer()
s.save(obj)
return s.target


def save_npz(file, obj, compression=True):
"""Saves an object to the file in NPZ format.
Expand All @@ -61,6 +75,7 @@ def save_npz(file, obj, compression=True):
Args:
file (str or file-like): Target file to write to.
obj: Object to be serialized. It must support serialization protocol.
If it is a dictionary object, the serialization will be skipped.
compression (bool): If ``True``, compression in the resulting zip file
is enabled.
Expand All @@ -73,12 +88,17 @@ def save_npz(file, obj, compression=True):
save_npz(f, obj, compression)
return

s = DictionarySerializer()
s.save(obj)
if isinstance(obj, dict):
target = obj
else:
s = DictionarySerializer()
s.save(obj)
target = s.target

if compression:
numpy.savez_compressed(file, **s.target)
numpy.savez_compressed(file, **target)
else:
numpy.savez(file, **s.target)
numpy.savez(file, **target)


class NpzDeserializer(serializer.Deserializer):
Expand Down
168 changes: 141 additions & 27 deletions chainer/training/extensions/_snapshot.py
@@ -1,9 +1,7 @@
import os
import shutil
import tempfile

from chainer.serializers import npz
from chainer.training import extension
from chainer.training.extensions import snapshot_writers
from chainer.utils import argument


def snapshot_object(target, filename, savefun=npz.save_npz):
Expand Down Expand Up @@ -31,19 +29,23 @@ def snapshot_object(target, filename, savefun=npz.save_npz):
output file path and the object to serialize.
Returns:
An extension function.
Snapshot extension object.
.. seealso::
- :meth:`chainer.training.extensions.snapshot`
"""
@extension.make_extension(trigger=(1, 'epoch'), priority=-100)
def snapshot_object(trainer):
_snapshot_object(trainer, target, filename.format(trainer), savefun)
return _Snapshot(
target=target,
writer=snapshot_writers.SimpleWriter(savefun=savefun),
filename=filename)

return snapshot_object

def snapshot(savefun=None,
filename='snapshot_iter_{.updater.iteration}', **kwargs):
"""snapshot(savefun=None, filename='snapshot_iter_{.updater.iteration}', *, target=None, condition=None, writer=None)
def snapshot(savefun=npz.save_npz,
filename='snapshot_iter_{.updater.iteration}'):
"""Returns a trainer extension to take snapshots of the trainer.
Returns a trainer extension to take snapshots of the trainer.
This extension serializes the trainer object and saves it to the output
directory. It is used to support resuming the training loop from the saved
Expand All @@ -66,26 +68,138 @@ def snapshot(savefun=npz.save_npz,
Args:
savefun: Function to save the trainer. It takes two arguments: the
output file path and the trainer object.
It is :meth:`chainer.serializers.save_npz` by default.
If ``writer`` is specified, this argument must be ``None``.
filename (str): Name of the file into which the trainer is serialized.
It can be a format string, where the trainer object is passed to
the :meth:`str.format` method.
target: Object to serialize. If it is not specified, it will
be the trainer object.
condition: Condition object. It must be a callable object that returns
boolean without any arguments. If it returns ``True``, the snapshot
will be done.
If not, it will be skipped. The default is a function that always
returns ``True``.
writer: Writer object.
It must be a callable object.
See below for the list of built-in writers.
If ``savefun`` is other than ``None``, this argument must be
``None``. In that case, a
:class:`~chainer.training.extensions.snapshot_writers.SimpleWriter`
object instantiated with specified ``savefun`` argument will be
used.
"""
@extension.make_extension(trigger=(1, 'epoch'), priority=-100)
def snapshot(trainer):
_snapshot_object(trainer, trainer, filename.format(trainer), savefun)
Returns:
Snapshot extension object.
.. testcode::
:hide:
from chainer import training
class Model(chainer.Link):
def __call__(self, x):
return x
train_iter = chainer.iterators.SerialIterator([], 1)
optimizer = optimizers.SGD().setup(Model())
updater = training.updaters.StandardUpdater(
train_iter, optimizer, device=0)
trainer = training.Trainer(updater)
.. admonition:: Using asynchronous writers
By specifying ``writer`` argument, writing operations can be made
asynchronous, hiding I/O overhead of snapshots.
>>> from chainer.training import extensions
>>> writer = extensions.snapshot_writers.ProcessWriter()
>>> trainer.extend(extensions.snapshot(writer=writer), \
trigger=(1, 'epoch'))
To change the format, such as npz or hdf5, you can pass a saving
function as ``savefun`` argument of the writer.
>>> from chainer.training import extensions
>>> writer = extensions.snapshot_writers.ProcessWriter(
... savefun=extensions.snapshots.util.save_npz)
>>> trainer.extend(extensions.snapshot(writer=writer), \
trigger=(1, 'epoch'))
This is the list of built-in snapshot writers.
- :class:`chainer.training.extensions.snapshot_writers.SimpleWriter`
- :class:`chainer.training.extensions.snapshot_writers.ThreadWriter`
- :class:`chainer.training.extensions.snapshot_writers.ProcessWriter`
- :class:`chainer.training.extensions.snapshot_writers.\
ThreadQueueWriter`
- :class:`chainer.training.extensions.snapshot_writers.\
ProcessQueueWriter`
.. seealso::
- :meth:`chainer.training.extensions.snapshot_object`
""" # NOQA
target = kwargs.pop('target', None)
condition = kwargs.pop('condition', None)
writer = kwargs.pop('writer', None)
if savefun is not None and writer is not None:
raise TypeError(
'savefun and writer argument cannot be specified together.')
argument.assert_kwargs_empty(kwargs)

return snapshot
if writer is None:
if savefun is None:
savefun = npz.save_npz
writer = snapshot_writers.SimpleWriter(savefun=savefun)

return _Snapshot(
target=target, condition=condition, writer=writer, filename=filename)

def _snapshot_object(trainer, target, filename, savefun):
fn = filename.format(trainer)
prefix = 'tmp' + fn

tmpdir = tempfile.mkdtemp(prefix=prefix, dir=trainer.out)
tmppath = os.path.join(tmpdir, fn)
try:
savefun(tmppath, target)
shutil.move(tmppath, os.path.join(trainer.out, fn))
finally:
shutil.rmtree(tmpdir)
def _always_true():
return True


class _Snapshot(extension.Extension):
"""Trainer extension to take snapshots.
This extension serializes the given object and saves it to the output
directory.
This extension is called once per epoch by default. To take a
snapshot at a different interval, a trigger object specifying the
required interval can be passed along with this extension
to the `extend()` method of the trainer.
The default priority is -100, which is lower than that of most
built-in extensions.
"""
trigger = 1, 'epoch'
priority = -100

def __init__(
self, target=None, condition=None, writer=None,
filename='snapshot_iter_{.updater.iteration}'):
if condition is None:
condition = _always_true
if writer is None:
writer = snapshot_writers.SimpleWriter()
self._target = target
self.filename = filename
self.condition = condition
self.writer = writer

def __call__(self, trainer):
if self.condition():
target = trainer if self._target is None else self._target
serialized_target = npz.serialize(target)
filename = self.filename
if callable(filename):
filename = filename(trainer)
else:
filename = filename.format(trainer)
outdir = trainer.out
self.writer(filename, outdir, serialized_target)

def finalize(self):
if hasattr(self.writer, 'finalize'):
self.writer.finalize()

0 comments on commit 7c5e299

Please sign in to comment.