From dbd27dbc136f06a22c3b8628707068958313c36d Mon Sep 17 00:00:00 2001 From: Ethan Blackwood Date: Wed, 15 May 2024 22:44:14 -0400 Subject: [PATCH 1/2] Clean up special cases for loading from HDF5 and fix None handling --- caiman/utils/utils.py | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/caiman/utils/utils.py b/caiman/utils/utils.py index caf82d906..3fe65ea44 100644 --- a/caiman/utils/utils.py +++ b/caiman/utils/utils.py @@ -561,25 +561,16 @@ def recursively_load_dict_contents_from_group(h5file:h5py.File, path:str) -> dic for key, item in h5file[path].items(): if isinstance(item, h5py._hl.dataset.Dataset): - val_set = np.nan - if isinstance(item[()], str): - if item[()] == 'NoneType': - ans[key] = None - else: - ans[key] = item[()] - - elif key in ['dims', 'medw', 'sigma_smooth_snmf', 'dxy', 'max_shifts', 'strides', 'overlaps']: - if isinstance(item[()], np.ndarray): - ans[key] = tuple(item[()]) - else: - ans[key] = item[()] + val = item[()] + if val == 'NoneType' or val == b'NoneType': + ans[key] = None + elif key in ['dims', 'medw', 'sigma_smooth_snmf', + 'dxy', 'max_shifts', 'strides', 'overlaps'] and isinstance(val, np.ndarray): + ans[key] = tuple(val) + elif isinstance(val, np.bool_): # sigh + ans[key] = bool(val) else: - if isinstance(item[()], np.bool_): # sigh - ans[key] = bool(item[()]) - else: - ans[key] = item[()] - if isinstance(ans[key], bytes) and ans[key] == b'NoneType': - ans[key] = None + ans[key] = item[()] elif isinstance(item, h5py._hl.group.Group): if key in ('A', 'W', 'Ab', 'downscale_matrix', 'upscale_matrix'): From 09a8a0273a04d5d23516113b3f0fbf4618c2b9fe Mon Sep 17 00:00:00 2001 From: Ethan Blackwood Date: Wed, 15 May 2024 22:51:49 -0400 Subject: [PATCH 2/2] Fix invalid comparison for non-scalars --- caiman/utils/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/caiman/utils/utils.py b/caiman/utils/utils.py index 3fe65ea44..e0dd46f15 100644 --- a/caiman/utils/utils.py +++ b/caiman/utils/utils.py @@ -562,7 +562,7 @@ def recursively_load_dict_contents_from_group(h5file:h5py.File, path:str) -> dic for key, item in h5file[path].items(): if isinstance(item, h5py._hl.dataset.Dataset): val = item[()] - if val == 'NoneType' or val == b'NoneType': + if isinstance(val, str) and val == 'NoneType' or isinstance(val, bytes) and val == b'NoneType': ans[key] = None elif key in ['dims', 'medw', 'sigma_smooth_snmf', 'dxy', 'max_shifts', 'strides', 'overlaps'] and isinstance(val, np.ndarray):