From e13de2684d03e2f2f9d0326c0d49a7a0e322d881 Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Mon, 24 Apr 2023 09:17:15 -0700 Subject: [PATCH] Add dtype field to RasterDataset (#1149) * Adding dtype to RasterDataset * Removing explicit cast to float for images * Updating test case * Reverting * Compromising on a UserWarning * Update torchgeo/datasets/geo.py Co-authored-by: Adam J. Stewart * Adding test * Changing back * Set the docstring of dtype * Good grief * pydocstyle workaround * REquested changes --------- Co-authored-by: Adam J. Stewart --- torchgeo/datasets/geo.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 4cd54d5d9b3..9c440da896f 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -296,6 +296,20 @@ class RasterDataset(GeoDataset): #: Color map for the dataset, used for plotting cmap: dict[int, tuple[int, int, int, int]] = {} + @property + def dtype(self) -> torch.dtype: + """The dtype of the dataset (overrides the dtype of the data file via a cast). + + Returns: + the dtype of the dataset + + .. versionadded:: 5.0 + """ + if self.is_image: + return torch.float32 + else: + return torch.long + def __init__( self, root: str = "data", @@ -429,10 +443,12 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: data = self._merge_files(filepaths, query, self.band_indexes) sample = {"crs": self.crs, "bbox": query} + + data = data.to(self.dtype) if self.is_image: - sample["image"] = data.float() + sample["image"] = data else: - sample["mask"] = data.long() + sample["mask"] = data if self.transforms is not None: sample = self.transforms(sample)