Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow VectorDataset to accept list of files #1597

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
4ade256
Make RasterDataset accept list of files
Jun 22, 2023
a0f985d
Fix check if str
adriantre Jun 22, 2023
a9f6944
Use isdir and isfile
adriantre Jun 23, 2023
b69625b
Rename root to paths and update type hint
adriantre Jun 23, 2023
b7a51bd
Update children of RasterDataset methods using root
adriantre Jun 23, 2023
d6a0919
Fix check to cast str to list
adriantre Jun 28, 2023
ce5f474
Update conf files for RasterDatasets
adriantre Jun 28, 2023
81833e9
Add initial suggested test
adriantre Jun 28, 2023
b247861
Add workaround for lists LandCoverAIBase
adriantre Jun 29, 2023
f569051
Add method handle_nonlocal_path for users to override
adriantre Jun 29, 2023
d4f757c
Raise RuntimeError to support existing tests
adriantre Jun 29, 2023
0414f63
Remove reduntand cast to set
adriantre Jun 29, 2023
2195553
Remove required os.exists for paths
adriantre Jul 3, 2023
61de902
Revert "Remove required os.exists for paths"
adriantre Jul 3, 2023
5bed6f3
Use arg as potitional argument not kwarg
adriantre Sep 28, 2023
8e80458
Improve comments and logs about arg paths
adriantre Sep 28, 2023
b736cef
Remove misleading comment
adriantre Sep 28, 2023
2f7df48
Change type hint of 'paths' to Iterable
adriantre Sep 28, 2023
a6e5fe1
Change type hint of 'paths' to Iterable
adriantre Sep 28, 2023
ca5c4bf
Remove premature handling of non-local paths
adriantre Sep 28, 2023
a228cc2
Replace root with paths in docstrings
adriantre Sep 28, 2023
c22a6a9
Add versionadded to list_files docstring
adriantre Sep 28, 2023
44f6eb5
Add versionchanged to docstrings
adriantre Sep 28, 2023
9dae8c4
Update type of paths in childred of Raster
adriantre Sep 28, 2023
1311957
Replace docstring for paths in all raster
adriantre Sep 28, 2023
697dfd7
Swap root with paths for conf files for raster
adriantre Sep 28, 2023
026ee11
Add newline before versionchanged
adriantre Sep 29, 2023
628801a
Revert name to root in conf for ChesapeakeCVPR
adriantre Sep 29, 2023
eae2992
Simplify EUDEM tests
adamjstewart Sep 29, 2023
2bc82c8
paths must be a string if you want autodownload support
adamjstewart Sep 29, 2023
d391079
Convert list_files to a property
adamjstewart Sep 29, 2023
66f2f02
Fix type hints
adamjstewart Sep 29, 2023
8ec2e93
Test with a real empty directory
adamjstewart Sep 29, 2023
be29b24
Move property `files` up to GeoDataset
adriantre Sep 29, 2023
0b0ade4
Rename root to paths for VectorDataset
adriantre Sep 29, 2023
a02c9b1
Merge remote-tracking branch 'origin/main' into feature/refactor_vect…
adriantre Sep 29, 2023
1a0d877
Fix mypy
adriantre Sep 29, 2023
55fbf83
Fix tests
adriantre Sep 29, 2023
3504673
Delete duplicate code
adamjstewart Sep 29, 2023
89df751
Delete duplicate code
adamjstewart Sep 29, 2023
10fd27a
Fix test coverage
adamjstewart Sep 29, 2023
110d695
Document name change
adamjstewart Sep 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/datasets/test_cbf.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_or(self, dataset: CanadianBuildingFootprints) -> None:
assert isinstance(ds, UnionDataset)

def test_already_downloaded(self, dataset: CanadianBuildingFootprints) -> None:
CanadianBuildingFootprints(root=dataset.root, download=True)
CanadianBuildingFootprints(dataset.paths, download=True)

def test_plot(self, dataset: CanadianBuildingFootprints) -> None:
query = dataset.bounds
Expand Down
2 changes: 1 addition & 1 deletion tests/datasets/test_chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def dataset(
)
monkeypatch.setattr(
ChesapeakeCVPR,
"files",
"_files",
["de_1m_2013_extended-debuffered-test_tiles", "spatial_index.geojson"],
)
root = str(tmp_path)
Expand Down
2 changes: 1 addition & 1 deletion tests/datasets/test_enviroatlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def dataset(
)
monkeypatch.setattr(
EnviroAtlas,
"files",
"_files",
["pittsburgh_pa-2010_1m-train_tiles-debuffered", "spatial_index.geojson"],
)
root = str(tmp_path)
Expand Down
10 changes: 5 additions & 5 deletions tests/datasets/test_openbuildings.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> OpenBuildings:

