From 7e7888437a1a681a483d02f25085dc9b196a4afd Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Wed, 22 Mar 2023 18:38:09 +0300 Subject: [PATCH] Reverting --- tests/datasets/test_geo.py | 4 +--- torchgeo/datasets/geo.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index 2d684ae1e38..158117c9d88 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -182,9 +182,7 @@ def sentinel(self, request: SubRequest) -> Sentinel2: @pytest.fixture() def custom_dtype_ds(self) -> RasterDataset: root = os.path.join("tests", "data", "raster") - ds = RasterDataset(root) - ds.dtype = torch.long - return ds + return RasterDataset(root) def test_getitem_single_file(self, naip: NAIP) -> None: x = naip[naip.bounds] diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index ddaed22d2db..4f864a0aefe 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -435,7 +435,7 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: data = data.to(self.dtype) if self.is_image: - sample["image"] = data + sample["image"] = data.float() else: sample["mask"] = data