Skip to content

Commit

Permalink
Add dtype field to RasterDataset (#1149)
Browse files Browse the repository at this point in the history
* Adding dtype to RasterDataset

* Removing explicit cast to float for images

* Updating test case

* Reverting

* Compromising on a UserWarning

* Update torchgeo/datasets/geo.py

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>

* Adding test

* Changing back

* Set the docstring of dtype

* Good grief

* pydocstyle workaround

* REquested changes

---------

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
  • Loading branch information
calebrob6 and adamjstewart committed Apr 24, 2023
1 parent d3c82a5 commit e13de26
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,20 @@ class RasterDataset(GeoDataset):
#: Color map for the dataset, used for plotting
cmap: dict[int, tuple[int, int, int, int]] = {}

@property
def dtype(self) -> torch.dtype:
"""The dtype of the dataset (overrides the dtype of the data file via a cast).
Returns:
the dtype of the dataset
.. versionadded:: 5.0
"""
if self.is_image:
return torch.float32
else:
return torch.long

def __init__(
self,
root: str = "data",
Expand Down Expand Up @@ -429,10 +443,12 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
data = self._merge_files(filepaths, query, self.band_indexes)

sample = {"crs": self.crs, "bbox": query}

data = data.to(self.dtype)
if self.is_image:
sample["image"] = data.float()
sample["image"] = data
else:
sample["mask"] = data.long()
sample["mask"] = data

if self.transforms is not None:
sample = self.transforms(sample)
Expand Down

0 comments on commit e13de26

Please sign in to comment.