357 lines (301 sloc) 11.5 KB
"""Utilities related to disk I/O."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from collections import defaultdict
import sys
import six
import h5py
except ImportError:
h5py = None
if sys.version_info[0] == 3:
import pickle
import cPickle as pickle
class HDF5Matrix(object):
"""Representation of HDF5 dataset to be used instead of a Numpy array.
# Example
x_data = HDF5Matrix('input/file.hdf5', 'data')
Providing `start` and `end` allows use of a slice of the dataset.
Optionally, a normalizer function (or lambda) can be given. This will
be called on every slice of data retrieved.
# Arguments
datapath: string, path to a HDF5 file
dataset: string, name of the HDF5 dataset in the file specified
in datapath
start: int, start of desired slice of the specified dataset
end: int, end of desired slice of the specified dataset
normalizer: function to be called on data when retrieved
# Returns
An array-like HDF5 dataset.
refs = defaultdict(int)
def __init__(self, datapath, dataset, start=0, end=None, normalizer=None):
if h5py is None:
raise ImportError('The use of HDF5Matrix requires '
'HDF5 and h5py installed.')
if datapath not in list(self.refs.keys()):
f = h5py.File(datapath)
self.refs[datapath] = f
f = self.refs[datapath] = f[dataset]
self.start = start
if end is None:
self.end =[0]
self.end = end
self.normalizer = normalizer
if self.normalizer is not None:
first_val = self.normalizer([0:1])
first_val =[0:1]
self._base_shape = first_val.shape[1:]
self._base_dtype = first_val.dtype
def __len__(self):
return self.end - self.start
def __getitem__(self, key):
if isinstance(key, slice):
start, stop = key.start, key.stop
if start is None:
start = 0
if stop is None:
stop = self.shape[0]
if stop + self.start <= self.end:
idx = slice(start + self.start, stop + self.start)
raise IndexError
elif isinstance(key, (int, np.integer)):
if key + self.start < self.end:
idx = key + self.start
raise IndexError
elif isinstance(key, np.ndarray):
if np.max(key) + self.start < self.end:
idx = (self.start + key).tolist()
raise IndexError
# Assume list/iterable
if max(key) + self.start < self.end:
idx = [x + self.start for x in key]
raise IndexError
if self.normalizer is not None:
return self.normalizer([idx])
def shape(self):
"""Gets a numpy-style shape tuple giving the dataset dimensions.
# Returns
A numpy-style shape tuple.
return (self.end - self.start,) + self._base_shape
def dtype(self):
"""Gets the datatype of the dataset.
# Returns
A numpy dtype string.
return self._base_dtype
def ndim(self):
"""Gets the number of dimensions (rank) of the dataset.
# Returns
An integer denoting the number of dimensions (rank) of the dataset.
def size(self):
"""Gets the total dataset size (number of elements).
# Returns
An integer denoting the number of elements in the dataset.
def ask_to_proceed_with_overwrite(filepath):
"""Produces a prompt asking about overwriting a file.
# Arguments
filepath: the path to the file to be overwritten.
# Returns
True if we can proceed with overwrite, False otherwise.
overwrite = six.moves.input('[WARNING] %s already exists - overwrite? '
'[y/n]' % (filepath)).strip().lower()
while overwrite not in ('y', 'n'):
overwrite = six.moves.input('Enter "y" (overwrite) or "n" '
if overwrite == 'n':
return False
print('[TIP] Next time specify overwrite=True!')
return True
class H5Dict(object):
""" A dict-like wrapper around h5py groups (or dicts).
This allows us to have a single serialization logic
for both pickling and saving to disk.
Note: This is not intended to be a generic wrapper.
There are lot of edge cases which have been hardcoded,
and makes sense only in the context of model serialization/
# Arguments
path: Either a string (path on disk), a dict, or a HDF5 Group.
mode: File open mode (one of `{"a", "r", "w"}`).
def __init__(self, path, mode='a'):
if isinstance(path, h5py.Group): = path
self._is_file = False
elif isinstance(path, six.string_types): = h5py.File(path, mode=mode)
self._is_file = True
elif isinstance(path, dict): = path
self._is_file = False
if mode == 'w':
# Flag to check if a dict is user defined data or a sub group:['_is_group'] = True
raise TypeError('Required Group, str or dict. '
'Received: {}.'.format(type(path)))
self.read_only = mode == 'r'
def __setitem__(self, attr, val):
if self.read_only:
raise ValueError('Cannot set item in read-only mode.')
is_np = type(val).__module__ == np.__name__
if isinstance(, dict):
if isinstance(attr, bytes):
attr = attr.decode('utf-8')
if is_np:[attr] = pickle.dumps(val)
# We have to remember to unpickle in __getitem__['_{}_pickled'.format(attr)] = True
else:[attr] = val
if isinstance(, h5py.Group) and attr in
raise KeyError('Cannot set attribute. '
'Group with name "{}" exists.'.format(attr))
if is_np:
dataset =, val.shape, dtype=val.dtype)
if not val.shape:
# scalar
dataset[()] = val
dataset[:] = val
elif isinstance(val, (list, tuple)):
# Check that no item in `data` is larger than `HDF5_OBJECT_HEADER_LIMIT`
# because in that case even chunking the array would not make the saving
# possible.
bad_attributes = [x for x in val if len(x) > HDF5_OBJECT_HEADER_LIMIT]
# Expecting this to never be true.
if bad_attributes:
raise RuntimeError('The following attributes cannot be saved to '
'HDF5 file because they are larger than '
'%d bytes: %s' % (HDF5_OBJECT_HEADER_LIMIT,
', '.join(bad_attributes)))
if (val and sys.version_info[0] == 3 and isinstance(
val[0], six.string_types)):
# convert to bytes
val = [x.encode('utf-8') for x in val]
data_npy = np.asarray(val)
num_chunks = 1
chunked_data = np.array_split(data_npy, num_chunks)
# This will never loop forever thanks to the test above.
is_too_big = lambda x: x.nbytes > HDF5_OBJECT_HEADER_LIMIT
while any(map(is_too_big, chunked_data)):
num_chunks += 1
chunked_data = np.array_split(data_npy, num_chunks)
if num_chunks > 1:
for chunk_id, chunk_data in enumerate(chunked_data):['%s%d' % (attr, chunk_id)] = chunk_data
else:[attr] = val
else:[attr] = val
def __getitem__(self, attr):
if isinstance(, dict):
if isinstance(attr, bytes):
attr = attr.decode('utf-8')
if attr in
val =[attr]
if isinstance(val, dict) and val.get('_is_group'):
val = H5Dict(val)
elif '_{}_pickled'.format(attr) in
val = pickle.loads(val)
return val
if self.read_only:
raise ValueError('Cannot create group in read-only mode.')
val = {'_is_group': True}[attr] = val
return H5Dict(val)
if attr in
val =[attr]
if type(val).__module__ == np.__name__:
if val.dtype.type == np.string_:
val = val.tolist()
elif attr in
val =[attr]
if isinstance(val, h5py.Dataset):
val = np.asarray(val)
val = H5Dict(val)
# could be chunked
chunk_attr = '%s%d' % (attr, 0)
is_chunked = chunk_attr in
if is_chunked:
val = []
chunk_id = 0
while chunk_attr in
chunk =[chunk_attr]
val.extend([x.decode('utf8') for x in chunk])
chunk_id += 1
chunk_attr = '%s%d' % (attr, chunk_id)
if self.read_only:
raise ValueError('Cannot create group in read-only mode.')
val = H5Dict(
return val
def __len__(self):
return len(
def __iter__(self):
return iter(
def iter(self):
return iter(
def __getattr__(self, attr):
def wrapper(f):
def h5wrapper(*args, **kwargs):
out = f(*args, **kwargs)
if isinstance(, type(out)):
return H5Dict(out)
return out
return h5wrapper
return wrapper(getattr(, attr))
def close(self):
if isinstance(, h5py.Group):
if self._is_file:
def update(self, *args):
if isinstance(, dict):*args)
raise NotImplementedError
def __contains__(self, key):
if isinstance(, dict):
return key in
return (key in or (key in
def get(self, key, default=None):
if key in self:
return self[key]
return default
h5dict = H5Dict