Skip to content

Commit

Permalink
Merge pull request #2906 from CSSFrancis/saving_ragged_arrays_dim2
Browse files Browse the repository at this point in the history
BugFix: Basic way to save ragged arrays with ndim>1
  • Loading branch information
ericpre committed Apr 25, 2022
2 parents 6a1a5f2 + 306821d commit 2fff228
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 1 deletion.
20 changes: 19 additions & 1 deletion hyperspy/io_plugins/_hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,14 @@ def group2signaldict(self, group, lazy=False):
exp["package_version"] = ""

data = group['data']
try:
ragged_shape = group["ragged_shapes"]
new_data = np.empty(shape=data.shape, dtype=object)
for i in np.ndindex(data.shape):
new_data[i] = np.reshape(data[i], ragged_shape[i])
data = new_data
except KeyError:
pass
if lazy:
data = da.from_array(data, chunks=data.chunks)
exp['attributes']['_lazy'] = True
Expand Down Expand Up @@ -626,7 +634,17 @@ def overwrite_dataset(cls, group, data, key, signal_axes=None,
del group[key]

_logger.info(f"Chunks used for saving: {chunks}")
cls._store_data(data, dset, group, key, chunks)
if data.dtype == np.dtype('O'):
new_data = np.empty(shape=data.shape, dtype=object)
shapes = np.empty(shape=data.shape, dtype=object)
for i in np.ndindex(data.shape):
new_data[i] = data[i].ravel()
shapes[i] = np.array(data[i].shape)
shape_dset = cls._get_object_dset(group, shapes, "ragged_shapes", shapes.shape, **kwds)
cls._store_data(shapes, shape_dset, group, 'ragged_shapes', chunks=shapes.shape)
cls._store_data(new_data, dset, group, key, chunks)
else:
cls._store_data(data, dset, group, key, chunks)

def write(self):
self.write_signal(self.signal,
Expand Down
16 changes: 16 additions & 0 deletions hyperspy/tests/io/test_hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,22 @@ def test_save_ragged_array(tmp_path, file):
assert s.__class__ == s1.__class__


@zspy_marker
def test_save_ragged_dim2(tmp_path, file):
x = np.empty(5, dtype=object)
for i in range(1, 6):
x[i - 1] = np.array([list(range(i)), list(range(i))])

s = BaseSignal(x, ragged=True)

filename = tmp_path / file
s.save(filename)
s2 = load(filename)

for i, j in zip(s.data,s2.data):
np.testing.assert_array_equal(i,j)


def test_load_missing_extension(caplog):
path = my_path / "hdf5_files" / "hspy_ext_missing.hspy"
with pytest.warns(UserWarning):
Expand Down
1 change: 1 addition & 0 deletions upcoming_changes/2906.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix bug for not saving ragged arrays with dimensions larger than 2 in the ragged dimension.

0 comments on commit 2fff228

Please sign in to comment.