Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CI: test optional datasets on every commit #2045

Closed
wants to merge 19 commits into from
Closed
28 changes: 0 additions & 28 deletions .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,6 @@ on:
branches:
- release**
jobs:
datasets:
name: datasets
runs-on: ubuntu-latest
steps:
- name: Clone repo
uses: actions/checkout@v4.1.4
- name: Set up python
uses: actions/setup-python@v5.1.0
with:
python-version: "3.12"
- name: Cache dependencies
uses: actions/cache@v4.0.2
id: cache
with:
path: ${{ env.pythonLocation }}
key: ${{ env.pythonLocation }}-${{ hashFiles('pyproject.toml') }}-datasets
- name: Install pip dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: |
pip install .[tests]
pip cache purge
- name: List pip dependencies
run: pip list
- name: Run pytest checks
run: |
pytest --cov=torchgeo --cov-report=xml --durations=10
python -m torchgeo --help
torchgeo --help
integration:
name: integration
runs-on: ubuntu-latest
Expand Down
33 changes: 33 additions & 0 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,39 @@ jobs:
uses: codecov/codecov-action@v4.3.0
with:
token: ${{ secrets.CODECOV_TOKEN }}
datasets:
name: datasets
runs-on: ubuntu-latest
env:
MPLBACKEND: Agg
steps:
- name: Clone repo
uses: actions/checkout@v4.1.4
- name: Set up python
uses: actions/setup-python@v5.1.0
with:
python-version: "3.12"
- name: Cache dependencies
uses: actions/cache@v4.0.2
id: cache
with:
path: ${{ env.pythonLocation }}
key: ${{ env.pythonLocation }}-${{ hashFiles('requirements/required.txt') }}-${{ hashFiles('requirements/tests.txt') }}
- name: Install pip dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: |
pip install -r requirements/required.txt -r requirements/tests.txt
pip cache purge
- name: List pip dependencies
run: pip list
- name: Run pytest checks
run: |
pytest --cov=torchgeo --cov-report=xml --durations=10
python3 -m torchgeo --help
- name: Report coverage
uses: codecov/codecov-action@v4.3.0
with:
token: ${{ secrets.CODECOV_TOKEN }}
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.head.label || github.head_ref || github.ref }}
cancel-in-progress: true
1 change: 0 additions & 1 deletion tests/datamodules/test_usavars.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
class TestUSAVarsDataModule:
@pytest.fixture
def datamodule(self, request: SubRequest) -> USAVarsDataModule:
pytest.importorskip('pandas', minversion='1.1.3')
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to this PR, but this dep is now required

root = os.path.join('tests', 'data', 'usavars')
batch_size = 1
num_workers = 0
Expand Down
30 changes: 8 additions & 22 deletions tests/datasets/test_advance.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import builtins
import os
import shutil
from pathlib import Path
from typing import Any

import matplotlib.pyplot as plt
import pytest
Expand All @@ -16,6 +14,8 @@
import torchgeo.datasets.utils
from torchgeo.datasets import ADVANCE, DatasetNotFoundError

from .utils import importandskip


def download_url(url: str, root: str, *args: str) -> None:
shutil.copy(url, root)
Expand All @@ -37,19 +37,8 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> ADVANCE:
transforms = nn.Identity()
return ADVANCE(root, transforms, download=True, checksum=True)

@pytest.fixture
def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
import_orig = builtins.__import__

def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == 'scipy.io':
raise ImportError()
return import_orig(name, *args, **kwargs)

monkeypatch.setattr(builtins, '__import__', mocked_import)

def test_getitem(self, dataset: ADVANCE) -> None:
pytest.importorskip('scipy', minversion='1.6.2')
pytest.importorskip('scipy', minversion='1.7.2')
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x['image'], torch.Tensor)
Expand All @@ -71,17 +60,14 @@ def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
ADVANCE(str(tmp_path))

def test_mock_missing_module(
self, dataset: ADVANCE, mock_missing_module: None
) -> None:
with pytest.raises(
ImportError,
match='scipy is not installed and is required to use this dataset',
):
def test_missing_module(self, dataset: ADVANCE) -> None:
importandskip('scipy')
match = 'scipy is not installed and is required to use this dataset'
with pytest.raises(ImportError, match=match):
dataset[0]

def test_plot(self, dataset: ADVANCE) -> None:
pytest.importorskip('scipy', minversion='1.6.2')
pytest.importorskip('scipy', minversion='1.7.2')
x = dataset[0].copy()
dataset.plot(x, suptitle='Test')
plt.close()
Expand Down
31 changes: 9 additions & 22 deletions tests/datasets/test_chabud.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import builtins
import os
import shutil
from pathlib import Path
from typing import Any

import matplotlib.pyplot as plt
import pytest
Expand All @@ -17,7 +15,7 @@
import torchgeo.datasets.utils
from torchgeo.datasets import ChaBuD, DatasetNotFoundError

pytest.importorskip('h5py', minversion='3')
from .utils import importandskip


def download_url(url: str, root: str, filename: str, *args: str, **kwargs: str) -> None:
Expand All @@ -29,6 +27,7 @@ class TestChaBuD:
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> ChaBuD:
pytest.importorskip('h5py', minversion='3.6')
monkeypatch.setattr(torchgeo.datasets.chabud, 'download_url', download_url)
data_dir = os.path.join('tests', 'data', 'chabud')
url = os.path.join(data_dir, 'train_eval.hdf5')
Expand All @@ -47,17 +46,6 @@ def dataset(
checksum=True,
)

