From d7e3cd3307b9096aa3266d95065605e4aef64f45 Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Tue, 28 Feb 2023 04:53:39 +0000 Subject: [PATCH 01/12] Adding dtype to RasterDataset --- torchgeo/datasets/chesapeake.py | 1 + torchgeo/datasets/geo.py | 9 ++++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/torchgeo/datasets/chesapeake.py b/torchgeo/datasets/chesapeake.py index a472c9b20bf..7df21ae8845 100644 --- a/torchgeo/datasets/chesapeake.py +++ b/torchgeo/datasets/chesapeake.py @@ -40,6 +40,7 @@ class Chesapeake(RasterDataset, abc.ABC): """ is_image = False + dtype = torch.long # subclasses use the 13 class cmap by default cmap = { diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 4cd54d5d9b3..4bef45fa703 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -296,6 +296,9 @@ class RasterDataset(GeoDataset): #: Color map for the dataset, used for plotting cmap: dict[int, tuple[int, int, int, int]] = {} + #: dtype to force onto the dataset (overrides the dtype of the file via a cast) + dtype: Optional[torch.dtype] = None + def __init__( self, root: str = "data", @@ -429,10 +432,14 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: data = self._merge_files(filepaths, query, self.band_indexes) sample = {"crs": self.crs, "bbox": query} + + if self.dtype is not None: + data = data.to(self.dtype) + if self.is_image: sample["image"] = data.float() else: - sample["mask"] = data.long() + sample["mask"] = data if self.transforms is not None: sample = self.transforms(sample) From 1373660f9c786b5b534a073e47f243202c7d6ba6 Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Tue, 28 Feb 2023 11:17:08 -0800 Subject: [PATCH 02/12] Removing explicit cast to float for images --- torchgeo/datasets/geo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 4bef45fa703..308bba74666 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -437,7 +437,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: data = data.to(self.dtype) if self.is_image: - sample["image"] = data.float() + sample["image"] = data else: sample["mask"] = data From 843750846d79082a3c6e6d5b5e27b5f8275ebf94 Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Wed, 22 Mar 2023 15:56:44 +0300 Subject: [PATCH 03/12] Updating test case --- tests/datasets/test_geo.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index 259c583ebe2..36b8771548a 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -181,7 +181,9 @@ def sentinel(self, request: SubRequest) -> Sentinel2: @pytest.fixture() def custom_dtype_ds(self) -> RasterDataset: root = os.path.join("tests", "data", "raster") - return RasterDataset(root) + ds = RasterDataset(root) + ds.dtype = torch.long + return ds def test_getitem_single_file(self, naip: NAIP) -> None: x = naip[naip.bounds] From 955d874a074d32b2eeb4b83c8959aa2e81f1c13c Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Wed, 22 Mar 2023 18:38:09 +0300 Subject: [PATCH 04/12] Reverting --- tests/datasets/test_geo.py | 4 +--- torchgeo/datasets/geo.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index 36b8771548a..259c583ebe2 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -181,9 +181,7 @@ def sentinel(self, request: SubRequest) -> Sentinel2: @pytest.fixture() def custom_dtype_ds(self) -> RasterDataset: root = os.path.join("tests", "data", "raster") - ds = RasterDataset(root) - ds.dtype = torch.long - return ds + return RasterDataset(root) def test_getitem_single_file(self, naip: NAIP) -> None: x = naip[naip.bounds] diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 308bba74666..4bef45fa703 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -437,7 +437,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: data = data.to(self.dtype) if self.is_image: - sample["image"] = data + sample["image"] = data.float() else: sample["mask"] = data From 9617b0c150e8928948466d1e9a3f901aaf23e685 Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Mon, 10 Apr 2023 21:34:35 +0000 Subject: [PATCH 05/12] Compromising on a UserWarning --- torchgeo/datasets/geo.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 4bef45fa703..4893fb1745e 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -11,6 +11,7 @@ import sys from collections.abc import Sequence from typing import Any, Callable, Optional, cast +from warnings import warn import fiona import fiona.transform @@ -433,8 +434,13 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: sample = {"crs": self.crs, "bbox": query} - if self.dtype is not None: + if self.dtype is not None and not self.is_image: data = data.to(self.dtype) + if self.dtype is not None and self.is_image: + warn( + "Custom dtype is explicitely set, however the current RasterDataset is" + + " an image, so no action will be taken." + ) if self.is_image: sample["image"] = data.float() From 69148df21a3ab58968254562fdcfd4e2f1af58f2 Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Sun, 23 Apr 2023 13:48:23 -0700 Subject: [PATCH 06/12] Update torchgeo/datasets/geo.py Co-authored-by: Adam J. Stewart --- torchgeo/datasets/geo.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 4893fb1745e..ff12850cfa8 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -434,13 +434,11 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: sample = {"crs": self.crs, "bbox": query} - if self.dtype is not None and not self.is_image: - data = data.to(self.dtype) - if self.dtype is not None and self.is_image: - warn( - "Custom dtype is explicitely set, however the current RasterDataset is" - + " an image, so no action will be taken." - ) + if self.dtype is not None: + if self.is_image: + warn("Custom dtype is explicitly set, but dtype is only valid for mask RasterDatasets.") + else: + data = data.to(self.dtype) if self.is_image: sample["image"] = data.float() From a85a1c26c8b45312bccca02444c92c6ec6e3faff Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Sun, 23 Apr 2023 21:00:11 +0000 Subject: [PATCH 07/12] Adding test --- tests/datasets/test_geo.py | 5 +++++ torchgeo/datasets/geo.py | 5 ++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index 259c583ebe2..7574866eb43 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -227,6 +227,11 @@ def test_no_all_bands(self) -> None: with pytest.raises(AssertionError, match=msg): CustomSentinelDataset(root, bands=bands, transforms=transforms, cache=cache) + def test_dtype_warning(self, custom_dtype_ds: RasterDataset) -> None: + custom_dtype_ds.dtype = torch.int32 + with pytest.warns(UserWarning, match="Custom dtype is explicitly set*"): + custom_dtype_ds[custom_dtype_ds.bounds] + class TestVectorDataset: @pytest.fixture(scope="class") diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index ff12850cfa8..91389ac4e94 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -436,7 +436,10 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: if self.dtype is not None: if self.is_image: - warn("Custom dtype is explicitly set, but dtype is only valid for mask RasterDatasets.") + warn( + "Custom dtype is explicitly set, but dtype is only valid for mask" + + " RasterDatasets." + ) else: data = data.to(self.dtype) From c9cea4cb82d7ecd8b694a7e4a7f95598a0eae5c8 Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Sun, 23 Apr 2023 23:09:00 +0000 Subject: [PATCH 08/12] Changing back --- tests/datasets/test_geo.py | 5 ----- torchgeo/datasets/geo.py | 20 ++++++++------------ 2 files changed, 8 insertions(+), 17 deletions(-) diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index 7574866eb43..259c583ebe2 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -227,11 +227,6 @@ def test_no_all_bands(self) -> None: with pytest.raises(AssertionError, match=msg): CustomSentinelDataset(root, bands=bands, transforms=transforms, cache=cache) - def test_dtype_warning(self, custom_dtype_ds: RasterDataset) -> None: - custom_dtype_ds.dtype = torch.int32 - with pytest.warns(UserWarning, match="Custom dtype is explicitly set*"): - custom_dtype_ds[custom_dtype_ds.bounds] - class TestVectorDataset: @pytest.fixture(scope="class") diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 91389ac4e94..1f0bdc49a83 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -11,7 +11,6 @@ import sys from collections.abc import Sequence from typing import Any, Callable, Optional, cast -from warnings import warn import fiona import fiona.transform @@ -298,7 +297,12 @@ class RasterDataset(GeoDataset): cmap: dict[int, tuple[int, int, int, int]] = {} #: dtype to force onto the dataset (overrides the dtype of the file via a cast) - dtype: Optional[torch.dtype] = None + @property + def dtype(self) -> torch.dtype: + if self.is_image: + return torch.float32 + else: + return torch.long def __init__( self, @@ -434,17 +438,9 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: sample = {"crs": self.crs, "bbox": query} - if self.dtype is not None: - if self.is_image: - warn( - "Custom dtype is explicitly set, but dtype is only valid for mask" - + " RasterDatasets." - ) - else: - data = data.to(self.dtype) - + data = data.to(self.dtype) if self.is_image: - sample["image"] = data.float() + sample["image"] = data else: sample["mask"] = data From b537b64ffc7a95e1baf599b470c11bf163e0e36e Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Sun, 23 Apr 2023 23:11:38 +0000 Subject: [PATCH 09/12] Set the docstring of dtype --- torchgeo/datasets/geo.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 1f0bdc49a83..c4f1d31e6df 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -296,9 +296,13 @@ class RasterDataset(GeoDataset): #: Color map for the dataset, used for plotting cmap: dict[int, tuple[int, int, int, int]] = {} - #: dtype to force onto the dataset (overrides the dtype of the file via a cast) @property def dtype(self) -> torch.dtype: + """dtype of the dataset (overrides the dtype of the data file via a cast). + + Returns: + the dtype of the dataset + """ if self.is_image: return torch.float32 else: From 664f4f8850afa3e292a24724dc09a418ac8e5812 Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Sun, 23 Apr 2023 23:14:09 +0000 Subject: [PATCH 10/12] Good grief --- torchgeo/datasets/geo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index c4f1d31e6df..da31298624e 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -298,7 +298,7 @@ class RasterDataset(GeoDataset): @property def dtype(self) -> torch.dtype: - """dtype of the dataset (overrides the dtype of the data file via a cast). + """dtype of the dataset (overrides the dtype of the data file via a cast). # noqa: D403, E501 Returns: the dtype of the dataset From e3a6dc031676cf13dbb6753c49a48deaf211c458 Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Sun, 23 Apr 2023 23:16:14 +0000 Subject: [PATCH 11/12] pydocstyle workaround --- torchgeo/datasets/geo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index da31298624e..91e161c2527 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -298,7 +298,7 @@ class RasterDataset(GeoDataset): @property def dtype(self) -> torch.dtype: - """dtype of the dataset (overrides the dtype of the data file via a cast). # noqa: D403, E501 + """The dtype of the dataset (overrides the dtype of the data file via a cast). Returns: the dtype of the dataset From f89b1170997c6519cbf680640bb47b00220f22a3 Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Mon, 24 Apr 2023 15:24:54 +0000 Subject: [PATCH 12/12] REquested changes --- torchgeo/datasets/chesapeake.py | 1 - torchgeo/datasets/geo.py | 2 ++ 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/torchgeo/datasets/chesapeake.py b/torchgeo/datasets/chesapeake.py index 7df21ae8845..a472c9b20bf 100644 --- a/torchgeo/datasets/chesapeake.py +++ b/torchgeo/datasets/chesapeake.py @@ -40,7 +40,6 @@ class Chesapeake(RasterDataset, abc.ABC): """ is_image = False - dtype = torch.long # subclasses use the 13 class cmap by default cmap = { diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 91e161c2527..9c440da896f 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -302,6 +302,8 @@ def dtype(self) -> torch.dtype: Returns: the dtype of the dataset + + .. versionadded:: 5.0 """ if self.is_image: return torch.float32