Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dtype field to RasterDataset #1149

Merged
merged 12 commits into from
Apr 24, 2023
1 change: 1 addition & 0 deletions torchgeo/datasets/chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class Chesapeake(RasterDataset, abc.ABC):
"""

is_image = False
dtype = torch.long
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
calebrob6 marked this conversation as resolved.
Show resolved Hide resolved

# subclasses use the 13 class cmap by default
cmap = {
Expand Down
18 changes: 16 additions & 2 deletions torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,18 @@ class RasterDataset(GeoDataset):
#: Color map for the dataset, used for plotting
cmap: dict[int, tuple[int, int, int, int]] = {}

@property
def dtype(self) -> torch.dtype:
"""The dtype of the dataset (overrides the dtype of the data file via a cast).

Returns:
the dtype of the dataset
calebrob6 marked this conversation as resolved.
Show resolved Hide resolved
"""
if self.is_image:
return torch.float32
else:
return torch.long

def __init__(
self,
root: str = "data",
Expand Down Expand Up @@ -429,10 +441,12 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
data = self._merge_files(filepaths, query, self.band_indexes)

sample = {"crs": self.crs, "bbox": query}

data = data.to(self.dtype)
if self.is_image:
sample["image"] = data.float()
sample["image"] = data
else:
sample["mask"] = data.long()
sample["mask"] = data

if self.transforms is not None:
sample = self.transforms(sample)
Expand Down