-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
serializer.py
54 lines (38 loc) · 1.53 KB
/
serializer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
from chainer import serializers
from chainer import utils
def save_and_load(src, dst, filename, saver, loader):
"""Saves ``src`` and loads it to ``dst`` using a de/serializer.
This function simply runs a serialization and deserialization to check if
the serialization code is correctly implemented. The save and load are
done within a temporary directory.
Args:
src: An object to save from.
dst: An object to load into.
filename (str): File name used during the save/load.
saver (callable): Function that saves the source object.
loader (callable): Function that loads the file into the destination
object.
"""
with utils.tempdir() as tempdir:
path = os.path.join(tempdir, filename)
saver(path, src)
loader(path, dst)
def save_and_load_npz(src, dst):
"""Saves ``src`` to an NPZ file and loads it to ``dst``.
This is a short cut of :func:`save_and_load` using NPZ de/serializers.
Args:
src: An object to save.
dst: An object to load to.
"""
save_and_load(src, dst, 'tmp.npz',
serializers.save_npz, serializers.load_npz)
def save_and_load_hdf5(src, dst):
"""Saves ``src`` to an HDF5 file and loads it to ``dst``.
This is a short cut of :func:`save_and_load` using HDF5 de/serializers.
Args:
src: An object to save.
dst: An object to load to.
"""
save_and_load(src, dst, 'tmp.h5',
serializers.save_hdf5, serializers.load_hdf5)