diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 4cd54d5d9b3..9c440da896f 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -296,6 +296,20 @@ class RasterDataset(GeoDataset): #: Color map for the dataset, used for plotting cmap: dict[int, tuple[int, int, int, int]] = {} + @property + def dtype(self) -> torch.dtype: + """The dtype of the dataset (overrides the dtype of the data file via a cast). + + Returns: + the dtype of the dataset + + .. versionadded:: 5.0 + """ + if self.is_image: + return torch.float32 + else: + return torch.long + def __init__( self, root: str = "data", @@ -429,10 +443,12 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: data = self._merge_files(filepaths, query, self.band_indexes) sample = {"crs": self.crs, "bbox": query} + + data = data.to(self.dtype) if self.is_image: - sample["image"] = data.float() + sample["image"] = data else: - sample["mask"] = data.long() + sample["mask"] = data if self.transforms is not None: sample = self.transforms(sample)