Skip to content

Commit

Permalink
[MAINT] Rewrite write_tmp_imgs function and use it with pytest's `t…
Browse files Browse the repository at this point in the history
…mp_path` (#4094)

* Rewrite write_tmp_imgs as function

* Update tests to use tmp_path

* Add unit tests

* Rename function

* Add whatsnew

* Rename function

* Update whatsnew
  • Loading branch information
ymzayek committed Nov 13, 2023
1 parent 2a31c45 commit aa9c914
Show file tree
Hide file tree
Showing 19 changed files with 437 additions and 403 deletions.
1 change: 1 addition & 0 deletions doc/changes/latest.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ Changes
-------

- :bdg-success:`API` Expose scipy CubicSpline ``extrapolate`` parameter in :func:`~signal.clean` to control the interpolation of censored volumes in both ends of the BOLD signal data (:gh:`4028` by `Jordi Huguet`_).
- :bdg-dark:`Code` Private utility context manager ``write_tmp_imgs`` is refactored into function ``write_imgs_to_path`` (:gh:`4094` by `Yasmin Mzayek`_).
68 changes: 22 additions & 46 deletions nilearn/_utils/testing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Utilities for testing nilearn."""
# Author: Alexandre Abraham, Philippe Gervais
import contextlib
import functools
import gc
import os
Expand Down Expand Up @@ -113,12 +112,10 @@ def serialize_niimg(img, gzipped=True):
return f.read()


@contextlib.contextmanager
def write_tmp_imgs(*imgs, **kwargs):
"""Context manager for writing Nifti images.
def write_imgs_to_path(*imgs, file_path=None, **kwargs):
"""Write Nifti images on disk.
Write nifti images in a temporary location, and remove them at the end of
the block.
Write nifti images in a specified location.
Parameters
----------
Expand All @@ -144,6 +141,9 @@ def write_tmp_imgs(*imgs, **kwargs):
list of string is returned.
"""
if file_path is None:
file_path = Path.cwd()

valid_keys = {"create_files", "use_wildcards"}
input_keys = set(kwargs.keys())
invalid_keys = input_keys - valid_keys
Expand All @@ -160,49 +160,25 @@ def write_tmp_imgs(*imgs, **kwargs):

if create_files:
filenames = []
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore", RuntimeWarning)
for img in imgs:
fd, filename = tempfile.mkstemp(
prefix=prefix, suffix=suffix, dir=None
)
os.close(fd)
filenames.append(filename)
img.to_filename(filename)
del img

if use_wildcards:
yield f"{prefix}*{suffix}"
else:
if len(imgs) == 1:
yield filenames[0]
else:
yield filenames
finally:
failures = []
# Ensure all created files are removed
for filename in filenames:
try:
os.remove(filename)
except FileNotFoundError:
# ok, file already removed
pass
except OSError as e:
# problem eg permission, or open file descriptor
failures.append(e)
if failures:
failed_lines = "\n".join(str(e) for e in failures)
raise OSError(
"The following files could not be removed:\n"
f"{failed_lines}"
)
with warnings.catch_warnings():
warnings.simplefilter("ignore", RuntimeWarning)
for i, img in enumerate(imgs):
filename = file_path / (prefix + str(i) + suffix)
filenames.append(str(filename))
img.to_filename(filename)
del img

if use_wildcards:
return str(file_path / f"{prefix}*{suffix}")
else:
if len(filenames) == 1:
return filenames[0]
return filenames

else: # No-op
if len(imgs) == 1:
yield imgs[0]
else:
yield imgs
return imgs[0]
return imgs


def are_tests_running():
Expand Down
10 changes: 6 additions & 4 deletions nilearn/_utils/tests/test_niimg.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def test_img_data_dtype(rng, affine_eye, tmp_path):
assert not all(dtype_matches)


def test_load_niimg(img1):
with testing.write_tmp_imgs(img1, create_files=True) as filename:
filename = Path(filename)
load_niimg(filename)
def test_load_niimg(img1, tmp_path):
filename = testing.write_imgs_to_path(
img1, file_path=tmp_path, create_files=True
)
filename = Path(filename)
load_niimg(filename)
103 changes: 51 additions & 52 deletions nilearn/_utils/tests/test_niimg_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def test_check_same_fov(affine_eye):
)


def test_check_niimg_3d(affine_eye, img_3d_zeros_eye):
def test_check_niimg_3d(affine_eye, img_3d_zeros_eye, tmp_path):
# check error for non-forced but necessary resampling
with pytest.raises(TypeError, match="nibabel format"):
_utils.check_niimg(0)
Expand All @@ -124,8 +124,10 @@ def test_check_niimg_3d(affine_eye, img_3d_zeros_eye):
data[20, 20, 20] = 1
data_img = Nifti1Image(data, affine_eye)

with testing.write_tmp_imgs(data_img, create_files=True) as filename:
_utils.check_niimg_3d(filename)
filename = testing.write_imgs_to_path(
data_img, file_path=tmp_path, create_files=True
)
_utils.check_niimg_3d(filename)

# check data dtype equal with dtype='auto'
img_check = _utils.check_niimg_3d(img_3d_zeros_eye, dtype="auto")
Expand Down Expand Up @@ -242,12 +244,12 @@ def test_check_niimg(img_3d_zeros_eye, img_4d_zeros_eye):
)


def test_check_niimg_pathlike(img_3d_zeros_eye):
with testing.write_tmp_imgs(
img_3d_zeros_eye, create_files=True
) as filename:
filename = Path(filename)
_utils.check_niimg_3d(filename)
def test_check_niimg_pathlike(img_3d_zeros_eye, tmp_path):
filename = testing.write_imgs_to_path(
img_3d_zeros_eye, file_path=tmp_path, create_files=True
)
filename = Path(filename)
_utils.check_niimg_3d(filename)


def test_check_niimg_wildcards_errors():
Expand Down Expand Up @@ -278,58 +280,55 @@ def test_check_niimg_wildcards_errors():
@pytest.mark.parametrize(
"wildcards", [True, False]
) # (With globbing behavior or not)
def test_check_niimg_wildcards(affine_eye, shape, wildcards):
def test_check_niimg_wildcards(affine_eye, shape, wildcards, tmp_path):
# First create some testing data
img = Nifti1Image(np.zeros(shape), affine_eye)

with testing.write_tmp_imgs(img, create_files=True) as filename:
assert_array_equal(
get_data(_utils.check_niimg(filename, wildcards=wildcards)),
get_data(img),
)

filename = testing.write_imgs_to_path(
img, file_path=tmp_path, create_files=True
)
assert_array_equal(
get_data(_utils.check_niimg(filename, wildcards=wildcards)),
get_data(img),
)

def test_check_niimg_wildcards_one_file_name(img_3d_zeros_eye):
tmp_dir = tempfile.tempdir + os.sep

def test_check_niimg_wildcards_one_file_name(img_3d_zeros_eye, tmp_path):
file_not_found_msg = "File not found: '%s'"

# Testing with a glob matching exactly one filename
# Using a glob matching one file containing a 3d image returns a 4d image
# with 1 as last dimension.
with testing.write_tmp_imgs(
img_3d_zeros_eye, create_files=True, use_wildcards=True
) as globs:
glob_input = tmp_dir + globs
assert_array_equal(
get_data(_utils.check_niimg(glob_input))[..., 0],
get_data(img_3d_zeros_eye),
)
globs = testing.write_imgs_to_path(
img_3d_zeros_eye,
file_path=tmp_path,
create_files=True,
use_wildcards=True,
)
assert_array_equal(
get_data(_utils.check_niimg(globs))[..., 0],
get_data(img_3d_zeros_eye),
)
# Disabled globbing behavior should raise an ValueError exception
with testing.write_tmp_imgs(
img_3d_zeros_eye, create_files=True, use_wildcards=True
) as globs:
glob_input = tmp_dir + globs
with pytest.raises(
ValueError, match=file_not_found_msg % re.escape(glob_input)
):
_utils.check_niimg(glob_input, wildcards=False)
with pytest.raises(
ValueError, match=file_not_found_msg % re.escape(globs)
):
_utils.check_niimg(globs, wildcards=False)

# Testing with a glob matching multiple filenames
img_4d = _utils.check_niimg_4d((img_3d_zeros_eye, img_3d_zeros_eye))
with testing.write_tmp_imgs(
globs = testing.write_imgs_to_path(
img_3d_zeros_eye,
img_3d_zeros_eye,
file_path=tmp_path,
create_files=True,
use_wildcards=True,
) as globs:
assert_array_equal(
get_data(_utils.check_niimg(glob_input)), get_data(img_4d)
)
)
assert_array_equal(get_data(_utils.check_niimg(globs)), get_data(img_4d))


def test_check_niimg_wildcards_no_expand_wildcards(
img_3d_zeros_eye, img_4d_zeros_eye
img_3d_zeros_eye, img_4d_zeros_eye, tmp_path
):
nofile_path = "/tmp/nofile"

Expand All @@ -350,20 +349,20 @@ def test_check_niimg_wildcards_no_expand_wildcards(
_utils.check_niimg(nofile_path, wildcards=False)

# Testing with an exact filename matching (3d case)
with testing.write_tmp_imgs(
img_3d_zeros_eye, create_files=True
) as filename:
assert_array_equal(
get_data(_utils.check_niimg(filename)), get_data(img_3d_zeros_eye)
)
filename = testing.write_imgs_to_path(
img_3d_zeros_eye, file_path=tmp_path, create_files=True
)
assert_array_equal(
get_data(_utils.check_niimg(filename)), get_data(img_3d_zeros_eye)
)

# Testing with an exact filename matching (4d case)
with testing.write_tmp_imgs(
img_4d_zeros_eye, create_files=True
) as filename:
assert_array_equal(
get_data(_utils.check_niimg(filename)), get_data(img_4d_zeros_eye)
)
filename = testing.write_imgs_to_path(
img_4d_zeros_eye, file_path=tmp_path, create_files=True
)
assert_array_equal(
get_data(_utils.check_niimg(filename)), get_data(img_4d_zeros_eye)
)

# Reverting to default behavior
ni.EXPAND_PATH_WILDCARDS = True
Expand Down
30 changes: 30 additions & 0 deletions nilearn/_utils/tests/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
assert_memory_less_than,
check_deprecation,
with_memory_profiler,
write_imgs_to_path,
)


Expand Down Expand Up @@ -56,3 +57,32 @@ def dummy_deprecation(start_version, end_version):

def test_check_deprecation():
check_deprecation(dummy_deprecation, "Deprecated")("0.0.1", "0.0.2")


@pytest.mark.parametrize("create_files", [True, False])
@pytest.mark.parametrize("use_wildcards", [True, False])
def test_write_tmp_imgs_default(
monkeypatch, tmp_path, img_3d_mni, create_files, use_wildcards
):
"""Write imgs to default location."""
monkeypatch.chdir(tmp_path)

write_imgs_to_path(
img_3d_mni,
create_files=create_files,
use_wildcards=use_wildcards,
)


@pytest.mark.parametrize("create_files", [True, False])
@pytest.mark.parametrize("use_wildcards", [True, False])
def test_write_tmp_imgs_set_path(
tmp_path, img_3d_mni, create_files, use_wildcards
):
"""Write imgs to a specified location."""
write_imgs_to_path(
img_3d_mni,
file_path=tmp_path,
create_files=create_files,
use_wildcards=use_wildcards,
)
45 changes: 23 additions & 22 deletions nilearn/decomposition/tests/test_canica.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
from nibabel import Nifti1Image
from numpy.testing import assert_array_almost_equal

from nilearn._utils.testing import write_tmp_imgs
from nilearn._utils.testing import write_imgs_to_path
from nilearn.conftest import _affine_eye, _rng
from nilearn.decomposition.canica import CanICA
from nilearn.decomposition.tests.test_multi_pca import _tmp_dir
from nilearn.image import get_data, iter_img
from nilearn.maskers import MultiNiftiMasker

Expand Down Expand Up @@ -256,44 +255,46 @@ def test_components_img(canica_data, mask_img):
assert components_img.shape, check_shape


def test_with_globbing_patterns_with_single_subject(mask_img):
def test_with_globbing_patterns_with_single_subject(mask_img, tmp_path):
# single subject
data, *_ = _make_canica_test_data(n_subjects=1)
n_components = 3

canica = CanICA(n_components=n_components, mask=mask_img)

with write_tmp_imgs(data[0], create_files=True, use_wildcards=True) as img:
input_image = _tmp_dir() + img
canica.fit(input_image)
components_img = canica.components_img_
img = write_imgs_to_path(
data[0], file_path=tmp_path, create_files=True, use_wildcards=True
)
canica.fit(img)
components_img = canica.components_img_

assert isinstance(components_img, Nifti1Image)
assert isinstance(components_img, Nifti1Image)

# n_components = 3
check_shape = data[0].shape[:3] + (3,)
# n_components = 3
check_shape = data[0].shape[:3] + (3,)

assert components_img.shape, check_shape
assert components_img.shape, check_shape


def test_with_globbing_patterns_with_multi_subjects(canica_data, mask_img):
def test_with_globbing_patterns_with_multi_subjects(
canica_data, mask_img, tmp_path
):
# Multi subjects
n_components = 3
canica = CanICA(n_components=n_components, mask=mask_img)

with write_tmp_imgs(
*canica_data, create_files=True, use_wildcards=True
) as img:
input_image = _tmp_dir() + img
canica.fit(input_image)
components_img = canica.components_img_
img = write_imgs_to_path(
*canica_data, file_path=tmp_path, create_files=True, use_wildcards=True
)
canica.fit(img)
components_img = canica.components_img_

assert isinstance(components_img, Nifti1Image)
assert isinstance(components_img, Nifti1Image)

# n_components = 3
check_shape = canica_data[0].shape[:3] + (3,)
# n_components = 3
check_shape = canica_data[0].shape[:3] + (3,)

assert components_img.shape, check_shape
assert components_img.shape, check_shape


def test_canica_score(canica_data, mask_img):
Expand Down

0 comments on commit aa9c914

Please sign in to comment.