Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] add readme to default data locations #4037

Merged
merged 4 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
71 changes: 36 additions & 35 deletions nilearn/datasets/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
import pytest
import requests

from nilearn import datasets
from nilearn.datasets import utils
from nilearn.datasets.utils import _get_dataset_descr
Comment on lines -19 to -21
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated but simplified the imports a bit

from nilearn.image import load_img

currdir = os.path.dirname(os.path.abspath(__file__))
Expand Down Expand Up @@ -65,15 +63,15 @@ def test_get_dataset_descr_warning():
with pytest.warns(
UserWarning, match="Could not find dataset description."
):
descr = _get_dataset_descr("")
descr = utils._get_dataset_descr("")

assert descr == ""


@pytest.mark.parametrize("name", DATASET_NAMES)
def test_get_dataset_descr(name):
"""Test function ``_get_dataset_descr()``."""
descr = _get_dataset_descr(name)
descr = utils._get_dataset_descr(name)

assert isinstance(descr, str)
assert len(descr) > 0
Expand All @@ -86,7 +84,7 @@ def test_get_dataset_dir(tmp_path):
os.environ.pop("NILEARN_SHARED_DATA", None)

expected_base_dir = os.path.expanduser("~/nilearn_data")
data_dir = datasets.utils._get_dataset_dir("test", verbose=0)
data_dir = utils._get_dataset_dir("test", verbose=0)

assert data_dir == os.path.join(expected_base_dir, "test")
assert os.path.exists(data_dir)
Expand All @@ -95,7 +93,7 @@ def test_get_dataset_dir(tmp_path):

expected_base_dir = str(tmp_path / "test_nilearn_data")
os.environ["NILEARN_DATA"] = expected_base_dir
data_dir = datasets.utils._get_dataset_dir("test", verbose=0)
data_dir = utils._get_dataset_dir("test", verbose=0)

assert data_dir == os.path.join(expected_base_dir, "test")
assert os.path.exists(data_dir)
Expand All @@ -104,7 +102,7 @@ def test_get_dataset_dir(tmp_path):

expected_base_dir = str(tmp_path / "nilearn_shared_data")
os.environ["NILEARN_SHARED_DATA"] = expected_base_dir
data_dir = datasets.utils._get_dataset_dir("test", verbose=0)
data_dir = utils._get_dataset_dir("test", verbose=0)

assert data_dir == os.path.join(expected_base_dir, "test")
assert os.path.exists(data_dir)
Expand All @@ -121,7 +119,13 @@ def test_get_dataset_dir(tmp_path):
match="Nilearn tried to store the dataset in the following "
"directories, but",
):
datasets.utils._get_dataset_dir("test", test_file, verbose=0)
utils._get_dataset_dir("test", test_file, verbose=0)


def test_add_readme_to_default_data_locations(tmp_path):
assert not (tmp_path / "README.md").exists()
utils._get_dataset_dir(dataset_name="test", verbose=0, data_dir=tmp_path)
assert (tmp_path / "README.md").exists()
Comment on lines +125 to +128
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simple test added



