diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index ba35c1dffbe..83fb1a90588 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -7,34 +7,6 @@ on: branches: - release** jobs: - datasets: - name: datasets - runs-on: ubuntu-latest - steps: - - name: Clone repo - uses: actions/checkout@v4.1.4 - - name: Set up python - uses: actions/setup-python@v5.1.0 - with: - python-version: "3.12" - - name: Cache dependencies - uses: actions/cache@v4.0.2 - id: cache - with: - path: ${{ env.pythonLocation }} - key: ${{ env.pythonLocation }}-${{ hashFiles('pyproject.toml') }}-datasets - - name: Install pip dependencies - if: steps.cache.outputs.cache-hit != 'true' - run: | - pip install .[tests] - pip cache purge - - name: List pip dependencies - run: pip list - - name: Run pytest checks - run: | - pytest --cov=torchgeo --cov-report=xml --durations=10 - python -m torchgeo --help - torchgeo --help integration: name: integration runs-on: ubuntu-latest diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index c572be4633e..b63eb738db6 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -99,6 +99,39 @@ jobs: uses: codecov/codecov-action@v4.3.0 with: token: ${{ secrets.CODECOV_TOKEN }} + datasets: + name: datasets + runs-on: ubuntu-latest + env: + MPLBACKEND: Agg + steps: + - name: Clone repo + uses: actions/checkout@v4.1.4 + - name: Set up python + uses: actions/setup-python@v5.1.0 + with: + python-version: "3.12" + - name: Cache dependencies + uses: actions/cache@v4.0.2 + id: cache + with: + path: ${{ env.pythonLocation }} + key: ${{ env.pythonLocation }}-${{ hashFiles('requirements/required.txt') }}-${{ hashFiles('requirements/tests.txt') }} + - name: Install pip dependencies + if: steps.cache.outputs.cache-hit != 'true' + run: | + pip install -r requirements/required.txt -r requirements/tests.txt + pip cache purge + - name: List pip dependencies + run: pip list + - name: Run pytest checks + run: | + pytest --cov=torchgeo --cov-report=xml --durations=10 + python3 -m torchgeo --help + - name: Report coverage + uses: codecov/codecov-action@v4.3.0 + with: + token: ${{ secrets.CODECOV_TOKEN }} concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.head.label || github.head_ref || github.ref }} cancel-in-progress: true diff --git a/tests/datamodules/test_usavars.py b/tests/datamodules/test_usavars.py index 004750c2840..c6d5b7c77bf 100644 --- a/tests/datamodules/test_usavars.py +++ b/tests/datamodules/test_usavars.py @@ -14,7 +14,6 @@ class TestUSAVarsDataModule: @pytest.fixture def datamodule(self, request: SubRequest) -> USAVarsDataModule: - pytest.importorskip('pandas', minversion='1.1.3') root = os.path.join('tests', 'data', 'usavars') batch_size = 1 num_workers = 0 diff --git a/tests/datasets/test_advance.py b/tests/datasets/test_advance.py index efa3ac538f8..9e8ef718a79 100644 --- a/tests/datasets/test_advance.py +++ b/tests/datasets/test_advance.py @@ -1,11 +1,9 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import builtins import os import shutil from pathlib import Path -from typing import Any import matplotlib.pyplot as plt import pytest @@ -16,6 +14,8 @@ import torchgeo.datasets.utils from torchgeo.datasets import ADVANCE, DatasetNotFoundError +from .utils import importandskip + def download_url(url: str, root: str, *args: str) -> None: shutil.copy(url, root) @@ -37,19 +37,8 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> ADVANCE: transforms = nn.Identity() return ADVANCE(root, transforms, download=True, checksum=True) - @pytest.fixture - def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None: - import_orig = builtins.__import__ - - def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: - if name == 'scipy.io': - raise ImportError() - return import_orig(name, *args, **kwargs) - - monkeypatch.setattr(builtins, '__import__', mocked_import) - def test_getitem(self, dataset: ADVANCE) -> None: - pytest.importorskip('scipy', minversion='1.6.2') + pytest.importorskip('scipy', minversion='1.7.2') x = dataset[0] assert isinstance(x, dict) assert isinstance(x['image'], torch.Tensor) @@ -71,17 +60,14 @@ def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): ADVANCE(str(tmp_path)) - def test_mock_missing_module( - self, dataset: ADVANCE, mock_missing_module: None - ) -> None: - with pytest.raises( - ImportError, - match='scipy is not installed and is required to use this dataset', - ): + def test_missing_module(self, dataset: ADVANCE) -> None: + importandskip('scipy') + match = 'scipy is not installed and is required to use this dataset' + with pytest.raises(ImportError, match=match): dataset[0] def test_plot(self, dataset: ADVANCE) -> None: - pytest.importorskip('scipy', minversion='1.6.2') + pytest.importorskip('scipy', minversion='1.7.2') x = dataset[0].copy() dataset.plot(x, suptitle='Test') plt.close() diff --git a/tests/datasets/test_chabud.py b/tests/datasets/test_chabud.py index 955104a53fb..f93f036b248 100644 --- a/tests/datasets/test_chabud.py +++ b/tests/datasets/test_chabud.py @@ -1,11 +1,9 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import builtins import os import shutil from pathlib import Path -from typing import Any import matplotlib.pyplot as plt import pytest @@ -17,7 +15,7 @@ import torchgeo.datasets.utils from torchgeo.datasets import ChaBuD, DatasetNotFoundError -pytest.importorskip('h5py', minversion='3') +from .utils import importandskip def download_url(url: str, root: str, filename: str, *args: str, **kwargs: str) -> None: @@ -29,6 +27,7 @@ class TestChaBuD: def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> ChaBuD: + pytest.importorskip('h5py', minversion='3.6') monkeypatch.setattr(torchgeo.datasets.chabud, 'download_url', download_url) data_dir = os.path.join('tests', 'data', 'chabud') url = os.path.join(data_dir, 'train_eval.hdf5') @@ -47,17 +46,6 @@ def dataset( checksum=True, ) - @pytest.fixture - def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None: - import_orig = builtins.__import__ - - def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: - if name == 'h5py': - raise ImportError() - return import_orig(name, *args, **kwargs) - - monkeypatch.setattr(builtins, '__import__', mocked_import) - def test_getitem(self, dataset: ChaBuD) -> None: x = dataset[0] assert isinstance(x, dict) @@ -82,17 +70,16 @@ def test_already_downloaded(self, dataset: ChaBuD) -> None: ChaBuD(root=dataset.root, download=True) def test_not_downloaded(self, tmp_path: Path) -> None: + pytest.importorskip('h5py', minversion='3.6') with pytest.raises(DatasetNotFoundError, match='Dataset not found'): ChaBuD(str(tmp_path)) - def test_mock_missing_module( - self, dataset: ChaBuD, tmp_path: Path, mock_missing_module: None - ) -> None: - with pytest.raises( - ImportError, - match='h5py is not installed and is required to use this dataset', - ): - ChaBuD(dataset.root, download=True, checksum=True) + def test_missing_module(self) -> None: + importandskip('h5py') + root = os.path.join('tests', 'data', 'chabud') + match = 'h5py is not installed and is required to use this dataset' + with pytest.raises(ImportError, match=match): + ChaBuD(root) def test_invalid_bands(self) -> None: with pytest.raises(AssertionError): diff --git a/tests/datasets/test_chesapeake.py b/tests/datasets/test_chesapeake.py index 459c6bcdbdf..1cff011441f 100644 --- a/tests/datasets/test_chesapeake.py +++ b/tests/datasets/test_chesapeake.py @@ -29,10 +29,9 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: class TestChesapeake13: - pytest.importorskip('zipfile_deflate64') - @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> Chesapeake13: + pytest.importorskip('zipfile_deflate64') monkeypatch.setattr(torchgeo.datasets.chesapeake, 'download_url', download_url) md5 = 'fe35a615b8e749b21270472aa98bb42c' monkeypatch.setattr(Chesapeake13, 'md5', md5) @@ -63,6 +62,7 @@ def test_already_extracted(self, dataset: Chesapeake13) -> None: Chesapeake13(dataset.paths, download=True) def test_already_downloaded(self, tmp_path: Path) -> None: + pytest.importorskip('zipfile_deflate64') url = os.path.join( 'tests', 'data', 'chesapeake', 'BAYWIDE', 'Baywide_13Class_20132014.zip' ) diff --git a/tests/datasets/test_cropharvest.py b/tests/datasets/test_cropharvest.py index f478cdf53ad..69ef07d025d 100644 --- a/tests/datasets/test_cropharvest.py +++ b/tests/datasets/test_cropharvest.py @@ -1,11 +1,9 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import builtins import os import shutil from pathlib import Path -from typing import Any import matplotlib.pyplot as plt import pytest @@ -16,7 +14,7 @@ import torchgeo.datasets.utils from torchgeo.datasets import CropHarvest, DatasetNotFoundError -pytest.importorskip('h5py', minversion='3') +from .utils import importandskip def download_url(url: str, root: str, filename: str, md5: str) -> None: @@ -24,17 +22,6 @@ def download_url(url: str, root: str, filename: str, md5: str) -> None: class TestCropHarvest: - @pytest.fixture - def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None: - import_orig = builtins.__import__ - - def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: - if name == 'h5py': - raise ImportError() - return import_orig(name, *args, **kwargs) - - monkeypatch.setattr(builtins, '__import__', mocked_import) - @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CropHarvest: monkeypatch.setattr(torchgeo.datasets.cropharvest, 'download_url', download_url) @@ -62,6 +49,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CropHarvest: return dataset def test_getitem(self, dataset: CropHarvest) -> None: + pytest.importorskip('h5py', minversion='3.6') x = dataset[0] assert isinstance(x, dict) assert isinstance(x['array'], torch.Tensor) @@ -73,28 +61,27 @@ def test_getitem(self, dataset: CropHarvest) -> None: def test_len(self, dataset: CropHarvest) -> None: assert len(dataset) == 5 - def test_already_downloaded(self, dataset: CropHarvest, tmp_path: Path) -> None: - CropHarvest(root=str(tmp_path), download=False) + def test_already_downloaded(self, dataset: CropHarvest) -> None: + CropHarvest(dataset.root) def test_downloaded_zipped(self, dataset: CropHarvest, tmp_path: Path) -> None: feature_path = os.path.join(tmp_path, 'features') shutil.rmtree(feature_path) - CropHarvest(root=str(tmp_path), download=True) + CropHarvest(root=str(tmp_path)) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): CropHarvest(str(tmp_path)) def test_plot(self, dataset: CropHarvest) -> None: + pytest.importorskip('h5py', minversion='3.6') x = dataset[0].copy() dataset.plot(x, suptitle='Test') plt.close() - def test_mock_missing_module( - self, dataset: CropHarvest, tmp_path: Path, mock_missing_module: None - ) -> None: - with pytest.raises( - ImportError, - match='h5py is not installed and is required to use this dataset', - ): - CropHarvest(root=str(tmp_path), download=True)[0] + def test_missing_module(self) -> None: + importandskip('h5py') + root = os.path.join('tests', 'data', 'cropharvest') + match = 'h5py is not installed and is required to use this dataset' + with pytest.raises(ImportError, match=match): + CropHarvest(root)[0] diff --git a/tests/datasets/test_idtrees.py b/tests/datasets/test_idtrees.py index e907213c29c..4726f9c957a 100644 --- a/tests/datasets/test_idtrees.py +++ b/tests/datasets/test_idtrees.py @@ -1,12 +1,10 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import builtins import glob import os import shutil from pathlib import Path -from typing import Any import matplotlib.pyplot as plt import pytest @@ -18,7 +16,7 @@ import torchgeo.datasets.utils from torchgeo.datasets import DatasetNotFoundError, IDTReeS -pytest.importorskip('laspy', minversion='2') +from .utils import importandskip def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -50,20 +48,8 @@ def dataset( transforms = nn.Identity() return IDTReeS(root, split, task, transforms, download=True, checksum=True) - @pytest.fixture(params=['laspy', 'pyvista']) - def mock_missing_module(self, monkeypatch: MonkeyPatch, request: SubRequest) -> str: - import_orig = builtins.__import__ - package = str(request.param) - - def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: - if name == package: - raise ImportError() - return import_orig(name, *args, **kwargs) - - monkeypatch.setattr(builtins, '__import__', mocked_import) - return package - def test_getitem(self, dataset: IDTReeS) -> None: + pytest.importorskip('laspy', minversion='2') x = dataset[0] assert isinstance(x, dict) assert isinstance(x['image'], torch.Tensor) @@ -91,35 +77,31 @@ def test_already_downloaded(self, dataset: IDTReeS) -> None: IDTReeS(root=dataset.root, download=True) def test_not_downloaded(self, tmp_path: Path) -> None: + pytest.importorskip('laspy', minversion='2') with pytest.raises(DatasetNotFoundError, match='Dataset not found'): IDTReeS(str(tmp_path)) def test_not_extracted(self, tmp_path: Path) -> None: + pytest.importorskip('laspy', minversion='2') pathname = os.path.join('tests', 'data', 'idtrees', '*.zip') root = str(tmp_path) for zipfile in glob.iglob(pathname): shutil.copy(zipfile, root) IDTReeS(root) - def test_mock_missing_module( - self, dataset: IDTReeS, mock_missing_module: str - ) -> None: - package = mock_missing_module - - if package == 'laspy': - with pytest.raises( - ImportError, - match=f'{package} is not installed and is required to use this dataset', - ): - IDTReeS(dataset.root, dataset.split, dataset.task) - elif package == 'pyvista': - with pytest.raises( - ImportError, - match=f'{package} is not installed and is required to plot point cloud', - ): - dataset.plot_las(0) + def test_missing_module(self, dataset: IDTReeS) -> None: + importandskip('laspy') + match = 'laspy is not installed and is required to use this dataset' + with pytest.raises(ImportError, match=match): + dataset[0] + + importandskip('pyvista') + match = 'pyvista is not installed and is required to plot point cloud' + with pytest.raises(ImportError, match=match): + dataset.plot_las(0) def test_plot(self, dataset: IDTReeS) -> None: + pytest.importorskip('laspy', minversion='2') x = dataset[0].copy() dataset.plot(x, suptitle='Test') plt.close() @@ -136,6 +118,7 @@ def test_plot(self, dataset: IDTReeS) -> None: plt.close() def test_plot_las(self, dataset: IDTReeS) -> None: + pytest.importorskip('laspy', minversion='2') pyvista = pytest.importorskip('pyvista', minversion='0.34.2') pyvista.OFF_SCREEN = True diff --git a/tests/datasets/test_landcoverai.py b/tests/datasets/test_landcoverai.py index 7222ff78bbc..7b2d0b6d3f3 100644 --- a/tests/datasets/test_landcoverai.py +++ b/tests/datasets/test_landcoverai.py @@ -80,7 +80,7 @@ class TestLandCoverAI: def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> LandCoverAI: - pytest.importorskip('cv2', minversion='4.4.0') + pytest.importorskip('cv2', minversion='4.5.4') monkeypatch.setattr(torchgeo.datasets.landcoverai, 'download_url', download_url) md5 = 'ff8998857cc8511f644d3f7d0f3688d0' monkeypatch.setattr(LandCoverAI, 'md5', md5) @@ -111,7 +111,7 @@ def test_already_extracted(self, dataset: LandCoverAI) -> None: LandCoverAI(root=dataset.root, download=True) def test_already_downloaded(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> None: - pytest.importorskip('cv2', minversion='4.4.0') + pytest.importorskip('cv2', minversion='4.5.4') sha256 = 'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b' monkeypatch.setattr(LandCoverAI, 'sha256', sha256) url = os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip') diff --git a/tests/datasets/test_quakeset.py b/tests/datasets/test_quakeset.py index 59596f2c0c5..112ec10e0c9 100644 --- a/tests/datasets/test_quakeset.py +++ b/tests/datasets/test_quakeset.py @@ -1,11 +1,9 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import builtins import os import shutil from pathlib import Path -from typing import Any import matplotlib.pyplot as plt import pytest @@ -17,6 +15,8 @@ import torchgeo.datasets.utils from torchgeo.datasets import DatasetNotFoundError, QuakeSet +from .utils import importandskip + def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -27,6 +27,7 @@ class TestQuakeSet: def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> QuakeSet: + pytest.importorskip('h5py', minversion='3.6') monkeypatch.setattr(torchgeo.datasets.quakeset, 'download_url', download_url) url = os.path.join('tests', 'data', 'quakeset', 'earthquakes.h5') md5 = '127d0d6a1f82d517129535f50053a4c9' @@ -39,25 +40,14 @@ def dataset( root, split, transforms=transforms, download=True, checksum=True ) - @pytest.fixture - def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None: - import_orig = builtins.__import__ - - def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: - if name == 'h5py': - raise ImportError() - return import_orig(name, *args, **kwargs) - - monkeypatch.setattr(builtins, '__import__', mocked_import) - - def test_mock_missing_module( - self, dataset: QuakeSet, tmp_path: Path, mock_missing_module: None - ) -> None: + def test_missing_module(self) -> None: + importandskip('h5py') + root = os.path.join('tests', 'data', 'quakeset') with pytest.raises( ImportError, match='h5py is not installed and is required to use this dataset', ): - QuakeSet(dataset.root, download=True, checksum=True) + QuakeSet(root) def test_getitem(self, dataset: QuakeSet) -> None: x = dataset[0] @@ -69,10 +59,11 @@ def test_getitem(self, dataset: QuakeSet) -> None: def test_len(self, dataset: QuakeSet) -> None: assert len(dataset) == 8 - def test_already_downloaded(self, dataset: QuakeSet, tmp_path: Path) -> None: - QuakeSet(root=str(tmp_path), download=True) + def test_already_downloaded(self, dataset: QuakeSet) -> None: + QuakeSet(dataset.root) def test_not_downloaded(self, tmp_path: Path) -> None: + pytest.importorskip('h5py', minversion='3.6') with pytest.raises(DatasetNotFoundError, match='Dataset not found'): QuakeSet(str(tmp_path)) diff --git a/tests/datasets/test_skippd.py b/tests/datasets/test_skippd.py index 01907d22cea..fc4140cb03f 100644 --- a/tests/datasets/test_skippd.py +++ b/tests/datasets/test_skippd.py @@ -1,12 +1,10 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import builtins import os import shutil from itertools import product from pathlib import Path -from typing import Any import matplotlib.pyplot as plt import pytest @@ -18,7 +16,7 @@ import torchgeo.datasets.utils from torchgeo.datasets import SKIPPD, DatasetNotFoundError -pytest.importorskip('h5py', minversion='3') +from .utils import importandskip def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -30,6 +28,7 @@ class TestSKIPPD: def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> SKIPPD: + pytest.importorskip('h5py', minversion='3.6') task, split = request.param monkeypatch.setattr(torchgeo.datasets.skippd, 'download_url', download_url) @@ -53,31 +52,19 @@ def dataset( checksum=True, ) - @pytest.fixture - def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None: - import_orig = builtins.__import__ - - def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: - if name == 'h5py': - raise ImportError() - return import_orig(name, *args, **kwargs) - - monkeypatch.setattr(builtins, '__import__', mocked_import) - - def test_mock_missing_module( - self, dataset: SKIPPD, tmp_path: Path, mock_missing_module: None - ) -> None: - with pytest.raises( - ImportError, - match='h5py is not installed and is required to use this dataset', - ): - SKIPPD(dataset.root, download=True, checksum=True) + def test_missing_module(self) -> None: + importandskip('h5py') + root = os.path.join('tests', 'data', 'skippd') + match = 'h5py is not installed and is required to use this dataset' + with pytest.raises(ImportError, match=match): + SKIPPD(root) def test_already_extracted(self, dataset: SKIPPD) -> None: SKIPPD(root=dataset.root, download=True) @pytest.mark.parametrize('task', ['nowcast', 'forecast']) def test_already_downloaded(self, tmp_path: Path, task: str) -> None: + pytest.importorskip('h5py', minversion='3.6') pathname = os.path.join( 'tests', 'data', 'skippd', f'2017_2019_images_pv_processed_{task}.zip' ) @@ -105,6 +92,7 @@ def test_invalid_split(self) -> None: SKIPPD(split='foo') def test_not_downloaded(self, tmp_path: Path) -> None: + pytest.importorskip('h5py', minversion='3.6') with pytest.raises(DatasetNotFoundError, match='Dataset not found'): SKIPPD(str(tmp_path)) diff --git a/tests/datasets/test_so2sat.py b/tests/datasets/test_so2sat.py index 6daa4057bed..d97cad5762a 100644 --- a/tests/datasets/test_so2sat.py +++ b/tests/datasets/test_so2sat.py @@ -1,10 +1,8 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import builtins import os from pathlib import Path -from typing import Any import matplotlib.pyplot as plt import pytest @@ -15,12 +13,13 @@ from torchgeo.datasets import DatasetNotFoundError, RGBBandsMissingError, So2Sat -pytest.importorskip('h5py', minversion='3') +from .utils import importandskip class TestSo2Sat: @pytest.fixture(params=['train', 'validation', 'test']) def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> So2Sat: + pytest.importorskip('h5py', minversion='3.6') md5s_by_version = { '2': { 'train': '56e6fa0edb25b065124a3113372f76e5', @@ -35,17 +34,6 @@ def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> So2Sat: transforms = nn.Identity() return So2Sat(root=root, split=split, transforms=transforms, checksum=True) - @pytest.fixture - def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None: - import_orig = builtins.__import__ - - def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: - if name == 'h5py': - raise ImportError() - return import_orig(name, *args, **kwargs) - - monkeypatch.setattr(builtins, '__import__', mocked_import) - def test_getitem(self, dataset: So2Sat) -> None: x = dataset[0] assert isinstance(x, dict) @@ -62,14 +50,17 @@ def test_out_of_bounds(self, dataset: So2Sat) -> None: dataset[2] def test_invalid_split(self) -> None: + pytest.importorskip('h5py', minversion='3.6') with pytest.raises(AssertionError): So2Sat(split='foo') def test_invalid_bands(self) -> None: + pytest.importorskip('h5py', minversion='3.6') with pytest.raises(ValueError): So2Sat(bands=('OK', 'BK')) def test_not_downloaded(self, tmp_path: Path) -> None: + pytest.importorskip('h5py', minversion='3.6') with pytest.raises(DatasetNotFoundError, match='Dataset not found'): So2Sat(str(tmp_path)) @@ -90,11 +81,8 @@ def test_plot_rgb(self, dataset: So2Sat) -> None: ): dataset.plot(dataset[0], suptitle='Single Band') - def test_mock_missing_module( - self, dataset: So2Sat, mock_missing_module: None - ) -> None: - with pytest.raises( - ImportError, - match='h5py is not installed and is required to use this dataset', - ): - So2Sat(dataset.root) + def test_missing_module(self) -> None: + importandskip('h5py') + match = 'h5py is not installed and is required to use this dataset' + with pytest.raises(ImportError, match=match): + So2Sat() diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py index 092e8864acf..bb4cfedb38a 100644 --- a/tests/datasets/test_utils.py +++ b/tests/datasets/test_utils.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import builtins import glob import math import os @@ -38,6 +37,8 @@ working_dir, ) +from .utils import importandskip + class TestDatasetNotFoundError: def test_none(self) -> None: @@ -85,18 +86,6 @@ def test_paths_download(self) -> None: raise DatasetNotFoundError(ds) -@pytest.fixture -def mock_missing_module(monkeypatch: MonkeyPatch) -> None: - import_orig = builtins.__import__ - - def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: - if name in ['radiant_mlhub', 'rarfile', 'zipfile_deflate64']: - raise ImportError() - return import_orig(name, *args, **kwargs) - - monkeypatch.setattr(builtins, '__import__', mocked_import) - - class MLHubDataset: def download(self, output_dir: str, **kwargs: str) -> None: glob_path = os.path.join( @@ -127,10 +116,6 @@ def download_url(url: str, root: str, *args: str) -> None: shutil.copy(url, root) -def test_mock_missing_module(mock_missing_module: None) -> None: - import sys # noqa: F401 - - @pytest.mark.parametrize( 'src', [ @@ -150,17 +135,16 @@ def test_extract_archive(src: str, tmp_path: Path) -> None: extract_archive(os.path.join('tests', 'data', src), str(tmp_path)) -def test_missing_rarfile(mock_missing_module: None) -> None: - with pytest.raises( - ImportError, - match='rarfile is not installed and is required to extract this dataset', - ): - extract_archive( - os.path.join('tests', 'data', 'vhr10', 'NWPU VHR-10 dataset.rar') - ) +def test_missing_rarfile() -> None: + importandskip('rarfile') + match = 'rarfile is not installed and is required to extract this dataset' + with pytest.raises(ImportError, match=match): + path = os.path.join('tests', 'data', 'vhr10', 'NWPU VHR-10 dataset.rar') + extract_archive(path) -def test_missing_zipfile_deflate64(mock_missing_module: None) -> None: +def test_missing_zipfile_deflate64() -> None: + importandskip('zipfile_deflate64') # Should fallback on Python builtin zipfile extract_archive(os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip')) @@ -196,18 +180,13 @@ def test_download_radiant_mlhub_collection( download_radiant_mlhub_collection('', str(tmp_path)) -def test_missing_radiant_mlhub(mock_missing_module: None) -> None: - with pytest.raises( - ImportError, - match='radiant_mlhub is not installed and is required to download this dataset', - ): +def test_missing_radiant_mlhub() -> None: + importandskip('radiant_mlhub') + match = 'radiant_mlhub is not installed and is required to download this {}' + with pytest.raises(ImportError, match=match.format('dataset')): download_radiant_mlhub_dataset('', '') - with pytest.raises( - ImportError, - match='radiant_mlhub is not installed and is required to download this' - + ' collection', - ): + with pytest.raises(ImportError, match=match.format('collection')): download_radiant_mlhub_collection('', '') diff --git a/tests/datasets/test_vhr10.py b/tests/datasets/test_vhr10.py index de4c6c2d507..4479fc8ae9b 100644 --- a/tests/datasets/test_vhr10.py +++ b/tests/datasets/test_vhr10.py @@ -1,11 +1,9 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import builtins import os import shutil from pathlib import Path -from typing import Any import matplotlib.pyplot as plt import pytest @@ -18,7 +16,7 @@ import torchgeo.datasets.utils from torchgeo.datasets import VHR10, DatasetNotFoundError -pytest.importorskip('pycocotools') +from .utils import importandskip def download_url(url: str, root: str, *args: str) -> None: @@ -30,6 +28,9 @@ class TestVHR10: def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> VHR10: + split = request.param + if split == 'positive': + pytest.importorskip('pycocotools') pytest.importorskip('rarfile', minversion='4') monkeypatch.setattr(torchgeo.datasets.vhr10, 'download_url', download_url) monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url) @@ -42,21 +43,9 @@ def dataset( md5 = '567c4cd8c12624864ff04865de504c58' monkeypatch.setitem(VHR10.target_meta, 'md5', md5) root = str(tmp_path) - split = request.param transforms = nn.Identity() return VHR10(root, split, transforms, download=True, checksum=True) - @pytest.fixture - def mock_missing_modules(self, monkeypatch: MonkeyPatch) -> None: - import_orig = builtins.__import__ - - def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: - if name in {'pycocotools.coco', 'skimage.measure'}: - raise ImportError() - return import_orig(name, *args, **kwargs) - - monkeypatch.setattr(builtins, '__import__', mocked_import) - def test_getitem(self, dataset: VHR10) -> None: for i in range(2): x = dataset[i] @@ -93,25 +82,25 @@ def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): VHR10(str(tmp_path)) - def test_mock_missing_module( - self, dataset: VHR10, mock_missing_modules: None - ) -> None: - if dataset.split == 'positive': - with pytest.raises( - ImportError, - match='pycocotools is not installed and is required to use this datase', - ): - VHR10(dataset.root, dataset.split) - - with pytest.raises( - ImportError, - match='scikit-image is not installed and is required to plot masks', - ): - x = dataset[0] - dataset.plot(x) + def test_missing_module(self) -> None: + importandskip('pycocotools') + root = os.path.join('tests', 'data', 'vhr10') + match = 'pycocotools is not installed and is required to use this datase' + with pytest.raises(ImportError, match=match): + VHR10(root, 'positive') + + def test_missing_module_plot(self) -> None: + importandskip('skimage') + root = os.path.join('tests', 'data', 'vhr10') + match = 'scikit-image is not installed and is required to plot masks' + with pytest.raises(ImportError, match=match): + ds = VHR10(root, 'negative') + x = ds[0] + ds.split = 'positive' + ds.plot(x) def test_plot(self, dataset: VHR10) -> None: - pytest.importorskip('skimage', minversion='0.18') + pytest.importorskip('skimage', minversion='0.19') x = dataset[1].copy() dataset.plot(x, suptitle='Test') plt.close() diff --git a/tests/datasets/test_zuericrop.py b/tests/datasets/test_zuericrop.py index 330866b36f0..8984e47f05a 100644 --- a/tests/datasets/test_zuericrop.py +++ b/tests/datasets/test_zuericrop.py @@ -1,11 +1,9 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import builtins import os import shutil from pathlib import Path -from typing import Any import matplotlib.pyplot as plt import pytest @@ -16,7 +14,7 @@ import torchgeo.datasets.utils from torchgeo.datasets import DatasetNotFoundError, RGBBandsMissingError, ZueriCrop -pytest.importorskip('h5py', minversion='3') +from .utils import importandskip def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -26,6 +24,7 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: class TestZueriCrop: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> ZueriCrop: + pytest.importorskip('h5py', minversion='3.6') monkeypatch.setattr(torchgeo.datasets.zuericrop, 'download_url', download_url) data_dir = os.path.join('tests', 'data', 'zuericrop') urls = [ @@ -39,17 +38,6 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> ZueriCrop: transforms = nn.Identity() return ZueriCrop(root=root, transforms=transforms, download=True, checksum=True) - @pytest.fixture - def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None: - import_orig = builtins.__import__ - - def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: - if name == 'h5py': - raise ImportError() - return import_orig(name, *args, **kwargs) - - monkeypatch.setattr(builtins, '__import__', mocked_import) - def test_getitem(self, dataset: ZueriCrop) -> None: x = dataset[0] assert isinstance(x, dict) @@ -79,19 +67,19 @@ def test_already_downloaded(self, dataset: ZueriCrop) -> None: ZueriCrop(root=dataset.root, download=True) def test_not_downloaded(self, tmp_path: Path) -> None: + pytest.importorskip('h5py', minversion='3.6') with pytest.raises(DatasetNotFoundError, match='Dataset not found'): ZueriCrop(str(tmp_path)) - def test_mock_missing_module( - self, dataset: ZueriCrop, tmp_path: Path, mock_missing_module: None - ) -> None: - with pytest.raises( - ImportError, - match='h5py is not installed and is required to use this dataset', - ): - ZueriCrop(dataset.root, download=True, checksum=True) + def test_missing_module(self) -> None: + importandskip('h5py') + root = os.path.join('tests', 'data', 'zuericrop') + match = 'h5py is not installed and is required to use this dataset' + with pytest.raises(ImportError, match=match): + ZueriCrop(root) def test_invalid_bands(self) -> None: + pytest.importorskip('h5py', minversion='3.6') with pytest.raises(ValueError): ZueriCrop(bands=('OK', 'BK')) diff --git a/tests/datasets/utils.py b/tests/datasets/utils.py new file mode 100644 index 00000000000..ff1fd45b226 --- /dev/null +++ b/tests/datasets/utils.py @@ -0,0 +1,21 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import pytest + + +def importandskip(modname: str, reason: str | None = None) -> None: + """Exact opposite of :func:`pytest.importorskip`. + + Args: + modname: The name of the module to import. + reason: If given, this reason is shown as the message when the module can + be imported. + """ + try: + __import__(modname) + if reason is None: + reason = f'could import {modname!r}' + pytest.skip(reason) + except ImportError: + pass diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index ca4da48bb78..3402bd0f3fc 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -89,7 +89,7 @@ def test_trainer( self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool ) -> None: if name.startswith('so2sat') or name == 'quakeset': - pytest.importorskip('h5py', minversion='3') + pytest.importorskip('h5py', minversion='3.6') config = os.path.join('tests', 'conf', name + '.yaml') diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index ef3c6164d98..c62c808c72f 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -71,7 +71,7 @@ def test_trainer( self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool ) -> None: if name == 'skippd': - pytest.importorskip('h5py', minversion='3') + pytest.importorskip('h5py', minversion='3.6') config = os.path.join('tests', 'conf', name + '.yaml') diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index 88cd58c0553..d8b207d5d2d 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -87,12 +87,16 @@ class TestSemanticSegmentationTask: def test_trainer( self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool ) -> None: - if name == 'naipchesapeake': - pytest.importorskip('zipfile_deflate64') - - if name == 'landcoverai': - sha256 = 'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b' - monkeypatch.setattr(LandCoverAI, 'sha256', sha256) + match name: + case 'chabud': + pytest.importorskip('h5py', minversion='3.6') + case 'landcoverai': + sha256 = ( + 'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b' + ) + monkeypatch.setattr(LandCoverAI, 'sha256', sha256) + case 'naipchesapeake': + pytest.importorskip('zipfile_deflate64') config = os.path.join('tests', 'conf', name + '.yaml') diff --git a/torchgeo/datasets/cropharvest.py b/torchgeo/datasets/cropharvest.py index 9887d71db6b..4a383687332 100644 --- a/torchgeo/datasets/cropharvest.py +++ b/torchgeo/datasets/cropharvest.py @@ -111,15 +111,7 @@ def __init__( Raises: DatasetNotFoundError: If dataset is not found and *download* is False. - ImportError: If h5py is not installed """ - try: - import h5py # noqa: F401 - except ImportError: - raise ImportError( - 'h5py is not installed and is required to use this dataset' - ) - self.root = root self.transforms = transforms self.checksum = checksum @@ -141,6 +133,9 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: Returns: single pixel time-series array and label at that index + + Raises: + ImportError: If h5py is not installed """ files = self.files[index] data = self._load_array(files['chip']) @@ -209,7 +204,12 @@ def _load_array(self, path: str) -> Tensor: Returns: the image """ - import h5py + try: + import h5py # noqa: F401 + except ImportError: + raise ImportError( + 'h5py is not installed and is required to use this dataset' + ) filename = os.path.join(path) with h5py.File(filename, 'r') as f: diff --git a/torchgeo/datasets/idtrees.py b/torchgeo/datasets/idtrees.py index c17df135759..ffd1d460e48 100644 --- a/torchgeo/datasets/idtrees.py +++ b/torchgeo/datasets/idtrees.py @@ -166,7 +166,6 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - ImportError: if laspy is not installed DatasetNotFoundError: If dataset is not found and *download* is False. """ assert split in ['train', 'test'] @@ -182,13 +181,6 @@ def __init__( self.num_classes = len(self.classes) self._verify() - try: - import laspy # noqa: F401 - except ImportError: - raise ImportError( - 'laspy is not installed and is required to use this dataset' - ) - self.images, self.geometries, self.labels = self._load(root) def __getitem__(self, index: int) -> dict[str, Tensor]: @@ -199,6 +191,9 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: Returns: data and label at that index + + Raises: + ImportError: if laspy is not installed """ path = self.images[index] image = self._load_image(path).to(torch.uint8) @@ -262,7 +257,12 @@ def _load_las(self, path: str) -> Tensor: Returns: the point cloud """ - import laspy + try: + import laspy # noqa: F401 + except ImportError: + raise ImportError( + 'laspy is not installed and is required to use this dataset' + ) las = laspy.read(path) array: 'np.typing.NDArray[np.int_]' = np.stack([las.x, las.y, las.z], axis=0)