From 5d94aa7ada42d4042eaaf8f2bb37a487319a57fc Mon Sep 17 00:00:00 2001 From: shaw Date: Tue, 31 Aug 2021 11:50:36 -0500 Subject: [PATCH 01/31] refactoring hspy and zspy to build from each other --- hyperspy/io_plugins/hspy.py | 113 +++--- hyperspy/io_plugins/zspy.py | 176 +++++++++ hyperspy/tests/io/test_zspy.py | 689 +++++++++++++++++++++++++++++++++ 3 files changed, 929 insertions(+), 49 deletions(-) create mode 100644 hyperspy/io_plugins/zspy.py create mode 100644 hyperspy/tests/io/test_zspy.py diff --git a/hyperspy/io_plugins/hspy.py b/hyperspy/io_plugins/hspy.py index b799868e6a..7f62b91fc8 100644 --- a/hyperspy/io_plugins/hspy.py +++ b/hyperspy/io_plugins/hspy.py @@ -23,12 +23,15 @@ import ast import h5py +import zarr import numpy as np import dask.array as da from traits.api import Undefined from hyperspy.misc.utils import ensure_unicode, multiply, get_object_package_info from hyperspy.axes import AxesManager +from h5py import Dataset, File, Group + _logger = logging.getLogger(__name__) @@ -122,8 +125,10 @@ def get_hspy_format_version(f): return LooseVersion(version) -def file_reader(filename, backing_store=False, - lazy=False, **kwds): +def file_reader(filename, + backing_store=False, + lazy=False, + **kwds): """Read data from hdf5 files saved with the hyperspy hdf5 format specification Parameters @@ -139,7 +144,7 @@ def file_reader(filename, backing_store=False, except ImportError: pass mode = kwds.pop('mode', 'r') - f = h5py.File(filename, mode=mode, **kwds) + f = File(filename, mode=mode, **kwds) # Getting the format version here also checks if it is a valid HSpy # hdf5 file, so the following two lines must not be deleted or moved # elsewhere. @@ -157,7 +162,7 @@ def file_reader(filename, backing_store=False, standalone_models = [] if 'Analysis/models' in f: try: - m_gr = f.require_group('Analysis/models') + m_gr = f['Analysis/models'] for model_name in m_gr: if '_signal' in m_gr[model_name].attrs: key = m_gr[model_name].attrs['_signal'] @@ -178,7 +183,7 @@ def file_reader(filename, backing_store=False, exp_dict_list = [] if 'Experiments' in f: for ds in f['Experiments']: - if isinstance(f['Experiments'][ds], h5py.Group): + if isinstance(f['Experiments'][ds], Group): if 'data' in f['Experiments'][ds]: experiments.append(ds) # Parse the file @@ -473,7 +478,7 @@ def parse_structure(key, group, value, _type, **kwds): elif isinstance(value, BaseSignal): kn = key if key.startswith('_sig_') else '_sig_' + key write_signal(value, group.require_group(kn)) - elif isinstance(value, (np.ndarray, h5py.Dataset, da.Array)): + elif isinstance(value, (np.ndarray, Dataset, da.Array)): overwrite_dataset(group, value, key, **kwds) elif value is None: group.attrs[key] = '_None_' @@ -559,7 +564,12 @@ def get_signal_chunks(shape, dtype, signal_axes=None): return tuple(int(x) for x in chunks) -def overwrite_dataset(group, data, key, signal_axes=None, chunks=None, **kwds): +def overwrite_dataset(group, + data, + key, + signal_axes=None, + chunks=None, + **kwds): if chunks is None: if isinstance(data, da.Array): # For lazy dataset, by default, we use the current dask chunking @@ -568,53 +578,58 @@ def overwrite_dataset(group, data, key, signal_axes=None, chunks=None, **kwds): # If signal_axes=None, use automatic h5py chunking, otherwise # optimise the chunking to contain at least one signal per chunk chunks = get_signal_chunks(data.shape, data.dtype, signal_axes) - if np.issubdtype(data.dtype, np.dtype('U')): # Saving numpy unicode type is not supported in h5py data = data.astype(np.dtype('S')) if data.dtype == np.dtype('O'): - # For saving ragged array - # http://docs.h5py.org/en/stable/special.html#arbitrary-vlen-data - group.require_dataset(key, - chunks, - dtype=h5py.special_dtype(vlen=data[0].dtype), - **kwds) - group[key][:] = data[:] - - maxshape = tuple(None for _ in data.shape) - - got_data = False - while not got_data: - try: - these_kwds = kwds.copy() - these_kwds.update(dict(shape=data.shape, - dtype=data.dtype, - exact=True, - maxshape=maxshape, - chunks=chunks, - shuffle=True,)) - - # If chunks is True, the `chunks` attribute of `dset` below - # contains the chunk shape guessed by h5py - dset = group.require_dataset(key, **these_kwds) - got_data = True - except TypeError: - # if the shape or dtype/etc do not match, - # we delete the old one and create new in the next loop run - del group[key] + dset = get_object_dset(group, data, key, chunks, **kwds) + else: + got_data=False + maxshape = tuple(None for _ in data.shape) + while not got_data: + try: + these_kwds = kwds.copy() + these_kwds.update(dict(shape=data.shape, + dtype=data.dtype, + exact=True, + maxshape=maxshape, + chunks=chunks, + shuffle=True, )) + + # If chunks is True, the `chunks` attribute of `dset` below + # contains the chunk shape guessed by h5py + dset = group.require_dataset(key, **these_kwds) + got_data = True + except TypeError: + # if the shape or dtype/etc do not match, + # we delete the old one and create new in the next loop run + del group[key] if dset == data: # just a reference to already created thing pass else: _logger.info(f"Chunks used for saving: {dset.chunks}") - if isinstance(data, da.Array): - if data.chunks != dset.chunks: - data = data.rechunk(dset.chunks) - da.store(data, dset) - elif data.flags.c_contiguous: - dset.write_direct(data) - else: - dset[:] = data + store_data(data, dset, group, key, chunks, **kwds) + + +def get_object_dset(group, data, key, chunks, **kwds): + # For saving ragged array + # http://docs.h5py.org/en/stable/special.html#arbitrary-vlen-data + dset = group.require_dataset(key, + chunks, + dtype=h5py.special_dtype(vlen=data[0].dtype), + **kwds) + return dset + +def store_data(data, dset, group, key, chunks, **kwds): + if isinstance(data, da.Array): + if data.chunks != dset.chunks: + data = data.rechunk(dset.chunks) + da.store(data, dset) + elif data.flags.c_contiguous: + dset.write_direct(data) + else: + dset[:] = data def hdfgroup2dict(group, dictionary=None, lazy=False): @@ -654,19 +669,19 @@ def hdfgroup2dict(group, dictionary=None, lazy=False): dictionary[key.replace("_datetime_", "")] = date_iso else: dictionary[key] = value - if not isinstance(group, h5py.Dataset): + if not isinstance(group, Dataset): for key in group.keys(): if key.startswith('_sig_'): from hyperspy.io import dict2signal dictionary[key[len('_sig_'):]] = ( dict2signal(hdfgroup2signaldict( group[key], lazy=lazy))) - elif isinstance(group[key], h5py.Dataset): + elif isinstance(group[key], Dataset): dat = group[key] kn = key if key.startswith("_list_"): if (h5py.check_string_dtype(dat.dtype) and - hasattr(dat, 'asstr')): + hasattr(dat, 'asstr')): # h5py 3.0 and newer # https://docs.h5py.org/en/3.0.0/strings.html dat = dat.asstr()[:] @@ -715,8 +730,8 @@ def hdfgroup2dict(group, dictionary=None, lazy=False): group[key], dictionary[key], lazy=lazy) - return dictionary + return dictionary def write_signal(signal, group, **kwds): "Writes a hyperspy signal to a hdf5 group" diff --git a/hyperspy/io_plugins/zspy.py b/hyperspy/io_plugins/zspy.py new file mode 100644 index 0000000000..b47328a726 --- /dev/null +++ b/hyperspy/io_plugins/zspy.py @@ -0,0 +1,176 @@ +# -*- coding: utf-8 -*- +# Copyright 2007-2021 The HyperSpy developers +# +# This file is part of HyperSpy. +# +# HyperSpy is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# HyperSpy is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with HyperSpy. If not, see . + +from distutils.version import LooseVersion +import warnings +import logging +import datetime +import ast + +import zarr +from zarr import Array as Dataset +from zarr import open as File +import numpy as np +import dask.array as da +from traits.api import Undefined +from hyperspy.misc.utils import ensure_unicode, multiply, get_object_package_info +from hyperspy.axes import AxesManager +from hyperspy.io_plugins.hspy import hdfgroup2signaldict, file_reader, write_signal, overwrite_dataset, get_signal_chunks +import numcodecs + +from hyperspy.io_plugins.hspy import version + +_logger = logging.getLogger(__name__) + + +# Plugin characteristics +# ---------------------- +format_name = 'ZSpy' +description = \ + 'A default file format for HyperSpy based on the zarr standard' +full_support = False +# Recognised file extension +file_extensions = ['zspy'] +default_extension = 0 +# Writing capabilities +non_uniform_axis = True +writes = True + +# ----------------------- +# File format description +# ----------------------- +# The root must contain a group called Experiments +# The experiments group can contain any number of subgroups +# Each subgroup is an experiment or signal +# Each subgroup must contain at least one dataset called data +# The data is an array of arbitrary dimension +# In addition a number equal to the number of dimensions of the data +# dataset + 1 of empty groups called coordinates followed by a number +# must exists with the following attributes: +# 'name' +# 'offset' +# 'scale' +# 'units' +# 'size' +# 'index_in_array' +# The experiment group contains a number of attributes that will be +# directly assigned as class attributes of the Signal instance. In +# addition the experiment groups may contain 'original_metadata' and +# 'metadata'subgroup that will be +# assigned to the same name attributes of the Signal instance as a +# Dictionary Browsers +# The Experiments group can contain attributes that may be common to all +# the experiments and that will be accessible as attributes of the +# Experiments instance + +def get_object_dset(group, data, key, chunks, **kwds): + if data.dtype == np.dtype('O'): + # For saving ragged array + # https://zarr.readthedocs.io/en/stable/tutorial.html?highlight=ragged%20array#ragged-arrays + if chunks is None: + chunks == 1 + these_kwds = kwds.copy() + these_kwds.update(dict(dtype=object, + exact=True, + chunks=chunks)) + dset = group.require_dataset(key, + data.shape, + object_codec=numcodecs.VLenArray(int), + **these_kwds) + return data, dset + + +def store_data(data, dset, group, key, chunks, **kwds): + if isinstance(data, da.Array): + if data.chunks != dset.chunks: + data = data.rechunk(dset.chunks) + path = group._store.dir_path() + "/" + dset.path + data.to_zarr(url=path, + overwrite=True, + **kwds) # add in compression etc + elif data.dtype == np.dtype('O'): + group[key][:] = data[:] # check lazy + else: + path = group._store.dir_path() + "/" + dset.path + dset = zarr.open_array(path, + mode="w", + shape=data.shape, + dtype=data.dtype, + chunks=chunks, + **kwds) + dset[:] = data + +def get_signal_chunks(shape, dtype, signal_axes=None): + """Function that calculates chunks for the signal, + preferably at least one chunk per signal space. + Parameters + ---------- + shape : tuple + the shape of the dataset to be sored / chunked + dtype : {dtype, string} + the numpy dtype of the data + signal_axes: {None, iterable of ints} + the axes defining "signal space" of the dataset. If None, the default + zarr chunking is performed. + """ + typesize = np.dtype(dtype).itemsize + if signal_axes is None: + return None + # chunk size larger than 1 Mb https://zarr.readthedocs.io/en/stable/tutorial.html#chunk-optimizations + # shooting for 100 Mb chunks + total_size = np.prod(shape)*typesize + if total_size < 1e8: # 1 mb + return None +def file_writer(filename, signal, *args, **kwds): + """Writes data to hyperspy's zarr format + Parameters + ---------- + filename: str + signal: a BaseSignal instance + *args, optional + **kwds, optional + """ + if "compressor" not in kwds: + from numcodecs import Blosc + kwds["compressor"] = Blosc(cname='zstd', clevel=1) + store = zarr.storage.NestedDirectoryStore(filename,) + f = zarr.group(store=store, overwrite=True) + f.attrs['file_format'] = "ZSpy" + f.attrs['file_format_version'] = version + exps = f.create_group('Experiments') + group_name = signal.metadata.General.title if \ + signal.metadata.General.title else '__unnamed__' + # / is a invalid character, see #942 + if "/" in group_name: + group_name = group_name.replace("/", "-") + expg = exps.create_group(group_name) + + # Add record_by metadata for backward compatibility + smd = signal.metadata.Signal + if signal.axes_manager.signal_dimension == 1: + smd.record_by = "spectrum" + elif signal.axes_manager.signal_dimension == 2: + smd.record_by = "image" + else: + smd.record_by = "" + try: + write_signal(signal, expg, f, **kwds) + except BaseException: + raise + finally: + del smd.record_by \ No newline at end of file diff --git a/hyperspy/tests/io/test_zspy.py b/hyperspy/tests/io/test_zspy.py new file mode 100644 index 0000000000..907fd1aec2 --- /dev/null +++ b/hyperspy/tests/io/test_zspy.py @@ -0,0 +1,689 @@ +# -*- coding: utf-8 -*- +# Copyright 2007-2021 The HyperSpy developers +# +# This file is part of HyperSpy. +# +# HyperSpy is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# HyperSpy is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with HyperSpy. If not, see . + +import gc +import os.path +import sys +import tempfile +import time +from os import remove + +import dask.array as da +import h5py +import numpy as np +import pytest +import zarr + +from hyperspy._signals.signal1d import Signal1D +from hyperspy._signals.signal2d import Signal2D +from hyperspy.datasets.example_signals import EDS_TEM_Spectrum +from hyperspy.exceptions import VisibleDeprecationWarning +from hyperspy.axes import AxesManager +from hyperspy.io import load +from hyperspy.misc.test_utils import assert_deep_almost_equal +from hyperspy.misc.test_utils import sanitize_dict as san_dict +from hyperspy.roi import Point2DROI +from hyperspy.signal import BaseSignal +from hyperspy.utils import markers +import hyperspy.api as hs + +my_path = os.path.dirname(__file__) + +data = np.array([4066., 3996., 3932., 3923., 5602., 5288., 7234., 7809., + 4710., 5015., 4366., 4524., 4832., 5474., 5718., 5034., + 4651., 4613., 4637., 4429., 4217.]) +example1_original_metadata = { + 'BEAMDIAM -nm': 100.0, + 'BEAMKV -kV': 120.0, + 'CHOFFSET': -168.0, + 'COLLANGLE-mR': 3.4, + 'CONVANGLE-mR': 1.5, + 'DATATYPE': 'XY', + 'DATE': '01-OCT-1991', + 'DWELLTIME-ms': 100.0, + 'ELSDET': 'SERIAL', + 'EMISSION -uA': 5.5, + 'FORMAT': 'EMSA/MAS Spectral Data File', + 'MAGCAM': 100.0, + 'NCOLUMNS': 1.0, + 'NPOINTS': 20.0, + 'OFFSET': 520.13, + 'OPERMODE': 'IMAG', + 'OWNER': 'EMSA/MAS TASK FORCE', + 'PROBECUR -nA': 12.345, + 'SIGNALTYPE': 'ELS', + 'THICKNESS-nm': 50.0, + 'TIME': '12:00', + 'TITLE': 'NIO EELS OK SHELL', + 'VERSION': '1.0', + 'XLABEL': 'Energy', + 'XPERCHAN': 3.1, + 'XUNITS': 'eV', + 'YLABEL': 'Counts', + 'YUNITS': 'Intensity'} + + +class Example1: + "Used as a base class for the TestExample classes below" + + def test_data(self): + assert ( + [4066.0, + 3996.0, + 3932.0, + 3923.0, + 5602.0, + 5288.0, + 7234.0, + 7809.0, + 4710.0, + 5015.0, + 4366.0, + 4524.0, + 4832.0, + 5474.0, + 5718.0, + 5034.0, + 4651.0, + 4613.0, + 4637.0, + 4429.0, + 4217.0] == self.s.data.tolist()) + + def test_original_metadata(self): + assert ( + example1_original_metadata == + self.s.original_metadata.as_dictionary()) + + @pytest.mark.xfail( + reason="dill is not guaranteed to load across Python versions") + def test_binary_string(self): + import dill + # apparently pickle is not "full" and marshal is not + # backwards-compatible + f = dill.loads(self.s.metadata.test.binary_string) + assert f(3.5) == 4.5 + + +class TestSavingMetadataContainers: + + def setup_method(self, method): + self.s = BaseSignal([0.1]) + + def test_save_unicode(self, tmp_path): + s = self.s + s.metadata.set_item('test', ['a', 'b', '\u6f22\u5b57']) + fname = tmp_path / 'test.zspy' + s.save(fname) + s.save("test.zspy") + l = load(fname) + assert isinstance(l.metadata.test[0], str) + assert isinstance(l.metadata.test[1], str) + assert isinstance(l.metadata.test[2], str) + assert l.metadata.test[2] == '\u6f22\u5b57' + + def test_save_long_list(self, tmp_path): + s = self.s + s.metadata.set_item('long_list', list(range(10000))) + start = time.time() + fname = tmp_path / 'test.zspy' + s.save(fname) + end = time.time() + assert end - start < 1.0 # It should finish in less that 1 s. + + def test_numpy_only_inner_lists(self, tmp_path): + s = self.s + s.metadata.set_item('test', [[1., 2], ('3', 4)]) + fname = tmp_path / 'test.zspy' + s.save(fname) + l = load(fname) + assert isinstance(l.metadata.test, list) + assert isinstance(l.metadata.test[0], list) + assert isinstance(l.metadata.test[1], tuple) + + @pytest.mark.xfail(sys.platform == 'win32', + reason="randomly fails in win32") + def test_numpy_general_type(self, tmp_path): + s = self.s + s.metadata.set_item('test', np.array([[1., 2], ['3', 4]])) + fname = tmp_path / 'test.zspy' + s.save(fname) + l = load(fname) + np.testing.assert_array_equal(l.metadata.test, s.metadata.test) + + @pytest.mark.xfail(sys.platform == 'win32', + reason="randomly fails in win32") + def test_list_general_type(self, tmp_path): + s = self.s + s.metadata.set_item('test', [[1., 2], ['3', 4]]) + fname = tmp_path / 'test.zspy' + s.save(fname) + l = load(fname) + assert isinstance(l.metadata.test[0][0], float) + assert isinstance(l.metadata.test[0][1], float) + assert isinstance(l.metadata.test[1][0], str) + assert isinstance(l.metadata.test[1][1], str) + + @pytest.mark.xfail(sys.platform == 'win32', + reason="randomly fails in win32") + def test_general_type_not_working(self, tmp_path): + s = self.s + s.metadata.set_item('test', (BaseSignal([1]), 0.1, 'test_string')) + fname = tmp_path / 'test.zspy' + s.save(fname) + l = load(fname) + assert isinstance(l.metadata.test, tuple) + assert isinstance(l.metadata.test[0], Signal1D) + assert isinstance(l.metadata.test[1], float) + assert isinstance(l.metadata.test[2], str) + + def test_unsupported_type(self, tmp_path): + s = self.s + s.metadata.set_item('test', Point2DROI(1, 2)) + fname = tmp_path / 'test.zspy' + s.save(fname) + l = load(fname) + assert 'test' not in l.metadata + + def test_date_time(self, tmp_path): + s = self.s + date, time = "2016-08-05", "15:00:00.450" + s.metadata.General.date = date + s.metadata.General.time = time + fname = tmp_path / 'test.zspy' + s.save(fname) + l = load(fname) + assert l.metadata.General.date == date + assert l.metadata.General.time == time + + def test_general_metadata(self, tmp_path): + s = self.s + notes = "Dummy notes" + authors = "Author 1, Author 2" + doi = "doi" + s.metadata.General.notes = notes + s.metadata.General.authors = authors + s.metadata.General.doi = doi + fname = tmp_path / 'test.zspy' + s.save(fname) + l = load(fname) + assert l.metadata.General.notes == notes + assert l.metadata.General.authors == authors + assert l.metadata.General.doi == doi + + def test_quantity(self, tmp_path): + s = self.s + quantity = "Intensity (electron)" + s.metadata.Signal.quantity = quantity + fname = tmp_path / 'test.zspy' + s.save(fname) + l = load(fname) + assert l.metadata.Signal.quantity == quantity + + def test_title(self, tmp_path): + s = self.s + fname = tmp_path / 'test.zspy' + s.metadata.General.title = '__unnamed__' + s.save(fname) + l = load(fname) + assert l.metadata.General.title is "" + + def test_save_bytes(self, tmp_path): + s = self.s + byte_message = bytes("testing", 'utf-8') + s.metadata.set_item('test', byte_message) + fname = tmp_path / 'test.zspy' + s.save(fname) + l = load(fname) + assert l.metadata.test == s.metadata.test.decode() + + def test_save_empty_tuple(self, tmp_path): + s = self.s + s.metadata.set_item('test', ()) + fname = tmp_path / 'test.zspy' + s.save(fname) + l = load(fname) + #strange becuase you need the encoding... + assert l.metadata.test == s.metadata.test + + def test_save_axes_manager(self, tmp_path): + s = self.s + s.metadata.set_item('test', s.axes_manager) + fname = tmp_path / 'test.zspy' + s.save(fname) + l = load(fname) + #strange becuase you need the encoding... + assert isinstance(l.metadata.test, AxesManager) + +class TestLoadingOOMReadOnly: + + def setup_method(self, method): + s = BaseSignal(np.empty((5, 5, 5))) + s.save('tmp.zspy', overwrite=True) + self.shape = (10000, 10000, 100) + del s + f = zarr.open('tmp.zspy', mode='r+') + s = f['Experiments/__unnamed__'] + del s['data'] + s.create_dataset( + 'data', + shape=self.shape, + dtype='float64', + chunks=True) + + def test_oom_loading(self): + s = load('tmp.zspy', lazy=True) + assert self.shape == s.data.shape + assert isinstance(s.data, da.Array) + assert s._lazy + s.close_file() + + def teardown_method(self, method): + gc.collect() # Make sure any memmaps are closed first! + try: + remove('tmp.zspy') + except BaseException: + # Don't fail tests if we cannot remove + pass + + +class TestPassingArgs: + def test_compression_opts(self, tmp_path): + self.filename = tmp_path / 'testfile.zspy' + from numcodecs import Blosc + comp = Blosc(cname='zstd', clevel=1, shuffle=Blosc.SHUFFLE) + BaseSignal([1, 2, 3]).save(self.filename, compressor=comp) + f = zarr.open(self.filename.__str__(), mode='r+') + d = f['Experiments/__unnamed__/data'] + assert (d.compressor == comp) + + +class TestAxesConfiguration: + + def test_axes_configuration(self, tmp_path): + self.filename = tmp_path / 'testfile.zspy' + s = BaseSignal(np.zeros((2, 2, 2, 2, 2))) + s.axes_manager.signal_axes[0].navigate = True + s.axes_manager.signal_axes[0].navigate = True + s.save(self.filename) + s = load(self.filename) + assert s.axes_manager.navigation_axes[0].index_in_array == 4 + assert s.axes_manager.navigation_axes[1].index_in_array == 3 + assert s.axes_manager.signal_dimension == 3 + + +class TestAxesConfigurationBinning: + def test_axes_configuration(self): + self.filename = 'testfile.zspy' + s = BaseSignal(np.zeros((2, 2, 2))) + s.axes_manager.signal_axes[-1].is_binned = True + s.save(self.filename) + s = load(self.filename) + assert s.axes_manager.signal_axes[-1].is_binned == True + + +class Test_permanent_markers_io: + + def test_save_permanent_marker(self): + s = Signal2D(np.arange(100).reshape(10, 10)) + m = markers.point(x=5, y=5) + s.add_marker(m, permanent=True) + with tempfile.TemporaryDirectory() as tmp: + filename = tmp + '/testsavefile.zspy' + s.save(filename) + + def test_save_load_empty_metadata_markers(self): + s = Signal2D(np.arange(100).reshape(10, 10)) + m = markers.point(x=5, y=5) + m.name = "test" + s.add_marker(m, permanent=True) + del s.metadata.Markers.test + with tempfile.TemporaryDirectory() as tmp: + filename = tmp + '/testsavefile.zspy' + s.save(filename) + s1 = load(filename) + assert len(s1.metadata.Markers) == 0 + + def test_save_load_permanent_marker(self): + x, y = 5, 2 + color = 'red' + size = 10 + name = 'testname' + s = Signal2D(np.arange(100).reshape(10, 10)) + m = markers.point(x=x, y=y, color=color, size=size) + m.name = name + s.add_marker(m, permanent=True) + with tempfile.TemporaryDirectory() as tmp: + filename = tmp + '/testloadfile.zspy' + s.save(filename) + s1 = load(filename) + assert s1.metadata.Markers.has_item(name) + m1 = s1.metadata.Markers.get_item(name) + assert m1.get_data_position('x1') == x + assert m1.get_data_position('y1') == y + assert m1.get_data_position('size') == size + assert m1.marker_properties['color'] == color + assert m1.name == name + + def test_save_load_permanent_marker_all_types(self): + x1, y1, x2, y2 = 5, 2, 1, 8 + s = Signal2D(np.arange(100).reshape(10, 10)) + m0_list = [ + markers.point(x=x1, y=y1), + markers.horizontal_line(y=y1), + markers.horizontal_line_segment(x1=x1, x2=x2, y=y1), + markers.line_segment(x1=x1, x2=x2, y1=y1, y2=y2), + markers.rectangle(x1=x1, x2=x2, y1=y1, y2=y2), + markers.text(x=x1, y=y1, text="test"), + markers.vertical_line(x=x1), + markers.vertical_line_segment(x=x1, y1=y1, y2=y2), + ] + for m in m0_list: + s.add_marker(m, permanent=True) + with tempfile.TemporaryDirectory() as tmp: + filename = tmp + '/testallmarkersfile.zspy' + s.save(filename) + s1 = load(filename) + markers_dict = s1.metadata.Markers + m0_dict_list = [] + m1_dict_list = [] + for m in m0_list: + m0_dict_list.append(san_dict(m._to_dictionary())) + m1_dict_list.append( + san_dict(markers_dict.get_item(m.name)._to_dictionary())) + assert len(list(s1.metadata.Markers)) == 8 + for m0_dict, m1_dict in zip(m0_dict_list, m1_dict_list): + assert m0_dict == m1_dict + + def test_save_load_horizontal_line_marker(self): + y = 8 + color = 'blue' + linewidth = 2.5 + name = "horizontal_line_test" + s = Signal2D(np.arange(100).reshape(10, 10)) + m = markers.horizontal_line(y=y, color=color, linewidth=linewidth) + m.name = name + s.add_marker(m, permanent=True) + with tempfile.TemporaryDirectory() as tmp: + filename = tmp + '/test_save_horizontal_line_marker.zspy' + s.save(filename) + s1 = load(filename) + m1 = s1.metadata.Markers.get_item(name) + assert san_dict(m1._to_dictionary()) == san_dict(m._to_dictionary()) + + def test_save_load_horizontal_line_segment_marker(self): + x1, x2, y = 1, 5, 8 + color = 'red' + linewidth = 1.2 + name = "horizontal_line_segment_test" + s = Signal2D(np.arange(100).reshape(10, 10)) + m = markers.horizontal_line_segment( + x1=x1, x2=x2, y=y, color=color, linewidth=linewidth) + m.name = name + s.add_marker(m, permanent=True) + with tempfile.TemporaryDirectory() as tmp: + filename = tmp + '/test_save_horizontal_line_segment_marker.zspy' + s.save(filename) + s1 = load(filename) + m1 = s1.metadata.Markers.get_item(name) + assert san_dict(m1._to_dictionary()) == san_dict(m._to_dictionary()) + + def test_save_load_vertical_line_marker(self): + x = 9 + color = 'black' + linewidth = 3.5 + name = "vertical_line_test" + s = Signal2D(np.arange(100).reshape(10, 10)) + m = markers.vertical_line(x=x, color=color, linewidth=linewidth) + m.name = name + s.add_marker(m, permanent=True) + with tempfile.TemporaryDirectory() as tmp: + filename = tmp + '/test_save_vertical_line_marker.zspy' + s.save(filename) + s1 = load(filename) + m1 = s1.metadata.Markers.get_item(name) + assert san_dict(m1._to_dictionary()) == san_dict(m._to_dictionary()) + + def test_save_load_vertical_line_segment_marker(self): + x, y1, y2 = 2, 1, 3 + color = 'white' + linewidth = 4.2 + name = "vertical_line_segment_test" + s = Signal2D(np.arange(100).reshape(10, 10)) + m = markers.vertical_line_segment( + x=x, y1=y1, y2=y2, color=color, linewidth=linewidth) + m.name = name + s.add_marker(m, permanent=True) + with tempfile.TemporaryDirectory() as tmp: + filename = tmp + '/test_save_vertical_line_segment_marker.zspy' + s.save(filename) + s1 = load(filename) + m1 = s1.metadata.Markers.get_item(name) + assert san_dict(m1._to_dictionary()) == san_dict(m._to_dictionary()) + + def test_save_load_line_segment_marker(self): + x1, x2, y1, y2 = 1, 9, 4, 7 + color = 'cyan' + linewidth = 0.7 + name = "line_segment_test" + s = Signal2D(np.arange(100).reshape(10, 10)) + m = markers.line_segment( + x1=x1, x2=x2, y1=y1, y2=y2, color=color, linewidth=linewidth) + m.name = name + s.add_marker(m, permanent=True) + with tempfile.TemporaryDirectory() as tmp: + filename = tmp + '/test_save_line_segment_marker.zspy' + s.save(filename) + s1 = load(filename) + m1 = s1.metadata.Markers.get_item(name) + assert san_dict(m1._to_dictionary()) == san_dict(m._to_dictionary()) + + def test_save_load_point_marker(self): + x, y = 9, 8 + color = 'purple' + name = "point test" + s = Signal2D(np.arange(100).reshape(10, 10)) + m = markers.point( + x=x, y=y, color=color) + m.name = name + s.add_marker(m, permanent=True) + with tempfile.TemporaryDirectory() as tmp: + filename = tmp + '/test_save_point_marker.zspy' + s.save(filename) + s1 = load(filename) + m1 = s1.metadata.Markers.get_item(name) + assert san_dict(m1._to_dictionary()) == san_dict(m._to_dictionary()) + + def test_save_load_rectangle_marker(self): + x1, x2, y1, y2 = 2, 4, 1, 3 + color = 'yellow' + linewidth = 5 + name = "rectangle_test" + s = Signal2D(np.arange(100).reshape(10, 10)) + m = markers.rectangle( + x1=x1, x2=x2, y1=y1, y2=y2, color=color, linewidth=linewidth) + m.name = name + s.add_marker(m, permanent=True) + with tempfile.TemporaryDirectory() as tmp: + filename = tmp + '/test_save_rectangle_marker.zspy' + s.save(filename) + s1 = load(filename) + m1 = s1.metadata.Markers.get_item(name) + assert san_dict(m1._to_dictionary()) == san_dict(m._to_dictionary()) + + def test_metadata_update_to_v3_1(self): + md = {'Acquisition_instrument': {'SEM': {'Stage': {'tilt_alpha': 5.0}}, + 'TEM': {'Detector': {'Camera': {'exposure': 0.20000000000000001}}, + 'Stage': {'tilt_alpha': 10.0}, + 'acquisition_mode': 'TEM', + 'beam_current': 0.0, + 'beam_energy': 200.0, + 'camera_length': 320.00000000000006, + 'microscope': 'FEI Tecnai'}}, + 'General': {'date': '2014-07-09', + 'original_filename': 'test_diffraction_pattern.dm3', + 'time': '18:56:37', + 'title': 'test_diffraction_pattern'}, + 'Signal': {'Noise_properties': {'Variance_linear_model': {'gain_factor': 1.0, + 'gain_offset': 0.0}}, + 'quantity': 'Intensity', + 'signal_type': ''}, + '_HyperSpy': {'Folding': {'original_axes_manager': None, + 'original_shape': None, + 'signal_unfolded': False, + 'unfolded': False}}} + s = load(os.path.join( + my_path, + "hdf5_files", + 'example2_v3.1.hspy')) + assert_deep_almost_equal(s.metadata.as_dictionary(), md) + + def test_save_load_text_marker(self): + x, y = 3, 9.5 + color = 'brown' + name = "text_test" + text = "a text" + s = Signal2D(np.arange(100).reshape(10, 10)) + m = markers.text( + x=x, y=y, text=text, color=color) + m.name = name + s.add_marker(m, permanent=True) + with tempfile.TemporaryDirectory() as tmp: + filename = tmp + '/test_save_text_marker.zspy' + s.save(filename) + s1 = load(filename) + m1 = s1.metadata.Markers.get_item(name) + assert san_dict(m1._to_dictionary()) == san_dict(m._to_dictionary()) + + def test_save_load_multidim_navigation_marker(self): + x, y = (1, 2, 3), (5, 6, 7) + name = 'test point' + s = Signal2D(np.arange(300).reshape(3, 10, 10)) + m = markers.point(x=x, y=y) + m.name = name + s.add_marker(m, permanent=True) + with tempfile.TemporaryDirectory() as tmp: + filename = tmp + '/test_save_multidim_nav_marker.zspy' + s.save(filename) + s1 = load(filename) + m1 = s1.metadata.Markers.get_item(name) + assert san_dict(m1._to_dictionary()) == san_dict(m._to_dictionary()) + assert m1.get_data_position('x1') == x[0] + assert m1.get_data_position('y1') == y[0] + s1.axes_manager.navigation_axes[0].index = 1 + assert m1.get_data_position('x1') == x[1] + assert m1.get_data_position('y1') == y[1] + s1.axes_manager.navigation_axes[0].index = 2 + assert m1.get_data_position('x1') == x[2] + assert m1.get_data_position('y1') == y[2] + + + + +@pytest.mark.parametrize("compressor", (None, "default", "blosc")) +def test_compression(compressor, tmp_path): + if compressor is "blosc": + from numcodecs import Blosc + compressor = Blosc(cname='zstd', clevel=3, shuffle=Blosc.BITSHUFFLE) + s = Signal1D(np.ones((3, 3))) + s.save(tmp_path / 'test_compression.zspy', + overwrite=True, + compressor=compressor) + load(tmp_path / 'test_compression.zspy') + + +def test_strings_from_py2(): + s = EDS_TEM_Spectrum() + assert isinstance(s.metadata.Sample.elements, list) + + +def test_save_ragged_array(tmp_path): + a = np.array([0, 1]) + b = np.array([0, 1, 2]) + s = BaseSignal(np.array([a, b], dtype=object)).T + fname = tmp_path / 'test_save_ragged_array.zspy' + s.save(fname) + s1 = load(fname) + for i in range(len(s.data)): + np.testing.assert_allclose(s.data[i], s1.data[i]) + assert s.__class__ == s1.__class__ + + +def test_save_chunks_signal_metadata(): + N = 10 + dim = 3 + s = Signal1D(np.arange(N ** dim).reshape([N] * dim)) + s.navigator = s.sum(-1) + s.change_dtype('float') + s.decomposition() + with tempfile.TemporaryDirectory() as tmp: + filename = os.path.join(tmp, 'test_save_chunks_signal_metadata.zspy') + chunks = (5, 2, 2) + s.save(filename, chunks=chunks) + s2 = load(filename, lazy=True) + assert tuple([c[0] for c in s2.data.chunks]) == chunks + + +def test_chunking_saving_lazy(): + s = Signal2D(da.zeros((50, 100, 100))).as_lazy() + s.data = s.data.rechunk([50, 25, 25]) + with tempfile.TemporaryDirectory() as tmp: + filename = os.path.join(tmp, 'test_chunking_saving_lazy.zspy') + filename2 = os.path.join(tmp, 'test_chunking_saving_lazy_chunks_True.zspy') + filename3 = os.path.join(tmp, 'test_chunking_saving_lazy_chunks_specified.zspy') + s.save(filename) + s1 = load(filename, lazy=True) + assert s.data.chunks == s1.data.chunks + + s.save(filename2) + s2 = load(filename2, lazy=True) + assert tuple([c[0] for c in s2.data.chunks]) == (50, 25, 25) + + # specify chunks + chunks = (50, 10, 10) + s.save(filename3, chunks=chunks) + s3 = load(filename3, lazy=True) + assert tuple([c[0] for c in s3.data.chunks]) == chunks + +def test_data_lazy(): + s = Signal2D(da.ones((5, 10, 10))).as_lazy() + s.data = s.data.rechunk([5, 2, 2]) + with tempfile.TemporaryDirectory() as tmp: + filename = os.path.join(tmp, 'test_chunking_saving_lazy.zspy') + s.save(filename) + s1 = load(filename) + np.testing.assert_array_almost_equal(s1.data, s.data) + + +class TestZspy: + @pytest.fixture + def signal(self): + data = np.ones((10,10,10,10)) + s = Signal1D(data) + return s + + def test_save_load_model(self, signal): + with tempfile.TemporaryDirectory() as tmp: + filename = tmp + '/testmodels.zspy' + m = signal.create_model() + m.append(hs.model.components1D.Gaussian()) + m.store("test") + signal.save(filename) + signal2 = hs.load(filename) + m2 = signal2.models.restore("test") + assert m.signal == m2.signal \ No newline at end of file From 2270bd775125fc35480b72442241ba25c9f6d0d9 Mon Sep 17 00:00:00 2001 From: shaw Date: Wed, 1 Sep 2021 08:08:08 -0500 Subject: [PATCH 02/31] Added in zspy as file format --- hyperspy/io.py | 6 ++--- hyperspy/io_plugins/__init__.py | 2 ++ hyperspy/io_plugins/zspy.py | 43 +++++++++++++++++++++++++++++++-- 3 files changed, 46 insertions(+), 5 deletions(-) diff --git a/hyperspy/io.py b/hyperspy/io.py index eb0248d0ad..f47f60bc45 100644 --- a/hyperspy/io.py +++ b/hyperspy/io.py @@ -334,7 +334,7 @@ def load(filenames=None, filenames = _escape_square_brackets(filenames) filenames = natsorted([f for f in glob.glob(filenames) - if os.path.isfile(f)]) + if os.path.isfile(f) or os.path.isdir(f)]) if not filenames: raise ValueError(f'No filename matches the pattern "{pattern}"') @@ -342,7 +342,7 @@ def load(filenames=None, elif isinstance(filenames, Path): # Just convert to list for now, pathlib.Path not # fully supported in io_plugins - filenames = [f for f in [filenames] if f.is_file()] + filenames = [f for f in [filenames] if f.is_file() or f.is_dir()] elif isgenerator(filenames): filenames = list(filenames) @@ -449,7 +449,7 @@ def load_single_file(filename, **kwds): Data loaded from the file. """ - if not os.path.isfile(filename): + if not os.path.isfile(filename) and not os.path.isdir(filename): raise FileNotFoundError(f"File: {filename} not found!") # File extension without "." separator diff --git a/hyperspy/io_plugins/__init__.py b/hyperspy/io_plugins/__init__.py index 6e4224512c..3e19e898f4 100644 --- a/hyperspy/io_plugins/__init__.py +++ b/hyperspy/io_plugins/__init__.py @@ -40,6 +40,7 @@ semper_unf, sur, tiff, + zspy, ) io_plugins = [ @@ -63,6 +64,7 @@ semper_unf, sur, tiff, + zspy, ] diff --git a/hyperspy/io_plugins/zspy.py b/hyperspy/io_plugins/zspy.py index b47328a726..50b090d181 100644 --- a/hyperspy/io_plugins/zspy.py +++ b/hyperspy/io_plugins/zspy.py @@ -30,7 +30,7 @@ from traits.api import Undefined from hyperspy.misc.utils import ensure_unicode, multiply, get_object_package_info from hyperspy.axes import AxesManager -from hyperspy.io_plugins.hspy import hdfgroup2signaldict, file_reader, write_signal, overwrite_dataset, get_signal_chunks +from hyperspy.io_plugins.hspy import hdfgroup2signaldict, dict2hdfgroup, file_reader, write_signal, overwrite_dataset, get_signal_chunks import numcodecs from hyperspy.io_plugins.hspy import version @@ -136,6 +136,45 @@ def get_signal_chunks(shape, dtype, signal_axes=None): total_size = np.prod(shape)*typesize if total_size < 1e8: # 1 mb return None + +def write_signal(signal, group, f=None, **kwds): + """Writes a hyperspy signal to a zarr group""" + + group.attrs.update(get_object_package_info(signal)) + metadata = "metadata" + original_metadata = "original_metadata" + + if 'compressor' not in kwds: + kwds['compressor'] = None + + for axis in signal.axes_manager._axes: + axis_dict = axis.get_axis_dictionary() + coord_group = group.create_group( + 'axis-%s' % axis.index_in_array) + dict2hdfgroup(axis_dict, coord_group, **kwds) + mapped_par = group.create_group(metadata) + metadata_dict = signal.metadata.as_dictionary() + overwrite_dataset(group, signal.data, 'data', + signal_axes=signal.axes_manager.signal_indices_in_array, + **kwds) + # Remove chunks from the kwds since it wouldn't have the same rank as the + # dataset and can't be used + kwds.pop('chunks', None) + dict2hdfgroup(metadata_dict, mapped_par, **kwds) + original_par = group.create_group(original_metadata) + dict2hdfgroup(signal.original_metadata.as_dictionary(), original_par, + **kwds) + learning_results = group.create_group('learning_results') + dict2hdfgroup(signal.learning_results.__dict__, + learning_results, **kwds) + + if len(signal.models) and f is not None: + model_group = f.require_group('Analysis/models') + dict2hdfgroup(signal.models._models.as_dictionary(), + model_group, **kwds) + for model in model_group.values(): + model.attrs['_signal'] = group.name + def file_writer(filename, signal, *args, **kwds): """Writes data to hyperspy's zarr format Parameters @@ -173,4 +212,4 @@ def file_writer(filename, signal, *args, **kwds): except BaseException: raise finally: - del smd.record_by \ No newline at end of file + del smd.record_by From 0e6d7058efccc48f199dfb27e8646f8686d69c1a Mon Sep 17 00:00:00 2001 From: shaw Date: Fri, 3 Sep 2021 10:26:13 -0500 Subject: [PATCH 03/31] Added in HyperspyReader and HyperspyWriter Classes for easier organization --- hyperspy/io_plugins/hspy.py | 1325 ++++++++++++++++++----------------- hyperspy/io_plugins/spy.py | 0 2 files changed, 666 insertions(+), 659 deletions(-) create mode 100644 hyperspy/io_plugins/spy.py diff --git a/hyperspy/io_plugins/hspy.py b/hyperspy/io_plugins/hspy.py index 7f62b91fc8..91630b3928 100644 --- a/hyperspy/io_plugins/hspy.py +++ b/hyperspy/io_plugins/hspy.py @@ -23,7 +23,6 @@ import ast import h5py -import zarr import numpy as np import dask.array as da from traits.api import Undefined @@ -106,682 +105,658 @@ current_file_version = None # Format version of the file being read default_version = LooseVersion(version) - -def get_hspy_format_version(f): - if "file_format_version" in f.attrs: - version = f.attrs["file_format_version"] - if isinstance(version, bytes): - version = version.decode() - if isinstance(version, float): - version = str(round(version, 2)) - elif "Experiments" in f: - # Chances are that this is a HSpy hdf5 file version 1.0 - version = "1.0" - elif "Analysis" in f: - # Starting version 2.0 we have "Analysis" field as well - version = "2.0" - else: - raise IOError(not_valid_format) - return LooseVersion(version) - - -def file_reader(filename, - backing_store=False, - lazy=False, - **kwds): - """Read data from hdf5 files saved with the hyperspy hdf5 format specification - - Parameters - ---------- - filename: str - lazy: bool - Load image lazily using dask - **kwds, optional - """ - try: - # in case blosc compression is used - import hdf5plugin - except ImportError: - pass - mode = kwds.pop('mode', 'r') - f = File(filename, mode=mode, **kwds) - # Getting the format version here also checks if it is a valid HSpy - # hdf5 file, so the following two lines must not be deleted or moved - # elsewhere. - global current_file_version - current_file_version = get_hspy_format_version(f) - global default_version - if current_file_version > default_version: - warnings.warn( - "This file was written using a newer version of the " - "HyperSpy hdf5 file format. I will attempt to load it, but, " - "if I fail, it is likely that I will be more successful at " - "this and other tasks if you upgrade me.") - - models_with_signals = [] - standalone_models = [] - if 'Analysis/models' in f: - try: - m_gr = f['Analysis/models'] - for model_name in m_gr: - if '_signal' in m_gr[model_name].attrs: - key = m_gr[model_name].attrs['_signal'] - # del m_gr[model_name].attrs['_signal'] - res = hdfgroup2dict( - m_gr[model_name], - lazy=lazy) - del res['_signal'] - models_with_signals.append((key, {model_name: res})) - else: - standalone_models.append( - {model_name: hdfgroup2dict( - m_gr[model_name], lazy=lazy)}) - except TypeError: +class HyperspyReader: + def __init__(self, file): + self.file = file + self.version = self.get_hspy_format_version() + + def get_hspy_format_version(self): + if "file_format_version" in self.file.attrs: + version = self.file.attrs["file_format_version"] + if isinstance(version, bytes): + version = version.decode() + if isinstance(version, float): + version = str(round(version, 2)) + elif "Experiments" in self.file: + # Chances are that this is a HSpy hdf5 file version 1.0 + version = "1.0" + elif "Analysis" in self.f: + # Starting version 2.0 we have "Analysis" field as well + version = "2.0" + else: raise IOError(not_valid_format) + return LooseVersion(version) + + def hdfgroup2signaldict(self, + group, + lazy=False): + global default_version + if self.version < LooseVersion("1.2"): + metadata = "mapped_parameters" + original_metadata = "original_parameters" + else: + metadata = "metadata" + original_metadata = "original_metadata" + + exp = {'metadata': self.hdfgroup2dict( + group[metadata], lazy=lazy), + 'original_metadata': self.hdfgroup2dict( + group[original_metadata], lazy=lazy), + 'attributes': {} + } + if "package" in group.attrs: + # HyperSpy version is >= 1.5 + exp["package"] = group.attrs["package"] + exp["package_version"] = group.attrs["package_version"] + else: + # Prior to v1.4 we didn't store the package information. Since there + # were already external package we cannot assume any package provider so + # we leave this empty. + exp["package"] = "" + exp["package_version"] = "" + + data = group['data'] + if lazy: + data = da.from_array(data, chunks=data.chunks) + exp['attributes']['_lazy'] = True + else: + data = np.asanyarray(data) + exp['data'] = data + axes = [] + for i in range(len(exp['data'].shape)): + try: + axes.append(self.hdfgroup2dict(group['axis-%i' % i])) + axis = axes[-1] + for key, item in axis.items(): + if isinstance(item, np.bool_): + axis[key] = bool(item) + else: + axis[key] = ensure_unicode(item) + except KeyError: + break + if len(axes) != len(exp['data'].shape): # broke from the previous loop + try: + axes = [i for k, i in sorted(iter(self.hdfgroup2dict( + group['_list_' + str(len(exp['data'].shape)) + '_axes'], + lazy=lazy).items()))] + except KeyError: + raise IOError(not_valid_format) + exp['axes'] = axes + if 'learning_results' in group.keys(): + exp['attributes']['learning_results'] = \ + self.hdfgroup2dict( + group['learning_results'], + lazy=lazy) + if 'peak_learning_results' in group.keys(): + exp['attributes']['peak_learning_results'] = \ + self.hdfgroup2dict( + group['peak_learning_results'], + lazy=lazy) - experiments = [] - exp_dict_list = [] - if 'Experiments' in f: - for ds in f['Experiments']: - if isinstance(f['Experiments'][ds], Group): - if 'data' in f['Experiments'][ds]: - experiments.append(ds) - # Parse the file - for experiment in experiments: - exg = f['Experiments'][experiment] - exp = hdfgroup2signaldict(exg, lazy) - # assign correct models, if found: - _tmp = {} - for (key, _dict) in reversed(models_with_signals): - if key == exg.name: - _tmp.update(_dict) - models_with_signals.remove((key, _dict)) - exp['models'] = _tmp - - exp_dict_list.append(exp) - - for _, m in models_with_signals: - standalone_models.append(m) - - exp_dict_list.extend(standalone_models) - if not len(exp_dict_list): - raise IOError('This is not a valid HyperSpy HDF5 file. ' - 'You can still load the data using a hdf5 reader, ' - 'e.g. h5py, and manually create a Signal. ' - 'Please, refer to the User Guide for details') - if not lazy: - f.close() - return exp_dict_list - - -def hdfgroup2signaldict(group, lazy=False): - global current_file_version - global default_version - if current_file_version < LooseVersion("1.2"): - metadata = "mapped_parameters" - original_metadata = "original_parameters" - else: - metadata = "metadata" - original_metadata = "original_metadata" - - exp = {'metadata': hdfgroup2dict( - group[metadata], lazy=lazy), - 'original_metadata': hdfgroup2dict( - group[original_metadata], lazy=lazy), - 'attributes': {} - } - if "package" in group.attrs: - # HyperSpy version is >= 1.5 - exp["package"] = group.attrs["package"] - exp["package_version"] = group.attrs["package_version"] - else: - # Prior to v1.4 we didn't store the package information. Since there - # were already external package we cannot assume any package provider so - # we leave this empty. - exp["package"] = "" - exp["package_version"] = "" - - data = group['data'] - if lazy: - data = da.from_array(data, chunks=data.chunks) - exp['attributes']['_lazy'] = True - else: - data = np.asanyarray(data) - exp['data'] = data - axes = [] - for i in range(len(exp['data'].shape)): - try: - axes.append(hdfgroup2dict(group['axis-%i' % i])) - axis = axes[-1] - for key, item in axis.items(): - if isinstance(item, np.bool_): - axis[key] = bool(item) - else: - axis[key] = ensure_unicode(item) - except KeyError: - break - if len(axes) != len(exp['data'].shape): # broke from the previous loop - try: - axes = [i for k, i in sorted(iter(hdfgroup2dict( - group['_list_' + str(len(exp['data'].shape)) + '_axes'], - lazy=lazy).items()))] - except KeyError: - raise IOError(not_valid_format) - exp['axes'] = axes - if 'learning_results' in group.keys(): - exp['attributes']['learning_results'] = \ - hdfgroup2dict( - group['learning_results'], - lazy=lazy) - if 'peak_learning_results' in group.keys(): - exp['attributes']['peak_learning_results'] = \ - hdfgroup2dict( - group['peak_learning_results'], - lazy=lazy) - - # If the title was not defined on writing the Experiment is - # then called __unnamed__. The next "if" simply sets the title - # back to the empty string - if "General" in exp["metadata"] and "title" in exp["metadata"]["General"]: - if '__unnamed__' == exp['metadata']['General']['title']: - exp['metadata']["General"]['title'] = '' - - if current_file_version < LooseVersion("1.1"): - # Load the decomposition results written with the old name, - # mva_results - if 'mva_results' in group.keys(): - exp['attributes']['learning_results'] = hdfgroup2dict( - group['mva_results'], lazy=lazy) - if 'peak_mva_results' in group.keys(): - exp['attributes']['peak_learning_results'] = hdfgroup2dict( - group['peak_mva_results'], lazy=lazy) - # Replace the old signal and name keys with their current names - if 'signal' in exp['metadata']: - if "Signal" not in exp["metadata"]: - exp["metadata"]["Signal"] = {} - exp['metadata']["Signal"]['signal_type'] = \ - exp['metadata']['signal'] - del exp['metadata']['signal'] - - if 'name' in exp['metadata']: - if "General" not in exp["metadata"]: - exp["metadata"]["General"] = {} - exp['metadata']['General']['title'] = \ - exp['metadata']['name'] - del exp['metadata']['name'] - - if current_file_version < LooseVersion("1.2"): - if '_internal_parameters' in exp['metadata']: - exp['metadata']['_HyperSpy'] = \ - exp['metadata']['_internal_parameters'] - del exp['metadata']['_internal_parameters'] - if 'stacking_history' in exp['metadata']['_HyperSpy']: - exp['metadata']['_HyperSpy']["Stacking_history"] = \ - exp['metadata']['_HyperSpy']['stacking_history'] - del exp['metadata']['_HyperSpy']["stacking_history"] - if 'folding' in exp['metadata']['_HyperSpy']: - exp['metadata']['_HyperSpy']["Folding"] = \ - exp['metadata']['_HyperSpy']['folding'] - del exp['metadata']['_HyperSpy']["folding"] - if 'Variance_estimation' in exp['metadata']: - if "Noise_properties" not in exp["metadata"]: - exp["metadata"]["Noise_properties"] = {} - exp['metadata']['Noise_properties']["Variance_linear_model"] = \ - exp['metadata']['Variance_estimation'] - del exp['metadata']['Variance_estimation'] - if "TEM" in exp["metadata"]: - if "Acquisition_instrument" not in exp["metadata"]: - exp["metadata"]["Acquisition_instrument"] = {} - exp["metadata"]["Acquisition_instrument"]["TEM"] = \ - exp["metadata"]["TEM"] - del exp["metadata"]["TEM"] - tem = exp["metadata"]["Acquisition_instrument"]["TEM"] - if "EELS" in tem: - if "dwell_time" in tem: - tem["EELS"]["dwell_time"] = tem["dwell_time"] - del tem["dwell_time"] - if "dwell_time_units" in tem: - tem["EELS"]["dwell_time_units"] = tem["dwell_time_units"] - del tem["dwell_time_units"] - if "exposure" in tem: - tem["EELS"]["exposure"] = tem["exposure"] - del tem["exposure"] - if "exposure_units" in tem: - tem["EELS"]["exposure_units"] = tem["exposure_units"] - del tem["exposure_units"] - if "Detector" not in tem: - tem["Detector"] = {} - tem["Detector"] = tem["EELS"] - del tem["EELS"] - if "EDS" in tem: - if "Detector" not in tem: - tem["Detector"] = {} - if "EDS" not in tem["Detector"]: - tem["Detector"]["EDS"] = {} - tem["Detector"]["EDS"] = tem["EDS"] - del tem["EDS"] - del tem - if "SEM" in exp["metadata"]: - if "Acquisition_instrument" not in exp["metadata"]: - exp["metadata"]["Acquisition_instrument"] = {} - exp["metadata"]["Acquisition_instrument"]["SEM"] = \ - exp["metadata"]["SEM"] - del exp["metadata"]["SEM"] - sem = exp["metadata"]["Acquisition_instrument"]["SEM"] - if "EDS" in sem: - if "Detector" not in sem: - sem["Detector"] = {} - if "EDS" not in sem["Detector"]: - sem["Detector"]["EDS"] = {} - sem["Detector"]["EDS"] = sem["EDS"] - del sem["EDS"] - del sem - - if "Sample" in exp["metadata"] and "Xray_lines" in exp[ - "metadata"]["Sample"]: - exp["metadata"]["Sample"]["xray_lines"] = exp[ - "metadata"]["Sample"]["Xray_lines"] - del exp["metadata"]["Sample"]["Xray_lines"] - - for key in ["title", "date", "time", "original_filename"]: - if key in exp["metadata"]: - if "General" not in exp["metadata"]: - exp["metadata"]["General"] = {} - exp["metadata"]["General"][key] = exp["metadata"][key] - del exp["metadata"][key] - for key in ["record_by", "signal_origin", "signal_type"]: - if key in exp["metadata"]: + # If the title was not defined on writing the Experiment is + # then called __unnamed__. The next "if" simply sets the title + # back to the empty string + if "General" in exp["metadata"] and "title" in exp["metadata"]["General"]: + if '__unnamed__' == exp['metadata']['General']['title']: + exp['metadata']["General"]['title'] = '' + + if current_file_version < LooseVersion("1.1"): + # Load the decomposition results written with the old name, + # mva_results + if 'mva_results' in group.keys(): + exp['attributes']['learning_results'] = hdfgroup2dict( + group['mva_results'], lazy=lazy) + if 'peak_mva_results' in group.keys(): + exp['attributes']['peak_learning_results'] = hdfgroup2dict( + group['peak_mva_results'], lazy=lazy) + # Replace the old signal and name keys with their current names + if 'signal' in exp['metadata']: if "Signal" not in exp["metadata"]: exp["metadata"]["Signal"] = {} - exp["metadata"]["Signal"][key] = exp["metadata"][key] - del exp["metadata"][key] - - if current_file_version < LooseVersion("3.0"): - if "Acquisition_instrument" in exp["metadata"]: - # Move tilt_stage to Stage.tilt_alpha - # Move exposure time to Detector.Camera.exposure_time - if "TEM" in exp["metadata"]["Acquisition_instrument"]: + exp['metadata']["Signal"]['signal_type'] = \ + exp['metadata']['signal'] + del exp['metadata']['signal'] + + if 'name' in exp['metadata']: + if "General" not in exp["metadata"]: + exp["metadata"]["General"] = {} + exp['metadata']['General']['title'] = \ + exp['metadata']['name'] + del exp['metadata']['name'] + + if self.version < LooseVersion("1.2"): + if '_internal_parameters' in exp['metadata']: + exp['metadata']['_HyperSpy'] = \ + exp['metadata']['_internal_parameters'] + del exp['metadata']['_internal_parameters'] + if 'stacking_history' in exp['metadata']['_HyperSpy']: + exp['metadata']['_HyperSpy']["Stacking_history"] = \ + exp['metadata']['_HyperSpy']['stacking_history'] + del exp['metadata']['_HyperSpy']["stacking_history"] + if 'folding' in exp['metadata']['_HyperSpy']: + exp['metadata']['_HyperSpy']["Folding"] = \ + exp['metadata']['_HyperSpy']['folding'] + del exp['metadata']['_HyperSpy']["folding"] + if 'Variance_estimation' in exp['metadata']: + if "Noise_properties" not in exp["metadata"]: + exp["metadata"]["Noise_properties"] = {} + exp['metadata']['Noise_properties']["Variance_linear_model"] = \ + exp['metadata']['Variance_estimation'] + del exp['metadata']['Variance_estimation'] + if "TEM" in exp["metadata"]: + if "Acquisition_instrument" not in exp["metadata"]: + exp["metadata"]["Acquisition_instrument"] = {} + exp["metadata"]["Acquisition_instrument"]["TEM"] = \ + exp["metadata"]["TEM"] + del exp["metadata"]["TEM"] tem = exp["metadata"]["Acquisition_instrument"]["TEM"] - exposure = None - if "tilt_stage" in tem: - tem["Stage"] = {"tilt_alpha": tem["tilt_stage"]} - del tem["tilt_stage"] - if "exposure" in tem: - exposure = "exposure" - # Digital_micrograph plugin was parsing to 'exposure_time' - # instead of 'exposure': need this to be compatible with - # previous behaviour - if "exposure_time" in tem: - exposure = "exposure_time" - if exposure is not None: + if "EELS" in tem: + if "dwell_time" in tem: + tem["EELS"]["dwell_time"] = tem["dwell_time"] + del tem["dwell_time"] + if "dwell_time_units" in tem: + tem["EELS"]["dwell_time_units"] = tem["dwell_time_units"] + del tem["dwell_time_units"] + if "exposure" in tem: + tem["EELS"]["exposure"] = tem["exposure"] + del tem["exposure"] + if "exposure_units" in tem: + tem["EELS"]["exposure_units"] = tem["exposure_units"] + del tem["exposure_units"] if "Detector" not in tem: - tem["Detector"] = {"Camera": { - "exposure": tem[exposure]}} - tem["Detector"]["Camera"] = {"exposure": tem[exposure]} - del tem[exposure] - # Move tilt_stage to Stage.tilt_alpha - if "SEM" in exp["metadata"]["Acquisition_instrument"]: + tem["Detector"] = {} + tem["Detector"] = tem["EELS"] + del tem["EELS"] + if "EDS" in tem: + if "Detector" not in tem: + tem["Detector"] = {} + if "EDS" not in tem["Detector"]: + tem["Detector"]["EDS"] = {} + tem["Detector"]["EDS"] = tem["EDS"] + del tem["EDS"] + del tem + if "SEM" in exp["metadata"]: + if "Acquisition_instrument" not in exp["metadata"]: + exp["metadata"]["Acquisition_instrument"] = {} + exp["metadata"]["Acquisition_instrument"]["SEM"] = \ + exp["metadata"]["SEM"] + del exp["metadata"]["SEM"] sem = exp["metadata"]["Acquisition_instrument"]["SEM"] - if "tilt_stage" in sem: - sem["Stage"] = {"tilt_alpha": sem["tilt_stage"]} - del sem["tilt_stage"] - - return exp - + if "EDS" in sem: + if "Detector" not in sem: + sem["Detector"] = {} + if "EDS" not in sem["Detector"]: + sem["Detector"]["EDS"] = {} + sem["Detector"]["EDS"] = sem["EDS"] + del sem["EDS"] + del sem + + if "Sample" in exp["metadata"] and "Xray_lines" in exp[ + "metadata"]["Sample"]: + exp["metadata"]["Sample"]["xray_lines"] = exp[ + "metadata"]["Sample"]["Xray_lines"] + del exp["metadata"]["Sample"]["Xray_lines"] + + for key in ["title", "date", "time", "original_filename"]: + if key in exp["metadata"]: + if "General" not in exp["metadata"]: + exp["metadata"]["General"] = {} + exp["metadata"]["General"][key] = exp["metadata"][key] + del exp["metadata"][key] + for key in ["record_by", "signal_origin", "signal_type"]: + if key in exp["metadata"]: + if "Signal" not in exp["metadata"]: + exp["metadata"]["Signal"] = {} + exp["metadata"]["Signal"][key] = exp["metadata"][key] + del exp["metadata"][key] + + if self.version < LooseVersion("3.0"): + if "Acquisition_instrument" in exp["metadata"]: + # Move tilt_stage to Stage.tilt_alpha + # Move exposure time to Detector.Camera.exposure_time + if "TEM" in exp["metadata"]["Acquisition_instrument"]: + tem = exp["metadata"]["Acquisition_instrument"]["TEM"] + exposure = None + if "tilt_stage" in tem: + tem["Stage"] = {"tilt_alpha": tem["tilt_stage"]} + del tem["tilt_stage"] + if "exposure" in tem: + exposure = "exposure" + # Digital_micrograph plugin was parsing to 'exposure_time' + # instead of 'exposure': need this to be compatible with + # previous behaviour + if "exposure_time" in tem: + exposure = "exposure_time" + if exposure is not None: + if "Detector" not in tem: + tem["Detector"] = {"Camera": { + "exposure": tem[exposure]}} + tem["Detector"]["Camera"] = {"exposure": tem[exposure]} + del tem[exposure] + # Move tilt_stage to Stage.tilt_alpha + if "SEM" in exp["metadata"]["Acquisition_instrument"]: + sem = exp["metadata"]["Acquisition_instrument"]["SEM"] + if "tilt_stage" in sem: + sem["Stage"] = {"tilt_alpha": sem["tilt_stage"]} + del sem["tilt_stage"] + + return exp + + def hdfgroup2dict(self, + group, + dictionary=None, + lazy=False): + if dictionary is None: + dictionary = {} + for key, value in group.attrs.items(): + if isinstance(value, bytes): + value = value.decode() + if isinstance(value, (np.string_, str)): + if value == '_None_': + value = None + elif isinstance(value, np.bool_): + value = bool(value) + elif isinstance(value, np.ndarray) and value.dtype.char == "S": + # Convert strings to unicode + value = value.astype("U") + if value.dtype.str.endswith("U1"): + value = value.tolist() + # skip signals - these are handled below. + if key.startswith('_sig_'): + pass + elif key.startswith('_list_empty_'): + dictionary[key[len('_list_empty_'):]] = [] + elif key.startswith('_tuple_empty_'): + dictionary[key[len('_tuple_empty_'):]] = () + elif key.startswith('_bs_'): + dictionary[key[len('_bs_'):]] = value.tobytes() + # The following two elif stataments enable reading date and time from + # v < 2 of HyperSpy's metadata specifications + elif key.startswith('_datetime_date'): + date_iso = datetime.date( + *ast.literal_eval(value[value.index("("):])).isoformat() + dictionary[key.replace("_datetime_", "")] = date_iso + elif key.startswith('_datetime_time'): + date_iso = datetime.time( + *ast.literal_eval(value[value.index("("):])).isoformat() + dictionary[key.replace("_datetime_", "")] = date_iso + else: + dictionary[key] = value + if not isinstance(group, Dataset): + for key in group.keys(): + if key.startswith('_sig_'): + from hyperspy.io import dict2signal + dictionary[key[len('_sig_'):]] = ( + dict2signal(self.hdfgroup2signaldict( + group[key], lazy=lazy))) + elif isinstance(group[key], Dataset): + dat = group[key] + kn = key + if key.startswith("_list_"): + if (h5py.check_string_dtype(dat.dtype) and + hasattr(dat, 'asstr')): + # h5py 3.0 and newer + # https://docs.h5py.org/en/3.0.0/strings.html + dat = dat.asstr()[:] + ans = np.array(dat) + ans = ans.tolist() + kn = key[6:] + elif key.startswith("_tuple_"): + ans = np.array(dat) + ans = tuple(ans.tolist()) + kn = key[7:] + elif dat.dtype.char == "S": + ans = np.array(dat) + try: + ans = ans.astype("U") + except UnicodeDecodeError: + # There are some strings that must stay in binary, + # for example dill pickles. This will obviously also + # let "wrong" binary string fail somewhere else... + pass + elif lazy: + ans = da.from_array(dat, chunks=dat.chunks) + else: + ans = np.array(dat) + dictionary[kn] = ans + elif key.startswith('_hspy_AxesManager_'): + dictionary[key[len('_hspy_AxesManager_'):]] = AxesManager( + [i for k, i in sorted(iter( + self.hdfgroup2dict( + group[key], lazy=lazy).items() + ))]) + elif key.startswith('_list_'): + dictionary[key[7 + key[6:].find('_'):]] = \ + [i for k, i in sorted(iter( + self.hdfgroup2dict( + group[key], lazy=lazy).items() + ))] + elif key.startswith('_tuple_'): + dictionary[key[8 + key[7:].find('_'):]] = tuple( + [i for k, i in sorted(iter( + self.hdfgroup2dict( + group[key], lazy=lazy).items() + ))]) + else: + dictionary[key] = {} + self.hdfgroup2dict( + group[key], + dictionary[key], + lazy=lazy) -def dict2hdfgroup(dictionary, group, **kwds): - "Recursive writer of dicts and signals" + return dictionary - from hyperspy.misc.utils import DictionaryTreeBrowser - from hyperspy.signal import BaseSignal +class HyperspyWriter: + """An object used to simplify and orgainize the process for + writing a hyperspy signal. (.hspy format) + """ + def __init__(self, file): + + def dict2hdfgroup(self, dictionary, group, **kwds): + "Recursive writer of dicts and signals" + + from hyperspy.misc.utils import DictionaryTreeBrowser + from hyperspy.signal import BaseSignal + + for key, value in dictionary.items(): + if isinstance(value, dict): + self.dict2hdfgroup(value, group.create_group(key), + **kwds) + elif isinstance(value, DictionaryTreeBrowser): + self.dict2hdfgroup(value.as_dictionary(), + group.create_group(key), + **kwds) + elif isinstance(value, BaseSignal): + kn = key if key.startswith('_sig_') else '_sig_' + key + self.write_signal(value, group.require_group(kn)) + elif isinstance(value, (np.ndarray, Dataset, da.Array)): + overwrite_dataset(group, value, key, **kwds) + elif value is None: + group.attrs[key] = '_None_' + elif isinstance(value, bytes): + try: + # binary string if has any null characters (otherwise not + # supported by hdf5) + value.index(b'\x00') + group.attrs['_bs_' + key] = np.void(value) + except ValueError: + group.attrs[key] = value.decode() + elif isinstance(value, str): + group.attrs[key] = value + elif isinstance(value, AxesManager): + dict2hdfgroup(value.as_dictionary(), + group.create_group('_hspy_AxesManager_' + key), + **kwds) + elif isinstance(value, list): + if len(value): + parse_structure(key, group, value, '_list_', **kwds) + else: + group.attrs['_list_empty_' + key] = '_None_' + elif isinstance(value, tuple): + if len(value): + parse_structure(key, group, value, '_tuple_', **kwds) + else: + group.attrs['_tuple_empty_' + key] = '_None_' - def parse_structure(key, group, value, _type, **kwds): - try: - # Here we check if there are any signals in the container, as - # casting a long list of signals to a numpy array takes a very long - # time. So we check if there are any, and save numpy the trouble - if np.any([isinstance(t, BaseSignal) for t in value]): - tmp = np.array([[0]]) + elif value is Undefined: + continue else: - tmp = np.array(value) - except ValueError: - tmp = np.array([[0]]) - if tmp.dtype == np.dtype('O') or tmp.ndim != 1: - dict2hdfgroup(dict(zip( - [str(i) for i in range(len(value))], value)), - group.create_group(_type + str(len(value)) + '_' + key), - **kwds) - elif tmp.dtype.type is np.unicode_: - if _type + key in group: - del group[_type + key] - group.create_dataset(_type + key, - tmp.shape, - dtype=h5py.special_dtype(vlen=str), - **kwds) - group[_type + key][:] = tmp[:] - else: - if _type + key in group: - del group[_type + key] - group.create_dataset( - _type + key, - data=tmp, - **kwds) - - for key, value in dictionary.items(): - if isinstance(value, dict): - dict2hdfgroup(value, group.create_group(key), - **kwds) - elif isinstance(value, DictionaryTreeBrowser): - dict2hdfgroup(value.as_dictionary(), - group.create_group(key), - **kwds) - elif isinstance(value, BaseSignal): - kn = key if key.startswith('_sig_') else '_sig_' + key - write_signal(value, group.require_group(kn)) - elif isinstance(value, (np.ndarray, Dataset, da.Array)): - overwrite_dataset(group, value, key, **kwds) - elif value is None: - group.attrs[key] = '_None_' - elif isinstance(value, bytes): - try: - # binary string if has any null characters (otherwise not - # supported by hdf5) - value.index(b'\x00') - group.attrs['_bs_' + key] = np.void(value) - except ValueError: - group.attrs[key] = value.decode() - elif isinstance(value, str): - group.attrs[key] = value - elif isinstance(value, AxesManager): - dict2hdfgroup(value.as_dictionary(), - group.create_group('_hspy_AxesManager_' + key), - **kwds) - elif isinstance(value, list): - if len(value): - parse_structure(key, group, value, '_list_', **kwds) - else: - group.attrs['_list_empty_' + key] = '_None_' - elif isinstance(value, tuple): - if len(value): - parse_structure(key, group, value, '_tuple_', **kwds) + try: + group.attrs[key] = value + except BaseException: + _logger.exception( + "The hdf5 writer could not write the following " + "information in the file: %s : %s", key, value) + + + def get_signal_chunks(shape, dtype, signal_axes=None): + """Function that calculates chunks for the signal, preferably at least one + chunk per signal space. + + Parameters + ---------- + shape : tuple + the shape of the dataset to be sored / chunked + dtype : {dtype, string} + the numpy dtype of the data + signal_axes: {None, iterable of ints} + the axes defining "signal space" of the dataset. If None, the default + h5py chunking is performed. + """ + typesize = np.dtype(dtype).itemsize + if signal_axes is None: + return h5py._hl.filters.guess_chunk(shape, None, typesize) + + # largely based on the guess_chunk in h5py + CHUNK_MAX = 1024 * 1024 + want_to_keep = multiply([shape[i] for i in signal_axes]) * typesize + if want_to_keep >= CHUNK_MAX: + chunks = [1 for _ in shape] + for i in signal_axes: + chunks[i] = shape[i] + return tuple(chunks) + + chunks = [i for i in shape] + idx = 0 + navigation_axes = tuple(i for i in range(len(shape)) if i not in + signal_axes) + nchange = len(navigation_axes) + while True: + chunk_bytes = multiply(chunks) * typesize + + if chunk_bytes < CHUNK_MAX: + break + + if multiply([chunks[i] for i in navigation_axes]) == 1: + break + change = navigation_axes[idx % nchange] + chunks[change] = np.ceil(chunks[change] / 2.0) + idx += 1 + return tuple(int(x) for x in chunks) + + def overwrite_dataset(self, + group, + data, + key, + signal_axes=None, + chunks=None, + **kwds): + if chunks is None: + if isinstance(data, da.Array): + # For lazy dataset, by default, we use the current dask chunking + chunks = tuple([c[0] for c in data.chunks]) else: - group.attrs['_tuple_empty_' + key] = '_None_' - - elif value is Undefined: - continue - else: - try: - group.attrs[key] = value - except BaseException: - _logger.exception( - "The hdf5 writer could not write the following " - "information in the file: %s : %s", key, value) - - -def get_signal_chunks(shape, dtype, signal_axes=None): - """Function that calculates chunks for the signal, preferably at least one - chunk per signal space. - - Parameters - ---------- - shape : tuple - the shape of the dataset to be sored / chunked - dtype : {dtype, string} - the numpy dtype of the data - signal_axes: {None, iterable of ints} - the axes defining "signal space" of the dataset. If None, the default - h5py chunking is performed. - """ - typesize = np.dtype(dtype).itemsize - if signal_axes is None: - return h5py._hl.filters.guess_chunk(shape, None, typesize) - - # largely based on the guess_chunk in h5py - CHUNK_MAX = 1024 * 1024 - want_to_keep = multiply([shape[i] for i in signal_axes]) * typesize - if want_to_keep >= CHUNK_MAX: - chunks = [1 for _ in shape] - for i in signal_axes: - chunks[i] = shape[i] - return tuple(chunks) - - chunks = [i for i in shape] - idx = 0 - navigation_axes = tuple(i for i in range(len(shape)) if i not in - signal_axes) - nchange = len(navigation_axes) - while True: - chunk_bytes = multiply(chunks) * typesize - - if chunk_bytes < CHUNK_MAX: - break - - if multiply([chunks[i] for i in navigation_axes]) == 1: - break - change = navigation_axes[idx % nchange] - chunks[change] = np.ceil(chunks[change] / 2.0) - idx += 1 - return tuple(int(x) for x in chunks) - - -def overwrite_dataset(group, - data, - key, - signal_axes=None, - chunks=None, - **kwds): - if chunks is None: - if isinstance(data, da.Array): - # For lazy dataset, by default, we use the current dask chunking - chunks = tuple([c[0] for c in data.chunks]) + # If signal_axes=None, use automatic h5py chunking, otherwise + # optimise the chunking to contain at least one signal per chunk + chunks = self.get_signal_chunks(data.shape, data.dtype, signal_axes) + if np.issubdtype(data.dtype, np.dtype('U')): + # Saving numpy unicode type is not supported in h5py + data = data.astype(np.dtype('S')) + if data.dtype == np.dtype('O'): + dset = self.get_object_dset(group, data, key, chunks, **kwds) else: - # If signal_axes=None, use automatic h5py chunking, otherwise - # optimise the chunking to contain at least one signal per chunk - chunks = get_signal_chunks(data.shape, data.dtype, signal_axes) - if np.issubdtype(data.dtype, np.dtype('U')): - # Saving numpy unicode type is not supported in h5py - data = data.astype(np.dtype('S')) - if data.dtype == np.dtype('O'): - dset = get_object_dset(group, data, key, chunks, **kwds) - else: - got_data=False - maxshape = tuple(None for _ in data.shape) - while not got_data: - try: - these_kwds = kwds.copy() - these_kwds.update(dict(shape=data.shape, - dtype=data.dtype, - exact=True, - maxshape=maxshape, - chunks=chunks, - shuffle=True, )) - - # If chunks is True, the `chunks` attribute of `dset` below - # contains the chunk shape guessed by h5py - dset = group.require_dataset(key, **these_kwds) - got_data = True - except TypeError: - # if the shape or dtype/etc do not match, - # we delete the old one and create new in the next loop run - del group[key] - if dset == data: - # just a reference to already created thing - pass - else: - _logger.info(f"Chunks used for saving: {dset.chunks}") - store_data(data, dset, group, key, chunks, **kwds) - - -def get_object_dset(group, data, key, chunks, **kwds): - # For saving ragged array - # http://docs.h5py.org/en/stable/special.html#arbitrary-vlen-data - dset = group.require_dataset(key, - chunks, - dtype=h5py.special_dtype(vlen=data[0].dtype), - **kwds) - return dset - -def store_data(data, dset, group, key, chunks, **kwds): - if isinstance(data, da.Array): - if data.chunks != dset.chunks: - data = data.rechunk(dset.chunks) - da.store(data, dset) - elif data.flags.c_contiguous: - dset.write_direct(data) - else: - dset[:] = data - - -def hdfgroup2dict(group, dictionary=None, lazy=False): - if dictionary is None: - dictionary = {} - for key, value in group.attrs.items(): - if isinstance(value, bytes): - value = value.decode() - if isinstance(value, (np.string_, str)): - if value == '_None_': - value = None - elif isinstance(value, np.bool_): - value = bool(value) - elif isinstance(value, np.ndarray) and value.dtype.char == "S": - # Convert strings to unicode - value = value.astype("U") - if value.dtype.str.endswith("U1"): - value = value.tolist() - # skip signals - these are handled below. - if key.startswith('_sig_'): + got_data=False + maxshape = tuple(None for _ in data.shape) + while not got_data: + try: + these_kwds = kwds.copy() + these_kwds.update(dict(shape=data.shape, + dtype=data.dtype, + exact=True, + maxshape=maxshape, + chunks=chunks, + shuffle=True, )) + + # If chunks is True, the `chunks` attribute of `dset` below + # contains the chunk shape guessed by h5py + dset = group.require_dataset(key, **these_kwds) + got_data = True + except TypeError: + # if the shape or dtype/etc do not match, + # we delete the old one and create new in the next loop run + del group[key] + if dset == data: + # just a reference to already created thing pass - elif key.startswith('_list_empty_'): - dictionary[key[len('_list_empty_'):]] = [] - elif key.startswith('_tuple_empty_'): - dictionary[key[len('_tuple_empty_'):]] = () - elif key.startswith('_bs_'): - dictionary[key[len('_bs_'):]] = value.tobytes() - # The following two elif stataments enable reading date and time from - # v < 2 of HyperSpy's metadata specifications - elif key.startswith('_datetime_date'): - date_iso = datetime.date( - *ast.literal_eval(value[value.index("("):])).isoformat() - dictionary[key.replace("_datetime_", "")] = date_iso - elif key.startswith('_datetime_time'): - date_iso = datetime.time( - *ast.literal_eval(value[value.index("("):])).isoformat() - dictionary[key.replace("_datetime_", "")] = date_iso else: - dictionary[key] = value - if not isinstance(group, Dataset): - for key in group.keys(): - if key.startswith('_sig_'): - from hyperspy.io import dict2signal - dictionary[key[len('_sig_'):]] = ( - dict2signal(hdfgroup2signaldict( - group[key], lazy=lazy))) - elif isinstance(group[key], Dataset): - dat = group[key] - kn = key - if key.startswith("_list_"): - if (h5py.check_string_dtype(dat.dtype) and - hasattr(dat, 'asstr')): - # h5py 3.0 and newer - # https://docs.h5py.org/en/3.0.0/strings.html - dat = dat.asstr()[:] - ans = np.array(dat) - ans = ans.tolist() - kn = key[6:] - elif key.startswith("_tuple_"): - ans = np.array(dat) - ans = tuple(ans.tolist()) - kn = key[7:] - elif dat.dtype.char == "S": - ans = np.array(dat) - try: - ans = ans.astype("U") - except UnicodeDecodeError: - # There are some strings that must stay in binary, - # for example dill pickles. This will obviously also - # let "wrong" binary string fail somewhere else... - pass - elif lazy: - ans = da.from_array(dat, chunks=dat.chunks) - else: - ans = np.array(dat) - dictionary[kn] = ans - elif key.startswith('_hspy_AxesManager_'): - dictionary[key[len('_hspy_AxesManager_'):]] = AxesManager( - [i for k, i in sorted(iter( - hdfgroup2dict( - group[key], lazy=lazy).items() - ))]) - elif key.startswith('_list_'): - dictionary[key[7 + key[6:].find('_'):]] = \ - [i for k, i in sorted(iter( - hdfgroup2dict( - group[key], lazy=lazy).items() - ))] - elif key.startswith('_tuple_'): - dictionary[key[8 + key[7:].find('_'):]] = tuple( - [i for k, i in sorted(iter( - hdfgroup2dict( - group[key], lazy=lazy).items() - ))]) - else: - dictionary[key] = {} - hdfgroup2dict( - group[key], - dictionary[key], - lazy=lazy) - - return dictionary + _logger.info(f"Chunks used for saving: {dset.chunks}") + self.store_data(data, dset, group, key, chunks, **kwds) + + def get_object_dset(self, group, data, key, chunks, **kwds): + # For saving ragged array + # http://docs.h5py.org/en/stable/special.html#arbitrary-vlen-data + dset = group.require_dataset(key, + chunks, + dtype=h5py.special_dtype(vlen=data[0].dtype), + **kwds) + return dset + + def store_data(self, data, dset, group, key, chunks, **kwds): + if isinstance(data, da.Array): + if data.chunks != dset.chunks: + data = data.rechunk(dset.chunks) + da.store(data, dset) + elif data.flags.c_contiguous: + dset.write_direct(data) + else: + dset[:] = data -def write_signal(signal, group, **kwds): - "Writes a hyperspy signal to a hdf5 group" + def write_signal(self, signal, group, **kwds): + "Writes a hyperspy signal to a hdf5 group" - group.attrs.update(get_object_package_info(signal)) - if default_version < LooseVersion("1.2"): - metadata = "mapped_parameters" - original_metadata = "original_parameters" - else: - metadata = "metadata" - original_metadata = "original_metadata" - - if 'compression' not in kwds: - kwds['compression'] = 'gzip' - - for axis in signal.axes_manager._axes: - axis_dict = axis.get_axis_dictionary() - coord_group = group.create_group( - 'axis-%s' % axis.index_in_array) - dict2hdfgroup(axis_dict, coord_group, **kwds) - mapped_par = group.create_group(metadata) - metadata_dict = signal.metadata.as_dictionary() - overwrite_dataset(group, signal.data, 'data', - signal_axes=signal.axes_manager.signal_indices_in_array, + group.attrs.update(get_object_package_info(signal)) + if default_version < LooseVersion("1.2"): + metadata = "mapped_parameters" + original_metadata = "original_parameters" + else: + metadata = "metadata" + original_metadata = "original_metadata" + + if 'compression' not in kwds: + kwds['compression'] = 'gzip' + + for axis in signal.axes_manager._axes: + axis_dict = axis.get_axis_dictionary() + coord_group = group.create_group( + 'axis-%s' % axis.index_in_array) + self.dict2hdfgroup(axis_dict, coord_group, **kwds) + mapped_par = group.create_group(metadata) + metadata_dict = signal.metadata.as_dictionary() + self.overwrite_dataset(group, signal.data, 'data', + signal_axes=signal.axes_manager.signal_indices_in_array, + **kwds) + if default_version < LooseVersion("1.2"): + metadata_dict["_internal_parameters"] = \ + metadata_dict.pop("_HyperSpy") + # Remove chunks from the kwds since it wouldn't have the same rank as the + # dataset and can't be used + kwds.pop('chunks', None) + self.dict2hdfgroup(metadata_dict, mapped_par, **kwds) + original_par = group.create_group(original_metadata) + self.dict2hdfgroup(signal.original_metadata.as_dictionary(), original_par, **kwds) - if default_version < LooseVersion("1.2"): - metadata_dict["_internal_parameters"] = \ - metadata_dict.pop("_HyperSpy") - # Remove chunks from the kwds since it wouldn't have the same rank as the - # dataset and can't be used - kwds.pop('chunks', None) - dict2hdfgroup(metadata_dict, mapped_par, **kwds) - original_par = group.create_group(original_metadata) - dict2hdfgroup(signal.original_metadata.as_dictionary(), original_par, - **kwds) - learning_results = group.create_group('learning_results') - dict2hdfgroup(signal.learning_results.__dict__, - learning_results, **kwds) - if hasattr(signal, 'peak_learning_results'): - peak_learning_results = group.create_group( - 'peak_learning_results') - dict2hdfgroup(signal.peak_learning_results.__dict__, - peak_learning_results, **kwds) - - if len(signal.models): - model_group = group.file.require_group('Analysis/models') - dict2hdfgroup(signal.models._models.as_dictionary(), - model_group, **kwds) - for model in model_group.values(): - model.attrs['_signal'] = group.name + learning_results = group.create_group('learning_results') + self.dict2hdfgroup(signal.learning_results.__dict__, + learning_results, **kwds) + if hasattr(signal, 'peak_learning_results'): + peak_learning_results = group.create_group( + 'peak_learning_results') + self.dict2hdfgroup(signal.peak_learning_results.__dict__, + peak_learning_results, **kwds) + + if len(signal.models): + model_group = group.file.require_group('Analysis/models') + self.dict2hdfgroup(signal.models._models.as_dictionary(), + model_group, **kwds) + for model in model_group.values(): + model.attrs['_signal'] = group.name + + def file_reader(self, + filename, + backing_store=False, + lazy=False, + **kwds): + """Read data from hdf5 files saved with the hyperspy hdf5 format specification + + Parameters + ---------- + filename: str + lazy: bool + Load image lazily using dask + **kwds, optional + """ + try: + # in case blosc compression is used + import hdf5plugin + except ImportError: + pass + mode = kwds.pop('mode', 'r') + f = File(filename, mode=mode, **kwds) + # Getting the format version here also checks if it is a valid HSpy + # hdf5 file, so the following two lines must not be deleted or moved + # elsewhere. + reader = HyperspyReader(f) + if reader.version > version: + warnings.warn( + "This file was written using a newer version of the " + "HyperSpy hdf5 file format. I will attempt to load it, but, " + "if I fail, it is likely that I will be more successful at " + "this and other tasks if you upgrade me.") + + models_with_signals = [] + standalone_models = [] + if 'Analysis/models' in f: + try: + m_gr = f['Analysis/models'] + for model_name in m_gr: + if '_signal' in m_gr[model_name].attrs: + key = m_gr[model_name].attrs['_signal'] + # del m_gr[model_name].attrs['_signal'] + res = self.hdfgroup2dict( + m_gr[model_name], + lazy=lazy) + del res['_signal'] + models_with_signals.append((key, {model_name: res})) + else: + standalone_models.append( + {model_name: hdfgroup2dict( + m_gr[model_name], lazy=lazy)}) + except TypeError: + raise IOError(not_valid_format) + + experiments = [] + exp_dict_list = [] + if 'Experiments' in f: + for ds in f['Experiments']: + if isinstance(f['Experiments'][ds], Group): + if 'data' in f['Experiments'][ds]: + experiments.append(ds) + # Parse the file + for experiment in experiments: + exg = f['Experiments'][experiment] + exp = self.hdfgroup2signaldict(exg, lazy) + # assign correct models, if found: + _tmp = {} + for (key, _dict) in reversed(models_with_signals): + if key == exg.name: + _tmp.update(_dict) + models_with_signals.remove((key, _dict)) + exp['models'] = _tmp + + exp_dict_list.append(exp) + + for _, m in models_with_signals: + standalone_models.append(m) + + exp_dict_list.extend(standalone_models) + if not len(exp_dict_list): + raise IOError('This is not a valid HyperSpy HDF5 file. ' + 'You can still load the data using a hdf5 reader, ' + 'e.g. h5py, and manually create a Signal. ' + 'Please, refer to the User Guide for details') + if not lazy: + f.close() + return exp_dict_list def file_writer(filename, signal, *args, **kwds): @@ -819,3 +794,35 @@ def file_writer(filename, signal, *args, **kwds): raise finally: del smd.record_by + + +def parse_structure(key, group, value, _type, **kwds): + try: + # Here we check if there are any signals in the container, as + # casting a long list of signals to a numpy array takes a very long + # time. So we check if there are any, and save numpy the trouble + if np.any([isinstance(t, BaseSignal) for t in value]): + tmp = np.array([[0]]) + else: + tmp = np.array(value) + except ValueError: + tmp = np.array([[0]]) + if tmp.dtype == np.dtype('O') or tmp.ndim != 1: + self.dict2hdfgroup(dict(zip( + [str(i) for i in range(len(value))], value)), + group.create_group(_type + str(len(value)) + '_' + key), + **kwds) + elif tmp.dtype.type is np.unicode_: + if _type + key in group: + del group[_type + key] + group.create_dataset(_type + key, + tmp.shape, + dtype=h5py.special_dtype(vlen=str), + **kwds) + group[_type + key][:] = tmp[:] + else: + if _type + key in group: + del group[_type + key] + group.create_dataset(_type + key, + data=tmp, + **kwds) \ No newline at end of file diff --git a/hyperspy/io_plugins/spy.py b/hyperspy/io_plugins/spy.py new file mode 100644 index 0000000000..e69de29bb2 From 37d5fe09e9e21c6bd8172e9e48c72094e8743062 Mon Sep 17 00:00:00 2001 From: shaw Date: Fri, 3 Sep 2021 11:36:52 -0500 Subject: [PATCH 04/31] Hyperspy file loading as a class set of functions --- hyperspy/io_plugins/emd.py | 2 +- hyperspy/io_plugins/hspy.py | 364 +++++++++++++++++++---------------- hyperspy/io_plugins/nexus.py | 2 +- hyperspy/io_plugins/zspy.py | 2 +- 4 files changed, 196 insertions(+), 174 deletions(-) diff --git a/hyperspy/io_plugins/emd.py b/hyperspy/io_plugins/emd.py index aeac092490..8773d218b4 100644 --- a/hyperspy/io_plugins/emd.py +++ b/hyperspy/io_plugins/emd.py @@ -43,7 +43,7 @@ from hyperspy.exceptions import VisibleDeprecationWarning from hyperspy.misc.elements import atomic_number2name import hyperspy.misc.io.fei_stream_readers as stream_readers -from hyperspy.io_plugins.hspy import get_signal_chunks +#from hyperspy.io_plugins.hspy import get_signal_chunks # Plugin characteristics diff --git a/hyperspy/io_plugins/hspy.py b/hyperspy/io_plugins/hspy.py index 91630b3928..20d06cca2d 100644 --- a/hyperspy/io_plugins/hspy.py +++ b/hyperspy/io_plugins/hspy.py @@ -127,6 +127,61 @@ def get_hspy_format_version(self): raise IOError(not_valid_format) return LooseVersion(version) + def read(self, lazy): + models_with_signals = [] + standalone_models = [] + if 'Analysis/models' in self.file: + try: + m_gr = self.file['Analysis/models'] + for model_name in m_gr: + if '_signal' in m_gr[model_name].attrs: + key = m_gr[model_name].attrs['_signal'] + # del m_gr[model_name].attrs['_signal'] + res = self.hdfgroup2dict( + m_gr[model_name], + lazy=lazy) + del res['_signal'] + models_with_signals.append((key, {model_name: res})) + else: + standalone_models.append( + {model_name: self.hdfgroup2dict( + m_gr[model_name], lazy=lazy)}) + except TypeError: + raise IOError(not_valid_format) + experiments = [] + exp_dict_list = [] + if 'Experiments' in self.file: + for ds in self.file['Experiments']: + if isinstance(self.file['Experiments'][ds], Group): + if 'data' in self.file['Experiments'][ds]: + experiments.append(ds) + # Parse the file + for experiment in experiments: + exg = self.file['Experiments'][experiment] + exp = self.hdfgroup2signaldict(exg, lazy) + # assign correct models, if found: + _tmp = {} + for (key, _dict) in reversed(models_with_signals): + if key == exg.name: + _tmp.update(_dict) + models_with_signals.remove((key, _dict)) + exp['models'] = _tmp + + exp_dict_list.append(exp) + + for _, m in models_with_signals: + standalone_models.append(m) + + exp_dict_list.extend(standalone_models) + if not len(exp_dict_list): + raise IOError('This is not a valid HyperSpy HDF5 file. ' + 'You can still load the data using a hdf5 reader, ' + 'e.g. h5py, and manually create a Signal. ' + 'Please, refer to the User Guide for details') + if not lazy: + self.file.close() + return exp_dict_list + def hdfgroup2signaldict(self, group, lazy=False): @@ -200,14 +255,14 @@ def hdfgroup2signaldict(self, if '__unnamed__' == exp['metadata']['General']['title']: exp['metadata']["General"]['title'] = '' - if current_file_version < LooseVersion("1.1"): + if self.version < LooseVersion("1.1"): # Load the decomposition results written with the old name, # mva_results if 'mva_results' in group.keys(): - exp['attributes']['learning_results'] = hdfgroup2dict( + exp['attributes']['learning_results'] = self.hdfgroup2dict( group['mva_results'], lazy=lazy) if 'peak_mva_results' in group.keys(): - exp['attributes']['peak_learning_results'] = hdfgroup2dict( + exp['attributes']['peak_learning_results'] = self.hdfgroup2dict( group['peak_mva_results'], lazy=lazy) # Replace the old signal and name keys with their current names if 'signal' in exp['metadata']: @@ -450,7 +505,70 @@ class HyperspyWriter: """An object used to simplify and orgainize the process for writing a hyperspy signal. (.hspy format) """ - def __init__(self, file): + def __init__(self, + file, + signal, + expg, + **kwds): + self.file = file + self.signal = signal + self.expg = expg + self.kwds = kwds + + def write(self): + self.write_signal(self.signal, + self.expg, + **self.kwds) + + def write_signal(self, signal, group, **kwds): + "Writes a hyperspy signal to a hdf5 group" + + group.attrs.update(get_object_package_info(signal)) + if default_version < LooseVersion("1.2"): + metadata = "mapped_parameters" + original_metadata = "original_parameters" + else: + metadata = "metadata" + original_metadata = "original_metadata" + + if 'compression' not in kwds: + kwds['compression'] = 'gzip' + + for axis in signal.axes_manager._axes: + axis_dict = axis.get_axis_dictionary() + coord_group = group.create_group( + 'axis-%s' % axis.index_in_array) + self.dict2hdfgroup(axis_dict, coord_group, **kwds) + mapped_par = group.create_group(metadata) + metadata_dict = signal.metadata.as_dictionary() + self.overwrite_dataset(group, signal.data, 'data', + signal_axes=signal.axes_manager.signal_indices_in_array, + **kwds) + if default_version < LooseVersion("1.2"): + metadata_dict["_internal_parameters"] = \ + metadata_dict.pop("_HyperSpy") + # Remove chunks from the kwds since it wouldn't have the same rank as the + # dataset and can't be used + kwds.pop('chunks', None) + self.dict2hdfgroup(metadata_dict, mapped_par, **kwds) + original_par = group.create_group(original_metadata) + self.dict2hdfgroup(signal.original_metadata.as_dictionary(), original_par, + **kwds) + learning_results = group.create_group('learning_results') + self.dict2hdfgroup(signal.learning_results.__dict__, + learning_results, **kwds) + if hasattr(signal, 'peak_learning_results'): + peak_learning_results = group.create_group( + 'peak_learning_results') + self.dict2hdfgroup(signal.peak_learning_results.__dict__, + peak_learning_results, **kwds) + + if len(signal.models): + model_group = group.file.require_group('Analysis/models') + self.dict2hdfgroup(signal.models._models.as_dictionary(), + model_group, **kwds) + for model in model_group.values(): + model.attrs['_signal'] = group.name def dict2hdfgroup(self, dictionary, group, **kwds): "Recursive writer of dicts and signals" @@ -470,7 +588,7 @@ def dict2hdfgroup(self, dictionary, group, **kwds): kn = key if key.startswith('_sig_') else '_sig_' + key self.write_signal(value, group.require_group(kn)) elif isinstance(value, (np.ndarray, Dataset, da.Array)): - overwrite_dataset(group, value, key, **kwds) + self.overwrite_dataset(group, value, key, **kwds) elif value is None: group.attrs[key] = '_None_' elif isinstance(value, bytes): @@ -484,17 +602,17 @@ def dict2hdfgroup(self, dictionary, group, **kwds): elif isinstance(value, str): group.attrs[key] = value elif isinstance(value, AxesManager): - dict2hdfgroup(value.as_dictionary(), + self.dict2hdfgroup(value.as_dictionary(), group.create_group('_hspy_AxesManager_' + key), **kwds) elif isinstance(value, list): if len(value): - parse_structure(key, group, value, '_list_', **kwds) + self.parse_structure(key, group, value, '_list_', **kwds) else: group.attrs['_list_empty_' + key] = '_None_' elif isinstance(value, tuple): if len(value): - parse_structure(key, group, value, '_tuple_', **kwds) + self.parse_structure(key, group, value, '_tuple_', **kwds) else: group.attrs['_tuple_empty_' + key] = '_None_' @@ -509,7 +627,10 @@ def dict2hdfgroup(self, dictionary, group, **kwds): "information in the file: %s : %s", key, value) - def get_signal_chunks(shape, dtype, signal_axes=None): + def get_signal_chunks(self, + shape, + dtype, + signal_axes=None): """Function that calculates chunks for the signal, preferably at least one chunk per signal space. @@ -621,142 +742,71 @@ def store_data(self, data, dset, group, key, chunks, **kwds): else: dset[:] = data - def write_signal(self, signal, group, **kwds): - "Writes a hyperspy signal to a hdf5 group" - - group.attrs.update(get_object_package_info(signal)) - if default_version < LooseVersion("1.2"): - metadata = "mapped_parameters" - original_metadata = "original_parameters" - else: - metadata = "metadata" - original_metadata = "original_metadata" - - if 'compression' not in kwds: - kwds['compression'] = 'gzip' - - for axis in signal.axes_manager._axes: - axis_dict = axis.get_axis_dictionary() - coord_group = group.create_group( - 'axis-%s' % axis.index_in_array) - self.dict2hdfgroup(axis_dict, coord_group, **kwds) - mapped_par = group.create_group(metadata) - metadata_dict = signal.metadata.as_dictionary() - self.overwrite_dataset(group, signal.data, 'data', - signal_axes=signal.axes_manager.signal_indices_in_array, - **kwds) - if default_version < LooseVersion("1.2"): - metadata_dict["_internal_parameters"] = \ - metadata_dict.pop("_HyperSpy") - # Remove chunks from the kwds since it wouldn't have the same rank as the - # dataset and can't be used - kwds.pop('chunks', None) - self.dict2hdfgroup(metadata_dict, mapped_par, **kwds) - original_par = group.create_group(original_metadata) - self.dict2hdfgroup(signal.original_metadata.as_dictionary(), original_par, - **kwds) - learning_results = group.create_group('learning_results') - self.dict2hdfgroup(signal.learning_results.__dict__, - learning_results, **kwds) - if hasattr(signal, 'peak_learning_results'): - peak_learning_results = group.create_group( - 'peak_learning_results') - self.dict2hdfgroup(signal.peak_learning_results.__dict__, - peak_learning_results, **kwds) - - if len(signal.models): - model_group = group.file.require_group('Analysis/models') - self.dict2hdfgroup(signal.models._models.as_dictionary(), - model_group, **kwds) - for model in model_group.values(): - model.attrs['_signal'] = group.name - - def file_reader(self, - filename, - backing_store=False, - lazy=False, - **kwds): - """Read data from hdf5 files saved with the hyperspy hdf5 format specification - - Parameters - ---------- - filename: str - lazy: bool - Load image lazily using dask - **kwds, optional - """ + def parse_structure(self, key, group, value, _type, **kwds): + from hyperspy.signal import BaseSignal try: - # in case blosc compression is used - import hdf5plugin - except ImportError: - pass - mode = kwds.pop('mode', 'r') - f = File(filename, mode=mode, **kwds) - # Getting the format version here also checks if it is a valid HSpy - # hdf5 file, so the following two lines must not be deleted or moved - # elsewhere. - reader = HyperspyReader(f) - if reader.version > version: - warnings.warn( - "This file was written using a newer version of the " - "HyperSpy hdf5 file format. I will attempt to load it, but, " - "if I fail, it is likely that I will be more successful at " - "this and other tasks if you upgrade me.") - - models_with_signals = [] - standalone_models = [] - if 'Analysis/models' in f: - try: - m_gr = f['Analysis/models'] - for model_name in m_gr: - if '_signal' in m_gr[model_name].attrs: - key = m_gr[model_name].attrs['_signal'] - # del m_gr[model_name].attrs['_signal'] - res = self.hdfgroup2dict( - m_gr[model_name], - lazy=lazy) - del res['_signal'] - models_with_signals.append((key, {model_name: res})) - else: - standalone_models.append( - {model_name: hdfgroup2dict( - m_gr[model_name], lazy=lazy)}) - except TypeError: - raise IOError(not_valid_format) + # Here we check if there are any signals in the container, as + # casting a long list of signals to a numpy array takes a very long + # time. So we check if there are any, and save numpy the trouble + if np.any([isinstance(t, BaseSignal) for t in value]): + tmp = np.array([[0]]) + else: + tmp = np.array(value) + except ValueError: + tmp = np.array([[0]]) + if tmp.dtype == np.dtype('O') or tmp.ndim != 1: + self.dict2hdfgroup(dict(zip( + [str(i) for i in range(len(value))], value)), + group.create_group(_type + str(len(value)) + '_' + key), + **kwds) + elif tmp.dtype.type is np.unicode_: + if _type + key in group: + del group[_type + key] + group.create_dataset(_type + key, + tmp.shape, + dtype=h5py.special_dtype(vlen=str), + **kwds) + group[_type + key][:] = tmp[:] + else: + if _type + key in group: + del group[_type + key] + group.create_dataset(_type + key, + data=tmp, + **kwds) - experiments = [] - exp_dict_list = [] - if 'Experiments' in f: - for ds in f['Experiments']: - if isinstance(f['Experiments'][ds], Group): - if 'data' in f['Experiments'][ds]: - experiments.append(ds) - # Parse the file - for experiment in experiments: - exg = f['Experiments'][experiment] - exp = self.hdfgroup2signaldict(exg, lazy) - # assign correct models, if found: - _tmp = {} - for (key, _dict) in reversed(models_with_signals): - if key == exg.name: - _tmp.update(_dict) - models_with_signals.remove((key, _dict)) - exp['models'] = _tmp - exp_dict_list.append(exp) - for _, m in models_with_signals: - standalone_models.append(m) +def file_reader( + filename, + lazy=False, + **kwds): + """Read data from hdf5 files saved with the hyperspy hdf5 format specification - exp_dict_list.extend(standalone_models) - if not len(exp_dict_list): - raise IOError('This is not a valid HyperSpy HDF5 file. ' - 'You can still load the data using a hdf5 reader, ' - 'e.g. h5py, and manually create a Signal. ' - 'Please, refer to the User Guide for details') - if not lazy: - f.close() - return exp_dict_list + Parameters + ---------- + filename: str + lazy: bool + Load image lazily using dask + **kwds, optional + """ + try: + # in case blosc compression is used + import hdf5plugin + except ImportError: + pass + mode = kwds.pop('mode', 'r') + f = File(filename, mode=mode, **kwds) + # Getting the format version here also checks if it is a valid HSpy + # hdf5 file, so the following two lines must not be deleted or moved + # elsewhere. + reader = HyperspyReader(f) + if reader.version > version: + warnings.warn( + "This file was written using a newer version of the " + "HyperSpy hdf5 file format. I will attempt to load it, but, " + "if I fail, it is likely that I will be more successful at " + "this and other tasks if you upgrade me.") + return reader.read(lazy=lazy) def file_writer(filename, signal, *args, **kwds): @@ -789,40 +839,12 @@ def file_writer(filename, signal, *args, **kwds): else: smd.record_by = "" try: - write_signal(signal, expg, **kwds) + writer = HyperspyWriter(f, signal, expg, **kwds) + writer.write() + #write_signal(signal, expg, **kwds) except BaseException: raise finally: del smd.record_by -def parse_structure(key, group, value, _type, **kwds): - try: - # Here we check if there are any signals in the container, as - # casting a long list of signals to a numpy array takes a very long - # time. So we check if there are any, and save numpy the trouble - if np.any([isinstance(t, BaseSignal) for t in value]): - tmp = np.array([[0]]) - else: - tmp = np.array(value) - except ValueError: - tmp = np.array([[0]]) - if tmp.dtype == np.dtype('O') or tmp.ndim != 1: - self.dict2hdfgroup(dict(zip( - [str(i) for i in range(len(value))], value)), - group.create_group(_type + str(len(value)) + '_' + key), - **kwds) - elif tmp.dtype.type is np.unicode_: - if _type + key in group: - del group[_type + key] - group.create_dataset(_type + key, - tmp.shape, - dtype=h5py.special_dtype(vlen=str), - **kwds) - group[_type + key][:] = tmp[:] - else: - if _type + key in group: - del group[_type + key] - group.create_dataset(_type + key, - data=tmp, - **kwds) \ No newline at end of file diff --git a/hyperspy/io_plugins/nexus.py b/hyperspy/io_plugins/nexus.py index c27514a084..e33e86cff9 100644 --- a/hyperspy/io_plugins/nexus.py +++ b/hyperspy/io_plugins/nexus.py @@ -25,7 +25,7 @@ import h5py import pprint import traits.api as t -from hyperspy.io_plugins.hspy import overwrite_dataset, get_signal_chunks +#from hyperspy.io_plugins.hspy import overwrite_dataset, get_signal_chunks from hyperspy.misc.utils import DictionaryTreeBrowser from hyperspy.exceptions import VisibleDeprecationWarning _logger = logging.getLogger(__name__) diff --git a/hyperspy/io_plugins/zspy.py b/hyperspy/io_plugins/zspy.py index 50b090d181..b30adc8132 100644 --- a/hyperspy/io_plugins/zspy.py +++ b/hyperspy/io_plugins/zspy.py @@ -30,7 +30,7 @@ from traits.api import Undefined from hyperspy.misc.utils import ensure_unicode, multiply, get_object_package_info from hyperspy.axes import AxesManager -from hyperspy.io_plugins.hspy import hdfgroup2signaldict, dict2hdfgroup, file_reader, write_signal, overwrite_dataset, get_signal_chunks +#from hyperspy.io_plugins.hspy import hdfgroup2signaldict, dict2hdfgroup, file_reader, write_signal, overwrite_dataset, get_signal_chunks import numcodecs from hyperspy.io_plugins.hspy import version From 8e27d65526b0aa44404bdaea4fe76c7c8bd15b96 Mon Sep 17 00:00:00 2001 From: shaw Date: Fri, 3 Sep 2021 14:53:40 -0500 Subject: [PATCH 05/31] Zspy classes for saving/loading data added. Still need some additional reduction --- hyperspy/io_plugins/hspy.py | 30 ++-- hyperspy/io_plugins/zspy.py | 276 +++++++++++++++++++++++------------- 2 files changed, 198 insertions(+), 108 deletions(-) diff --git a/hyperspy/io_plugins/hspy.py b/hyperspy/io_plugins/hspy.py index 20d06cca2d..c181e28a0c 100644 --- a/hyperspy/io_plugins/hspy.py +++ b/hyperspy/io_plugins/hspy.py @@ -105,10 +105,14 @@ current_file_version = None # Format version of the file being read default_version = LooseVersion(version) + class HyperspyReader: - def __init__(self, file): + def __init__(self, file, Group, Dataset): self.file = file self.version = self.get_hspy_format_version() + self.Dataset = Dataset + self.Group = Group + def get_hspy_format_version(self): if "file_format_version" in self.file.attrs: @@ -152,11 +156,15 @@ def read(self, lazy): exp_dict_list = [] if 'Experiments' in self.file: for ds in self.file['Experiments']: - if isinstance(self.file['Experiments'][ds], Group): + print(type(self.file['Experiments'][ds])) + print(self.Group) + if isinstance(self.file['Experiments'][ds], self.Group): if 'data' in self.file['Experiments'][ds]: experiments.append(ds) # Parse the file + print(experiments) for experiment in experiments: + exg = self.file['Experiments'][experiment] exp = self.hdfgroup2signaldict(exg, lazy) # assign correct models, if found: @@ -178,8 +186,6 @@ def read(self, lazy): 'You can still load the data using a hdf5 reader, ' 'e.g. h5py, and manually create a Signal. ' 'Please, refer to the User Guide for details') - if not lazy: - self.file.close() return exp_dict_list def hdfgroup2signaldict(self, @@ -437,14 +443,14 @@ def hdfgroup2dict(self, dictionary[key.replace("_datetime_", "")] = date_iso else: dictionary[key] = value - if not isinstance(group, Dataset): + if not isinstance(group, self.Dataset): for key in group.keys(): if key.startswith('_sig_'): from hyperspy.io import dict2signal dictionary[key[len('_sig_'):]] = ( dict2signal(self.hdfgroup2signaldict( group[key], lazy=lazy))) - elif isinstance(group[key], Dataset): + elif isinstance(group[key], self.Dataset): dat = group[key] kn = key if key.startswith("_list_"): @@ -514,6 +520,7 @@ def __init__(self, self.signal = signal self.expg = expg self.kwds = kwds + self.Dataset= Dataset def write(self): self.write_signal(self.signal, @@ -587,7 +594,7 @@ def dict2hdfgroup(self, dictionary, group, **kwds): elif isinstance(value, BaseSignal): kn = key if key.startswith('_sig_') else '_sig_' + key self.write_signal(value, group.require_group(kn)) - elif isinstance(value, (np.ndarray, Dataset, da.Array)): + elif isinstance(value, (np.ndarray, self.Dataset, da.Array)): self.overwrite_dataset(group, value, key, **kwds) elif value is None: group.attrs[key] = '_None_' @@ -720,7 +727,7 @@ def overwrite_dataset(self, # just a reference to already created thing pass else: - _logger.info(f"Chunks used for saving: {dset.chunks}") + _logger.info(f"Chunks used for saving: {chunks}") self.store_data(data, dset, group, key, chunks, **kwds) def get_object_dset(self, group, data, key, chunks, **kwds): @@ -799,14 +806,17 @@ def file_reader( # Getting the format version here also checks if it is a valid HSpy # hdf5 file, so the following two lines must not be deleted or moved # elsewhere. - reader = HyperspyReader(f) + reader = HyperspyReader(f, Group, Dataset) if reader.version > version: warnings.warn( "This file was written using a newer version of the " "HyperSpy hdf5 file format. I will attempt to load it, but, " "if I fail, it is likely that I will be more successful at " "this and other tasks if you upgrade me.") - return reader.read(lazy=lazy) + exp_dict_list = reader.read(lazy=lazy) + if not lazy: + f.close() + return exp_dict_list def file_writer(filename, signal, *args, **kwds): diff --git a/hyperspy/io_plugins/zspy.py b/hyperspy/io_plugins/zspy.py index b30adc8132..d704b4da5e 100644 --- a/hyperspy/io_plugins/zspy.py +++ b/hyperspy/io_plugins/zspy.py @@ -23,14 +23,14 @@ import ast import zarr -from zarr import Array as Dataset +from zarr import Array, Group from zarr import open as File import numpy as np import dask.array as da from traits.api import Undefined from hyperspy.misc.utils import ensure_unicode, multiply, get_object_package_info from hyperspy.axes import AxesManager -#from hyperspy.io_plugins.hspy import hdfgroup2signaldict, dict2hdfgroup, file_reader, write_signal, overwrite_dataset, get_signal_chunks +from hyperspy.io_plugins.hspy import HyperspyReader, HyperspyWriter import numcodecs from hyperspy.io_plugins.hspy import version @@ -78,102 +78,158 @@ # the experiments and that will be accessible as attributes of the # Experiments instance -def get_object_dset(group, data, key, chunks, **kwds): - if data.dtype == np.dtype('O'): - # For saving ragged array - # https://zarr.readthedocs.io/en/stable/tutorial.html?highlight=ragged%20array#ragged-arrays - if chunks is None: - chunks == 1 - these_kwds = kwds.copy() - these_kwds.update(dict(dtype=object, - exact=True, - chunks=chunks)) - dset = group.require_dataset(key, - data.shape, - object_codec=numcodecs.VLenArray(int), - **these_kwds) - return data, dset - - -def store_data(data, dset, group, key, chunks, **kwds): - if isinstance(data, da.Array): - if data.chunks != dset.chunks: - data = data.rechunk(dset.chunks) - path = group._store.dir_path() + "/" + dset.path - data.to_zarr(url=path, - overwrite=True, - **kwds) # add in compression etc - elif data.dtype == np.dtype('O'): - group[key][:] = data[:] # check lazy - else: - path = group._store.dir_path() + "/" + dset.path - dset = zarr.open_array(path, - mode="w", - shape=data.shape, - dtype=data.dtype, - chunks=chunks, - **kwds) - dset[:] = data - -def get_signal_chunks(shape, dtype, signal_axes=None): - """Function that calculates chunks for the signal, - preferably at least one chunk per signal space. - Parameters - ---------- - shape : tuple - the shape of the dataset to be sored / chunked - dtype : {dtype, string} - the numpy dtype of the data - signal_axes: {None, iterable of ints} - the axes defining "signal space" of the dataset. If None, the default - zarr chunking is performed. - """ - typesize = np.dtype(dtype).itemsize - if signal_axes is None: - return None - # chunk size larger than 1 Mb https://zarr.readthedocs.io/en/stable/tutorial.html#chunk-optimizations - # shooting for 100 Mb chunks - total_size = np.prod(shape)*typesize - if total_size < 1e8: # 1 mb - return None - -def write_signal(signal, group, f=None, **kwds): - """Writes a hyperspy signal to a zarr group""" - - group.attrs.update(get_object_package_info(signal)) - metadata = "metadata" - original_metadata = "original_metadata" - - if 'compressor' not in kwds: - kwds['compressor'] = None - - for axis in signal.axes_manager._axes: - axis_dict = axis.get_axis_dictionary() - coord_group = group.create_group( - 'axis-%s' % axis.index_in_array) - dict2hdfgroup(axis_dict, coord_group, **kwds) - mapped_par = group.create_group(metadata) - metadata_dict = signal.metadata.as_dictionary() - overwrite_dataset(group, signal.data, 'data', - signal_axes=signal.axes_manager.signal_indices_in_array, + +class ZspyWriter(HyperspyWriter): + def __init__(self, + file, + signal, + expg, **kwargs): + super().__init__(file, signal, expg, **kwargs) + self.Dataset = Array + + def write_signal(self, + signal, + group, + **kwds): + """Overrides the hyperspy store data function for using zarr as the backend + """ + + group.attrs.update(get_object_package_info(signal)) + metadata = "metadata" + original_metadata = "original_metadata" + + if 'compressor' not in kwds: + kwds['compressor'] = None + + for axis in signal.axes_manager._axes: + axis_dict = axis.get_axis_dictionary() + coord_group = group.create_group( + 'axis-%s' % axis.index_in_array) + self.dict2hdfgroup(axis_dict, coord_group, **kwds) + mapped_par = group.create_group(metadata) + metadata_dict = signal.metadata.as_dictionary() + self.overwrite_dataset(group, signal.data, 'data', + signal_axes=signal.axes_manager.signal_indices_in_array, + **kwds) + # Remove chunks from the kwds since it wouldn't have the same rank as the + # dataset and can't be used + kwds.pop('chunks', None) + self.dict2hdfgroup(metadata_dict, mapped_par, **kwds) + original_par = group.create_group(original_metadata) + self.dict2hdfgroup(signal.original_metadata.as_dictionary(), original_par, **kwds) - # Remove chunks from the kwds since it wouldn't have the same rank as the - # dataset and can't be used - kwds.pop('chunks', None) - dict2hdfgroup(metadata_dict, mapped_par, **kwds) - original_par = group.create_group(original_metadata) - dict2hdfgroup(signal.original_metadata.as_dictionary(), original_par, - **kwds) - learning_results = group.create_group('learning_results') - dict2hdfgroup(signal.learning_results.__dict__, - learning_results, **kwds) - - if len(signal.models) and f is not None: - model_group = f.require_group('Analysis/models') - dict2hdfgroup(signal.models._models.as_dictionary(), - model_group, **kwds) - for model in model_group.values(): - model.attrs['_signal'] = group.name + learning_results = group.create_group('learning_results') + self.dict2hdfgroup(signal.learning_results.__dict__, + learning_results, **kwds) + + if len(signal.models): + model_group = self.file.require_group('Analysis/models') + self.dict2hdfgroup(signal.models._models.as_dictionary(), + model_group, **kwds) + for model in model_group.values(): + model.attrs['_signal'] = group.name + + def store_data(self, + data, + dset, + group, + key, + chunks, + **kwds): + """Overrides the hyperspy store data function for using zarr as the backend + """ + if isinstance(data, da.Array): + if data.chunks != dset.chunks: + data = data.rechunk(dset.chunks) + path = group._store.dir_path() + "/" + dset.path + data.to_zarr(url=path, + overwrite=True, + **kwds) # add in compression etc + elif data.dtype == np.dtype('O'): + group[key][:] = data[:] # check lazy + else: + path = group._store.dir_path() + "/" + dset.path + dset = zarr.open_array(path, + mode="w", + shape=data.shape, + dtype=data.dtype, + chunks=chunks, + **kwds) + dset[:] = data + + def get_object_dset(self, group, data, key, chunks, **kwds): + """Overrides the hyperspy get object dset function for using zarr as the backend + """ + if data.dtype == np.dtype('O'): + # For saving ragged array + # https://zarr.readthedocs.io/en/stable/tutorial.html?highlight=ragged%20array#ragged-arrays + if chunks is None: + chunks == 1 + these_kwds = kwds.copy() + these_kwds.update(dict(dtype=object, + exact=True, + chunks=chunks)) + dset = group.require_dataset(key, + data.shape, + object_codec=numcodecs.VLenArray(int), + **these_kwds) + return data, dset + + def get_signal_chunks(self, shape, dtype, signal_axes=None): + """Function that calculates chunks for the signal, + preferably at least one chunk per signal space. + Parameters + ---------- + shape : tuple + the shape of the dataset to be sored / chunked + dtype : {dtype, string} + the numpy dtype of the data + signal_axes: {None, iterable of ints} + the axes defining "signal space" of the dataset. If None, the default + zarr chunking is performed. + """ + typesize = np.dtype(dtype).itemsize + if signal_axes is None: + return None + # chunk size larger than 1 Mb https://zarr.readthedocs.io/en/stable/tutorial.html#chunk-optimizations + # shooting for 100 Mb chunks + total_size = np.prod(shape) * typesize + if total_size < 1e8: # 1 mb + return None + + def parse_structure(self, key, group, value, _type, **kwds): + from hyperspy.signal import BaseSignal + try: + # Here we check if there are any signals in the container, as + # casting a long list of signals to a numpy array takes a very long + # time. So we check if there are any, and save numpy the trouble + if np.any([isinstance(t, BaseSignal) for t in value]): + tmp = np.array([[0]]) + else: + tmp = np.array(value) + except ValueError: + tmp = np.array([[0]]) + if tmp.dtype == np.dtype('O') or tmp.ndim != 1: + self.dict2hdfgroup(dict(zip( + [str(i) for i in range(len(value))], value)), + group.create_group(_type + str(len(value)) + '_' + key), + **kwds) + elif tmp.dtype.type is np.unicode_: + if _type + key in group: + del group[_type + key] + group.create_dataset(_type + key, + data=tmp, + dtype=object, + object_codec=numcodecs.JSON(), + **kwds) + else: + if _type + key in group: + del group[_type + key] + group.create_dataset( + _type + key, + data=tmp, + **kwds) + def file_writer(filename, signal, *args, **kwds): """Writes data to hyperspy's zarr format @@ -208,8 +264,32 @@ def file_writer(filename, signal, *args, **kwds): else: smd.record_by = "" try: - write_signal(signal, expg, f, **kwds) + writer = ZspyWriter(f, signal, expg, **kwds) + writer.write() except BaseException: raise finally: del smd.record_by + + +def file_reader(filename, + lazy=False, + **kwds): + """Read data from zspy files saved with the hyperspy zspy format specification + Parameters + ---------- + filename: str + lazy: bool + Load image lazily using dask + **kwds, optional + """ + mode = kwds.pop('mode', 'r') + f = zarr.open(filename, mode=mode, **kwds) + reader = HyperspyReader(f, Group, Dataset=Array) + if reader.version > version: + warnings.warn( + "This file was written using a newer version of the " + "HyperSpy zspy file format. I will attempt to load it, but, " + "if I fail, it is likely that I will be more successful at " + "this and other tasks if you upgrade me.") + return reader.read(lazy=lazy) From af590eeffe2bd5a3e7f685b92730cbb102de04b2 Mon Sep 17 00:00:00 2001 From: shaw Date: Fri, 3 Sep 2021 16:07:11 -0500 Subject: [PATCH 06/31] Tried to reduce code. Still could remove parse data function --- hyperspy/io_plugins/hspy.py | 15 ++++++---- hyperspy/io_plugins/zspy.py | 55 +++---------------------------------- 2 files changed, 13 insertions(+), 57 deletions(-) diff --git a/hyperspy/io_plugins/hspy.py b/hyperspy/io_plugins/hspy.py index c181e28a0c..cfc1a0b015 100644 --- a/hyperspy/io_plugins/hspy.py +++ b/hyperspy/io_plugins/hspy.py @@ -520,7 +520,8 @@ def __init__(self, self.signal = signal self.expg = expg self.kwds = kwds - self.Dataset= Dataset + self.Dataset = Dataset + self.unicode_kwds = {"dtype":h5py.special_dtype(vlen=str)} def write(self): self.write_signal(self.signal, @@ -538,8 +539,7 @@ def write_signal(self, signal, group, **kwds): metadata = "metadata" original_metadata = "original_metadata" - if 'compression' not in kwds: - kwds['compression'] = 'gzip' + for axis in signal.axes_manager._axes: axis_dict = axis.get_axis_dictionary() @@ -571,7 +571,7 @@ def write_signal(self, signal, group, **kwds): peak_learning_results, **kwds) if len(signal.models): - model_group = group.file.require_group('Analysis/models') + model_group = self.file.require_group('Analysis/models') self.dict2hdfgroup(signal.models._models.as_dictionary(), model_group, **kwds) for model in model_group.values(): @@ -769,9 +769,11 @@ def parse_structure(self, key, group, value, _type, **kwds): elif tmp.dtype.type is np.unicode_: if _type + key in group: del group[_type + key] + print(self.unicode_kwds) + print(kwds) group.create_dataset(_type + key, tmp.shape, - dtype=h5py.special_dtype(vlen=str), + dtype = self.unicode_kwds["dtype"], **kwds) group[_type + key][:] = tmp[:] else: @@ -782,7 +784,6 @@ def parse_structure(self, key, group, value, _type, **kwds): **kwds) - def file_reader( filename, lazy=False, @@ -829,6 +830,8 @@ def file_writer(filename, signal, *args, **kwds): *args, optional **kwds, optional """ + if 'compression' not in kwds: + kwds['compression'] = 'gzip' with h5py.File(filename, mode='w') as f: f.attrs['file_format'] = "HyperSpy" f.attrs['file_format_version'] = version diff --git a/hyperspy/io_plugins/zspy.py b/hyperspy/io_plugins/zspy.py index d704b4da5e..062f7565a6 100644 --- a/hyperspy/io_plugins/zspy.py +++ b/hyperspy/io_plugins/zspy.py @@ -16,20 +16,15 @@ # You should have received a copy of the GNU General Public License # along with HyperSpy. If not, see . -from distutils.version import LooseVersion import warnings import logging -import datetime -import ast + import zarr from zarr import Array, Group -from zarr import open as File import numpy as np import dask.array as da -from traits.api import Undefined -from hyperspy.misc.utils import ensure_unicode, multiply, get_object_package_info -from hyperspy.axes import AxesManager +from hyperspy.misc.utils import get_object_package_info from hyperspy.io_plugins.hspy import HyperspyReader, HyperspyWriter import numcodecs @@ -86,48 +81,7 @@ def __init__(self, expg, **kwargs): super().__init__(file, signal, expg, **kwargs) self.Dataset = Array - - def write_signal(self, - signal, - group, - **kwds): - """Overrides the hyperspy store data function for using zarr as the backend - """ - - group.attrs.update(get_object_package_info(signal)) - metadata = "metadata" - original_metadata = "original_metadata" - - if 'compressor' not in kwds: - kwds['compressor'] = None - - for axis in signal.axes_manager._axes: - axis_dict = axis.get_axis_dictionary() - coord_group = group.create_group( - 'axis-%s' % axis.index_in_array) - self.dict2hdfgroup(axis_dict, coord_group, **kwds) - mapped_par = group.create_group(metadata) - metadata_dict = signal.metadata.as_dictionary() - self.overwrite_dataset(group, signal.data, 'data', - signal_axes=signal.axes_manager.signal_indices_in_array, - **kwds) - # Remove chunks from the kwds since it wouldn't have the same rank as the - # dataset and can't be used - kwds.pop('chunks', None) - self.dict2hdfgroup(metadata_dict, mapped_par, **kwds) - original_par = group.create_group(original_metadata) - self.dict2hdfgroup(signal.original_metadata.as_dictionary(), original_par, - **kwds) - learning_results = group.create_group('learning_results') - self.dict2hdfgroup(signal.learning_results.__dict__, - learning_results, **kwds) - - if len(signal.models): - model_group = self.file.require_group('Analysis/models') - self.dict2hdfgroup(signal.models._models.as_dictionary(), - model_group, **kwds) - for model in model_group.values(): - model.attrs['_signal'] = group.name + self.unicode_kwds = {"dtype" : object, "object_codec" : numcodecs.JSON()} def store_data(self, data, @@ -220,8 +174,7 @@ def parse_structure(self, key, group, value, _type, **kwds): group.create_dataset(_type + key, data=tmp, dtype=object, - object_codec=numcodecs.JSON(), - **kwds) + object_codec=numcodecs.JSON(), **kwds) else: if _type + key in group: del group[_type + key] From 023630bf61352887e2dbc6af91c5c86de096f264 Mon Sep 17 00:00:00 2001 From: cssfrancis Date: Wed, 8 Sep 2021 15:40:52 -0500 Subject: [PATCH 07/31] Messed up the file reading a little bit. Need to Fix Nexus and EMD formats. --- hyperspy/io_plugins/hierarchical.py | 749 ++++++++++++++++++++++++++++ hyperspy/io_plugins/hspy.py | 688 +------------------------ hyperspy/io_plugins/nexus.py | 60 +-- hyperspy/io_plugins/zspy.py | 78 +-- 4 files changed, 794 insertions(+), 781 deletions(-) create mode 100644 hyperspy/io_plugins/hierarchical.py diff --git a/hyperspy/io_plugins/hierarchical.py b/hyperspy/io_plugins/hierarchical.py new file mode 100644 index 0000000000..f92a28ac15 --- /dev/null +++ b/hyperspy/io_plugins/hierarchical.py @@ -0,0 +1,749 @@ +from distutils.version import LooseVersion +import warnings +import logging +import datetime +import ast + +import h5py + +import numpy as np +import dask.array as da +from traits.api import Undefined +from hyperspy.misc.utils import ensure_unicode, multiply, get_object_package_info +from hyperspy.axes import AxesManager + +version = "3.1" +default_version = LooseVersion(version) + +not_valid_format = 'The file is not a valid HyperSpy hdf5 file' + +_logger = logging.getLogger(__name__) + + +def _overwrite_dataset(group, + data, + key, + signal_axes=None, + chunks=None, + get_signal_chunks=None, + get_object_dset=None, + **kwds): + """Overwrites some dataset in a hirarchical dataset. + + Parameters + ---------- + group: Zarr.Group or h5py.Group + The group to write the data to + data: Array-like + The data to be written + key: str + The key for the data + signal_axes: tuple + The indexes of the signal axes + chunks: tuple + The chunks for the dataset. If None then get_signal_chunks will be called + get_signal_chunks: func + A function to get the signal chunks for the dataset + get_object_dset: + A function to get the object dset for saving ragged arrays. + kwds: + Any additional keywords + Returns + ------- + + """ + if chunks is None: + if isinstance(data, da.Array): + # For lazy dataset, by default, we use the current dask chunking + chunks = tuple([c[0] for c in data.chunks]) + else: + # If signal_axes=None, use automatic h5py chunking, otherwise + # optimise the chunking to contain at least one signal per chunk + chunks = get_signal_chunks(data.shape, data.dtype, signal_axes) + if np.issubdtype(data.dtype, np.dtype('U')): + # Saving numpy unicode type is not supported in h5py + data = data.astype(np.dtype('S')) + if data.dtype == np.dtype('O'): + dset = get_object_dset(group, data, key, chunks, **kwds) + + else: + got_data = False + maxshape = tuple(None for _ in data.shape) + while not got_data: + try: + these_kwds = kwds.copy() + these_kwds.update(dict(shape=data.shape, + dtype=data.dtype, + exact=True, + maxshape=maxshape, + chunks=chunks, + shuffle=True, )) + + # If chunks is True, the `chunks` attribute of `dset` below + # contains the chunk shape guessed by h5py + dset = group.require_dataset(key, **these_kwds) + got_data = True + except TypeError: + # if the shape or dtype/etc do not match, + # we delete the old one and create new in the next loop run + del group[key] + if dset == data: + # just a reference to already created thing + pass + else: + _logger.info(f"Chunks used for saving: {chunks}") + store_data(data, dset, group, key, chunks, **kwds) + + def get_signal_chunks(self, + shape, + dtype, + signal_axes=None): + """Function that calculates chunks for the signal, preferably at least one + chunk per signal space. + + Parameters + ---------- + shape : tuple + the shape of the dataset to be sored / chunked + dtype : {dtype, string} + the numpy dtype of the data + signal_axes: {None, iterable of ints} + the axes defining "signal space" of the dataset. If None, the default + h5py chunking is performed. + """ + typesize = np.dtype(dtype).itemsize + if signal_axes is None: + return h5py._hl.filters.guess_chunk(shape, None, typesize) + + # largely based on the guess_chunk in h5py + CHUNK_MAX = 1024 * 1024 + want_to_keep = multiply([shape[i] for i in signal_axes]) * typesize + if want_to_keep >= CHUNK_MAX: + chunks = [1 for _ in shape] + for i in signal_axes: + chunks[i] = shape[i] + return tuple(chunks) + + chunks = [i for i in shape] + idx = 0 + navigation_axes = tuple(i for i in range(len(shape)) if i not in + signal_axes) + nchange = len(navigation_axes) + while True: + chunk_bytes = multiply(chunks) * typesize + + if chunk_bytes < CHUNK_MAX: + break + + if multiply([chunks[i] for i in navigation_axes]) == 1: + break + change = navigation_axes[idx % nchange] + chunks[change] = np.ceil(chunks[change] / 2.0) + idx += 1 + return tuple(int(x) for x in chunks) + +def _get_object_dset( group, key, chunks, **kwds): + """Creates a Dataset or Zarr Array object for saving ragged data + Parameters + ---------- + self + group + data + key + chunks + kwds + + Returns + ------- + + """ + # For saving ragged array + if chunks is None: + chunks = 1 + dset = group.require_dataset(key, + **kwds) + return dset + +def _store_data(data, dset, group, key, chunks, **kwds): + if isinstance(data, da.Array): + if data.chunks != dset.chunks: + data = data.rechunk(dset.chunks) + da.store(data, dset) + elif data.flags.c_contiguous: + dset.write_direct(data) + else: + dset[:] = data + + +class HierarchicalReader: + """A generic Reader class for reading data from hierarchical file types.""" + def __init__(self, + file): + self.file = file + self.version = self.get_format_version() + self.Dataset = None + self.Group = None + self.unicode_kwds = None + self.ragged_kwds = None + self.store_data + self.get_object_dset + + + def get_format_version(self): + if "file_format_version" in self.file.attrs: + version = self.file.attrs["file_format_version"] + if isinstance(version, bytes): + version = version.decode() + if isinstance(version, float): + version = str(round(version, 2)) + elif "Experiments" in self.file: + # Chances are that this is a HSpy hdf5 file version 1.0 + version = "1.0" + elif "Analysis" in self.f: + # Starting version 2.0 we have "Analysis" field as well + version = "2.0" + else: + raise IOError(not_valid_format) + return LooseVersion(version) + + def read(self, lazy): + models_with_signals = [] + standalone_models = [] + if 'Analysis/models' in self.file: + try: + m_gr = self.file['Analysis/models'] + for model_name in m_gr: + if '_signal' in m_gr[model_name].attrs: + key = m_gr[model_name].attrs['_signal'] + # del m_gr[model_name].attrs['_signal'] + res = self.group2dict( + m_gr[model_name], + lazy=lazy) + del res['_signal'] + models_with_signals.append((key, {model_name: res})) + else: + standalone_models.append( + {model_name: self.group2dict( + m_gr[model_name], lazy=lazy)}) + except TypeError: + raise IOError(not_valid_format) + experiments = [] + exp_dict_list = [] + if 'Experiments' in self.file: + for ds in self.file['Experiments']: + print(type(self.file['Experiments'][ds])) + print(self.Group) + if isinstance(self.file['Experiments'][ds], self.Group): + if 'data' in self.file['Experiments'][ds]: + experiments.append(ds) + # Parse the file + print(experiments) + for experiment in experiments: + + exg = self.file['Experiments'][experiment] + exp = self.group2signaldict(exg, lazy) + # assign correct models, if found: + _tmp = {} + for (key, _dict) in reversed(models_with_signals): + if key == exg.name: + _tmp.update(_dict) + models_with_signals.remove((key, _dict)) + exp['models'] = _tmp + + exp_dict_list.append(exp) + + for _, m in models_with_signals: + standalone_models.append(m) + + exp_dict_list.extend(standalone_models) + if not len(exp_dict_list): + raise IOError('This is not a valid HyperSpy HDF5 file. ' + 'You can still load the data using a hdf5 reader, ' + 'e.g. h5py, and manually create a Signal. ' + 'Please, refer to the User Guide for details') + return exp_dict_list + + def group2signaldict(self, + group, + lazy=False): + if self.version < LooseVersion("1.2"): + metadata = "mapped_parameters" + original_metadata = "original_parameters" + else: + metadata = "metadata" + original_metadata = "original_metadata" + + exp = {'metadata': self.group2dict( + group[metadata], lazy=lazy), + 'original_metadata': self.group2dict( + group[original_metadata], lazy=lazy), + 'attributes': {} + } + if "package" in group.attrs: + # HyperSpy version is >= 1.5 + exp["package"] = group.attrs["package"] + exp["package_version"] = group.attrs["package_version"] + else: + # Prior to v1.4 we didn't store the package information. Since there + # were already external package we cannot assume any package provider so + # we leave this empty. + exp["package"] = "" + exp["package_version"] = "" + + data = group['data'] + if lazy: + data = da.from_array(data, chunks=data.chunks) + exp['attributes']['_lazy'] = True + else: + data = np.asanyarray(data) + exp['data'] = data + axes = [] + for i in range(len(exp['data'].shape)): + try: + axes.append(self.group2dict(group['axis-%i' % i])) + axis = axes[-1] + for key, item in axis.items(): + if isinstance(item, np.bool_): + axis[key] = bool(item) + else: + axis[key] = ensure_unicode(item) + except KeyError: + break + if len(axes) != len(exp['data'].shape): # broke from the previous loop + try: + axes = [i for k, i in sorted(iter(self.group2dict( + group['_list_' + str(len(exp['data'].shape)) + '_axes'], + lazy=lazy).items()))] + except KeyError: + raise IOError(not_valid_format) + exp['axes'] = axes + if 'learning_results' in group.keys(): + exp['attributes']['learning_results'] = \ + self.group2dict( + group['learning_results'], + lazy=lazy) + if 'peak_learning_results' in group.keys(): + exp['attributes']['peak_learning_results'] = \ + self.group2dict( + group['peak_learning_results'], + lazy=lazy) + + # If the title was not defined on writing the Experiment is + # then called __unnamed__. The next "if" simply sets the title + # back to the empty string + if "General" in exp["metadata"] and "title" in exp["metadata"]["General"]: + if '__unnamed__' == exp['metadata']['General']['title']: + exp['metadata']["General"]['title'] = '' + + if self.version < LooseVersion("1.1"): + # Load the decomposition results written with the old name, + # mva_results + if 'mva_results' in group.keys(): + exp['attributes']['learning_results'] = self.group2dict( + group['mva_results'], lazy=lazy) + if 'peak_mva_results' in group.keys(): + exp['attributes']['peak_learning_results'] = self.group2dict( + group['peak_mva_results'], lazy=lazy) + # Replace the old signal and name keys with their current names + if 'signal' in exp['metadata']: + if "Signal" not in exp["metadata"]: + exp["metadata"]["Signal"] = {} + exp['metadata']["Signal"]['signal_type'] = \ + exp['metadata']['signal'] + del exp['metadata']['signal'] + + if 'name' in exp['metadata']: + if "General" not in exp["metadata"]: + exp["metadata"]["General"] = {} + exp['metadata']['General']['title'] = \ + exp['metadata']['name'] + del exp['metadata']['name'] + + if self.version < LooseVersion("1.2"): + if '_internal_parameters' in exp['metadata']: + exp['metadata']['_HyperSpy'] = \ + exp['metadata']['_internal_parameters'] + del exp['metadata']['_internal_parameters'] + if 'stacking_history' in exp['metadata']['_HyperSpy']: + exp['metadata']['_HyperSpy']["Stacking_history"] = \ + exp['metadata']['_HyperSpy']['stacking_history'] + del exp['metadata']['_HyperSpy']["stacking_history"] + if 'folding' in exp['metadata']['_HyperSpy']: + exp['metadata']['_HyperSpy']["Folding"] = \ + exp['metadata']['_HyperSpy']['folding'] + del exp['metadata']['_HyperSpy']["folding"] + if 'Variance_estimation' in exp['metadata']: + if "Noise_properties" not in exp["metadata"]: + exp["metadata"]["Noise_properties"] = {} + exp['metadata']['Noise_properties']["Variance_linear_model"] = \ + exp['metadata']['Variance_estimation'] + del exp['metadata']['Variance_estimation'] + if "TEM" in exp["metadata"]: + if "Acquisition_instrument" not in exp["metadata"]: + exp["metadata"]["Acquisition_instrument"] = {} + exp["metadata"]["Acquisition_instrument"]["TEM"] = \ + exp["metadata"]["TEM"] + del exp["metadata"]["TEM"] + tem = exp["metadata"]["Acquisition_instrument"]["TEM"] + if "EELS" in tem: + if "dwell_time" in tem: + tem["EELS"]["dwell_time"] = tem["dwell_time"] + del tem["dwell_time"] + if "dwell_time_units" in tem: + tem["EELS"]["dwell_time_units"] = tem["dwell_time_units"] + del tem["dwell_time_units"] + if "exposure" in tem: + tem["EELS"]["exposure"] = tem["exposure"] + del tem["exposure"] + if "exposure_units" in tem: + tem["EELS"]["exposure_units"] = tem["exposure_units"] + del tem["exposure_units"] + if "Detector" not in tem: + tem["Detector"] = {} + tem["Detector"] = tem["EELS"] + del tem["EELS"] + if "EDS" in tem: + if "Detector" not in tem: + tem["Detector"] = {} + if "EDS" not in tem["Detector"]: + tem["Detector"]["EDS"] = {} + tem["Detector"]["EDS"] = tem["EDS"] + del tem["EDS"] + del tem + if "SEM" in exp["metadata"]: + if "Acquisition_instrument" not in exp["metadata"]: + exp["metadata"]["Acquisition_instrument"] = {} + exp["metadata"]["Acquisition_instrument"]["SEM"] = \ + exp["metadata"]["SEM"] + del exp["metadata"]["SEM"] + sem = exp["metadata"]["Acquisition_instrument"]["SEM"] + if "EDS" in sem: + if "Detector" not in sem: + sem["Detector"] = {} + if "EDS" not in sem["Detector"]: + sem["Detector"]["EDS"] = {} + sem["Detector"]["EDS"] = sem["EDS"] + del sem["EDS"] + del sem + + if "Sample" in exp["metadata"] and "Xray_lines" in exp[ + "metadata"]["Sample"]: + exp["metadata"]["Sample"]["xray_lines"] = exp[ + "metadata"]["Sample"]["Xray_lines"] + del exp["metadata"]["Sample"]["Xray_lines"] + + for key in ["title", "date", "time", "original_filename"]: + if key in exp["metadata"]: + if "General" not in exp["metadata"]: + exp["metadata"]["General"] = {} + exp["metadata"]["General"][key] = exp["metadata"][key] + del exp["metadata"][key] + for key in ["record_by", "signal_origin", "signal_type"]: + if key in exp["metadata"]: + if "Signal" not in exp["metadata"]: + exp["metadata"]["Signal"] = {} + exp["metadata"]["Signal"][key] = exp["metadata"][key] + del exp["metadata"][key] + + if self.version < LooseVersion("3.0"): + if "Acquisition_instrument" in exp["metadata"]: + # Move tilt_stage to Stage.tilt_alpha + # Move exposure time to Detector.Camera.exposure_time + if "TEM" in exp["metadata"]["Acquisition_instrument"]: + tem = exp["metadata"]["Acquisition_instrument"]["TEM"] + exposure = None + if "tilt_stage" in tem: + tem["Stage"] = {"tilt_alpha": tem["tilt_stage"]} + del tem["tilt_stage"] + if "exposure" in tem: + exposure = "exposure" + # Digital_micrograph plugin was parsing to 'exposure_time' + # instead of 'exposure': need this to be compatible with + # previous behaviour + if "exposure_time" in tem: + exposure = "exposure_time" + if exposure is not None: + if "Detector" not in tem: + tem["Detector"] = {"Camera": { + "exposure": tem[exposure]}} + tem["Detector"]["Camera"] = {"exposure": tem[exposure]} + del tem[exposure] + # Move tilt_stage to Stage.tilt_alpha + if "SEM" in exp["metadata"]["Acquisition_instrument"]: + sem = exp["metadata"]["Acquisition_instrument"]["SEM"] + if "tilt_stage" in sem: + sem["Stage"] = {"tilt_alpha": sem["tilt_stage"]} + del sem["tilt_stage"] + + return exp + + def group2dict(self, + group, + dictionary=None, + lazy=False): + if dictionary is None: + dictionary = {} + for key, value in group.attrs.items(): + if isinstance(value, bytes): + value = value.decode() + if isinstance(value, (np.string_, str)): + if value == '_None_': + value = None + elif isinstance(value, np.bool_): + value = bool(value) + elif isinstance(value, np.ndarray) and value.dtype.char == "S": + # Convert strings to unicode + value = value.astype("U") + if value.dtype.str.endswith("U1"): + value = value.tolist() + # skip signals - these are handled below. + if key.startswith('_sig_'): + pass + elif key.startswith('_list_empty_'): + dictionary[key[len('_list_empty_'):]] = [] + elif key.startswith('_tuple_empty_'): + dictionary[key[len('_tuple_empty_'):]] = () + elif key.startswith('_bs_'): + dictionary[key[len('_bs_'):]] = value.tobytes() + # The following two elif stataments enable reading date and time from + # v < 2 of HyperSpy's metadata specifications + elif key.startswith('_datetime_date'): + date_iso = datetime.date( + *ast.literal_eval(value[value.index("("):])).isoformat() + dictionary[key.replace("_datetime_", "")] = date_iso + elif key.startswith('_datetime_time'): + date_iso = datetime.time( + *ast.literal_eval(value[value.index("("):])).isoformat() + dictionary[key.replace("_datetime_", "")] = date_iso + else: + dictionary[key] = value + if not isinstance(group, self.Dataset): + for key in group.keys(): + if key.startswith('_sig_'): + from hyperspy.io import dict2signal + dictionary[key[len('_sig_'):]] = ( + dict2signal(self.group2signaldict( + group[key], lazy=lazy))) + elif isinstance(group[key], self.Dataset): + dat = group[key] + kn = key + if key.startswith("_list_"): + if (h5py.check_string_dtype(dat.dtype) and + hasattr(dat, 'asstr')): + # h5py 3.0 and newer + # https://docs.h5py.org/en/3.0.0/strings.html + dat = dat.asstr()[:] + ans = np.array(dat) + ans = ans.tolist() + kn = key[6:] + elif key.startswith("_tuple_"): + ans = np.array(dat) + ans = tuple(ans.tolist()) + kn = key[7:] + elif dat.dtype.char == "S": + ans = np.array(dat) + try: + ans = ans.astype("U") + except UnicodeDecodeError: + # There are some strings that must stay in binary, + # for example dill pickles. This will obviously also + # let "wrong" binary string fail somewhere else... + pass + elif lazy: + ans = da.from_array(dat, chunks=dat.chunks) + else: + ans = np.array(dat) + dictionary[kn] = ans + elif key.startswith('_hspy_AxesManager_'): + dictionary[key[len('_hspy_AxesManager_'):]] = AxesManager( + [i for k, i in sorted(iter( + self.group2dict( + group[key], lazy=lazy).items() + ))]) + elif key.startswith('_list_'): + dictionary[key[7 + key[6:].find('_'):]] = \ + [i for k, i in sorted(iter( + self.group2dict( + group[key], lazy=lazy).items() + ))] + elif key.startswith('_tuple_'): + dictionary[key[8 + key[7:].find('_'):]] = tuple( + [i for k, i in sorted(iter( + self.group2dict( + group[key], lazy=lazy).items() + ))]) + else: + dictionary[key] = {} + self.group2dict( + group[key], + dictionary[key], + lazy=lazy) + + return dictionary + +class HierarchicalWriter: + """An object used to simplify and orgainize the process for + writing a Hierachical signal. (.hspy format) + """ + def __init__(self, + file, + signal, + expg, + **kwds): + self.file = file + self.signal = signal + self.expg = expg + self.Dataset = None + self.Group = None + self.unicode_kwds = None + self.ragged_kwds = None + self.kwds = kwds + + def write(self): + self.write_signal(self.signal, + self.expg, + **self.kwds) + + def write_signal(self, signal, group, **kwds): + "Writes a hyperspy signal to a hdf5 group" + + group.attrs.update(get_object_package_info(signal)) + if default_version < LooseVersion("1.2"): + metadata = "mapped_parameters" + original_metadata = "original_parameters" + else: + metadata = "metadata" + original_metadata = "original_metadata" + + + + for axis in signal.axes_manager._axes: + axis_dict = axis.get_axis_dictionary() + coord_group = group.create_group( + 'axis-%s' % axis.index_in_array) + self.dict2hdfgroup(axis_dict, coord_group, **kwds) + mapped_par = group.create_group(metadata) + metadata_dict = signal.metadata.as_dictionary() + self.overwrite_dataset(group, signal.data, 'data', + signal_axes=signal.axes_manager.signal_indices_in_array, + **kwds) + if default_version < LooseVersion("1.2"): + metadata_dict["_internal_parameters"] = \ + metadata_dict.pop("_HyperSpy") + # Remove chunks from the kwds since it wouldn't have the same rank as the + # dataset and can't be used + kwds.pop('chunks', None) + self.dict2hdfgroup(metadata_dict, mapped_par, **kwds) + original_par = group.create_group(original_metadata) + self.dict2hdfgroup(signal.original_metadata.as_dictionary(), original_par, + **kwds) + learning_results = group.create_group('learning_results') + self.dict2hdfgroup(signal.learning_results.__dict__, + learning_results, **kwds) + if hasattr(signal, 'peak_learning_results'): + peak_learning_results = group.create_group( + 'peak_learning_results') + self.dict2hdfgroup(signal.peak_learning_results.__dict__, + peak_learning_results, **kwds) + + if len(signal.models): + model_group = self.file.require_group('Analysis/models') + self.dict2group(signal.models._models.as_dictionary(), + model_group, **kwds) + for model in model_group.values(): + model.attrs['_signal'] = group.name + + def dict2group(self, dictionary, group, **kwds): + "Recursive writer of dicts and signals" + + from hyperspy.misc.utils import DictionaryTreeBrowser + from hyperspy.signal import BaseSignal + + for key, value in dictionary.items(): + if isinstance(value, dict): + self.dict2hdfgroup(value, group.create_group(key), + **kwds) + elif isinstance(value, DictionaryTreeBrowser): + self.dict2hdfgroup(value.as_dictionary(), + group.create_group(key), + **kwds) + elif isinstance(value, BaseSignal): + kn = key if key.startswith('_sig_') else '_sig_' + key + self.write_signal(value, group.require_group(kn)) + elif isinstance(value, (np.ndarray, self.Dataset, da.Array)): + self.overwrite_dataset(group, value, key, **kwds) + elif value is None: + group.attrs[key] = '_None_' + elif isinstance(value, bytes): + try: + # binary string if has any null characters (otherwise not + # supported by hdf5) + value.index(b'\x00') + group.attrs['_bs_' + key] = np.void(value) + except ValueError: + group.attrs[key] = value.decode() + elif isinstance(value, str): + group.attrs[key] = value + elif isinstance(value, AxesManager): + self.dict2hdfgroup(value.as_dictionary(), + group.create_group('_hspy_AxesManager_' + key), + **kwds) + elif isinstance(value, list): + if len(value): + self.parse_structure(key, group, value, '_list_', **kwds) + else: + group.attrs['_list_empty_' + key] = '_None_' + elif isinstance(value, tuple): + if len(value): + self.parse_structure(key, group, value, '_tuple_', **kwds) + else: + group.attrs['_tuple_empty_' + key] = '_None_' + + elif value is Undefined: + continue + else: + try: + group.attrs[key] = value + except BaseException: + _logger.exception( + "The hdf5 writer could not write the following " + "information in the file: %s : %s", key, value) + + + + + def parse_structure(self, key, group, value, _type, **kwds): + from hyperspy.signal import BaseSignal + try: + # Here we check if there are any signals in the container, as + # casting a long list of signals to a numpy array takes a very long + # time. So we check if there are any, and save numpy the trouble + if np.any([isinstance(t, BaseSignal) for t in value]): + tmp = np.array([[0]]) + else: + tmp = np.array(value) + except ValueError: + tmp = np.array([[0]]) + if tmp.dtype == np.dtype('O') or tmp.ndim != 1: + self.dict2hdfgroup(dict(zip( + [str(i) for i in range(len(value))], value)), + group.create_group(_type + str(len(value)) + '_' + key), + **kwds) + elif tmp.dtype.type is np.unicode_: + if _type + key in group: + del group[_type + key] + print(self.unicode_kwds) + print(kwds) + group.create_dataset(_type + key, + shape=tmp.shape, + **self.unicode_kwds, + **kwds) + group[_type + key][:] = tmp[:] + else: + if _type + key in group: + del group[_type + key] + group.create_dataset(_type + key, + data=tmp, + **kwds) + + diff --git a/hyperspy/io_plugins/hspy.py b/hyperspy/io_plugins/hspy.py index cfc1a0b015..e9e1988197 100644 --- a/hyperspy/io_plugins/hspy.py +++ b/hyperspy/io_plugins/hspy.py @@ -19,18 +19,10 @@ from distutils.version import LooseVersion import warnings import logging -import datetime -import ast - import h5py -import numpy as np -import dask.array as da -from traits.api import Undefined -from hyperspy.misc.utils import ensure_unicode, multiply, get_object_package_info -from hyperspy.axes import AxesManager - from h5py import Dataset, File, Group +from hyperspy.io_plugins.hierarchical import HierarchicalWriter, HierarchicalReader _logger = logging.getLogger(__name__) @@ -106,408 +98,15 @@ default_version = LooseVersion(version) -class HyperspyReader: - def __init__(self, file, Group, Dataset): - self.file = file - self.version = self.get_hspy_format_version() +class HyperspyReader(HierarchicalReader): + def __init__(self, file): + super().__init__(file) self.Dataset = Dataset self.Group = Group + self.unicode_kwds = {"dtype":h5py.special_dtype(vlen=str)} - def get_hspy_format_version(self): - if "file_format_version" in self.file.attrs: - version = self.file.attrs["file_format_version"] - if isinstance(version, bytes): - version = version.decode() - if isinstance(version, float): - version = str(round(version, 2)) - elif "Experiments" in self.file: - # Chances are that this is a HSpy hdf5 file version 1.0 - version = "1.0" - elif "Analysis" in self.f: - # Starting version 2.0 we have "Analysis" field as well - version = "2.0" - else: - raise IOError(not_valid_format) - return LooseVersion(version) - - def read(self, lazy): - models_with_signals = [] - standalone_models = [] - if 'Analysis/models' in self.file: - try: - m_gr = self.file['Analysis/models'] - for model_name in m_gr: - if '_signal' in m_gr[model_name].attrs: - key = m_gr[model_name].attrs['_signal'] - # del m_gr[model_name].attrs['_signal'] - res = self.hdfgroup2dict( - m_gr[model_name], - lazy=lazy) - del res['_signal'] - models_with_signals.append((key, {model_name: res})) - else: - standalone_models.append( - {model_name: self.hdfgroup2dict( - m_gr[model_name], lazy=lazy)}) - except TypeError: - raise IOError(not_valid_format) - experiments = [] - exp_dict_list = [] - if 'Experiments' in self.file: - for ds in self.file['Experiments']: - print(type(self.file['Experiments'][ds])) - print(self.Group) - if isinstance(self.file['Experiments'][ds], self.Group): - if 'data' in self.file['Experiments'][ds]: - experiments.append(ds) - # Parse the file - print(experiments) - for experiment in experiments: - - exg = self.file['Experiments'][experiment] - exp = self.hdfgroup2signaldict(exg, lazy) - # assign correct models, if found: - _tmp = {} - for (key, _dict) in reversed(models_with_signals): - if key == exg.name: - _tmp.update(_dict) - models_with_signals.remove((key, _dict)) - exp['models'] = _tmp - - exp_dict_list.append(exp) - - for _, m in models_with_signals: - standalone_models.append(m) - - exp_dict_list.extend(standalone_models) - if not len(exp_dict_list): - raise IOError('This is not a valid HyperSpy HDF5 file. ' - 'You can still load the data using a hdf5 reader, ' - 'e.g. h5py, and manually create a Signal. ' - 'Please, refer to the User Guide for details') - return exp_dict_list - - def hdfgroup2signaldict(self, - group, - lazy=False): - global default_version - if self.version < LooseVersion("1.2"): - metadata = "mapped_parameters" - original_metadata = "original_parameters" - else: - metadata = "metadata" - original_metadata = "original_metadata" - - exp = {'metadata': self.hdfgroup2dict( - group[metadata], lazy=lazy), - 'original_metadata': self.hdfgroup2dict( - group[original_metadata], lazy=lazy), - 'attributes': {} - } - if "package" in group.attrs: - # HyperSpy version is >= 1.5 - exp["package"] = group.attrs["package"] - exp["package_version"] = group.attrs["package_version"] - else: - # Prior to v1.4 we didn't store the package information. Since there - # were already external package we cannot assume any package provider so - # we leave this empty. - exp["package"] = "" - exp["package_version"] = "" - - data = group['data'] - if lazy: - data = da.from_array(data, chunks=data.chunks) - exp['attributes']['_lazy'] = True - else: - data = np.asanyarray(data) - exp['data'] = data - axes = [] - for i in range(len(exp['data'].shape)): - try: - axes.append(self.hdfgroup2dict(group['axis-%i' % i])) - axis = axes[-1] - for key, item in axis.items(): - if isinstance(item, np.bool_): - axis[key] = bool(item) - else: - axis[key] = ensure_unicode(item) - except KeyError: - break - if len(axes) != len(exp['data'].shape): # broke from the previous loop - try: - axes = [i for k, i in sorted(iter(self.hdfgroup2dict( - group['_list_' + str(len(exp['data'].shape)) + '_axes'], - lazy=lazy).items()))] - except KeyError: - raise IOError(not_valid_format) - exp['axes'] = axes - if 'learning_results' in group.keys(): - exp['attributes']['learning_results'] = \ - self.hdfgroup2dict( - group['learning_results'], - lazy=lazy) - if 'peak_learning_results' in group.keys(): - exp['attributes']['peak_learning_results'] = \ - self.hdfgroup2dict( - group['peak_learning_results'], - lazy=lazy) - - # If the title was not defined on writing the Experiment is - # then called __unnamed__. The next "if" simply sets the title - # back to the empty string - if "General" in exp["metadata"] and "title" in exp["metadata"]["General"]: - if '__unnamed__' == exp['metadata']['General']['title']: - exp['metadata']["General"]['title'] = '' - - if self.version < LooseVersion("1.1"): - # Load the decomposition results written with the old name, - # mva_results - if 'mva_results' in group.keys(): - exp['attributes']['learning_results'] = self.hdfgroup2dict( - group['mva_results'], lazy=lazy) - if 'peak_mva_results' in group.keys(): - exp['attributes']['peak_learning_results'] = self.hdfgroup2dict( - group['peak_mva_results'], lazy=lazy) - # Replace the old signal and name keys with their current names - if 'signal' in exp['metadata']: - if "Signal" not in exp["metadata"]: - exp["metadata"]["Signal"] = {} - exp['metadata']["Signal"]['signal_type'] = \ - exp['metadata']['signal'] - del exp['metadata']['signal'] - - if 'name' in exp['metadata']: - if "General" not in exp["metadata"]: - exp["metadata"]["General"] = {} - exp['metadata']['General']['title'] = \ - exp['metadata']['name'] - del exp['metadata']['name'] - - if self.version < LooseVersion("1.2"): - if '_internal_parameters' in exp['metadata']: - exp['metadata']['_HyperSpy'] = \ - exp['metadata']['_internal_parameters'] - del exp['metadata']['_internal_parameters'] - if 'stacking_history' in exp['metadata']['_HyperSpy']: - exp['metadata']['_HyperSpy']["Stacking_history"] = \ - exp['metadata']['_HyperSpy']['stacking_history'] - del exp['metadata']['_HyperSpy']["stacking_history"] - if 'folding' in exp['metadata']['_HyperSpy']: - exp['metadata']['_HyperSpy']["Folding"] = \ - exp['metadata']['_HyperSpy']['folding'] - del exp['metadata']['_HyperSpy']["folding"] - if 'Variance_estimation' in exp['metadata']: - if "Noise_properties" not in exp["metadata"]: - exp["metadata"]["Noise_properties"] = {} - exp['metadata']['Noise_properties']["Variance_linear_model"] = \ - exp['metadata']['Variance_estimation'] - del exp['metadata']['Variance_estimation'] - if "TEM" in exp["metadata"]: - if "Acquisition_instrument" not in exp["metadata"]: - exp["metadata"]["Acquisition_instrument"] = {} - exp["metadata"]["Acquisition_instrument"]["TEM"] = \ - exp["metadata"]["TEM"] - del exp["metadata"]["TEM"] - tem = exp["metadata"]["Acquisition_instrument"]["TEM"] - if "EELS" in tem: - if "dwell_time" in tem: - tem["EELS"]["dwell_time"] = tem["dwell_time"] - del tem["dwell_time"] - if "dwell_time_units" in tem: - tem["EELS"]["dwell_time_units"] = tem["dwell_time_units"] - del tem["dwell_time_units"] - if "exposure" in tem: - tem["EELS"]["exposure"] = tem["exposure"] - del tem["exposure"] - if "exposure_units" in tem: - tem["EELS"]["exposure_units"] = tem["exposure_units"] - del tem["exposure_units"] - if "Detector" not in tem: - tem["Detector"] = {} - tem["Detector"] = tem["EELS"] - del tem["EELS"] - if "EDS" in tem: - if "Detector" not in tem: - tem["Detector"] = {} - if "EDS" not in tem["Detector"]: - tem["Detector"]["EDS"] = {} - tem["Detector"]["EDS"] = tem["EDS"] - del tem["EDS"] - del tem - if "SEM" in exp["metadata"]: - if "Acquisition_instrument" not in exp["metadata"]: - exp["metadata"]["Acquisition_instrument"] = {} - exp["metadata"]["Acquisition_instrument"]["SEM"] = \ - exp["metadata"]["SEM"] - del exp["metadata"]["SEM"] - sem = exp["metadata"]["Acquisition_instrument"]["SEM"] - if "EDS" in sem: - if "Detector" not in sem: - sem["Detector"] = {} - if "EDS" not in sem["Detector"]: - sem["Detector"]["EDS"] = {} - sem["Detector"]["EDS"] = sem["EDS"] - del sem["EDS"] - del sem - - if "Sample" in exp["metadata"] and "Xray_lines" in exp[ - "metadata"]["Sample"]: - exp["metadata"]["Sample"]["xray_lines"] = exp[ - "metadata"]["Sample"]["Xray_lines"] - del exp["metadata"]["Sample"]["Xray_lines"] - - for key in ["title", "date", "time", "original_filename"]: - if key in exp["metadata"]: - if "General" not in exp["metadata"]: - exp["metadata"]["General"] = {} - exp["metadata"]["General"][key] = exp["metadata"][key] - del exp["metadata"][key] - for key in ["record_by", "signal_origin", "signal_type"]: - if key in exp["metadata"]: - if "Signal" not in exp["metadata"]: - exp["metadata"]["Signal"] = {} - exp["metadata"]["Signal"][key] = exp["metadata"][key] - del exp["metadata"][key] - - if self.version < LooseVersion("3.0"): - if "Acquisition_instrument" in exp["metadata"]: - # Move tilt_stage to Stage.tilt_alpha - # Move exposure time to Detector.Camera.exposure_time - if "TEM" in exp["metadata"]["Acquisition_instrument"]: - tem = exp["metadata"]["Acquisition_instrument"]["TEM"] - exposure = None - if "tilt_stage" in tem: - tem["Stage"] = {"tilt_alpha": tem["tilt_stage"]} - del tem["tilt_stage"] - if "exposure" in tem: - exposure = "exposure" - # Digital_micrograph plugin was parsing to 'exposure_time' - # instead of 'exposure': need this to be compatible with - # previous behaviour - if "exposure_time" in tem: - exposure = "exposure_time" - if exposure is not None: - if "Detector" not in tem: - tem["Detector"] = {"Camera": { - "exposure": tem[exposure]}} - tem["Detector"]["Camera"] = {"exposure": tem[exposure]} - del tem[exposure] - # Move tilt_stage to Stage.tilt_alpha - if "SEM" in exp["metadata"]["Acquisition_instrument"]: - sem = exp["metadata"]["Acquisition_instrument"]["SEM"] - if "tilt_stage" in sem: - sem["Stage"] = {"tilt_alpha": sem["tilt_stage"]} - del sem["tilt_stage"] - - return exp - - def hdfgroup2dict(self, - group, - dictionary=None, - lazy=False): - if dictionary is None: - dictionary = {} - for key, value in group.attrs.items(): - if isinstance(value, bytes): - value = value.decode() - if isinstance(value, (np.string_, str)): - if value == '_None_': - value = None - elif isinstance(value, np.bool_): - value = bool(value) - elif isinstance(value, np.ndarray) and value.dtype.char == "S": - # Convert strings to unicode - value = value.astype("U") - if value.dtype.str.endswith("U1"): - value = value.tolist() - # skip signals - these are handled below. - if key.startswith('_sig_'): - pass - elif key.startswith('_list_empty_'): - dictionary[key[len('_list_empty_'):]] = [] - elif key.startswith('_tuple_empty_'): - dictionary[key[len('_tuple_empty_'):]] = () - elif key.startswith('_bs_'): - dictionary[key[len('_bs_'):]] = value.tobytes() - # The following two elif stataments enable reading date and time from - # v < 2 of HyperSpy's metadata specifications - elif key.startswith('_datetime_date'): - date_iso = datetime.date( - *ast.literal_eval(value[value.index("("):])).isoformat() - dictionary[key.replace("_datetime_", "")] = date_iso - elif key.startswith('_datetime_time'): - date_iso = datetime.time( - *ast.literal_eval(value[value.index("("):])).isoformat() - dictionary[key.replace("_datetime_", "")] = date_iso - else: - dictionary[key] = value - if not isinstance(group, self.Dataset): - for key in group.keys(): - if key.startswith('_sig_'): - from hyperspy.io import dict2signal - dictionary[key[len('_sig_'):]] = ( - dict2signal(self.hdfgroup2signaldict( - group[key], lazy=lazy))) - elif isinstance(group[key], self.Dataset): - dat = group[key] - kn = key - if key.startswith("_list_"): - if (h5py.check_string_dtype(dat.dtype) and - hasattr(dat, 'asstr')): - # h5py 3.0 and newer - # https://docs.h5py.org/en/3.0.0/strings.html - dat = dat.asstr()[:] - ans = np.array(dat) - ans = ans.tolist() - kn = key[6:] - elif key.startswith("_tuple_"): - ans = np.array(dat) - ans = tuple(ans.tolist()) - kn = key[7:] - elif dat.dtype.char == "S": - ans = np.array(dat) - try: - ans = ans.astype("U") - except UnicodeDecodeError: - # There are some strings that must stay in binary, - # for example dill pickles. This will obviously also - # let "wrong" binary string fail somewhere else... - pass - elif lazy: - ans = da.from_array(dat, chunks=dat.chunks) - else: - ans = np.array(dat) - dictionary[kn] = ans - elif key.startswith('_hspy_AxesManager_'): - dictionary[key[len('_hspy_AxesManager_'):]] = AxesManager( - [i for k, i in sorted(iter( - self.hdfgroup2dict( - group[key], lazy=lazy).items() - ))]) - elif key.startswith('_list_'): - dictionary[key[7 + key[6:].find('_'):]] = \ - [i for k, i in sorted(iter( - self.hdfgroup2dict( - group[key], lazy=lazy).items() - ))] - elif key.startswith('_tuple_'): - dictionary[key[8 + key[7:].find('_'):]] = tuple( - [i for k, i in sorted(iter( - self.hdfgroup2dict( - group[key], lazy=lazy).items() - ))]) - else: - dictionary[key] = {} - self.hdfgroup2dict( - group[key], - dictionary[key], - lazy=lazy) - - return dictionary - -class HyperspyWriter: +class HyperspyWriter(HierarchicalWriter): """An object used to simplify and orgainize the process for writing a hyperspy signal. (.hspy format) """ @@ -516,273 +115,14 @@ def __init__(self, signal, expg, **kwds): - self.file = file - self.signal = signal - self.expg = expg - self.kwds = kwds + super().__init__(file, + signal, + expg, + **kwds) self.Dataset = Dataset - self.unicode_kwds = {"dtype":h5py.special_dtype(vlen=str)} - - def write(self): - self.write_signal(self.signal, - self.expg, - **self.kwds) - - def write_signal(self, signal, group, **kwds): - "Writes a hyperspy signal to a hdf5 group" - - group.attrs.update(get_object_package_info(signal)) - if default_version < LooseVersion("1.2"): - metadata = "mapped_parameters" - original_metadata = "original_parameters" - else: - metadata = "metadata" - original_metadata = "original_metadata" - - - - for axis in signal.axes_manager._axes: - axis_dict = axis.get_axis_dictionary() - coord_group = group.create_group( - 'axis-%s' % axis.index_in_array) - self.dict2hdfgroup(axis_dict, coord_group, **kwds) - mapped_par = group.create_group(metadata) - metadata_dict = signal.metadata.as_dictionary() - self.overwrite_dataset(group, signal.data, 'data', - signal_axes=signal.axes_manager.signal_indices_in_array, - **kwds) - if default_version < LooseVersion("1.2"): - metadata_dict["_internal_parameters"] = \ - metadata_dict.pop("_HyperSpy") - # Remove chunks from the kwds since it wouldn't have the same rank as the - # dataset and can't be used - kwds.pop('chunks', None) - self.dict2hdfgroup(metadata_dict, mapped_par, **kwds) - original_par = group.create_group(original_metadata) - self.dict2hdfgroup(signal.original_metadata.as_dictionary(), original_par, - **kwds) - learning_results = group.create_group('learning_results') - self.dict2hdfgroup(signal.learning_results.__dict__, - learning_results, **kwds) - if hasattr(signal, 'peak_learning_results'): - peak_learning_results = group.create_group( - 'peak_learning_results') - self.dict2hdfgroup(signal.peak_learning_results.__dict__, - peak_learning_results, **kwds) - - if len(signal.models): - model_group = self.file.require_group('Analysis/models') - self.dict2hdfgroup(signal.models._models.as_dictionary(), - model_group, **kwds) - for model in model_group.values(): - model.attrs['_signal'] = group.name - - def dict2hdfgroup(self, dictionary, group, **kwds): - "Recursive writer of dicts and signals" - - from hyperspy.misc.utils import DictionaryTreeBrowser - from hyperspy.signal import BaseSignal - - for key, value in dictionary.items(): - if isinstance(value, dict): - self.dict2hdfgroup(value, group.create_group(key), - **kwds) - elif isinstance(value, DictionaryTreeBrowser): - self.dict2hdfgroup(value.as_dictionary(), - group.create_group(key), - **kwds) - elif isinstance(value, BaseSignal): - kn = key if key.startswith('_sig_') else '_sig_' + key - self.write_signal(value, group.require_group(kn)) - elif isinstance(value, (np.ndarray, self.Dataset, da.Array)): - self.overwrite_dataset(group, value, key, **kwds) - elif value is None: - group.attrs[key] = '_None_' - elif isinstance(value, bytes): - try: - # binary string if has any null characters (otherwise not - # supported by hdf5) - value.index(b'\x00') - group.attrs['_bs_' + key] = np.void(value) - except ValueError: - group.attrs[key] = value.decode() - elif isinstance(value, str): - group.attrs[key] = value - elif isinstance(value, AxesManager): - self.dict2hdfgroup(value.as_dictionary(), - group.create_group('_hspy_AxesManager_' + key), - **kwds) - elif isinstance(value, list): - if len(value): - self.parse_structure(key, group, value, '_list_', **kwds) - else: - group.attrs['_list_empty_' + key] = '_None_' - elif isinstance(value, tuple): - if len(value): - self.parse_structure(key, group, value, '_tuple_', **kwds) - else: - group.attrs['_tuple_empty_' + key] = '_None_' - - elif value is Undefined: - continue - else: - try: - group.attrs[key] = value - except BaseException: - _logger.exception( - "The hdf5 writer could not write the following " - "information in the file: %s : %s", key, value) - - - def get_signal_chunks(self, - shape, - dtype, - signal_axes=None): - """Function that calculates chunks for the signal, preferably at least one - chunk per signal space. - - Parameters - ---------- - shape : tuple - the shape of the dataset to be sored / chunked - dtype : {dtype, string} - the numpy dtype of the data - signal_axes: {None, iterable of ints} - the axes defining "signal space" of the dataset. If None, the default - h5py chunking is performed. - """ - typesize = np.dtype(dtype).itemsize - if signal_axes is None: - return h5py._hl.filters.guess_chunk(shape, None, typesize) - - # largely based on the guess_chunk in h5py - CHUNK_MAX = 1024 * 1024 - want_to_keep = multiply([shape[i] for i in signal_axes]) * typesize - if want_to_keep >= CHUNK_MAX: - chunks = [1 for _ in shape] - for i in signal_axes: - chunks[i] = shape[i] - return tuple(chunks) - - chunks = [i for i in shape] - idx = 0 - navigation_axes = tuple(i for i in range(len(shape)) if i not in - signal_axes) - nchange = len(navigation_axes) - while True: - chunk_bytes = multiply(chunks) * typesize - - if chunk_bytes < CHUNK_MAX: - break - - if multiply([chunks[i] for i in navigation_axes]) == 1: - break - change = navigation_axes[idx % nchange] - chunks[change] = np.ceil(chunks[change] / 2.0) - idx += 1 - return tuple(int(x) for x in chunks) - - def overwrite_dataset(self, - group, - data, - key, - signal_axes=None, - chunks=None, - **kwds): - if chunks is None: - if isinstance(data, da.Array): - # For lazy dataset, by default, we use the current dask chunking - chunks = tuple([c[0] for c in data.chunks]) - else: - # If signal_axes=None, use automatic h5py chunking, otherwise - # optimise the chunking to contain at least one signal per chunk - chunks = self.get_signal_chunks(data.shape, data.dtype, signal_axes) - if np.issubdtype(data.dtype, np.dtype('U')): - # Saving numpy unicode type is not supported in h5py - data = data.astype(np.dtype('S')) - if data.dtype == np.dtype('O'): - dset = self.get_object_dset(group, data, key, chunks, **kwds) - else: - got_data=False - maxshape = tuple(None for _ in data.shape) - while not got_data: - try: - these_kwds = kwds.copy() - these_kwds.update(dict(shape=data.shape, - dtype=data.dtype, - exact=True, - maxshape=maxshape, - chunks=chunks, - shuffle=True, )) - - # If chunks is True, the `chunks` attribute of `dset` below - # contains the chunk shape guessed by h5py - dset = group.require_dataset(key, **these_kwds) - got_data = True - except TypeError: - # if the shape or dtype/etc do not match, - # we delete the old one and create new in the next loop run - del group[key] - if dset == data: - # just a reference to already created thing - pass - else: - _logger.info(f"Chunks used for saving: {chunks}") - self.store_data(data, dset, group, key, chunks, **kwds) - - def get_object_dset(self, group, data, key, chunks, **kwds): - # For saving ragged array - # http://docs.h5py.org/en/stable/special.html#arbitrary-vlen-data - dset = group.require_dataset(key, - chunks, - dtype=h5py.special_dtype(vlen=data[0].dtype), - **kwds) - return dset - - def store_data(self, data, dset, group, key, chunks, **kwds): - if isinstance(data, da.Array): - if data.chunks != dset.chunks: - data = data.rechunk(dset.chunks) - da.store(data, dset) - elif data.flags.c_contiguous: - dset.write_direct(data) - else: - dset[:] = data - - def parse_structure(self, key, group, value, _type, **kwds): - from hyperspy.signal import BaseSignal - try: - # Here we check if there are any signals in the container, as - # casting a long list of signals to a numpy array takes a very long - # time. So we check if there are any, and save numpy the trouble - if np.any([isinstance(t, BaseSignal) for t in value]): - tmp = np.array([[0]]) - else: - tmp = np.array(value) - except ValueError: - tmp = np.array([[0]]) - if tmp.dtype == np.dtype('O') or tmp.ndim != 1: - self.dict2hdfgroup(dict(zip( - [str(i) for i in range(len(value))], value)), - group.create_group(_type + str(len(value)) + '_' + key), - **kwds) - elif tmp.dtype.type is np.unicode_: - if _type + key in group: - del group[_type + key] - print(self.unicode_kwds) - print(kwds) - group.create_dataset(_type + key, - tmp.shape, - dtype = self.unicode_kwds["dtype"], - **kwds) - group[_type + key][:] = tmp[:] - else: - if _type + key in group: - del group[_type + key] - group.create_dataset(_type + key, - data=tmp, - **kwds) - + self.Group = Group + self.unicode_kwds = {"dtype": h5py.special_dtype(vlen=str)} + self.ragged_kwds = {"dtype": h5py.special_dtype(vlen=signal.data[0].dtype)} def file_reader( filename, @@ -807,7 +147,7 @@ def file_reader( # Getting the format version here also checks if it is a valid HSpy # hdf5 file, so the following two lines must not be deleted or moved # elsewhere. - reader = HyperspyReader(f, Group, Dataset) + reader = HyperspyReader(f) if reader.version > version: warnings.warn( "This file was written using a newer version of the " diff --git a/hyperspy/io_plugins/nexus.py b/hyperspy/io_plugins/nexus.py index e33e86cff9..8fa6baab63 100644 --- a/hyperspy/io_plugins/nexus.py +++ b/hyperspy/io_plugins/nexus.py @@ -25,7 +25,7 @@ import h5py import pprint import traits.api as t -#from hyperspy.io_plugins.hspy import overwrite_dataset, get_signal_chunks +from hyperspy.io_plugins.hspy import HyperspyReader, HyperspyWriter from hyperspy.misc.utils import DictionaryTreeBrowser from hyperspy.exceptions import VisibleDeprecationWarning _logger = logging.getLogger(__name__) @@ -46,6 +46,11 @@ # ---------------------- +class NexusReader(HyperspyReader): + +class NexusWriter(HyperspyWriter): + + def _byte_to_string(value): """Decode a byte string. @@ -96,7 +101,7 @@ def _parse_from_file(value, lazy=False): if value.chunks: toreturn = da.from_array(value, value.chunks) else: - chunks = get_signal_chunks(value.shape, value.dtype) + chunks = self.get_signal_chunks(value.shape, value.dtype) toreturn = da.from_array(value, chunks) else: toreturn = np.array(value) @@ -965,58 +970,7 @@ def find_searchkeys_in_tree(myDict, rootname): return metadata_dict -def _write_nexus_groups(dictionary, group, skip_keys=None, **kwds): - """Recursively iterate throuh dictionary and write groups to nexus. - - Parameters - ---------- - dictionary : dict - dictionary contents to store to hdf group - group : hdf group - location to store dictionary - skip_keys : str or list of str - the key(s) to skip when writing into the group - **kwds : additional keywords - additional keywords to pass to h5py.create_dataset method - - """ - if skip_keys is None: - skip_keys = [] - elif isinstance(skip_keys, str): - skip_keys = [skip_keys] - for key, value in dictionary.items(): - if key == 'attrs' or key in skip_keys: - # we will handle attrs later... and skip unwanted keys - continue - if isinstance(value, dict): - if "attrs" in value: - if "NX_class" in value["attrs"] and \ - value["attrs"]["NX_class"] == "NXdata": - continue - if 'value' in value.keys() \ - and not isinstance(value["value"], dict) \ - and len(set(list(value.keys()) + ["attrs", "value"])) == 2: - value = value["value"] - else: - _write_nexus_groups(value, group.require_group(key), - skip_keys=skip_keys, **kwds) - if isinstance(value, (list, tuple, np.ndarray, da.Array)): - if all(isinstance(v, dict) for v in value): - # a list of dictionary is from the axes of HyperSpy signal - for i, ax_dict in enumerate(value): - ax_key = '_axes_' + str(i) - _write_nexus_groups(ax_dict, group.require_group(ax_key), - skip_keys=skip_keys, **kwds) - else: - data = _parse_to_file(value) - overwrite_dataset(group, data, key, chunks=None, **kwds) - elif isinstance(value, (int, float, str, bytes)): - group.create_dataset(key, data=_parse_to_file(value)) - else: - if value is not None and value != t.Undefined and key not in group: - _write_nexus_groups(value, group.require_group(key), - skip_keys=skip_keys, **kwds) def _write_nexus_attr(dictionary, group, skip_keys=None): diff --git a/hyperspy/io_plugins/zspy.py b/hyperspy/io_plugins/zspy.py index 062f7565a6..f7e88e9aab 100644 --- a/hyperspy/io_plugins/zspy.py +++ b/hyperspy/io_plugins/zspy.py @@ -19,16 +19,14 @@ import warnings import logging - import zarr from zarr import Array, Group import numpy as np import dask.array as da -from hyperspy.misc.utils import get_object_package_info -from hyperspy.io_plugins.hspy import HyperspyReader, HyperspyWriter +from hyperspy.io_plugins.hspy import version import numcodecs -from hyperspy.io_plugins.hspy import version +from hyperspy.io_plugins.hierarchical import HierarchicalWriter, HierarchicalReader _logger = logging.getLogger(__name__) @@ -74,14 +72,23 @@ # Experiments instance -class ZspyWriter(HyperspyWriter): +class ZspyReader(HierarchicalReader): + def __init__(self, file): + super(ZspyReader, self).__init__(file) + self.Dataset = Array + self.Group = Group + +class ZspyWriter(HierarchicalWriter): def __init__(self, file, signal, expg, **kwargs): super().__init__(file, signal, expg, **kwargs) self.Dataset = Array - self.unicode_kwds = {"dtype" : object, "object_codec" : numcodecs.JSON()} + self.unicode_kwds = {"dtype": object, "object_codec": numcodecs.JSON()} + self.ragged_kwds = {"dtype": object, + "object_codec": numcodecs.VLenArray(int), + "exact": True} def store_data(self, data, @@ -111,23 +118,18 @@ def store_data(self, **kwds) dset[:] = data - def get_object_dset(self, group, data, key, chunks, **kwds): + def get_object_dset(self, group,data_shape, key, chunks, **kwds): """Overrides the hyperspy get object dset function for using zarr as the backend """ - if data.dtype == np.dtype('O'): - # For saving ragged array - # https://zarr.readthedocs.io/en/stable/tutorial.html?highlight=ragged%20array#ragged-arrays - if chunks is None: - chunks == 1 - these_kwds = kwds.copy() - these_kwds.update(dict(dtype=object, - exact=True, - chunks=chunks)) - dset = group.require_dataset(key, - data.shape, - object_codec=numcodecs.VLenArray(int), - **these_kwds) - return data, dset + these_kwds = kwds.copy() + these_kwds.update(dict(dtype=object, + exact=True, + chunks=chunks)) + dset = group.require_dataset(key, + data_shape, + object_codec=numcodecs.VLenArray(int), + **these_kwds) + return dset def get_signal_chunks(self, shape, dtype, signal_axes=None): """Function that calculates chunks for the signal, @@ -151,38 +153,6 @@ def get_signal_chunks(self, shape, dtype, signal_axes=None): if total_size < 1e8: # 1 mb return None - def parse_structure(self, key, group, value, _type, **kwds): - from hyperspy.signal import BaseSignal - try: - # Here we check if there are any signals in the container, as - # casting a long list of signals to a numpy array takes a very long - # time. So we check if there are any, and save numpy the trouble - if np.any([isinstance(t, BaseSignal) for t in value]): - tmp = np.array([[0]]) - else: - tmp = np.array(value) - except ValueError: - tmp = np.array([[0]]) - if tmp.dtype == np.dtype('O') or tmp.ndim != 1: - self.dict2hdfgroup(dict(zip( - [str(i) for i in range(len(value))], value)), - group.create_group(_type + str(len(value)) + '_' + key), - **kwds) - elif tmp.dtype.type is np.unicode_: - if _type + key in group: - del group[_type + key] - group.create_dataset(_type + key, - data=tmp, - dtype=object, - object_codec=numcodecs.JSON(), **kwds) - else: - if _type + key in group: - del group[_type + key] - group.create_dataset( - _type + key, - data=tmp, - **kwds) - def file_writer(filename, signal, *args, **kwds): """Writes data to hyperspy's zarr format @@ -238,7 +208,7 @@ def file_reader(filename, """ mode = kwds.pop('mode', 'r') f = zarr.open(filename, mode=mode, **kwds) - reader = HyperspyReader(f, Group, Dataset=Array) + reader = ZspyReader(f) if reader.version > version: warnings.warn( "This file was written using a newer version of the " From 28350a5de711600b1284c247e0d7db5a50285d97 Mon Sep 17 00:00:00 2001 From: shaw Date: Thu, 16 Sep 2021 10:09:49 -0500 Subject: [PATCH 08/31] Broken-- Starting with a new branch --- hyperspy/io_plugins/nexus.py | 411 ++++++++++++++++++++--------------- hyperspy/io_plugins/zspy.py | 2 +- 2 files changed, 234 insertions(+), 179 deletions(-) diff --git a/hyperspy/io_plugins/nexus.py b/hyperspy/io_plugins/nexus.py index 8fa6baab63..8fc30ffdf5 100644 --- a/hyperspy/io_plugins/nexus.py +++ b/hyperspy/io_plugins/nexus.py @@ -25,7 +25,7 @@ import h5py import pprint import traits.api as t -from hyperspy.io_plugins.hspy import HyperspyReader, HyperspyWriter +from hyperspy.io_plugins.hspy import HyperspyReader, HierarchicalWriter from hyperspy.misc.utils import DictionaryTreeBrowser from hyperspy.exceptions import VisibleDeprecationWarning _logger = logging.getLogger(__name__) @@ -47,9 +47,232 @@ class NexusReader(HyperspyReader): + def get_format_version(self): + pass + def read(self, + lazy, + dataset_key=None, + dataset_path=None, + metadata_key=None, + skip_array_metadata=False, + nxdata_only=False, + hardlinks_only=False, + use_default=False, + **kwds): + signal_dict_list = [] + if 'dataset_keys' in kwds: + warnings.warn("The `dataset_keys` keyword is deprecated. " + "Use `dataset_key` instead.", VisibleDeprecationWarning) + dataset_key = kwds['dataset_keys'] + + if 'dataset_paths' in kwds: + warnings.warn("The `dataset_paths` keyword is deprecated. " + "Use `dataset_path` instead.", VisibleDeprecationWarning) + dataset_path = kwds['dataset_paths'] + + if 'metadata_keys' in kwds: + warnings.warn("The `metadata_keys` keyword is deprecated. " + "Use `metadata_key` instead.", VisibleDeprecationWarning) + metadata_key = kwds['metadata_keys'] + + dataset_key = _check_search_keys(dataset_key) + dataset_path = _check_search_keys(dataset_path) + metadata_key = _check_search_keys(metadata_key) + original_metadata = _load_metadata(fin, + lazy=lazy, + skip_array_metadata=skip_array_metadata) + # some default values... + nexus_data_paths = [] + hdf_data_paths = [] + # check if a default dataset is defined + if use_default: + nexus_data_paths, hdf_data_paths = _find_data(fin, + search_keys=None, + hardlinks_only=False) + nxentry = None + nxdata = None + if "attrs" in original_metadata: + if "default" in original_metadata["attrs"]: + nxentry = original_metadata["attrs"]["default"] + else: + rootlist = list(original_metadata.keys()) + rootlist.remove("attrs") + if rootlist and len(rootlist) == 1: + nxentry == rootlist[0] + if nxentry: + if "default" in original_metadata[nxentry]["attrs"]: + nxdata = original_metadata[nxentry]["attrs"]["default"] + if nxentry and nxdata: + nxdata = "/" + nxentry + "/" + nxdata + if nxdata: + hdf_data_paths = [] + nexus_data_paths = [nxpath for nxpath in nexus_data_paths + if nxdata in nxpath] + + # if no default found then search for the data as normal + if not nexus_data_paths and not hdf_data_paths: + nexus_data_paths, hdf_data_paths = \ + _find_data(fin, search_keys=dataset_key, + hardlinks_only=hardlinks_only, + absolute_path=dataset_path) + + for data_path in nexus_data_paths: + dictionary = _nexus_dataset_to_signal(fin, data_path, lazy=lazy) + entryname = _text_split(data_path, "/")[0] + dictionary["mapping"] = mapping + title = dictionary["metadata"]["General"]["title"] + if entryname in original_metadata: + if metadata_key is None: + dictionary["original_metadata"] = \ + original_metadata[entryname] + else: + dictionary["original_metadata"] = \ + _find_search_keys_in_dict(original_metadata, + search_keys=metadata_key) + # test if it's a hyperspy_nexus format and update metadata + # as appropriate. + if "attrs" in original_metadata and \ + "file_writer" in original_metadata["attrs"]: + if original_metadata["attrs"]["file_writer"] == \ + "hyperspy_nexus_v3": + orig_metadata = original_metadata[entryname] + if "auxiliary" in orig_metadata: + oma = orig_metadata["auxiliary"] + if "learning_results" in oma: + learning = oma["learning_results"] + dictionary["attributes"] = {} + dictionary["attributes"]["learning_results"] = \ + learning + if "original_metadata" in oma: + if metadata_key is None: + dictionary["original_metadata"] = \ + (oma["original_metadata"]) + else: + dictionary["original_metadata"] = \ + _find_search_keys_in_dict( + (oma["original_metadata"]), + search_keys=metadata_key) + # reconstruct the axes_list for axes_manager + for k, v in oma['original_metadata'].items(): + if k.startswith('_sig_'): + hyper_ax = [ax_v for ax_k, ax_v in v.items() + if ax_k.startswith('_axes')] + oma['original_metadata'][k]['axes'] = hyper_ax + if "hyperspy_metadata" in oma: + hyper_metadata = oma["hyperspy_metadata"] + hyper_metadata.update(dictionary["metadata"]) + dictionary["metadata"] = hyper_metadata + else: + dictionary["original_metadata"] = {} + + signal_dict_list.append(dictionary) + + if not nxdata_only: + for data_path in hdf_data_paths: + datadict = _extract_hdf_dataset(fin, data_path, lazy=lazy) + if datadict: + title = data_path[1:].replace('/', '_') + basic_metadata = {'General': + {'original_filename': + os.path.split(filename)[1], + 'title': title}} + datadict["metadata"].update(basic_metadata) + signal_dict_list.append(datadict) + + return signal_dict_list + + +class NexusWriter(HeirachicalWriter): + def write_signal(self, + signals, + expg, + save_original_metadata=True, + skip_metadata_key=None, + use_default=False, + **kwds): + """Write the signal and metadata as a nexus file. + + This will save the signal in NXdata format in the file. + As the form of the metadata can vary and is not validated it will + be stored as an NXcollection (an unvalidated collection) + + Parameters + ---------- + filename : str + Path of the file to write + signals : signal or list of signals + Signal(s) to be written + save_original_metadata : bool , default : False + Option to save hyperspy.original_metadata with the signal. + A loaded Nexus file may have a large amount of data + when loaded which you may wish to omit on saving + skip_metadata_key : str or list of str, default : None + the key(s) to skip when it is saving original metadata. This is useful + when some metadata's keys are to be ignored. + use_default : bool , default : False + Option to define the default dataset in the file. + If set to True the signal or first signal in the list of signals + will be defined as the default (following Nexus v3 data rules). + + See Also + -------- + * :py:meth:`~.io_plugins.nexus.file_reader` + * :py:meth:`~.io_plugins.nexus.list_datasets_in_file` + * :py:meth:`~.io_plugins.nexus.read_metadata_from_file` + + """ + for i, sig in enumerate(signals): + nxentry = f.create_group("entry%d" % (i + 1)) + nxentry.attrs["NX_class"] = _parse_to_file("NXentry") + if isinstance(sig.metadata, dict): + sig.metadata = DictionaryTreeBrowser(sig.metadata) + if isinstance(sig.original_metadata, dict): + sig.original_metadata = DictionaryTreeBrowser( + sig.original_metadata) + + signal_name = sig.metadata.General.title \ + if sig.metadata.General.title else 'unnamed__%d' % i + if "/" in signal_name: + signal_name = signal_name.replace("/", "_") + if signal_name.startswith("__"): + signal_name = signal_name[2:] + + if i == 0 and use_default: + nxentry.attrs["default"] = signal_name -class NexusWriter(HyperspyWriter): + nxaux = nxentry.create_group("auxiliary") + nxaux.attrs["NX_class"] = _parse_to_file("NXentry") + _write_signal(sig, nxentry, signal_name, **kwds) + if sig.learning_results: + nxlearn = nxaux.create_group('learning_results') + nxlearn.attrs["NX_class"] = _parse_to_file("NXcollection") + learn = sig.learning_results.__dict__ + _write_nexus_groups(learn, nxlearn, **kwds) + _write_nexus_attr(learn, nxlearn) + # + # write metadata + # + if save_original_metadata: + if sig.original_metadata: + ometa = sig.original_metadata.as_dictionary() + + nxometa = nxaux.create_group('original_metadata') + nxometa.attrs["NX_class"] = _parse_to_file("NXcollection") + # write the groups and structure + _write_nexus_groups(ometa, nxometa, + skip_keys=skip_metadata_key, **kwds) + _write_nexus_attr(ometa, nxometa, + skip_keys=skip_metadata_key) + + if sig.metadata: + meta = sig.metadata.as_dictionary() + + nxometa = nxaux.create_group('hyperspy_metadata') + nxometa.attrs["NX_class"] = _parse_to_file("NXcollection") + # write the groups and structure + _write_nexus_groups(meta, nxometa, **kwds) + _write_nexus_attr(meta, nxometa) def _byte_to_string(value): """Decode a byte string. @@ -506,126 +729,6 @@ def file_reader(filename, lazy=False, dataset_key=None, dataset_path=None, fin = h5py.File(filename, "r") signal_dict_list = [] - if 'dataset_keys' in kwds: - warnings.warn("The `dataset_keys` keyword is deprecated. " - "Use `dataset_key` instead.", VisibleDeprecationWarning) - dataset_key = kwds['dataset_keys'] - - if 'dataset_paths' in kwds: - warnings.warn("The `dataset_paths` keyword is deprecated. " - "Use `dataset_path` instead.", VisibleDeprecationWarning) - dataset_path = kwds['dataset_paths'] - - if 'metadata_keys' in kwds: - warnings.warn("The `metadata_keys` keyword is deprecated. " - "Use `metadata_key` instead.", VisibleDeprecationWarning) - metadata_key = kwds['metadata_keys'] - - dataset_key = _check_search_keys(dataset_key) - dataset_path = _check_search_keys(dataset_path) - metadata_key = _check_search_keys(metadata_key) - original_metadata = _load_metadata(fin, lazy=lazy, - skip_array_metadata=skip_array_metadata) - # some default values... - nexus_data_paths = [] - hdf_data_paths = [] - # check if a default dataset is defined - if use_default: - nexus_data_paths, hdf_data_paths = _find_data(fin, - search_keys=None, - hardlinks_only=False) - nxentry = None - nxdata = None - if "attrs" in original_metadata: - if "default" in original_metadata["attrs"]: - nxentry = original_metadata["attrs"]["default"] - else: - rootlist = list(original_metadata.keys()) - rootlist.remove("attrs") - if rootlist and len(rootlist) == 1: - nxentry == rootlist[0] - if nxentry: - if "default" in original_metadata[nxentry]["attrs"]: - nxdata = original_metadata[nxentry]["attrs"]["default"] - if nxentry and nxdata: - nxdata = "/"+nxentry+"/"+nxdata - if nxdata: - hdf_data_paths = [] - nexus_data_paths = [nxpath for nxpath in nexus_data_paths - if nxdata in nxpath] - - # if no default found then search for the data as normal - if not nexus_data_paths and not hdf_data_paths: - nexus_data_paths, hdf_data_paths = \ - _find_data(fin, search_keys=dataset_key, - hardlinks_only=hardlinks_only, - absolute_path=dataset_path) - - for data_path in nexus_data_paths: - dictionary = _nexus_dataset_to_signal(fin, data_path, lazy=lazy) - entryname = _text_split(data_path, "/")[0] - dictionary["mapping"] = mapping - title = dictionary["metadata"]["General"]["title"] - if entryname in original_metadata: - if metadata_key is None: - dictionary["original_metadata"] = \ - original_metadata[entryname] - else: - dictionary["original_metadata"] = \ - _find_search_keys_in_dict(original_metadata, - search_keys=metadata_key) - # test if it's a hyperspy_nexus format and update metadata - # as appropriate. - if "attrs" in original_metadata and \ - "file_writer" in original_metadata["attrs"]: - if original_metadata["attrs"]["file_writer"] == \ - "hyperspy_nexus_v3": - orig_metadata = original_metadata[entryname] - if "auxiliary" in orig_metadata: - oma = orig_metadata["auxiliary"] - if "learning_results" in oma: - learning = oma["learning_results"] - dictionary["attributes"] = {} - dictionary["attributes"]["learning_results"] = \ - learning - if "original_metadata" in oma: - if metadata_key is None: - dictionary["original_metadata"] = \ - (oma["original_metadata"]) - else: - dictionary["original_metadata"] = \ - _find_search_keys_in_dict( - (oma["original_metadata"]), - search_keys=metadata_key) - # reconstruct the axes_list for axes_manager - for k, v in oma['original_metadata'].items(): - if k.startswith('_sig_'): - hyper_ax = [ax_v for ax_k, ax_v in v.items() - if ax_k.startswith('_axes')] - oma['original_metadata'][k]['axes'] = hyper_ax - if "hyperspy_metadata" in oma: - hyper_metadata = oma["hyperspy_metadata"] - hyper_metadata.update(dictionary["metadata"]) - dictionary["metadata"] = hyper_metadata - else: - dictionary["original_metadata"] = {} - - signal_dict_list.append(dictionary) - - if not nxdata_only: - for data_path in hdf_data_paths: - datadict = _extract_hdf_dataset(fin, data_path, lazy=lazy) - if datadict: - title = data_path[1:].replace('/', '_') - basic_metadata = {'General': - {'original_filename': - os.path.split(filename)[1], - 'title': title}} - datadict["metadata"].update(basic_metadata) - signal_dict_list.append(datadict) - - return signal_dict_list - def _is_linear_axis(data): """Check if the data is linearly incrementing. @@ -1219,60 +1322,12 @@ def file_writer(filename, kwds['compression'] = 'gzip' if use_default: f.attrs["default"] = "entry1" - # - # write the signals - # - for i, sig in enumerate(signals): - nxentry = f.create_group("entry%d" % (i + 1)) - nxentry.attrs["NX_class"] = _parse_to_file("NXentry") - - if isinstance(sig.metadata, dict): - sig.metadata = DictionaryTreeBrowser(sig.metadata) - if isinstance(sig.original_metadata, dict): - sig.original_metadata = DictionaryTreeBrowser( - sig.original_metadata) - - signal_name = sig.metadata.General.title \ - if sig.metadata.General.title else 'unnamed__%d' % i - if "/" in signal_name: - signal_name = signal_name.replace("/", "_") - if signal_name.startswith("__"): - signal_name = signal_name[2:] - - if i == 0 and use_default: - nxentry.attrs["default"] = signal_name - - nxaux = nxentry.create_group("auxiliary") - nxaux.attrs["NX_class"] = _parse_to_file("NXentry") - _write_signal(sig, nxentry, signal_name, **kwds) - - if sig.learning_results: - nxlearn = nxaux.create_group('learning_results') - nxlearn.attrs["NX_class"] = _parse_to_file("NXcollection") - learn = sig.learning_results.__dict__ - _write_nexus_groups(learn, nxlearn, **kwds) - _write_nexus_attr(learn, nxlearn) - # - # write metadata - # - if save_original_metadata: - if sig.original_metadata: - ometa = sig.original_metadata.as_dictionary() - - nxometa = nxaux.create_group('original_metadata') - nxometa.attrs["NX_class"] = _parse_to_file("NXcollection") - # write the groups and structure - _write_nexus_groups(ometa, nxometa, - skip_keys=skip_metadata_key, **kwds) - _write_nexus_attr(ometa, nxometa, - skip_keys=skip_metadata_key) - - if sig.metadata: - meta = sig.metadata.as_dictionary() - - nxometa = nxaux.create_group('hyperspy_metadata') - nxometa.attrs["NX_class"] = _parse_to_file("NXcollection") - # write the groups and structure - _write_nexus_groups(meta, nxometa, **kwds) - _write_nexus_attr(meta, nxometa) + writer = NexusWriter(f, + signals, + save_original_metadata=True, + skip_metadata_key=None, + use_default=False, + *args, + **kwds) + writer.write_signal() diff --git a/hyperspy/io_plugins/zspy.py b/hyperspy/io_plugins/zspy.py index f7e88e9aab..22f784a4f5 100644 --- a/hyperspy/io_plugins/zspy.py +++ b/hyperspy/io_plugins/zspy.py @@ -129,7 +129,7 @@ def get_object_dset(self, group,data_shape, key, chunks, **kwds): data_shape, object_codec=numcodecs.VLenArray(int), **these_kwds) - return dset + return dset def get_signal_chunks(self, shape, dtype, signal_axes=None): """Function that calculates chunks for the signal, From c6f8178168931ce0342279191ccd5c75d8f20c03 Mon Sep 17 00:00:00 2001 From: shaw Date: Thu, 16 Sep 2021 12:09:47 -0500 Subject: [PATCH 09/31] Cleaned up get_signal_chunks and overwrite dataset functions --- hyperspy/io_plugins/emd.py | 2 +- hyperspy/io_plugins/hierarchical.py | 113 +----- hyperspy/io_plugins/hspy.py | 97 ++++- hyperspy/io_plugins/nexus.py | 549 ++++++++++++---------------- hyperspy/io_plugins/zspy.py | 140 +++---- 5 files changed, 416 insertions(+), 485 deletions(-) diff --git a/hyperspy/io_plugins/emd.py b/hyperspy/io_plugins/emd.py index 8773d218b4..aeac092490 100644 --- a/hyperspy/io_plugins/emd.py +++ b/hyperspy/io_plugins/emd.py @@ -43,7 +43,7 @@ from hyperspy.exceptions import VisibleDeprecationWarning from hyperspy.misc.elements import atomic_number2name import hyperspy.misc.io.fei_stream_readers as stream_readers -#from hyperspy.io_plugins.hspy import get_signal_chunks +from hyperspy.io_plugins.hspy import get_signal_chunks # Plugin characteristics diff --git a/hyperspy/io_plugins/hierarchical.py b/hyperspy/io_plugins/hierarchical.py index f92a28ac15..3b4bf131d3 100644 --- a/hyperspy/io_plugins/hierarchical.py +++ b/hyperspy/io_plugins/hierarchical.py @@ -27,6 +27,7 @@ def _overwrite_dataset(group, chunks=None, get_signal_chunks=None, get_object_dset=None, + store_data=None, **kwds): """Overwrites some dataset in a hirarchical dataset. @@ -94,86 +95,6 @@ def _overwrite_dataset(group, _logger.info(f"Chunks used for saving: {chunks}") store_data(data, dset, group, key, chunks, **kwds) - def get_signal_chunks(self, - shape, - dtype, - signal_axes=None): - """Function that calculates chunks for the signal, preferably at least one - chunk per signal space. - - Parameters - ---------- - shape : tuple - the shape of the dataset to be sored / chunked - dtype : {dtype, string} - the numpy dtype of the data - signal_axes: {None, iterable of ints} - the axes defining "signal space" of the dataset. If None, the default - h5py chunking is performed. - """ - typesize = np.dtype(dtype).itemsize - if signal_axes is None: - return h5py._hl.filters.guess_chunk(shape, None, typesize) - - # largely based on the guess_chunk in h5py - CHUNK_MAX = 1024 * 1024 - want_to_keep = multiply([shape[i] for i in signal_axes]) * typesize - if want_to_keep >= CHUNK_MAX: - chunks = [1 for _ in shape] - for i in signal_axes: - chunks[i] = shape[i] - return tuple(chunks) - - chunks = [i for i in shape] - idx = 0 - navigation_axes = tuple(i for i in range(len(shape)) if i not in - signal_axes) - nchange = len(navigation_axes) - while True: - chunk_bytes = multiply(chunks) * typesize - - if chunk_bytes < CHUNK_MAX: - break - - if multiply([chunks[i] for i in navigation_axes]) == 1: - break - change = navigation_axes[idx % nchange] - chunks[change] = np.ceil(chunks[change] / 2.0) - idx += 1 - return tuple(int(x) for x in chunks) - -def _get_object_dset( group, key, chunks, **kwds): - """Creates a Dataset or Zarr Array object for saving ragged data - Parameters - ---------- - self - group - data - key - chunks - kwds - - Returns - ------- - - """ - # For saving ragged array - if chunks is None: - chunks = 1 - dset = group.require_dataset(key, - **kwds) - return dset - -def _store_data(data, dset, group, key, chunks, **kwds): - if isinstance(data, da.Array): - if data.chunks != dset.chunks: - data = data.rechunk(dset.chunks) - da.store(data, dset) - elif data.flags.c_contiguous: - dset.write_direct(data) - else: - dset[:] = data - class HierarchicalReader: """A generic Reader class for reading data from hierarchical file types.""" @@ -185,8 +106,6 @@ def __init__(self, self.Group = None self.unicode_kwds = None self.ragged_kwds = None - self.store_data - self.get_object_dset def get_format_version(self): @@ -598,6 +517,7 @@ def __init__(self, self.unicode_kwds = None self.ragged_kwds = None self.kwds = kwds + self.overwrite_dataset=None def write(self): self.write_signal(self.signal, @@ -621,29 +541,31 @@ def write_signal(self, signal, group, **kwds): axis_dict = axis.get_axis_dictionary() coord_group = group.create_group( 'axis-%s' % axis.index_in_array) - self.dict2hdfgroup(axis_dict, coord_group, **kwds) + self.dict2group(axis_dict, coord_group, **kwds) mapped_par = group.create_group(metadata) metadata_dict = signal.metadata.as_dictionary() - self.overwrite_dataset(group, signal.data, 'data', - signal_axes=signal.axes_manager.signal_indices_in_array, - **kwds) + self.overwrite_dataset(group, + signal.data, + 'data', + signal_axes=signal.axes_manager.signal_indices_in_array, + **kwds) if default_version < LooseVersion("1.2"): metadata_dict["_internal_parameters"] = \ metadata_dict.pop("_HyperSpy") # Remove chunks from the kwds since it wouldn't have the same rank as the # dataset and can't be used kwds.pop('chunks', None) - self.dict2hdfgroup(metadata_dict, mapped_par, **kwds) + self.dict2group(metadata_dict, mapped_par, **kwds) original_par = group.create_group(original_metadata) - self.dict2hdfgroup(signal.original_metadata.as_dictionary(), original_par, + self.dict2group(signal.original_metadata.as_dictionary(), original_par, **kwds) learning_results = group.create_group('learning_results') - self.dict2hdfgroup(signal.learning_results.__dict__, + self.dict2group(signal.learning_results.__dict__, learning_results, **kwds) if hasattr(signal, 'peak_learning_results'): peak_learning_results = group.create_group( 'peak_learning_results') - self.dict2hdfgroup(signal.peak_learning_results.__dict__, + self.dict2group(signal.peak_learning_results.__dict__, peak_learning_results, **kwds) if len(signal.models): @@ -661,10 +583,10 @@ def dict2group(self, dictionary, group, **kwds): for key, value in dictionary.items(): if isinstance(value, dict): - self.dict2hdfgroup(value, group.create_group(key), + self.dict2group(value, group.create_group(key), **kwds) elif isinstance(value, DictionaryTreeBrowser): - self.dict2hdfgroup(value.as_dictionary(), + self.dict2group(value.as_dictionary(), group.create_group(key), **kwds) elif isinstance(value, BaseSignal): @@ -685,7 +607,7 @@ def dict2group(self, dictionary, group, **kwds): elif isinstance(value, str): group.attrs[key] = value elif isinstance(value, AxesManager): - self.dict2hdfgroup(value.as_dictionary(), + self.dict2group(value.as_dictionary(), group.create_group('_hspy_AxesManager_' + key), **kwds) elif isinstance(value, list): @@ -709,9 +631,6 @@ def dict2group(self, dictionary, group, **kwds): "The hdf5 writer could not write the following " "information in the file: %s : %s", key, value) - - - def parse_structure(self, key, group, value, _type, **kwds): from hyperspy.signal import BaseSignal try: @@ -725,7 +644,7 @@ def parse_structure(self, key, group, value, _type, **kwds): except ValueError: tmp = np.array([[0]]) if tmp.dtype == np.dtype('O') or tmp.ndim != 1: - self.dict2hdfgroup(dict(zip( + self.dict2group(dict(zip( [str(i) for i in range(len(value))], value)), group.create_group(_type + str(len(value)) + '_' + key), **kwds) diff --git a/hyperspy/io_plugins/hspy.py b/hyperspy/io_plugins/hspy.py index e9e1988197..1ddb9fc06a 100644 --- a/hyperspy/io_plugins/hspy.py +++ b/hyperspy/io_plugins/hspy.py @@ -19,10 +19,14 @@ from distutils.version import LooseVersion import warnings import logging +from functools import partial import h5py +import numpy as np +import dask.array as da from h5py import Dataset, File, Group -from hyperspy.io_plugins.hierarchical import HierarchicalWriter, HierarchicalReader +from hyperspy.io_plugins.hierarchical import HierarchicalWriter, HierarchicalReader, _overwrite_dataset +from hyperspy.misc.utils import multiply _logger = logging.getLogger(__name__) @@ -98,6 +102,96 @@ default_version = LooseVersion(version) +def get_signal_chunks(shape, + dtype, + signal_axes=None): + """Function that calculates chunks for the signal, preferably at least one + chunk per signal space. + + Parameters + ---------- + shape : tuple + the shape of the dataset to be sored / chunked + dtype : {dtype, string} + the numpy dtype of the data + signal_axes: {None, iterable of ints} + the axes defining "signal space" of the dataset. If None, the default + h5py chunking is performed. + """ + typesize = np.dtype(dtype).itemsize + if signal_axes is None: + return h5py._hl.filters.guess_chunk(shape, None, typesize) + + # largely based on the guess_chunk in h5py + CHUNK_MAX = 1024 * 1024 + want_to_keep = multiply([shape[i] for i in signal_axes]) * typesize + if want_to_keep >= CHUNK_MAX: + chunks = [1 for _ in shape] + for i in signal_axes: + chunks[i] = shape[i] + return tuple(chunks) + + chunks = [i for i in shape] + idx = 0 + navigation_axes = tuple(i for i in range(len(shape)) if i not in + signal_axes) + nchange = len(navigation_axes) + while True: + chunk_bytes = multiply(chunks) * typesize + + if chunk_bytes < CHUNK_MAX: + break + + if multiply([chunks[i] for i in navigation_axes]) == 1: + break + change = navigation_axes[idx % nchange] + chunks[change] = np.ceil(chunks[change] / 2.0) + idx += 1 + return tuple(int(x) for x in chunks) + + +def _get_object_dset(group, data, key, chunks, **kwds): + """Creates a Dataset or Zarr Array object for saving ragged data + Parameters + ---------- + self + group + data + key + chunks + kwds + + Returns + ------- + + """ + # For saving ragged array + if chunks is None: + chunks = 1 + dset = group.require_dataset(key, + chunks, + dtype=h5py.special_dtype(vlen=data[0].dtype), + **kwds) + return dset + + +def _store_data(data, dset, group, key, chunks, **kwds): + if isinstance(data, da.Array): + if data.chunks != dset.chunks: + data = data.rechunk(dset.chunks) + da.store(data, dset) + elif data.flags.c_contiguous: + dset.write_direct(data) + else: + dset[:] = data + + +overwrite_dataset = partial(_overwrite_dataset, + get_signal_chunks=get_signal_chunks, + get_object_dset=_get_object_dset, + store_data=_store_data) + + class HyperspyReader(HierarchicalReader): def __init__(self, file): super().__init__(file) @@ -123,6 +217,7 @@ def __init__(self, self.Group = Group self.unicode_kwds = {"dtype": h5py.special_dtype(vlen=str)} self.ragged_kwds = {"dtype": h5py.special_dtype(vlen=signal.data[0].dtype)} + self.overwrite_dataset = overwrite_dataset def file_reader( filename, diff --git a/hyperspy/io_plugins/nexus.py b/hyperspy/io_plugins/nexus.py index 8fc30ffdf5..2f7b801e75 100644 --- a/hyperspy/io_plugins/nexus.py +++ b/hyperspy/io_plugins/nexus.py @@ -25,7 +25,7 @@ import h5py import pprint import traits.api as t -from hyperspy.io_plugins.hspy import HyperspyReader, HierarchicalWriter +from hyperspy.io_plugins.hspy import overwrite_dataset, get_signal_chunks from hyperspy.misc.utils import DictionaryTreeBrowser from hyperspy.exceptions import VisibleDeprecationWarning _logger = logging.getLogger(__name__) @@ -46,246 +46,15 @@ # ---------------------- -class NexusReader(HyperspyReader): - def get_format_version(self): - pass - def read(self, - lazy, - dataset_key=None, - dataset_path=None, - metadata_key=None, - skip_array_metadata=False, - nxdata_only=False, - hardlinks_only=False, - use_default=False, - **kwds): - signal_dict_list = [] - if 'dataset_keys' in kwds: - warnings.warn("The `dataset_keys` keyword is deprecated. " - "Use `dataset_key` instead.", VisibleDeprecationWarning) - dataset_key = kwds['dataset_keys'] - - if 'dataset_paths' in kwds: - warnings.warn("The `dataset_paths` keyword is deprecated. " - "Use `dataset_path` instead.", VisibleDeprecationWarning) - dataset_path = kwds['dataset_paths'] - - if 'metadata_keys' in kwds: - warnings.warn("The `metadata_keys` keyword is deprecated. " - "Use `metadata_key` instead.", VisibleDeprecationWarning) - metadata_key = kwds['metadata_keys'] - - dataset_key = _check_search_keys(dataset_key) - dataset_path = _check_search_keys(dataset_path) - metadata_key = _check_search_keys(metadata_key) - original_metadata = _load_metadata(fin, - lazy=lazy, - skip_array_metadata=skip_array_metadata) - # some default values... - nexus_data_paths = [] - hdf_data_paths = [] - # check if a default dataset is defined - if use_default: - nexus_data_paths, hdf_data_paths = _find_data(fin, - search_keys=None, - hardlinks_only=False) - nxentry = None - nxdata = None - if "attrs" in original_metadata: - if "default" in original_metadata["attrs"]: - nxentry = original_metadata["attrs"]["default"] - else: - rootlist = list(original_metadata.keys()) - rootlist.remove("attrs") - if rootlist and len(rootlist) == 1: - nxentry == rootlist[0] - if nxentry: - if "default" in original_metadata[nxentry]["attrs"]: - nxdata = original_metadata[nxentry]["attrs"]["default"] - if nxentry and nxdata: - nxdata = "/" + nxentry + "/" + nxdata - if nxdata: - hdf_data_paths = [] - nexus_data_paths = [nxpath for nxpath in nexus_data_paths - if nxdata in nxpath] - - # if no default found then search for the data as normal - if not nexus_data_paths and not hdf_data_paths: - nexus_data_paths, hdf_data_paths = \ - _find_data(fin, search_keys=dataset_key, - hardlinks_only=hardlinks_only, - absolute_path=dataset_path) - - for data_path in nexus_data_paths: - dictionary = _nexus_dataset_to_signal(fin, data_path, lazy=lazy) - entryname = _text_split(data_path, "/")[0] - dictionary["mapping"] = mapping - title = dictionary["metadata"]["General"]["title"] - if entryname in original_metadata: - if metadata_key is None: - dictionary["original_metadata"] = \ - original_metadata[entryname] - else: - dictionary["original_metadata"] = \ - _find_search_keys_in_dict(original_metadata, - search_keys=metadata_key) - # test if it's a hyperspy_nexus format and update metadata - # as appropriate. - if "attrs" in original_metadata and \ - "file_writer" in original_metadata["attrs"]: - if original_metadata["attrs"]["file_writer"] == \ - "hyperspy_nexus_v3": - orig_metadata = original_metadata[entryname] - if "auxiliary" in orig_metadata: - oma = orig_metadata["auxiliary"] - if "learning_results" in oma: - learning = oma["learning_results"] - dictionary["attributes"] = {} - dictionary["attributes"]["learning_results"] = \ - learning - if "original_metadata" in oma: - if metadata_key is None: - dictionary["original_metadata"] = \ - (oma["original_metadata"]) - else: - dictionary["original_metadata"] = \ - _find_search_keys_in_dict( - (oma["original_metadata"]), - search_keys=metadata_key) - # reconstruct the axes_list for axes_manager - for k, v in oma['original_metadata'].items(): - if k.startswith('_sig_'): - hyper_ax = [ax_v for ax_k, ax_v in v.items() - if ax_k.startswith('_axes')] - oma['original_metadata'][k]['axes'] = hyper_ax - if "hyperspy_metadata" in oma: - hyper_metadata = oma["hyperspy_metadata"] - hyper_metadata.update(dictionary["metadata"]) - dictionary["metadata"] = hyper_metadata - else: - dictionary["original_metadata"] = {} - - signal_dict_list.append(dictionary) - - if not nxdata_only: - for data_path in hdf_data_paths: - datadict = _extract_hdf_dataset(fin, data_path, lazy=lazy) - if datadict: - title = data_path[1:].replace('/', '_') - basic_metadata = {'General': - {'original_filename': - os.path.split(filename)[1], - 'title': title}} - datadict["metadata"].update(basic_metadata) - signal_dict_list.append(datadict) - - return signal_dict_list - - -class NexusWriter(HeirachicalWriter): - def write_signal(self, - signals, - expg, - save_original_metadata=True, - skip_metadata_key=None, - use_default=False, - **kwds): - """Write the signal and metadata as a nexus file. - - This will save the signal in NXdata format in the file. - As the form of the metadata can vary and is not validated it will - be stored as an NXcollection (an unvalidated collection) - - Parameters - ---------- - filename : str - Path of the file to write - signals : signal or list of signals - Signal(s) to be written - save_original_metadata : bool , default : False - Option to save hyperspy.original_metadata with the signal. - A loaded Nexus file may have a large amount of data - when loaded which you may wish to omit on saving - skip_metadata_key : str or list of str, default : None - the key(s) to skip when it is saving original metadata. This is useful - when some metadata's keys are to be ignored. - use_default : bool , default : False - Option to define the default dataset in the file. - If set to True the signal or first signal in the list of signals - will be defined as the default (following Nexus v3 data rules). - - See Also - -------- - * :py:meth:`~.io_plugins.nexus.file_reader` - * :py:meth:`~.io_plugins.nexus.list_datasets_in_file` - * :py:meth:`~.io_plugins.nexus.read_metadata_from_file` - - """ - for i, sig in enumerate(signals): - nxentry = f.create_group("entry%d" % (i + 1)) - nxentry.attrs["NX_class"] = _parse_to_file("NXentry") - if isinstance(sig.metadata, dict): - sig.metadata = DictionaryTreeBrowser(sig.metadata) - if isinstance(sig.original_metadata, dict): - sig.original_metadata = DictionaryTreeBrowser( - sig.original_metadata) - - signal_name = sig.metadata.General.title \ - if sig.metadata.General.title else 'unnamed__%d' % i - if "/" in signal_name: - signal_name = signal_name.replace("/", "_") - if signal_name.startswith("__"): - signal_name = signal_name[2:] - - if i == 0 and use_default: - nxentry.attrs["default"] = signal_name - - nxaux = nxentry.create_group("auxiliary") - nxaux.attrs["NX_class"] = _parse_to_file("NXentry") - _write_signal(sig, nxentry, signal_name, **kwds) - - if sig.learning_results: - nxlearn = nxaux.create_group('learning_results') - nxlearn.attrs["NX_class"] = _parse_to_file("NXcollection") - learn = sig.learning_results.__dict__ - _write_nexus_groups(learn, nxlearn, **kwds) - _write_nexus_attr(learn, nxlearn) - # - # write metadata - # - if save_original_metadata: - if sig.original_metadata: - ometa = sig.original_metadata.as_dictionary() - - nxometa = nxaux.create_group('original_metadata') - nxometa.attrs["NX_class"] = _parse_to_file("NXcollection") - # write the groups and structure - _write_nexus_groups(ometa, nxometa, - skip_keys=skip_metadata_key, **kwds) - _write_nexus_attr(ometa, nxometa, - skip_keys=skip_metadata_key) - - if sig.metadata: - meta = sig.metadata.as_dictionary() - - nxometa = nxaux.create_group('hyperspy_metadata') - nxometa.attrs["NX_class"] = _parse_to_file("NXcollection") - # write the groups and structure - _write_nexus_groups(meta, nxometa, **kwds) - _write_nexus_attr(meta, nxometa) - def _byte_to_string(value): """Decode a byte string. - Parameters ---------- value : byte str - Returns ------- str decoded version of input value - """ try: text = value.decode("utf-8") @@ -296,24 +65,20 @@ def _byte_to_string(value): def _parse_from_file(value, lazy=False): """To convert values from the hdf file to compatible formats. - When reading string arrays we convert or keep string arrays as byte_strings (some io_plugins only supports byte-strings arrays so this ensures inter-compatibility across io_plugins) Arrays of length 1 - return the single value stored. Large datasets are returned as dask arrays if lazy=True. - Parameters ---------- value : input read from hdf file (array,list,tuple,string,int,float) lazy : bool {default: False} The lazy flag is only applied to values of size >=2 - Returns ------- str,int, float, ndarray dask Array parsed value. - """ toreturn = value if isinstance(value, h5py.Dataset): @@ -324,7 +89,7 @@ def _parse_from_file(value, lazy=False): if value.chunks: toreturn = da.from_array(value, value.chunks) else: - chunks = self.get_signal_chunks(value.shape, value.dtype) + chunks = get_signal_chunks(value.shape, value.dtype) toreturn = da.from_array(value, chunks) else: toreturn = np.array(value) @@ -342,18 +107,14 @@ def _parse_from_file(value, lazy=False): def _parse_to_file(value): """Convert to a suitable format for writing to HDF5. - For example unicode values are not compatible with hdf5 so conversion to byte strings is required. - Parameters ---------- value - input object to write to the hdf file - Returns ------- parsed value - """ totest = value toreturn = totest @@ -373,17 +134,14 @@ def _parse_to_file(value): def _text_split(s, sep): """Split a string based of list of seperators. - Parameters ---------- s : str sep : str - seperator or list of seperators e.g. '.' or ['_','/'] - Returns ------- list String sections split based on the seperators - """ stack = [s] for char in sep: @@ -398,16 +156,13 @@ def _text_split(s, sep): def _getlink(h5group, rootkey, key): """Return the link target path. - If a hdf group is a soft link or has a target attribute this method will return the target path. If no link is found return None. - Returns ------- str Soft link path if it exists, otherwise None - """ _target = None if rootkey != '/': @@ -427,14 +182,12 @@ def _getlink(h5group, rootkey, key): def _get_nav_list(data, dataentry): """Get the list with information of each axes of the dataset - Parameters ---------- data : hdf dataset the dataset to be loaded. dataentry : hdf group the group with corresponding attributes. - Returns ------- nav_list : list @@ -524,7 +277,6 @@ def _get_nav_list(data, dataentry): def _extract_hdf_dataset(group, dataset, lazy=False): """Import data from hdf path. - Parameters ---------- group : hdf group @@ -533,12 +285,10 @@ def _extract_hdf_dataset(group, dataset, lazy=False): path to the dataset within the group lazy : bool {default:True} If true use lazy opening, if false read into memory - Returns ------- dict A signal dictionary which can be used to instantiate a signal. - """ data = group[dataset] @@ -573,7 +323,6 @@ def _extract_hdf_dataset(group, dataset, lazy=False): def _nexus_dataset_to_signal(group, nexus_dataset_path, lazy=False): """Load an NXdata set as a hyperspy signal. - Parameters ---------- group : hdf group containing the NXdata @@ -581,12 +330,10 @@ def _nexus_dataset_to_signal(group, nexus_dataset_path, lazy=False): Path to the NXdata set in the group lazy : bool, default : True lazy loading of data - Returns ------- dict A signal dictionary which can be used to instantiate a signal. - """ interpretation = None @@ -662,7 +409,6 @@ def file_reader(filename, lazy=False, dataset_key=None, dataset_path=None, use_default=False, **kwds): """Read NXdata class or hdf datasets from a file and return signal(s). - Note ---- Loading all datasets can result in a large number of signals @@ -670,9 +416,7 @@ def file_reader(filename, lazy=False, dataset_key=None, dataset_path=None, the datasets of interest. "keys" is a special keywords and prepended with "fix" in the metadata structure to avoid any issues. - Datasets are all arrays with size>2 (arrays, lists) - Parameters ---------- filename : str @@ -708,18 +452,13 @@ def file_reader(filename, lazy=False, dataset_key=None, dataset_path=None, signal. This will ignore the other keyword options. If True and no default is defined the file will be loaded according to the keyword options. - Returns ------- dict : signal dictionary or list of signal dictionaries - - See Also -------- * :py:meth:`~.io_plugins.nexus.list_datasets_in_file` * :py:meth:`~.io_plugins.nexus.read_metadata_from_file` - - """ # search for NXdata sets... @@ -729,19 +468,136 @@ def file_reader(filename, lazy=False, dataset_key=None, dataset_path=None, fin = h5py.File(filename, "r") signal_dict_list = [] + if 'dataset_keys' in kwds: + warnings.warn("The `dataset_keys` keyword is deprecated. " + "Use `dataset_key` instead.", VisibleDeprecationWarning) + dataset_key = kwds['dataset_keys'] + + if 'dataset_paths' in kwds: + warnings.warn("The `dataset_paths` keyword is deprecated. " + "Use `dataset_path` instead.", VisibleDeprecationWarning) + dataset_path = kwds['dataset_paths'] + + if 'metadata_keys' in kwds: + warnings.warn("The `metadata_keys` keyword is deprecated. " + "Use `metadata_key` instead.", VisibleDeprecationWarning) + metadata_key = kwds['metadata_keys'] + + dataset_key = _check_search_keys(dataset_key) + dataset_path = _check_search_keys(dataset_path) + metadata_key = _check_search_keys(metadata_key) + original_metadata = _load_metadata(fin, lazy=lazy, + skip_array_metadata=skip_array_metadata) + # some default values... + nexus_data_paths = [] + hdf_data_paths = [] + # check if a default dataset is defined + if use_default: + nexus_data_paths, hdf_data_paths = _find_data(fin, + search_keys=None, + hardlinks_only=False) + nxentry = None + nxdata = None + if "attrs" in original_metadata: + if "default" in original_metadata["attrs"]: + nxentry = original_metadata["attrs"]["default"] + else: + rootlist = list(original_metadata.keys()) + rootlist.remove("attrs") + if rootlist and len(rootlist) == 1: + nxentry == rootlist[0] + if nxentry: + if "default" in original_metadata[nxentry]["attrs"]: + nxdata = original_metadata[nxentry]["attrs"]["default"] + if nxentry and nxdata: + nxdata = "/"+nxentry+"/"+nxdata + if nxdata: + hdf_data_paths = [] + nexus_data_paths = [nxpath for nxpath in nexus_data_paths + if nxdata in nxpath] + + # if no default found then search for the data as normal + if not nexus_data_paths and not hdf_data_paths: + nexus_data_paths, hdf_data_paths = \ + _find_data(fin, search_keys=dataset_key, + hardlinks_only=hardlinks_only, + absolute_path=dataset_path) + + for data_path in nexus_data_paths: + dictionary = _nexus_dataset_to_signal(fin, data_path, lazy=lazy) + entryname = _text_split(data_path, "/")[0] + dictionary["mapping"] = mapping + title = dictionary["metadata"]["General"]["title"] + if entryname in original_metadata: + if metadata_key is None: + dictionary["original_metadata"] = \ + original_metadata[entryname] + else: + dictionary["original_metadata"] = \ + _find_search_keys_in_dict(original_metadata, + search_keys=metadata_key) + # test if it's a hyperspy_nexus format and update metadata + # as appropriate. + if "attrs" in original_metadata and \ + "file_writer" in original_metadata["attrs"]: + if original_metadata["attrs"]["file_writer"] == \ + "hyperspy_nexus_v3": + orig_metadata = original_metadata[entryname] + if "auxiliary" in orig_metadata: + oma = orig_metadata["auxiliary"] + if "learning_results" in oma: + learning = oma["learning_results"] + dictionary["attributes"] = {} + dictionary["attributes"]["learning_results"] = \ + learning + if "original_metadata" in oma: + if metadata_key is None: + dictionary["original_metadata"] = \ + (oma["original_metadata"]) + else: + dictionary["original_metadata"] = \ + _find_search_keys_in_dict( + (oma["original_metadata"]), + search_keys=metadata_key) + # reconstruct the axes_list for axes_manager + for k, v in oma['original_metadata'].items(): + if k.startswith('_sig_'): + hyper_ax = [ax_v for ax_k, ax_v in v.items() + if ax_k.startswith('_axes')] + oma['original_metadata'][k]['axes'] = hyper_ax + if "hyperspy_metadata" in oma: + hyper_metadata = oma["hyperspy_metadata"] + hyper_metadata.update(dictionary["metadata"]) + dictionary["metadata"] = hyper_metadata + else: + dictionary["original_metadata"] = {} + + signal_dict_list.append(dictionary) + + if not nxdata_only: + for data_path in hdf_data_paths: + datadict = _extract_hdf_dataset(fin, data_path, lazy=lazy) + if datadict: + title = data_path[1:].replace('/', '_') + basic_metadata = {'General': + {'original_filename': + os.path.split(filename)[1], + 'title': title}} + datadict["metadata"].update(basic_metadata) + signal_dict_list.append(datadict) + + return signal_dict_list + def _is_linear_axis(data): """Check if the data is linearly incrementing. - Parameters ---------- data : dask or numpy array - Returns ------- bool True or False - """ steps = np.diff(data) est_steps = np.array([steps[0]]*len(steps)) @@ -750,16 +606,13 @@ def _is_linear_axis(data): def _is_numeric_data(data): """Check that data contains numeric data. - Parameters ---------- data : dask or numpy array - Returns ------- bool True or False - """ try: data.astype(float) @@ -770,16 +623,13 @@ def _is_numeric_data(data): def _is_int(s): """Check that s in an integer. - Parameters ---------- s : python object to test - Returns ------- bool True or False - """ try: int(s) @@ -806,7 +656,6 @@ def _check_search_keys(search_keys): def _find_data(group, search_keys=None, hardlinks_only=False, absolute_path=None): """Read from a nexus or hdf file and return a list of the dataset entries. - The method iterates through group attributes and returns NXdata or hdf datasets of size >=2 if they're not already NXdata blocks and returns a list of the entries @@ -815,8 +664,6 @@ def _find_data(group, search_keys=None, hardlinks_only=False, h5py.visit or visititems does not visit soft links or external links so an implementation of a recursive search is required. See https://github.com/h5py/h5py/issues/671 - - Parameters ---------- group : hdf group or File @@ -828,14 +675,12 @@ def _find_data(group, search_keys=None, hardlinks_only=False, Option to ignore links (soft or External) within the file. absolute_path : string, list of strings or None, default: None Return items with the exact specified absolute path - Returns ------- nx_dataset_list, hdf_dataset_list nx_dataset_list is a list of all NXdata paths hdf_dataset_list is a list of all hdf_datasets not linked to an NXdata set. - """ _check_search_keys(search_keys) _check_search_keys(absolute_path) @@ -917,11 +762,9 @@ def find_data_in_tree(group, rootname): def _load_metadata(group, lazy=False, skip_array_metadata=False): """Search through a hdf group and return the group structure. - h5py.visit or visititems does not visit soft links or external links so an implementation of a recursive search is required. See https://github.com/h5py/h5py/issues/671 - Parameters ---------- group : hdf group @@ -930,13 +773,10 @@ def _load_metadata(group, lazy=False, skip_array_metadata=False): Option for lazy loading skip_array_metadata : bool, default : False whether to skip loading array metadata - Returns ------- dict dictionary of group contents - - """ rootname = "" @@ -1001,21 +841,16 @@ def find_meta_in_tree(group, rootname, lazy=False, def _fix_exclusion_keys(key): """Exclude hyperspy specific keys. - Signal and DictionaryBrowser break if a a key is a dict method - e.g. {"keys":2.0}. - This method prepends the key with ``fix_`` so the information is still present to work around this issue - Parameters ---------- key : str - Returns ------- str - """ if key.startswith("keys"): return "fix_"+key @@ -1025,10 +860,8 @@ def _fix_exclusion_keys(key): def _find_search_keys_in_dict(tree, search_keys=None): """Search through a dict for search keys. - This is a convenience method to inspect a file for a value rather than loading the file as a signal - Parameters ---------- tree : h5py File object @@ -1036,13 +869,11 @@ def _find_search_keys_in_dict(tree, search_keys=None): Only return items which contain the strings .e.g search_keys = ["instrument","Fe"] will return hdf entries with instrument or Fe in their hdf path. - Returns ------- dict When search_list is specified only full paths containing one or more search_keys will be returned - """ _check_search_keys(search_keys) metadata_dict = {} @@ -1073,21 +904,67 @@ def find_searchkeys_in_tree(myDict, rootname): return metadata_dict +def _write_nexus_groups(dictionary, group, skip_keys=None, **kwds): + """Recursively iterate throuh dictionary and write groups to nexus. + Parameters + ---------- + dictionary : dict + dictionary contents to store to hdf group + group : hdf group + location to store dictionary + skip_keys : str or list of str + the key(s) to skip when writing into the group + **kwds : additional keywords + additional keywords to pass to h5py.create_dataset method + """ + if skip_keys is None: + skip_keys = [] + elif isinstance(skip_keys, str): + skip_keys = [skip_keys] + for key, value in dictionary.items(): + if key == 'attrs' or key in skip_keys: + # we will handle attrs later... and skip unwanted keys + continue + if isinstance(value, dict): + if "attrs" in value: + if "NX_class" in value["attrs"] and \ + value["attrs"]["NX_class"] == "NXdata": + continue + if 'value' in value.keys() \ + and not isinstance(value["value"], dict) \ + and len(set(list(value.keys()) + ["attrs", "value"])) == 2: + value = value["value"] + else: + _write_nexus_groups(value, group.require_group(key), + skip_keys=skip_keys, **kwds) + if isinstance(value, (list, tuple, np.ndarray, da.Array)): + if all(isinstance(v, dict) for v in value): + # a list of dictionary is from the axes of HyperSpy signal + for i, ax_dict in enumerate(value): + ax_key = '_axes_' + str(i) + _write_nexus_groups(ax_dict, group.require_group(ax_key), + skip_keys=skip_keys, **kwds) + else: + data = _parse_to_file(value) + overwrite_dataset(group, data, key, chunks=None, **kwds) + elif isinstance(value, (int, float, str, bytes)): + group.create_dataset(key, data=_parse_to_file(value)) + else: + if value is not None and value != t.Undefined and key not in group: + _write_nexus_groups(value, group.require_group(key), + skip_keys=skip_keys, **kwds) def _write_nexus_attr(dictionary, group, skip_keys=None): """Recursively iterate through dictionary and write "attrs" dictionaries. - This step is called after the groups and datasets have been created - Parameters ---------- dictionary : dict Input dictionary to be written to the hdf group group : hdf group location to store the attrs sections of the dictionary - """ if skip_keys is None: skip_keys = [] @@ -1114,12 +991,10 @@ def read_metadata_from_file(filename, metadata_key=None, lazy=False, verbose=False, skip_array_metadata=False): """Read the metadata from a nexus or hdf file. - This method iterates through the file and returns a dictionary of the entries. This is a convenience method to inspect a file for a value rather than loading the file as a signal. - Parameters ---------- filename : str @@ -1136,19 +1011,15 @@ def read_metadata_from_file(filename, metadata_key=None, Whether to skip loading array metadata. This is useful as a lot of large array may be present in the metadata and it is redundant with dataset itself. - Returns ------- dict Metadata dictionary. - See Also -------- * :py:meth:`~.io_plugins.nexus.file_reader` * :py:meth:`~.io_plugins.nexus.file_writer` * :py:meth:`~.io_plugins.nexus.list_datasets_in_file` - - """ search_keys = _check_search_keys(metadata_key) fin = h5py.File(filename, "r") @@ -1169,14 +1040,12 @@ def list_datasets_in_file(filename, dataset_key=None, hardlinks_only=False, verbose=True): """Read from a nexus or hdf file and return a list of the dataset paths. - This method is used to inspect the contents of a Nexus file. The method iterates through group attributes and returns NXdata or hdf datasets of size >=2 if they're not already NXdata blocks and returns a list of the entries. This is a convenience method to inspect a file to list datasets present rather than loading all the datasets in the file as signals. - Parameters ---------- filename : str @@ -1190,21 +1059,15 @@ def list_datasets_in_file(filename, dataset_key=None, If true any links (soft or External) will be ignored when loading. verbose : boolean, default : True Prints the results to screen - - Returns ------- list list of paths to datasets - - See Also -------- * :py:meth:`~.io_plugins.nexus.file_reader` * :py:meth:`~.io_plugins.nexus.file_writer` * :py:meth:`~.io_plugins.nexus.read_metadata_from_file` - - """ search_keys = _check_search_keys(dataset_key) fin = h5py.File(filename, "r") @@ -1231,7 +1094,6 @@ def list_datasets_in_file(filename, dataset_key=None, def _write_signal(signal, nxgroup, signal_name, **kwds): """Store the signal data as an NXdata dataset. - Parameters ---------- signal : Hyperspy signal @@ -1239,7 +1101,6 @@ def _write_signal(signal, nxgroup, signal_name, **kwds): Entry at which to save signal data signal_name : str Name under which to store the signal entry in the file - """ smd = signal.metadata.Signal if signal.axes_manager.signal_dimension == 1: @@ -1282,11 +1143,9 @@ def file_writer(filename, use_default=False, *args, **kwds): """Write the signal and metadata as a nexus file. - This will save the signal in NXdata format in the file. As the form of the metadata can vary and is not validated it will be stored as an NXcollection (an unvalidated collection) - Parameters ---------- filename : str @@ -1304,13 +1163,11 @@ def file_writer(filename, Option to define the default dataset in the file. If set to True the signal or first signal in the list of signals will be defined as the default (following Nexus v3 data rules). - See Also -------- * :py:meth:`~.io_plugins.nexus.file_reader` * :py:meth:`~.io_plugins.nexus.list_datasets_in_file` * :py:meth:`~.io_plugins.nexus.read_metadata_from_file` - """ if not isinstance(signals, list): signals = [signals] @@ -1322,12 +1179,60 @@ def file_writer(filename, kwds['compression'] = 'gzip' if use_default: f.attrs["default"] = "entry1" + # + # write the signals + # + + for i, sig in enumerate(signals): + nxentry = f.create_group("entry%d" % (i + 1)) + nxentry.attrs["NX_class"] = _parse_to_file("NXentry") - writer = NexusWriter(f, - signals, - save_original_metadata=True, - skip_metadata_key=None, - use_default=False, - *args, - **kwds) - writer.write_signal() + if isinstance(sig.metadata, dict): + sig.metadata = DictionaryTreeBrowser(sig.metadata) + if isinstance(sig.original_metadata, dict): + sig.original_metadata = DictionaryTreeBrowser( + sig.original_metadata) + + signal_name = sig.metadata.General.title \ + if sig.metadata.General.title else 'unnamed__%d' % i + if "/" in signal_name: + signal_name = signal_name.replace("/", "_") + if signal_name.startswith("__"): + signal_name = signal_name[2:] + + if i == 0 and use_default: + nxentry.attrs["default"] = signal_name + + nxaux = nxentry.create_group("auxiliary") + nxaux.attrs["NX_class"] = _parse_to_file("NXentry") + _write_signal(sig, nxentry, signal_name, **kwds) + + if sig.learning_results: + nxlearn = nxaux.create_group('learning_results') + nxlearn.attrs["NX_class"] = _parse_to_file("NXcollection") + learn = sig.learning_results.__dict__ + _write_nexus_groups(learn, nxlearn, **kwds) + _write_nexus_attr(learn, nxlearn) + # + # write metadata + # + if save_original_metadata: + if sig.original_metadata: + ometa = sig.original_metadata.as_dictionary() + + nxometa = nxaux.create_group('original_metadata') + nxometa.attrs["NX_class"] = _parse_to_file("NXcollection") + # write the groups and structure + _write_nexus_groups(ometa, nxometa, + skip_keys=skip_metadata_key, **kwds) + _write_nexus_attr(ometa, nxometa, + skip_keys=skip_metadata_key) + + if sig.metadata: + meta = sig.metadata.as_dictionary() + + nxometa = nxaux.create_group('hyperspy_metadata') + nxometa.attrs["NX_class"] = _parse_to_file("NXcollection") + # write the groups and structure + _write_nexus_groups(meta, nxometa, **kwds) + _write_nexus_attr(meta, nxometa) \ No newline at end of file diff --git a/hyperspy/io_plugins/zspy.py b/hyperspy/io_plugins/zspy.py index 22f784a4f5..3ffb92c1c4 100644 --- a/hyperspy/io_plugins/zspy.py +++ b/hyperspy/io_plugins/zspy.py @@ -18,6 +18,7 @@ import warnings import logging +from functools import partial import zarr from zarr import Array, Group @@ -26,7 +27,8 @@ from hyperspy.io_plugins.hspy import version import numcodecs -from hyperspy.io_plugins.hierarchical import HierarchicalWriter, HierarchicalReader + +from hyperspy.io_plugins.hierarchical import HierarchicalWriter, HierarchicalReader, _overwrite_dataset _logger = logging.getLogger(__name__) @@ -72,12 +74,84 @@ # Experiments instance +def get_object_dset(group, data, key, chunks, **kwds): + """Overrides the hyperspy get object dset function for using zarr as the backend + """ + these_kwds = kwds.copy() + these_kwds.update(dict(dtype=object, + exact=True, + chunks=chunks)) + dset = group.require_dataset(key, + data.shape, + object_codec=numcodecs.VLenArray(int), + **these_kwds) + return dset + + +def _get_signal_chunks(shape, dtype, signal_axes=None): + """Function that calculates chunks for the signal, + preferably at least one chunk per signal space. + Parameters + ---------- + shape : tuple + the shape of the dataset to be sored / chunked + dtype : {dtype, string} + the numpy dtype of the data + signal_axes: {None, iterable of ints} + the axes defining "signal space" of the dataset. If None, the default + zarr chunking is performed. + """ + typesize = np.dtype(dtype).itemsize + if signal_axes is None: + return None + # chunk size larger than 1 Mb https://zarr.readthedocs.io/en/stable/tutorial.html#chunk-optimizations + # shooting for 100 Mb chunks + total_size = np.prod(shape) * typesize + if total_size < 1e8: # 1 mb + return None + + +def _store_data(data, + dset, + group, + key, + chunks, + **kwds): + """Overrides the hyperspy store data function for using zarr as the backend + """ + if isinstance(data, da.Array): + if data.chunks != dset.chunks: + data = data.rechunk(dset.chunks) + path = group._store.dir_path() + "/" + dset.path + data.to_zarr(url=path, + overwrite=True, + **kwds) # add in compression etc + elif data.dtype == np.dtype('O'): + group[key][:] = data[:] # check lazy + else: + path = group._store.dir_path() + "/" + dset.path + dset = zarr.open_array(path, + mode="w", + shape=data.shape, + dtype=data.dtype, + chunks=chunks, + **kwds) + dset[:] = data + + +overwrite_dataset = partial(_overwrite_dataset, + get_signal_chunks=_get_signal_chunks, + get_object_dset=get_object_dset, + store_data=_store_data) + + class ZspyReader(HierarchicalReader): def __init__(self, file): super(ZspyReader, self).__init__(file) self.Dataset = Array self.Group = Group + class ZspyWriter(HierarchicalWriter): def __init__(self, file, @@ -89,69 +163,7 @@ def __init__(self, self.ragged_kwds = {"dtype": object, "object_codec": numcodecs.VLenArray(int), "exact": True} - - def store_data(self, - data, - dset, - group, - key, - chunks, - **kwds): - """Overrides the hyperspy store data function for using zarr as the backend - """ - if isinstance(data, da.Array): - if data.chunks != dset.chunks: - data = data.rechunk(dset.chunks) - path = group._store.dir_path() + "/" + dset.path - data.to_zarr(url=path, - overwrite=True, - **kwds) # add in compression etc - elif data.dtype == np.dtype('O'): - group[key][:] = data[:] # check lazy - else: - path = group._store.dir_path() + "/" + dset.path - dset = zarr.open_array(path, - mode="w", - shape=data.shape, - dtype=data.dtype, - chunks=chunks, - **kwds) - dset[:] = data - - def get_object_dset(self, group,data_shape, key, chunks, **kwds): - """Overrides the hyperspy get object dset function for using zarr as the backend - """ - these_kwds = kwds.copy() - these_kwds.update(dict(dtype=object, - exact=True, - chunks=chunks)) - dset = group.require_dataset(key, - data_shape, - object_codec=numcodecs.VLenArray(int), - **these_kwds) - return dset - - def get_signal_chunks(self, shape, dtype, signal_axes=None): - """Function that calculates chunks for the signal, - preferably at least one chunk per signal space. - Parameters - ---------- - shape : tuple - the shape of the dataset to be sored / chunked - dtype : {dtype, string} - the numpy dtype of the data - signal_axes: {None, iterable of ints} - the axes defining "signal space" of the dataset. If None, the default - zarr chunking is performed. - """ - typesize = np.dtype(dtype).itemsize - if signal_axes is None: - return None - # chunk size larger than 1 Mb https://zarr.readthedocs.io/en/stable/tutorial.html#chunk-optimizations - # shooting for 100 Mb chunks - total_size = np.prod(shape) * typesize - if total_size < 1e8: # 1 mb - return None + self.overwrite_dataset = overwrite_dataset def file_writer(filename, signal, *args, **kwds): From 6021f3fe630987dca4310b379a76623d1668b781 Mon Sep 17 00:00:00 2001 From: shaw Date: Thu, 16 Sep 2021 13:12:36 -0500 Subject: [PATCH 10/31] Added in the ability to define your storage container --- hyperspy/io_plugins/hierarchical.py | 3 --- hyperspy/io_plugins/zspy.py | 13 ++++++++++--- hyperspy/tests/io/test_zspy.py | 21 ++++++++++++++++++++- 3 files changed, 30 insertions(+), 7 deletions(-) diff --git a/hyperspy/io_plugins/hierarchical.py b/hyperspy/io_plugins/hierarchical.py index 3b4bf131d3..5dd9b4fce8 100644 --- a/hyperspy/io_plugins/hierarchical.py +++ b/hyperspy/io_plugins/hierarchical.py @@ -150,13 +150,10 @@ def read(self, lazy): exp_dict_list = [] if 'Experiments' in self.file: for ds in self.file['Experiments']: - print(type(self.file['Experiments'][ds])) - print(self.Group) if isinstance(self.file['Experiments'][ds], self.Group): if 'data' in self.file['Experiments'][ds]: experiments.append(ds) # Parse the file - print(experiments) for experiment in experiments: exg = self.file['Experiments'][experiment] diff --git a/hyperspy/io_plugins/zspy.py b/hyperspy/io_plugins/zspy.py index 3ffb92c1c4..5381740165 100644 --- a/hyperspy/io_plugins/zspy.py +++ b/hyperspy/io_plugins/zspy.py @@ -19,6 +19,7 @@ import warnings import logging from functools import partial +from collections import MutableMapping import zarr from zarr import Array, Group @@ -166,7 +167,10 @@ def __init__(self, self.overwrite_dataset = overwrite_dataset -def file_writer(filename, signal, *args, **kwds): +def file_writer(filename, + signal, + *args, + **kwds): """Writes data to hyperspy's zarr format Parameters ---------- @@ -178,8 +182,11 @@ def file_writer(filename, signal, *args, **kwds): if "compressor" not in kwds: from numcodecs import Blosc kwds["compressor"] = Blosc(cname='zstd', clevel=1) - store = zarr.storage.NestedDirectoryStore(filename,) - f = zarr.group(store=store, overwrite=True) + if "write_to_storage" in kwds and kwds["write_to_storage"]: + f = zarr.open(filename) + else: + store = zarr.storage.NestedDirectoryStore(filename,) + f = zarr.group(store=store, overwrite=True) f.attrs['file_format'] = "ZSpy" f.attrs['file_format_version'] = version exps = f.create_group('Experiments') diff --git a/hyperspy/tests/io/test_zspy.py b/hyperspy/tests/io/test_zspy.py index 907fd1aec2..7325c4b2d9 100644 --- a/hyperspy/tests/io/test_zspy.py +++ b/hyperspy/tests/io/test_zspy.py @@ -25,6 +25,7 @@ import dask.array as da import h5py +import lmdb import numpy as np import pytest import zarr @@ -686,4 +687,22 @@ def test_save_load_model(self, signal): signal.save(filename) signal2 = hs.load(filename) m2 = signal2.models.restore("test") - assert m.signal == m2.signal \ No newline at end of file + assert m.signal == m2.signal + + def test_save_N5_type(self,signal): + with tempfile.TemporaryDirectory() as tmp: + filename = tmp + '/testmodels.zspy' + store = zarr.N5Store(path=filename) + signal.save(store.path, write_to_storage=True) + signal2 = hs.load(filename) + np.testing.assert_array_equal(signal2.data, signal.data) + + @pytest.mark.skip(reason="lmdb must be installed to test") + def test_save_lmdb_type(self, signal): + with tempfile.TemporaryDirectory() as tmp: + os.mkdir(tmp+"/testmodels.zspy") + filename = tmp + '/testmodels.zspy/' + store = zarr.LMDBStore(path=filename) + signal.save(store.path, write_to_storage=True) + signal2 = hs.load(store.path) + np.testing.assert_array_equal(signal2.data, signal.data) \ No newline at end of file From db105d5dbe62d7f0dfac5decad0db66480112087 Mon Sep 17 00:00:00 2001 From: shaw Date: Thu, 16 Sep 2021 13:40:15 -0500 Subject: [PATCH 11/31] Updated Documentation including zarr format --- doc/user_guide/io.rst | 60 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/doc/user_guide/io.rst b/doc/user_guide/io.rst index 59cf4417ae..c2a8c5d384 100644 --- a/doc/user_guide/io.rst +++ b/doc/user_guide/io.rst @@ -251,6 +251,8 @@ HyperSpy. The "lazy" column specifies if lazy evaluation is supported. +-----------------------------------+--------+--------+--------+ | hspy | Yes | Yes | Yes | +-----------------------------------+--------+--------+--------+ + | zspy | Yes | Yes | Yes | + +-----------------------------------+--------+--------+--------+ | Image: e.g. jpg, png, tif, ... | Yes | Yes | Yes | +-----------------------------------+--------+--------+--------+ | TIFF | Yes | Yes | Yes | @@ -418,6 +420,64 @@ Extra saving arguments saving a file, be aware that it may not be possible to load it in some platforms. +.. _zspy-format: + +ZSpy - HyperSpy's Zarr Specification +------------------------------------ + +This is the an additional format which guarantees that no +information will be lost in the writing process and that supports saving data +of arbitrary dimensions. It is based on the `Zarr project `_. Which exists as a drop in +replacement for hdf5 with the intention to fix some of the speed and scaling +issues with the hdf5 format. **If you are working with very large datasets lazily +we recommend using .zspy for saving and loading your data** + +.. code-block:: python + + >>> s = hs.signals.BaseSignal([0]) + >>> s.save('test.zspy') # will save in nested directory + >>> hs.load('test.zspy') # loads the directory + + +When saving to ``zspy``, all supported objects in the signal's +:py:attr:`~.signal.BaseSignal.metadata` is stored. This includes lists, tuples and signals. +Please note that in order to increase saving efficiency and speed, if possible, +the inner-most structures are converted to numpy arrays when saved. This +procedure homogenizes any types of the objects inside, most notably casting +numbers as strings if any other strings are present: + +Extra saving arguments +^^^^^^^^^^^^^^^^^^^^^^ + +- ``compressor``: A `Numcodecs Codec `_. + A compresssor can be passed to the save function to compress the data efficiently. The defualt + is to call a Blosc Compressor object. + +.. code-block:: python + + >>> from numcodecs import Blosc + >>> compressor=Blosc(cname='zstd', clevel=1, shuffle=Blosc.SHUFFLE) + >>> s.save('test.zspy', compressor = compressor) # will save with Blosc compression + +.. note:: + + Compression can significantly increase the saving speed. If file size is not + an issue, it can be disabled by setting ``compressor=None``. In general we recommend + compressing your datasets as it can greatly reduce i-o overhead + +- ``write_to_storage``: The write to storage option allows you to pass the path to a directory (or database) + and write directly to the storage container. This gives you access to the `different storage methods + `_ + available through zarr. Namely using a SQL, MongoDB or LMDB database. Additional downloads may need + to be configured to use these features. + +.. code-block:: python + + >>> filename = 'test.zspy/' + >>> os.mkdir('test.zspy') + >>> store = zarr.LMDBStore(path=filename) + >>> signal.save(store.path, write_to_storage=True) # saved to Lmdb + .. _netcdf-format: NetCDF From fc2302da169e922820c6491f723d74302a3732e4 Mon Sep 17 00:00:00 2001 From: shaw Date: Thu, 16 Sep 2021 13:50:22 -0500 Subject: [PATCH 12/31] Updated Big data documentation --- doc/user_guide/big_data.rst | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/doc/user_guide/big_data.rst b/doc/user_guide/big_data.rst index 768aa91fed..f44efbf5d2 100644 --- a/doc/user_guide/big_data.rst +++ b/doc/user_guide/big_data.rst @@ -397,6 +397,14 @@ Other minor differences convenience, ``nansum``, ``nanmean`` and other ``nan*`` signal methods were added to mimic the workflow as closely as possible. +Saving Big Data +^^^^^^^^^^^^^^^^^ + +* **Zspy Format** When saving big data the hdf5 format struggles to compress, +save, load and operate on datasets > 50 GB. A better option is to use the +:ref:`zspy-format`. + +This also allows for smooth interaction with dask-distributed for efficient scaling. .. _lazy_details: From f766284f0311b400439ba12c35af6cb6d4f4fad5 Mon Sep 17 00:00:00 2001 From: shaw Date: Thu, 16 Sep 2021 13:53:58 -0500 Subject: [PATCH 13/31] Added in change description --- upcoming_changes/2798.new.rst | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 upcoming_changes/2798.new.rst diff --git a/upcoming_changes/2798.new.rst b/upcoming_changes/2798.new.rst new file mode 100644 index 0000000000..a3f3344930 --- /dev/null +++ b/upcoming_changes/2798.new.rst @@ -0,0 +1,2 @@ +Add in `zspy` saving specification for saving large data. Helps with saving and loading large datasets. +Using `zarr` background \ No newline at end of file From 7fe289a842788f4e47c09da705f5fe4fa317e574 Mon Sep 17 00:00:00 2001 From: shaw Date: Thu, 16 Sep 2021 14:02:41 -0500 Subject: [PATCH 14/31] Added Zarr as a requirement --- conda_environment.yml | 1 + setup.py | 1 + 2 files changed, 2 insertions(+) diff --git a/conda_environment.yml b/conda_environment.yml index 0ef9dc73fc..4abc06a009 100644 --- a/conda_environment.yml +++ b/conda_environment.yml @@ -28,5 +28,6 @@ dependencies: - toolz - tqdm - traits +- zarr diff --git a/setup.py b/setup.py index 309eb2d43b..238fc6c5af 100644 --- a/setup.py +++ b/setup.py @@ -72,6 +72,7 @@ # included in stdlib since v3.8, but this required version requires Python 3.10 # We can remove this requirement when the minimum supported version becomes Python 3.10 'importlib_metadata>=3.6', + 'zarr' ] extras_require = { From 57d96f17d7e90d9fca341ea02d92c1a4fae78340 Mon Sep 17 00:00:00 2001 From: shaw Date: Thu, 16 Sep 2021 14:25:34 -0500 Subject: [PATCH 15/31] removed unused imports --- hyperspy/tests/io/test_zspy.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/hyperspy/tests/io/test_zspy.py b/hyperspy/tests/io/test_zspy.py index 7325c4b2d9..0f73194b46 100644 --- a/hyperspy/tests/io/test_zspy.py +++ b/hyperspy/tests/io/test_zspy.py @@ -24,8 +24,6 @@ from os import remove import dask.array as da -import h5py -import lmdb import numpy as np import pytest import zarr From f0a2144a16a2e9f8bf83159a6b832324e18dc396 Mon Sep 17 00:00:00 2001 From: Carter Francis Date: Tue, 5 Oct 2021 09:47:19 -0500 Subject: [PATCH 16/31] Change wording to be more formal Co-authored-by: Eric Prestat --- doc/user_guide/big_data.rst | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/doc/user_guide/big_data.rst b/doc/user_guide/big_data.rst index f44efbf5d2..ffc98d49a0 100644 --- a/doc/user_guide/big_data.rst +++ b/doc/user_guide/big_data.rst @@ -400,9 +400,7 @@ Other minor differences Saving Big Data ^^^^^^^^^^^^^^^^^ -* **Zspy Format** When saving big data the hdf5 format struggles to compress, -save, load and operate on datasets > 50 GB. A better option is to use the -:ref:`zspy-format`. +The most efficient format supported by HyperSpy to write data is the :ref:` zspy format `, mainly because it supports writing currently from concurrently from multiple threads or processes. This also allows for smooth interaction with dask-distributed for efficient scaling. From 4132db5ece464c72ae36a078345224a7518c82d8 Mon Sep 17 00:00:00 2001 From: shaw Date: Tue, 5 Oct 2021 12:21:27 -0500 Subject: [PATCH 17/31] Cleaning up after review --- doc/user_guide/big_data.rst | 5 +- doc/user_guide/io.rst | 8 +- hyperspy/io.py | 9 +- hyperspy/io_plugins/hierarchical.py | 74 ++-- hyperspy/io_plugins/hspy.py | 8 +- hyperspy/io_plugins/nexus.py | 78 ++++ hyperspy/io_plugins/spy.py | 0 hyperspy/io_plugins/zspy.py | 7 +- hyperspy/tests/io/test_hdf5.py | 342 +++++++++------ hyperspy/tests/io/test_zspy.py | 641 +--------------------------- upcoming_changes/2798.new.rst | 3 +- 11 files changed, 375 insertions(+), 800 deletions(-) delete mode 100644 hyperspy/io_plugins/spy.py diff --git a/doc/user_guide/big_data.rst b/doc/user_guide/big_data.rst index ffc98d49a0..2da66f9e12 100644 --- a/doc/user_guide/big_data.rst +++ b/doc/user_guide/big_data.rst @@ -397,10 +397,13 @@ Other minor differences convenience, ``nansum``, ``nanmean`` and other ``nan*`` signal methods were added to mimic the workflow as closely as possible. +.. _big_data.saving: + Saving Big Data ^^^^^^^^^^^^^^^^^ -The most efficient format supported by HyperSpy to write data is the :ref:` zspy format `, mainly because it supports writing currently from concurrently from multiple threads or processes. +The most efficient format supported by HyperSpy to write data is the :ref:` zspy format `, +mainly because it supports writing currently from concurrently from multiple threads or processes. This also allows for smooth interaction with dask-distributed for efficient scaling. diff --git a/doc/user_guide/io.rst b/doc/user_guide/io.rst index c2a8c5d384..e7eb82aed4 100644 --- a/doc/user_guide/io.rst +++ b/doc/user_guide/io.rst @@ -425,12 +425,12 @@ Extra saving arguments ZSpy - HyperSpy's Zarr Specification ------------------------------------ -This is the an additional format which guarantees that no +Similarly to the :ref:`hspy format `, the zspy format guarantees that no information will be lost in the writing process and that supports saving data of arbitrary dimensions. It is based on the `Zarr project `_. Which exists as a drop in replacement for hdf5 with the intention to fix some of the speed and scaling -issues with the hdf5 format. **If you are working with very large datasets lazily -we recommend using .zspy for saving and loading your data** +issues with the hdf5 format and is therefore suitable for saving :ref:`big data `. + .. code-block:: python @@ -439,7 +439,7 @@ we recommend using .zspy for saving and loading your data** >>> hs.load('test.zspy') # loads the directory -When saving to ``zspy``, all supported objects in the signal's +When saving to `zspy `_, all supported objects in the signal's :py:attr:`~.signal.BaseSignal.metadata` is stored. This includes lists, tuples and signals. Please note that in order to increase saving efficiency and speed, if possible, the inner-most structures are converted to numpy arrays when saved. This diff --git a/hyperspy/io.py b/hyperspy/io.py index f47f60bc45..c5aa5db287 100644 --- a/hyperspy/io.py +++ b/hyperspy/io.py @@ -334,7 +334,8 @@ def load(filenames=None, filenames = _escape_square_brackets(filenames) filenames = natsorted([f for f in glob.glob(filenames) - if os.path.isfile(f) or os.path.isdir(f)]) + if os.path.isfile(f) or (os.path.isdir(f) and + os.path.splitext(f)[1] == '.zspy')]) if not filenames: raise ValueError(f'No filename matches the pattern "{pattern}"') @@ -342,7 +343,8 @@ def load(filenames=None, elif isinstance(filenames, Path): # Just convert to list for now, pathlib.Path not # fully supported in io_plugins - filenames = [f for f in [filenames] if f.is_file() or f.is_dir()] + filenames = [f for f in [filenames] + if f.is_file() or (f.is_dir() and ".zspy" in f.name)] elif isgenerator(filenames): filenames = list(filenames) @@ -449,7 +451,8 @@ def load_single_file(filename, **kwds): Data loaded from the file. """ - if not os.path.isfile(filename) and not os.path.isdir(filename): + if not os.path.isfile(filename) and not (os.path.isdir(filename) or + os.path.splitext(filename)[1] == '.zspy'): raise FileNotFoundError(f"File: {filename} not found!") # File extension without "." separator diff --git a/hyperspy/io_plugins/hierarchical.py b/hyperspy/io_plugins/hierarchical.py index 5dd9b4fce8..8351700a52 100644 --- a/hyperspy/io_plugins/hierarchical.py +++ b/hyperspy/io_plugins/hierarchical.py @@ -1,4 +1,4 @@ -from distutils.version import LooseVersion +from packaging.version import Version import warnings import logging import datetime @@ -12,8 +12,10 @@ from hyperspy.misc.utils import ensure_unicode, multiply, get_object_package_info from hyperspy.axes import AxesManager + version = "3.1" -default_version = LooseVersion(version) + +default_version = Version(version) not_valid_format = 'The file is not a valid HyperSpy hdf5 file' @@ -21,15 +23,15 @@ def _overwrite_dataset(group, - data, - key, - signal_axes=None, - chunks=None, - get_signal_chunks=None, - get_object_dset=None, - store_data=None, - **kwds): - """Overwrites some dataset in a hirarchical dataset. + data, + key, + signal_axes=None, + chunks=None, + get_signal_chunks=None, + get_object_dset=None, + store_data=None, + **kwds): + """Overwrites some dataset in a hierarchical dataset. Parameters ---------- @@ -45,10 +47,13 @@ def _overwrite_dataset(group, The chunks for the dataset. If None then get_signal_chunks will be called get_signal_chunks: func A function to get the signal chunks for the dataset - get_object_dset: + get_object_dset: func A function to get the object dset for saving ragged arrays. + store_data: func + A function to store the data in some hierarchical data format kwds: - Any additional keywords + Any additional keywords for to be passed to the store data function + The store data function is passed for each hierarchical data format Returns ------- @@ -100,6 +105,13 @@ class HierarchicalReader: """A generic Reader class for reading data from hierarchical file types.""" def __init__(self, file): + """ Initalizes a general reader for hierarchical signals. + + Parmeters + --------- + file: str + A file to be read. + """ self.file = file self.version = self.get_format_version() self.Dataset = None @@ -107,7 +119,6 @@ def __init__(self, self.unicode_kwds = None self.ragged_kwds = None - def get_format_version(self): if "file_format_version" in self.file.attrs: version = self.file.attrs["file_format_version"] @@ -123,7 +134,7 @@ def get_format_version(self): version = "2.0" else: raise IOError(not_valid_format) - return LooseVersion(version) + return Version(version) def read(self, lazy): models_with_signals = [] @@ -182,7 +193,7 @@ def read(self, lazy): def group2signaldict(self, group, lazy=False): - if self.version < LooseVersion("1.2"): + if self.version < Version("1.2"): metadata = "mapped_parameters" original_metadata = "original_parameters" else: @@ -251,7 +262,7 @@ def group2signaldict(self, if '__unnamed__' == exp['metadata']['General']['title']: exp['metadata']["General"]['title'] = '' - if self.version < LooseVersion("1.1"): + if self.version < Version("1.1"): # Load the decomposition results written with the old name, # mva_results if 'mva_results' in group.keys(): @@ -275,7 +286,7 @@ def group2signaldict(self, exp['metadata']['name'] del exp['metadata']['name'] - if self.version < LooseVersion("1.2"): + if self.version < Version("1.2"): if '_internal_parameters' in exp['metadata']: exp['metadata']['_HyperSpy'] = \ exp['metadata']['_internal_parameters'] @@ -361,7 +372,7 @@ def group2signaldict(self, exp["metadata"]["Signal"][key] = exp["metadata"][key] del exp["metadata"][key] - if self.version < LooseVersion("3.0"): + if self.version < Version("3.0"): if "Acquisition_instrument" in exp["metadata"]: # Move tilt_stage to Stage.tilt_alpha # Move exposure time to Detector.Camera.exposure_time @@ -504,11 +515,24 @@ class HierarchicalWriter: def __init__(self, file, signal, - expg, + group, **kwds): + """Initialize a generic file writer for hierachical data storage types. + + Parameters + ---------- + file: str + The file where the signal is to be saved + signal: BaseSignal + A BaseSignal to be saved + group: Group + A group to where the experimental data will be saved. + kwds: + Any additional keywords used for saving the data. + """ self.file = file self.signal = signal - self.expg = expg + self.group = group self.Dataset = None self.Group = None self.unicode_kwds = None @@ -518,14 +542,14 @@ def __init__(self, def write(self): self.write_signal(self.signal, - self.expg, + self.group, **self.kwds) def write_signal(self, signal, group, **kwds): "Writes a hyperspy signal to a hdf5 group" group.attrs.update(get_object_package_info(signal)) - if default_version < LooseVersion("1.2"): + if Version(version) < Version("1.2"): metadata = "mapped_parameters" original_metadata = "original_parameters" else: @@ -546,7 +570,7 @@ def write_signal(self, signal, group, **kwds): 'data', signal_axes=signal.axes_manager.signal_indices_in_array, **kwds) - if default_version < LooseVersion("1.2"): + if default_version < Version("1.2"): metadata_dict["_internal_parameters"] = \ metadata_dict.pop("_HyperSpy") # Remove chunks from the kwds since it wouldn't have the same rank as the @@ -648,8 +672,6 @@ def parse_structure(self, key, group, value, _type, **kwds): elif tmp.dtype.type is np.unicode_: if _type + key in group: del group[_type + key] - print(self.unicode_kwds) - print(kwds) group.create_dataset(_type + key, shape=tmp.shape, **self.unicode_kwds, diff --git a/hyperspy/io_plugins/hspy.py b/hyperspy/io_plugins/hspy.py index 1c1ac57ad5..802fef3dbc 100644 --- a/hyperspy/io_plugins/hspy.py +++ b/hyperspy/io_plugins/hspy.py @@ -25,7 +25,7 @@ import dask.array as da from h5py import Dataset, File, Group -from hyperspy.io_plugins.hierarchical import HierarchicalWriter, HierarchicalReader, _overwrite_dataset +from hyperspy.io_plugins.hierarchical import HierarchicalWriter, HierarchicalReader, _overwrite_dataset, version from hyperspy.misc.utils import multiply _logger = logging.getLogger(__name__) @@ -43,7 +43,7 @@ # Writing capabilities writes = True non_uniform_axis = True -version = "3.1" +version = version # ---------------------- # ----------------------- @@ -197,7 +197,7 @@ def __init__(self, file): super().__init__(file) self.Dataset = Dataset self.Group = Group - self.unicode_kwds = {"dtype":h5py.special_dtype(vlen=str)} + self.unicode_kwds = {"dtype": h5py.special_dtype(vlen=str)} class HyperspyWriter(HierarchicalWriter): @@ -243,7 +243,7 @@ def file_reader( # hdf5 file, so the following two lines must not be deleted or moved # elsewhere. reader = HyperspyReader(f) - if reader.version > version: + if reader.version > Version(version): warnings.warn( "This file was written using a newer version of the " "HyperSpy hdf5 file format. I will attempt to load it, but, " diff --git a/hyperspy/io_plugins/nexus.py b/hyperspy/io_plugins/nexus.py index 2f7b801e75..2c3095bb65 100644 --- a/hyperspy/io_plugins/nexus.py +++ b/hyperspy/io_plugins/nexus.py @@ -48,6 +48,7 @@ def _byte_to_string(value): """Decode a byte string. + Parameters ---------- value : byte str @@ -55,6 +56,7 @@ def _byte_to_string(value): ------- str decoded version of input value + """ try: text = value.decode("utf-8") @@ -65,20 +67,24 @@ def _byte_to_string(value): def _parse_from_file(value, lazy=False): """To convert values from the hdf file to compatible formats. + When reading string arrays we convert or keep string arrays as byte_strings (some io_plugins only supports byte-strings arrays so this ensures inter-compatibility across io_plugins) Arrays of length 1 - return the single value stored. Large datasets are returned as dask arrays if lazy=True. + Parameters ---------- value : input read from hdf file (array,list,tuple,string,int,float) lazy : bool {default: False} The lazy flag is only applied to values of size >=2 + Returns ------- str,int, float, ndarray dask Array parsed value. + """ toreturn = value if isinstance(value, h5py.Dataset): @@ -107,14 +113,18 @@ def _parse_from_file(value, lazy=False): def _parse_to_file(value): """Convert to a suitable format for writing to HDF5. + For example unicode values are not compatible with hdf5 so conversion to byte strings is required. + Parameters ---------- value - input object to write to the hdf file + Returns ------- parsed value + """ totest = value toreturn = totest @@ -134,14 +144,17 @@ def _parse_to_file(value): def _text_split(s, sep): """Split a string based of list of seperators. + Parameters ---------- s : str sep : str - seperator or list of seperators e.g. '.' or ['_','/'] + Returns ------- list String sections split based on the seperators + """ stack = [s] for char in sep: @@ -156,13 +169,16 @@ def _text_split(s, sep): def _getlink(h5group, rootkey, key): """Return the link target path. + If a hdf group is a soft link or has a target attribute this method will return the target path. If no link is found return None. + Returns ------- str Soft link path if it exists, otherwise None + """ _target = None if rootkey != '/': @@ -182,16 +198,19 @@ def _getlink(h5group, rootkey, key): def _get_nav_list(data, dataentry): """Get the list with information of each axes of the dataset + Parameters ---------- data : hdf dataset the dataset to be loaded. dataentry : hdf group the group with corresponding attributes. + Returns ------- nav_list : list contains information about each axes. + """ detector_index = 0 @@ -277,6 +296,7 @@ def _get_nav_list(data, dataentry): def _extract_hdf_dataset(group, dataset, lazy=False): """Import data from hdf path. + Parameters ---------- group : hdf group @@ -285,10 +305,12 @@ def _extract_hdf_dataset(group, dataset, lazy=False): path to the dataset within the group lazy : bool {default:True} If true use lazy opening, if false read into memory + Returns ------- dict A signal dictionary which can be used to instantiate a signal. + """ data = group[dataset] @@ -323,6 +345,7 @@ def _extract_hdf_dataset(group, dataset, lazy=False): def _nexus_dataset_to_signal(group, nexus_dataset_path, lazy=False): """Load an NXdata set as a hyperspy signal. + Parameters ---------- group : hdf group containing the NXdata @@ -330,10 +353,12 @@ def _nexus_dataset_to_signal(group, nexus_dataset_path, lazy=False): Path to the NXdata set in the group lazy : bool, default : True lazy loading of data + Returns ------- dict A signal dictionary which can be used to instantiate a signal. + """ interpretation = None @@ -409,6 +434,7 @@ def file_reader(filename, lazy=False, dataset_key=None, dataset_path=None, use_default=False, **kwds): """Read NXdata class or hdf datasets from a file and return signal(s). + Note ---- Loading all datasets can result in a large number of signals @@ -417,6 +443,7 @@ def file_reader(filename, lazy=False, dataset_key=None, dataset_path=None, "keys" is a special keywords and prepended with "fix" in the metadata structure to avoid any issues. Datasets are all arrays with size>2 (arrays, lists) + Parameters ---------- filename : str @@ -452,13 +479,18 @@ def file_reader(filename, lazy=False, dataset_key=None, dataset_path=None, signal. This will ignore the other keyword options. If True and no default is defined the file will be loaded according to the keyword options. + Returns ------- dict : signal dictionary or list of signal dictionaries + + See Also -------- * :py:meth:`~.io_plugins.nexus.list_datasets_in_file` * :py:meth:`~.io_plugins.nexus.read_metadata_from_file` + + """ # search for NXdata sets... @@ -591,13 +623,16 @@ def file_reader(filename, lazy=False, dataset_key=None, dataset_path=None, def _is_linear_axis(data): """Check if the data is linearly incrementing. + Parameters ---------- data : dask or numpy array + Returns ------- bool True or False + """ steps = np.diff(data) est_steps = np.array([steps[0]]*len(steps)) @@ -606,13 +641,16 @@ def _is_linear_axis(data): def _is_numeric_data(data): """Check that data contains numeric data. + Parameters ---------- data : dask or numpy array + Returns ------- bool True or False + """ try: data.astype(float) @@ -623,13 +661,16 @@ def _is_numeric_data(data): def _is_int(s): """Check that s in an integer. + Parameters ---------- s : python object to test + Returns ------- bool True or False + """ try: int(s) @@ -656,6 +697,7 @@ def _check_search_keys(search_keys): def _find_data(group, search_keys=None, hardlinks_only=False, absolute_path=None): """Read from a nexus or hdf file and return a list of the dataset entries. + The method iterates through group attributes and returns NXdata or hdf datasets of size >=2 if they're not already NXdata blocks and returns a list of the entries @@ -664,6 +706,8 @@ def _find_data(group, search_keys=None, hardlinks_only=False, h5py.visit or visititems does not visit soft links or external links so an implementation of a recursive search is required. See https://github.com/h5py/h5py/issues/671 + + Parameters ---------- group : hdf group or File @@ -675,12 +719,14 @@ def _find_data(group, search_keys=None, hardlinks_only=False, Option to ignore links (soft or External) within the file. absolute_path : string, list of strings or None, default: None Return items with the exact specified absolute path + Returns ------- nx_dataset_list, hdf_dataset_list nx_dataset_list is a list of all NXdata paths hdf_dataset_list is a list of all hdf_datasets not linked to an NXdata set. + """ _check_search_keys(search_keys) _check_search_keys(absolute_path) @@ -762,9 +808,11 @@ def find_data_in_tree(group, rootname): def _load_metadata(group, lazy=False, skip_array_metadata=False): """Search through a hdf group and return the group structure. + h5py.visit or visititems does not visit soft links or external links so an implementation of a recursive search is required. See https://github.com/h5py/h5py/issues/671 + Parameters ---------- group : hdf group @@ -773,10 +821,13 @@ def _load_metadata(group, lazy=False, skip_array_metadata=False): Option for lazy loading skip_array_metadata : bool, default : False whether to skip loading array metadata + Returns ------- dict dictionary of group contents + + """ rootname = "" @@ -841,16 +892,20 @@ def find_meta_in_tree(group, rootname, lazy=False, def _fix_exclusion_keys(key): """Exclude hyperspy specific keys. + Signal and DictionaryBrowser break if a a key is a dict method - e.g. {"keys":2.0}. This method prepends the key with ``fix_`` so the information is still present to work around this issue + Parameters ---------- key : str + Returns ------- str + """ if key.startswith("keys"): return "fix_"+key @@ -860,8 +915,10 @@ def _fix_exclusion_keys(key): def _find_search_keys_in_dict(tree, search_keys=None): """Search through a dict for search keys. + This is a convenience method to inspect a file for a value rather than loading the file as a signal + Parameters ---------- tree : h5py File object @@ -869,11 +926,13 @@ def _find_search_keys_in_dict(tree, search_keys=None): Only return items which contain the strings .e.g search_keys = ["instrument","Fe"] will return hdf entries with instrument or Fe in their hdf path. + Returns ------- dict When search_list is specified only full paths containing one or more search_keys will be returned + """ _check_search_keys(search_keys) metadata_dict = {} @@ -906,6 +965,7 @@ def find_searchkeys_in_tree(myDict, rootname): def _write_nexus_groups(dictionary, group, skip_keys=None, **kwds): """Recursively iterate throuh dictionary and write groups to nexus. + Parameters ---------- dictionary : dict @@ -916,6 +976,7 @@ def _write_nexus_groups(dictionary, group, skip_keys=None, **kwds): the key(s) to skip when writing into the group **kwds : additional keywords additional keywords to pass to h5py.create_dataset method + """ if skip_keys is None: skip_keys = [] @@ -958,13 +1019,16 @@ def _write_nexus_groups(dictionary, group, skip_keys=None, **kwds): def _write_nexus_attr(dictionary, group, skip_keys=None): """Recursively iterate through dictionary and write "attrs" dictionaries. + This step is called after the groups and datasets have been created + Parameters ---------- dictionary : dict Input dictionary to be written to the hdf group group : hdf group location to store the attrs sections of the dictionary + """ if skip_keys is None: skip_keys = [] @@ -991,10 +1055,12 @@ def read_metadata_from_file(filename, metadata_key=None, lazy=False, verbose=False, skip_array_metadata=False): """Read the metadata from a nexus or hdf file. + This method iterates through the file and returns a dictionary of the entries. This is a convenience method to inspect a file for a value rather than loading the file as a signal. + Parameters ---------- filename : str @@ -1011,10 +1077,12 @@ def read_metadata_from_file(filename, metadata_key=None, Whether to skip loading array metadata. This is useful as a lot of large array may be present in the metadata and it is redundant with dataset itself. + Returns ------- dict Metadata dictionary. + See Also -------- * :py:meth:`~.io_plugins.nexus.file_reader` @@ -1040,12 +1108,14 @@ def list_datasets_in_file(filename, dataset_key=None, hardlinks_only=False, verbose=True): """Read from a nexus or hdf file and return a list of the dataset paths. + This method is used to inspect the contents of a Nexus file. The method iterates through group attributes and returns NXdata or hdf datasets of size >=2 if they're not already NXdata blocks and returns a list of the entries. This is a convenience method to inspect a file to list datasets present rather than loading all the datasets in the file as signals. + Parameters ---------- filename : str @@ -1059,10 +1129,12 @@ def list_datasets_in_file(filename, dataset_key=None, If true any links (soft or External) will be ignored when loading. verbose : boolean, default : True Prints the results to screen + Returns ------- list list of paths to datasets + See Also -------- * :py:meth:`~.io_plugins.nexus.file_reader` @@ -1094,6 +1166,7 @@ def list_datasets_in_file(filename, dataset_key=None, def _write_signal(signal, nxgroup, signal_name, **kwds): """Store the signal data as an NXdata dataset. + Parameters ---------- signal : Hyperspy signal @@ -1101,6 +1174,7 @@ def _write_signal(signal, nxgroup, signal_name, **kwds): Entry at which to save signal data signal_name : str Name under which to store the signal entry in the file + """ smd = signal.metadata.Signal if signal.axes_manager.signal_dimension == 1: @@ -1143,9 +1217,11 @@ def file_writer(filename, use_default=False, *args, **kwds): """Write the signal and metadata as a nexus file. + This will save the signal in NXdata format in the file. As the form of the metadata can vary and is not validated it will be stored as an NXcollection (an unvalidated collection) + Parameters ---------- filename : str @@ -1163,11 +1239,13 @@ def file_writer(filename, Option to define the default dataset in the file. If set to True the signal or first signal in the list of signals will be defined as the default (following Nexus v3 data rules). + See Also -------- * :py:meth:`~.io_plugins.nexus.file_reader` * :py:meth:`~.io_plugins.nexus.list_datasets_in_file` * :py:meth:`~.io_plugins.nexus.read_metadata_from_file` + """ if not isinstance(signals, list): signals = [signals] diff --git a/hyperspy/io_plugins/spy.py b/hyperspy/io_plugins/spy.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/hyperspy/io_plugins/zspy.py b/hyperspy/io_plugins/zspy.py index 5381740165..e2ccc19c2b 100644 --- a/hyperspy/io_plugins/zspy.py +++ b/hyperspy/io_plugins/zspy.py @@ -19,16 +19,17 @@ import warnings import logging from functools import partial -from collections import MutableMapping +from packaging.version import Version import zarr from zarr import Array, Group import numpy as np import dask.array as da -from hyperspy.io_plugins.hspy import version +from hyperspy.io_plugins.hierarchical import version import numcodecs + from hyperspy.io_plugins.hierarchical import HierarchicalWriter, HierarchicalReader, _overwrite_dataset _logger = logging.getLogger(__name__) @@ -228,7 +229,7 @@ def file_reader(filename, mode = kwds.pop('mode', 'r') f = zarr.open(filename, mode=mode, **kwds) reader = ZspyReader(f) - if reader.version > version: + if reader.version > Version(version): warnings.warn( "This file was written using a newer version of the " "HyperSpy zspy file format. I will attempt to load it, but, " diff --git a/hyperspy/tests/io/test_hdf5.py b/hyperspy/tests/io/test_hdf5.py index a2587e810e..7f6e6ba082 100644 --- a/hyperspy/tests/io/test_hdf5.py +++ b/hyperspy/tests/io/test_hdf5.py @@ -29,7 +29,7 @@ import pytest from hyperspy.io import load -from hyperspy.axes import DataAxis, UniformDataAxis, FunctionalDataAxis +from hyperspy.axes import DataAxis, UniformDataAxis, FunctionalDataAxis, AxesManager from hyperspy.signal import BaseSignal from hyperspy._signals.signal1d import Signal1D from hyperspy._signals.signal2d import Signal2D @@ -188,15 +188,17 @@ def test_binary_string(self): assert f(3.5) == 4.5 + class TestSavingMetadataContainers: def setup_method(self, method): self.s = BaseSignal([0.1]) - def test_save_unicode(self, tmp_path): + @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + def test_save_unicode(self, tmp_path, file): s = self.s s.metadata.set_item('test', ['a', 'b', '\u6f22\u5b57']) - fname = tmp_path / 'test.hspy' + fname = tmp_path / file s.save(fname) l = load(fname) assert isinstance(l.metadata.test[0], str) @@ -204,19 +206,21 @@ def test_save_unicode(self, tmp_path): assert isinstance(l.metadata.test[2], str) assert l.metadata.test[2] == '\u6f22\u5b57' - def test_save_long_list(self, tmp_path): + @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + def test_save_long_list(self, tmp_path, file): s = self.s s.metadata.set_item('long_list', list(range(10000))) start = time.time() - fname = tmp_path / 'test.hspy' + fname = tmp_path / file s.save(fname) end = time.time() assert end - start < 1.0 # It should finish in less that 1 s. - def test_numpy_only_inner_lists(self, tmp_path): + @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + def test_numpy_only_inner_lists(self, tmp_path, file): s = self.s s.metadata.set_item('test', [[1., 2], ('3', 4)]) - fname = tmp_path / 'test.hspy' + fname = tmp_path / file s.save(fname) l = load(fname) assert isinstance(l.metadata.test, list) @@ -225,20 +229,22 @@ def test_numpy_only_inner_lists(self, tmp_path): @pytest.mark.xfail(sys.platform == 'win32', reason="randomly fails in win32") - def test_numpy_general_type(self, tmp_path): + @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + def test_numpy_general_type(self, tmp_path, file): s = self.s s.metadata.set_item('test', np.array([[1., 2], ['3', 4]])) - fname = tmp_path / 'test.hspy' + fname = tmp_path / file s.save(fname) l = load(fname) np.testing.assert_array_equal(l.metadata.test, s.metadata.test) @pytest.mark.xfail(sys.platform == 'win32', reason="randomly fails in win32") - def test_list_general_type(self, tmp_path): + @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + def test_list_general_type(self, tmp_path, file): s = self.s s.metadata.set_item('test', [[1., 2], ['3', 4]]) - fname = tmp_path / 'test.hspy' + fname = tmp_path / file s.save(fname) l = load(fname) assert isinstance(l.metadata.test[0][0], float) @@ -248,10 +254,11 @@ def test_list_general_type(self, tmp_path): @pytest.mark.xfail(sys.platform == 'win32', reason="randomly fails in win32") - def test_general_type_not_working(self, tmp_path): + @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + def test_general_type_not_working(self, tmp_path, file): s = self.s s.metadata.set_item('test', (BaseSignal([1]), 0.1, 'test_string')) - fname = tmp_path / 'test.hspy' + fname = tmp_path / file s.save(fname) l = load(fname) assert isinstance(l.metadata.test, tuple) @@ -259,26 +266,29 @@ def test_general_type_not_working(self, tmp_path): assert isinstance(l.metadata.test[1], float) assert isinstance(l.metadata.test[2], str) - def test_unsupported_type(self, tmp_path): + @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + def test_unsupported_type(self, tmp_path, file): s = self.s s.metadata.set_item('test', Point2DROI(1, 2)) - fname = tmp_path / 'test.hspy' + fname = tmp_path / file s.save(fname) l = load(fname) assert 'test' not in l.metadata - def test_date_time(self, tmp_path): + @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + def test_date_time(self, tmp_path, file): s = self.s date, time = "2016-08-05", "15:00:00.450" s.metadata.General.date = date s.metadata.General.time = time - fname = tmp_path / 'test.hspy' + fname = tmp_path / file s.save(fname) l = load(fname) assert l.metadata.General.date == date assert l.metadata.General.time == time - def test_general_metadata(self, tmp_path): + @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + def test_general_metadata(self, tmp_path, file): s = self.s notes = "Dummy notes" authors = "Author 1, Author 2" @@ -286,22 +296,62 @@ def test_general_metadata(self, tmp_path): s.metadata.General.notes = notes s.metadata.General.authors = authors s.metadata.General.doi = doi - fname = tmp_path / 'test.hspy' + fname = tmp_path / file s.save(fname) l = load(fname) assert l.metadata.General.notes == notes assert l.metadata.General.authors == authors assert l.metadata.General.doi == doi - def test_quantity(self, tmp_path): + @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + def test_quantity(self, tmp_path, file): s = self.s quantity = "Intensity (electron)" s.metadata.Signal.quantity = quantity - fname = tmp_path / 'test.hspy' + fname = tmp_path / file s.save(fname) l = load(fname) assert l.metadata.Signal.quantity == quantity + @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + def test_save_axes_manager(self, tmp_path, file): + s = self.s + s.metadata.set_item('test', s.axes_manager) + fname = tmp_path / file + s.save(fname) + l = load(fname) + #strange becuase you need the encoding... + assert isinstance(l.metadata.test, AxesManager) + + @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + def test_title(self, tmp_path, file): + s = self.s + fname = tmp_path / file + s.metadata.General.title = '__unnamed__' + s.save(fname) + l = load(fname) + assert l.metadata.General.title is "" + + @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + def test_save_empty_tuple(self, tmp_path, file): + s = self.s + s.metadata.set_item('test', ()) + fname = tmp_path / file + s.save(fname) + l = load(fname) + #strange becuase you need the encoding... + assert l.metadata.test == s.metadata.test + + @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + def test_save_bytes(self, tmp_path,file): + s = self.s + byte_message = bytes("testing", 'utf-8') + s.metadata.set_item('test', byte_message) + fname = tmp_path / file + s.save(fname) + l = load(fname) + assert l.metadata.test == s.metadata.test.decode() + def test_metadata_binned_deprecate(self): with pytest.warns(UserWarning, match="Loading old file"): s = load(os.path.join(my_path, "hdf5_files", 'example2_v2.2.hspy')) @@ -348,28 +398,30 @@ def test_rgba16(): data = np.load(os.path.join( my_path, "npy_files", "test_rgba16.npy")) assert (s.data == data).all() - -def test_nonuniformaxis(): +@pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) +def test_nonuniformaxis(tmp_path, file): + fname = tmp_path / file data = np.arange(10) - axis = DataAxis(axis = 1/np.arange(1,data.size+1), navigate = False) - s = Signal1D(data, axes = (axis.get_axis_dictionary(), )) - s.save('tmp.hdf5', overwrite = True) - s2 = load('tmp.hdf5') + axis = DataAxis(axis=1/np.arange(1, data.size+1), navigate=False) + s = Signal1D(data, axes=(axis.get_axis_dictionary(), )) + s.save(fname, overwrite=True) + s2 = load(fname) np.testing.assert_array_almost_equal(s.axes_manager[0].axis, s2.axes_manager[0].axis) assert(s2.axes_manager[0].is_uniform == False) assert(s2.axes_manager[0].navigate == False) assert(s2.axes_manager[0].size == data.size) - -def test_nonuniformFDA(): +@pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) +def test_nonuniformFDA(tmp_path, file): + fname = tmp_path / file data = np.arange(10) x0 = UniformDataAxis(size=data.size, offset=1) axis = FunctionalDataAxis(expression = '1/x', x = x0, navigate = False) - s = Signal1D(data, axes = (axis.get_axis_dictionary(), )) + s = Signal1D(data, axes=(axis.get_axis_dictionary(), )) print(axis.get_axis_dictionary()) - s.save('tmp.hdf5', overwrite = True) - s2 = load('tmp.hdf5') + s.save(fname, overwrite=True) + s2 = load(fname) np.testing.assert_array_almost_equal(s.axes_manager[0].axis, s2.axes_manager[0].axis) assert(s2.axes_manager[0].is_uniform == False) @@ -428,63 +480,53 @@ def teardown_method(self, method): class TestAxesConfiguration: - - - def setup_method(self, method): - self.s = BaseSignal(np.zeros((2, 2, 2, 2, 2))) - self.s.axes_manager.signal_axes[0].navigate = True - self.s.axes_manager.signal_axes[0].navigate = True - - def test_axes_configuration(self): - self.filename = 'testfile.hdf5' - self.s.save(self.filename, overwrite = True) - s = load(self.filename) + @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + def test_axes_binning(self, tmp_path, file): + fname = tmp_path / file + s = BaseSignal(np.zeros((2, 2, 2))) + s.axes_manager.signal_axes[-1].is_binned = True + s.save(fname) + s = load(fname) + assert s.axes_manager.signal_axes[-1].is_binned + + @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + def test_axes_configuration(self, tmp_path, file): + fname = tmp_path / file + s = BaseSignal(np.zeros((2, 2, 2, 2, 2))) + s.axes_manager.signal_axes[0].navigate = True + s.axes_manager.signal_axes[0].navigate = True + s.save(fname) + s = load(fname) assert s.axes_manager.navigation_axes[0].index_in_array == 4 assert s.axes_manager.navigation_axes[1].index_in_array == 3 assert s.axes_manager.signal_dimension == 3 - def teardown_method(self, method): - remove(self.filename) - -class TestAxesConfigurationBinning: - - def setup_method(self, method): - self.filename = 'testfile.hdf5' - s = BaseSignal(np.zeros((2, 2, 2))) - s.axes_manager.signal_axes[-1].is_binned = True - s.save(self.filename) - - def test_axes_configuration(self): - s = load(self.filename) - assert s.axes_manager.signal_axes[-1].is_binned == True - - def teardown_method(self, method): - remove(self.filename) - class Test_permanent_markers_io: - def test_save_permanent_marker(self): + @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + def test_save_permanent_marker(self, tmp_path, file): + filename = tmp_path / file s = Signal2D(np.arange(100).reshape(10, 10)) m = markers.point(x=5, y=5) s.add_marker(m, permanent=True) - with tempfile.TemporaryDirectory() as tmp: - filename = tmp + '/testsavefile.hdf5' s.save(filename) - def test_save_load_empty_metadata_markers(self): + @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + def test_save_load_empty_metadata_markers(self, tmp_path, file): + filename = tmp_path / file s = Signal2D(np.arange(100).reshape(10, 10)) m = markers.point(x=5, y=5) m.name = "test" s.add_marker(m, permanent=True) del s.metadata.Markers.test - with tempfile.TemporaryDirectory() as tmp: - filename = tmp + '/testsavefile.hdf5' s.save(filename) s1 = load(filename) assert len(s1.metadata.Markers) == 0 - def test_save_load_permanent_marker(self): + @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + def test_save_load_permanent_marker(self, tmp_path, file): + filename = tmp_path / file x, y = 5, 2 color = 'red' size = 10 @@ -493,8 +535,6 @@ def test_save_load_permanent_marker(self): m = markers.point(x=x, y=y, color=color, size=size) m.name = name s.add_marker(m, permanent=True) - with tempfile.TemporaryDirectory() as tmp: - filename = tmp + '/testloadfile.hdf5' s.save(filename) s1 = load(filename) assert s1.metadata.Markers.has_item(name) @@ -505,7 +545,9 @@ def test_save_load_permanent_marker(self): assert m1.marker_properties['color'] == color assert m1.name == name - def test_save_load_permanent_marker_all_types(self): + @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + def test_save_load_permanent_marker_all_types(self, tmp_path, file): + filename = tmp_path / file x1, y1, x2, y2 = 5, 2, 1, 8 s = Signal2D(np.arange(100).reshape(10, 10)) m0_list = [ @@ -520,8 +562,6 @@ def test_save_load_permanent_marker_all_types(self): ] for m in m0_list: s.add_marker(m, permanent=True) - with tempfile.TemporaryDirectory() as tmp: - filename = tmp + '/testallmarkersfile.hdf5' s.save(filename) s1 = load(filename) markers_dict = s1.metadata.Markers @@ -535,7 +575,9 @@ def test_save_load_permanent_marker_all_types(self): for m0_dict, m1_dict in zip(m0_dict_list, m1_dict_list): assert m0_dict == m1_dict - def test_save_load_horizontal_line_marker(self): + @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + def test_save_load_horizontal_line_marker(self,tmp_path,file): + filename = tmp_path / file y = 8 color = 'blue' linewidth = 2.5 @@ -544,14 +586,14 @@ def test_save_load_horizontal_line_marker(self): m = markers.horizontal_line(y=y, color=color, linewidth=linewidth) m.name = name s.add_marker(m, permanent=True) - with tempfile.TemporaryDirectory() as tmp: - filename = tmp + '/test_save_horizontal_line_marker.hdf5' s.save(filename) s1 = load(filename) m1 = s1.metadata.Markers.get_item(name) assert san_dict(m1._to_dictionary()) == san_dict(m._to_dictionary()) - def test_save_load_horizontal_line_segment_marker(self): + @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + def test_save_load_horizontal_line_segment_marker(self, tmp_path, file): + filename = tmp_path / file x1, x2, y = 1, 5, 8 color = 'red' linewidth = 1.2 @@ -561,14 +603,14 @@ def test_save_load_horizontal_line_segment_marker(self): x1=x1, x2=x2, y=y, color=color, linewidth=linewidth) m.name = name s.add_marker(m, permanent=True) - with tempfile.TemporaryDirectory() as tmp: - filename = tmp + '/test_save_horizontal_line_segment_marker.hdf5' s.save(filename) s1 = load(filename) m1 = s1.metadata.Markers.get_item(name) assert san_dict(m1._to_dictionary()) == san_dict(m._to_dictionary()) - def test_save_load_vertical_line_marker(self): + @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + def test_save_load_vertical_line_marker(self, tmp_path, file): + filename = tmp_path / file x = 9 color = 'black' linewidth = 3.5 @@ -577,14 +619,14 @@ def test_save_load_vertical_line_marker(self): m = markers.vertical_line(x=x, color=color, linewidth=linewidth) m.name = name s.add_marker(m, permanent=True) - with tempfile.TemporaryDirectory() as tmp: - filename = tmp + '/test_save_vertical_line_marker.hdf5' s.save(filename) s1 = load(filename) m1 = s1.metadata.Markers.get_item(name) assert san_dict(m1._to_dictionary()) == san_dict(m._to_dictionary()) - def test_save_load_vertical_line_segment_marker(self): + @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + def test_save_load_vertical_line_segment_marker(self, tmp_path, file): + filename = tmp_path / file x, y1, y2 = 2, 1, 3 color = 'white' linewidth = 4.2 @@ -594,14 +636,14 @@ def test_save_load_vertical_line_segment_marker(self): x=x, y1=y1, y2=y2, color=color, linewidth=linewidth) m.name = name s.add_marker(m, permanent=True) - with tempfile.TemporaryDirectory() as tmp: - filename = tmp + '/test_save_vertical_line_segment_marker.hdf5' s.save(filename) s1 = load(filename) m1 = s1.metadata.Markers.get_item(name) assert san_dict(m1._to_dictionary()) == san_dict(m._to_dictionary()) - def test_save_load_line_segment_marker(self): + @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + def test_save_load_line_segment_marker(self, tmp_path, file): + filename = tmp_path / file x1, x2, y1, y2 = 1, 9, 4, 7 color = 'cyan' linewidth = 0.7 @@ -611,14 +653,14 @@ def test_save_load_line_segment_marker(self): x1=x1, x2=x2, y1=y1, y2=y2, color=color, linewidth=linewidth) m.name = name s.add_marker(m, permanent=True) - with tempfile.TemporaryDirectory() as tmp: - filename = tmp + '/test_save_line_segment_marker.hdf5' s.save(filename) s1 = load(filename) m1 = s1.metadata.Markers.get_item(name) assert san_dict(m1._to_dictionary()) == san_dict(m._to_dictionary()) - def test_save_load_point_marker(self): + @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + def test_save_load_point_marker(self, tmp_path, file): + filename = tmp_path / file x, y = 9, 8 color = 'purple' name = "point test" @@ -627,14 +669,14 @@ def test_save_load_point_marker(self): x=x, y=y, color=color) m.name = name s.add_marker(m, permanent=True) - with tempfile.TemporaryDirectory() as tmp: - filename = tmp + '/test_save_point_marker.hdf5' s.save(filename) s1 = load(filename) m1 = s1.metadata.Markers.get_item(name) assert san_dict(m1._to_dictionary()) == san_dict(m._to_dictionary()) - def test_save_load_rectangle_marker(self): + @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + def test_save_load_rectangle_marker(self, tmp_path, file): + filename = tmp_path / file x1, x2, y1, y2 = 2, 4, 1, 3 color = 'yellow' linewidth = 5 @@ -644,14 +686,14 @@ def test_save_load_rectangle_marker(self): x1=x1, x2=x2, y1=y1, y2=y2, color=color, linewidth=linewidth) m.name = name s.add_marker(m, permanent=True) - with tempfile.TemporaryDirectory() as tmp: - filename = tmp + '/test_save_rectangle_marker.hdf5' s.save(filename) s1 = load(filename) m1 = s1.metadata.Markers.get_item(name) assert san_dict(m1._to_dictionary()) == san_dict(m._to_dictionary()) - def test_save_load_text_marker(self): + @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + def test_save_load_text_marker(self, tmp_path, file): + filename = tmp_path / file x, y = 3, 9.5 color = 'brown' name = "text_test" @@ -661,22 +703,20 @@ def test_save_load_text_marker(self): x=x, y=y, text=text, color=color) m.name = name s.add_marker(m, permanent=True) - with tempfile.TemporaryDirectory() as tmp: - filename = tmp + '/test_save_text_marker.hdf5' s.save(filename) s1 = load(filename) m1 = s1.metadata.Markers.get_item(name) assert san_dict(m1._to_dictionary()) == san_dict(m._to_dictionary()) - def test_save_load_multidim_navigation_marker(self): + @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + def test_save_load_multidim_navigation_marker(self, tmp_path, file): + filename = tmp_path / file x, y = (1, 2, 3), (5, 6, 7) name = 'test point' s = Signal2D(np.arange(300).reshape(3, 10, 10)) m = markers.point(x=x, y=y) m.name = name s.add_marker(m, permanent=True) - with tempfile.TemporaryDirectory() as tmp: - filename = tmp + '/test_save_multidim_nav_marker.hdf5' s.save(filename) s1 = load(filename) m1 = s1.metadata.Markers.get_item(name) @@ -713,6 +753,20 @@ def test_load_missing_y2_value(self): assert len(s.metadata.Markers) == 5 +@pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) +def test_save_load_model(tmp_path, file): + from hyperspy._components.gaussian import Gaussian + filename = tmp_path / file + s = Signal1D(np.ones((10, 10, 10, 10))) + m = s.create_model() + m.append(Gaussian()) + m.store("test") + s.save(filename) + signal2 = load(filename) + m2 = signal2.models.restore("test") + assert m.signal == m2.signal + + @pytest.mark.parametrize("compression", (None, "gzip", "lzf")) def test_compression(compression, tmp_path): s = Signal1D(np.ones((3,3))) @@ -725,12 +779,12 @@ def test_strings_from_py2(): s = EDS_TEM_Spectrum() assert isinstance(s.metadata.Sample.elements, list) - -def test_save_ragged_array(tmp_path): +@pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) +def test_save_ragged_array(tmp_path, file): a = np.array([0, 1]) b = np.array([0, 1, 2]) s = BaseSignal(np.array([a, b], dtype=object)).T - fname = tmp_path / 'test_save_ragged_array.hspy' + fname = tmp_path / file s.save(fname) s1 = load(fname) for i in range(len(s.data)): @@ -746,40 +800,50 @@ def test_load_missing_extension(caplog): with pytest.raises(ImportError): _ = s.models.restore("a") +class TestChunking: + @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + def test_save_chunks_signal_metadata(self, tmp_path, file): + N = 10 + dim = 3 + s = Signal1D(np.arange(N**dim).reshape([N]*dim)) + s.navigator = s.sum(-1) + s.change_dtype('float') + s.decomposition() + filename = tmp_path / file + chunks = (5, 5, 10) + s.save(filename, chunks=chunks) + s2 = load(filename, lazy=True) + assert tuple([c[0] for c in s2.data.chunks]) == chunks + + @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + def test_chunking_saving_lazy(self, tmp_path, file): + filename = tmp_path / file + s = Signal2D(da.zeros((50, 100, 100))).as_lazy() + s.data = s.data.rechunk([50, 25, 25]) + s.save(filename) + s1 = load(filename, lazy=True) + assert s.data.chunks == s1.data.chunks + + @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + def test_chunking_saving_lazy_True(self, tmp_path, file): + filename = tmp_path / file + s = Signal2D(da.zeros((50, 100, 100))).as_lazy() + s.data = s.data.rechunk([50, 25, 25]) + s.save(filename, chunks=True) + s1 = load(filename, lazy=True) + if file == "test.hspy": + assert tuple([c[0] for c in s1.data.chunks]) == (7, 25, 25) + else: + assert tuple([c[0] for c in s1.data.chunks]) == (25, 50, 50) + + @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + def test_chunking_saving_lazy_specify(self, tmp_path, file): + filename = tmp_path / file + s = Signal2D(da.zeros((50, 100, 100))).as_lazy() + # specify chunks + chunks = (50, 10, 10) + s.data = s.data.rechunk([50, 25, 25]) + s.save(filename, chunks=chunks) + s1 = load(filename, lazy=True) + assert tuple([c[0] for c in s1.data.chunks]) == chunks -def test_save_chunks_signal_metadata(): - N = 10 - dim = 3 - s = Signal1D(np.arange(N**dim).reshape([N]*dim)) - s.navigator = s.sum(-1) - s.change_dtype('float') - s.decomposition() - with tempfile.TemporaryDirectory() as tmp: - filename = os.path.join(tmp, 'test_save_chunks_signal_metadata.hspy') - chunks = (5, 5, 10) - s.save(filename, chunks=chunks) - s2 = load(filename, lazy=True) - assert tuple([c[0] for c in s2.data.chunks]) == chunks - - -def test_chunking_saving_lazy(): - s = Signal2D(da.zeros((50, 100, 100))).as_lazy() - s.data = s.data.rechunk([50, 25, 25]) - with tempfile.TemporaryDirectory() as tmp: - filename = os.path.join(tmp, 'test_chunking_saving_lazy.hspy') - filename2 = os.path.join(tmp, 'test_chunking_saving_lazy_chunks_True.hspy') - filename3 = os.path.join(tmp, 'test_chunking_saving_lazy_chunks_specified.hspy') - s.save(filename) - s1 = load(filename, lazy=True) - assert s.data.chunks == s1.data.chunks - - # with chunks=True, use h5py chunking - s.save(filename2, chunks=True) - s2 = load(filename2, lazy=True) - assert tuple([c[0] for c in s2.data.chunks]) == (7, 25, 25) - - # specify chunks - chunks = (50, 10, 10) - s.save(filename3, chunks=chunks) - s3 = load(filename3, lazy=True) - assert tuple([c[0] for c in s3.data.chunks]) == chunks diff --git a/hyperspy/tests/io/test_zspy.py b/hyperspy/tests/io/test_zspy.py index 0f73194b46..be295f9c9a 100644 --- a/hyperspy/tests/io/test_zspy.py +++ b/hyperspy/tests/io/test_zspy.py @@ -18,9 +18,7 @@ import gc import os.path -import sys import tempfile -import time from os import remove import dask.array as da @@ -29,245 +27,11 @@ import zarr from hyperspy._signals.signal1d import Signal1D -from hyperspy._signals.signal2d import Signal2D -from hyperspy.datasets.example_signals import EDS_TEM_Spectrum -from hyperspy.exceptions import VisibleDeprecationWarning -from hyperspy.axes import AxesManager from hyperspy.io import load -from hyperspy.misc.test_utils import assert_deep_almost_equal -from hyperspy.misc.test_utils import sanitize_dict as san_dict -from hyperspy.roi import Point2DROI from hyperspy.signal import BaseSignal -from hyperspy.utils import markers -import hyperspy.api as hs my_path = os.path.dirname(__file__) -data = np.array([4066., 3996., 3932., 3923., 5602., 5288., 7234., 7809., - 4710., 5015., 4366., 4524., 4832., 5474., 5718., 5034., - 4651., 4613., 4637., 4429., 4217.]) -example1_original_metadata = { - 'BEAMDIAM -nm': 100.0, - 'BEAMKV -kV': 120.0, - 'CHOFFSET': -168.0, - 'COLLANGLE-mR': 3.4, - 'CONVANGLE-mR': 1.5, - 'DATATYPE': 'XY', - 'DATE': '01-OCT-1991', - 'DWELLTIME-ms': 100.0, - 'ELSDET': 'SERIAL', - 'EMISSION -uA': 5.5, - 'FORMAT': 'EMSA/MAS Spectral Data File', - 'MAGCAM': 100.0, - 'NCOLUMNS': 1.0, - 'NPOINTS': 20.0, - 'OFFSET': 520.13, - 'OPERMODE': 'IMAG', - 'OWNER': 'EMSA/MAS TASK FORCE', - 'PROBECUR -nA': 12.345, - 'SIGNALTYPE': 'ELS', - 'THICKNESS-nm': 50.0, - 'TIME': '12:00', - 'TITLE': 'NIO EELS OK SHELL', - 'VERSION': '1.0', - 'XLABEL': 'Energy', - 'XPERCHAN': 3.1, - 'XUNITS': 'eV', - 'YLABEL': 'Counts', - 'YUNITS': 'Intensity'} - - -class Example1: - "Used as a base class for the TestExample classes below" - - def test_data(self): - assert ( - [4066.0, - 3996.0, - 3932.0, - 3923.0, - 5602.0, - 5288.0, - 7234.0, - 7809.0, - 4710.0, - 5015.0, - 4366.0, - 4524.0, - 4832.0, - 5474.0, - 5718.0, - 5034.0, - 4651.0, - 4613.0, - 4637.0, - 4429.0, - 4217.0] == self.s.data.tolist()) - - def test_original_metadata(self): - assert ( - example1_original_metadata == - self.s.original_metadata.as_dictionary()) - - @pytest.mark.xfail( - reason="dill is not guaranteed to load across Python versions") - def test_binary_string(self): - import dill - # apparently pickle is not "full" and marshal is not - # backwards-compatible - f = dill.loads(self.s.metadata.test.binary_string) - assert f(3.5) == 4.5 - - -class TestSavingMetadataContainers: - - def setup_method(self, method): - self.s = BaseSignal([0.1]) - - def test_save_unicode(self, tmp_path): - s = self.s - s.metadata.set_item('test', ['a', 'b', '\u6f22\u5b57']) - fname = tmp_path / 'test.zspy' - s.save(fname) - s.save("test.zspy") - l = load(fname) - assert isinstance(l.metadata.test[0], str) - assert isinstance(l.metadata.test[1], str) - assert isinstance(l.metadata.test[2], str) - assert l.metadata.test[2] == '\u6f22\u5b57' - - def test_save_long_list(self, tmp_path): - s = self.s - s.metadata.set_item('long_list', list(range(10000))) - start = time.time() - fname = tmp_path / 'test.zspy' - s.save(fname) - end = time.time() - assert end - start < 1.0 # It should finish in less that 1 s. - - def test_numpy_only_inner_lists(self, tmp_path): - s = self.s - s.metadata.set_item('test', [[1., 2], ('3', 4)]) - fname = tmp_path / 'test.zspy' - s.save(fname) - l = load(fname) - assert isinstance(l.metadata.test, list) - assert isinstance(l.metadata.test[0], list) - assert isinstance(l.metadata.test[1], tuple) - - @pytest.mark.xfail(sys.platform == 'win32', - reason="randomly fails in win32") - def test_numpy_general_type(self, tmp_path): - s = self.s - s.metadata.set_item('test', np.array([[1., 2], ['3', 4]])) - fname = tmp_path / 'test.zspy' - s.save(fname) - l = load(fname) - np.testing.assert_array_equal(l.metadata.test, s.metadata.test) - - @pytest.mark.xfail(sys.platform == 'win32', - reason="randomly fails in win32") - def test_list_general_type(self, tmp_path): - s = self.s - s.metadata.set_item('test', [[1., 2], ['3', 4]]) - fname = tmp_path / 'test.zspy' - s.save(fname) - l = load(fname) - assert isinstance(l.metadata.test[0][0], float) - assert isinstance(l.metadata.test[0][1], float) - assert isinstance(l.metadata.test[1][0], str) - assert isinstance(l.metadata.test[1][1], str) - - @pytest.mark.xfail(sys.platform == 'win32', - reason="randomly fails in win32") - def test_general_type_not_working(self, tmp_path): - s = self.s - s.metadata.set_item('test', (BaseSignal([1]), 0.1, 'test_string')) - fname = tmp_path / 'test.zspy' - s.save(fname) - l = load(fname) - assert isinstance(l.metadata.test, tuple) - assert isinstance(l.metadata.test[0], Signal1D) - assert isinstance(l.metadata.test[1], float) - assert isinstance(l.metadata.test[2], str) - - def test_unsupported_type(self, tmp_path): - s = self.s - s.metadata.set_item('test', Point2DROI(1, 2)) - fname = tmp_path / 'test.zspy' - s.save(fname) - l = load(fname) - assert 'test' not in l.metadata - - def test_date_time(self, tmp_path): - s = self.s - date, time = "2016-08-05", "15:00:00.450" - s.metadata.General.date = date - s.metadata.General.time = time - fname = tmp_path / 'test.zspy' - s.save(fname) - l = load(fname) - assert l.metadata.General.date == date - assert l.metadata.General.time == time - - def test_general_metadata(self, tmp_path): - s = self.s - notes = "Dummy notes" - authors = "Author 1, Author 2" - doi = "doi" - s.metadata.General.notes = notes - s.metadata.General.authors = authors - s.metadata.General.doi = doi - fname = tmp_path / 'test.zspy' - s.save(fname) - l = load(fname) - assert l.metadata.General.notes == notes - assert l.metadata.General.authors == authors - assert l.metadata.General.doi == doi - - def test_quantity(self, tmp_path): - s = self.s - quantity = "Intensity (electron)" - s.metadata.Signal.quantity = quantity - fname = tmp_path / 'test.zspy' - s.save(fname) - l = load(fname) - assert l.metadata.Signal.quantity == quantity - - def test_title(self, tmp_path): - s = self.s - fname = tmp_path / 'test.zspy' - s.metadata.General.title = '__unnamed__' - s.save(fname) - l = load(fname) - assert l.metadata.General.title is "" - - def test_save_bytes(self, tmp_path): - s = self.s - byte_message = bytes("testing", 'utf-8') - s.metadata.set_item('test', byte_message) - fname = tmp_path / 'test.zspy' - s.save(fname) - l = load(fname) - assert l.metadata.test == s.metadata.test.decode() - - def test_save_empty_tuple(self, tmp_path): - s = self.s - s.metadata.set_item('test', ()) - fname = tmp_path / 'test.zspy' - s.save(fname) - l = load(fname) - #strange becuase you need the encoding... - assert l.metadata.test == s.metadata.test - - def test_save_axes_manager(self, tmp_path): - s = self.s - s.metadata.set_item('test', s.axes_manager) - fname = tmp_path / 'test.zspy' - s.save(fname) - l = load(fname) - #strange becuase you need the encoding... - assert isinstance(l.metadata.test, AxesManager) class TestLoadingOOMReadOnly: @@ -301,374 +65,6 @@ def teardown_method(self, method): pass -class TestPassingArgs: - def test_compression_opts(self, tmp_path): - self.filename = tmp_path / 'testfile.zspy' - from numcodecs import Blosc - comp = Blosc(cname='zstd', clevel=1, shuffle=Blosc.SHUFFLE) - BaseSignal([1, 2, 3]).save(self.filename, compressor=comp) - f = zarr.open(self.filename.__str__(), mode='r+') - d = f['Experiments/__unnamed__/data'] - assert (d.compressor == comp) - - -class TestAxesConfiguration: - - def test_axes_configuration(self, tmp_path): - self.filename = tmp_path / 'testfile.zspy' - s = BaseSignal(np.zeros((2, 2, 2, 2, 2))) - s.axes_manager.signal_axes[0].navigate = True - s.axes_manager.signal_axes[0].navigate = True - s.save(self.filename) - s = load(self.filename) - assert s.axes_manager.navigation_axes[0].index_in_array == 4 - assert s.axes_manager.navigation_axes[1].index_in_array == 3 - assert s.axes_manager.signal_dimension == 3 - - -class TestAxesConfigurationBinning: - def test_axes_configuration(self): - self.filename = 'testfile.zspy' - s = BaseSignal(np.zeros((2, 2, 2))) - s.axes_manager.signal_axes[-1].is_binned = True - s.save(self.filename) - s = load(self.filename) - assert s.axes_manager.signal_axes[-1].is_binned == True - - -class Test_permanent_markers_io: - - def test_save_permanent_marker(self): - s = Signal2D(np.arange(100).reshape(10, 10)) - m = markers.point(x=5, y=5) - s.add_marker(m, permanent=True) - with tempfile.TemporaryDirectory() as tmp: - filename = tmp + '/testsavefile.zspy' - s.save(filename) - - def test_save_load_empty_metadata_markers(self): - s = Signal2D(np.arange(100).reshape(10, 10)) - m = markers.point(x=5, y=5) - m.name = "test" - s.add_marker(m, permanent=True) - del s.metadata.Markers.test - with tempfile.TemporaryDirectory() as tmp: - filename = tmp + '/testsavefile.zspy' - s.save(filename) - s1 = load(filename) - assert len(s1.metadata.Markers) == 0 - - def test_save_load_permanent_marker(self): - x, y = 5, 2 - color = 'red' - size = 10 - name = 'testname' - s = Signal2D(np.arange(100).reshape(10, 10)) - m = markers.point(x=x, y=y, color=color, size=size) - m.name = name - s.add_marker(m, permanent=True) - with tempfile.TemporaryDirectory() as tmp: - filename = tmp + '/testloadfile.zspy' - s.save(filename) - s1 = load(filename) - assert s1.metadata.Markers.has_item(name) - m1 = s1.metadata.Markers.get_item(name) - assert m1.get_data_position('x1') == x - assert m1.get_data_position('y1') == y - assert m1.get_data_position('size') == size - assert m1.marker_properties['color'] == color - assert m1.name == name - - def test_save_load_permanent_marker_all_types(self): - x1, y1, x2, y2 = 5, 2, 1, 8 - s = Signal2D(np.arange(100).reshape(10, 10)) - m0_list = [ - markers.point(x=x1, y=y1), - markers.horizontal_line(y=y1), - markers.horizontal_line_segment(x1=x1, x2=x2, y=y1), - markers.line_segment(x1=x1, x2=x2, y1=y1, y2=y2), - markers.rectangle(x1=x1, x2=x2, y1=y1, y2=y2), - markers.text(x=x1, y=y1, text="test"), - markers.vertical_line(x=x1), - markers.vertical_line_segment(x=x1, y1=y1, y2=y2), - ] - for m in m0_list: - s.add_marker(m, permanent=True) - with tempfile.TemporaryDirectory() as tmp: - filename = tmp + '/testallmarkersfile.zspy' - s.save(filename) - s1 = load(filename) - markers_dict = s1.metadata.Markers - m0_dict_list = [] - m1_dict_list = [] - for m in m0_list: - m0_dict_list.append(san_dict(m._to_dictionary())) - m1_dict_list.append( - san_dict(markers_dict.get_item(m.name)._to_dictionary())) - assert len(list(s1.metadata.Markers)) == 8 - for m0_dict, m1_dict in zip(m0_dict_list, m1_dict_list): - assert m0_dict == m1_dict - - def test_save_load_horizontal_line_marker(self): - y = 8 - color = 'blue' - linewidth = 2.5 - name = "horizontal_line_test" - s = Signal2D(np.arange(100).reshape(10, 10)) - m = markers.horizontal_line(y=y, color=color, linewidth=linewidth) - m.name = name - s.add_marker(m, permanent=True) - with tempfile.TemporaryDirectory() as tmp: - filename = tmp + '/test_save_horizontal_line_marker.zspy' - s.save(filename) - s1 = load(filename) - m1 = s1.metadata.Markers.get_item(name) - assert san_dict(m1._to_dictionary()) == san_dict(m._to_dictionary()) - - def test_save_load_horizontal_line_segment_marker(self): - x1, x2, y = 1, 5, 8 - color = 'red' - linewidth = 1.2 - name = "horizontal_line_segment_test" - s = Signal2D(np.arange(100).reshape(10, 10)) - m = markers.horizontal_line_segment( - x1=x1, x2=x2, y=y, color=color, linewidth=linewidth) - m.name = name - s.add_marker(m, permanent=True) - with tempfile.TemporaryDirectory() as tmp: - filename = tmp + '/test_save_horizontal_line_segment_marker.zspy' - s.save(filename) - s1 = load(filename) - m1 = s1.metadata.Markers.get_item(name) - assert san_dict(m1._to_dictionary()) == san_dict(m._to_dictionary()) - - def test_save_load_vertical_line_marker(self): - x = 9 - color = 'black' - linewidth = 3.5 - name = "vertical_line_test" - s = Signal2D(np.arange(100).reshape(10, 10)) - m = markers.vertical_line(x=x, color=color, linewidth=linewidth) - m.name = name - s.add_marker(m, permanent=True) - with tempfile.TemporaryDirectory() as tmp: - filename = tmp + '/test_save_vertical_line_marker.zspy' - s.save(filename) - s1 = load(filename) - m1 = s1.metadata.Markers.get_item(name) - assert san_dict(m1._to_dictionary()) == san_dict(m._to_dictionary()) - - def test_save_load_vertical_line_segment_marker(self): - x, y1, y2 = 2, 1, 3 - color = 'white' - linewidth = 4.2 - name = "vertical_line_segment_test" - s = Signal2D(np.arange(100).reshape(10, 10)) - m = markers.vertical_line_segment( - x=x, y1=y1, y2=y2, color=color, linewidth=linewidth) - m.name = name - s.add_marker(m, permanent=True) - with tempfile.TemporaryDirectory() as tmp: - filename = tmp + '/test_save_vertical_line_segment_marker.zspy' - s.save(filename) - s1 = load(filename) - m1 = s1.metadata.Markers.get_item(name) - assert san_dict(m1._to_dictionary()) == san_dict(m._to_dictionary()) - - def test_save_load_line_segment_marker(self): - x1, x2, y1, y2 = 1, 9, 4, 7 - color = 'cyan' - linewidth = 0.7 - name = "line_segment_test" - s = Signal2D(np.arange(100).reshape(10, 10)) - m = markers.line_segment( - x1=x1, x2=x2, y1=y1, y2=y2, color=color, linewidth=linewidth) - m.name = name - s.add_marker(m, permanent=True) - with tempfile.TemporaryDirectory() as tmp: - filename = tmp + '/test_save_line_segment_marker.zspy' - s.save(filename) - s1 = load(filename) - m1 = s1.metadata.Markers.get_item(name) - assert san_dict(m1._to_dictionary()) == san_dict(m._to_dictionary()) - - def test_save_load_point_marker(self): - x, y = 9, 8 - color = 'purple' - name = "point test" - s = Signal2D(np.arange(100).reshape(10, 10)) - m = markers.point( - x=x, y=y, color=color) - m.name = name - s.add_marker(m, permanent=True) - with tempfile.TemporaryDirectory() as tmp: - filename = tmp + '/test_save_point_marker.zspy' - s.save(filename) - s1 = load(filename) - m1 = s1.metadata.Markers.get_item(name) - assert san_dict(m1._to_dictionary()) == san_dict(m._to_dictionary()) - - def test_save_load_rectangle_marker(self): - x1, x2, y1, y2 = 2, 4, 1, 3 - color = 'yellow' - linewidth = 5 - name = "rectangle_test" - s = Signal2D(np.arange(100).reshape(10, 10)) - m = markers.rectangle( - x1=x1, x2=x2, y1=y1, y2=y2, color=color, linewidth=linewidth) - m.name = name - s.add_marker(m, permanent=True) - with tempfile.TemporaryDirectory() as tmp: - filename = tmp + '/test_save_rectangle_marker.zspy' - s.save(filename) - s1 = load(filename) - m1 = s1.metadata.Markers.get_item(name) - assert san_dict(m1._to_dictionary()) == san_dict(m._to_dictionary()) - - def test_metadata_update_to_v3_1(self): - md = {'Acquisition_instrument': {'SEM': {'Stage': {'tilt_alpha': 5.0}}, - 'TEM': {'Detector': {'Camera': {'exposure': 0.20000000000000001}}, - 'Stage': {'tilt_alpha': 10.0}, - 'acquisition_mode': 'TEM', - 'beam_current': 0.0, - 'beam_energy': 200.0, - 'camera_length': 320.00000000000006, - 'microscope': 'FEI Tecnai'}}, - 'General': {'date': '2014-07-09', - 'original_filename': 'test_diffraction_pattern.dm3', - 'time': '18:56:37', - 'title': 'test_diffraction_pattern'}, - 'Signal': {'Noise_properties': {'Variance_linear_model': {'gain_factor': 1.0, - 'gain_offset': 0.0}}, - 'quantity': 'Intensity', - 'signal_type': ''}, - '_HyperSpy': {'Folding': {'original_axes_manager': None, - 'original_shape': None, - 'signal_unfolded': False, - 'unfolded': False}}} - s = load(os.path.join( - my_path, - "hdf5_files", - 'example2_v3.1.hspy')) - assert_deep_almost_equal(s.metadata.as_dictionary(), md) - - def test_save_load_text_marker(self): - x, y = 3, 9.5 - color = 'brown' - name = "text_test" - text = "a text" - s = Signal2D(np.arange(100).reshape(10, 10)) - m = markers.text( - x=x, y=y, text=text, color=color) - m.name = name - s.add_marker(m, permanent=True) - with tempfile.TemporaryDirectory() as tmp: - filename = tmp + '/test_save_text_marker.zspy' - s.save(filename) - s1 = load(filename) - m1 = s1.metadata.Markers.get_item(name) - assert san_dict(m1._to_dictionary()) == san_dict(m._to_dictionary()) - - def test_save_load_multidim_navigation_marker(self): - x, y = (1, 2, 3), (5, 6, 7) - name = 'test point' - s = Signal2D(np.arange(300).reshape(3, 10, 10)) - m = markers.point(x=x, y=y) - m.name = name - s.add_marker(m, permanent=True) - with tempfile.TemporaryDirectory() as tmp: - filename = tmp + '/test_save_multidim_nav_marker.zspy' - s.save(filename) - s1 = load(filename) - m1 = s1.metadata.Markers.get_item(name) - assert san_dict(m1._to_dictionary()) == san_dict(m._to_dictionary()) - assert m1.get_data_position('x1') == x[0] - assert m1.get_data_position('y1') == y[0] - s1.axes_manager.navigation_axes[0].index = 1 - assert m1.get_data_position('x1') == x[1] - assert m1.get_data_position('y1') == y[1] - s1.axes_manager.navigation_axes[0].index = 2 - assert m1.get_data_position('x1') == x[2] - assert m1.get_data_position('y1') == y[2] - - - - -@pytest.mark.parametrize("compressor", (None, "default", "blosc")) -def test_compression(compressor, tmp_path): - if compressor is "blosc": - from numcodecs import Blosc - compressor = Blosc(cname='zstd', clevel=3, shuffle=Blosc.BITSHUFFLE) - s = Signal1D(np.ones((3, 3))) - s.save(tmp_path / 'test_compression.zspy', - overwrite=True, - compressor=compressor) - load(tmp_path / 'test_compression.zspy') - - -def test_strings_from_py2(): - s = EDS_TEM_Spectrum() - assert isinstance(s.metadata.Sample.elements, list) - - -def test_save_ragged_array(tmp_path): - a = np.array([0, 1]) - b = np.array([0, 1, 2]) - s = BaseSignal(np.array([a, b], dtype=object)).T - fname = tmp_path / 'test_save_ragged_array.zspy' - s.save(fname) - s1 = load(fname) - for i in range(len(s.data)): - np.testing.assert_allclose(s.data[i], s1.data[i]) - assert s.__class__ == s1.__class__ - - -def test_save_chunks_signal_metadata(): - N = 10 - dim = 3 - s = Signal1D(np.arange(N ** dim).reshape([N] * dim)) - s.navigator = s.sum(-1) - s.change_dtype('float') - s.decomposition() - with tempfile.TemporaryDirectory() as tmp: - filename = os.path.join(tmp, 'test_save_chunks_signal_metadata.zspy') - chunks = (5, 2, 2) - s.save(filename, chunks=chunks) - s2 = load(filename, lazy=True) - assert tuple([c[0] for c in s2.data.chunks]) == chunks - - -def test_chunking_saving_lazy(): - s = Signal2D(da.zeros((50, 100, 100))).as_lazy() - s.data = s.data.rechunk([50, 25, 25]) - with tempfile.TemporaryDirectory() as tmp: - filename = os.path.join(tmp, 'test_chunking_saving_lazy.zspy') - filename2 = os.path.join(tmp, 'test_chunking_saving_lazy_chunks_True.zspy') - filename3 = os.path.join(tmp, 'test_chunking_saving_lazy_chunks_specified.zspy') - s.save(filename) - s1 = load(filename, lazy=True) - assert s.data.chunks == s1.data.chunks - - s.save(filename2) - s2 = load(filename2, lazy=True) - assert tuple([c[0] for c in s2.data.chunks]) == (50, 25, 25) - - # specify chunks - chunks = (50, 10, 10) - s.save(filename3, chunks=chunks) - s3 = load(filename3, lazy=True) - assert tuple([c[0] for c in s3.data.chunks]) == chunks - -def test_data_lazy(): - s = Signal2D(da.ones((5, 10, 10))).as_lazy() - s.data = s.data.rechunk([5, 2, 2]) - with tempfile.TemporaryDirectory() as tmp: - filename = os.path.join(tmp, 'test_chunking_saving_lazy.zspy') - s.save(filename) - s1 = load(filename) - np.testing.assert_array_almost_equal(s1.data, s.data) - - class TestZspy: @pytest.fixture def signal(self): @@ -676,23 +72,12 @@ def signal(self): s = Signal1D(data) return s - def test_save_load_model(self, signal): - with tempfile.TemporaryDirectory() as tmp: - filename = tmp + '/testmodels.zspy' - m = signal.create_model() - m.append(hs.model.components1D.Gaussian()) - m.store("test") - signal.save(filename) - signal2 = hs.load(filename) - m2 = signal2.models.restore("test") - assert m.signal == m2.signal - def test_save_N5_type(self,signal): with tempfile.TemporaryDirectory() as tmp: filename = tmp + '/testmodels.zspy' store = zarr.N5Store(path=filename) signal.save(store.path, write_to_storage=True) - signal2 = hs.load(filename) + signal2 = load(filename) np.testing.assert_array_equal(signal2.data, signal.data) @pytest.mark.skip(reason="lmdb must be installed to test") @@ -702,5 +87,25 @@ def test_save_lmdb_type(self, signal): filename = tmp + '/testmodels.zspy/' store = zarr.LMDBStore(path=filename) signal.save(store.path, write_to_storage=True) - signal2 = hs.load(store.path) - np.testing.assert_array_equal(signal2.data, signal.data) \ No newline at end of file + signal2 = load(store.path) + np.testing.assert_array_equal(signal2.data, signal.data) + + def test_compression_opts(self, tmp_path): + self.filename = tmp_path / 'testfile.zspy' + from numcodecs import Blosc + comp = Blosc(cname='zstd', clevel=1, shuffle=Blosc.SHUFFLE) + BaseSignal([1, 2, 3]).save(self.filename, compressor=comp) + f = zarr.open(self.filename.__str__(), mode='r+') + d = f['Experiments/__unnamed__/data'] + assert (d.compressor == comp)\ + + @pytest.mark.parametrize("compressor", (None, "default", "blosc")) + def test_compression(self, compressor, tmp_path): + if compressor is "blosc": + from numcodecs import Blosc + compressor = Blosc(cname='zstd', clevel=3, shuffle=Blosc.BITSHUFFLE) + s = Signal1D(np.ones((3, 3))) + s.save(tmp_path / 'test_compression.zspy', + overwrite=True, + compressor=compressor) + load(tmp_path / 'test_compression.zspy') \ No newline at end of file diff --git a/upcoming_changes/2798.new.rst b/upcoming_changes/2798.new.rst index a3f3344930..d456938353 100644 --- a/upcoming_changes/2798.new.rst +++ b/upcoming_changes/2798.new.rst @@ -1,2 +1 @@ -Add in `zspy` saving specification for saving large data. Helps with saving and loading large datasets. -Using `zarr` background \ No newline at end of file +Add in `zspy` format: hspy specification with the zarr format. Particularly useful to speed up loading and saving large by using concurrency. From b83ec5dab567f2cbcf1fd64899b4a05a45dc9b73 Mon Sep 17 00:00:00 2001 From: shaw Date: Tue, 5 Oct 2021 12:24:59 -0500 Subject: [PATCH 18/31] Added in spacing for nexus file --- hyperspy/io_plugins/nexus.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/hyperspy/io_plugins/nexus.py b/hyperspy/io_plugins/nexus.py index 2c3095bb65..205eff6115 100644 --- a/hyperspy/io_plugins/nexus.py +++ b/hyperspy/io_plugins/nexus.py @@ -52,6 +52,7 @@ def _byte_to_string(value): Parameters ---------- value : byte str + Returns ------- str @@ -210,7 +211,6 @@ def _get_nav_list(data, dataentry): ------- nav_list : list contains information about each axes. - """ detector_index = 0 @@ -442,6 +442,7 @@ def file_reader(filename, lazy=False, dataset_key=None, dataset_path=None, the datasets of interest. "keys" is a special keywords and prepended with "fix" in the metadata structure to avoid any issues. + Datasets are all arrays with size>2 (arrays, lists) Parameters @@ -895,6 +896,7 @@ def _fix_exclusion_keys(key): Signal and DictionaryBrowser break if a a key is a dict method - e.g. {"keys":2.0}. + This method prepends the key with ``fix_`` so the information is still present to work around this issue @@ -1088,6 +1090,8 @@ def read_metadata_from_file(filename, metadata_key=None, * :py:meth:`~.io_plugins.nexus.file_reader` * :py:meth:`~.io_plugins.nexus.file_writer` * :py:meth:`~.io_plugins.nexus.list_datasets_in_file` + + """ search_keys = _check_search_keys(metadata_key) fin = h5py.File(filename, "r") @@ -1130,16 +1134,20 @@ def list_datasets_in_file(filename, dataset_key=None, verbose : boolean, default : True Prints the results to screen + Returns ------- list list of paths to datasets + See Also -------- * :py:meth:`~.io_plugins.nexus.file_reader` * :py:meth:`~.io_plugins.nexus.file_writer` * :py:meth:`~.io_plugins.nexus.read_metadata_from_file` + + """ search_keys = _check_search_keys(dataset_key) fin = h5py.File(filename, "r") From 9fffe4c336d9778b01bd2362f10a2674fda00f50 Mon Sep 17 00:00:00 2001 From: shaw Date: Tue, 5 Oct 2021 12:34:11 -0500 Subject: [PATCH 19/31] Removed maxshape and shuffle for the `require_dataset` function --- hyperspy/io_plugins/hierarchical.py | 5 ++--- hyperspy/io_plugins/nexus.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/hyperspy/io_plugins/hierarchical.py b/hyperspy/io_plugins/hierarchical.py index 8351700a52..625d4afb69 100644 --- a/hyperspy/io_plugins/hierarchical.py +++ b/hyperspy/io_plugins/hierarchical.py @@ -74,16 +74,15 @@ def _overwrite_dataset(group, else: got_data = False - maxshape = tuple(None for _ in data.shape) while not got_data: try: these_kwds = kwds.copy() these_kwds.update(dict(shape=data.shape, dtype=data.dtype, exact=True, - maxshape=maxshape, chunks=chunks, - shuffle=True, )) + ) + ) # If chunks is True, the `chunks` attribute of `dset` below # contains the chunk shape guessed by h5py diff --git a/hyperspy/io_plugins/nexus.py b/hyperspy/io_plugins/nexus.py index 205eff6115..c27514a084 100644 --- a/hyperspy/io_plugins/nexus.py +++ b/hyperspy/io_plugins/nexus.py @@ -1321,4 +1321,4 @@ def file_writer(filename, nxometa.attrs["NX_class"] = _parse_to_file("NXcollection") # write the groups and structure _write_nexus_groups(meta, nxometa, **kwds) - _write_nexus_attr(meta, nxometa) \ No newline at end of file + _write_nexus_attr(meta, nxometa) From cf8632e1bce492b56f615f937514a3e44af79c4d Mon Sep 17 00:00:00 2001 From: shaw Date: Tue, 5 Oct 2021 14:13:24 -0500 Subject: [PATCH 20/31] Overwrite working properly --- hyperspy/io.py | 3 ++- hyperspy/misc/io/tools.py | 2 +- hyperspy/tests/io/test_zspy.py | 15 +++++++++++++++ 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/hyperspy/io.py b/hyperspy/io.py index c5aa5db287..f20081b1a0 100644 --- a/hyperspy/io.py +++ b/hyperspy/io.py @@ -784,7 +784,8 @@ def save(filename, signal, overwrite=None, **kwds): # Create the directory if it does not exist ensure_directory(filename.parent) - is_file = filename.is_file() + is_file = filename.is_file() or (filename.is_dir() and + os.path.splitext(filename)[1] == '.zspy') if overwrite is None: write = overwrite_method(filename) # Ask what to do diff --git a/hyperspy/misc/io/tools.py b/hyperspy/misc/io/tools.py index 1fb3fe812e..9942fff5fd 100644 --- a/hyperspy/misc/io/tools.py +++ b/hyperspy/misc/io/tools.py @@ -104,7 +104,7 @@ def overwrite(fname): Whether to overwrite file. """ - if Path(fname).is_file(): + if Path(fname).is_file() or Path(fname).is_dir(): message = f"Overwrite '{fname}' (y/n)?\n" try: answer = input(message) diff --git a/hyperspy/tests/io/test_zspy.py b/hyperspy/tests/io/test_zspy.py index be295f9c9a..0f1067cb11 100644 --- a/hyperspy/tests/io/test_zspy.py +++ b/hyperspy/tests/io/test_zspy.py @@ -80,6 +80,21 @@ def test_save_N5_type(self,signal): signal2 = load(filename) np.testing.assert_array_equal(signal2.data, signal.data) + @pytest.mark.parametrize("overwrite",[None, True, False]) + def test_overwrite(self,signal,overwrite): + with tempfile.TemporaryDirectory() as tmp: + filename = tmp + '/testmodels.zspy' + signal.save(filename=filename) + signal2 = signal*2 + signal2.save(filename=filename, overwrite=overwrite) + if overwrite is None: + np.testing.assert_array_equal(signal.data,load(filename).data) + elif overwrite: + np.testing.assert_array_equal(signal2.data,load(filename).data) + else: + np.testing.assert_array_equal(signal.data,load(filename).data) + + @pytest.mark.skip(reason="lmdb must be installed to test") def test_save_lmdb_type(self, signal): with tempfile.TemporaryDirectory() as tmp: From fb4c048faaea26f75963dbea1f9ce6f7e51b6fd7 Mon Sep 17 00:00:00 2001 From: shaw Date: Tue, 5 Oct 2021 14:43:52 -0500 Subject: [PATCH 21/31] Added in more checking for directory and zspy format --- hyperspy/misc/io/tools.py | 4 +++- hyperspy/tests/io/test_zspy.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/hyperspy/misc/io/tools.py b/hyperspy/misc/io/tools.py index 9942fff5fd..9f26d2a5f9 100644 --- a/hyperspy/misc/io/tools.py +++ b/hyperspy/misc/io/tools.py @@ -20,6 +20,7 @@ import logging import xml.etree.ElementTree as ET from pathlib import Path +import os from hyperspy.misc.utils import DictionaryTreeBrowser @@ -104,7 +105,8 @@ def overwrite(fname): Whether to overwrite file. """ - if Path(fname).is_file() or Path(fname).is_dir(): + if Path(fname).is_file() or (Path(fname).is_dir() and + os.path.splitext(fname)[1] == '.zspy'): message = f"Overwrite '{fname}' (y/n)?\n" try: answer = input(message) diff --git a/hyperspy/tests/io/test_zspy.py b/hyperspy/tests/io/test_zspy.py index 0f1067cb11..427785d75e 100644 --- a/hyperspy/tests/io/test_zspy.py +++ b/hyperspy/tests/io/test_zspy.py @@ -123,4 +123,4 @@ def test_compression(self, compressor, tmp_path): s.save(tmp_path / 'test_compression.zspy', overwrite=True, compressor=compressor) - load(tmp_path / 'test_compression.zspy') \ No newline at end of file + load(tmp_path / 'test_compression.zspy') From 48c88f66369da9aad705b479290f18ee2f55020e Mon Sep 17 00:00:00 2001 From: shaw Date: Tue, 5 Oct 2021 17:18:23 -0500 Subject: [PATCH 22/31] Add in shuffle kwd argument as default for hspy format. --- hyperspy/io_plugins/hierarchical.py | 3 +-- hyperspy/io_plugins/hspy.py | 2 ++ 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/hyperspy/io_plugins/hierarchical.py b/hyperspy/io_plugins/hierarchical.py index 625d4afb69..d817211cc3 100644 --- a/hyperspy/io_plugins/hierarchical.py +++ b/hyperspy/io_plugins/hierarchical.py @@ -81,8 +81,7 @@ def _overwrite_dataset(group, dtype=data.dtype, exact=True, chunks=chunks, - ) - ) + )) # If chunks is True, the `chunks` attribute of `dset` below # contains the chunk shape guessed by h5py diff --git a/hyperspy/io_plugins/hspy.py b/hyperspy/io_plugins/hspy.py index 802fef3dbc..a8fe412ebc 100644 --- a/hyperspy/io_plugins/hspy.py +++ b/hyperspy/io_plugins/hspy.py @@ -287,6 +287,8 @@ def file_writer(filename, signal, *args, **kwds): else: smd.record_by = "" try: + if "shuffle" not in kwds: + kwds["shuffle"] = True writer = HyperspyWriter(f, signal, expg, **kwds) writer.write() #write_signal(signal, expg, **kwds) From efd10868f1243bead4b2299484dde7b26f0d7287 Mon Sep 17 00:00:00 2001 From: Eric Prestat Date: Wed, 6 Oct 2021 11:37:47 +0100 Subject: [PATCH 23/31] Fix typo error message. --- hyperspy/io_plugins/hierarchical.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/hyperspy/io_plugins/hierarchical.py b/hyperspy/io_plugins/hierarchical.py index d817211cc3..bc4b49c053 100644 --- a/hyperspy/io_plugins/hierarchical.py +++ b/hyperspy/io_plugins/hierarchical.py @@ -1,5 +1,4 @@ from packaging.version import Version -import warnings import logging import datetime import ast @@ -9,7 +8,7 @@ import numpy as np import dask.array as da from traits.api import Undefined -from hyperspy.misc.utils import ensure_unicode, multiply, get_object_package_info +from hyperspy.misc.utils import ensure_unicode, get_object_package_info from hyperspy.axes import AxesManager @@ -647,8 +646,8 @@ def dict2group(self, dictionary, group, **kwds): group.attrs[key] = value except BaseException: _logger.exception( - "The hdf5 writer could not write the following " - "information in the file: %s : %s", key, value) + "The writer could not write the following " + f"information in the file: {key} : {value}") def parse_structure(self, key, group, value, _type, **kwds): from hyperspy.signal import BaseSignal @@ -681,5 +680,3 @@ def parse_structure(self, key, group, value, _type, **kwds): group.create_dataset(_type + key, data=tmp, **kwds) - - From b9be6fca52ebd1e550e974af03001ed6f60e015b Mon Sep 17 00:00:00 2001 From: Eric Prestat Date: Wed, 6 Oct 2021 12:54:22 +0100 Subject: [PATCH 24/31] Simplify inheritance structure --- .../{hierarchical.py => _hierarchical.py} | 168 +++++++++--------- hyperspy/io_plugins/hspy.py | 102 +++++------ hyperspy/io_plugins/nexus.py | 13 +- hyperspy/io_plugins/zspy.py | 149 +++++++--------- hyperspy/tests/io/test_zspy.py | 78 +++----- 5 files changed, 235 insertions(+), 275 deletions(-) rename hyperspy/io_plugins/{hierarchical.py => _hierarchical.py} (89%) diff --git a/hyperspy/io_plugins/hierarchical.py b/hyperspy/io_plugins/_hierarchical.py similarity index 89% rename from hyperspy/io_plugins/hierarchical.py rename to hyperspy/io_plugins/_hierarchical.py index bc4b49c053..f1d9cafa71 100644 --- a/hyperspy/io_plugins/hierarchical.py +++ b/hyperspy/io_plugins/_hierarchical.py @@ -1,15 +1,15 @@ -from packaging.version import Version -import logging import datetime import ast +import logging +from packaging.version import Version +import dask.array as da import h5py - import numpy as np -import dask.array as da from traits.api import Undefined -from hyperspy.misc.utils import ensure_unicode, get_object_package_info + from hyperspy.axes import AxesManager +from hyperspy.misc.utils import ensure_unicode, get_object_package_info version = "3.1" @@ -21,83 +21,6 @@ _logger = logging.getLogger(__name__) -def _overwrite_dataset(group, - data, - key, - signal_axes=None, - chunks=None, - get_signal_chunks=None, - get_object_dset=None, - store_data=None, - **kwds): - """Overwrites some dataset in a hierarchical dataset. - - Parameters - ---------- - group: Zarr.Group or h5py.Group - The group to write the data to - data: Array-like - The data to be written - key: str - The key for the data - signal_axes: tuple - The indexes of the signal axes - chunks: tuple - The chunks for the dataset. If None then get_signal_chunks will be called - get_signal_chunks: func - A function to get the signal chunks for the dataset - get_object_dset: func - A function to get the object dset for saving ragged arrays. - store_data: func - A function to store the data in some hierarchical data format - kwds: - Any additional keywords for to be passed to the store data function - The store data function is passed for each hierarchical data format - Returns - ------- - - """ - if chunks is None: - if isinstance(data, da.Array): - # For lazy dataset, by default, we use the current dask chunking - chunks = tuple([c[0] for c in data.chunks]) - else: - # If signal_axes=None, use automatic h5py chunking, otherwise - # optimise the chunking to contain at least one signal per chunk - chunks = get_signal_chunks(data.shape, data.dtype, signal_axes) - if np.issubdtype(data.dtype, np.dtype('U')): - # Saving numpy unicode type is not supported in h5py - data = data.astype(np.dtype('S')) - if data.dtype == np.dtype('O'): - dset = get_object_dset(group, data, key, chunks, **kwds) - - else: - got_data = False - while not got_data: - try: - these_kwds = kwds.copy() - these_kwds.update(dict(shape=data.shape, - dtype=data.dtype, - exact=True, - chunks=chunks, - )) - - # If chunks is True, the `chunks` attribute of `dset` below - # contains the chunk shape guessed by h5py - dset = group.require_dataset(key, **these_kwds) - got_data = True - except TypeError: - # if the shape or dtype/etc do not match, - # we delete the old one and create new in the next loop run - del group[key] - if dset == data: - # just a reference to already created thing - pass - else: - _logger.info(f"Chunks used for saving: {chunks}") - store_data(data, dset, group, key, chunks, **kwds) - - class HierarchicalReader: """A generic Reader class for reading data from hierarchical file types.""" def __init__(self, @@ -535,7 +458,86 @@ def __init__(self, self.unicode_kwds = None self.ragged_kwds = None self.kwds = kwds - self.overwrite_dataset=None + + @staticmethod + def _get_signal_chunks(): + raise NotImplementedError( + "This method must be implemented by subclasses.") + + @staticmethod + def _get_object_dset(): + raise NotImplementedError( + "This method must be implemented by subclasses.") + + @staticmethod + def _store_data(): + raise NotImplementedError( + "This method must be implemented by subclasses.") + + @classmethod + def overwrite_dataset(cls, group, data, key, signal_axes=None, + chunks=None, **kwds): + """Overwrites some dataset in a hierarchical dataset. + + Parameters + ---------- + group: Zarr.Group or h5py.Group + The group to write the data to + data: Array-like + The data to be written + key: str + The key for the data + signal_axes: tuple + The indexes of the signal axes + chunks: tuple + The chunks for the dataset. If None then get_signal_chunks will be + called + kwds: + Any additional keywords for to be passed to the store data function + The store data function is passed for each hierarchical data format + """ + if chunks is None: + if isinstance(data, da.Array): + # For lazy dataset, by default, we use the current dask chunking + chunks = tuple([c[0] for c in data.chunks]) + else: + # If signal_axes=None, use automatic h5py chunking, otherwise + # optimise the chunking to contain at least one signal per chunk + chunks = cls._get_signal_chunks( + data.shape, data.dtype, signal_axes + ) + if np.issubdtype(data.dtype, np.dtype('U')): + # Saving numpy unicode type is not supported in h5py + data = data.astype(np.dtype('S')) + if data.dtype == np.dtype('O'): + dset = cls._get_object_dset(group, data, key, chunks, **kwds) + + else: + got_data = False + while not got_data: + try: + these_kwds = kwds.copy() + these_kwds.update(dict(shape=data.shape, + dtype=data.dtype, + exact=True, + chunks=chunks, + )) + + # If chunks is True, the `chunks` attribute of `dset` below + # contains the chunk shape guessed by h5py + dset = group.require_dataset(key, **these_kwds) + got_data = True + except TypeError: + # if the shape or dtype/etc do not match, + # we delete the old one and create new in the next loop run + del group[key] + if dset == data: + # just a reference to already created thing + pass + else: + _logger.info(f"Chunks used for saving: {chunks}") + cls._store_data(data, dset, group, key, chunks, **kwds) + def write(self): self.write_signal(self.signal, diff --git a/hyperspy/io_plugins/hspy.py b/hyperspy/io_plugins/hspy.py index a8fe412ebc..911cc47128 100644 --- a/hyperspy/io_plugins/hspy.py +++ b/hyperspy/io_plugins/hspy.py @@ -19,13 +19,15 @@ from packaging.version import Version import warnings import logging -from functools import partial + + +import dask.array as da import h5py import numpy as np -import dask.array as da -from h5py import Dataset, File, Group -from hyperspy.io_plugins.hierarchical import HierarchicalWriter, HierarchicalReader, _overwrite_dataset, version +from hyperspy.io_plugins._hierarchical import ( + HierarchicalWriter, HierarchicalReader, version + ) from hyperspy.misc.utils import multiply _logger = logging.getLogger(__name__) @@ -68,7 +70,7 @@ # The experiment group contains a number of attributes that will be # directly assigned as class attributes of the Signal instance. In # addition the experiment groups may contain 'original_metadata' and -# 'metadata'-subgroup that will be assigned to the same name attributes +# 'metadata'-subgroup that will be assigned to the same name attributes # of the Signal instance as a Dictionary Browser. # The Experiments group can contain attributes that may be common to all # the experiments and that will be accessible as attributes of the @@ -102,9 +104,7 @@ default_version = Version(version) -def get_signal_chunks(shape, - dtype, - signal_axes=None): +def get_signal_chunks(shape, dtype, signal_axes=None): """Function that calculates chunks for the signal, preferably at least one chunk per signal space. @@ -150,53 +150,11 @@ def get_signal_chunks(shape, return tuple(int(x) for x in chunks) -def _get_object_dset(group, data, key, chunks, **kwds): - """Creates a Dataset or Zarr Array object for saving ragged data - Parameters - ---------- - self - group - data - key - chunks - kwds - - Returns - ------- - - """ - # For saving ragged array - if chunks is None: - chunks = 1 - dset = group.require_dataset(key, - chunks, - dtype=h5py.special_dtype(vlen=data[0].dtype), - **kwds) - return dset - - -def _store_data(data, dset, group, key, chunks, **kwds): - if isinstance(data, da.Array): - if data.chunks != dset.chunks: - data = data.rechunk(dset.chunks) - da.store(data, dset) - elif data.flags.c_contiguous: - dset.write_direct(data) - else: - dset[:] = data - - -overwrite_dataset = partial(_overwrite_dataset, - get_signal_chunks=get_signal_chunks, - get_object_dset=_get_object_dset, - store_data=_store_data) - - class HyperspyReader(HierarchicalReader): def __init__(self, file): super().__init__(file) - self.Dataset = Dataset - self.Group = Group + self.Dataset = h5py.Dataset + self.Group = h5py.Group self.unicode_kwds = {"dtype": h5py.special_dtype(vlen=str)} @@ -213,11 +171,38 @@ def __init__(self, signal, expg, **kwds) - self.Dataset = Dataset - self.Group = Group + self.Dataset = h5py.Dataset + self.Group = h5py.Group self.unicode_kwds = {"dtype": h5py.special_dtype(vlen=str)} self.ragged_kwds = {"dtype": h5py.special_dtype(vlen=signal.data[0].dtype)} - self.overwrite_dataset = overwrite_dataset + + @staticmethod + def _get_signal_chunks(shape, dtype, signal_axes=None): + return get_signal_chunks(shape, dtype, signal_axes) + + @staticmethod + def _store_data(data, dset, group, key, chunks, **kwds): + if isinstance(data, da.Array): + if data.chunks != dset.chunks: + data = data.rechunk(dset.chunks) + da.store(data, dset) + elif data.flags.c_contiguous: + dset.write_direct(data) + else: + dset[:] = data + + @staticmethod + def _get_object_dset(group, data, key, chunks, **kwds): + """Creates a h5py dataset object for saving ragged data""" + # For saving ragged array + if chunks is None: + chunks = 1 + dset = group.require_dataset(key, + chunks, + dtype=h5py.special_dtype(vlen=data[0].dtype), + **kwds) + return dset + def file_reader( filename, @@ -238,7 +223,7 @@ def file_reader( except ImportError: pass mode = kwds.pop('mode', 'r') - f = File(filename, mode=mode, **kwds) + f = h5py.File(filename, mode=mode, **kwds) # Getting the format version here also checks if it is a valid HSpy # hdf5 file, so the following two lines must not be deleted or moved # elsewhere. @@ -267,6 +252,9 @@ def file_writer(filename, signal, *args, **kwds): """ if 'compression' not in kwds: kwds['compression'] = 'gzip' + if "shuffle" not in kwds: + # Use shuffle by default to improve compression + kwds["shuffle"] = True with h5py.File(filename, mode='w') as f: f.attrs['file_format'] = "HyperSpy" f.attrs['file_format_version'] = version @@ -287,8 +275,6 @@ def file_writer(filename, signal, *args, **kwds): else: smd.record_by = "" try: - if "shuffle" not in kwds: - kwds["shuffle"] = True writer = HyperspyWriter(f, signal, expg, **kwds) writer.write() #write_signal(signal, expg, **kwds) diff --git a/hyperspy/io_plugins/nexus.py b/hyperspy/io_plugins/nexus.py index c27514a084..7f4fe2fff7 100644 --- a/hyperspy/io_plugins/nexus.py +++ b/hyperspy/io_plugins/nexus.py @@ -19,15 +19,19 @@ # import logging import warnings -import numpy as np -import dask.array as da import os + +import dask.array as da import h5py +import numpy as np import pprint import traits.api as t -from hyperspy.io_plugins.hspy import overwrite_dataset, get_signal_chunks + +from hyperspy.io_plugins.hspy import HyperspyWriter, get_signal_chunks from hyperspy.misc.utils import DictionaryTreeBrowser from hyperspy.exceptions import VisibleDeprecationWarning + + _logger = logging.getLogger(__name__) # Plugin characteristics @@ -46,6 +50,9 @@ # ---------------------- +overwrite_dataset = HyperspyWriter.overwrite_dataset + + def _byte_to_string(value): """Decode a byte string. diff --git a/hyperspy/io_plugins/zspy.py b/hyperspy/io_plugins/zspy.py index e2ccc19c2b..585133eeec 100644 --- a/hyperspy/io_plugins/zspy.py +++ b/hyperspy/io_plugins/zspy.py @@ -18,20 +18,18 @@ import warnings import logging -from functools import partial from packaging.version import Version -import zarr -from zarr import Array, Group -import numpy as np import dask.array as da -from hyperspy.io_plugins.hierarchical import version import numcodecs +import numpy as np +import zarr +from hyperspy.io_plugins._hierarchical import ( + HierarchicalWriter, HierarchicalReader, version + ) -from hyperspy.io_plugins.hierarchical import HierarchicalWriter, HierarchicalReader, _overwrite_dataset - _logger = logging.getLogger(__name__) @@ -76,82 +74,11 @@ # Experiments instance -def get_object_dset(group, data, key, chunks, **kwds): - """Overrides the hyperspy get object dset function for using zarr as the backend - """ - these_kwds = kwds.copy() - these_kwds.update(dict(dtype=object, - exact=True, - chunks=chunks)) - dset = group.require_dataset(key, - data.shape, - object_codec=numcodecs.VLenArray(int), - **these_kwds) - return dset - - -def _get_signal_chunks(shape, dtype, signal_axes=None): - """Function that calculates chunks for the signal, - preferably at least one chunk per signal space. - Parameters - ---------- - shape : tuple - the shape of the dataset to be sored / chunked - dtype : {dtype, string} - the numpy dtype of the data - signal_axes: {None, iterable of ints} - the axes defining "signal space" of the dataset. If None, the default - zarr chunking is performed. - """ - typesize = np.dtype(dtype).itemsize - if signal_axes is None: - return None - # chunk size larger than 1 Mb https://zarr.readthedocs.io/en/stable/tutorial.html#chunk-optimizations - # shooting for 100 Mb chunks - total_size = np.prod(shape) * typesize - if total_size < 1e8: # 1 mb - return None - - -def _store_data(data, - dset, - group, - key, - chunks, - **kwds): - """Overrides the hyperspy store data function for using zarr as the backend - """ - if isinstance(data, da.Array): - if data.chunks != dset.chunks: - data = data.rechunk(dset.chunks) - path = group._store.dir_path() + "/" + dset.path - data.to_zarr(url=path, - overwrite=True, - **kwds) # add in compression etc - elif data.dtype == np.dtype('O'): - group[key][:] = data[:] # check lazy - else: - path = group._store.dir_path() + "/" + dset.path - dset = zarr.open_array(path, - mode="w", - shape=data.shape, - dtype=data.dtype, - chunks=chunks, - **kwds) - dset[:] = data - - -overwrite_dataset = partial(_overwrite_dataset, - get_signal_chunks=_get_signal_chunks, - get_object_dset=get_object_dset, - store_data=_store_data) - - class ZspyReader(HierarchicalReader): def __init__(self, file): super(ZspyReader, self).__init__(file) - self.Dataset = Array - self.Group = Group + self.Dataset = zarr.Array + self.Group = zarr.Group class ZspyWriter(HierarchicalWriter): @@ -160,12 +87,70 @@ def __init__(self, signal, expg, **kwargs): super().__init__(file, signal, expg, **kwargs) - self.Dataset = Array + self.Dataset = zarr.Array self.unicode_kwds = {"dtype": object, "object_codec": numcodecs.JSON()} self.ragged_kwds = {"dtype": object, "object_codec": numcodecs.VLenArray(int), "exact": True} - self.overwrite_dataset = overwrite_dataset + + @staticmethod + def _get_signal_chunks(shape, dtype, signal_axes=None): + """Function that calculates chunks for the signal, + preferably at least one chunk per signal space. + + Parameters + ---------- + shape : tuple + the shape of the dataset to be sored / chunked + dtype : {dtype, string} + the numpy dtype of the data + signal_axes: {None, iterable of ints} + the axes defining "signal space" of the dataset. If None, the default + zarr chunking is performed. + """ + typesize = np.dtype(dtype).itemsize + if signal_axes is None: + return None + # chunk size larger than 1 Mb, shooting for 100 Mb chunks, see + # https://zarr.readthedocs.io/en/stable/tutorial.html#chunk-optimizations + total_size = np.prod(shape) * typesize + if total_size < 1e8: # 1 mb + return None + + @staticmethod + def _get_object_dset(group, data, key, chunks, **kwds): + """Creates a Zarr Array object for saving ragged data""" + these_kwds = kwds.copy() + these_kwds.update(dict(dtype=object, + exact=True, + chunks=chunks)) + dset = group.require_dataset(key, + data.shape, + object_codec=numcodecs.VLenArray(int), + **these_kwds) + return dset + + @staticmethod + def _store_data(data, dset, group, key, chunks, **kwds): + """Overrides the hyperspy store data function for using zarr format.""" + if isinstance(data, da.Array): + if data.chunks != dset.chunks: + data = data.rechunk(dset.chunks) + path = group._store.dir_path() + "/" + dset.path + data.to_zarr(url=path, + overwrite=True, + **kwds) # add in compression etc + elif data.dtype == np.dtype('O'): + group[key][:] = data[:] # check lazy + else: + path = group._store.dir_path() + "/" + dset.path + dset = zarr.open_array(path, + mode="w", + shape=data.shape, + dtype=data.dtype, + chunks=chunks, + **kwds) + dset[:] = data def file_writer(filename, diff --git a/hyperspy/tests/io/test_zspy.py b/hyperspy/tests/io/test_zspy.py index 427785d75e..129ba7ab39 100644 --- a/hyperspy/tests/io/test_zspy.py +++ b/hyperspy/tests/io/test_zspy.py @@ -16,10 +16,7 @@ # You should have received a copy of the GNU General Public License # along with HyperSpy. If not, see . -import gc import os.path -import tempfile -from os import remove import dask.array as da import numpy as np @@ -33,57 +30,42 @@ my_path = os.path.dirname(__file__) -class TestLoadingOOMReadOnly: +def test_lazy_loading_read_only(tmp_path): + s = BaseSignal(np.empty((5, 5, 5))) + fname = tmp_path / 'tmp.zspy' + s.save(fname, overwrite=True) + shape = (10000, 10000, 100) + del s + f = zarr.open(fname, mode='r+') + group = f['Experiments/__unnamed__'] + del group['data'] + group.create_dataset('data', shape=shape, dtype=float, chunks=True) - def setup_method(self, method): - s = BaseSignal(np.empty((5, 5, 5))) - s.save('tmp.zspy', overwrite=True) - self.shape = (10000, 10000, 100) - del s - f = zarr.open('tmp.zspy', mode='r+') - s = f['Experiments/__unnamed__'] - del s['data'] - s.create_dataset( - 'data', - shape=self.shape, - dtype='float64', - chunks=True) - - def test_oom_loading(self): - s = load('tmp.zspy', lazy=True) - assert self.shape == s.data.shape - assert isinstance(s.data, da.Array) - assert s._lazy - s.close_file() - - def teardown_method(self, method): - gc.collect() # Make sure any memmaps are closed first! - try: - remove('tmp.zspy') - except BaseException: - # Don't fail tests if we cannot remove - pass + s2 = load(fname, lazy=True) + assert shape == s2.data.shape + assert isinstance(s2.data, da.Array) + assert s2._lazy + s2.close_file() class TestZspy: + @pytest.fixture def signal(self): data = np.ones((10,10,10,10)) s = Signal1D(data) return s - def test_save_N5_type(self,signal): - with tempfile.TemporaryDirectory() as tmp: - filename = tmp + '/testmodels.zspy' + def test_save_N5_type(self, signal, tmp_path): + filename = tmp_path / 'testmodels.zspy' store = zarr.N5Store(path=filename) signal.save(store.path, write_to_storage=True) signal2 = load(filename) np.testing.assert_array_equal(signal2.data, signal.data) @pytest.mark.parametrize("overwrite",[None, True, False]) - def test_overwrite(self,signal,overwrite): - with tempfile.TemporaryDirectory() as tmp: - filename = tmp + '/testmodels.zspy' + def test_overwrite(self,signal, overwrite, tmp_path): + filename = tmp_path / 'testmodels.zspy' signal.save(filename=filename) signal2 = signal*2 signal2.save(filename=filename, overwrite=overwrite) @@ -94,16 +76,14 @@ def test_overwrite(self,signal,overwrite): else: np.testing.assert_array_equal(signal.data,load(filename).data) - - @pytest.mark.skip(reason="lmdb must be installed to test") - def test_save_lmdb_type(self, signal): - with tempfile.TemporaryDirectory() as tmp: - os.mkdir(tmp+"/testmodels.zspy") - filename = tmp + '/testmodels.zspy/' - store = zarr.LMDBStore(path=filename) - signal.save(store.path, write_to_storage=True) - signal2 = load(store.path) - np.testing.assert_array_equal(signal2.data, signal.data) + def test_save_lmdb_type(self, signal, tmp_path): + pytest.importorskip("lmdb") + path = tmp_path / "testmodels.zspy" + os.mkdir(path) + store = zarr.LMDBStore(path=path) + signal.save(store.path, write_to_storage=True) + signal2 = load(store.path) + np.testing.assert_array_equal(signal2.data, signal.data) def test_compression_opts(self, tmp_path): self.filename = tmp_path / 'testfile.zspy' @@ -116,7 +96,7 @@ def test_compression_opts(self, tmp_path): @pytest.mark.parametrize("compressor", (None, "default", "blosc")) def test_compression(self, compressor, tmp_path): - if compressor is "blosc": + if compressor == "blosc": from numcodecs import Blosc compressor = Blosc(cname='zstd', clevel=3, shuffle=Blosc.BITSHUFFLE) s = Signal1D(np.ones((3, 3))) From 9840bcc7a07a8f8a95b00e3941a73c90c7596efd Mon Sep 17 00:00:00 2001 From: shaw Date: Thu, 7 Oct 2021 12:13:04 -0500 Subject: [PATCH 25/31] Replaced get_signal_chunks_function --- hyperspy/io_plugins/_hierarchical.py | 54 ++++++++++++++++++++++++---- hyperspy/io_plugins/emd.py | 4 ++- hyperspy/io_plugins/hspy.py | 51 ++------------------------ hyperspy/io_plugins/nexus.py | 4 +-- hyperspy/io_plugins/zspy.py | 26 ++------------ hyperspy/tests/io/test_hdf5.py | 8 +++++ hyperspy/tests/io/test_zspy.py | 11 +----- 7 files changed, 66 insertions(+), 92 deletions(-) diff --git a/hyperspy/io_plugins/_hierarchical.py b/hyperspy/io_plugins/_hierarchical.py index f1d9cafa71..38dacc2058 100644 --- a/hyperspy/io_plugins/_hierarchical.py +++ b/hyperspy/io_plugins/_hierarchical.py @@ -10,6 +10,7 @@ from hyperspy.axes import AxesManager from hyperspy.misc.utils import ensure_unicode, get_object_package_info +from hyperspy.misc.utils import multiply version = "3.1" @@ -460,9 +461,49 @@ def __init__(self, self.kwds = kwds @staticmethod - def _get_signal_chunks(): - raise NotImplementedError( - "This method must be implemented by subclasses.") + def _get_signal_chunks(shape, dtype, signal_axes=None, target_size=1e6): + """Function that calculates chunks for the signal, preferably at least one + chunk per signal space. + + Parameters + ---------- + shape : tuple + the shape of the dataset to be sored / chunked + dtype : {dtype, string} + the numpy dtype of the data + signal_axes: {None, iterable of ints} + the axes defining "signal space" of the dataset. If None, the default + h5py chunking is performed. + target_size: int + The target number of bytes for one chunk + """ + typesize = np.dtype(dtype).itemsize + if signal_axes is None: + return h5py._hl.filters.guess_chunk(shape, None, typesize) + + # largely based on the guess_chunk in h5py + bytes_per_signal = multiply([shape[i] for i in signal_axes]) * typesize + signals_per_chunk = np.floor_divide(target_size,bytes_per_signal) + navigation_axes = tuple(i for i in range(len(shape)) if i not in + signal_axes) + num_nav_axes = len(navigation_axes) + num_signals = np.prod([shape[i] for i in navigation_axes]) + if signals_per_chunk < 2 or num_nav_axes==0: + # signal is larger than chunk max + chunks = [s if i in signal_axes else 1 for i, s in enumerate(shape)] + return tuple(chunks) + elif signals_per_chunk > num_signals: + return shape + else: + # signal is smaller than chunk max + sig_axes_chunk = np.floor(signals_per_chunk**(1/num_nav_axes)) + remainder = np.floor_divide(signals_per_chunk - (sig_axes_chunk**num_nav_axes), + sig_axes_chunk) + if remainder<0: + remainder =0 + chunks = [s if i in signal_axes else sig_axes_chunk for i, s in enumerate(shape)] + chunks[navigation_axes[0]] = chunks[navigation_axes[0]]+remainder + return tuple(int(x) for x in chunks) @staticmethod def _get_object_dset(): @@ -490,8 +531,9 @@ def overwrite_dataset(cls, group, data, key, signal_axes=None, signal_axes: tuple The indexes of the signal axes chunks: tuple - The chunks for the dataset. If None then get_signal_chunks will be - called + The chunks for the dataset. If ``None`` and saving lazy signal, the chunks + of the dask array will be used otherwise the chunks will determine by + ``get_signal_chunks``. kwds: Any additional keywords for to be passed to the store data function The store data function is passed for each hierarchical data format @@ -504,7 +546,7 @@ def overwrite_dataset(cls, group, data, key, signal_axes=None, # If signal_axes=None, use automatic h5py chunking, otherwise # optimise the chunking to contain at least one signal per chunk chunks = cls._get_signal_chunks( - data.shape, data.dtype, signal_axes + data.shape, data.dtype, signal_axes, cls.target_size ) if np.issubdtype(data.dtype, np.dtype('U')): # Saving numpy unicode type is not supported in h5py diff --git a/hyperspy/io_plugins/emd.py b/hyperspy/io_plugins/emd.py index aeac092490..b292bb30d0 100644 --- a/hyperspy/io_plugins/emd.py +++ b/hyperspy/io_plugins/emd.py @@ -43,7 +43,9 @@ from hyperspy.exceptions import VisibleDeprecationWarning from hyperspy.misc.elements import atomic_number2name import hyperspy.misc.io.fei_stream_readers as stream_readers -from hyperspy.io_plugins.hspy import get_signal_chunks +from hyperspy.io_plugins.hspy import HyperspyWriter + +get_signal_chunks = HyperspyWriter._get_signal_chunks # Plugin characteristics diff --git a/hyperspy/io_plugins/hspy.py b/hyperspy/io_plugins/hspy.py index 911cc47128..d223ddb885 100644 --- a/hyperspy/io_plugins/hspy.py +++ b/hyperspy/io_plugins/hspy.py @@ -104,52 +104,6 @@ default_version = Version(version) -def get_signal_chunks(shape, dtype, signal_axes=None): - """Function that calculates chunks for the signal, preferably at least one - chunk per signal space. - - Parameters - ---------- - shape : tuple - the shape of the dataset to be sored / chunked - dtype : {dtype, string} - the numpy dtype of the data - signal_axes: {None, iterable of ints} - the axes defining "signal space" of the dataset. If None, the default - h5py chunking is performed. - """ - typesize = np.dtype(dtype).itemsize - if signal_axes is None: - return h5py._hl.filters.guess_chunk(shape, None, typesize) - - # largely based on the guess_chunk in h5py - CHUNK_MAX = 1024 * 1024 - want_to_keep = multiply([shape[i] for i in signal_axes]) * typesize - if want_to_keep >= CHUNK_MAX: - chunks = [1 for _ in shape] - for i in signal_axes: - chunks[i] = shape[i] - return tuple(chunks) - - chunks = [i for i in shape] - idx = 0 - navigation_axes = tuple(i for i in range(len(shape)) if i not in - signal_axes) - nchange = len(navigation_axes) - while True: - chunk_bytes = multiply(chunks) * typesize - - if chunk_bytes < CHUNK_MAX: - break - - if multiply([chunks[i] for i in navigation_axes]) == 1: - break - change = navigation_axes[idx % nchange] - chunks[change] = np.ceil(chunks[change] / 2.0) - idx += 1 - return tuple(int(x) for x in chunks) - - class HyperspyReader(HierarchicalReader): def __init__(self, file): super().__init__(file) @@ -162,6 +116,8 @@ class HyperspyWriter(HierarchicalWriter): """An object used to simplify and orgainize the process for writing a hyperspy signal. (.hspy format) """ + target_size = 1e6 + def __init__(self, file, signal, @@ -176,9 +132,6 @@ def __init__(self, self.unicode_kwds = {"dtype": h5py.special_dtype(vlen=str)} self.ragged_kwds = {"dtype": h5py.special_dtype(vlen=signal.data[0].dtype)} - @staticmethod - def _get_signal_chunks(shape, dtype, signal_axes=None): - return get_signal_chunks(shape, dtype, signal_axes) @staticmethod def _store_data(data, dset, group, key, chunks, **kwds): diff --git a/hyperspy/io_plugins/nexus.py b/hyperspy/io_plugins/nexus.py index 7f4fe2fff7..5cf68247c5 100644 --- a/hyperspy/io_plugins/nexus.py +++ b/hyperspy/io_plugins/nexus.py @@ -27,7 +27,7 @@ import pprint import traits.api as t -from hyperspy.io_plugins.hspy import HyperspyWriter, get_signal_chunks +from hyperspy.io_plugins.hspy import HyperspyWriter from hyperspy.misc.utils import DictionaryTreeBrowser from hyperspy.exceptions import VisibleDeprecationWarning @@ -51,7 +51,7 @@ overwrite_dataset = HyperspyWriter.overwrite_dataset - +get_signal_chunks = HyperspyWriter._get_signal_chunks def _byte_to_string(value): """Decode a byte string. diff --git a/hyperspy/io_plugins/zspy.py b/hyperspy/io_plugins/zspy.py index 585133eeec..88993db5ec 100644 --- a/hyperspy/io_plugins/zspy.py +++ b/hyperspy/io_plugins/zspy.py @@ -82,6 +82,7 @@ def __init__(self, file): class ZspyWriter(HierarchicalWriter): + target_size = 1e8 def __init__(self, file, signal, @@ -92,30 +93,7 @@ def __init__(self, self.ragged_kwds = {"dtype": object, "object_codec": numcodecs.VLenArray(int), "exact": True} - - @staticmethod - def _get_signal_chunks(shape, dtype, signal_axes=None): - """Function that calculates chunks for the signal, - preferably at least one chunk per signal space. - - Parameters - ---------- - shape : tuple - the shape of the dataset to be sored / chunked - dtype : {dtype, string} - the numpy dtype of the data - signal_axes: {None, iterable of ints} - the axes defining "signal space" of the dataset. If None, the default - zarr chunking is performed. - """ - typesize = np.dtype(dtype).itemsize - if signal_axes is None: - return None - # chunk size larger than 1 Mb, shooting for 100 Mb chunks, see - # https://zarr.readthedocs.io/en/stable/tutorial.html#chunk-optimizations - total_size = np.prod(shape) * typesize - if total_size < 1e8: # 1 mb - return None + self.target_size = 1e8 @staticmethod def _get_object_dset(group, data, key, chunks, **kwds): diff --git a/hyperspy/tests/io/test_hdf5.py b/hyperspy/tests/io/test_hdf5.py index 7f6e6ba082..3c96b544a1 100644 --- a/hyperspy/tests/io/test_hdf5.py +++ b/hyperspy/tests/io/test_hdf5.py @@ -41,6 +41,7 @@ from hyperspy.roi import Point2DROI from hyperspy.signal import BaseSignal from hyperspy.utils import markers +from hyperspy.io_plugins._hierarchical import HierarchicalWriter my_path = os.path.dirname(__file__) @@ -847,3 +848,10 @@ def test_chunking_saving_lazy_specify(self, tmp_path, file): s1 = load(filename, lazy=True) assert tuple([c[0] for c in s1.data.chunks]) == chunks +@pytest.mark.parametrize("target_size", (1e6,1e7)) +def test_get_signal_chunks(target_size): + chunks = HierarchicalWriter._get_signal_chunks(shape=[15, 15, 256, 256], + dtype=int, + signal_axes=(2, 3), + target_size=target_size) + assert (np.prod(chunks)*8 < target_size) \ No newline at end of file diff --git a/hyperspy/tests/io/test_zspy.py b/hyperspy/tests/io/test_zspy.py index 129ba7ab39..6f74681791 100644 --- a/hyperspy/tests/io/test_zspy.py +++ b/hyperspy/tests/io/test_zspy.py @@ -36,7 +36,7 @@ def test_lazy_loading_read_only(tmp_path): s.save(fname, overwrite=True) shape = (10000, 10000, 100) del s - f = zarr.open(fname, mode='r+') + f = zarr.open(fname.name, mode='r+') group = f['Experiments/__unnamed__'] del group['data'] group.create_dataset('data', shape=shape, dtype=float, chunks=True) @@ -76,15 +76,6 @@ def test_overwrite(self,signal, overwrite, tmp_path): else: np.testing.assert_array_equal(signal.data,load(filename).data) - def test_save_lmdb_type(self, signal, tmp_path): - pytest.importorskip("lmdb") - path = tmp_path / "testmodels.zspy" - os.mkdir(path) - store = zarr.LMDBStore(path=path) - signal.save(store.path, write_to_storage=True) - signal2 = load(store.path) - np.testing.assert_array_equal(signal2.data, signal.data) - def test_compression_opts(self, tmp_path): self.filename = tmp_path / 'testfile.zspy' from numcodecs import Blosc From 31d27b073006ed709ac62bbd25b3c2ad6e82295c Mon Sep 17 00:00:00 2001 From: shaw Date: Thu, 7 Oct 2021 14:20:27 -0500 Subject: [PATCH 26/31] Allow MutableMapping Objects to be passed in --- doc/user_guide/io.rst | 7 +++-- hyperspy/io.py | 56 +++++++++++++++++++--------------- hyperspy/io_plugins/zspy.py | 11 +++++-- hyperspy/signal.py | 9 ++++-- hyperspy/tests/io/test_zspy.py | 45 +++++++++++++++------------ 5 files changed, 76 insertions(+), 52 deletions(-) diff --git a/doc/user_guide/io.rst b/doc/user_guide/io.rst index e7eb82aed4..605f410c9d 100644 --- a/doc/user_guide/io.rst +++ b/doc/user_guide/io.rst @@ -461,9 +461,10 @@ Extra saving arguments .. note:: - Compression can significantly increase the saving speed. If file size is not - an issue, it can be disabled by setting ``compressor=None``. In general we recommend - compressing your datasets as it can greatly reduce i-o overhead + Lazy operations are often i-o bound, reading and writing the data creates a bottle neck in processes + due to the slow read write speed of many hard disks. In these cases, compressing your data is often + beneficial to the speed of some operation. Compression speeds up the process as there is less to + read/write with the trade off of slightly more computational work on the CPU." - ``write_to_storage``: The write to storage option allows you to pass the path to a directory (or database) and write directly to the storage container. This gives you access to the `different storage methods diff --git a/hyperspy/io.py b/hyperspy/io.py index f20081b1a0..3241ce8040 100644 --- a/hyperspy/io.py +++ b/hyperspy/io.py @@ -26,6 +26,7 @@ from natsort import natsorted from inspect import isgenerator from pathlib import Path +from collections import MutableMapping from hyperspy.drawing.marker import markers_metadata_dict_to_markers from hyperspy.exceptions import VisibleDeprecationWarning @@ -327,7 +328,6 @@ def load(filenames=None, lazy = load_ui.lazy if filenames is None: raise ValueError("No file provided to reader") - if isinstance(filenames, str): pattern = filenames if escape_square_brackets: @@ -734,11 +734,14 @@ def save(filename, signal, overwrite=None, **kwds): None """ - filename = Path(filename).resolve() - extension = filename.suffix - if extension == '': - extension = ".hspy" - filename = filename.with_suffix(extension) + if isinstance(filename, MutableMapping): + extension =".zspy" + else: + filename = Path(filename).resolve() + extension = filename.suffix + if extension == '': + extension = ".hspy" + filename = filename.with_suffix(extension) writer = None for plugin in io_plugins: @@ -783,25 +786,30 @@ def save(filename, signal, overwrite=None, **kwds): ) # Create the directory if it does not exist - ensure_directory(filename.parent) - is_file = filename.is_file() or (filename.is_dir() and - os.path.splitext(filename)[1] == '.zspy') - - if overwrite is None: - write = overwrite_method(filename) # Ask what to do - elif overwrite is True or (overwrite is False and not is_file): - write = True # Write the file - elif overwrite is False and is_file: - write = False # Don't write the file + if not isinstance(filename, MutableMapping): + ensure_directory(filename.parent) + is_file = filename.is_file() or (filename.is_dir() and + os.path.splitext(filename)[1] == '.zspy') + + if overwrite is None: + write = overwrite_method(filename) # Ask what to do + elif overwrite is True or (overwrite is False and not is_file): + write = True # Write the file + elif overwrite is False and is_file: + write = False # Don't write the file + else: + raise ValueError("`overwrite` parameter can only be None, True or " + "False.") else: - raise ValueError("`overwrite` parameter can only be None, True or " - "False.") + write=True if write: # Pass as a string for now, pathlib.Path not # properly supported in io_plugins - writer.file_writer(str(filename), signal, **kwds) - - _logger.info(f'{filename} was created') - signal.tmp_parameters.set_item('folder', filename.parent) - signal.tmp_parameters.set_item('filename', filename.stem) - signal.tmp_parameters.set_item('extension', extension) + if not isinstance(filename, MutableMapping): + writer.file_writer(str(filename), signal, **kwds) + _logger.info(f'{filename} was created') + signal.tmp_parameters.set_item('folder', filename.parent) + signal.tmp_parameters.set_item('filename', filename.stem) + signal.tmp_parameters.set_item('extension', extension) + else: + writer.file_writer(filename, signal, **kwds) \ No newline at end of file diff --git a/hyperspy/io_plugins/zspy.py b/hyperspy/io_plugins/zspy.py index 88993db5ec..89d1c03459 100644 --- a/hyperspy/io_plugins/zspy.py +++ b/hyperspy/io_plugins/zspy.py @@ -19,6 +19,7 @@ import warnings import logging from packaging.version import Version +from collections import MutableMapping import dask.array as da import numcodecs @@ -146,11 +147,15 @@ def file_writer(filename, if "compressor" not in kwds: from numcodecs import Blosc kwds["compressor"] = Blosc(cname='zstd', clevel=1) - if "write_to_storage" in kwds and kwds["write_to_storage"]: - f = zarr.open(filename) + if ("write_to_storage" in kwds and kwds["write_to_storage"]) or isinstance(filename, MutableMapping): + if isinstance(filename, MutableMapping): + store = filename.path + else: + store = filename else: store = zarr.storage.NestedDirectoryStore(filename,) - f = zarr.group(store=store, overwrite=True) + f = zarr.group(store=store, overwrite=True, cache_attrs=False) + print(f) f.attrs['file_format'] = "ZSpy" f.attrs['file_format_version'] = version exps = f.create_group('Experiments') diff --git a/hyperspy/signal.py b/hyperspy/signal.py index a67a26f4c6..ddda9b36f6 100644 --- a/hyperspy/signal.py +++ b/hyperspy/signal.py @@ -33,6 +33,7 @@ from matplotlib import pyplot as plt import traits.api as t import numbers +from collections import MutableMapping from hyperspy.axes import AxesManager from hyperspy import io @@ -2856,9 +2857,11 @@ def save(self, filename=None, overwrite=None, extension=None, else: raise ValueError('File name not defined') - filename = Path(filename) - if extension is not None: - filename = filename.with_suffix(f".{extension}") + print(type(filename)) + if not isinstance(filename, MutableMapping): + filename = Path(filename) + if extension is not None: + filename = filename.with_suffix(f".{extension}") io.save(filename, self, overwrite=overwrite, **kwds) def _replot(self): diff --git a/hyperspy/tests/io/test_zspy.py b/hyperspy/tests/io/test_zspy.py index 6f74681791..738568d3f2 100644 --- a/hyperspy/tests/io/test_zspy.py +++ b/hyperspy/tests/io/test_zspy.py @@ -16,7 +16,7 @@ # You should have received a copy of the GNU General Public License # along with HyperSpy. If not, see . -import os.path +import os import dask.array as da import numpy as np @@ -30,24 +30,6 @@ my_path = os.path.dirname(__file__) -def test_lazy_loading_read_only(tmp_path): - s = BaseSignal(np.empty((5, 5, 5))) - fname = tmp_path / 'tmp.zspy' - s.save(fname, overwrite=True) - shape = (10000, 10000, 100) - del s - f = zarr.open(fname.name, mode='r+') - group = f['Experiments/__unnamed__'] - del group['data'] - group.create_dataset('data', shape=shape, dtype=float, chunks=True) - - s2 = load(fname, lazy=True) - assert shape == s2.data.shape - assert isinstance(s2.data, da.Array) - assert s2._lazy - s2.close_file() - - class TestZspy: @pytest.fixture @@ -57,6 +39,13 @@ def signal(self): return s def test_save_N5_type(self, signal, tmp_path): + filename = tmp_path / 'testmodels.zspy' + store = zarr.N5Store(path=filename) + signal.save(store, write_to_storage=True) + signal2 = load(filename) + np.testing.assert_array_equal(signal2.data, signal.data) + + def test_save_N5_type_path(self, signal, tmp_path): filename = tmp_path / 'testmodels.zspy' store = zarr.N5Store(path=filename) signal.save(store.path, write_to_storage=True) @@ -95,3 +84,21 @@ def test_compression(self, compressor, tmp_path): overwrite=True, compressor=compressor) load(tmp_path / 'test_compression.zspy') + + def test_lazy_loading_read_only(self, tmp_path): + s = BaseSignal(np.ones((5, 5, 5))) + + fname = tmp_path / 'tmp.zspy' + s.save(fname, overwrite=True) + shape = (10000, 10000, 100) + del s + f = zarr.open(fname.name, mode='r+') + group = f['Experiments/__unnamed__'] + del group['data'] + group.create_dataset('data', shape=shape, dtype=float, chunks=True, overwrite=True) + + s2 = load(fname, lazy=True) + assert shape == s2.data.shape + assert isinstance(s2.data, da.Array) + assert s2._lazy + s2.close_file() From 0e98caefdbf3578186c4a586ce7abeb76e5f2214 Mon Sep 17 00:00:00 2001 From: cssfrancis Date: Mon, 18 Oct 2021 13:59:54 -0500 Subject: [PATCH 27/31] Cleaning up print statements and saving output --- hyperspy/io.py | 9 +++++++-- hyperspy/io_plugins/zspy.py | 3 +-- hyperspy/signal.py | 1 - hyperspy/tests/io/test_zspy.py | 15 ++++----------- 4 files changed, 12 insertions(+), 16 deletions(-) diff --git a/hyperspy/io.py b/hyperspy/io.py index 3241ce8040..3f300f8602 100644 --- a/hyperspy/io.py +++ b/hyperspy/io.py @@ -801,7 +801,7 @@ def save(filename, signal, overwrite=None, **kwds): raise ValueError("`overwrite` parameter can only be None, True or " "False.") else: - write=True + write = True # file does not exist (creating it) if write: # Pass as a string for now, pathlib.Path not # properly supported in io_plugins @@ -812,4 +812,9 @@ def save(filename, signal, overwrite=None, **kwds): signal.tmp_parameters.set_item('filename', filename.stem) signal.tmp_parameters.set_item('extension', extension) else: - writer.file_writer(filename, signal, **kwds) \ No newline at end of file + writer.file_writer(filename, signal, **kwds) + if hasattr(filename, "path"): + file = Path(filename.path).resolve() + signal.tmp_parameters.set_item('folder', file.parent) + signal.tmp_parameters.set_item('filename', file.stem) + signal.tmp_parameters.set_item('extension', extension) diff --git a/hyperspy/io_plugins/zspy.py b/hyperspy/io_plugins/zspy.py index 89d1c03459..31556803ef 100644 --- a/hyperspy/io_plugins/zspy.py +++ b/hyperspy/io_plugins/zspy.py @@ -154,8 +154,7 @@ def file_writer(filename, store = filename else: store = zarr.storage.NestedDirectoryStore(filename,) - f = zarr.group(store=store, overwrite=True, cache_attrs=False) - print(f) + f = zarr.group(store=store, overwrite=True) f.attrs['file_format'] = "ZSpy" f.attrs['file_format_version'] = version exps = f.create_group('Experiments') diff --git a/hyperspy/signal.py b/hyperspy/signal.py index ddda9b36f6..edc0317c27 100644 --- a/hyperspy/signal.py +++ b/hyperspy/signal.py @@ -2857,7 +2857,6 @@ def save(self, filename=None, overwrite=None, extension=None, else: raise ValueError('File name not defined') - print(type(filename)) if not isinstance(filename, MutableMapping): filename = Path(filename) if extension is not None: diff --git a/hyperspy/tests/io/test_zspy.py b/hyperspy/tests/io/test_zspy.py index 738568d3f2..7eea1f71e4 100644 --- a/hyperspy/tests/io/test_zspy.py +++ b/hyperspy/tests/io/test_zspy.py @@ -85,20 +85,13 @@ def test_compression(self, compressor, tmp_path): compressor=compressor) load(tmp_path / 'test_compression.zspy') - def test_lazy_loading_read_only(self, tmp_path): + def test_overwrite(self, tmp_path): s = BaseSignal(np.ones((5, 5, 5))) fname = tmp_path / 'tmp.zspy' s.save(fname, overwrite=True) - shape = (10000, 10000, 100) - del s - f = zarr.open(fname.name, mode='r+') - group = f['Experiments/__unnamed__'] - del group['data'] - group.create_dataset('data', shape=shape, dtype=float, chunks=True, overwrite=True) + shape = (10, 10, 10) + s2 = BaseSignal(np.ones(shape)) + s2.save(fname, overwrite=True) - s2 = load(fname, lazy=True) assert shape == s2.data.shape - assert isinstance(s2.data, da.Array) - assert s2._lazy - s2.close_file() From 88b38bcc74a3382fac85281cb7863246fc465ea5 Mon Sep 17 00:00:00 2001 From: Carter Francis Date: Tue, 19 Oct 2021 14:07:58 -0500 Subject: [PATCH 28/31] Update hyperspy/io_plugins/_hierarchical.py Co-authored-by: Eric Prestat --- hyperspy/io_plugins/_hierarchical.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hyperspy/io_plugins/_hierarchical.py b/hyperspy/io_plugins/_hierarchical.py index 38dacc2058..9778e9306c 100644 --- a/hyperspy/io_plugins/_hierarchical.py +++ b/hyperspy/io_plugins/_hierarchical.py @@ -50,7 +50,7 @@ def get_format_version(self): elif "Experiments" in self.file: # Chances are that this is a HSpy hdf5 file version 1.0 version = "1.0" - elif "Analysis" in self.f: + elif "Analysis" in self.file: # Starting version 2.0 we have "Analysis" field as well version = "2.0" else: From cb459582704ffe46ceffa3b7c5ded1ad5eb76422 Mon Sep 17 00:00:00 2001 From: shaw Date: Tue, 19 Oct 2021 16:30:57 -0500 Subject: [PATCH 29/31] Specify dtype for testing on 32bit operating systems --- hyperspy/io_plugins/_hierarchical.py | 2 +- hyperspy/tests/io/test_hdf5.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/hyperspy/io_plugins/_hierarchical.py b/hyperspy/io_plugins/_hierarchical.py index 38dacc2058..f7d78a5cb4 100644 --- a/hyperspy/io_plugins/_hierarchical.py +++ b/hyperspy/io_plugins/_hierarchical.py @@ -483,7 +483,7 @@ def _get_signal_chunks(shape, dtype, signal_axes=None, target_size=1e6): # largely based on the guess_chunk in h5py bytes_per_signal = multiply([shape[i] for i in signal_axes]) * typesize - signals_per_chunk = np.floor_divide(target_size,bytes_per_signal) + signals_per_chunk = np.floor_divide(target_size, bytes_per_signal) navigation_axes = tuple(i for i in range(len(shape)) if i not in signal_axes) num_nav_axes = len(navigation_axes) diff --git a/hyperspy/tests/io/test_hdf5.py b/hyperspy/tests/io/test_hdf5.py index 3c96b544a1..273d7128d0 100644 --- a/hyperspy/tests/io/test_hdf5.py +++ b/hyperspy/tests/io/test_hdf5.py @@ -851,7 +851,7 @@ def test_chunking_saving_lazy_specify(self, tmp_path, file): @pytest.mark.parametrize("target_size", (1e6,1e7)) def test_get_signal_chunks(target_size): chunks = HierarchicalWriter._get_signal_chunks(shape=[15, 15, 256, 256], - dtype=int, + dtype=np.int64, signal_axes=(2, 3), target_size=target_size) assert (np.prod(chunks)*8 < target_size) \ No newline at end of file From 3132e225d3127b0cafa4252161db456cd24f4c7c Mon Sep 17 00:00:00 2001 From: Eric Prestat Date: Wed, 20 Oct 2021 10:47:20 +0100 Subject: [PATCH 30/31] Add fsspec to conda environment and setup.py since this is not a dask dependency for dask version < 2021.3.1 --- conda_environment.yml | 4 +++- setup.py | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/conda_environment.yml b/conda_environment.yml index 4abc06a009..340cac56e9 100644 --- a/conda_environment.yml +++ b/conda_environment.yml @@ -2,8 +2,10 @@ name: test_env channels: - conda-forge dependencies: -- dask-core >2.0 +- dask-core >2.1 - dill +# dask-core < 2021.3.1 doesn't have fsspec as dependency +- fsspec - h5py - imageio - ipyparallel diff --git a/setup.py b/setup.py index b5b3b56cf2..790981ef08 100644 --- a/setup.py +++ b/setup.py @@ -59,6 +59,8 @@ 'python-dateutil>=2.5.0', 'ipyparallel', 'dask[array]>2.1.0', + # fsspec is missing from dask dependencies for dask < 2021.3.1 + 'fsspec', 'scikit-image>=0.15', 'pint>=0.10', 'numexpr', From 2b236f5537ea7786f062974932ec610c62dac8b4 Mon Sep 17 00:00:00 2001 From: Eric Prestat Date: Wed, 20 Oct 2021 12:13:48 +0100 Subject: [PATCH 31/31] Install zarr on x86_64 machines until pypi and conda packages are available for numcodecs on other platforms. Skip zspy tests when zarr not installed --- hyperspy/io_plugins/__init__.py | 11 +++- hyperspy/tests/io/test_hdf5.py | 98 +++++++++++++++++---------------- hyperspy/tests/io/test_zspy.py | 18 +----- setup.py | 3 +- 4 files changed, 65 insertions(+), 65 deletions(-) diff --git a/hyperspy/io_plugins/__init__.py b/hyperspy/io_plugins/__init__.py index 3e19e898f4..689c03ed97 100644 --- a/hyperspy/io_plugins/__init__.py +++ b/hyperspy/io_plugins/__init__.py @@ -40,7 +40,6 @@ semper_unf, sur, tiff, - zspy, ) io_plugins = [ @@ -64,7 +63,6 @@ semper_unf, sur, tiff, - zspy, ] @@ -100,6 +98,15 @@ "the mrcz package is not installed." ) +try: + from hyperspy.io_plugins import zspy + + io_plugins.append(zspy) +except ImportError: + _logger.info( + "The zspy IO plugin is not available because " + "the zarr package is not installed." + ) default_write_ext = set() for plugin in io_plugins: diff --git a/hyperspy/tests/io/test_hdf5.py b/hyperspy/tests/io/test_hdf5.py index 273d7128d0..35c5fdb3e1 100644 --- a/hyperspy/tests/io/test_hdf5.py +++ b/hyperspy/tests/io/test_hdf5.py @@ -19,7 +19,6 @@ import gc import os.path import sys -import tempfile import time from os import remove @@ -35,17 +34,24 @@ from hyperspy._signals.signal2d import Signal2D from hyperspy.datasets.example_signals import EDS_TEM_Spectrum from hyperspy.exceptions import VisibleDeprecationWarning -from hyperspy.io import load from hyperspy.misc.test_utils import assert_deep_almost_equal from hyperspy.misc.test_utils import sanitize_dict as san_dict from hyperspy.roi import Point2DROI -from hyperspy.signal import BaseSignal from hyperspy.utils import markers from hyperspy.io_plugins._hierarchical import HierarchicalWriter + my_path = os.path.dirname(__file__) +try: + # zarr (because of numcodecs) is only supported on x86_64 machines + import zarr + zspy_marker = pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) +except ImportError: + zspy_marker = pytest.mark.parametrize("file", ["test.hspy"]) + + data = np.array([4066., 3996., 3932., 3923., 5602., 5288., 7234., 7809., 4710., 5015., 4366., 4524., 4832., 5474., 5718., 5034., 4651., 4613., 4637., 4429., 4217.]) @@ -195,7 +201,7 @@ class TestSavingMetadataContainers: def setup_method(self, method): self.s = BaseSignal([0.1]) - @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + @zspy_marker def test_save_unicode(self, tmp_path, file): s = self.s s.metadata.set_item('test', ['a', 'b', '\u6f22\u5b57']) @@ -207,7 +213,7 @@ def test_save_unicode(self, tmp_path, file): assert isinstance(l.metadata.test[2], str) assert l.metadata.test[2] == '\u6f22\u5b57' - @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + @zspy_marker def test_save_long_list(self, tmp_path, file): s = self.s s.metadata.set_item('long_list', list(range(10000))) @@ -217,7 +223,7 @@ def test_save_long_list(self, tmp_path, file): end = time.time() assert end - start < 1.0 # It should finish in less that 1 s. - @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + @zspy_marker def test_numpy_only_inner_lists(self, tmp_path, file): s = self.s s.metadata.set_item('test', [[1., 2], ('3', 4)]) @@ -230,7 +236,7 @@ def test_numpy_only_inner_lists(self, tmp_path, file): @pytest.mark.xfail(sys.platform == 'win32', reason="randomly fails in win32") - @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + @zspy_marker def test_numpy_general_type(self, tmp_path, file): s = self.s s.metadata.set_item('test', np.array([[1., 2], ['3', 4]])) @@ -241,7 +247,7 @@ def test_numpy_general_type(self, tmp_path, file): @pytest.mark.xfail(sys.platform == 'win32', reason="randomly fails in win32") - @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + @zspy_marker def test_list_general_type(self, tmp_path, file): s = self.s s.metadata.set_item('test', [[1., 2], ['3', 4]]) @@ -255,7 +261,7 @@ def test_list_general_type(self, tmp_path, file): @pytest.mark.xfail(sys.platform == 'win32', reason="randomly fails in win32") - @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + @zspy_marker def test_general_type_not_working(self, tmp_path, file): s = self.s s.metadata.set_item('test', (BaseSignal([1]), 0.1, 'test_string')) @@ -267,7 +273,7 @@ def test_general_type_not_working(self, tmp_path, file): assert isinstance(l.metadata.test[1], float) assert isinstance(l.metadata.test[2], str) - @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + @zspy_marker def test_unsupported_type(self, tmp_path, file): s = self.s s.metadata.set_item('test', Point2DROI(1, 2)) @@ -276,7 +282,7 @@ def test_unsupported_type(self, tmp_path, file): l = load(fname) assert 'test' not in l.metadata - @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + @zspy_marker def test_date_time(self, tmp_path, file): s = self.s date, time = "2016-08-05", "15:00:00.450" @@ -288,7 +294,7 @@ def test_date_time(self, tmp_path, file): assert l.metadata.General.date == date assert l.metadata.General.time == time - @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + @zspy_marker def test_general_metadata(self, tmp_path, file): s = self.s notes = "Dummy notes" @@ -304,7 +310,7 @@ def test_general_metadata(self, tmp_path, file): assert l.metadata.General.authors == authors assert l.metadata.General.doi == doi - @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + @zspy_marker def test_quantity(self, tmp_path, file): s = self.s quantity = "Intensity (electron)" @@ -314,7 +320,7 @@ def test_quantity(self, tmp_path, file): l = load(fname) assert l.metadata.Signal.quantity == quantity - @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + @zspy_marker def test_save_axes_manager(self, tmp_path, file): s = self.s s.metadata.set_item('test', s.axes_manager) @@ -324,7 +330,7 @@ def test_save_axes_manager(self, tmp_path, file): #strange becuase you need the encoding... assert isinstance(l.metadata.test, AxesManager) - @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + @zspy_marker def test_title(self, tmp_path, file): s = self.s fname = tmp_path / file @@ -333,7 +339,7 @@ def test_title(self, tmp_path, file): l = load(fname) assert l.metadata.General.title is "" - @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + @zspy_marker def test_save_empty_tuple(self, tmp_path, file): s = self.s s.metadata.set_item('test', ()) @@ -343,7 +349,7 @@ def test_save_empty_tuple(self, tmp_path, file): #strange becuase you need the encoding... assert l.metadata.test == s.metadata.test - @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + @zspy_marker def test_save_bytes(self, tmp_path,file): s = self.s byte_message = bytes("testing", 'utf-8') @@ -358,7 +364,7 @@ def test_metadata_binned_deprecate(self): s = load(os.path.join(my_path, "hdf5_files", 'example2_v2.2.hspy')) assert s.metadata.has_item('Signal.binned') == False assert s.axes_manager[-1].is_binned == False - + def test_metadata_update_to_v3_1(self): md = {'Acquisition_instrument': {'SEM': {'Stage': {'tilt_alpha': 5.0}}, @@ -399,7 +405,7 @@ def test_rgba16(): data = np.load(os.path.join( my_path, "npy_files", "test_rgba16.npy")) assert (s.data == data).all() -@pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) +@zspy_marker def test_nonuniformaxis(tmp_path, file): fname = tmp_path / file data = np.arange(10) @@ -407,13 +413,13 @@ def test_nonuniformaxis(tmp_path, file): s = Signal1D(data, axes=(axis.get_axis_dictionary(), )) s.save(fname, overwrite=True) s2 = load(fname) - np.testing.assert_array_almost_equal(s.axes_manager[0].axis, + np.testing.assert_array_almost_equal(s.axes_manager[0].axis, s2.axes_manager[0].axis) assert(s2.axes_manager[0].is_uniform == False) assert(s2.axes_manager[0].navigate == False) assert(s2.axes_manager[0].size == data.size) -@pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) +@zspy_marker def test_nonuniformFDA(tmp_path, file): fname = tmp_path / file data = np.arange(10) @@ -423,13 +429,13 @@ def test_nonuniformFDA(tmp_path, file): print(axis.get_axis_dictionary()) s.save(fname, overwrite=True) s2 = load(fname) - np.testing.assert_array_almost_equal(s.axes_manager[0].axis, + np.testing.assert_array_almost_equal(s.axes_manager[0].axis, s2.axes_manager[0].axis) assert(s2.axes_manager[0].is_uniform == False) assert(s2.axes_manager[0].navigate == False) - assert(s2.axes_manager[0].size == data.size) + assert(s2.axes_manager[0].size == data.size) + - class TestLoadingOOMReadOnly: def setup_method(self, method): @@ -481,7 +487,7 @@ def teardown_method(self, method): class TestAxesConfiguration: - @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + @zspy_marker def test_axes_binning(self, tmp_path, file): fname = tmp_path / file s = BaseSignal(np.zeros((2, 2, 2))) @@ -490,7 +496,7 @@ def test_axes_binning(self, tmp_path, file): s = load(fname) assert s.axes_manager.signal_axes[-1].is_binned - @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + @zspy_marker def test_axes_configuration(self, tmp_path, file): fname = tmp_path / file s = BaseSignal(np.zeros((2, 2, 2, 2, 2))) @@ -505,7 +511,7 @@ def test_axes_configuration(self, tmp_path, file): class Test_permanent_markers_io: - @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + @zspy_marker def test_save_permanent_marker(self, tmp_path, file): filename = tmp_path / file s = Signal2D(np.arange(100).reshape(10, 10)) @@ -513,7 +519,7 @@ def test_save_permanent_marker(self, tmp_path, file): s.add_marker(m, permanent=True) s.save(filename) - @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + @zspy_marker def test_save_load_empty_metadata_markers(self, tmp_path, file): filename = tmp_path / file s = Signal2D(np.arange(100).reshape(10, 10)) @@ -525,7 +531,7 @@ def test_save_load_empty_metadata_markers(self, tmp_path, file): s1 = load(filename) assert len(s1.metadata.Markers) == 0 - @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + @zspy_marker def test_save_load_permanent_marker(self, tmp_path, file): filename = tmp_path / file x, y = 5, 2 @@ -546,7 +552,7 @@ def test_save_load_permanent_marker(self, tmp_path, file): assert m1.marker_properties['color'] == color assert m1.name == name - @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + @zspy_marker def test_save_load_permanent_marker_all_types(self, tmp_path, file): filename = tmp_path / file x1, y1, x2, y2 = 5, 2, 1, 8 @@ -576,7 +582,7 @@ def test_save_load_permanent_marker_all_types(self, tmp_path, file): for m0_dict, m1_dict in zip(m0_dict_list, m1_dict_list): assert m0_dict == m1_dict - @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + @zspy_marker def test_save_load_horizontal_line_marker(self,tmp_path,file): filename = tmp_path / file y = 8 @@ -592,7 +598,7 @@ def test_save_load_horizontal_line_marker(self,tmp_path,file): m1 = s1.metadata.Markers.get_item(name) assert san_dict(m1._to_dictionary()) == san_dict(m._to_dictionary()) - @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + @zspy_marker def test_save_load_horizontal_line_segment_marker(self, tmp_path, file): filename = tmp_path / file x1, x2, y = 1, 5, 8 @@ -609,7 +615,7 @@ def test_save_load_horizontal_line_segment_marker(self, tmp_path, file): m1 = s1.metadata.Markers.get_item(name) assert san_dict(m1._to_dictionary()) == san_dict(m._to_dictionary()) - @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + @zspy_marker def test_save_load_vertical_line_marker(self, tmp_path, file): filename = tmp_path / file x = 9 @@ -625,7 +631,7 @@ def test_save_load_vertical_line_marker(self, tmp_path, file): m1 = s1.metadata.Markers.get_item(name) assert san_dict(m1._to_dictionary()) == san_dict(m._to_dictionary()) - @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + @zspy_marker def test_save_load_vertical_line_segment_marker(self, tmp_path, file): filename = tmp_path / file x, y1, y2 = 2, 1, 3 @@ -642,7 +648,7 @@ def test_save_load_vertical_line_segment_marker(self, tmp_path, file): m1 = s1.metadata.Markers.get_item(name) assert san_dict(m1._to_dictionary()) == san_dict(m._to_dictionary()) - @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + @zspy_marker def test_save_load_line_segment_marker(self, tmp_path, file): filename = tmp_path / file x1, x2, y1, y2 = 1, 9, 4, 7 @@ -659,7 +665,7 @@ def test_save_load_line_segment_marker(self, tmp_path, file): m1 = s1.metadata.Markers.get_item(name) assert san_dict(m1._to_dictionary()) == san_dict(m._to_dictionary()) - @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + @zspy_marker def test_save_load_point_marker(self, tmp_path, file): filename = tmp_path / file x, y = 9, 8 @@ -675,7 +681,7 @@ def test_save_load_point_marker(self, tmp_path, file): m1 = s1.metadata.Markers.get_item(name) assert san_dict(m1._to_dictionary()) == san_dict(m._to_dictionary()) - @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + @zspy_marker def test_save_load_rectangle_marker(self, tmp_path, file): filename = tmp_path / file x1, x2, y1, y2 = 2, 4, 1, 3 @@ -692,7 +698,7 @@ def test_save_load_rectangle_marker(self, tmp_path, file): m1 = s1.metadata.Markers.get_item(name) assert san_dict(m1._to_dictionary()) == san_dict(m._to_dictionary()) - @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + @zspy_marker def test_save_load_text_marker(self, tmp_path, file): filename = tmp_path / file x, y = 3, 9.5 @@ -709,7 +715,7 @@ def test_save_load_text_marker(self, tmp_path, file): m1 = s1.metadata.Markers.get_item(name) assert san_dict(m1._to_dictionary()) == san_dict(m._to_dictionary()) - @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + @zspy_marker def test_save_load_multidim_navigation_marker(self, tmp_path, file): filename = tmp_path / file x, y = (1, 2, 3), (5, 6, 7) @@ -754,7 +760,7 @@ def test_load_missing_y2_value(self): assert len(s.metadata.Markers) == 5 -@pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) +@zspy_marker def test_save_load_model(tmp_path, file): from hyperspy._components.gaussian import Gaussian filename = tmp_path / file @@ -780,7 +786,7 @@ def test_strings_from_py2(): s = EDS_TEM_Spectrum() assert isinstance(s.metadata.Sample.elements, list) -@pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) +@zspy_marker def test_save_ragged_array(tmp_path, file): a = np.array([0, 1]) b = np.array([0, 1, 2]) @@ -802,7 +808,7 @@ def test_load_missing_extension(caplog): _ = s.models.restore("a") class TestChunking: - @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + @zspy_marker def test_save_chunks_signal_metadata(self, tmp_path, file): N = 10 dim = 3 @@ -816,7 +822,7 @@ def test_save_chunks_signal_metadata(self, tmp_path, file): s2 = load(filename, lazy=True) assert tuple([c[0] for c in s2.data.chunks]) == chunks - @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + @zspy_marker def test_chunking_saving_lazy(self, tmp_path, file): filename = tmp_path / file s = Signal2D(da.zeros((50, 100, 100))).as_lazy() @@ -825,7 +831,7 @@ def test_chunking_saving_lazy(self, tmp_path, file): s1 = load(filename, lazy=True) assert s.data.chunks == s1.data.chunks - @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + @zspy_marker def test_chunking_saving_lazy_True(self, tmp_path, file): filename = tmp_path / file s = Signal2D(da.zeros((50, 100, 100))).as_lazy() @@ -837,7 +843,7 @@ def test_chunking_saving_lazy_True(self, tmp_path, file): else: assert tuple([c[0] for c in s1.data.chunks]) == (25, 50, 50) - @pytest.mark.parametrize("file", ["test.hspy", "test.zspy"]) + @zspy_marker def test_chunking_saving_lazy_specify(self, tmp_path, file): filename = tmp_path / file s = Signal2D(da.zeros((50, 100, 100))).as_lazy() @@ -854,4 +860,4 @@ def test_get_signal_chunks(target_size): dtype=np.int64, signal_axes=(2, 3), target_size=target_size) - assert (np.prod(chunks)*8 < target_size) \ No newline at end of file + assert (np.prod(chunks)*8 < target_size) diff --git a/hyperspy/tests/io/test_zspy.py b/hyperspy/tests/io/test_zspy.py index 7eea1f71e4..247110541b 100644 --- a/hyperspy/tests/io/test_zspy.py +++ b/hyperspy/tests/io/test_zspy.py @@ -16,18 +16,15 @@ # You should have received a copy of the GNU General Public License # along with HyperSpy. If not, see . -import os - -import dask.array as da import numpy as np import pytest -import zarr from hyperspy._signals.signal1d import Signal1D from hyperspy.io import load from hyperspy.signal import BaseSignal -my_path = os.path.dirname(__file__) +# zarr (because of numcodecs) is only supported on x86_64 machines +zarr = pytest.importorskip("zarr", reason="zarr not installed") class TestZspy: @@ -84,14 +81,3 @@ def test_compression(self, compressor, tmp_path): overwrite=True, compressor=compressor) load(tmp_path / 'test_compression.zspy') - - def test_overwrite(self, tmp_path): - s = BaseSignal(np.ones((5, 5, 5))) - - fname = tmp_path / 'tmp.zspy' - s.save(fname, overwrite=True) - shape = (10, 10, 10) - s2 = BaseSignal(np.ones(shape)) - s2.save(fname, overwrite=True) - - assert shape == s2.data.shape diff --git a/setup.py b/setup.py index 790981ef08..eb81df2203 100644 --- a/setup.py +++ b/setup.py @@ -75,7 +75,8 @@ # included in stdlib since v3.8, but this required version requires Python 3.10 # We can remove this requirement when the minimum supported version becomes Python 3.10 'importlib_metadata>=3.6', - 'zarr' + # numcodecs currently only supported on x86_64/AMD64 machines + 'zarr;platform_machine=="x86_64" or platform_machine=="AMD64"', ] extras_require = {