@pytest.mark.parametrize("should_cast_path_to_string", [False, True])
Expand All @@ -130,7 +134,7 @@ def test_get_dataset_dir_path_as_str(should_cast_path_to_string, tmp_path):
expected_dataset_dir = expected_base_dir / "test"
if should_cast_path_to_string:
expected_dataset_dir = str(expected_dataset_dir)
data_dir = datasets.utils._get_dataset_dir(
data_dir = utils._get_dataset_dir(
"test", default_paths=[expected_dataset_dir], verbose=0
)

Expand All @@ -149,7 +153,7 @@ def test_get_dataset_dir_write_access(tmp_path):

expected_base_dir = str(tmp_path / "nilearn_shared_data")
os.environ["NILEARN_SHARED_DATA"] = expected_base_dir
data_dir = datasets.utils._get_dataset_dir(
data_dir = utils._get_dataset_dir(
"test", default_paths=[no_write], verbose=0
)

Expand All @@ -167,9 +171,7 @@ def test_md5_sum_file():
os.write(out, b"abcfeg")
os.close(out)

assert (
datasets.utils._md5_sum_file(f) == "18f32295c556b2a1a3a8e68fe1ad40f7"
)
assert utils._md5_sum_file(f) == "18f32295c556b2a1a3a8e68fe1ad40f7"

os.remove(f)

Expand All @@ -183,7 +185,7 @@ def test_read_md5_sum_file():
b"70886dcabe7bf5c5a1c24ca24e4cbd94 test/some_image.nii",
)
os.close(out)
h = datasets.utils._read_md5_sum_file(f)
h = utils._read_md5_sum_file(f)

assert "/tmp/test" in h
assert "/etc/test" not in h
Expand Down Expand Up @@ -212,7 +214,7 @@ def test_tree():
open(os.path.join(dir11, "file111"), "w").close()
open(os.path.join(dir2, "file21"), "w").close()

tree_ = datasets.utils._tree(parent)
tree_ = utils._tree(parent)

# Check the tree
# assert_equal(tree_[0]['dir1'][0]['dir11'][0], 'file111')
Expand Down Expand Up @@ -253,7 +255,7 @@ def test_movetree():
open(os.path.join(dir12, "file121"), "w").close()
open(os.path.join(dir2, "file21"), "w").close()

datasets.utils.movetree(dir1, dir2)
utils.movetree(dir1, dir2)

assert not os.path.exists(dir11)
assert not os.path.exists(dir12)
Expand Down Expand Up @@ -283,11 +285,11 @@ def test_filter_columns():
list(zip(value1, value2)), dtype=[("INT", int), ("STR", "S1")]
)

f = datasets.utils._filter_columns(values, {"INT": (23, 46)})
f = utils._filter_columns(values, {"INT": (23, 46)})

assert np.sum(f) == 24

f = datasets.utils._filter_columns(values, {"INT": [0, 9, (12, 24)]})
f = utils._filter_columns(values, {"INT": [0, 9, (12, 24)]})

assert np.sum(f) == 15

Expand All @@ -297,23 +299,23 @@ def test_filter_columns():
)

# No filter
f = datasets.utils._filter_columns(values, [])
f = utils._filter_columns(values, [])

assert np.sum(f) == 500

f = datasets.utils._filter_columns(values, {"STR": b"b"})
f = utils._filter_columns(values, {"STR": b"b"})

assert np.sum(f) == 167

f = datasets.utils._filter_columns(values, {"STR": "b"})
f = utils._filter_columns(values, {"STR": "b"})

assert np.sum(f) == 167

f = datasets.utils._filter_columns(values, {"INT": 1, "STR": b"b"})
f = utils._filter_columns(values, {"INT": 1, "STR": b"b"})

assert np.sum(f) == 84