monkeypatch.setattr(OpenBuildings, "md5s", md5s)
transforms = nn.Identity()
return OpenBuildings(root=root, transforms=transforms)
return OpenBuildings(root, transforms=transforms)

def test_no_shapes_to_rasterize(
self, dataset: OpenBuildings, tmp_path: Path
Expand All @@ -61,19 +61,19 @@ def test_no_building_data_found(self, tmp_path: Path) -> None:
with pytest.raises(
RuntimeError, match="have manually downloaded the dataset as suggested "
):
OpenBuildings(root=false_root)
OpenBuildings(false_root)

def test_corrupted(self, dataset: OpenBuildings, tmp_path: Path) -> None:
with open(os.path.join(tmp_path, "000_buildings.csv.gz"), "w") as f:
f.write("bad")
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
OpenBuildings(dataset.root, checksum=True)
OpenBuildings(dataset.paths, checksum=True)

def test_no_meta_data_found(self, tmp_path: Path) -> None:
false_root = os.path.join(tmp_path, "empty")
os.makedirs(false_root)
with pytest.raises(FileNotFoundError, match="Meta data file"):
OpenBuildings(root=false_root)
OpenBuildings(false_root)

def test_nothing_in_index(self, dataset: OpenBuildings, tmp_path: Path) -> None:
# change meta data to another 'title_url' so that there is no match found
Expand All @@ -85,7 +85,7 @@ def test_nothing_in_index(self, dataset: OpenBuildings, tmp_path: Path) -> None:
json.dump(content, f)

with pytest.raises(FileNotFoundError, match="data was found in"):
OpenBuildings(dataset.root)
OpenBuildings(dataset.paths)

def test_getitem(self, dataset: OpenBuildings) -> None:
x = dataset[dataset.bounds]
Expand Down
21 changes: 13 additions & 8 deletions torchgeo/datasets/cbf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"""Canadian Building Footprints dataset."""

import os
from typing import Any, Callable, Optional
from collections.abc import Iterable
from typing import Any, Callable, Optional, Union

import matplotlib.pyplot as plt
from matplotlib.figure import Figure
Expand Down Expand Up @@ -60,7 +61,7 @@ class CanadianBuildingFootprints(VectorDataset):

def __init__(
self,
root: str = "data",
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: float = 0.00001,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
Expand All @@ -70,7 +71,7 @@ def __init__(
"""Initialize a new Dataset instance.

Args:
root: root directory where dataset can be found
paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
Expand All @@ -83,8 +84,11 @@ def __init__(
FileNotFoundError: if no files are found in ``root``
RuntimeError: if ``download=False`` and data is not found, or
``checksum=True`` and checksums don't match

.. versionchanged:: 0.5
*root* was renamed to *paths*.
"""
self.root = root
self.paths = paths
self.checksum = checksum

if download:
Expand All @@ -96,16 +100,17 @@ def __init__(
+ "You can use download=True to download it"
)

super().__init__(root, crs, res, transforms)
super().__init__(paths, crs, res, transforms)

def _check_integrity(self) -> bool:
"""Check integrity of dataset.

Returns:
True if dataset files are found and/or MD5s match, else False
"""
assert isinstance(self.paths, str)
for prov_terr, md5 in zip(self.provinces_territories, self.md5s):
filepath = os.path.join(self.root, prov_terr + ".zip")
filepath = os.path.join(self.paths, prov_terr + ".zip")
if not check_integrity(filepath, md5 if self.checksum else None):
return False
return True
Expand All @@ -115,11 +120,11 @@ def _download(self) -> None:
if self._check_integrity():
print("Files already downloaded and verified")
return

assert isinstance(self.paths, str)
for prov_terr, md5 in zip(self.provinces_territories, self.md5s):
download_and_extract_archive(
self.url + prov_terr + ".zip",
self.root,
self.paths,
md5=md5 if self.checksum else None,
)

Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datasets/chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ class ChesapeakeCVPR(GeoDataset):
)

# these are used to check the integrity of the dataset
files = [
_files = [
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
"de_1m_2013_extended-debuffered-test_tiles",
"de_1m_2013_extended-debuffered-train_tiles",
"de_1m_2013_extended-debuffered-val_tiles",
Expand Down Expand Up @@ -704,7 +704,7 @@ def exists(filename: str) -> bool:
return os.path.exists(os.path.join(self.root, filename))

# Check if the extracted files already exist
if all(map(exists, self.files)):
if all(map(exists, self._files)):
return
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved

# Check if the zip files have already been downloaded
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datasets/enviroatlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class EnviroAtlas(GeoDataset):
)

# these are used to check the integrity of the dataset
files = [
_files = [
"austin_tx-2012_1m-test_tiles-debuffered",
"austin_tx-2012_1m-val5_tiles-debuffered",
"durham_nc-2012_1m-test_tiles-debuffered",
Expand Down Expand Up @@ -422,7 +422,7 @@ def exists(filename: str) -> bool:
return os.path.exists(os.path.join(self.root, "enviroatlas_lotp", filename))

# Check if the extracted files already exist
if all(map(exists, self.files)):
if all(map(exists, self._files)):
return

# Check if the zip files have already been downloaded
Expand Down
88 changes: 42 additions & 46 deletions torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,17 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
dataset = landsat7 | landsat8
"""

paths: Union[str, Iterable[str]]
_crs = CRS.from_epsg(4326)
_res = 0.0

#: Glob expression used to search for files.
#:
#: This expression should be specific enough that it will not pick up files from
#: other datasets. It should not include a file extension, as the dataset may be in
#: a different file format than what it was originally downloaded as.
filename_glob = "*"
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved

# NOTE: according to the Python docs:
#
# * https://docs.python.org/3/library/exceptions.html#NotImplementedError
Expand Down Expand Up @@ -269,17 +277,36 @@ def res(self, new_res: float) -> None:
print(f"Converting {self.__class__.__name__} res from {self.res} to {new_res}")
self._res = new_res

@property
def files(self) -> set[str]:
"""A list of all files in the dataset.

Returns:
All files in the dataset.

.. versionadded:: 0.5
"""
# Make iterable
if isinstance(self.paths, str):
paths: Iterable[str] = [self.paths]
else:
paths = self.paths
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved

# Using set to remove any duplicates if directories are overlapping
files: set[str] = set()
for path in paths:
if os.path.isdir(path):
pathname = os.path.join(path, "**", self.filename_glob)
files |= set(glob.iglob(pathname, recursive=True))
else:
files.add(path)
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved

return files


class RasterDataset(GeoDataset):
"""Abstract base class for :class:`GeoDataset` stored as raster files."""

#: Glob expression used to search for files.
#:
#: This expression should be specific enough that it will not pick up files from
#: other datasets. It should not include a file extension, as the dataset may be in
#: a different file format than what it was originally downloaded as.
filename_glob = "*"

#: Regular expression used to extract date from filename.
#:
#: The expression should use named groups. The expression may contain any number of
Expand Down Expand Up @@ -423,32 +450,6 @@ def __init__(
self._crs = cast(CRS, crs)
self._res = cast(float, res)

@property
def files(self) -> set[str]:
"""A list of all files in the dataset.

Returns:
All files in the dataset.

.. versionadded:: 0.5
"""
# Make iterable
if isinstance(self.paths, str):
paths: Iterable[str] = [self.paths]
else:
paths = self.paths

# Using set to remove any duplicates if directories are overlapping
files: set[str] = set()
for path in paths:
if os.path.isdir(path):
pathname = os.path.join(path, "**", self.filename_glob)
files |= set(glob.iglob(pathname, recursive=True))
else:
files.add(path)

return files

def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
"""Retrieve image/mask and metadata indexed by query.

Expand Down Expand Up @@ -571,16 +572,9 @@ def _load_warp_file(self, filepath: str) -> DatasetReader:
class VectorDataset(GeoDataset):
"""Abstract base class for :class:`GeoDataset` stored as vector files."""

#: Glob expression used to search for files.
#:
#: This expression should be specific enough that it will not pick up files from
#: other datasets. It should not include a file extension, as the dataset may be in
#: a different file format than what it was originally downloaded as.
filename_glob = "*"

def __init__(
self,
root: str = "data",
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: float = 0.0001,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
Expand All @@ -589,7 +583,7 @@ def __init__(
"""Initialize a new Dataset instance.

Args:
root: root directory where dataset can be found
paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
Expand All @@ -603,16 +597,18 @@ def __init__(

.. versionadded:: 0.4
The *label_name* parameter.

.. versionchanged:: 0.5
*root* was renamed to *paths*.
"""
super().__init__(transforms)

self.root = root
self.paths = paths
self.label_name = label_name

# Populate the dataset index
i = 0
pathname = os.path.join(root, "**", self.filename_glob)
for filepath in glob.iglob(pathname, recursive=True):
for filepath in self.files:
try:
with fiona.open(filepath) as src:
if crs is None:
Expand All @@ -633,7 +629,7 @@ def __init__(
i += 1

if i == 0:
msg = f"No {self.__class__.__name__} data was found in `root='{root}'`"
msg = f"No {self.__class__.__name__} data was found in `root='{paths}'`"
raise FileNotFoundError(msg)

self._crs = crs
Expand Down
Loading