diff --git a/hickle/__version__.py b/hickle/__version__.py index 087fa0e3..62b8b628 100644 --- a/hickle/__version__.py +++ b/hickle/__version__.py @@ -10,4 +10,4 @@ # %% VERSIONS # Default/Latest/Current version -__version__ = '4.0.1' +__version__ = '4.1.0' diff --git a/hickle/helpers.py b/hickle/helpers.py index 95e66ed1..9907ce70 100644 --- a/hickle/helpers.py +++ b/hickle/helpers.py @@ -1,116 +1,190 @@ # %% IMPORTS # Built-in imports import re +import operator +import typing +import types +import collections +import numbers # Package imports import dill as pickle -# %% FUNCTION DEFINITIONS -def get_type(h_node): - """ Helper function to return the py_type for an HDF node """ - base_type = h_node.attrs['base_type'] - if base_type != b'pickle': - py_type = pickle.loads(h_node.attrs['type']) - else: - py_type = None - return py_type, base_type - - -def get_type_and_data(h_node): - """ Helper function to return the py_type and data block for an HDF node""" - py_type, base_type = get_type(h_node) - data = h_node[()] - return py_type, base_type, data +# %% EXCEPTION DEFINITIONS +nobody_is_my_name = () -def sort_keys(key_list): - """ Take a list of strings and sort it by integer value within string - - Args: - key_list (list): List of keys - - Returns: - key_list_sorted (list): List of keys, sorted by integer +class NotHicklable(Exception): """ + object can not be mapped to proper hickle HDF5 file structure and + thus shall be converted to pickle string before storing. + """ + pass - # Py3 h5py returns an irritating KeysView object - # Py3 also complains about bytes and strings, convert all keys to bytes - key_list2 = [] - for key in key_list: - if isinstance(key, str): - key = bytes(key, 'ascii') - key_list2.append(key) - key_list = key_list2 - - # Check which keys contain a number - numbered_keys = [re.search(br'\d+', key) for key in key_list] - - # Sort the keys on number if they have it, or normally if not - if(len(key_list) and not numbered_keys.count(None)): - return(sorted(key_list, - key=lambda x: int(re.search(br'\d+', x).group(0)))) - else: - return(sorted(key_list)) - - -def check_is_iterable(py_obj): - """ Check whether a python object is a built-in iterable. - - Note: this treats unicode and string as NON ITERABLE - - Args: - py_obj: python object to test +# %% CLASS DEFINITIONS - Returns: - iter_ok (bool): True if item is iterable, False is item is not +class PyContainer(): + """ + Abstract base class for all PyContainer classes acting as proxy between + h5py.Group and python object represented by the content of the h5py.Group. + Any container type object as well as complex objects are represented + in a tree like structure on HDF5 file which PyContainer objects ensure to + be properly mapped before beeing converted into the final object. + + Parameters: + ----------- + h5_attrs (h5py.AttributeManager): + attributes defined on h5py.Group object represented by this PyContainer + + base_type (bytes): + the basic type used for representation on the HDF5 file + + object_type: + type of Python object to be restored. Dependent upon container may + be used by PyContainer.convert to convert loaded Python object into + final one. + + Attributes: + ----------- + base_type (bytes): + the basic type used for representation on the HDF5 file + + object_type: + type of Python object to be restored. Dependent upon container may + be used by PyContainer.convert to convert loaded Python object into + final one. + """ - # Check if py_obj is an accepted iterable and return - return(isinstance(py_obj, (tuple, list, set))) - + __slots__ = ("base_type", "object_type", "_h5_attrs", "_content","__dict__" ) + + def __init__(self,h5_attrs, base_type, object_type,_content = None): + """ + Parameters (protected): + ----------------------- + _content (default: list): + container to be used to collect the Python objects representing + the sub items or the state of the final Python object. Shall only + be set by derived PyContainer classes and not be set by + + """ + # the base type used to select this PyContainer + self.base_type = base_type + # class of python object represented by this PyContainer + self.object_type = object_type + # the h5_attrs structure of the h5_group to load the object_type from + # can be used by the append and convert methods to obtain more + # information about the container like object to be restored + self._h5_attrs = h5_attrs + # intermediate list, tuple, dict, etc. used to collect and store the sub items + # when calling the append method + self._content = _content if _content is not None else [] + + def filter(self,items): + yield from items + + def append(self,name,item,h5_attrs): + """ + adds the passed item (object) to the content of this container. + + Parameters: + ----------- + name (string): + the name of the h5py.Dataset or h5py.Group subitem was loaded from + + item: + the Python object of the subitem + + h5_attrs: + attributes defined on h5py.Group or h5py.Dataset object sub item + was loaded from. + """ + self._content.append(item) + + def convert(self): + """ + creates the final object and populates it with the items stored in the _content slot + must be implemented by the derived Container classes + + Returns: + -------- + py_obj: The final Python object loaded from file + + + """ + raise NotImplementedError("convert method must be implemented") + + +class H5NodeFilterProxy(): + """ + Proxy class which allows to temporarily modify h5_node.attrs content. + Original attributes of underlying h5_node are left unchanged. + + Parameters: + ----------- + h5_node: + node for which attributes shall be replaced by a temporary value + + """ -def check_is_hashable(py_obj): - """ Check if a python object is hashable + __slots__ = ('_h5_node','attrs','__dict__') + + def __init__(self,h5_node): + self._h5_node = h5_node + self.attrs = collections.ChainMap({},h5_node.attrs) + + def __getattribute__(self,name): + # for attrs and wrapped _h5_node return local copy any other request + # redirect to wrapped _h5_node + if name in {"attrs","_h5_node"}: + return super(H5NodeFilterProxy,self).__getattribute__(name) + _h5_node = super(H5NodeFilterProxy,self).__getattribute__('_h5_node') + return getattr(_h5_node,name) + + def __setattr__(self,name,value): + # if wrapped _h5_node and attrs shall be set store value on local attributes + # otherwise pass on to wrapped _h5_node + if name in {'_h5_node','attrs'}: + super(H5NodeFilterProxy,self).__setattr__(name,value) + return + _h5_node = super(H5NodeFilterProxy,self).__getattribute__('_h5_node') + setattr(_h5_node,name,value) + + def __getitem__(self,*args,**kwargs): + _h5_node = super(H5NodeFilterProxy,self).__getattribute__('_h5_node') + return _h5_node.__getitem__(*args,**kwargs) + # TODO as needed add more function like __getitem__ to fully proxy h5_node + # or consider using metaclass __getattribute__ for handling special methods + + - Note: this function is currently not used, but is useful for future - development. +# %% FUNCTION DEFINITIONS - Args: - py_obj: python object to test +def not_dumpable( py_obj, h_group, name, **kwargs): # pragma: nocover + """ + create_dataset method attached to dummy py_objects used to mimic container + groups by older versions of hickle lacking generic PyContainer mapping + h5py.Groups to corresponding py_object + + + Raises: + ------- + RuntimeError: + in any case as this function shall never be called """ - try: - py_obj.__hash__() - return True - except TypeError: - return False - - -def check_iterable_item_type(iter_obj): - """ Check if all items within an iterable are the same type. + raise RuntimeError("types defined by loaders not dumpable") - Args: - iter_obj: iterable object - Returns: - iter_type: type of item contained within the iterable. If - the iterable has many types, a boolean False is returned instead. - References: - http://stackoverflow.com/questions/13252333 +def no_compression(kwargs): """ - - iseq = iter(iter_obj) - - try: - first_type = type(next(iseq)) - except StopIteration: - return False - except Exception: # pragma: no cover - return False - else: - if all([type(x) is first_type for x in iseq]): - return(first_type) - else: - return(False) + filter which temporarily removes any compression or data filter related + arguments from the kwargs dict. + """ + return { + key:value + for key,value in kwargs.items() + if key not in {"compression","shuffle","compression_opts","chunks","fletcher32","scaleoffset"} + } diff --git a/hickle/hickle.py b/hickle/hickle.py index 5b13d1bf..d51493c8 100644 --- a/hickle/hickle.py +++ b/hickle/hickle.py @@ -26,9 +26,12 @@ # %% IMPORTS # Built-in imports import io +import os.path as os_path from pathlib import Path import sys import warnings +import types +import functools as ft # Package imports import dill as pickle @@ -37,11 +40,12 @@ # hickle imports from hickle import __version__ -from hickle.helpers import ( - get_type, sort_keys, check_is_iterable, check_iterable_item_type) -from hickle.lookup import ( - types_dict, hkl_types_dict, types_not_to_sort, dict_key_types_dict, - check_is_ndarray_like, load_loader) +from .helpers import PyContainer, NotHicklable, nobody_is_my_name +from .lookup import ( + hkl_types_dict, hkl_container_dict, load_loader, load_legacy_loader , + create_pickled_dataset, load_nothing, fix_lambda_obj_type +) + # All declaration __all__ = ['dump', 'load'] @@ -54,12 +58,10 @@ class FileError(Exception): """ An exception raised if the file is fishy """ - pass class ClosedFileError(Exception): """ An exception raised if the file is fishy """ - pass class ToDoError(Exception): # pragma: no cover @@ -67,6 +69,10 @@ class ToDoError(Exception): # pragma: no cover def __str__(self): return "Error: this functionality hasn't been implemented yet." +class SerializedWarning(UserWarning): + """ An object type was not understood + The data will be serialized using pickle. + """ # %% FUNCTION DEFINITIONS def file_opener(f, path, mode='r'): @@ -81,7 +87,7 @@ def file_opener(f, path, mode='r'): File to open for dumping or loading purposes. If str, `file_obj` provides the path of the HDF5-file that must be used. - If :obj:`~h5py._hl.group.Group`, the group (or file) in an open + If :obj:`~h5py.Group`, the group (or file) in an open HDF5-file that must be used. path : str Path within HDF5-file or group to dump to/load from. @@ -97,7 +103,7 @@ def file_opener(f, path, mode='r'): # Make sure that the given path always starts with '/' if not path.startswith('/'): - path = '/%s' % (path) + path = "/%s" % path # Were we handed a file object or just a file name string? if isinstance(f, (io.TextIOWrapper, io.BufferedWriter)): @@ -108,7 +114,7 @@ def file_opener(f, path, mode='r'): elif isinstance(f, (str, Path)): filename = f h5f = h5.File(filename, mode) - elif isinstance(f, h5._hl.group.Group): + elif isinstance(f, h5.Group): try: filename = f.file.filename except ValueError: @@ -129,18 +135,14 @@ def file_opener(f, path, mode='r'): raise FileError("Cannot open file. Please pass either a filename " "string, a file object, or a h5py.File") - return(h5f, path, close_flag) + return h5f, path, close_flag ########### # DUMPERS # ########### -# Get list of dumpable dtypes -dumpable_dtypes = [bool, complex, bytes, float, int, str] - - -def _dump(py_obj, h_group, call_id=None, **kwargs): +def _dump(py_obj, h_group, name, attrs={} , **kwargs): """ Dump a python object to a group within an HDF5 file. This function is called recursively by the main dump() function. @@ -148,45 +150,45 @@ def _dump(py_obj, h_group, call_id=None, **kwargs): Args: py_obj: python object to dump. h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the - iterable. + name (bytes): name of resultin hdf5 group or dataset """ - # Check if we have a unloaded loader for the provided py_obj - load_loader(py_obj) - - # Firstly, check if item is a numpy array. If so, just dump it. - if check_is_ndarray_like(py_obj): - create_hkl_dataset(py_obj, h_group, call_id, **kwargs) - - # Next, check if item is a dict - elif isinstance(py_obj, dict): - create_hkl_dataset(py_obj, h_group, call_id, **kwargs) - - # If not, check if item is iterable - elif check_is_iterable(py_obj): - item_type = check_iterable_item_type(py_obj) - - # item_type == False implies multiple types. Create a dataset - if not item_type: - h_subgroup = create_hkl_group(py_obj, h_group, call_id) - for ii, py_subobj in enumerate(py_obj): - _dump(py_subobj, h_subgroup, call_id=ii, **kwargs) - - # otherwise, subitems have same type. Check if subtype is an iterable - # (e.g. list of lists), or not (e.g. list of ints, which should be - # treated as a single dataset). - else: - if item_type in dumpable_dtypes: - create_hkl_dataset(py_obj, h_group, call_id, **kwargs) - else: - h_subgroup = create_hkl_group(py_obj, h_group, call_id) - for ii, py_subobj in enumerate(py_obj): - _dump(py_subobj, h_subgroup, call_id=ii, **kwargs) - - # item is not iterable, so create a dataset for it + # Check if we have a unloaded loader for the provided py_obj and + # retrive the most apropriate method for crating the corresponding + # representation within HDF5 file + if isinstance( + py_obj, + (types.FunctionType, types.BuiltinFunctionType, types.MethodType, types.BuiltinMethodType, type) + ): + py_obj_type,create_dataset,base_type = object,create_pickled_dataset,b'pickle' else: - create_hkl_dataset(py_obj, h_group, call_id, **kwargs) + py_obj_type, (create_dataset, base_type) = load_loader(py_obj.__class__) + try: + h_node,h_subitems = create_dataset(py_obj, h_group, name, **kwargs) + + # loop through list of all subitems and recursively dump them + # to HDF5 file + for h_subname,py_subobj,h_subattrs,sub_kwargs in h_subitems: + _dump(py_subobj,h_node,h_subname,h_subattrs,**sub_kwargs) + # add addtional attributes and set 'base_type' and 'type' + # attributes accordingly + h_node.attrs.update(attrs) + + # only explicitly store base_type and type if not dumped by + # create_pickled_dataset + if create_dataset is not create_pickled_dataset: + h_node.attrs['base_type'] = base_type + h_node.attrs['type'] = np.array(pickle.dumps(py_obj_type)) + return + except NotHicklable: + + # ask pickle to try to store + h_node,h_subitems = create_pickled_dataset(py_obj, h_group, name, reason = str(NotHicklable), **kwargs) + + # dump any sub items if create_pickled_dataset create an object group + for h_subname,py_subobj,h_subattrs,sub_kwargs in h_subitems: + _dump(py_subobj,h_node,h_subname,h_subattrs,**sub_kwargs) + h_node.attrs.update(attrs) def dump(py_obj, file_obj, mode='w', path='/', **kwargs): @@ -201,7 +203,7 @@ def dump(py_obj, file_obj, mode='w', path='/', **kwargs): File in which to store the object. If str, `file_obj` provides the path of the HDF5-file that must be used. - If :obj:`~h5py._hl.group.Group`, the group (or file) in an open + If :obj:`~h5py.Group`, the group (or file) in an open HDF5-file that must be used. mode : str, optional Accepted values are 'r' (read only), 'w' (write; default) or 'a' @@ -212,7 +214,7 @@ def dump(py_obj, file_obj, mode='w', path='/', **kwargs): Defaults to root ('/'). kwargs : keyword arguments Additional keyword arguments that must be provided to the - :meth:`~h5py._hl.group.Group.create_dataset` method. + :meth:`~h5py.Group.create_dataset` method. """ @@ -228,252 +230,54 @@ def dump(py_obj, file_obj, mode='w', path='/', **kwargs): pv = sys.version_info py_ver = "%i.%i.%i" % (pv[0], pv[1], pv[2]) - # Try to create the root group - try: + h_root_group = h5f.get(path,None) + if h_root_group is None: h_root_group = h5f.create_group(path) - - # If that is not possible, check if it is empty - except ValueError as error: - # Raise error if this group is not empty - if len(h5f[path]): - raise error - else: - h_root_group = h5f.get(path) + elif h_root_group.items(): + raise ValueError("Unable to create group (name already exists)") h_root_group.attrs["HICKLE_VERSION"] = __version__ h_root_group.attrs["HICKLE_PYTHON_VERSION"] = py_ver - _dump(py_obj, h_root_group, **kwargs) + _dump(py_obj, h_root_group,'data', **kwargs) finally: # Close the file if requested. # Closing a file twice will not cause any problems if close_flag: h5f.close() - -def create_dataset_lookup(py_obj): - """ What type of object are we trying to hickle? This is a python - dictionary based equivalent of a case statement. It returns the correct - helper function for a given data type. - - Args: - py_obj: python object to look-up what function to use to dump to disk - - Returns: - match: function that should be used to dump data to a new dataset - base_type: the base type of the data that will be dumped - """ - - # Obtain the MRO of this object - mro_list = py_obj.__class__.mro() - - # Create a type_map - type_map = map(types_dict.get, mro_list) - - # Loop over the entire type_map until something else than None is found - for type_item in type_map: - if type_item is not None: - return(type_item) - - -def create_hkl_dataset(py_obj, h_group, call_id=None, **kwargs): - """ Create a dataset within the hickle HDF5 file - - Args: - py_obj: python object to dump. - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the - iterable. - - """ - # lookup dataset creator type based on python object type - create_dataset, base_type = create_dataset_lookup(py_obj) - - # Set the name of this dataset - name = 'data%s' % ("_%i" % (call_id) if call_id is not None else '') - - # Try to create the dataset - try: - h_subgroup = create_dataset(py_obj, h_group, name, **kwargs) - # If that fails, pickle the object instead - except Exception as error: - # Make sure builtins loader is loaded - load_loader(object) - - # Obtain the proper dataset creator and base type - create_dataset, base_type = types_dict[object] - - # Make sure that a group/dataset with name 'name' does not exist - try: - del h_group[name] - except Exception: - pass - - # Create the pickled dataset - h_subgroup = create_dataset(py_obj, h_group, name, error, **kwargs) - - # Save base type of py_obj - h_subgroup.attrs['base_type'] = base_type - - # Save a pickled version of the true type of py_obj if necessary - if base_type != b'pickle' and 'type' not in h_subgroup.attrs: - h_subgroup.attrs['type'] = np.array(pickle.dumps(py_obj.__class__)) - - -def create_hkl_group(py_obj, h_group, call_id=None): - """ Create a new group within the hickle file - - Args: - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the - iterable. - - """ - - # Set the name of this group - if isinstance(call_id, str): - name = call_id - else: - name = 'data%s' % ("_%i" % (call_id) if call_id is not None else '') - - h_subgroup = h_group.create_group(name) - h_subgroup.attrs['type'] = np.array(pickle.dumps(py_obj.__class__)) - h_subgroup.attrs['base_type'] = create_dataset_lookup(py_obj)[1] - return h_subgroup - - -def create_dict_dataset(py_obj, h_group, name, **kwargs): - """ Creates a data group for each key in dictionary - - Notes: - This is a very important function which uses the recursive _dump - method to build up hierarchical data models stored in the HDF5 file. - As this is critical to functioning, it is kept in the main hickle.py - file instead of in the loaders/ directory. - - Args: - py_obj: python object to dump; should be dictionary - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the - iterable. - """ - - h_dictgroup = h_group.create_group(name) - - for idx, (key, py_subobj) in enumerate(py_obj.items()): - # Obtain the raw string representation of this key - subgroup_key = "%r" % (key) - - # Make sure that the '\\\\' is not in the key, or raise error if so - if '\\\\' in subgroup_key: - del h_group[name] - raise ValueError("Dict item keys containing the '\\\\' string are " - "not supported!") - - # Replace any forward slashes with double backslashes - subgroup_key = subgroup_key.replace('/', '\\\\') - h_subgroup = h_dictgroup.create_group(subgroup_key) - h_subgroup.attrs['base_type'] = b'dict_item' - - h_subgroup.attrs['key_base_type'] = bytes(type(key).__name__, 'ascii') - h_subgroup.attrs['key_type'] = np.array(pickle.dumps(key.__class__)) - - h_subgroup.attrs['key_idx'] = idx - - _dump(py_subobj, h_subgroup, call_id=None, **kwargs) - return(h_dictgroup) - - -# Add create_dict_dataset to types_dict -types_dict[dict] = (create_dict_dataset, b"dict") - - ########### # LOADERS # ########### -class PyContainer(list): - """ A group-like object into which to load datasets. - - In order to build up a tree-like structure, we need to be able - to load datasets into a container with an append() method. - Python tuples and sets do not allow this. This class provides - a list-like object that be converted into a list, tuple, set or dict. +class RootContainer(PyContainer): + """ + PyContainer representing the whole HDF5 file """ - def __init__(self): - super(PyContainer, self).__init__() - self.container_type = None - self.container_base_type = None - self.name = None - self.key_type = None - self.key_base_type = None + __slots__ = () def convert(self): - """ Convert from PyContainer to python core data type. - - Returns: self, either as a list, tuple, set or dict - (or other type specified in lookup.py) - """ - - # If this container is a dict, convert its items properly - if self.container_base_type == b"dict": - # Create empty list of items - items = [[]]*len(self) + return self._content[0] - # Loop over all items in the container - for item in self: - # Obtain the name of this item - key = item.name.split('/')[-1].replace('\\\\', '/') - # Obtain the base type and index of this item's key - key_base_type = item.key_base_type - key_idx = item.key_idx - - # If this key has a type that must be converted, do so - if key_base_type in dict_key_types_dict.keys(): - to_type_fn = dict_key_types_dict[key_base_type] - key = to_type_fn(key) - - # Insert item at the correct index into the list - items[key_idx] = [key, item[0]] - - # Initialize dict using its true type and return - return(self.container_type(items)) - - # In all other cases, return container - else: - # If container has a true type defined, convert to that first - if self.container_type is not None: - return(self.container_type(self)) +class NoMatchContainer(PyContainer): # pragma: no cover + """ + PyContainer used by load when no appropriate container + could be found for specified base_type. + """ - # If not, return the container itself - else: - return(self) + __slots__ = () + def __init__(self,h5_attrs, base_type, object_type): # pragma: no cover + raise RuntimeError("Cannot load container proxy for %s data type " % base_type) + def no_match_load(key): # pragma: no cover - """ If no match is made when loading, need to raise an exception + """ + If no match is made when loading dataset , need to raise an exception """ raise RuntimeError("Cannot load %s data type" % key) - -def load_dataset_lookup(key): - """ What type of object are we trying to unpickle? This is a python - dictionary based equivalent of a case statement. It returns the type - a given 'type' keyword in the hickle file. - - Args: - py_obj: python object to look-up what function to use to dump to disk - - Returns: - match: function that should be used to dump data to a new dataset - """ - - match = hkl_types_dict.get(key, no_match_load) - - return match - - def load(file_obj, path='/', safe=True): """ Load the Python object stored in `file_obj` at `path` and return it. @@ -484,7 +288,7 @@ def load(file_obj, path='/', safe=True): File from which to load the object. If str, `file_obj` provides the path of the HDF5-file that must be used. - If :obj:`~h5py._hl.group.Group`, the group (or file) in an open + If :obj:`~h5py.Group`, the group (or file) in an open HDF5-file that must be used. path : str, optional Path within HDF5-file or group to load data from. @@ -544,20 +348,30 @@ def load(file_obj, path='/', safe=True): return(legacy_v3.load(file_obj, path, safe)) # Else, check if the proper attributes for v4 loading are available - elif all(map(h_root_group.attrs.get, v4_attrs)): + if all(map(h_root_group.attrs.get, v4_attrs)): # Load file - py_container = PyContainer() - py_container = _load(py_container, h_root_group['data']) - return(py_container[0]) + py_container = RootContainer(h_root_group.attrs,b'document_root',RootContainer) + pickle_loads = pickle.loads + hickle_version = h_root_group.attrs["HICKLE_VERSION"].split('.') + if int(hickle_version[0]) == 4 and int(hickle_version[1]) < 1: + # hickle 4.0.x file activate if legacy load fixes for 4.0.x + # eg. pickle of versions < 3.8 do not prevent dumping of lambda functions + # eventhough stated otherwise in documentation. Activate workarrounds + # just in case issues arrise. Especially as corresponding lambdas in + # load_numpy are not needed anymore and thus have been removed. + pickle_loads = fix_lambda_obj_type + _load(py_container, 'data',h_root_group['data'],pickle_loads = fix_lambda_obj_type,load_loader = load_legacy_loader) + return py_container.convert() + # 4.1.x file and newer + _load(py_container, 'data',h_root_group['data'],pickle_loads = pickle.loads,load_loader = load_loader) + return py_container.convert() # Else, raise error - else: # pragma: no cover - raise FileError("HDF5-file does not have the proper attributes!") + raise FileError("HDF5-file does not have the proper attributes!") # If this fails, raise error and provide user with caught error message - except Exception as error: # pragma: no cover - raise ValueError("Provided argument 'file_obj' does not appear to be a" - " valid hickle file! (%s)" % (error)) + except Exception as error: + raise ValueError("Provided argument 'file_obj' does not appear to be a valid hickle file! (%s)" % (error),error) from error finally: # Close the file if requested. # Closing a file twice will not cause any problems @@ -565,73 +379,62 @@ def load(file_obj, path='/', safe=True): h5f.close() -def load_dataset(h_node): - """ Load a dataset, converting into its correct python type - - Args: - h_node (h5py dataset): h5py dataset object to read - - Returns: - data: reconstructed python object from loaded data - """ - py_type, base_type = get_type(h_node) - load_loader(py_type) - - load_fn = load_dataset_lookup(base_type) - data = load_fn(h_node) - - # If data is not py_type yet, convert to it (unless it is pickle) - if base_type != b'pickle' and type(data) != py_type: - data = py_type(data) - return data - -def _load(py_container, h_group): +def _load(py_container, h_name, h_node,pickle_loads=pickle.loads,load_loader = load_loader): """ Load a hickle file Recursive funnction to load hdf5 data into a PyContainer() Args: py_container (PyContainer): Python container to load data into - h_group (h5 group or dataset): h5py object, group or dataset, to spider + h_name (string): the name of the resulting h5py object group or dataset + h_node (h5 group or dataset): h5py object, group or dataset, to spider and load all datasets. + pickle_loads (FunctionType,MethodType): defaults to pickle.loads and will + be switched to fix_lambda_obj_type if file to be loaded was created by + hickle 4.0.x version + load_loader (FunctionType,MethodType): defaults to lookup.load_loader and + will be switched to load_legacy_loader if file to be loaded was + created by hickle 4.0.x version """ + # load base_type of node. if not set assume that it contains + # pickled object data to be restored through load_pickled_data or + # PickledContainer object in case of group. + base_type = h_node.attrs.get('base_type',b'pickle') + if base_type == b'pickle': + # pickled dataset or group assume its object_type to be object + # as true object type is anyway handled by load_pickled_data or + # PickledContainer + py_obj_type = object + else: + # extract object_type and ensure loader beeing able to handle is loaded + # loading is controlled through base_type, object_type is just required + # to allow load_fn or py_subcontainer to properly restore and cast + # py_obj to proper object type + py_obj_type = pickle_loads(h_node.attrs.get('type',None)) + py_obj_type,_ = load_loader(py_obj_type) + # Either a file, group, or dataset - if isinstance(h_group, h5._hl.group.Group): - - py_subcontainer = PyContainer() - py_subcontainer.container_base_type = bytes(h_group.attrs['base_type']) - - py_subcontainer.name = h_group.name - - if py_subcontainer.container_base_type == b'dict_item': - py_subcontainer.key_base_type = h_group.attrs['key_base_type'] - py_obj_type = pickle.loads(h_group.attrs['key_type']) - py_subcontainer.key_type = py_obj_type - py_subcontainer.key_idx = h_group.attrs['key_idx'] - else: - py_obj_type = pickle.loads(h_group.attrs['type']) - py_subcontainer.container_type = py_obj_type - - # Check if we have an unloaded loader for the provided py_obj - load_loader(py_obj_type) - - if py_subcontainer.container_base_type not in types_not_to_sort: - h_keys = sort_keys(h_group.keys()) - else: - h_keys = h_group.keys() - - for h_name in h_keys: - h_node = h_group[h_name] - py_subcontainer = _load(py_subcontainer, h_node) - + if isinstance(h_node, h5.Group): + + py_container_class = hkl_container_dict.get(base_type,NoMatchContainer) + py_subcontainer = py_container_class(h_node.attrs,base_type,py_obj_type) + + # NOTE: Sorting of container items according to their key Name is + # to be handled by container class provided by loader only + # as loader has all the knowledge required to properly decide + # if sort is necessary and how to sort and at what stage to sort + for h_key,h_subnode in py_subcontainer.filter(h_node.items()): + _load(py_subcontainer, h_key, h_subnode, pickle_loads, load_loader) + + # finalize subitem and append to parent container. sub_data = py_subcontainer.convert() - py_container.append(sub_data) + py_container.append(h_name,sub_data,h_node.attrs) else: - # must be a dataset - subdata = load_dataset(h_group) - py_container.append(subdata) + # must be a dataset load it and append to parent container + load_fn = hkl_types_dict.get(base_type, no_match_load) + data = load_fn(h_node,base_type,py_obj_type) + py_container.append(h_name,data,h_node.attrs) - return py_container diff --git a/hickle/loaders/load_astropy.py b/hickle/loaders/load_astropy.py index 7857d78e..e5c59204 100644 --- a/hickle/loaders/load_astropy.py +++ b/hickle/loaders/load_astropy.py @@ -8,7 +8,6 @@ import numpy as np # hickle imports -from hickle.helpers import get_type_and_data # %% FUNCTION DEFINITIONS @@ -19,15 +18,16 @@ def create_astropy_quantity(py_obj, h_group, name, **kwargs): py_obj: python object to dump; should be a python type (int, float, bool etc) h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the - iterable. + name (str): the name of the resulting dataset + + Returns: + dataset representing astropy quantity and empty subitems """ d = h_group.create_dataset(name, data=py_obj.value, dtype='float64', **kwargs) - unit = bytes(str(py_obj.unit), 'ascii') - d.attrs['unit'] = unit - return(d) + d.attrs['unit'] = py_obj.unit.to_string().encode('ascii') + return d,() def create_astropy_angle(py_obj, h_group, name, **kwargs): @@ -37,15 +37,16 @@ def create_astropy_angle(py_obj, h_group, name, **kwargs): py_obj: python object to dump; should be a python type (int, float, bool etc) h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the - iterable. + name (str): the name of the resulting dataset + + Returns: + dataset representing astropy angle and empty subitems """ d = h_group.create_dataset(name, data=py_obj.value, dtype='float64', **kwargs) - unit = str(py_obj.unit).encode('ascii') - d.attrs['unit'] = unit - return(d) + d.attrs['unit'] = py_obj.unit.to_string().encode('ascii') + return d,() def create_astropy_skycoord(py_obj, h_group, name, **kwargs): @@ -55,20 +56,22 @@ def create_astropy_skycoord(py_obj, h_group, name, **kwargs): py_obj: python object to dump; should be a python type (int, float, bool etc) h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the - iterable. + name (str): the name of the resulting dataset + + Returns: + dataset representing astorpy SkyCoord and empty subitems """ - lat = py_obj.data.lat.value lon = py_obj.data.lon.value + lat = py_obj.data.lat.value dd = np.stack((lon, lat), axis=-1) d = h_group.create_dataset(name, data=dd, dtype='float64', **kwargs) - lon_unit = str(py_obj.data.lon.unit).encode('ascii') - lat_unit = str(py_obj.data.lat.unit).encode('ascii') + lon_unit = py_obj.data.lon.unit.to_string().encode('ascii') + lat_unit = py_obj.data.lat.unit.to_string().encode('ascii') d.attrs['lon_unit'] = lon_unit d.attrs['lat_unit'] = lat_unit - return(d) + return d,() def create_astropy_time(py_obj, h_group, name, **kwargs): @@ -78,28 +81,27 @@ def create_astropy_time(py_obj, h_group, name, **kwargs): py_obj: python object to dump; should be a python type (int, float, bool etc) h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the - iterable. - """ + name (str): the name of the resulting dataset - data = py_obj.value - dtype = str(py_obj.value.dtype) + Returns: + dataset representing string astropy time and empty subitems + """ # Need to catch string times - if ' 64): - py_obj = bytes(str(py_obj), 'ascii') + return h_group.create_dataset(name,data = bytearray(str(py_obj), 'ascii'),**kwargs),() - d = h_group.create_dataset(name, data=py_obj, **kwargs) - return(d) + return h_group.create_dataset(name, data=py_obj, **no_compression(kwargs)),() def create_none_dataset(py_obj, h_group, name, **kwargs): @@ -89,81 +64,345 @@ def create_none_dataset(py_obj, h_group, name, **kwargs): Args: py_obj: python object to dump; must be None object h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the - iterable. + name (str): the name of the resulting dataset + Returns: + correspoinding h5py.Dataset and empty subitems list """ - d = h_group.create_dataset(name, data=b'None', **kwargs) - return(d) + return h_group.create_dataset(name, data=bytearray(b'None'),**kwargs),() -def create_pickled_dataset(py_obj, h_group, name, reason=None, **kwargs): - """ If no match is made, raise a warning +def check_iterable_item_type(first_item,iter_obj): + """ + checks if for all items of an iterable sequence (list, tuple, etc.) a least common + dtype exists to which all items can be safely be casted. Args: - py_obj: python object to dump; default if item is not matched. + first_item: the first item of the iterable sequence used to initialize the dtype + iter_obj: the remaing items of the iterable sequence + + Returns: + the least common dtype or none if not all items can be casted + """ + + if ( + operator.length_hint(first_item) > 1 or + ( operator.length_hint(first_item) == 1 and not isinstance(first_item,(str,bytes)) ) or + np.ndim(first_item) != 0 + ): + return None + dtype = np.dtype(first_item.__class__) + if dtype.name == 'object' or 'str' in dtype.name or ( 'bytes' in dtype.name and len(first_item) > 1): + return None + for item in iter_obj: + if np.ndim(item) != 0: + return None + common_dtype = np.result_type(np.dtype(item.__class__),dtype) + if common_dtype.name == 'object' or 'str' in common_dtype.name or ( 'bytes' in common_dtype.name and len(item) > 1 ): + return None + if dtype != common_dtype: + dtype = common_dtype + return dtype + +def create_listlike_dataset(py_obj, h_group, name,list_len = -1,item_dtype = None, **kwargs): + """ Dumper for list, set, tuple + + Args: + py_obj: python object to dump; should be list-like h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the - iterable. + name (str): the name of the resulting dataset + + Returns: + Group or Dataset representing listlike object and a list of subitems toi + be stored within this group. In case of Dataset this list is allways empty """ - reason_str = " (Reason: %s)" % (reason) if reason is not None else "" - pickled_obj = pickle.dumps(py_obj) - d = h_group.create_dataset(name, data=np.array(pickled_obj), **kwargs) - warnings.warn("%r type not understood, data has been serialized%s" - % (py_obj.__class__.__name__, reason_str), SerializedWarning) - return(d) + if isinstance(py_obj,(str,bytes)): + # strings and bytes are stored as array of bytes with strings encoded using utf8 encoding + dataset = h_group.create_dataset( + name, + data = bytearray(py_obj,"utf8") if isinstance(py_obj,str) else bytearray(py_obj), + **kwargs + ) + dataset.attrs["str_type"] = py_obj.__class__.__name__.encode("ascii") + return dataset,() + + if len(py_obj) < 1: + # listlike object is empty just store empty dataset + return h_group.create_dataset(name,shape=None,dtype=int,**no_compression(kwargs)),() + + if list_len < 0: + # neither length nor dtype of items is know compute them now + item_dtype = check_iterable_item_type(py_obj[0],py_obj[1:]) + list_len = len(py_obj) + + if item_dtype or list_len < 1: + # create a dataset and mapp all items to least common dtype + shape = (list_len,) if list_len > 0 else None + dataset = h_group.create_dataset(name,shape = shape,dtype = item_dtype,**kwargs) + for index,item in enumerate(py_obj,0): + dataset[index] = item_dtype.type(item) + return dataset,() + + # crate group and provide generator yielding all subitems to be stored within + item_name = "data{:d}" + def provide_listlike_items(): + for index,item in enumerate(py_obj,0): + yield item_name.format(index),item,{"item_index":index},kwargs + h_subgroup = h_group.create_group(name) + h_subgroup.attrs["num_items"] = list_len + return h_subgroup,provide_listlike_items() + + +def create_setlike_dataset(py_obj,h_group,name,**kwargs): + """ + Creates a dataset or group for setlike objects. + Args: + py_obj: python object to dump; should be list-like + h_group (h5.File.group): group to dump data into. + name (str): the name of the resulting dataset -def load_list_dataset(h_node): - _, _, data = get_type_and_data(h_node) - str_type = h_node.attrs.get('str_type', None) + Returns: + Group or Dataset representing listlike object and a list of subitems toi + be stored within this group. In case of Dataset this list is allways empty + """ - if str_type == b'str': - return(np.array(data, copy=False, dtype=str).tolist()) - else: - return(data.tolist()) + # set objects do not support indexing thus determination of item dtype has to + # be handled specially. Call create_listlike_dataset for proper creation + # of corresponding dataset + if not py_obj: + # dump empty set + return h_group.create_dataset(name,data = list(py_obj),shape = None,dtype = int,**no_compression(kwargs)),() + set_iter = iter(py_obj) + first_item = next(set_iter) + item_dtype = check_iterable_item_type(first_item,set_iter) + return create_listlike_dataset(py_obj,h_group,name,list_len = len(py_obj),item_dtype = item_dtype,**kwargs) + +_byte_slashes = re.compile(b'[\\/]') +_str_slashes = re.compile(r'[\\/]') -def load_tuple_dataset(h_node): - data = load_list_dataset(h_node) - return tuple(data) +def create_dictlike_dataset(py_obj, h_group, name, **kwargs): + """ Creates a data group for each key in dictionary + + Notes: + This is a very important function which uses the recursive _dump + method to build up hierarchical data models stored in the HDF5 file. + As this is critical to functioning, it is kept in the main hickle.py + file instead of in the loaders/ directory. + + Args: + py_obj: python object to dump; should be dictionary + h_group (h5.File.group): group to dump data into. + name (str): h5 node name + iterable. + """ + + h_dictgroup = h_group.create_group(name) + key_value_pair_name = "data{:d}" + + def package_dict_items(): + """ + generator yielding appropriate parameters for dumping each + dict key value pair + """ + for idx, (key, py_subobj) in enumerate(py_obj.items()): + # Obtain the raw string representation of this key + key_base_type = key.__class__.__name__.encode("utf8") + if isinstance(key,str): + if not _str_slashes.search(key): + yield r'"{}"'.format(key),py_subobj,{'key_idx':idx,'key_base_type':key_base_type},kwargs + continue + elif isinstance(key,bytes): + if not _byte_slashes.search(key): + try: + h_key = key.decode("utf8") + except UnicodeError: # pragma nocover + pass + else: + yield r'b"{}"'.format(h_key),py_subobj,{'key_idx':idx,'key_base_type':key_base_type},kwargs + continue + elif key_base_type in dict_key_types_dict: + h_key = "{!r}".format(key) + if not _str_slashes.search(h_key): + yield h_key,py_subobj,{'key_idx':idx,'key_base_type':key_base_type},kwargs + continue + sub_node_name = key_value_pair_name.format(idx) + yield sub_node_name,(key,py_subobj),{'key_idx':idx,'key_base_type':b'key_value'},kwargs + return h_dictgroup,package_dict_items() + + + +def load_scalar_dataset(h_node,base_type,py_obj_type): + """ + loads scalar dataset + + Args: + h_node (h5py.Dataset): the hdf5 node to load data from + base_type (bytes): bytes string denoting base_type + py_obj_type: final type of restored scalar + + Returns: + resulting python object of type py_obj_type + """ + data = h_node[()] if h_node.size < 2 else bytearray(h_node[()]) -def load_set_dataset(h_node): - data = load_list_dataset(h_node) - return set(data) + return py_obj_type(data) if data.__class__ is not py_obj_type else data -def load_none_dataset(h_node): +def load_none_dataset(h_node,base_type,py_obj_type): + """ + returns None value as represented by underlying dataset + """ return None + +def load_list_dataset(h_node,base_type,py_obj_type): + """ + loads any kind of list like dataset + Args: + h_node (h5py.Dataset): the hdf5 node to load data from + base_type (bytes): bytes string denoting base_type + py_obj_type: final type of restored scalar -def load_pickled_data(h_node): - _, _, data = get_type_and_data(h_node) - return pickle.loads(data) + Returns: + resulting python object of type py_obj_type + """ + if h_node.shape is None: + # empty list tuple or set just return new instance of py_obj_type + return py_obj_type() if isinstance(py_obj_type,tuple) else py_obj_type(()) + str_type = h_node.attrs.get('str_type', None) + content = h_node[()] + if str_type == b'str': -def load_scalar_dataset(h_node): - _, base_type, data = get_type_and_data(h_node) + if "bytes" in h_node.dtype.name: + # string dataset 4.0.x style convert it back to python string + content = np.array(content, copy=False, dtype=str).tolist() + else: + # decode bytes representing python string before final conversion + content = bytes(content).decode("utf8") + return py_obj_type(content) if content.__class__ is not py_obj_type else content - if(base_type == b'int'): - data = int(data) +class ListLikeContainer(PyContainer): + """ + PyContainer for all list like objects excempt set + """ - return(data) + __slots__ = () + + # regular expression used to extract index value from name of group or dataset + # representing subitem appended to the final list + extract_index = re.compile(r'\d+$') + + # as None can and may be a valid list entry define an alternative marker for + # missing items and indices + def __init__(self,h5_attrs,base_type,object_type): + # if number of items is defind upon group resize content to + # at least match this ammount of subitems + num_items = h5_attrs.get('num_items',0) + super(ListLikeContainer,self).__init__(h5_attrs,base_type,object_type,_content = [nobody_is_my_name] * num_items) + + def append(self,name,item,h5_attrs): + # load item index from attributes if known else extract it from name + index = h5_attrs.get("item_index",None) + if index is None: + index_match = self.extract_index.search(name) + if index_match is None: + if item is nobody_is_my_name: + # dummy data injected likely by load_nothing, ignore it + return + raise KeyError("List like item name '{}' not understood".format(name)) + index = int(index_match.group(0)) + # if index exceeds capacity of extend list apropriately + if len(self._content) <= index: + self._content.extend([nobody_is_my_name] * ( index - len(self._content) + 1 )) + if self._content[index] is not nobody_is_my_name: + raise IndexError("Index {} already set".format(index)) + self._content[index] = item + + def convert(self): + return self._content if self.object_type is self._content.__class__ else self.object_type(self._content) + +class SetLikeContainer(PyContainer): + """ + PyContainer for all set like objects. + """ + __slots__ = () + + def __init__(self,h5_attrs, base_type, object_type): + super(SetLikeContainer,self).__init__(h5_attrs,base_type,object_type,_content=set()) + + + def append(self,name,item,h5_attrs): + self._content.add(item) + + def convert(self): + return self._content if self._content.__class__ is self.object_type else self.object_type(self._content) + +class DictLikeContainer(PyContainer): + """ + PyContainer for all dict like objects + """ + __slots__ = () + + + _swap_key_slashes = re.compile(r"\\") + + def append(self,name,item,h5_attrs): + key_base_type = h5_attrs.get('key_base_type',b'') + if key_base_type == b'str': + item = ( + name[1:-1] if name[0] == '"' else self._swap_key_slashes.sub(r'/',name)[1:-1], + item + ) + elif key_base_type == b'bytes': + item = ( + name[2:-1].encode("utf8") if name[:2] == 'b"' else self._swap_key_slashes.sub(r'/',name)[1:-1], + item + ) + elif not key_base_type == b'key_value': + load_key = dict_key_types_dict.get(key_base_type,None) + if load_key is None: + if key_base_type not in {b'tuple'}: + raise ValueError("key type '{}' not understood".format(key_base_type.decode("utf8"))) + load_key = eval + item = ( + load_key(self._swap_key_slashes.sub(r'/',name)), + item + ) + key_index = h5_attrs.get('key_idx',None) + if key_index is None: + if item[1] is nobody_is_my_name: + # dummy data injected most likely by load_nothing ignore it + return + raise KeyError("invalid dict item key_index missing") + if len(self._content) <= key_index: + self._content.extend([nobody_is_my_name] * ( key_index - len(self._content) + 1)) + if self._content[key_index] is not nobody_is_my_name: + raise IndexError("Key index {} already set".format(key_index)) + self._content[key_index] = item + + def convert(self): + return self.object_type(self._content) + + # %% REGISTERS class_register = [ - [list, b"list", create_listlike_dataset, load_list_dataset], - [tuple, b"tuple", create_listlike_dataset, load_tuple_dataset], - [set, b"set", create_listlike_dataset, load_set_dataset], - [bytes, b"bytes", create_scalar_dataset, load_scalar_dataset], - [str, b"str", create_scalar_dataset, load_scalar_dataset], + [list, b"list", create_listlike_dataset, load_list_dataset,ListLikeContainer], + [tuple, b"tuple", create_listlike_dataset, load_list_dataset,ListLikeContainer], + [dict, b"dict",create_dictlike_dataset,None,DictLikeContainer], + [set, b"set", create_setlike_dataset, load_list_dataset,SetLikeContainer], + [bytes, b"bytes", create_listlike_dataset, load_list_dataset], + [str, b"str", create_listlike_dataset, load_list_dataset], [int, b"int", create_scalar_dataset, load_scalar_dataset], [float, b"float", create_scalar_dataset, load_scalar_dataset], [complex, b"complex", create_scalar_dataset, load_scalar_dataset], [bool, b"bool", create_scalar_dataset, load_scalar_dataset], - [type(None), b"None", create_none_dataset, load_none_dataset], - [object, b"pickle", create_pickled_dataset, load_pickled_data]] + [None.__class__, b"None", create_none_dataset, load_none_dataset] +] exclude_register = [] diff --git a/hickle/loaders/load_numpy.py b/hickle/loaders/load_numpy.py index 38c40a86..de9487c4 100644 --- a/hickle/loaders/load_numpy.py +++ b/hickle/loaders/load_numpy.py @@ -9,27 +9,14 @@ # %% IMPORTS # Package imports import numpy as np -import dill as pickle +import types # hickle imports -from hickle.helpers import get_type_and_data -from hickle.hickle import _dump +from hickle.loaders.load_builtins import create_listlike_dataset,ListLikeContainer +from hickle.helpers import PyContainer,no_compression # %% FUNCTION DEFINITIONS -def check_is_numpy_array(py_obj): - """ Check if a python object is a numpy array (masked or regular) - - Args: - py_obj: python object to check whether it is a numpy array - - Returns - is_numpy (bool): Returns True if it is a numpy array, else False if it - isn't - """ - - return(isinstance(py_obj, np.ndarray)) - def create_np_scalar_dataset(py_obj, h_group, name, **kwargs): """ dumps an np dtype object to h5py file @@ -38,14 +25,16 @@ def create_np_scalar_dataset(py_obj, h_group, name, **kwargs): py_obj: python object to dump; should be a numpy scalar, e.g. np.float16(1) h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the - iterable. + name (str): the name of the resulting dataset + + Returns: + Dataset and empty list of subitems """ - d = h_group.create_dataset(name, data=py_obj, **kwargs) + d = h_group.create_dataset(name, data=py_obj, **no_compression(kwargs)) - d.attrs["np_dtype"] = bytes(str(d.dtype), 'ascii') - return(d) + d.attrs["np_dtype"] = py_obj.dtype.str.encode("ascii") + return d,() def create_np_dtype(py_obj, h_group, name, **kwargs): @@ -55,11 +44,13 @@ def create_np_dtype(py_obj, h_group, name, **kwargs): py_obj: python object to dump; should be a numpy dtype, e.g. np.float16 h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the - iterable. + name (str): the name of the resulting dataset + + Returns: + Dataset and empty list of subitems """ - d = h_group.create_dataset(name, data=str(py_obj), **kwargs) - return(d) + d = h_group.create_dataset(name, data=bytearray(py_obj.str,"ascii"), **kwargs) + return d,() def create_np_array_dataset(py_obj, h_group, name, **kwargs): @@ -69,88 +60,176 @@ def create_np_array_dataset(py_obj, h_group, name, **kwargs): py_obj: python object to dump; should be a numpy array or np.ma.array (masked) h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the - iterable. + name (str): the name of the resulting dataset or group + + Returns: + Datset and empty list of subitems or Group and iterable of subitems """ # Obtain dtype of py_obj - dtype = str(py_obj.dtype) + dtype = py_obj.dtype # Check if py_obj contains strings - if ' 1: + + # object could be successfully reduced. Create a dedicated + # object group. Prepare all items of the returned tuple which + # are not None to be dumped into this group. The first two items + # 'create' and 'args' are according to pickle documentation + # manadatory. The remaining ones 'state', 'set', 'list', 'keys' and + # 'values' are optional. They shall be ommited in case either not + # present or None. + # 'create' ... the method to be called upon load to restore the object + # 'args' ..... the arguments to be passed to the 'create' method + # 'state' .... state value to be passed to __setstate__ or 'set' method + # 'set' ...... method to called instead of __setstate__ of object + # 'list' ..... tuple representing the items provided by list iterator + # 'keys' ..... tuple representing dict keys provided by dict iterator + # 'values' ... tuple representing dict values provided by dict iterator + object_group = h_group.create_group(name) + subitems = [ + ('create',reduced_obj[0],{},kwargs), + ('args',reduced_obj[1],{},kwargs), + None,None,None,None,None + ] + if len(reduced_obj) < 3: + return object_group,subitems[:2] + next_set = 2 + if reduced_obj[2] is not None: + subitems[next_set] = ('state',reduced_obj[2],{},kwargs) + next_set += 1 + if len(reduced_obj) > 5: + subitems[next_set] =('set',reduced_obj[5],{},kwargs) + next_set += 1 + if len(reduced_obj) < 4: + return object_group,subitems[:next_set] + if reduced_obj[3] is not None: + subitems[next_set] = ('list',tuple(reduced_obj[3]),{},kwargs) + next_set += 1 + if len(reduced_obj) < 5 or reduced_obj[4] is None: + return object_group,subitems[:next_set] + keys,values = zip(*reduced_obj[4]) if operator.length_hint(py_obj) > 0 else ((),()) + subitems[next_set] = ('keys',keys,{},kwargs) + next_set += 1 + subitems[next_set] = ('values',values,{},kwargs) + return object_group,subitems[:next_set+1] + + # for what ever reason py_obj could not be successfully reduced + # ask pickle for help and report to user. + reason_str = " (Reason: %s)" % (reason) if reason is not None else "" + warnings.warn( + "{!r} type not understood, data is serialized:{:s}".format( + py_obj.__class__.__name__, reason_str + ), + SerializedWarning + ) + + # store object as pickle string + pickled_obj = pickle.dumps(py_obj) + d = h_group.create_dataset(name, data=bytearray(pickled_obj), **kwargs) + return d,() + +def load_pickled_data(h_node, base_type, py_obj_type): + """ + loade pickle string and return resulting py_obj + """ + + return pickle.loads(h_node[()]) + +class PickledContainer(PyContainer): + """ + PyContainer handling restore of object from object group + """ + + _notset = () + + def __init__(self,h5_attrs, base_type, object_type): + super(PickledContainer,self).__init__(h5_attrs,base_type,object_type,_content = dict()) + + def append(self,name,item,h5_attrs): + self._content[name] = item + + def convert(self): + + # create the python object + py_obj = self._content['create'](*self._content['args']) + state = self._content.get('state',self._notset) + + if state is not self._notset: + # restore its state + set_state = self._content.get('set',None) + if set_state is None: + set_state = getattr(py_obj.__class__,'__setstate__',None) + if set_state is None: + if isinstance(state,dict): + object_dict = getattr(py_obj,'__dict__',None) + if object_dict is not None: + object_dict.update(state) + elif not isinstance(state,bool) or state: + set_state(py_obj,state) + list_iter = self._content.get('list',None) + if list_iter is not None: + # load any list values + if len(list_iter) < 2: + py_obj.append(list_iter[0]) + else: + extend = getattr(py_obj,'extend',None) + if extend is not None: + extend(list_iter) + else: + for item in list_iter: + py_obj.append(item) + + # load any dict values + keys = self._content.get('keys',None) + if keys is None: + return py_obj + values = self._content.get('values',None) + if values is None: + return py_obj + for key,value in zip(keys,values): + py_obj[key] = value + return py_obj + +# no dump method is registered for object as this is the default for +# any unknown object and for classes, functions and methods +register_class(object,b'pickle',None,load_pickled_data,PickledContainer) + + +def _moc_numpy_array_object_lambda(x): + """ + drop in replacement for lambda object types which seem not + any more be accepted by pickle for Python 3.8 and onward. + see fix_lambda_obj_type function below + + Parameters: + ----------- + x (list): itemlist from which to return first element + + Returns: + first element of provided list + """ + return x[0] + +register_class(_moc_numpy_array_object_lambda,b'moc_lambda',dump_nothing,load_nothing) + +def fix_lambda_obj_type(bytes_object, *, fix_imports=True, encoding="ASCII", errors="strict"): + """ + drop in replacement for pickle.loads method when loading files created by hickle 4.0.x + It captures any TypeError thrown by pickle.loads when encountering a picle string + representing a lambda function used as py_obj_type for a h5py.Dataset or h5py.Group + While in Python <3.8 pickle loads creates the lambda Python >= 3.8 throws an + error when encountering such a pickle string. This is captured and _moc_numpy_array_object_lambda + returned instead. futher some h5py.Group and h5py.Datasets do not provide any + py_obj_type for them object is returned assuming that proper loader has been identified + by other objects already + """ + if bytes_object is None: + return object + try: + return pickle.loads(bytes_object,fix_imports=fix_imports,encoding=encoding,errors=errors) + except TypeError: + print("reporting ",_moc_numpy_array_object_lambda) + return _moc_numpy_array_object_lambda diff --git a/hickle/tests/generate_legacy_4_0_0.py b/hickle/tests/generate_legacy_4_0_0.py new file mode 100644 index 00000000..18166154 --- /dev/null +++ b/hickle/tests/generate_legacy_4_0_0.py @@ -0,0 +1,155 @@ +#! /usr/bin/env python +# encoding: utf-8 +""" +# generate_legacy_4_0_0.py + +Creates datastructure to be dumped to the hickle_4_0_0.hkl file. + +When run as script under hickle 4.0.0 or hickle 4.0.1 it will +result in a valid legacy 4.0.0 file which can be used to tests +that later version are still capable loading hickle 4.0.0 format +files. + +When imported by any of the tests the method generate_py_object +returns the same datasstructure stored to the prior generated file. + +""" +import sys +sys.path.insert(0,"../..") +import hickle +import numpy as np +import scipy +import scipy.sparse +import astropy +import collections +import dill as pickle +import os.path + +def generate_py_object(): + """ + create a data structure covering all or at least the most obvious, + prominent and most likely breaking differences betwen hickle + 4.0.0/4.0.1 version and Versions > 4.1.0 + + Returns: + listobject containing all the relevant data objects and the + filename of the file the data has been stored to or shall be + stored to. + """ + scriptdir = os.path.split(__file__)[0] + some_string = "this is some string to be dumped by hickle 4.0.0" + some_bytes = b"this is the same in bytes instead of utf8" + some_char_list = list(some_string) + some_bytes_list = list(some_bytes) + some_numbers = tuple(range(50)) + some_floats = tuple( float(f) for f in range(50)) + mixed = list( f for f in ( some_numbers[i//2] if i & 1 else some_floats[i//2] for i in range(100) ) ) + wordlist = ["hello","world","i","like","you"] + byteslist = [ s.encode("ascii") for s in wordlist] + mixus = [some_string,some_numbers,12,11] + numpy_array = np.array([ + [ + 0.8918443906408066, 0.5408942506873636, 0.43463333793335346, 0.21382281373491407, + 0.14580527098359963, 0.6869306139451369, 0.22954988509310692, 0.2833880251470392, + 0.8811201329390297, 0.4144190218983931, 0.06595369247674943 + ], [ + 0.8724300029833221, 0.7173303189807705, 0.5721666862018427, 0.8535567654595188, + 0.5806566016388102, 0.9921250367638187, 0.07104048226766191, 0.47131100732975095, + 0.8006065068241431, 0.2804909335297441, 0.1968823602346148 + ], [ + 0.0515177648326276, 0.1852582437284651, 0.22016412062225577, 0.6393104121476216, + 0.7751103631149562, 0.12810902186723572, 0.09634877693000932, 0.2388423061420949, + 0.5730001119950099, 0.1197268172277629, 0.11539619086292308 + ], [ + 0.031751102230864414, 0.21672180477587166, 0.4366501648161476, 0.9549518596659471, + 0.42398684476912474, 0.04490851499559967, 0.7394234049135264, 0.7378312792413693, + 0.9808812550712923, 0.2488404519024885, 0.5158454824458993 + ], [ + 0.07550969197984403, 0.08485317435746553, 0.15760274251917195, 0.18029979414515496, + 0.9501707036126847, 0.1723868250469468, 0.7951538687631865, 0.2546219217084682, + 0.9116518509985955, 0.6930255788272572, 0.9082828280630456 + ], [ + 0.6712307672376565, 0.367223385378443, 0.9522931417348294, 0.714592360187415, + 0.18334824241062575, 0.9322238504996762, 0.3594776411821822, 0.6302097368268973, + 0.6281766915388312, 0.7114942437206809, 0.6977764481953693 + ], [ + 0.9541502922560433, 0.47788295940203784, 0.6511716236981558, 0.4079446664375711, + 0.2747969334307605, 0.3571662787734283, 0.10235638316970186, 0.8567343897483571, + 0.6623468654315807, 0.21377047332104315, 0.860146852430476 + ] + ]) + mask = np.array([ + [0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0], + [1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1], + [0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0], + [1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0], + [0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1], + [0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1], + [0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1] + ]) + + numpy_array_masked = np.ma.array(numpy_array, dtype='float32', mask=mask) + plenty_dict = { + "string":1, + b'bytes':2, + 12:3, + 0.55:4, + complex(1,4):5, + (1,):6, + tuple(mixus):7, + ():8, + '9':9, + None:10, + 'a/b':11 + } + odrdered_dict = collections.OrderedDict(((3, [3, 0.1]), (7, [5, 0.1]), (5, [3, 0.1]))) + + row = np.array([0, 0, 1, 2, 2, 2]) + col = np.array([0, 2, 2, 0, 1, 2]) + data = np.array([1, 2, 3, 4, 5, 6]) + csr_matrix = scipy.sparse.csr_matrix((data, (row, col)), shape=(3, 3)) + csc_matrix = scipy.sparse.csc_matrix((data, (row, col)), shape=(3, 3)) + + indptr = np.array([0, 2, 3, 6]) + indices = np.array([0, 2, 2, 0, 1, 2]) + data = np.array([1, 2, 3, 4, 5, 6]).repeat(4).reshape(6, 2, 2) + bsr_matrix = scipy.sparse.bsr_matrix((data, indices, indptr), shape=(6, 6)) + numpy_string = np.array(some_string) + numpy_bytes = np.array(some_bytes) + numpy_wordlist = np.array(wordlist) + numpy_dict = np.array({}) + + return [ + some_string , + some_bytes , + some_char_list , + some_bytes_list , + some_numbers , + some_floats , + mixed , + wordlist , + byteslist , + mixus , + numpy_array , + mask , + numpy_array_masked , + plenty_dict , + odrdered_dict , + csr_matrix , + csc_matrix , + bsr_matrix , + numpy_string , + numpy_bytes , + numpy_wordlist , + numpy_dict + ],os.path.join(scriptdir,"legacy_hkls","hickle_4.0.0.hkl") + +if __name__ == '__main__': + # create the file by dumping using hickle but only if + # the availabe hickle version is >= 4.0.0 and < 4.1.0 + hickle_version = hickle.__version__.split('.') + if hickle_version[0] != 4 or hickle_version[1] > 0: + raise RuntimeError("Shall be run using < 4.1 only") + scriptdir = os.path.split(__file__)[0] + now_dumping,testfile = generate_py_object() + hickle.dump(now_dumping,testfile) diff --git a/hickle/tests/hickle_loaders/__init__.py b/hickle/tests/hickle_loaders/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/hickle/tests/hickle_loaders/load_builtins.py b/hickle/tests/hickle_loaders/load_builtins.py new file mode 100644 index 00000000..7e64ebb6 --- /dev/null +++ b/hickle/tests/hickle_loaders/load_builtins.py @@ -0,0 +1,12 @@ +def create_package_test(myclass_type,h_group,name,**kwargs): + return h_group,() + +def load_package_test(h_node,base_type,py_obj_type): + return {12:12} + + +class_register = [ + ( dict,b'dict',create_package_test,load_package_test ) +] + +exclude_register = [b'please_kindly_ignore_me'] diff --git a/hickle/tests/legacy_hkls/hickle_4.0.0.hkl b/hickle/tests/legacy_hkls/hickle_4.0.0.hkl new file mode 100644 index 00000000..aa10cdb4 Binary files /dev/null and b/hickle/tests/legacy_hkls/hickle_4.0.0.hkl differ diff --git a/hickle/tests/test_01_hickle_helpers.py b/hickle/tests/test_01_hickle_helpers.py new file mode 100644 index 00000000..bdb308c3 --- /dev/null +++ b/hickle/tests/test_01_hickle_helpers.py @@ -0,0 +1,150 @@ +#! /usr/bin/env python +# encoding: utf-8 +""" +# test_hickle_helpers.py + +Unit tests for hickle module -- helper functions. + +""" + +import pytest + +# %% IMPORTS +# Package imports +import numpy as np +import dill as pickle +import operator + +# hickle imports +from hickle.helpers import PyContainer,H5NodeFilterProxy,no_compression +from py.path import local + +# Set current working directory to the temporary directory +local.get_temproot().chdir() + + +# %% DATA DEFINITIONS + +dummy_data = (1,2,3) + + +# %% FIXTURES + +@pytest.fixture +def h5_data(request): + """ + create dummy hdf5 test data file for testing PyContainer and H5NodeFilterProxy + """ + + # create file and create a dataset the attributes of which will lateron be + # modified + import h5py as h5 + dummy_file = h5.File('hickle_helpers_{}.hdf5'.format(request.function.__name__),'w') + filename = dummy_file.filename + test_data = dummy_file.create_dataset("somedata",data=dummy_data,dtype='i') + test_data.attrs['type'] = np.array(pickle.dumps(tuple)) + test_data.attrs['base_type'] = b'tuple' + test_data.attrs['someattr'] = 12 + test_data.attrs['someother'] = 11 + + # writeout the file reopen it read only + dummy_file.flush() + dummy_file.close() + dummy_file = h5.File(filename,'r') + + # provide the file and close afterwards + yield dummy_file + dummy_file.close() + +# %% FUNCTION DEFINITIONS + +def test_no_compression(): + """ + test no_compression filter for temporarily hiding comression related + kwargs from h5py.create_dataset method + """ + + # simulate kwargs without compression related + kwargs = {'hello':1,'word':2} + assert dict(no_compression(kwargs)) == kwargs + + # simulate kwargs including all relevant keyword arguments + kwargs2 = dict(kwargs) + kwargs2.update({ + "compression":True, + "shuffle":True, + "compression_opts":8, + "chunks":512, + "fletcher32":True, + "scaleoffset":20 + }) + assert dict(no_compression(kwargs2)) == kwargs + +def test_py_container(h5_data): + """ + test abstract PyContainer base class defining container interface + and providing default implementations for append and filter + """ + + # test default implementation of append + container = PyContainer({},b'list',list) + container.append('data0',1,{}) + container.append('data1','b',{}) + + # ensure that default implementation of convert enforces overload by + # derived PyContainer classes by raising NotImplementedError + with pytest.raises(NotImplementedError): + my_list = container.convert() + + # test default implementation of PyContainer.filter method which + # simply shall yield from passed in itrator + assert [ item for item in dummy_data ] == list(dummy_data) + assert dict(container.filter(h5_data.items())) == {'somedata':h5_data['somedata']} + + +def test_H5NodeFilterProxy(h5_data): + """ + tests H5NodeFilterProxy class. This class allows to temporarily rewrite + attributes of h5py.Group and h5py.Dataset nodes before beeing loaded by + hickle._load method. + """ + + # load data and try to directly modify 'type' and 'base_type' Attributes + # which will fail cause hdf5 file is opened for read only + h5_node = h5_data['somedata'] + with pytest.raises(OSError): + h5_node.attrs['type'] = pickle.dumps(list) + with pytest.raises(OSError): + h5_node.attrs['base_type'] = b'list' + + # verify that 'type' expands to tuple before running + # the remaining tests + object_type = pickle.loads(h5_node.attrs['type']) + assert object_type is tuple + assert object_type(h5_node[()].tolist()) == dummy_data + + # Wrap node by H5NodeFilterProxy and rerun the above tests + # again. This time modifying Attributes shall be possible. + h5_node = H5NodeFilterProxy(h5_node) + h5_node.attrs['type'] = pickle.dumps(list) + h5_node.attrs['base_type'] = b'list' + object_type = pickle.loads(h5_node.attrs['type']) + assert object_type is list + + # test proper pass through of item and attribute access + # to wrapped h5py.Group or h5py.Dataset object respective + assert object_type(h5_node[()].tolist()) == list(dummy_data) + assert h5_node.shape == np.array(dummy_data).shape + with pytest.raises(AttributeError,match = r"can't\s+set\s+attribute"): + h5_node.dtype = np.float32 + +# %% MAIN SCRIPT +if __name__ == "__main__": + from _pytest.fixtures import FixtureRequest + + test_no_compression() + for data in h5_data(FixtureRequest(test_py_container)): + test_py_container(data) + for data in h5_data(FixtureRequest(test_py_container)): + test_H5NodeFilterProxy(data) + diff --git a/hickle/tests/test_02_hickle_lookup.py b/hickle/tests/test_02_hickle_lookup.py new file mode 100644 index 00000000..1a963fb1 --- /dev/null +++ b/hickle/tests/test_02_hickle_lookup.py @@ -0,0 +1,698 @@ +#! /usr/bin/env python +# encoding: utf-8 +""" +# test_hickle_lookup.py + +Unit tests for hickle module -- lookup functions. + +""" + +# %% IMPORTS +import pytest +import sys +import types + +# Package imports +import numpy as np +import h5py +import dill as pickle +from importlib.util import find_spec +from importlib import reload +from py.path import local + +# hickle imports +from hickle.helpers import PyContainer,not_dumpable +import hickle.lookup as lookup + +# Set current working directory to the temporary directory +local.get_temproot().chdir() + +# %% DATA DEFINITIONS + +dummy_data = (1,2,3) + + +# %% FIXTURES + +@pytest.fixture +def h5_data(request): + """ + create dummy hdf5 test data file for testing PyContainer and H5NodeFilterProxy + uses name of executed test as part of filename + """ + + dummy_file = h5py.File('hickle_lookup_{}.hdf5'.format(request.function.__name__),'w') + filename = dummy_file.filename + test_data = dummy_file.create_group("root_group") + yield test_data + dummy_file.close() + +@pytest.fixture() +def loader_table(): + + """ + create a class_register and a exclude_register table for testing + register_class and register_class_exclude functions + + 0: dataset only loader + 1: PyContainer only loader + 2: not dumped loader + 3: external loader module trying to overwrite hickle core loader + 4: hickle loader moudle trying to overload hickle core loader + 3: loader defined by hickle core + + """ + + # clear loaded_loaders, types_dict, hkl_types_dict and hkl_contianer_dict + # to ensure no loader preset by hickle core or hickle loader module + # intervenes with test + global lookup + lookup.loaded_loaders.clear() + lookup.types_dict.clear() + lookup.hkl_types_dict.clear() + lookup.hkl_container_dict.clear() + + # simulate loader definitions found within loader modules + def create_test_dataset(myclass_type,h_group,name,**kwargs): + return h_group,() + + def load_test_dataset(h_node,base_type,py_obj_type): + return 12 + + class TestContainer(PyContainer): + def convert(self): + return self._content[0] + + class NotHicklePackage(TestContainer): + """ + checks if container_class provided by module outside + hickle package tries to define alternative loader for + IteratorProxy class handled by hickle core directly + """ + __module__ = "nothickle.loaders.load_builtins" + + class HickleLoadersModule(TestContainer): + """ + checks if container_class provided by + hickle.loaders module tries to define alternative loader for + IteratorProxy class handled by hickle core directly + """ + __module__ = "hickle.loaders.load_builtins" + + class IsHickleCore(TestContainer): + """ + Simulates loader registered by hickle.hickle module + """ + __module__ = "hickle.hickle" + + # provide the table + yield [ + (int,b'int',create_test_dataset,load_test_dataset,None), + (list,b'list',create_test_dataset,None,TestContainer), + (tuple,b'tuple',None,load_test_dataset,TestContainer), + (lookup._DictItem,b'dict_item',None,None,NotHicklePackage), + (lookup._DictItem,b'pickle',None,None,HickleLoadersModule), + (lookup._DictItem,b'dict_item',lookup.register_class,None,IsHickleCore) + ] + + # cleanup and reload hickle.lookup module to reset it to its initial state + # in case hickle.hickle has already been preloaded by pytest also reload it + # to ensure no sideffectes occur during later tests + lookup.loaded_loaders.clear() + lookup.types_dict.clear() + lookup.hkl_types_dict.clear() + lookup.hkl_container_dict.clear() + reload(lookup) + lookup = sys.modules[lookup.__name__] + hickle_hickle = sys.modules.get("hickle.hickle",None) + if hickle_hickle is not None: + reload(hickle_hickle) + +# %% CLASS DEFINITIONS + +class ToBeInLoadersOrNotToBe(): + """ + Dummy class used to check that only loaders for Python objects + are accepted by load_loader which are either declared + outside hickle or are pre registered by hickle core through directly + calling register_class or are declared by a load_.py module + within the pickle.loaders package + + Also it is used in simulating reduced object tuple with all trailing + None items removed + """ + __slots__ = () + + def __reduce_ex__(self,proto = pickle.DEFAULT_PROTOCOL): + reduced = super(ToBeInLoadersOrNotToBe,self).__reduce_ex__(proto) + for index,item in enumerate(reduced[:1:-1],0): + if item is not None: + return reduced[:(-index if index > 0 else None)] + return reduced + + def __reduce__(self): + reduced = super(ToBeInLoadersOrNotToBe,self).__reduce__() + for index,item in enumerate(reduced[:1:-1],0): + if item is not None: + return reduced[:(-index if index > 0 else None)] + return reduced + + def __eq__(self,other): + return other.__class__ is self.__class__ + + def __ne__(self,other): + return self != other + + +class MetaClassToDump(type): + """ + Metaclass for ClassToDump allowing to controll which + unbound class methods and magic methods are visible to + create_pickled_dataset method and which not at the + class level + """ + + # any function listed therein is not defined on class + # when called the next time (single shot) + hide_special = set() + + def __getattribute__(self,name): + if name in MetaClassToDump.hide_special: + MetaClassToDump.hide_special.remove(name) + raise AttributeError("") + return super(MetaClassToDump,self).__getattribute__(name) + + +class ClassToDump(metaclass=MetaClassToDump): + """ + Primary class used to test create_pickled_dataset function + """ + def __init__(self,hallo,welt,with_default=1): + self._data = hallo,welt,with_default + + def dump_boundmethod(self): + """ + dummy instance method used to check if instance methods are + either rejected or allways stored as pickle string + """ + pass + + @staticmethod + def dump_staticmethod(): + """ + dummy static method used to check if static methods are allways + stored as pickle string + """ + pass + + @classmethod + def dump_classmethod(cls): + """ + dummy class method used to check if class methods are allways + stored as pickle string + """ + pass + + def __eq__(self,other): + return other.__class__ is self.__class__ and self._data == other._data + + def __ne__(self,other): + return self != other + + def __getattribute__(self,name): + # ensure that methods which are hidden by metaclass are also not + # accessible from class instance + if name in MetaClassToDump.hide_special: + raise AttributeError("") + return super(ClassToDump,self).__getattribute__(name) + + def __getstate__(self): + # returns the state of this class when asked by copy protocol handler + return self.__dict__ + + def __setstate__(self,state): + + # set the state from the passed state description + self.__dict__.update(state) + + # controls whether the setstate method is reported as + # sixth element of tuple returned by __reduce_ex__ or + # __reduce__ function or not + extern_setstate = False + + def __reduce_ex__(self,proto = pickle.DEFAULT_PROTOCOL): + state = super(ClassToDump,self).__reduce_ex__(proto) + if len(state) > 5 or not ClassToDump.extern_setstate: + return state + return (*state,*( (None,) * ( 5 - len(state)) ),ClassToDump.__setstate__) + + def __reduce__(self): + state = super(ClassToDump,self).__reduce__() + if len(state) > 5 or not ClassToDump.extern_setstate: + return state + return (*state,*( (None,) * ( 5 - len(state)) ),ClassToDump.__setstate__) + +class SimpleClass(): + """ + simple classe used to check that instance __dict__ is properly dumped and + restored by create_pickled_dataset and PickledContainer + """ + def __init__(self): + self.someattr = "im some attr" + self.someother = 12 + + def __eq__(self,other): + return other.__class__ is self.__class__ and self.__dict__ == other.__dict__ + + def __ne__(self,other): + return self != other + +class NoExtendList(list): + """ + special list class used to test whether append is properly used + when list like object is dumped and restored through create_pickled_dataset + and PickledContainer + """ + + def __getattribute__(self,name): + if name == "extend": + raise AttributeError("no extend") + return super(NoExtendList,self).__getattribute__(name) + +# %% FUNCTION DEFINITIONS + +def function_to_dump(hallo,welt,with_default=1): + """ + non class function to be dumpled and restored through + create_pickled_dataset and load_pickled_data + """ + return hallo,welt,with_default + +def test_register_class(loader_table): + """ + tests the register_class method + """ + + # try to register dataset only loader specified by loader_table + # and retrieve its contents from types_dict and hkl_types_dict + loader_spec = loader_table[0] + lookup.register_class(*loader_spec) + assert lookup.types_dict[loader_spec[0]] == loader_spec[2:0:-1] + assert lookup.hkl_types_dict[loader_spec[1]] == loader_spec[3] + with pytest.raises(KeyError): + lookup.hkl_container_dict[loader_spec[1]] is None + + # try to register PyContainer only loader specified by loader_table + # and retrive its contents from types_dict and hkl_contianer_dict + loader_spec = loader_table[1] + lookup.register_class(*loader_spec) + assert lookup.types_dict[loader_spec[0]] == loader_spec[2:0:-1] + with pytest.raises(KeyError): + lookup.hkl_types_dict[loader_spec[1]] is None + assert lookup.hkl_container_dict[loader_spec[1]] == loader_spec[4] + + + # try to register container without dump_function specified by + # loader table and try to retrive load_function and PyContainer from + # hkl_types_dict and hkl_container_dict + loader_spec = loader_table[2] + lookup.register_class(*loader_spec) + with pytest.raises(KeyError): + lookup.types_dict[loader_spec[0]][1] == loader_spec[1] + assert lookup.hkl_types_dict[loader_spec[1]] == loader_spec[3] + assert lookup.hkl_container_dict[loader_spec[1]] == loader_spec[4] + + # try to register loader shadowing loader preset by hickle core + # defined by external loader module + loader_spec = loader_table[3] + with pytest.raises(TypeError,match = r"loader\s+for\s+'\w+'\s+type\s+managed\s+by\s+hickle\s+only"): + lookup.register_class(*loader_spec) + loader_spec = loader_table[4] + + # try to register loader shadowing loader preset by hickle core + # defined by hickle loaders module + with pytest.raises(TypeError,match = r"loader\s+for\s+'\w+'\s+type\s+managed\s+by\s+hickle\s+core\s+only"): + lookup.register_class(*loader_spec) + + # simulate registering loader preset by hickle core + loader_spec = loader_table[5] + lookup.register_class(*loader_spec) + +def test_register_class_exclude(loader_table): + """ + test registr class exclude function + """ + + # try to disable loading of loader preset by hickle core + base_type = loader_table[5][1] + lookup.register_class(*loader_table[2]) + lookup.register_class(*loader_table[5]) + with pytest.raises(ValueError,match = r"excluding\s+'.+'\s+base_type\s+managed\s+by\s+hickle\s+core\s+not\s+possible"): + lookup.register_class_exclude(base_type) + + # disable any of the other loaders + base_type = loader_table[2][1] + lookup.register_class_exclude(base_type) + + +def patch_importlib_util_find_spec(name,package=None): + """ + function used to temporarily redirect search for laoders + to hickle_loader directory in test directory for testing + loading of new loaders + """ + return find_spec("hickle.tests." + name.replace('.','_',1),package) + +def patch_importlib_util_find_no_spec(name,package=None): + """ + function used to simulate situation where no appropriate loader + could be found for object + """ + return None + + +def test_load_loader(loader_table,monkeypatch): + """ + test load_loader function + """ + + # some data to check loader for + # assume loader should be load_builtins loader + py_object = dict() + loader_name = "hickle.loaders.load_builtins" + with monkeypatch.context() as moc_import_lib: + + # hide loader from hickle.lookup.loaded_loaders and check that + # fallback loader for python object is returned + moc_import_lib.setattr("importlib.util.find_spec",patch_importlib_util_find_no_spec) + moc_import_lib.setattr("hickle.lookup.find_spec",patch_importlib_util_find_no_spec) + moc_import_lib.delitem(sys.modules,"hickle.loaders.load_builtins",raising=False) + py_obj_type,nopickleloader = lookup.load_loader(py_object.__class__) + assert py_obj_type is dict and nopickleloader == (lookup.create_pickled_dataset,b'pickle') + + # redirect load_builtins loader to tests/hickle_loader path + moc_import_lib.setattr("importlib.util.find_spec",patch_importlib_util_find_spec) + moc_import_lib.setattr("hickle.lookup.find_spec",patch_importlib_util_find_spec) + + # preload dataset only loader and check that it can be resolved directly + loader_spec = loader_table[0] + lookup.register_class(*loader_spec) + assert lookup.load_loader((12).__class__) == (loader_spec[0],loader_spec[2:0:-1]) + + # try to find appropriate loader for dict object, a moc of this + # loader should be provided by hickle/tests/hickle_loaders/load_builtins + # module ensure that this module is the one found by load_loader function + import hickle.tests.hickle_loaders.load_builtins as load_builtins + moc_import_lib.setitem(sys.modules,loader_name,load_builtins) + assert lookup.load_loader(py_object.__class__) == (dict,(load_builtins.create_package_test,b'dict')) + + # remove loader again and undo redirection again. dict should now be + # processed by create_pickled_dataset + moc_import_lib.delitem(sys.modules,loader_name) + del lookup.types_dict[dict] + py_obj_type,nopickleloader = lookup.load_loader(py_object.__class__) + assert py_obj_type is dict and nopickleloader == (lookup.create_pickled_dataset,b'pickle') + + # check that load_loader prevenst redefinition of loaders to be predefined by hickle core + with pytest.raises( + RuntimeError, + match = r"objects\s+defined\s+by\s+hickle\s+core\s+must\s+be" + r"\s+registerd\s+before\s+first\s+dump\s+or\s+load" + ): + py_obj_type,nopickleloader = lookup.load_loader(ToBeInLoadersOrNotToBe) + monkeypatch.setattr(ToBeInLoadersOrNotToBe,'__module__','hickle.loaders') + + # check that load_loaders issues drop warning upon loader definitions for + # dummy objects defined within hickle package but outsied loaders modules + with pytest.warns( + RuntimeWarning, + match = r"ignoring\s+'.+'\s+dummy\s+type\s+not\s+defined\s+by\s+loader\s+module" + ): + py_obj_type,nopickleloader = lookup.load_loader(ToBeInLoadersOrNotToBe) + assert py_obj_type is ToBeInLoadersOrNotToBe + assert nopickleloader == (lookup.create_pickled_dataset,b'pickle') + + # check that loader definitions for dummy objets defined by loaders work as expected + # by loader module + monkeypatch.setattr(ToBeInLoadersOrNotToBe,'__module__',loader_name) + py_obj_type,(create_dataset,base_type) = lookup.load_loader(ToBeInLoadersOrNotToBe) + assert py_obj_type is ToBeInLoadersOrNotToBe and base_type == b'NotHicklable' + assert create_dataset is not_dumpable + + # remove loader_name from list of loaded loaders and check that loader is loaded anew + # and that values returned for dict object correspond to loader + # provided by freshly loaded loader module + lookup.loaded_loaders.remove(loader_name) + py_obj_type,(create_dataset,base_type) = lookup.load_loader(py_object.__class__) + load_builtins_moc = sys.modules.get(loader_name,None) + assert load_builtins_moc is not None + loader_spec = load_builtins_moc.class_register[0] + assert py_obj_type is dict and create_dataset is loader_spec[2] + assert base_type is loader_spec[1] + +def test_type_legacy_mro(): + """ + tests type_legacy_mro function which is used in replacement + for native type.mro function when loading 4.0.0 and 4.0.1 files + it handles cases where type objects passed to load_loader are + functions not classes + """ + + # check that for class object type_legacy_mro function returns + # the mro list provided by type.mro unchanged + assert lookup.type_legacy_mro(SimpleClass) == type.mro(SimpleClass) + + # check that in case function is passed as type object a tuple with + # function as single element is returned + assert lookup.type_legacy_mro(function_to_dump) == (function_to_dump,) + + +def test_create_pickled_dataset(h5_data): + """ + tests create_pickled_dataset, load_pickled_data function and PickledContainer + """ + + # check if create_pickled_dataset issues SerializedWarning for objects which + # either do not support copy protocol + py_object = ClassToDump('hello',1) + data_set_name = "greetings" + with pytest.warns(lookup.SerializedWarning,match = r".*type\s+not\s+understood,\s+data\s+is\s+serialized:.*") as warner: + MetaClassToDump.hide_special.add('__reduce_ex__') + MetaClassToDump.hide_special.add('__reduce__') + h5_node,subitems = lookup.create_pickled_dataset(py_object, h5_data,data_set_name) + MetaClassToDump.hide_special.clear() + assert isinstance(h5_node,h5py.Dataset) and not subitems and iter(subitems) + assert bytes(h5_node[()]) == pickle.dumps(py_object) and h5_node.name.split('/')[2] == data_set_name + assert lookup.load_pickled_data(h5_node,b'pickle',object) == py_object + + # check that create_pickled_dataset properly stores object with __getstate__ and __setstate__ + h5_node,subitems = lookup.create_pickled_dataset(py_object, h5_data,"reduced") + assert isinstance(h5_node,h5py.Group) and iter(subitems) + index = -1 + pickled_container = lookup.PickledContainer(h5_node.attrs,b'pickle',object) + valid_item_names = ('create','args','state') + reduced_obj = py_object.__reduce_ex__(pickle.DEFAULT_PROTOCOL) + for index,(name,item,attrs,kwargs) in enumerate(subitems): + assert name == valid_item_names[index] and not attrs and not kwargs + assert item == reduced_obj[index] + pickled_container.append(name,item,attrs) + assert index + 1 == 3 and pickled_container.convert() == py_object + + # check that create_pickled_dataset properly stores object without instance __dict__ + py_object_nostate = ToBeInLoadersOrNotToBe() + h5_node,subitems = lookup.create_pickled_dataset(py_object_nostate, h5_data,"no_state") + assert isinstance(h5_node,h5py.Group) and iter(subitems) + index = -1 + pickled_container = lookup.PickledContainer(h5_node.attrs,b'pickle',object) + valid_item_names = ('create','args') + reduced_obj = py_object_nostate.__reduce_ex__(pickle.DEFAULT_PROTOCOL) + for index,(name,item,attrs,kwargs) in enumerate(subitems): + assert name == valid_item_names[index] and not attrs and not kwargs + assert item == reduced_obj[index] + pickled_container.append(name,item,attrs) + assert index + 1 == 2 and pickled_container.convert() == py_object_nostate + + # check that create_pickled_dataset properly stored object with instance __dict__ only + py_object_dict_state = SimpleClass() + h5_node,subitems = lookup.create_pickled_dataset(py_object_dict_state, h5_data,"dict_state") + assert isinstance(h5_node,h5py.Group) and iter(subitems) + index = -1 + pickled_container = lookup.PickledContainer(h5_node.attrs,b'pickle',object) + valid_item_names = ('create','args','state') + reduced_obj = py_object_dict_state.__reduce_ex__(pickle.DEFAULT_PROTOCOL) + for index,(name,item,attrs,kwargs) in enumerate(subitems): + assert name == valid_item_names[index] and not attrs and not kwargs + assert item == reduced_obj[index] + pickled_container.append(name,item,attrs) + assert index + 1 == 3 and pickled_container.convert() == py_object_dict_state + + # check that create_pickled_dataset stores external setstate method along with + # state dict of object + py_object = ClassToDump('hello',1) + ClassToDump.extern_setstate = True + h5_node,subitems = lookup.create_pickled_dataset(py_object, h5_data,"explicit_setstate") + assert isinstance(h5_node,h5py.Group) and iter(subitems) + index = -1 + pickled_container = lookup.PickledContainer(h5_node.attrs,b'pickle',object) + valid_item_names = ('create','args','state','set') + reduced_obj = py_object.__reduce_ex__(pickle.DEFAULT_PROTOCOL) + for index,(name,item,attrs,kwargs) in enumerate(subitems): + assert name == valid_item_names[index] and not attrs and not kwargs + assert item == reduced_obj[index if name != "set" else 5] + pickled_container.append(name,item,attrs) + assert index + 1 == 4 and pickled_container.convert() == py_object + ClassToDump.extern_setstate = False + + # check that create_python_dtype_dataset stores list iterator reported by __reduce_ex__ or + # __reduce__ + py_object = list(range(50)) + h5_node,subitems = lookup.create_pickled_dataset(py_object, h5_data,"with_list_iterator") + assert isinstance(h5_node,h5py.Group) and iter(subitems) + index = -1 + pickled_container = lookup.PickledContainer(h5_node.attrs,b'pickle',object) + valid_item_names = ('create','args','list') + reduced_obj = py_object.__reduce_ex__(pickle.DEFAULT_PROTOCOL) + for index,(name,item,attrs,kwargs) in enumerate(subitems): + assert name == valid_item_names[index] and not attrs and not kwargs + if name != "list": + assert item == reduced_obj[index] + else: + assert item == tuple(reduced_obj[3]) + pickled_container.append(name,item,attrs) + assert index + 1 == 3 and pickled_container.convert() == py_object + + # check that PickledContainer uses append to pass additional list values + # to restored object. These values were before provided by list iterator + py_object = NoExtendList(py_object) + h5_node,subitems = lookup.create_pickled_dataset(py_object, h5_data,"no_extend") + assert isinstance(h5_node,h5py.Group) and iter(subitems) + index = -1 + pickled_container = lookup.PickledContainer(h5_node.attrs,b'pickle',object) + valid_item_names = ('create','args','list') + reduced_obj = py_object.__reduce_ex__(pickle.DEFAULT_PROTOCOL) + for index,(name,item,attrs,kwargs) in enumerate(subitems): + assert name == valid_item_names[index] and not attrs and not kwargs + if name != "list": + assert item == reduced_obj[index] + else: + assert item == tuple(reduced_obj[3]) + pickled_container.append(name,item,attrs) + assert index + 1 == 3 and pickled_container.convert() == py_object + + # check that PickledContainer uses append instead of extend for low numbers of + # items in itrator list + py_object = [1] + h5_node,subitems = lookup.create_pickled_dataset(py_object, h5_data,"short_iterator") + assert isinstance(h5_node,h5py.Group) and iter(subitems) + index = -1 + pickled_container = lookup.PickledContainer(h5_node.attrs,b'pickle',object) + valid_item_names = ('create','args','list') + reduced_obj = py_object.__reduce_ex__(pickle.DEFAULT_PROTOCOL) + for index,(name,item,attrs,kwargs) in enumerate(subitems): + assert name == valid_item_names[index] and not attrs and not kwargs + if name != "list": + assert item == reduced_obj[index] + else: + assert item == tuple(reduced_obj[3]) + pickled_container.append(name,item,attrs) + assert index + 1 == 3 and pickled_container.convert() == py_object + + # check create_pickled_dataset properly stores keys and values provided by + # dict iterator using one dataset for list of keys and the other for list of values + py_object = {"key_{}".format(value):value for value in py_object} + h5_node,subitems = lookup.create_pickled_dataset(py_object, h5_data,"with_dict_iterator") + assert isinstance(h5_node,h5py.Group) and iter(subitems) + index = -1 + pickled_container = lookup.PickledContainer(h5_node.attrs,b'pickle',object) + valid_item_names = ('create','args','keys','values') + reduced_obj = py_object.__reduce_ex__(pickle.DEFAULT_PROTOCOL) + reduced_obj = reduced_obj[:4] + (dict(reduced_obj[4]),) + for index,(name,item,attrs,kwargs) in enumerate(subitems): + assert name == valid_item_names[index] and not attrs and not kwargs + if name == 'keys': + assert item == tuple(reduced_obj[4].keys()) + elif name == 'values': + assert item == tuple(reduced_obj[4].values()) + else: + assert item == reduced_obj[index] + pickled_container.append(name,item,attrs) + assert index + 1 == 4 and pickled_container.convert() == py_object + + # check that create_pickled_dataset ignores key list or value list if either + # is not present + h5_node,subitems = lookup.create_pickled_dataset(py_object, h5_data,"no_values") + assert isinstance(h5_node,h5py.Group) and iter(subitems) + index = -1 + pickled_container = lookup.PickledContainer(h5_node.attrs,b'pickle',object) + valid_item_names = ('create','args','keys','values') + reduced_obj = py_object.__reduce_ex__(pickle.DEFAULT_PROTOCOL) + reduced_obj = reduced_obj[:4] + (dict(reduced_obj[4]),) + for index,(name,item,attrs,kwargs) in enumerate(subitems): + assert name == valid_item_names[index] and not attrs and not kwargs + if name == 'keys': + assert item == tuple(reduced_obj[4].keys()) + elif name == 'values': + # drop values + continue + else: + assert item == reduced_obj[index] + pickled_container.append(name,item,attrs) + assert index + 1 == 4 and pickled_container.convert() == {} + +def test__DictItemContainer(): + """ + tests _DictItemContainer class which represent dict_item goup + used by version 4.0.0 files to represent values of dictionary key + """ + container = lookup._DictItemContainer({},b'dict_item',lookup._DictItem) + my_bike_lock = (1,2,3,4) + container.append('my_bike_lock',my_bike_lock,{}) + assert container.convert() is my_bike_lock + + +def test__moc_numpy_array_object_lambda(): + """ + test the _moc_numpy_array_object_lambda function + which mimicks the effect of lambda function created + py pickle when expanding pickle `'type'` string set + for numpy arrays containing a single object not expandable + into a list. Mocking is necessary from Python 3.8.X on + as it seems in Python 3.8 and onwards trying to pickle + a lambda now causes a TypeError whilst it seems to be silently + accepted in Python < 3.8 + """ + data = ['hello','world'] + assert lookup._moc_numpy_array_object_lambda(data) == data[0] + +def test_fix_lambda_obj_type(): + """ + test _moc_numpy_array_object_lambda function it self. When invokded + it should return the first element of the passed list + """ + assert lookup.fix_lambda_obj_type(None) is object + picklestring = pickle.dumps(SimpleClass) + assert lookup.fix_lambda_obj_type(picklestring) is SimpleClass + assert lookup.fix_lambda_obj_type('') is lookup._moc_numpy_array_object_lambda + +# %% MAIN SCRIPT +if __name__ == "__main__": + from _pytest.monkeypatch import monkeypatch + from _pytest.fixtures import FixtureRequest + for table in loader_table(): + test_register_class(table) + for table in loader_table(): + test_register_class_exclude(table) + for monkey in monkeypatch(): + test_load_loader(table,monkey) + test_type_legacy_mro() + for h5_root in h5_data(FixtureRequest(test_create_pickled_dataset)): + test_create_pickled_dataset(h5_root) + test__DictItemContainer() + test__moc_numpy_array_object_lambda() + test_fix_lambda_obj_type() + test_fix_lambda_obj_type() + + + + + diff --git a/hickle/tests/test_03_load_builtins.py b/hickle/tests/test_03_load_builtins.py new file mode 100644 index 00000000..e071263e --- /dev/null +++ b/hickle/tests/test_03_load_builtins.py @@ -0,0 +1,416 @@ +#! /usr/bin/env python +# encoding: utf-8 +""" +# test_load_builtins + +Unit tests for hickle module -- builtins loader. + +""" + +import pytest +import collections +import itertools + +# %% IMPORTS +# Package imports +import h5py as h5 +import numpy as np +from py.path import local + +# hickle imports +import hickle.loaders.load_builtins as load_builtins +import hickle.helpers as helpers + + +# Set current working directory to the temporary directory +local.get_temproot().chdir() + + +# %% TEST DATA + +dummy_data = (1,2,3) + +# %% FIXTURES + +@pytest.fixture +def h5_data(request): + """ + create dummy hdf5 test data file providing parent group + hosting createed datasets and groups. Name of test function + is included in filename + """ + dummy_file = h5.File('load_builtins_{}.hdf5'.format(request.function.__name__),'w') + filename = dummy_file.filename + test_data = dummy_file.create_group("root_group") + yield test_data + dummy_file.close() + + +# %% FUNCTION DEFINITIONS + +def test_scalar_dataset(h5_data): + """ + tests creation and loading of datasets for scalar values + """ + + # check that scalar value is properly handled + floatvalue = 5.2 + h_dataset,subitems= load_builtins.create_scalar_dataset(floatvalue,h5_data,"floatvalue") + assert isinstance(h_dataset,h5.Dataset) and h_dataset[()] == floatvalue + assert not [ item for item in subitems ] + assert load_builtins.load_scalar_dataset(h_dataset,b'float',float) == floatvalue + + # check that intger value less thatn 64 bit is stored as int + intvalue = 11 + h_dataset,subitems = load_builtins.create_scalar_dataset(intvalue,h5_data,"intvalue") + assert isinstance(h_dataset,h5.Dataset) and h_dataset[()] == intvalue + assert not [ item for item in subitems ] + assert load_builtins.load_scalar_dataset(h_dataset,b'int',int) == intvalue + + # check that integer larger than 64 bit is stored as ascii byte string + non_mappable_int = int(2**65) + h_dataset,subitems = load_builtins.create_scalar_dataset(non_mappable_int,h5_data,"non_mappable_int") + assert isinstance(h_dataset,h5.Dataset) + assert bytearray(h_dataset[()]) == str(non_mappable_int).encode('utf8') + assert not [ item for item in subitems ] + assert load_builtins.load_scalar_dataset(h_dataset,b'int',int) == non_mappable_int + + +def test_non_dataset(h5_data): + """ + that None value is properly stored + """ + h_dataset,subitems = load_builtins.create_none_dataset(None,h5_data,"None_value") + assert isinstance(h_dataset,h5.Dataset) and bytearray(h_dataset[()]) == b'None' + assert not [ item for item in subitems ] + assert load_builtins.load_none_dataset(h_dataset,b'None',None.__class__) is None + + +def test_listlike_dataset(h5_data): + """ + test storing and loading of list like data + """ + + # check that empty tuple is stored properly + empty_tuple = () + h_dataset,subitems = load_builtins.create_listlike_dataset(empty_tuple, h5_data, "empty_tuple") + assert isinstance(h_dataset,h5.Dataset) and h_dataset.size is None + assert not subitems and iter(subitems) + assert load_builtins.load_list_dataset(h_dataset,b'tuple',tuple) == empty_tuple + + # check that string data is stored properly stored as array of bytes + # which supports compression + stringdata = "string_data" + h_dataset,subitems = load_builtins.create_listlike_dataset(stringdata, h5_data, "string_data") + assert isinstance(h_dataset,h5.Dataset) and not [ item for item in subitems ] + assert bytearray(h_dataset[()]).decode("utf8") == stringdata + assert h_dataset.attrs["str_type"].decode("ascii") == 'str' + assert load_builtins.load_list_dataset(h_dataset,b'str',str) == stringdata + + # check that byte string is proprly stored as array of bytes which + # supports compression + bytesdata = b'bytes_data' + h_dataset,subitems = load_builtins.create_listlike_dataset(bytesdata, h5_data, "bytes_data") + assert isinstance(h_dataset,h5.Dataset) and not [ item for item in subitems ] + assert bytes(h_dataset[()]) == bytesdata + assert h_dataset.attrs["str_type"].decode("ascii") == 'bytes' + assert load_builtins.load_list_dataset(h_dataset,b'bytes',bytes) == bytesdata + + # check that string dataset created by hickle 4.0.x is properly loaded + # utilizing numpy.array method. Mimick dumped data + h_dataset = h5_data.create_dataset("legacy_np_array_bytes_data",data=np.array(stringdata.encode('utf8'))) + h_dataset.attrs['str_type'] = b'str' + assert load_builtins.load_list_dataset(h_dataset,b'str',str) == stringdata + + # check that list of single type is stored as dataset of same type + homogenous_list = [ 1, 2, 3, 4, 5, 6] + h_dataset,subitems = load_builtins.create_listlike_dataset(homogenous_list,h5_data,"homogenous_list") + assert isinstance(h_dataset,h5.Dataset) and not [ item for item in subitems ] + assert h_dataset[()].tolist() == homogenous_list and h_dataset.dtype == int + assert load_builtins.load_list_dataset(h_dataset,b'list',list) == homogenous_list + + # check that list of different scalar types for which a least common type exists + # is stored using a dataset + mixed_dtype_list = [ 1, 2.5, 3.8, 4, 5, 6] + h_dataset,subitems = load_builtins.create_listlike_dataset(mixed_dtype_list,h5_data,"mixed_dtype_list") + assert isinstance(h_dataset,h5.Dataset) and not [ item for item in subitems ] + assert h_dataset[()].tolist() == mixed_dtype_list and h_dataset.dtype == float + assert load_builtins.load_list_dataset(h_dataset,b'list',list) == mixed_dtype_list + + # check that list containing non scalar objects is converted into group + # further check that for groups representing list the index of items is either + # provided via item_index attribute or can be read from name of item + not_so_homogenous_list = [ 1, 2, 3, [4],5 ,6 ] + h_dataset,subitems = load_builtins.create_listlike_dataset(not_so_homogenous_list,h5_data,"not_so_homogenous_list") + assert isinstance(h_dataset,h5.Group) + item_name = "data{:d}" + index = -1 + loaded_list = load_builtins.ListLikeContainer(h_dataset.attrs,b'list',list) + subitems1,subitems2 = itertools.tee(subitems,2) + index_from_string = load_builtins.ListLikeContainer(h_dataset.attrs,b'list',list) + for index,(name,item,attrs,kwargs) in enumerate(iter(subitems1)): + assert item_name.format(index) == name and item == not_so_homogenous_list[index] + assert attrs == {"item_index":index} and kwargs == {} + if isinstance(item,list): + item_dataset,_ = load_builtins.create_listlike_dataset(item,h_dataset,name) + else: + item_dataset = h_dataset.create_dataset(name,data = item) + item_dataset.attrs.update(attrs) + loaded_list.append(name,item,item_dataset.attrs) + index_from_string.append(name,item,{}) + assert index + 1 == len(not_so_homogenous_list) + assert loaded_list.convert() == not_so_homogenous_list + assert index_from_string.convert() == not_so_homogenous_list + + # check that list groups which do not provide num_items attribute + # are automatically expanded to properly cover the highes index encountered + # for any of the list items. + no_num_items = {key:value for key,value in h_dataset.attrs.items() if key != "num_items"} + no_num_items_container = load_builtins.ListLikeContainer(no_num_items,b'list',list) + for index,(name,item,attrs,kwargs) in enumerate(iter(subitems2)): + assert item_name.format(index) == name and item == not_so_homogenous_list[index] + assert attrs == {"item_index":index} and kwargs == {} + item_dataset = h_dataset.get(name,None) + no_num_items_container.append(name,item,{}) + assert index + 1 == len(not_so_homogenous_list) + assert no_num_items_container.convert() == not_so_homogenous_list + + # check that list the first of which is not a scalar is properly mapped + # to a group. Also check that ListLikeContainer.append raises exception + # in case neither item_index is provided nor an index value can be parsed + # from the taile of its name. Also check that ListLikeContainer.append + # raises exceptoin in case value for item_index already has been loaded + object_list = [ [4, 5 ] ,6, [ 1, 2, 3 ] ] + h_dataset,subitems = load_builtins.create_listlike_dataset(object_list,h5_data,"object_list") + assert isinstance(h_dataset,h5.Group) + item_name = "data{:d}" + wrong_item_name = item_name + "_ni" + index = -1 + loaded_list = load_builtins.ListLikeContainer(h_dataset.attrs,b'list',list) + index_from_string = load_builtins.ListLikeContainer(h_dataset.attrs,b'list',list) + for index,(name,item,attrs,kwargs) in enumerate(iter(subitems)): + assert item_name.format(index) == name and item == object_list[index] + assert attrs == {"item_index":index} and kwargs == {} + if isinstance(item,list): + item_dataset,_ = load_builtins.create_listlike_dataset(item,h_dataset,name) + else: + item_dataset = h_dataset.create_dataset(name,data = item) + item_dataset.attrs.update(attrs) + loaded_list.append(name,item,item_dataset.attrs) + with pytest.raises(KeyError,match = r"List\s+like\s+item name\s+'\w+'\s+not\s+understood"): + index_from_string.append(wrong_item_name.format(index),item,{}) + # check that previous error is not triggerd when + # legacy 4.0.x loader injects the special value helpers.nobody_is_my_name which + # is generated by load_nothing function. this is for example used as load method + # for legacy 4.0.x np.masked.array objects where the mask is injected in parallel + # in the root group of the corresponding values data set. By silently ignoring + # this special value returned by load_nothing it can be assured that for example + # mask datasets of numpy.masked.array objects hickup the loader. + index_from_string.append(wrong_item_name.format(index),helpers.nobody_is_my_name,{}) + if index < 1: + continue + with pytest.raises(IndexError, match = r"Index\s+\d+\s+already\s+set"): + loaded_list.append(name,item,{"item_index":index-1}) + assert index + 1 == len(object_list) + + # assert that list of strings where first string has lenght 1 is properly mapped + # to group + string_list = test_set = ['I','confess','appriciate','hickle','times'] + h_dataset,subitems = load_builtins.create_listlike_dataset(string_list,h5_data,"string_list") + assert isinstance(h_dataset,h5.Group) + item_name = "data{:d}" + index = -1 + loaded_list = load_builtins.ListLikeContainer(h_dataset.attrs,b'list',list) + index_from_string = load_builtins.ListLikeContainer(h_dataset.attrs,b'list',list) + for index,(name,item,attrs,kwargs) in enumerate(iter(subitems)): + assert item_name.format(index) == name and item == string_list[index] + assert attrs == {"item_index":index} and kwargs == {} + item_dataset = h_dataset.create_dataset(name,data = item) + item_dataset.attrs.update(attrs) + loaded_list.append(name,item,item_dataset.attrs) + index_from_string.append(name,item,{}) + assert index + 1 == len(string_list) + assert loaded_list.convert() == string_list + assert index_from_string.convert() == string_list + + # assert that list which contains numeric values and strings is properly mapped + # to group + mixed_string_list = test_set = [12,2.8,'I','confess','appriciate','hickle','times'] + h_dataset,subitems = load_builtins.create_listlike_dataset(mixed_string_list,h5_data,"mixed_string_list") + assert isinstance(h_dataset,h5.Group) + item_name = "data{:d}" + index = -1 + loaded_list = load_builtins.ListLikeContainer(h_dataset.attrs,b'list',list) + index_from_string = load_builtins.ListLikeContainer(h_dataset.attrs,b'list',list) + for index,(name,item,attrs,kwargs) in enumerate(iter(subitems)): + assert item_name.format(index) == name and item == mixed_string_list[index] + assert attrs == {"item_index":index} and kwargs == {} + item_dataset = h_dataset.create_dataset(name,data = item) + item_dataset.attrs.update(attrs) + loaded_list.append(name,item,item_dataset.attrs) + index_from_string.append(name,item,{}) + assert index + 1 == len(mixed_string_list) + assert loaded_list.convert() == mixed_string_list + assert index_from_string.convert() == mixed_string_list + + +def test_set_container(h5_data): + """ + tests storing and loading of set + """ + + # check that set of strings is store as group + test_set = {'I','confess','appriciate','hickle','times'} + h_setdataset,subitems = load_builtins.create_setlike_dataset(test_set,h5_data,"test_set") + set_container = load_builtins.SetLikeContainer(h_setdataset.attrs,b'set',set) + for name,item,attrs,kwargs in subitems: + set_container.append(name,item,attrs) + assert set_container.convert() == test_set + + # check that set of single bytes is stored as single dataset + test_set_2 = set(b"hello world") + h_setdataset,subitems = load_builtins.create_setlike_dataset(test_set_2,h5_data,"test_set_2") + assert isinstance(h_setdataset,h5.Dataset) and set(h_setdataset[()]) == test_set_2 + assert not subitems and iter(subitems) + assert load_builtins.load_list_dataset(h_setdataset,b'set',set) == test_set_2 + + # check that set containing byte strings is stored as group + test_set_3 = set((item.encode("utf8") for item in test_set)) + h_setdataset,subitems = load_builtins.create_setlike_dataset(test_set_3,h5_data,"test_set_3") + set_container = load_builtins.SetLikeContainer(h_setdataset.attrs,b'set',set) + for name,item,attrs,kwargs in subitems: + set_container.append(name,item,attrs) + assert set_container.convert() == test_set_3 + + # check that empty set is represented by emtpy dataset + h_setdataset,subitems = load_builtins.create_setlike_dataset(set(),h5_data,"empty_set") + assert isinstance(h_setdataset,h5.Dataset) and h_setdataset.size == 0 + assert not subitems and iter(subitems) + assert load_builtins.load_list_dataset(h_setdataset,b'set',set) == set() + + +def test_dictlike_dataset(h5_data): + """ + test storing and loading of dict + """ + + class KeyClass(): + """class used as dict key""" + + allkeys_dict = { + 'string_key':0, + b'bytes_key':1, + 12:2, + 0.25:3, + complex(1,2):4, + None:5, + (1,2,3):6, + KeyClass():7, + KeyClass:8 + } + + # check that dict is stored as group + # check that string and byte string keys are mapped to dataset or group name + # check that scalar dict keys are converted to their string representation + # check that for all other keys a key value pair is created + h_datagroup,subitems = load_builtins.create_dictlike_dataset(allkeys_dict,h5_data,"allkeys_dict") + assert isinstance(h_datagroup,h5.Group) + invalid_key = b'' + last_entry = -1 + load_dict = load_builtins.DictLikeContainer(h_datagroup.attrs,b'dict',dict) + ordered_dict = collections.OrderedDict() + for name,item,attrs,kwargs in subitems: + value = item + if attrs["key_base_type"] == b"str": + key = name[1:-1] + elif attrs["key_base_type"] == b"bytes": + key = name[2:-1].encode("utf8") + elif attrs["key_base_type"] == b'key_value': + key = item[0] + value = item[1] + else: + load_key = load_builtins.dict_key_types_dict.get(attrs["key_base_type"],None) + if load_key is None: + raise ValueError("key_base_type '{}' invalid".format(attrs["key_base_type"])) + key = load_key(name) + assert allkeys_dict.get(key,invalid_key) == value + load_dict.append(name,item,attrs) + last_entry = attrs.get("key_idx",None) + ordered_dict[key] = value + assert last_entry + 1 == len(allkeys_dict) + assert load_dict.convert() == allkeys_dict + + # verify that DictLikeContainer.append raises error incase invalid key_base_type + # is provided + with pytest.raises(ValueError, match = r"key\s+type\s+'.+'\s+not\s+understood"): + load_dict.append("invalid_key_type",12,{"key_idx":9,"key_base_type":b"invalid_type"}) + tuple_key = ('a','b','c') + + # verify that DictLikeContainer.append raises error in case index of key value pair + # within dict is whether provided by key_index attribute nor can be parsed from + # name of corresponding dataset or group + with pytest.raises(KeyError, match = r"invalid\s+dict\s+item\s+key_index\s+missing"): + load_dict.append(str(tuple_key),9,{"item_index":9,"key_base_type":b"tuple"}) + + # check that helpers.nobody_is_my_name injected for example by load_nothing is silently + # ignored in case no key could be retireved from dataset or sub group + load_dict.append( + str(tuple_key), helpers.nobody_is_my_name, + {"item_index":9,"key_base_type":b"tuple"} + ) + with pytest.raises(KeyError): + assert load_dict.convert()[tuple_key] is None + + # check that if key_idx attribute is provided key value pair may be added + load_dict.append(str(tuple_key),9,{"key_idx":9,"key_base_type":b"tuple"}) + assert load_dict.convert()[tuple_key] == 9 + + # verify that DictLikeContainer.append raises error in case item index already + # set + with pytest.raises(IndexError,match = r"Key\s+index\s+\d+\s+already\s+set"): + load_dict.append(str(tuple_key),9,{"key_idx":9,"key_base_type":b"tuple"}) + + # check that order of OrderedDict dict keys is not altered on loading data from + # hickle file + h_datagroup,subitems = load_builtins.create_dictlike_dataset(ordered_dict,h5_data,"ordered_dict") + assert isinstance(h_datagroup,h5.Group) + last_entry = -1 + load_ordered_dict = load_builtins.DictLikeContainer(h_datagroup.attrs,b'dict',collections.OrderedDict) + for name,item,attrs,kwargs in subitems: + value = item + if attrs["key_base_type"] == b"str": + key = name[1:-1] + elif attrs["key_base_type"] == b"bytes": + key = name[2:-1].encode("utf8") + elif attrs["key_base_type"] == b'key_value': + key = item[0] + value = item[1] + else: + load_key = load_builtins.dict_key_types_dict.get(attrs["key_base_type"],None) + if load_key is None: + raise ValueError("key_base_type '{}' invalid".format(attrs["key_base_type"])) + key = load_key(name) + assert ordered_dict.get(key,invalid_key) == value + load_ordered_dict.append(name,item,attrs) + last_entry = attrs.get("key_idx",None) + assert last_entry + 1 == len(allkeys_dict) + assert load_ordered_dict.convert() == ordered_dict + + +# %% MAIN SCRIPT +if __name__ == "__main__": + from _pytest.fixtures import FixtureRequest + for h5_root in h5_data(FixtureRequest(test_scalar_dataset)): + test_scalar_dataset(h5_root) + for h5_root in h5_data(FixtureRequest(test_non_dataset)): + test_non_dataset(h5_root) + for h5_root in h5_data(FixtureRequest(test_listlike_dataset)): + test_listlike_dataset(h5_root) + for h5_root in h5_data(FixtureRequest(test_set_container)): + test_set_container(h5_root) + for h5_root in h5_data(FixtureRequest(test_dictlike_dataset)): + test_dictlike_dataset(h5_root) + + + + diff --git a/hickle/tests/test_04_load_numpy.py b/hickle/tests/test_04_load_numpy.py new file mode 100644 index 00000000..7bfe086f --- /dev/null +++ b/hickle/tests/test_04_load_numpy.py @@ -0,0 +1,234 @@ +#! /usr/bin/env python +# encoding: utf-8 +""" +# test_load_numpy + +Unit tests for hickle module -- numpy loader. + +""" +import pytest + +import sys + +# %% IMPORTS +# Package imports +import h5py as h5 +import numpy as np +import hickle.loaders.load_numpy as load_numpy +from py.path import local + + +# Set current working directory to the temporary directory +local.get_temproot().chdir() + +# %% GLOBALS + +NESTED_DICT = { + "level1_1": { + "level2_1": [1, 2, 3], + "level2_2": [4, 5, 6] + }, + "level1_2": { + "level2_1": [1, 2, 3], + "level2_2": [4, 5, 6] + }, + "level1_3": { + "level2_1": { + "level3_1": [1, 2, 3], + "level3_2": [4, 5, 6] + }, + "level2_2": [4, 5, 6] + } +} + +# %% FIXTURES + +@pytest.fixture +def h5_data(request): + """ + create dummy hdf5 test data file for testing PyContainer and H5NodeFilterProxy + """ + dummy_file = h5.File('test_load_builtins.hdf5','w') + dummy_file = h5.File('load_numpy_{}.hdf5'.format(request.function.__name__),'w') + filename = dummy_file.filename + test_data = dummy_file.create_group("root_group") + yield test_data + dummy_file.close() + +# %% FUNCTION DEFINITIONS + +def test_create_np_scalar(h5_data): + """ + tests proper storage and loading of numpy scalars + """ + + # check that scalar dataset is created for nupy scalar + scalar_data = np.float64(np.pi) + dtype = scalar_data.dtype + h_dataset,subitems = load_numpy.create_np_scalar_dataset(scalar_data,h5_data,"scalar_data") + assert isinstance(h_dataset,h5.Dataset) and iter(subitems) and not subitems + assert h_dataset.attrs['np_dtype'] == dtype.str.encode('ascii') + assert h_dataset[()] == scalar_data + assert load_numpy.load_np_scalar_dataset(h_dataset,b'np_scalar',scalar_data.__class__) == scalar_data + + # check that numpy.bool_ scarlar is properly stored and reloaded + scalar_data = np.bool_(True) + dtype = scalar_data.dtype + h_dataset,subitems = load_numpy.create_np_scalar_dataset(scalar_data,h5_data,"generic_data") + assert isinstance(h_dataset,h5.Dataset) and iter(subitems) and not subitems + assert h_dataset.attrs['np_dtype'] == dtype.str.encode('ascii') and h_dataset[()] == scalar_data + assert load_numpy.load_np_scalar_dataset(h_dataset,b'np_scalar',scalar_data.__class__) == scalar_data + +def test_create_np_dtype(h5_data): + """ + test proper creation and loading of dataset representing numpy dtype + """ + dtype = np.dtype(np.int16) + h_dataset,subitems = load_numpy.create_np_dtype(dtype, h5_data,"dtype_string") + assert isinstance(h_dataset,h5.Dataset) and iter(subitems) and not subitems + assert bytes(h_dataset[()]).decode('ascii') == dtype.str + assert load_numpy.load_np_dtype_dataset(h_dataset,'np_dtype',np.dtype) == dtype + +def test_create_np_ndarray(h5_data): + """ + test proper creatoin and loading of numpy ndarray + """ + + # check that numpy array representing python utf8 string is properly + # stored as bytearray dataset and reloaded from + np_array_data = np.array("im python string") + h_dataset,subitems = load_numpy.create_np_array_dataset(np_array_data,h5_data,"numpy_string_array") + assert isinstance(h_dataset,h5.Dataset) and iter(subitems) and not subitems + assert bytes(h_dataset[()]) == np_array_data.tolist().encode("utf8") + assert h_dataset.attrs["np_dtype"] == np_array_data.dtype.str.encode("ascii") + assert load_numpy.load_ndarray_dataset(h_dataset,b'ndarray',np.ndarray) == np_array_data + + # chekc that numpy array representing python bytes string is properly + # stored as bytearray dataset and reloaded from + np_array_data = np.array(b"im python bytes") + h_dataset,subitems = load_numpy.create_np_array_dataset(np_array_data,h5_data,"numpy_bytes_array") + assert isinstance(h_dataset,h5.Dataset) and iter(subitems) and not subitems + assert h_dataset[()] == np_array_data.tolist() + assert h_dataset.attrs["np_dtype"] == np_array_data.dtype.str.encode("ascii") + assert load_numpy.load_ndarray_dataset(h_dataset,b'ndarray',np.ndarray) == np_array_data + + # check that numpy array with dtype object representing list of various kinds + # of objects is converted to list before storing and reloaded proprly from this + # list representation + np_array_data = np.array([[NESTED_DICT], ('What is this?',), {1, 2, 3, 7, 1}]) + h_dataset,subitems = load_numpy.create_np_array_dataset(np_array_data,h5_data,"numpy_list_object_array") + ndarray_container = load_numpy.NDArrayLikeContainer(h_dataset.attrs,b'ndarray',np_array_data.__class__) + assert isinstance(h_dataset,h5.Group) and iter(subitems) + assert h_dataset.attrs["np_dtype"] == np_array_data.dtype.str.encode("ascii") + for index,(name,item,attrs,kwargs) in enumerate(subitems): + assert name == "data{:d}".format(index) and attrs.get("item_index",None) == index + assert isinstance(kwargs,dict) and np_array_data[index] == item + ndarray_container.append(name,item,attrs) + assert np.all(ndarray_container.convert() == np_array_data) + + # check that numpy array containing multiple strings of length > 1 + # is properly converted to list of strings and restored from its list + # representation + np_array_data = np.array(["1313e", "was", "maybe?", "here"]) + h_dataset,subitems = load_numpy.create_np_array_dataset(np_array_data,h5_data,"numpy_list_of_strings_array") + ndarray_container = load_numpy.NDArrayLikeContainer(h_dataset.attrs,b'ndarray',np_array_data.__class__) + assert isinstance(h_dataset,h5.Group) and iter(subitems) + assert h_dataset.attrs["np_dtype"] == np_array_data.dtype.str.encode("ascii") + for index,(name,item,attrs,kwargs) in enumerate(subitems): + assert name == "data{:d}".format(index) and attrs.get("item_index",None) == index + assert isinstance(kwargs,dict) and np_array_data[index] == item + ndarray_container.append(name,item,attrs) + assert np.all(ndarray_container.convert() == np_array_data) + + # check that numpy array with object dtype which is converted to single object + # by ndarray.tolist method is properly stored according to type of object and + # restored from this representation accordingly + np_array_data = np.array(NESTED_DICT) + h_dataset,subitems = load_numpy.create_np_array_dataset(np_array_data,h5_data,"numpy_object_array") + ndarray_container = load_numpy.NDArrayLikeContainer(h_dataset.attrs,b'ndarray',np_array_data.__class__) + ndarray_pickle_container = load_numpy.NDArrayLikeContainer(h_dataset.attrs,b'ndarray',np_array_data.__class__) + assert isinstance(h_dataset,h5.Group) and iter(subitems) + assert h_dataset.attrs["np_dtype"] == np_array_data.dtype.str.encode("ascii") + data_set = False + for name,item,attrs,kwargs in subitems: + if name == "data": + assert not data_set and not attrs and isinstance(kwargs,dict) + assert np_array_data[()] == item + data_set = True + ndarray_container.append(name,item,attrs) + attrs = dict(attrs) + attrs["base_type"] = b'pickle' + ndarray_pickle_container.append(name,item,attrs) + else: + raise AssertionError("expected single data object") + assert np.all(ndarray_container.convert() == np_array_data) + assert np.all(ndarray_pickle_container.convert() == np_array_data) + + # check that numpy.matrix type object is properly stored and reloaded from + # hickle file. + # NOTE/TODO: current versions of numpy issue PendingDeprecationWarning when using + # numpy.matrix. In order to indicate to pytest that this is known and can safely + # be ignored the warning is captured here. Shall it be that future numpy verions + # convert PendingDeprecationWarning into any kind of exception like TypeError + # AttributeError, RuntimeError or alike that also capture these Exceptions not + # just PendingDeprecationWarning + with pytest.warns(PendingDeprecationWarning): + np_array_data = np.matrix([[1, 2], [3, 4]]) + h_dataset,subitems = load_numpy.create_np_array_dataset(np_array_data,h5_data,"numpy_matrix") + assert isinstance(h_dataset,h5.Dataset) and iter(subitems) and not subitems + assert np.all(h_dataset[()] == np_array_data) + assert h_dataset.attrs["np_dtype"] == np_array_data.dtype.str.encode("ascii") + np_loaded_array_data = load_numpy.load_ndarray_dataset(h_dataset,b'npmatrix',np.matrix) + assert np.all(np_loaded_array_data == np_array_data) + assert isinstance(np_loaded_array_data,np.matrix) + assert np_loaded_array_data.shape == np_array_data.shape + +def test_create_np_masked_array(h5_data): + """ + test proper creation and loading of numpy.masked arrays + """ + + # check that simple masked array is properly stored and loaded + masked_array = np.ma.array([1, 2, 3, 4], dtype='float32', mask=[0, 1, 0, 0]) + h_datagroup,subitems = load_numpy.create_np_masked_array_dataset(masked_array, h5_data, "masked_array") + masked_array_container = load_numpy.NDMaskedArrayContainer(h_datagroup.attrs,b'ndarray_masked',np.ma.array) + assert isinstance(h_datagroup,h5.Group) and iter(subitems) + assert h_datagroup.attrs["np_dtype"] == masked_array.dtype.str.encode("ascii") + data_set = mask_set = False + for name,item,attrs,kwargs in subitems: + assert isinstance(attrs,dict) and isinstance(kwargs,dict) + if name == "data": + assert not data_set and not attrs and np.all(masked_array.data == item) and item is not masked_array + masked_array_container.append(name,item,attrs) + data_set = True + elif name == "mask": + assert not mask_set and not attrs and np.all(masked_array.mask == item) and item is not masked_array + masked_array_container.append(name,item,attrs) + mask_set = True + else: + raise AssertionError("expected one data and one mask object") + assert np.all(masked_array_container.convert() == masked_array) + + # check that format used by hickle version 4.0.0 to encode is properly recognized + # on loading and masked array is restored accoringly + h_dataset = h5_data.create_dataset("masked_array_dataset",data = masked_array.data) + h_dataset.attrs["np_dtype"] = masked_array.dtype.str.encode("ascii") + with pytest.raises(ValueError,match = r"mask\s+not\s+found"): + loaded_masked_array = load_numpy.load_ndarray_masked_dataset(h_dataset,b'masked_array_data',np.ma.array) + h_mask_dataset = h5_data.create_dataset("masked_array_dataset_mask",data = masked_array.mask) + loaded_masked_array = load_numpy.load_ndarray_masked_dataset(h_dataset,b'masked_array_data',np.ma.array) + assert np.all(loaded_masked_array == masked_array ) + +# %% MAIN SCRIPT +if __name__ == "__main__": + from _pytest.fixtures import FixtureRequest + for h5_root in h5_data(FixtureRequest(test_create_np_scalar)): + test_create_np_scalar(h5_root) + for h5_root in h5_data(FixtureRequest(test_create_np_dtype)): + test_create_np_dtype(h5_root) + for h5_root in h5_data(FixtureRequest(test_create_np_ndarray)): + test_create_np_ndarray(h5_root) + for h5_root in h5_data(FixtureRequest(test_create_np_masked_array)): + test_create_np_masked_array(h5_root) + + diff --git a/hickle/tests/test_05_load_scipy.py b/hickle/tests/test_05_load_scipy.py new file mode 100644 index 00000000..ae3310cb --- /dev/null +++ b/hickle/tests/test_05_load_scipy.py @@ -0,0 +1,130 @@ +#! /usr/bin/env python +# encoding: utf-8 +""" +# test_load_scipy + +Unit tests for hickle module -- scipy loader. + +""" +# %% IMPORTS +# Package imports +import pytest +import h5py as h5 +import numpy as np +import dill as pickle +from scipy.sparse import csr_matrix, csc_matrix, bsr_matrix +from py.path import local + +# %% HICKLE imports +import hickle.loaders.load_scipy as load_scipy + +# Set the current working directory to the temporary directory +local.get_temproot().chdir() + +# %% FIXTURES + +@pytest.fixture +def h5_data(request): + """ + create dummy hdf5 test data file for testing PyContainer and H5NodeFilterProxy + """ + dummy_file = h5.File('test_load_builtins.hdf5','w') + dummy_file = h5.File('load_numpy_{}.hdf5'.format(request.function.__name__),'w') + filename = dummy_file.filename + test_data = dummy_file.create_group("root_group") + yield test_data + dummy_file.close() + +# %% FUNCTION DEFINITIONS + +def test_create_sparse_dataset(h5_data): + """ + test creation and loading of sparse matrix + """ + + # create all possible kinds of sparse matrix representations + row = np.array([0, 0, 1, 2, 2, 2]) + col = np.array([0, 2, 2, 0, 1, 2]) + data = np.array([1, 2, 3, 4, 5, 6]) + sm1 = csr_matrix((data, (row, col)), shape=(3, 3)) + sm2 = csc_matrix((data, (row, col)), shape=(3, 3)) + + indptr = np.array([0, 2, 3, 6]) + indices = np.array([0, 2, 2, 0, 1, 2]) + data = np.array([1, 2, 3, 4, 5, 6]).repeat(4).reshape([6, 2, 2]) + sm3 = bsr_matrix((data, indices, indptr), shape=(6, 6)) + + # check that csr type matrix is properly stored and loaded + h_datagroup,subitems = load_scipy.create_sparse_dataset(sm1,h5_data,"csr_matrix") + assert isinstance(h_datagroup,h5.Group) and iter(subitems) + seen_items = dict((key,False) for key in ("data",'indices','indptr','shape')) + sparse_container = load_scipy.SparseMatrixContainer(h_datagroup.attrs,b'csr_matrix',csr_matrix) + for name,item,attrs,kwargs in subitems: + assert not seen_items[name] + seen_items[name] = True + sparse_container.append(name,item,attrs) + reloaded = sparse_container.convert() + assert np.all(reloaded.data == sm1.data) and reloaded.dtype == sm1.dtype and reloaded.shape == sm1.shape + + # check that csc type matrix is properly stored and loaded + h_datagroup,subitems = load_scipy.create_sparse_dataset(sm2,h5_data,"csc_matrix") + assert isinstance(h_datagroup,h5.Group) and iter(subitems) + seen_items = dict((key,False) for key in ("data",'indices','indptr','shape')) + sparse_container = load_scipy.SparseMatrixContainer(h_datagroup.attrs,b'csc_matrix',csc_matrix) + for name,item,attrs,kwargs in subitems: + assert not seen_items[name] + seen_items[name] = True + sparse_container.append(name,item,attrs) + reloaded = sparse_container.convert() + assert np.all(reloaded.data == sm2.data) and reloaded.dtype == sm2.dtype and reloaded.shape == sm2.shape + + # check that bsr type matrix is properly stored and loaded + h_datagroup,subitems = load_scipy.create_sparse_dataset(sm3,h5_data,"bsr_matrix") + assert isinstance(h_datagroup,h5.Group) and iter(subitems) + seen_items = dict((key,False) for key in ("data",'indices','indptr','shape')) + sparse_container = load_scipy.SparseMatrixContainer(h_datagroup.attrs,b'bsr_matrix',bsr_matrix) + for name,item,attrs,kwargs in subitems: + assert not seen_items[name] + seen_items[name] = True + sparse_container.append(name,item,attrs) + reloaded = sparse_container.convert() + assert np.all(reloaded.data == sm3.data) and reloaded.dtype == sm3.dtype and reloaded.shape == sm3.shape + + # mimic hickle version 4.0.0 format to represent crs type matrix + h_datagroup,subitems = load_scipy.create_sparse_dataset(sm1,h5_data,"csr_matrix_filtered") + sparse_container = load_scipy.SparseMatrixContainer(h_datagroup.attrs,b'csr_matrix',load_scipy.return_first) + for name,item,attrs,kwargs in subitems: + h_dataset = h_datagroup.create_dataset(name,data=item) + if name == "data": + attrs["type"] = np.array(pickle.dumps(sm1.__class__)) + attrs["base_type"] = b'csr_matrix' + h_dataset.attrs.update(attrs) + + # check that dataset representin hickle 4.0.0 representaiton of sparse matrix + # is properly recognized by SparseMatrixContainer.filter method and sub items of + # sparse matrix group are properly adjusted to be safely loaded by SparseMatrixContainer + for name,h_dataset in sparse_container.filter(h_datagroup.items()): + if name == "shape": + sparse_container.append(name,tuple(h_dataset[()]),h_dataset.attrs) + else: + sparse_container.append(name,np.array(h_dataset[()]),h_dataset.attrs) + reloaded = sparse_container.convert() + assert np.all(reloaded.data == sm1.data) and reloaded.dtype == sm1.dtype and reloaded.shape == sm1.shape + + # verify that SparseMatrixContainer.filter method ignores any items which + # are not recognized by SparseMatrixContainer update or convert method + h_datagroup.create_dataset("ignoreme",data=12) + for name,h_dataset in sparse_container.filter(h_datagroup.items()): + if name == "shape": + sparse_container.append(name,tuple(h_dataset[()]),h_dataset.attrs) + else: + sparse_container.append(name,np.array(h_dataset[()]),h_dataset.attrs) + reloaded = sparse_container.convert() + assert np.all(reloaded.data == sm1.data) and reloaded.dtype == sm1.dtype and reloaded.shape == sm1.shape + + +# %% MAIN SCRIPT +if __name__ == "__main__": + from _pytest.fixtures import FixtureRequest + for h5_root in h5_data(FixtureRequest(test_create_sparse_dataset)): + test_create_sparse_dataset(h5_root) diff --git a/hickle/tests/test_06_load_astropy.py b/hickle/tests/test_06_load_astropy.py new file mode 100644 index 00000000..be7c5e66 --- /dev/null +++ b/hickle/tests/test_06_load_astropy.py @@ -0,0 +1,276 @@ +#! /usr/bin/env python +# encoding: utf-8 +""" +# test_load_astropy + +Unit tests for hickle module -- astropy loader. + +""" +# %% IMPORTS +# Package imports +import h5py as h5 +import numpy as np +import pytest +from astropy.units import Quantity +from astropy.time import Time +from astropy.coordinates import Angle, SkyCoord +import astropy.constants as apc +from astropy.table import Table +import numpy as np +from py.path import local + +# hickle imports +import hickle.loaders.load_astropy as load_astropy + + +# Set the current working directory to the temporary directory +local.get_temproot().chdir() + +# %% FIXTURES + +@pytest.fixture +def h5_data(request): + """ + create dummy hdf5 test data file for testing PyContainer and H5NodeFilterProxy + """ + dummy_file = h5.File('test_load_builtins.hdf5','w') + dummy_file = h5.File('load_numpy_{}.hdf5'.format(request.function.__name__),'w') + filename = dummy_file.filename + test_data = dummy_file.create_group("root_group") + yield test_data + dummy_file.close() + +# %% FUNCTION DEFINITIONS +def test_create_astropy_quantity(h5_data): + """ + test proper storage and loading of astorpy quantities + """ + + for index,uu in enumerate(['m^3', 'm^3 / s', 'kg/pc']): + a = Quantity(7, unit=uu) + h_dataset,subitems = load_astropy.create_astropy_quantity(a,h5_data,"quantity{}".format(index)) + assert isinstance(h_dataset,h5.Dataset) and not subitems and iter(subitems) + assert h_dataset.attrs['unit'] == a.unit.to_string().encode("ascii") and h_dataset[()] == a.value + reloaded = load_astropy.load_astropy_quantity_dataset(h_dataset,b'astropy_quantity',Quantity) + assert reloaded == a and reloaded.unit == a.unit + a *= a + h_dataset,subitems = load_astropy.create_astropy_quantity(a,h5_data,"quantity_sqr{}".format(index)) + assert isinstance(h_dataset,h5.Dataset) and not subitems and iter(subitems) + assert h_dataset.attrs['unit'] == a.unit.to_string().encode("ascii") and h_dataset[()] == a.value + reloaded = load_astropy.load_astropy_quantity_dataset(h_dataset,b'astropy_quantity',Quantity) + assert reloaded == a and reloaded.unit == a.unit + + +def test_create_astropy_constant(h5_data): + + """ + test proper storage and loading of astropy constants + """ + + h_dataset,subitems = load_astropy.create_astropy_constant(apc.G,h5_data,"apc_G") + assert isinstance(h_dataset,h5.Dataset) and not subitems and iter(subitems) + assert h_dataset.attrs["unit"] == apc.G.unit.to_string().encode('ascii') + assert h_dataset.attrs["abbrev"] == apc.G.abbrev.encode('ascii') + assert h_dataset.attrs["name"] == apc.G.name.encode('ascii') + assert h_dataset.attrs["reference"] == apc.G.reference.encode('ascii') + assert h_dataset.attrs["uncertainty"] == apc.G.uncertainty + reloaded = load_astropy.load_astropy_constant_dataset(h_dataset,b'astropy_constant',apc.G.__class__) + assert reloaded == apc.G and reloaded.dtype == apc.G.dtype + + h_dataset,subitems = load_astropy.create_astropy_constant(apc.cgs.e,h5_data,"apc_cgs_e") + assert isinstance(h_dataset,h5.Dataset) and not subitems and iter(subitems) + assert h_dataset.attrs["unit"] == apc.cgs.e.unit.to_string().encode('ascii') + assert h_dataset.attrs["abbrev"] == apc.cgs.e.abbrev.encode('ascii') + assert h_dataset.attrs["name"] == apc.cgs.e.name.encode('ascii') + assert h_dataset.attrs["reference"] == apc.cgs.e.reference.encode('ascii') + assert h_dataset.attrs["uncertainty"] == apc.cgs.e.uncertainty + assert h_dataset.attrs["system"] == apc.cgs.e.system.encode('ascii') + reloaded = load_astropy.load_astropy_constant_dataset(h_dataset,b'astropy_constant',apc.cgs.e.__class__) + assert reloaded == apc.cgs.e and reloaded.dtype == apc.cgs.e.dtype + + +def test_astropy_table(h5_data): + """ + test proper storage and loading of astropy table + """ + t = Table([[1, 2], [3, 4]], names=('a', 'b'), meta={'name': 'test_thing'}) + + h_dataset,subitems = load_astropy.create_astropy_table(t,h5_data,"astropy_table") + assert isinstance(h_dataset,h5.Dataset) and not subitems and iter(subitems) + assert np.all(h_dataset.attrs['colnames'] == [ cname.encode('ascii') for cname in t.colnames]) + for metakey,metavalue in t.meta.items(): + assert h_dataset.attrs[metakey] == metavalue + assert h_dataset.dtype == t.as_array().dtype + reloaded = load_astropy.load_astropy_table(h_dataset,b'astropy_table',t.__class__) + assert reloaded.meta == t.meta and reloaded.dtype == t.dtype + assert np.allclose(t['a'].astype('float32'),reloaded['a'].astype('float32')) + assert np.allclose(t['b'].astype('float32'),reloaded['b'].astype('float32')) + + +def test_astropy_quantity_array(h5_data): + """ + tet proper storage and loading of array of astropy quantities + """ + a = Quantity([1, 2, 3], unit='m') + h_dataset,subitems = load_astropy.create_astropy_quantity(a,h5_data,"quantity_array") + assert isinstance(h_dataset,h5.Dataset) and not subitems and iter(subitems) + assert h_dataset.attrs['unit'] == a.unit.to_string().encode("ascii") and np.all(h_dataset[()] == a.value) + reloaded = load_astropy.load_astropy_quantity_dataset(h_dataset,b'astropy_quantity',Quantity) + assert np.all(reloaded == a) and reloaded.unit == a.unit + + +def test_astropy_time_array(h5_data): + """ + test proper storage and loading of astropy time representations + """ + + times = ['1999-01-01T00:00:00.123456789', '2010-01-01T00:00:00'] + t1 = Time(times, format='isot', scale='utc') + + h_dataset,subitems = load_astropy.create_astropy_time(t1,h5_data,'time1') + assert isinstance(h_dataset,h5.Dataset) and not subitems and iter(subitems) + assert h_dataset.attrs['format'] == str(t1.format).encode('ascii') + assert h_dataset.attrs['scale'] == str(t1.scale).encode('ascii') + assert h_dataset.attrs['np_dtype'] == t1.value.dtype.str.encode('ascii') + reloaded = load_astropy.load_astropy_time_dataset(h_dataset,b'astropy_time',t1.__class__) + assert reloaded.value.shape == t1.value.shape + assert reloaded.format == t1.format + assert reloaded.scale == t1.scale + for index in range(len(t1)): + assert reloaded.value[index] == t1.value[index] + del h_dataset.attrs['np_dtype'] + + reloaded = load_astropy.load_astropy_time_dataset(h_dataset,b'astropy_time',t1.__class__) + assert reloaded.value.shape == t1.value.shape + assert reloaded.format == t1.format + assert reloaded.scale == t1.scale + for index in range(len(t1)): + assert reloaded.value[index] == t1.value[index] + + times = [58264, 58265, 58266] + t1 = Time(times, format='mjd', scale='utc') + h_dataset,subitems = load_astropy.create_astropy_time(t1,h5_data,'time2') + assert isinstance(h_dataset,h5.Dataset) and not subitems and iter(subitems) + assert h_dataset.attrs['format'] == str(t1.format).encode('ascii') + assert h_dataset.attrs['scale'] == str(t1.scale).encode('ascii') + assert h_dataset.attrs['np_dtype'] == t1.value.dtype.str.encode('ascii') + reloaded = load_astropy.load_astropy_time_dataset(h_dataset,b'astropy_time',t1.__class__) + assert reloaded.value.shape == t1.value.shape + assert reloaded.format == t1.format + assert reloaded.scale == t1.scale + for index in range(len(t1)): + assert reloaded.value[index] == t1.value[index] + + +def test_astropy_angle(h5_data): + """ + test proper storage of astropy angles + """ + + for index,uu in enumerate(['radian', 'degree']): + a = Angle(1.02, unit=uu) + h_dataset,subitems = load_astropy.create_astropy_angle(a,h5_data,"angle_{}".format(uu)) + assert isinstance(h_dataset,h5.Dataset) and not subitems and iter(subitems) + assert h_dataset.attrs['unit'] == a.unit.to_string().encode('ascii') + assert h_dataset[()] == a.value + reloaded = load_astropy.load_astropy_angle_dataset(h_dataset,b'astropy_angle',a.__class__) + assert reloaded == a and reloaded.unit == a.unit + + +def test_astropy_angle_array(h5_data): + """ + test proper storage and loading of arrays of astropy angles + """ + a = Angle([1, 2, 3], unit='degree') + h_dataset,subitems = load_astropy.create_astropy_angle(a,h5_data,"angle_array") + assert isinstance(h_dataset,h5.Dataset) and not subitems and iter(subitems) + assert h_dataset.attrs['unit'] == a.unit.to_string().encode('ascii') + assert np.allclose(h_dataset[()] , a.value ) + reloaded = load_astropy.load_astropy_angle_dataset(h_dataset,b'astropy_angle',a.__class__) + assert np.all(reloaded == a) and reloaded.unit == a.unit + +def test_astropy_skycoord(h5_data): + """ + test proper storage and loading of astropy sky coordinates + """ + + ra = Angle('1d20m', unit='degree') + dec = Angle('33d0m0s', unit='degree') + radec = SkyCoord(ra, dec) + h_dataset,subitems = load_astropy.create_astropy_skycoord(radec,h5_data,"astropy_skycoord_1") + assert isinstance(h_dataset,h5.Dataset) and not subitems and iter(subitems) + assert h_dataset[()][...,0] == radec.data.lon.value + assert h_dataset[()][...,1] == radec.data.lat.value + assert h_dataset.attrs['lon_unit'] == radec.data.lon.unit.to_string().encode('ascii') + assert h_dataset.attrs['lat_unit'] == radec.data.lat.unit.to_string().encode('ascii') + reloaded = load_astropy.load_astropy_skycoord_dataset(h_dataset,b'astropy_skycoord',radec.__class__) + assert np.allclose(reloaded.ra.value,radec.ra.value) + assert np.allclose(reloaded.dec.value,radec.dec.value) + + ra = Angle('1d20m', unit='hourangle') + dec = Angle('33d0m0s', unit='degree') + radec = SkyCoord(ra, dec) + h_dataset,subitems = load_astropy.create_astropy_skycoord(radec,h5_data,"astropy_skycoord_2") + assert isinstance(h_dataset,h5.Dataset) and not subitems and iter(subitems) + assert h_dataset[()][...,0] == radec.data.lon.value + assert h_dataset[()][...,1] == radec.data.lat.value + assert h_dataset.attrs['lon_unit'] == radec.data.lon.unit.to_string().encode('ascii') + assert h_dataset.attrs['lat_unit'] == radec.data.lat.unit.to_string().encode('ascii') + reloaded = load_astropy.load_astropy_skycoord_dataset(h_dataset,b'astropy_skycoord',radec.__class__) + assert reloaded.ra.value == radec.ra.value + assert reloaded.dec.value == radec.dec.value + +def test_astropy_skycoord_array(h5_data): + """ + test proper storage and loading of astropy sky coordinates + """ + + ra = Angle(['1d20m', '0d21m'], unit='degree') + dec = Angle(['33d0m0s', '-33d01m'], unit='degree') + radec = SkyCoord(ra, dec) + h_dataset,subitems = load_astropy.create_astropy_skycoord(radec,h5_data,"astropy_skycoord_1") + assert isinstance(h_dataset,h5.Dataset) and not subitems and iter(subitems) + assert np.allclose(h_dataset[()][...,0],radec.data.lon.value) + assert np.allclose(h_dataset[()][...,1],radec.data.lat.value) + assert h_dataset.attrs['lon_unit'] == radec.data.lon.unit.to_string().encode('ascii') + assert h_dataset.attrs['lat_unit'] == radec.data.lat.unit.to_string().encode('ascii') + reloaded = load_astropy.load_astropy_skycoord_dataset(h_dataset,b'astropy_skycoord',radec.__class__) + assert np.allclose(reloaded.ra.value,radec.ra.value) + assert np.allclose(reloaded.dec.value,radec.dec.value) + + ra = Angle([['1d20m', '0d21m'], ['1d20m', '0d21m']], unit='hourangle') + dec = Angle([['33d0m0s', '33d01m'], ['33d0m0s', '33d01m']], unit='degree') + radec = SkyCoord(ra, dec) + h_dataset,subitems = load_astropy.create_astropy_skycoord(radec,h5_data,"astropy_skycoord_2") + assert isinstance(h_dataset,h5.Dataset) and not subitems and iter(subitems) + assert np.allclose(h_dataset[()][...,0],radec.data.lon.value) + assert np.allclose(h_dataset[()][...,1],radec.data.lat.value) + assert h_dataset.attrs['lon_unit'] == radec.data.lon.unit.to_string().encode('ascii') + assert h_dataset.attrs['lat_unit'] == radec.data.lat.unit.to_string().encode('ascii') + reloaded = load_astropy.load_astropy_skycoord_dataset(h_dataset,b'astropy_skycoord',radec.__class__) + assert np.allclose(reloaded.ra.value,radec.ra.value) + assert np.allclose(reloaded.dec.value,radec.dec.value) + assert reloaded.ra.shape == radec.ra.shape + assert reloaded.dec.shape == radec.dec.shape + +# %% MAIN SCRIPT +if __name__ == "__main__": + from _pytest.fixtures import FixtureRequest + for h5_root in h5_data(FixtureRequest(test_create_astropy_quantity)): + test_create_astropy_quantity(h5_root) + for h5_root in h5_data(FixtureRequest(test_create_astropy_constant)): + test_create_astropy_constant(h5_root) + for h5_root in h5_data(FixtureRequest(test_astropy_table)): + test_astropy_table(h5_root) + for h5_root in h5_data(FixtureRequest(test_astropy_quantity_array)): + test_astropy_quantity_array(h5_root) + for h5_root in h5_data(FixtureRequest(test_astropy_time_array)): + test_astropy_time_array(h5_root) + for h5_root in h5_data(FixtureRequest(test_astropy_angle)): + test_astropy_angle(h5_root) + for h5_root in h5_data(FixtureRequest(test_astropy_angle_array)): + test_astropy_angle_array(h5_root) + for h5_root in h5_data(FixtureRequest(test_astropy_skycoord)): + test_astropy_skycoord(h5_root) + for h5_root in h5_data(FixtureRequest(test_astropy_skycoord_array)): + test_astropy_skycoord_array(h5_root) diff --git a/hickle/tests/test_99_hickle_core.py b/hickle/tests/test_99_hickle_core.py new file mode 100644 index 00000000..0dd9dd10 --- /dev/null +++ b/hickle/tests/test_99_hickle_core.py @@ -0,0 +1,384 @@ +#! /usr/bin/env python +# encoding: utf-8 +""" +# test_hickle.py + +Unit tests for hickle module. + +""" + + +# %% IMPORTS +# Built-in imports +from collections import OrderedDict as odict +import os +import re +from pprint import pprint + + +# Package imports +import pytest +import dill as pickle +import h5py +import numpy as np +from py.path import local + +# hickle imports +from hickle import dump, helpers, hickle, load, lookup + +# Set current working directory to the temporary directory +local.get_temproot().chdir() + + +# %% GLOBALS + +# %% HELPER DEFINITIONS + +# %% FIXTURES + +@pytest.fixture +def h5_data(request): + """ + create dummy hdf5 test data file for testing PyContainer and H5NodeFilterProxy + """ + import h5py as h5 + dummy_file = h5.File('hickle_lookup_{}.hdf5'.format(request.function.__name__),'w') + filename = dummy_file.filename + test_data = dummy_file.create_group("root_group") + yield test_data + dummy_file.close() + +@pytest.fixture +def test_file_name(request): + yield "{}.hkl".format(request.function.__name__) + +# %% FUNCTION DEFINITIONS + +def test_file_opener(h5_data,test_file_name): + """ + test file opener function + """ + + # check that file like object is properly initialized for writing + filename = test_file_name.replace(".hkl","_{}.{}") + with open(filename.format("w",".hdf5"),"w") as f: + h5_file,path,close_flag = hickle.file_opener(f,"root","w") + assert isinstance(h5_file,h5py.File) and path == "/root" and h5_file.mode == 'r+' + h5_file.close() + + # check that file like object is properly initialized for reading + with open(filename.format("w",".hdf5"),"r") as f: + h5_file,path,close_flag = hickle.file_opener(f,"root","r") + assert isinstance(h5_file,h5py.File) and path == "/root" and h5_file.mode == 'r' + h5_file.close() + + # check that h5py.File object is properly intialized for writing + h5_file,path,close_flag = hickle.file_opener(h5_data,"","w") + assert isinstance(h5_file,h5py.File) and path == "/root_group" + assert h5_file.mode == 'r+' and not close_flag + + # check that a new file is created for provided filename and properly intialized + h5_file,path,close_flag = hickle.file_opener(filename.format("w",".hkl"),"root_group","w") + assert isinstance(h5_file,h5py.File) and path == "/root_group" + assert h5_file.mode == 'r+' and close_flag + h5_file.close() + + # check that any other object not beein a file like object, a h5py.File object or + # a filename string triggers an FileError exception + with pytest.raises( + hickle.FileError, + match = r"Cannot\s+open\s+file.\s+Please\s+pass\s+either\s+a\s+" + r"filename\s+string,\s+a\s+file\s+object,\s+or\s+a\s+h5py.File" + ): + h5_file,path,close_flag = hickle.file_opener(dict(),"root_group","w") + +def test_recursive_dump(h5_data): + """ + test _dump function and that it properly calls itself recursively + """ + + # check that dump function properly creates a list dataset and + # sets appropriate values for 'type' and 'base_type' attributes + data = simple_list = [1,2,3,4] + hickle._dump(data, h5_data, "simple_list") + dumped_data = h5_data["simple_list"] + assert dumped_data.attrs['type'] == pickle.dumps(data.__class__) + assert dumped_data.attrs['base_type'] == b'list' + assert np.all(dumped_data[()] == simple_list) + + # check that dump function properly creats a group representing + # a dictionary and its keys and values and sets appropriate values + # for 'type', 'base_type' and 'key_base_type' attributes + data = { + '12':12, + (1,2,3):'hallo' + } + hickle._dump(data, h5_data, "some_dict") + dumped_data = h5_data["some_dict"] + assert dumped_data.attrs['type'] == pickle.dumps(data.__class__) + + # check that the name of the resulting dataset for the first dict item + # resembles double quouted string key and 'type', 'base_type 'key_base_type' + # attributes the resulting dataset are set accordingly + assert dumped_data.attrs['base_type'] == b'dict' + first_item = dumped_data['"12"'] + assert first_item[()] == 12 and first_item.attrs['key_base_type'] == b'str' + assert first_item.attrs['base_type'] == b'int' + assert first_item.attrs['type'] == pickle.dumps(data['12'].__class__) + + # check that second item is converted into key value pair group, that + # the name of that group reads 'data0' and that 'type', 'base_type' and + # 'key_base_type' attributes are set accordingly + second_item = dumped_data.get("data0",None) + if second_item is None: + second_item = dumped_data["data1"] + assert second_item.attrs['key_base_type'] == b'key_value' + assert second_item.attrs['base_type'] == b'tuple' + assert second_item.attrs['type'] == pickle.dumps(tuple) + + # check that content of key value pair group resembles key and value of + # second dict item + key = second_item['data0'] + value = second_item['data1'] + assert np.all(key[()] == (1,2,3)) and key.attrs['base_type'] == b'tuple' + assert key.attrs['type'] == pickle.dumps(tuple) + assert bytes(value[()]) == 'hallo'.encode('utf8') and value.attrs['base_type'] == b'str' + assert value.attrs['type'] == pickle.dumps(str) + + # check that objects for which no loader has been registred or for which + # available loader raises NotHicklable exception are handled by + # create_pickled_dataset function + backup_dict_loader = lookup.types_dict[dict] + def fail_create_dict(py_obj,h_group,name,**kwargs): + raise helpers.NotHicklable("test loader shrugg") + lookup.types_dict[dict] = fail_create_dict,backup_dict_loader[1] + hickle._dump(data, h5_data, "pickled_dict") + dumped_data = h5_data["pickled_dict"] + lookup.types_dict[dict] = backup_dict_loader + assert set(key for key in dumped_data.keys()) == {'create','args','keys','values'} + +def test_recursive_load(h5_data): + """ + test _load function and that it properly calls itself recursively + """ + + # check that simple scalar value is properly restored on load from + # corresponding dataset + data = 42 + data_name = "the_answer" + hickle._dump(data, h5_data, data_name) + py_container = hickle.RootContainer(h5_data.attrs,b'hickle_root',hickle.RootContainer) + hickle._load(py_container, data_name, h5_data[data_name]) + assert py_container.convert() == data + + # check that dict object is properly restored on load from corresponding group + data = {'question':None,'answer':42} + data_name = "not_formulated" + hickle._dump(data, h5_data, data_name) + py_container = hickle.RootContainer(h5_data.attrs,b'hickle_root',hickle.RootContainer) + hickle._load(py_container, data_name, h5_data[data_name]) + assert py_container.convert() == data + + + # check that objects for which no loader has been registred or for which + # available loader raises NotHicklable exception are properly restored on load + # from corresponding copy protocol group or pickled data string + backup_dict_loader = lookup.types_dict[dict] + def fail_create_dict(py_obj,h_group,name,**kwargs): + raise helpers.NotHicklable("test loader shrugg") + lookup.types_dict[dict] = fail_create_dict,backup_dict_loader[1] + data_name = "pickled_dict" + hickle._dump(data, h5_data, data_name) + hickle._load(py_container, data_name, h5_data[data_name]) + assert py_container.convert() == data + lookup.types_dict[dict] = backup_dict_loader + +# %% ISSUE RELATED TESTS + +def test_invalid_file(): + """ Test if trying to use a non-file object fails. """ + + with pytest.raises(hickle.FileError): + dump('test', ()) + + +def test_binary_file(test_file_name): + """ Test if using a binary file works + + https://github.com/telegraphic/hickle/issues/123""" + + filename = test_file_name.replace(".hkl",".hdf5") + with open(filename, "w") as f: + hickle.dump(None, f) + + with open(filename, "wb") as f: + hickle.dump(None, f) + + +def test_file_open_close(test_file_name,h5_data): + """ https://github.com/telegraphic/hickle/issues/20 """ + import h5py + f = h5py.File(test_file_name.replace(".hkl",".hdf"), 'w') + a = np.arange(5) + + dump(a, test_file_name) + dump(a, test_file_name) + + dump(a, f, mode='w') + f.close() + with pytest.raises(hickle.ClosedFileError): + dump(a, f, mode='w') + h5_data.create_dataset('nothing',data=[]) + with pytest.raises(ValueError,match = r"Unable\s+to\s+create\s+group\s+\(name\s+already\s+exists\)"): + dump(a,h5_data.file,path="/root_group") + + +def test_hdf5_group(test_file_name): + import h5py + hdf5_filename = test_file_name.replace(".hkl",".hdf5") + file = h5py.File(hdf5_filename, 'w') + group = file.create_group('test_group') + a = np.arange(5) + dump(a, group) + file.close() + + a_hkl = load(hdf5_filename, path='/test_group') + assert np.allclose(a_hkl, a) + + file = h5py.File(hdf5_filename, 'r+') + group = file.create_group('test_group2') + b = np.arange(8) + + dump(b, group, path='deeper/and_deeper') + file.close() + + b_hkl = load(hdf5_filename, path='/test_group2/deeper/and_deeper') + assert np.allclose(b_hkl, b) + + file = h5py.File(hdf5_filename, 'r') + b_hkl2 = load(file['test_group2'], path='deeper/and_deeper') + assert np.allclose(b_hkl2, b) + file.close() + + + +def test_with_open_file(test_file_name): + """ + Testing dumping and loading to an open file + + https://github.com/telegraphic/hickle/issues/92""" + + lst = [1] + tpl = (1,) + dct = {1: 1} + arr = np.array([1]) + + with h5py.File(test_file_name, 'w') as file: + dump(lst, file, path='/lst') + dump(tpl, file, path='/tpl') + dump(dct, file, path='/dct') + dump(arr, file, path='/arr') + + with h5py.File(test_file_name, 'r') as file: + assert load(file, '/lst') == lst + assert load(file, '/tpl') == tpl + assert load(file, '/dct') == dct + assert load(file, '/arr') == arr + + +def test_load(test_file_name): + a = set([1, 2, 3, 4]) + b = set([5, 6, 7, 8]) + c = set([9, 10, 11, 12]) + z = (a, b, c) + z = [z, z] + z = (z, z, z, z, z) + + print("Original:") + pprint(z) + dump(z, test_file_name, mode='w') + + print("\nReconstructed:") + z = load(test_file_name) + pprint(z) + + + + +def test_multi_hickle(test_file_name): + """ Dumping to and loading from the same file several times + + https://github.com/telegraphic/hickle/issues/20""" + + a = {'a': 123, 'b': [1, 2, 4]} + + if os.path.exists(test_file_name): + os.remove(test_file_name) + dump(a, test_file_name, path="/test", mode="w") + dump(a, test_file_name, path="/test2", mode="r+") + dump(a, test_file_name, path="/test3", mode="r+") + dump(a, test_file_name, path="/test4", mode="r+") + + load(test_file_name, path="/test") + load(test_file_name, path="/test2") + load(test_file_name, path="/test3") + load(test_file_name, path="/test4") + + +def test_improper_attrs(test_file_name): + """ + test for proper reporting missing mandatory attributes for the various + supported file versions + """ + + # check that missing attributes which disallow to identify + # hickle version are reported + data = "my name? Ha I'm Nobody" + dump(data,test_file_name) + manipulated = h5py.File(test_file_name,"r+") + root_group = manipulated.get('/') + root_group.attrs["VERSION"] = root_group.attrs["HICKLE_VERSION"] + del root_group.attrs["HICKLE_VERSION"] + manipulated.flush() + with pytest.raises( + ValueError, + match= r"Provided\s+argument\s+'file_obj'\s+does\s+not\s+appear" + r"\s+to\s+be\s+a\s+valid\s+hickle\s+file!.*" + ): + load(manipulated) + + +# %% MAIN SCRIPT +if __name__ == '__main__': + """ Some tests and examples """ + from _pytest.fixtures import FixtureRequest + + for h5_root,filename in ( + ( h5_data(request),test_file_name(request) ) + for request in (FixtureRequest(test_file_opener),) + ): + test_file_opener(h5_root,filename) + for h5_root in h5_data(FixtureRequest(test_recursive_dump)): + test_recursive_dump(h5_root) + for h5_root in h5_data(FixtureRequest(test_recursive_load)): + test_recursive_load(h5_root) + test_invalid_file() + for filename in test_file_name(FixtureRequest(test_binary_file)): + test_binary_file(filename) + for h5_root,filename in ( + ( h5_data(request),test_file_name(request) ) + for request in (FixtureRequest(test_file_open_close),) + ): + test_file_open_close(h5_root,filename) + for filename in test_file_name(FixtureRequest(test_hdf5_group)): + test_hdf5_group(filename) + for filename in test_file_name(FixtureRequest(test_with_open_file)): + test_with_open_file(filename) + + for filename in test_file_name(FixtureRequest(test_load)): + test_load(filename) + for filename in test_file_name(FixtureRequest(test_multi_hickle)): + test_multi_hickle(filename) + for filename in test_file_name(FixtureRequest(test_improper_attrs)): + test_improper_attrs(filename) + diff --git a/hickle/tests/test_astropy.py b/hickle/tests/test_astropy.py deleted file mode 100644 index 4c2caf96..00000000 --- a/hickle/tests/test_astropy.py +++ /dev/null @@ -1,172 +0,0 @@ -# %% IMPORTS -# Package imports -from astropy.units import Quantity -from astropy.time import Time -from astropy.coordinates import Angle, SkyCoord -import astropy.constants as apc -from astropy.table import Table -import numpy as np -from py.path import local - -# hickle imports -import hickle as hkl - -# Set the current working directory to the temporary directory -local.get_temproot().chdir() - - -# %% FUNCTION DEFINITIONS -def test_astropy_quantity(): - for uu in ['m^3', 'm^3 / s', 'kg/pc']: - a = Quantity(7, unit=uu) - - hkl.dump(a, "test_ap.h5") - b = hkl.load("test_ap.h5") - - assert a == b - assert a.unit == b.unit - - a *= a - hkl.dump(a, "test_ap.h5") - b = hkl.load("test_ap.h5") - assert a == b - assert a.unit == b.unit - - -def test_astropy_constant(): - hkl.dump(apc.G, "test_ap.h5") - gg = hkl.load("test_ap.h5") - assert gg == apc.G - - hkl.dump(apc.cgs.e, 'test_ap.h5') - ee = hkl.load('test_ap.h5') - assert ee == apc.cgs.e - - -def test_astropy_table(): - t = Table([[1, 2], [3, 4]], names=('a', 'b'), meta={'name': 'test_thing'}) - - hkl.dump({'a': t}, "test_ap.h5") - t2 = hkl.load("test_ap.h5")['a'] - - print(t) - print(t.meta) - print(t2) - print(t2.meta) - - print(t.dtype, t2.dtype) - assert t.meta == t2.meta - assert t.dtype == t2.dtype - - assert np.allclose(t['a'].astype('float32'), t2['a'].astype('float32')) - assert np.allclose(t['b'].astype('float32'), t2['b'].astype('float32')) - - -def test_astropy_quantity_array(): - a = Quantity([1, 2, 3], unit='m') - - hkl.dump(a, "test_ap.h5") - b = hkl.load("test_ap.h5") - - assert np.allclose(a.value, b.value) - assert a.unit == b.unit - - -def test_astropy_time_array(): - times = ['1999-01-01T00:00:00.123456789', '2010-01-01T00:00:00'] - t1 = Time(times, format='isot', scale='utc') - hkl.dump(t1, "test_ap2.h5") - t2 = hkl.load("test_ap2.h5") - - print(t1) - print(t2) - assert t1.value.shape == t2.value.shape - for ii in range(len(t1)): - assert t1.value[ii] == t2.value[ii] - assert t1.format == t2.format - assert t1.scale == t2.scale - - times = [58264, 58265, 58266] - t1 = Time(times, format='mjd', scale='utc') - hkl.dump(t1, "test_ap2.h5") - t2 = hkl.load("test_ap2.h5") - - print(t1) - print(t2) - assert t1.value.shape == t2.value.shape - assert np.allclose(t1.value, t2.value) - assert t1.format == t2.format - assert t1.scale == t2.scale - - -def test_astropy_angle(): - for uu in ['radian', 'degree']: - a = Angle(1.02, unit=uu) - - hkl.dump(a, "test_ap.h5") - b = hkl.load("test_ap.h5") - assert a == b - assert a.unit == b.unit - - -def test_astropy_angle_array(): - a = Angle([1, 2, 3], unit='degree') - - hkl.dump(a, "test_ap.h5") - b = hkl.load("test_ap.h5") - - assert np.allclose(a.value, b.value) - assert a.unit == b.unit - - -def test_astropy_skycoord(): - ra = Angle('1d20m', unit='degree') - dec = Angle('33d0m0s', unit='degree') - radec = SkyCoord(ra, dec) - hkl.dump(radec, "test_ap.h5") - radec2 = hkl.load("test_ap.h5") - assert radec.ra == radec2.ra - assert radec.dec == radec2.dec - - ra = Angle('1d20m', unit='hourangle') - dec = Angle('33d0m0s', unit='degree') - radec = SkyCoord(ra, dec) - hkl.dump(radec, "test_ap.h5") - radec2 = hkl.load("test_ap.h5") - assert radec.ra == radec2.ra - assert radec.dec == radec2.dec - - -def test_astropy_skycoord_array(): - ra = Angle(['1d20m', '0d21m'], unit='degree') - dec = Angle(['33d0m0s', '-33d01m'], unit='degree') - radec = SkyCoord(ra, dec) - hkl.dump(radec, "test_ap.h5") - radec2 = hkl.load("test_ap.h5") - assert np.allclose(radec.ra.value, radec2.ra.value) - assert np.allclose(radec.dec.value, radec2.dec.value) - assert radec.ra.shape == radec2.ra.shape - assert radec.dec.shape == radec2.dec.shape - - ra = Angle([['1d20m', '0d21m'], ['1d20m', '0d21m']], unit='hourangle') - dec = Angle([['33d0m0s', '33d01m'], ['33d0m0s', '33d01m']], unit='degree') - radec = SkyCoord(ra, dec) - hkl.dump(radec, "test_ap.h5") - radec2 = hkl.load("test_ap.h5") - assert np.allclose(radec.ra.value, radec2.ra.value) - assert np.allclose(radec.dec.value, radec2.dec.value) - assert radec.ra.shape == radec2.ra.shape - assert radec.dec.shape == radec2.dec.shape - - -# %% MAIN SCRIPT -if __name__ == "__main__": - test_astropy_quantity() - test_astropy_constant() - test_astropy_table() - test_astropy_quantity_array() - test_astropy_time_array() - test_astropy_angle() - test_astropy_angle_array() - test_astropy_skycoord() - test_astropy_skycoord_array() diff --git a/hickle/tests/test_hickle.py b/hickle/tests/test_hickle.py index 5ecfb3db..e0fd3f99 100644 --- a/hickle/tests/test_hickle.py +++ b/hickle/tests/test_hickle.py @@ -3,25 +3,28 @@ """ # test_hickle.py -Unit tests for hickle module. +Unit test for hickle package. """ # %% IMPORTS + # Built-in imports from collections import OrderedDict as odict import os +import re from pprint import pprint +import dill as pickle + # Package imports -import h5py import numpy as np from py.path import local import pytest # hickle imports -from hickle import dump, helpers, hickle, load, loaders +from hickle import dump, hickle, load # Set current working directory to the temporary directory local.get_temproot().chdir() @@ -47,14 +50,52 @@ } +# %% FIXTURES + +@pytest.fixture +def test_file_name(request): + """ + create test dependent filename path string + """ + yield "{}.hkl".format(request.function.__name__) + + # %% HELPER DEFINITIONS + # Define a test function that must be serialized and unpacked again def func(a, b, c=0): + """ just somethin to do """ return(a, b, c) +# the following is required as package name of with_state is hickle +# and load_loader refuses load any loader module for classes defined inside +# hickle package exempt when defined within load_*.py loaders modules. +# That has to be done by hickle sub modules directly using register_class function +pickle_dumps = pickle.dumps +pickle_loads = pickle.loads + +def make_visible_to_dumps(obj,protocol=None,*,fix_imports=True): + """ + simulate loader functions defined outside hickle package + """ + if obj in {with_state}: + obj.__module__ = re.sub(r'^\s*(?!hickle\.)','hickle.',obj.__module__) + return pickle_dumps(obj,protocol,fix_imports=fix_imports) + +def hide_from_hickle(bytes_obj,*,fix_imports=True,encoding="ASCII",errors="strict"): + """ + simulat loader function defined outside hickle package + """ + obj = pickle_loads(bytes_obj,fix_imports = fix_imports, encoding = encoding, errors = errors) + if obj in {with_state}: + obj.__module__ = re.sub(r'^\s*hickle\.','',obj.__module__) + return obj # Define a class that must always be pickled class with_state(object): + """ + A class that allways must be handled by create_pickled_dataset + """ def __init__(self): self.a = 12 self.b = { @@ -89,82 +130,75 @@ def test_invalid_file(): dump('test', ()) -def test_state_obj(): +def test_state_obj(monkeypatch,test_file_name): """ Dumping and loading a class object with pickle states https://github.com/telegraphic/hickle/issues/125""" - filename, mode = 'test.h5', 'w' - obj = with_state() - with pytest.warns(loaders.load_builtins.SerializedWarning): - dump(obj, filename, mode) - obj_hkl = load(filename) - assert type(obj) == type(obj_hkl) - assert np.allclose(obj[1], obj_hkl[1]) + + with monkeypatch.context() as monkey: + monkey.setattr(with_state,'__module__',re.sub(r'^\s*hickle\.','',with_state.__module__)) + monkey.setattr(pickle,'dumps',make_visible_to_dumps) + mode = 'w' + obj = with_state() + #with pytest.warns(SerializedWarning): + dump(obj, test_file_name, mode) + monkey.setattr(pickle,'loads',hide_from_hickle) + obj_hkl = load(test_file_name) + assert isinstance(obj,obj_hkl.__class__) or isinstance(obj_hkl,obj.__class__) + assert np.allclose(obj[1], obj_hkl[1]) -def test_local_func(): +def test_local_func(test_file_name): """ Dumping and loading a local function https://github.com/telegraphic/hickle/issues/119""" - filename, mode = 'test.h5', 'w' - with pytest.warns(loaders.load_builtins.SerializedWarning): - dump(func, filename, mode) - func_hkl = load(filename) - assert type(func) == type(func_hkl) - assert func(1, 2) == func_hkl(1, 2) - - -def test_binary_file(): - """ Test if using a binary file works - - https://github.com/telegraphic/hickle/issues/123""" - with open("test.hdf5", "w") as f: - hickle.dump(None, f) - - with open("test.hdf5", "wb") as f: - hickle.dump(None, f) + mode = 'w' + dump(func, test_file_name, mode) + func_hkl = load(test_file_name) + assert isinstance(func,func_hkl.__class__) or isinstance(func_hkl,func.__class__) + assert func(1, 2) == func_hkl(1, 2) -def test_non_empty_group(): +def test_non_empty_group(test_file_name): """ Test if attempting to dump to a group with data fails """ - hickle.dump(None, 'test.hdf5') + hickle.dump(None, test_file_name) with pytest.raises(ValueError): - dump(None, 'test.hdf5', 'r+') + dump(None, test_file_name, 'r+') -def test_string(): +def test_string(test_file_name): """ Dumping and loading a string """ - filename, mode = 'test.h5', 'w' + mode = 'w' string_obj = "The quick brown fox jumps over the lazy dog" - dump(string_obj, filename, mode) - string_hkl = load(filename) + dump(string_obj, test_file_name, mode) + string_hkl = load(test_file_name) assert isinstance(string_hkl, str) assert string_obj == string_hkl -def test_65bit_int(): +def test_65bit_int(test_file_name): """ Dumping and loading an integer with arbitrary precision https://github.com/telegraphic/hickle/issues/113""" i = 2**65-1 - dump(i, 'test.hdf5') - i_hkl = load('test.hdf5') + dump(i, test_file_name) + i_hkl = load(test_file_name) assert i == i_hkl -def test_list(): +def test_list(test_file_name): """ Dumping and loading a list """ filename, mode = 'test_list.h5', 'w' list_obj = [1, 2, 3, 4, 5] - dump(list_obj, filename, mode=mode) - list_hkl = load(filename) + dump(list_obj, test_file_name, mode=mode) + list_hkl = load(test_file_name) try: assert isinstance(list_hkl, list) assert list_obj == list_hkl import h5py - a = h5py.File(filename, 'r') + a = h5py.File(test_file_name, 'r') a.close() except AssertionError: @@ -174,12 +208,12 @@ def test_list(): raise -def test_set(): +def test_set(test_file_name) : """ Dumping and loading a list """ - filename, mode = 'test_set.h5', 'w' + mode = 'w' list_obj = set([1, 0, 3, 4.5, 11.2]) - dump(list_obj, filename, mode) - list_hkl = load(filename) + dump(list_obj, test_file_name, mode) + list_hkl = load(test_file_name) try: assert isinstance(list_hkl, set) assert list_obj == list_hkl @@ -189,15 +223,15 @@ def test_set(): raise -def test_numpy(): +def test_numpy(test_file_name): """ Dumping and loading numpy array """ - filename, mode = 'test.h5', 'w' + mode = 'w' dtypes = ['float32', 'float64', 'complex64', 'complex128'] for dt in dtypes: array_obj = np.ones(8, dtype=dt) - dump(array_obj, filename, mode) - array_hkl = load(filename) + dump(array_obj, test_file_name, mode) + array_hkl = load(test_file_name) try: assert array_hkl.dtype == array_obj.dtype assert np.all((array_hkl, array_obj)) @@ -207,13 +241,13 @@ def test_numpy(): raise -def test_masked(): +def test_masked(test_file_name): """ Test masked numpy array """ - filename, mode = 'test.h5', 'w' + mode = 'w' a = np.ma.array([1, 2, 3, 4], dtype='float32', mask=[0, 1, 0, 0]) - dump(a, filename, mode) - a_hkl = load(filename) + dump(a, test_file_name, mode) + a_hkl = load(test_file_name) try: assert a_hkl.dtype == a.dtype @@ -224,47 +258,47 @@ def test_masked(): raise -def test_object_numpy(): +def test_object_numpy(test_file_name): """ Dumping and loading a NumPy array containing non-NumPy objects. https://github.com/telegraphic/hickle/issues/90""" arr = np.array([[NESTED_DICT], ('What is this?',), {1, 2, 3, 7, 1}]) - dump(arr, 'test.hdf5') - arr_hkl = load('test.hdf5') + dump(arr, test_file_name) + arr_hkl = load(test_file_name) assert np.all(arr == arr_hkl) arr2 = np.array(NESTED_DICT) - dump(arr2, 'test.hdf5') - arr_hkl2 = load('test.hdf5') + dump(arr2, test_file_name) + arr_hkl2 = load(test_file_name) assert np.all(arr2 == arr_hkl2) -def test_string_numpy(): +def test_string_numpy(test_file_name): """ Dumping and loading NumPy arrays containing Python 3 strings. """ arr = np.array(["1313e", "was", "maybe?", "here"]) - dump(arr, 'test.hdf5') - arr_hkl = load('test.hdf5') + dump(arr, test_file_name) + arr_hkl = load(test_file_name) assert np.all(arr == arr_hkl) -def test_list_object_numpy(): +def test_list_object_numpy(test_file_name): """ Dumping and loading a list of NumPy arrays with objects. https://github.com/telegraphic/hickle/issues/90""" lst = [np.array(NESTED_DICT), np.array([('What is this?',), {1, 2, 3, 7, 1}])] - dump(lst, 'test.hdf5') - lst_hkl = load('test.hdf5') + dump(lst, test_file_name) + lst_hkl = load(test_file_name) assert np.all(lst[0] == lst_hkl[0]) assert np.all(lst[1] == lst_hkl[1]) -def test_dict(): +def test_dict(test_file_name): """ Test dictionary dumping and loading """ - filename, mode = 'test.h5', 'w' + mode = 'w' dd = { 'name': b'Danny', @@ -275,8 +309,8 @@ def test_dict(): 'narr': np.array([1, 2, 3]), } - dump(dd, filename, mode) - dd_hkl = load(filename) + dump(dd, test_file_name, mode) + dd_hkl = load(test_file_name) for k in dd.keys(): try: @@ -295,15 +329,15 @@ def test_dict(): raise -def test_odict(): +def test_odict(test_file_name): """ Test ordered dictionary dumping and loading https://github.com/telegraphic/hickle/issues/65""" - filename, mode = 'test.hdf5', 'w' + mode = 'w' od = odict(((3, [3, 0.1]), (7, [5, 0.1]), (5, [3, 0.1]))) - dump(od, filename, mode) - od_hkl = load(filename) + dump(od, test_file_name, mode) + od_hkl = load(test_file_name) assert od.keys() == od_hkl.keys() @@ -311,20 +345,20 @@ def test_odict(): assert od_item == od_hkl_item -def test_empty_dict(): +def test_empty_dict(test_file_name): """ Test empty dictionary dumping and loading https://github.com/telegraphic/hickle/issues/91""" - filename, mode = 'test.h5', 'w' + mode = 'w' - dump({}, filename, mode) - assert load(filename) == {} + dump({}, test_file_name, mode) + assert load(test_file_name) == {} -def test_compression(): +def test_compression(test_file_name): """ Test compression on datasets""" - filename, mode = 'test.h5', 'w' + mode = 'w' dtypes = ['int32', 'float32', 'float64', 'complex64', 'complex128'] comps = [None, 'gzip', 'lzf'] @@ -332,9 +366,9 @@ def test_compression(): for dt in dtypes: for cc in comps: array_obj = np.ones(32768, dtype=dt) - dump(array_obj, filename, mode, compression=cc) - print(cc, os.path.getsize(filename)) - array_hkl = load(filename) + dump(array_obj, test_file_name, mode, compression=cc) + print(cc, os.path.getsize(test_file_name)) + array_hkl = load(test_file_name) try: assert array_hkl.dtype == array_obj.dtype assert np.all((array_hkl, array_obj)) @@ -344,34 +378,34 @@ def test_compression(): raise -def test_dict_int_key(): +def test_dict_int_key(test_file_name): """ Test for dictionaries with integer keys """ - filename, mode = 'test.h5', 'w' + mode = 'w' dd = { 0: "test", 1: "test2" } - dump(dd, filename, mode) - load(filename) + dump(dd, test_file_name, mode) + load(test_file_name) -def test_dict_nested(): +def test_dict_nested(test_file_name): """ Test for dictionaries with integer keys """ - filename, mode = 'test.h5', 'w' + mode = 'w' dd = NESTED_DICT - dump(dd, filename, mode) - dd_hkl = load(filename) + dump(dd, test_file_name, mode) + dd_hkl = load(test_file_name) ll_hkl = dd_hkl["level1_3"]["level2_1"]["level3_1"] ll = dd["level1_3"]["level2_1"]["level3_1"] assert ll == ll_hkl -def test_masked_dict(): +def test_masked_dict(test_file_name): """ Test dictionaries with masked arrays """ filename, mode = 'test.h5', 'w' @@ -381,8 +415,8 @@ def test_masked_dict(): "data2": np.array([1, 2, 3, 4, 5]) } - dump(dd, filename, mode) - dd_hkl = load(filename) + dump(dd, test_file_name, mode) + dd_hkl = load(test_file_name) for k in dd.keys(): try: @@ -405,9 +439,9 @@ def test_masked_dict(): raise -def test_np_float(): +def test_np_float(test_file_name): """ Test for singular np dtypes """ - filename, mode = 'np_float.h5', 'w' + mode = 'w' dtype_list = (np.float16, np.float32, np.float64, np.complex64, np.complex128, @@ -417,26 +451,26 @@ def test_np_float(): for dt in dtype_list: dd = dt(1) - dump(dd, filename, mode) - dd_hkl = load(filename) + dump(dd, test_file_name, mode) + dd_hkl = load(test_file_name) assert dd == dd_hkl assert dd.dtype == dd_hkl.dtype dd = {} for dt in dtype_list: dd[str(dt)] = dt(1.0) - dump(dd, filename, mode) - dd_hkl = load(filename) + dump(dd, test_file_name, mode) + dd_hkl = load(test_file_name) print(dd) for dt in dtype_list: assert dd[str(dt)] == dd_hkl[str(dt)] -def test_comp_kwargs(): +def test_comp_kwargs(test_file_name): """ Test compression with some kwargs for shuffle and chunking """ - filename, mode = 'test.h5', 'w' + mode = 'w' dtypes = ['int32', 'float32', 'float64', 'complex64', 'complex128'] comps = [None, 'gzip', 'lzf'] @@ -457,22 +491,22 @@ def test_comp_kwargs(): 'scaleoffset': so } array_obj = NESTED_DICT - dump(array_obj, filename, mode, compression=cc) - print(kwargs, os.path.getsize(filename)) - load(filename) + dump(array_obj, test_file_name, mode, compression=cc) + print(kwargs, os.path.getsize(test_file_name)) + load(test_file_name) -def test_list_numpy(): +def test_list_numpy(test_file_name): """ Test converting a list of numpy arrays """ - filename, mode = 'test.h5', 'w' + mode = 'w' a = np.ones(1024) b = np.zeros(1000) c = [a, b] - dump(c, filename, mode) - dd_hkl = load(filename) + dump(c, test_file_name, mode) + dd_hkl = load(test_file_name) print(dd_hkl) @@ -480,17 +514,17 @@ def test_list_numpy(): assert isinstance(dd_hkl[0], np.ndarray) -def test_tuple_numpy(): +def test_tuple_numpy(test_file_name): """ Test converting a list of numpy arrays """ - filename, mode = 'test.h5', 'w' + mode = 'w' a = np.ones(1024) b = np.zeros(1000) c = (a, b, a) - dump(c, filename, mode) - dd_hkl = load(filename) + dump(c, test_file_name, mode) + dd_hkl = load(test_file_name) print(dd_hkl) @@ -498,79 +532,35 @@ def test_tuple_numpy(): assert isinstance(dd_hkl[0], np.ndarray) -def test_numpy_dtype(): +def test_numpy_dtype(test_file_name): """ Dumping and loading a NumPy dtype """ dtype = np.dtype('float16') - dump(dtype, 'test.hdf5') - dtype_hkl = load('test.hdf5') + dump(dtype, test_file_name) + dtype_hkl = load(test_file_name) assert dtype == dtype_hkl -def test_none(): +def test_none(test_file_name): """ Test None type hickling """ - filename, mode = 'test.h5', 'w' + mode = 'w' a = None - dump(a, filename, mode) - dd_hkl = load(filename) + dump(a, test_file_name, mode) + dd_hkl = load(test_file_name) print(a) print(dd_hkl) assert isinstance(dd_hkl, type(None)) -def test_file_open_close(): - """ https://github.com/telegraphic/hickle/issues/20 """ - import h5py - f = h5py.File('test.hdf', 'w') - a = np.arange(5) - - dump(a, 'test.hkl') - dump(a, 'test.hkl') - - dump(a, f, mode='w') - f.close() - try: - dump(a, f, mode='w') - except hickle.ClosedFileError: - print("Tests: Closed file exception caught") - - -def test_hdf5_group(): - import h5py - file = h5py.File('test.hdf5', 'w') - group = file.create_group('test_group') - a = np.arange(5) - dump(a, group) - file.close() - - a_hkl = load('test.hdf5', path='/test_group') - assert np.allclose(a_hkl, a) - - file = h5py.File('test.hdf5', 'r+') - group = file.create_group('test_group2') - b = np.arange(8) - - dump(b, group, path='deeper/and_deeper') - file.close() - - b_hkl = load('test.hdf5', path='/test_group2/deeper/and_deeper') - assert np.allclose(b_hkl, b) - - file = h5py.File('test.hdf5', 'r') - b_hkl2 = load(file['test_group2'], path='deeper/and_deeper') - assert np.allclose(b_hkl2, b) - file.close() - - -def test_list_order(): +def test_list_order(test_file_name): """ https://github.com/telegraphic/hickle/issues/26 """ d = [np.arange(n + 1) for n in range(20)] - dump(d, 'test.h5') - d_hkl = load('test.h5') + dump(d, test_file_name) + d_hkl = load(test_file_name) try: for ii, xx in enumerate(d): @@ -582,13 +572,13 @@ def test_list_order(): raise -def test_embedded_array(): +def test_embedded_array(test_file_name): """ See https://github.com/telegraphic/hickle/issues/24 """ d_orig = [[np.array([10., 20.]), np.array([10, 20, 30])], [np.array([10, 2]), np.array([1.])]] - dump(d_orig, 'test.h5') - d_hkl = load('test.h5') + dump(d_orig, test_file_name) + d_hkl = load(test_file_name) for ii, xx in enumerate(d_orig): for jj, yy in enumerate(xx): @@ -618,171 +608,76 @@ def generate_nested(): z = {'a': a, 'b': b, 'c': c, 'd': d, 'z': z} return z - -def test_is_iterable(): - a = [1, 2, 3] - b = 1 - - assert helpers.check_is_iterable(a) - assert not helpers.check_is_iterable(b) - - -def test_check_iterable_item_type(): - a = [1, 2, 3] - b = [a, a, a] - c = [a, b, 's'] - - type_a = helpers.check_iterable_item_type(a) - type_b = helpers.check_iterable_item_type(b) - type_c = helpers.check_iterable_item_type(c) - - assert type_a is int - assert type_b is list - assert not type_c - - -def test_dump_nested(): +def test_dump_nested(test_file_name): """ Dump a complicated nested object to HDF5 """ z = generate_nested() - dump(z, 'test.hkl', mode='w') - - -def test_with_open_file(): - """ - Testing dumping and loading to an open file - - https://github.com/telegraphic/hickle/issues/92""" - - lst = [1] - tpl = (1,) - dct = {1: 1} - arr = np.array([1]) - - with h5py.File('test.hkl', 'w') as file: - dump(lst, file, path='/lst') - dump(tpl, file, path='/tpl') - dump(dct, file, path='/dct') - dump(arr, file, path='/arr') + dump(z, test_file_name, mode='w') - with h5py.File('test.hkl', 'r') as file: - assert load(file, '/lst') == lst - assert load(file, '/tpl') == tpl - assert load(file, '/dct') == dct - assert load(file, '/arr') == arr - - -def test_load(): - a = set([1, 2, 3, 4]) - b = set([5, 6, 7, 8]) - c = set([9, 10, 11, 12]) - z = (a, b, c) - z = [z, z] - z = (z, z, z, z, z) - - print("Original:") - pprint(z) - dump(z, 'test.hkl', mode='w') - - print("\nReconstructed:") - z = load('test.hkl') - pprint(z) - - -def test_sort_keys(): - keys = [b'data_0', b'data_1', b'data_2', b'data_3', b'data_10'] - keys_sorted = [b'data_0', b'data_1', b'data_2', b'data_3', b'data_10'] - - print(keys) - print(keys_sorted) - assert helpers.sort_keys(keys) == keys_sorted - - -def test_ndarray(): +def test_ndarray(test_file_name): a = np.array([1, 2, 3]) b = np.array([2, 3, 4]) z = (a, b) print("Original:") pprint(z) - dump(z, 'test.hkl', mode='w') + dump(z, test_file_name, mode='w') print("\nReconstructed:") - z = load('test.hkl') + z = load(test_file_name) pprint(z) -def test_ndarray_masked(): +def test_ndarray_masked(test_file_name): a = np.ma.array([1, 2, 3]) b = np.ma.array([2, 3, 4], mask=[True, False, True]) z = (a, b) print("Original:") pprint(z) - dump(z, 'test.hkl', mode='w') + dump(z, test_file_name, mode='w') print("\nReconstructed:") - z = load('test.hkl') + z = load(test_file_name) pprint(z) -def test_simple_dict(): +def test_simple_dict(test_file_name): a = {'key1': 1, 'key2': 2} - dump(a, 'test.hkl') - z = load('test.hkl') + dump(a, test_file_name) + z = load(test_file_name) pprint(a) pprint(z) -def test_complex_dict(): +def test_complex_dict(test_file_name): a = {'akey': 1, 'akey2': 2} c = {'ckey': "hello", "ckey2": "hi there"} z = {'zkey1': a, 'zkey2': a, 'zkey3': c} print("Original:") pprint(z) - dump(z, 'test.hkl', mode='w') + dump(z, test_file_name, mode='w') print("\nReconstructed:") - z = load('test.hkl') + z = load(test_file_name) pprint(z) - -def test_multi_hickle(): - """ Dumping to and loading from the same file several times - - https://github.com/telegraphic/hickle/issues/20""" - - a = {'a': 123, 'b': [1, 2, 4]} - - if os.path.exists("test.hkl"): - os.remove("test.hkl") - dump(a, "test.hkl", path="/test", mode="w") - dump(a, "test.hkl", path="/test2", mode="r+") - dump(a, "test.hkl", path="/test3", mode="r+") - dump(a, "test.hkl", path="/test4", mode="r+") - - load("test.hkl", path="/test") - load("test.hkl", path="/test2") - load("test.hkl", path="/test3") - load("test.hkl", path="/test4") - - -def test_complex(): +def test_complex(test_file_name): """ Test complex value dtype is handled correctly https://github.com/telegraphic/hickle/issues/29 """ data = {"A": 1.5, "B": 1.5 + 1j, "C": np.linspace(0, 1, 4) + 2j} - dump(data, "test.hkl") - data2 = load("test.hkl") + dump(data, test_file_name) + data2 = load(test_file_name) for key in data.keys(): assert isinstance(data[key], data2[key].__class__) -def test_nonstring_keys(): +def test_nonstring_keys(test_file_name): """ Test that keys are reconstructed back to their original datatypes https://github.com/telegraphic/hickle/issues/36 """ @@ -803,8 +698,8 @@ def test_nonstring_keys(): } print(data) - dump(data, "test.hkl") - data2 = load("test.hkl") + dump(data, test_file_name) + data2 = load(test_file_name) print(data2) for key in data.keys(): @@ -813,7 +708,7 @@ def test_nonstring_keys(): print(data2) -def test_scalar_compression(): +def test_scalar_compression(test_file_name): """ Test bug where compression causes a crash on scalar datasets (Scalars are incompressible!) @@ -821,49 +716,48 @@ def test_scalar_compression(): """ data = {'a': 0, 'b': np.float(2), 'c': True} - dump(data, "test.hkl", compression='gzip') - data2 = load("test.hkl") + dump(data, test_file_name, compression='gzip') + data2 = load(test_file_name) print(data2) for key in data.keys(): assert isinstance(data[key], data2[key].__class__) -def test_bytes(): +def test_bytes(test_file_name): """ Dumping and loading a string. PYTHON3 ONLY """ - filename, mode = 'test.h5', 'w' + mode = 'w' string_obj = b"The quick brown fox jumps over the lazy dog" - dump(string_obj, filename, mode) - string_hkl = load(filename) + dump(string_obj, test_file_name, mode) + string_hkl = load(test_file_name) print(type(string_obj)) print(type(string_hkl)) assert isinstance(string_hkl, bytes) assert string_obj == string_hkl -def test_np_scalar(): +def test_np_scalar(test_file_name): """ Numpy scalar datatype https://github.com/telegraphic/hickle/issues/50 """ - fid = 'test.h5py' r0 = {'test': np.float64(10.)} - dump(r0, fid) - r = load(fid) + dump(r0, test_file_name) + r = load(test_file_name) print(r) assert isinstance(r0['test'], r['test'].__class__) -def test_slash_dict_keys(): +def test_slash_dict_keys(test_file_name): """ Support for having slashes in dict keys https://github.com/telegraphic/hickle/issues/124""" dct = {'a/b': [1, '2'], 1.4: 3} - dump(dct, 'test.hdf5', 'w') - dct_hkl = load('test.hdf5') + dump(dct, test_file_name, 'w') + dct_hkl = load(test_file_name) assert isinstance(dct_hkl, dict) for key, val in dct_hkl.items(): @@ -871,64 +765,92 @@ def test_slash_dict_keys(): # Check that having backslashes in dict keys will serialize the dict dct2 = {'a\\b': [1, '2'], 1.4: 3} - with pytest.warns(loaders.load_builtins.SerializedWarning): - dump(dct2, 'test.hdf5') + with pytest.warns(None) as not_expected: + dump(dct2, test_file_name) + assert not not_expected # %% MAIN SCRIPT if __name__ == '__main__': """ Some tests and examples """ - test_sort_keys() - - test_np_scalar() - test_scalar_compression() - test_complex() - test_file_open_close() - test_hdf5_group() - test_none() - test_masked_dict() - test_list() - test_set() - test_numpy() - test_dict() - test_odict() - test_empty_dict() - test_compression() - test_masked() - test_dict_nested() - test_comp_kwargs() - test_list_numpy() - test_tuple_numpy() - test_list_order() - test_embedded_array() - test_np_float() - test_string() - test_nonstring_keys() - test_bytes() + from _pytest.fixtures import FixtureRequest + + for filename in test_file_name(FixtureRequest(test_np_scalar)): + test_np_scalar(filename) + for filename in test_file_name(FixtureRequest(test_scalar_compression)): + test_scalar_compression(filename) + for filename in test_file_name(FixtureRequest(test_complex)): + test_complex(filename) + for filename in test_file_name(FixtureRequest(test_none)): + test_none(filename) + for filename in test_file_name(FixtureRequest(test_masked_dict)): + test_masked_dict(filename) + for filename in test_file_name(FixtureRequest(test_list)): + test_list(filename) + for filename in test_file_name(FixtureRequest(test_set)): + test_set(filename) + for filename in test_file_name(FixtureRequest(test_numpy)): + test_numpy(filename) + for filename in test_file_name(FixtureRequest(test_dict)): + test_dict(filename) + for filename in test_file_name(FixtureRequest(test_odict)): + test_odict(filename) + for filename in test_file_name(FixtureRequest(test_empty_dict)): + test_empty_dict(filename) + for filename in test_file_name(FixtureRequest(test_compression)): + test_compression(filename) + for filename in test_file_name(FixtureRequest(test_masked)): + test_masked(filename) + for filename in test_file_name(FixtureRequest(test_dict_nested)): + test_dict_nested(filename) + for filename in test_file_name(FixtureRequest(test_comp_kwargs)): + test_comp_kwargs(filename) + for filename in test_file_name(FixtureRequest(test_list_numpy)): + test_list_numpy(filename) + for filename in test_file_name(FixtureRequest(test_tuple_numpy)): + test_tuple_numpy(filename) + for filename in test_file_name(FixtureRequest(test_list_order)): + test_list_order(filename) + for filename in test_file_name(FixtureRequest(test_embedded_array)): + test_embedded_array(filename) + for filename in test_file_name(FixtureRequest(test_np_float)): + test_np_float(filename) + for filename in test_file_name(FixtureRequest(test_string)): + test_string(filename) + for filename in test_file_name(FixtureRequest(test_nonstring_keys)): + test_nonstring_keys(filename) + for filename in test_file_name(FixtureRequest(test_bytes)): + test_bytes(filename) # NEW TESTS - test_is_iterable() - test_check_iterable_item_type() - test_dump_nested() - test_with_open_file() - test_load() - test_sort_keys() - test_ndarray() - test_ndarray_masked() - test_simple_dict() - test_complex_dict() - test_multi_hickle() - test_dict_int_key() - test_local_func() - test_binary_file() - test_state_obj() - test_slash_dict_keys() + for filename in test_file_name(FixtureRequest(test_dump_nested)): + test_dump_nested(filename) + for filename in test_file_name(FixtureRequest(test_ndarray)): + test_ndarray(filename) + for filename in test_file_name(FixtureRequest(test_ndarray_masked)): + test_ndarray_masked(filename) + for filename in test_file_name(FixtureRequest(test_simple_dict)): + test_simple_dict(filename) + for filename in test_file_name(FixtureRequest(test_complex_dict)): + test_complex_dict(filename) + for filename in test_file_name(FixtureRequest(test_dict_int_key)): + test_dict_int_key(filename) + for filename in test_file_name(FixtureRequest(test_local_func)): + test_local_func(filename) + for filename in test_file_name(FixtureRequest(test_slash_dict_keys)): + test_slash_dict_keys(filename) test_invalid_file() - test_non_empty_group() - test_numpy_dtype() - test_object_numpy() - test_string_numpy() - test_list_object_numpy() + for filename in test_file_name(FixtureRequest(test_non_empty_group)): + test_non_empty_group(filename) + for filename in test_file_name(FixtureRequest(test_numpy_dtype)): + test_numpy_dtype(filename) + for filename in test_file_name(FixtureRequest(test_object_numpy)): + test_object_numpy(filename) + for filename in test_file_name(FixtureRequest(test_string_numpy)): + test_string_numpy(filename) + for filename in test_file_name(FixtureRequest(test_list_object_numpy)): + test_list_object_numpy(filename) # Cleanup - print("ALL TESTS PASSED!") + for filename in test_file_name(FixtureRequest(print)): + print(filename) diff --git a/hickle/tests/test_hickle_helpers.py b/hickle/tests/test_hickle_helpers.py deleted file mode 100644 index f5dab275..00000000 --- a/hickle/tests/test_hickle_helpers.py +++ /dev/null @@ -1,49 +0,0 @@ -#! /usr/bin/env python -# encoding: utf-8 -""" -# test_hickle_helpers.py - -Unit tests for hickle module -- helper functions. - -""" - - -# %% IMPORTS -# Package imports -import numpy as np - -# hickle imports -from hickle.helpers import ( - check_is_hashable, check_is_iterable, check_iterable_item_type) -from hickle.loaders.load_numpy import check_is_numpy_array - - -# %% FUNCTION DEFINITIONS -def test_check_is_iterable(): - assert check_is_iterable([1, 2, 3]) - assert not check_is_iterable(1) - - -def test_check_is_hashable(): - assert check_is_hashable(1) - assert not check_is_hashable([1, 2, 3]) - - -def test_check_iterable_item_type(): - assert check_iterable_item_type([1, 2, 3]) is int - assert not check_iterable_item_type([int(1), float(1)]) - assert not check_iterable_item_type([]) - - -def test_check_is_numpy_array(): - assert check_is_numpy_array(np.array([1, 2, 3])) - assert check_is_numpy_array(np.ma.array([1, 2, 3])) - assert not check_is_numpy_array([1, 2]) - - -# %% MAIN SCRIPT -if __name__ == "__main__": - test_check_is_hashable() - test_check_is_iterable() - test_check_is_numpy_array() - test_check_iterable_item_type() diff --git a/hickle/tests/test_legacy_load.py b/hickle/tests/test_legacy_load.py index caf2bc89..4b807012 100644 --- a/hickle/tests/test_legacy_load.py +++ b/hickle/tests/test_legacy_load.py @@ -3,35 +3,66 @@ import glob from os import path import warnings +import pytest +import scipy.sparse +import numpy as np # Package imports import h5py # hickle imports import hickle as hkl +import dill as pickle # %% FUNCTION DEFINITIONS def test_legacy_load(): dirpath = path.dirname(__file__) - filelist = sorted(glob.glob(path.join(dirpath, 'legacy_hkls/*.hkl'))) + filelist = sorted(glob.glob(path.join(dirpath, 'legacy_hkls/*3_[0-9]_[0-9].hkl'))) # Make all warnings show warnings.simplefilter("always") for filename in filelist: - try: - print(filename) - a = hkl.load(filename) - except Exception: - with h5py.File(filename) as a: - print(a.attrs.items()) - print(a.items()) - for key, item in a.items(): - print(item.attrs.items()) - raise + with pytest.warns( + UserWarning, + match = r"Input\s+argument\s+'file_obj'\s+appears\s+to\s+be\s+a\s+file\s+made" + r"\s+with\s+hickle\s+v3.\s+Using\s+legacy\s+load..." + ): + try: + print(filename) + a = hkl.load(filename) + except Exception: + with h5py.File(filename) as a: + print(a.attrs.items()) + print(a.items()) + for key, item in a.items(): + print(item.attrs.items()) + raise +def test_4_0_0_load(): + """ + test that files created by hickle 4.0.x can be loaded by + hickle 4.1.x properly + """ + dirpath = path.dirname(__file__) + filelist = sorted(glob.glob(path.join(dirpath, 'legacy_hkls/*4.[0-9].[0-9].hkl'))) + from hickle.tests.generate_legacy_4_0_0 import generate_py_object + compare_with,needs_compare = generate_py_object() + for filename in filelist: + content = hkl.load(filename) + if filename != needs_compare: + continue + for content_item,compare_item in ( (content[i],compare_with[i]) for i in range(len(compare_with)) ): + if scipy.sparse.issparse(content_item): + assert np.allclose(content_item.toarray(),compare_item.toarray()) + continue + try: + assert content_item == compare_item + except ValueError: + assert np.all(content_item == compare_item) # %% MAIN SCRIPT if __name__ == "__main__": test_legacy_load() + test_4_0_0_load() diff --git a/hickle/tests/test_scipy.py b/hickle/tests/test_scipy.py deleted file mode 100644 index d7d811f8..00000000 --- a/hickle/tests/test_scipy.py +++ /dev/null @@ -1,59 +0,0 @@ -# %% IMPORTS -# Package imports -import numpy as np -from py.path import local -from scipy.sparse import csr_matrix, csc_matrix, bsr_matrix - -# hickle imports -import hickle -from hickle.loaders.load_scipy import check_is_scipy_sparse_array - -# Set the current working directory to the temporary directory -local.get_temproot().chdir() - - -# %% FUNCTION DEFINITIONS -def test_is_sparse(): - sm0 = csr_matrix((3, 4), dtype=np.int8) - sm1 = csc_matrix((1, 2)) - - assert check_is_scipy_sparse_array(sm0) - assert check_is_scipy_sparse_array(sm1) - - -def test_sparse_matrix(): - row = np.array([0, 0, 1, 2, 2, 2]) - col = np.array([0, 2, 2, 0, 1, 2]) - data = np.array([1, 2, 3, 4, 5, 6]) - sm1 = csr_matrix((data, (row, col)), shape=(3, 3)) - sm2 = csc_matrix((data, (row, col)), shape=(3, 3)) - - indptr = np.array([0, 2, 3, 6]) - indices = np.array([0, 2, 2, 0, 1, 2]) - data = np.array([1, 2, 3, 4, 5, 6]).repeat(4).reshape(6, 2, 2) - sm3 = bsr_matrix((data, indices, indptr), shape=(6, 6)) - - hickle.dump(sm1, 'test_sp.h5') - sm1_h = hickle.load('test_sp.h5') - hickle.dump(sm2, 'test_sp2.h5') - sm2_h = hickle.load('test_sp2.h5') - hickle.dump(sm3, 'test_sp3.h5') - sm3_h = hickle.load('test_sp3.h5') - - assert isinstance(sm1_h, csr_matrix) - assert isinstance(sm2_h, csc_matrix) - assert isinstance(sm3_h, bsr_matrix) - - assert np.allclose(sm1_h.data, sm1.data) - assert np.allclose(sm2_h.data, sm2.data) - assert np.allclose(sm3_h.data, sm3.data) - - assert sm1_h. shape == sm1.shape - assert sm2_h. shape == sm2.shape - assert sm3_h. shape == sm3.shape - - -# %% MAIN SCRIPT -if __name__ == "__main__": - test_sparse_matrix() - test_is_sparse()