diff --git a/nitransforms/resampling.py b/nitransforms/resampling.py index 2428a769..8cb20594 100644 --- a/nitransforms/resampling.py +++ b/nitransforms/resampling.py @@ -234,6 +234,7 @@ def apply( if isinstance(spatialimage, (str, Path)): spatialimage = _nbload(str(spatialimage)) + singleton_4d = spatialimage.ndim == 4 and spatialimage.shape[-1] == 1 spatialimage = squeeze_image(spatialimage) # Avoid opening the data array just yet @@ -370,7 +371,8 @@ def apply( with suppress(ValueError): resampled = np.squeeze(resampled, axis=3) - moved = spatialimage.__class__(resampled, _ref.affine, hdr) + moved = spatialimage.__class__( + resampled[..., None] if singleton_4d else resampled, _ref.affine, hdr) return moved output_dtype = output_dtype or input_dtype diff --git a/nitransforms/tests/test_resampling.py b/nitransforms/tests/test_resampling.py index 3ff8fb36..f9b88577 100644 --- a/nitransforms/tests/test_resampling.py +++ b/nitransforms/tests/test_resampling.py @@ -56,8 +56,13 @@ def test_apply_singleton_time_dimension(): data = np.reshape(np.arange(27, dtype=np.uint8), (3, 3, 3, 1)) nii = nb.Nifti1Image(data, np.eye(4)) - xfm = nitl.Affine(np.eye(4), reference=nii) - apply(xfm, nii) + ref = nb.Nifti1Image(np.zeros((4, 4, 4)), np.eye(4)) + xfm = nitl.Affine(np.eye(4), reference=ref) + movednii = apply(xfm, nii) + assert movednii.shape == ref.shape + (1, ) + + movednii = apply(xfm, nii, reference=ref) + assert movednii.shape == ref.shape + (1, ) @pytest.mark.parametrize(