f = datasets.utils._filter_columns(
f = utils._filter_columns(
values, {"INT": 1, "STR": b"b"}, combination="or"
)

Expand All @@ -333,7 +335,7 @@ def test_uncompress():
try:
with contextlib.closing(zipfile.ZipFile(ztemp, "w")) as testzip:
testzip.writestr(ftemp, " ")
datasets.utils._uncompress_file(ztemp, verbose=0)
utils._uncompress_file(ztemp, verbose=0)

assert os.path.exists(os.path.join(dtemp, ftemp))

Expand All @@ -348,7 +350,7 @@ def test_uncompress():
os.close(fd)
with contextlib.closing(tarfile.open(ztemp, "w")) as tar:
tar.add(temp, arcname=ftemp)
datasets.utils._uncompress_file(ztemp, verbose=0)
utils._uncompress_file(ztemp, verbose=0)

assert os.path.exists(os.path.join(dtemp, ftemp))

Expand All @@ -357,7 +359,7 @@ def test_uncompress():
dtemp = mkdtemp()
ztemp = os.path.join(dtemp, "test.gz")
gzip.open(ztemp, "wb").close()
datasets.utils._uncompress_file(ztemp, verbose=0)
utils._uncompress_file(ztemp, verbose=0)

# test.gz gets uncompressed into test
assert os.path.exists(os.path.join(dtemp, "test"))
Expand All @@ -382,7 +384,7 @@ def test_safe_extract(tmp_path):
with pytest.raises(
Exception, match="Attempted Path Traversal in Tar File"
):
datasets.utils._uncompress_file(ztemp, verbose=0)
utils._uncompress_file(ztemp, verbose=0)


@pytest.mark.parametrize("should_cast_path_to_string", [False, True])
Expand All @@ -393,7 +395,7 @@ def test_fetch_file_overwrite(
tmp_path = str(tmp_path)

# overwrite non-exiting file.
fil = datasets.utils._fetch_file(
fil = utils._fetch_file(
url="http://foo/", data_dir=str(tmp_path), verbose=0, overwrite=True
)

Expand All @@ -407,7 +409,7 @@ def test_fetch_file_overwrite(
fp.write("some content")

# Don't overwrite existing file.
fil = datasets.utils._fetch_file(
fil = utils._fetch_file(
url="http://foo/", data_dir=str(tmp_path), verbose=0, overwrite=False
)

Expand All @@ -417,8 +419,7 @@ def test_fetch_file_overwrite(
assert fp.read() == "some content"

# Overwrite existing file.
# Overwrite existing file.
fil = datasets.utils._fetch_file(
fil = utils._fetch_file(
url="http://foo/", data_dir=str(tmp_path), verbose=0, overwrite=True
)

Expand All @@ -437,7 +438,7 @@ def test_fetch_files_use_session(

# regression test for https://github.com/nilearn/nilearn/issues/2863
session = MagicMock()
datasets.utils._fetch_files(
utils._fetch_files(
files=[
("example1", "https://example.org/example1", {"overwrite": True}),
("example2", "https://example.org/example2", {"overwrite": True}),
Expand All @@ -458,7 +459,7 @@ def test_fetch_files_overwrite(

# overwrite non-exiting file.
files = ("1.txt", "http://foo/1.txt")
fil = datasets.utils._fetch_files(
fil = utils._fetch_files(
data_dir=str(tmp_path),
verbose=0,
files=[files + (dict(overwrite=True),)],
Expand All @@ -474,7 +475,7 @@ def test_fetch_files_overwrite(
fp.write("some content")

# Don't overwrite existing file.
fil = datasets.utils._fetch_files(
fil = utils._fetch_files(
data_dir=str(tmp_path),
verbose=0,
files=[files + (dict(overwrite=False),)],
Expand All @@ -486,7 +487,7 @@ def test_fetch_files_overwrite(
assert fp.read() == "some content"

# Overwrite existing file.
fil = datasets.utils._fetch_files(
fil = utils._fetch_files(
data_dir=str(tmp_path),
verbose=0,
files=[files + (dict(overwrite=True),)],
Expand Down
21 changes: 21 additions & 0 deletions nilearn/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,10 @@ def _get_dataset_dir(
if not os.path.exists(path):
try:
os.makedirs(path)
_add_readme_to_default_data_locations(
data_dir=data_dir,
verbose=verbose,
)
if verbose > 0:
print(f"\nDataset created in {path}\n")
return path
Expand All @@ -316,6 +320,23 @@ def _get_dataset_dir(
)


def _add_readme_to_default_data_locations(data_dir=None, verbose=1):
for d in get_data_dirs(data_dir=data_dir):
file = Path(d) / "README.md"
if file.parent.exists() and not file.exists():
with open(file, "w") as f:
f.write(
"""# Nilearn data folder

This directory is used by Nilearn to store datasets
and atlases downloaded from the internet.
It can be safely deleted.
If you delete it, previously downloaded data will be downloaded again."""
)
if verbose > 0:
print(f"\nAdded README.md to {d}\n")


# The functions _is_within_directory and _safe_extract were implemented in
# https://github.com/nilearn/nilearn/pull/3391 to address a directory
# traversal vulnerability https://github.com/advisories/GHSA-gw9q-c7gh-j9vm
Expand Down