/
hdf5.py
154 lines (117 loc) · 4.59 KB
/
hdf5.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import numpy
from chainer import cuda
from chainer import serializer
try:
import h5py
_available = True
except ImportError:
_available = False
def _check_available():
if not _available:
msg = '''h5py is not installed on your environment.
Please install h5py to activate hdf5 serializers.
$ pip install h5py'''
raise RuntimeError(msg)
class HDF5Serializer(serializer.Serializer):
"""Serializer for HDF5 format.
This is the standard serializer in Chainer. The chain hierarchy is simply
mapped to HDF5 hierarchical groups.
Args:
group (h5py.Group): The group that this serializer represents.
compression (int): Gzip compression level.
"""
def __init__(self, group, compression=4):
_check_available()
self.group = group
self.compression = compression
def __getitem__(self, key):
name = self.group.name + '/' + key
return HDF5Serializer(self.group.require_group(name), self.compression)
def __call__(self, key, value):
ret = value
if isinstance(value, cuda.ndarray):
value = cuda.to_cpu(value)
if value is None:
# use Empty to represent None
if h5py.version.version_tuple < (2, 7, 0):
raise RuntimeError(
'h5py>=2.7.0 is required to serialize None.')
arr = h5py.Empty('f')
compression = None
else:
arr = numpy.asarray(value)
compression = None if arr.size <= 1 else self.compression
self.group.create_dataset(key, data=arr, compression=compression)
return ret
def save_hdf5(filename, obj, compression=4):
"""Saves an object to the file in HDF5 format.
This is a short-cut function to save only one object into an HDF5 file. If
you want to save multiple objects to one HDF5 file, use
:class:`HDF5Serializer` directly by passing appropriate :class:`h5py.Group`
objects.
Args:
filename (str): Target file name.
obj: Object to be serialized. It must support serialization protocol.
compression (int): Gzip compression level.
"""
_check_available()
with h5py.File(filename, 'w') as f:
s = HDF5Serializer(f, compression=compression)
s.save(obj)
class HDF5Deserializer(serializer.Deserializer):
"""Deserializer for HDF5 format.
This is the standard deserializer in Chainer. This deserializer can be used
to read an object serialized by :class:`HDF5Serializer`.
Args:
group (h5py.Group): The group that the deserialization starts from.
strict (bool): If ``True``, the deserializer raises an error when an
expected value is not found in the given HDF5 file. Otherwise,
it ignores the value and skip deserialization.
"""
def __init__(self, group, strict=True):
_check_available()
self.group = group
self.strict = strict
def __getitem__(self, key):
name = self.group.name + '/' + key
try:
group = self.group.require_group(name)
except ValueError:
# require_group raises ValueError if there does not exist
# the given group and the file is read mode.
group = None
return HDF5Deserializer(group, strict=self.strict)
def __call__(self, key, value):
if self.group is None:
if not self.strict:
return value
else:
raise ValueError('Inexistent group is specified')
if not self.strict and key not in self.group:
return value
dataset = self.group[key]
if dataset.shape is None: # Empty
return None
if value is None:
return numpy.asarray(dataset)
if isinstance(value, numpy.ndarray):
dataset.read_direct(value)
elif isinstance(value, cuda.ndarray):
value.set(numpy.asarray(dataset))
else:
value = type(value)(numpy.asarray(dataset))
return value
def load_hdf5(filename, obj):
"""Loads an object from the file in HDF5 format.
This is a short-cut function to load from an HDF5 file that contains only
one object. If you want to load multiple objects from one HDF5 file, use
:class:`HDF5Deserializer` directly by passing appropriate
:class:`h5py.Group` objects.
Args:
filename (str): Name of the file to be loaded.
obj: Object to be deserialized. It must support serialization protocol.
"""
_check_available()
with h5py.File(filename, 'r') as f:
d = HDF5Deserializer(f)
d.load(obj)