@pytest.fixture
def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
import_orig = builtins.__import__

def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == 'h5py':
raise ImportError()
return import_orig(name, *args, **kwargs)

monkeypatch.setattr(builtins, '__import__', mocked_import)

def test_getitem(self, dataset: ChaBuD) -> None:
x = dataset[0]
assert isinstance(x, dict)
Expand All @@ -82,17 +70,16 @@ def test_already_downloaded(self, dataset: ChaBuD) -> None:
ChaBuD(root=dataset.root, download=True)

def test_not_downloaded(self, tmp_path: Path) -> None:
pytest.importorskip('h5py', minversion='3.6')
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
ChaBuD(str(tmp_path))

def test_mock_missing_module(
self, dataset: ChaBuD, tmp_path: Path, mock_missing_module: None
) -> None:
with pytest.raises(
ImportError,
match='h5py is not installed and is required to use this dataset',
):
ChaBuD(dataset.root, download=True, checksum=True)
def test_missing_module(self) -> None:
importandskip('h5py')
root = os.path.join('tests', 'data', 'chabud')
match = 'h5py is not installed and is required to use this dataset'
with pytest.raises(ImportError, match=match):
ChaBuD(root)

def test_invalid_bands(self) -> None:
with pytest.raises(AssertionError):
Expand Down
4 changes: 2 additions & 2 deletions tests/datasets/test_chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,9 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:


class TestChesapeake13:
pytest.importorskip('zipfile_deflate64')

@pytest.fixture
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> Chesapeake13:
pytest.importorskip('zipfile_deflate64')
monkeypatch.setattr(torchgeo.datasets.chesapeake, 'download_url', download_url)
md5 = 'fe35a615b8e749b21270472aa98bb42c'
monkeypatch.setattr(Chesapeake13, 'md5', md5)
Expand Down Expand Up @@ -63,6 +62,7 @@ def test_already_extracted(self, dataset: Chesapeake13) -> None:
Chesapeake13(dataset.paths, download=True)

def test_already_downloaded(self, tmp_path: Path) -> None:
pytest.importorskip('zipfile_deflate64')
url = os.path.join(
'tests', 'data', 'chesapeake', 'BAYWIDE', 'Baywide_13Class_20132014.zip'
)
Expand Down
37 changes: 12 additions & 25 deletions tests/datasets/test_cropharvest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import builtins
import os
import shutil
from pathlib import Path
from typing import Any

import matplotlib.pyplot as plt
import pytest
Expand All @@ -16,25 +14,14 @@
import torchgeo.datasets.utils
from torchgeo.datasets import CropHarvest, DatasetNotFoundError

pytest.importorskip('h5py', minversion='3')
from .utils import importandskip


def download_url(url: str, root: str, filename: str, md5: str) -> None:
shutil.copy(url, os.path.join(root, filename))


class TestCropHarvest:
@pytest.fixture
def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
import_orig = builtins.__import__

def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == 'h5py':
raise ImportError()
return import_orig(name, *args, **kwargs)

monkeypatch.setattr(builtins, '__import__', mocked_import)

@pytest.fixture
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CropHarvest:
monkeypatch.setattr(torchgeo.datasets.cropharvest, 'download_url', download_url)
Expand Down Expand Up @@ -62,6 +49,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CropHarvest:
return dataset

def test_getitem(self, dataset: CropHarvest) -> None:
pytest.importorskip('h5py', minversion='3.6')
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x['array'], torch.Tensor)
Expand All @@ -73,28 +61,27 @@ def test_getitem(self, dataset: CropHarvest) -> None:
def test_len(self, dataset: CropHarvest) -> None:
assert len(dataset) == 5

def test_already_downloaded(self, dataset: CropHarvest, tmp_path: Path) -> None:
CropHarvest(root=str(tmp_path), download=False)
def test_already_downloaded(self, dataset: CropHarvest) -> None:
CropHarvest(dataset.root)

def test_downloaded_zipped(self, dataset: CropHarvest, tmp_path: Path) -> None:
feature_path = os.path.join(tmp_path, 'features')
shutil.rmtree(feature_path)
CropHarvest(root=str(tmp_path), download=True)
CropHarvest(root=str(tmp_path))

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
CropHarvest(str(tmp_path))

def test_plot(self, dataset: CropHarvest) -> None:
pytest.importorskip('h5py', minversion='3.6')
x = dataset[0].copy()
dataset.plot(x, suptitle='Test')
plt.close()

def test_mock_missing_module(
self, dataset: CropHarvest, tmp_path: Path, mock_missing_module: None
) -> None:
with pytest.raises(
ImportError,
match='h5py is not installed and is required to use this dataset',
):
CropHarvest(root=str(tmp_path), download=True)[0]
def test_missing_module(self) -> None:
importandskip('h5py')
root = os.path.join('tests', 'data', 'cropharvest')
match = 'h5py is not installed and is required to use this dataset'
with pytest.raises(ImportError, match=match):
CropHarvest(root)[0]