Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for float type raster datasets #379

Closed
tritolol opened this issue Jan 31, 2022 · 3 comments · Fixed by #384
Closed

Add support for float type raster datasets #379

tritolol opened this issue Jan 31, 2022 · 3 comments · Fixed by #384
Labels
datasets Geospatial or benchmark datasets
Milestone

Comments

@tritolol
Copy link
Contributor

tritolol commented Jan 31, 2022

I'm trying to load a custom DSM raster dataset of type float32.

I'm implementing a new class based on RasterDataset like so:

from torchgeo.datasets import RasterDataset

class DsmData(RasterDataset):
    filename_glob = "*.tif"

When I call __getitem__() on a DsmData object, I get an array containing int32 data instead.
I traced this down to the following line in the RasterDataset definition

dest = dest.astype(np.int32)

All raster datasets are forced into the int32 type which should not happen.

Having RasterDatasets with different types will probably cause problems when applying union or intersection operations.
But since a custom collate_fn can be defined, the user is able to provide a solution for this.

I achieved the desired behavior simply by removing the mentioned line.

@calebrob6
Copy link
Member

Oof, good catch. We should definitely change this. Do you want to open a PR to get this started?

@adamjstewart
Copy link
Collaborator

Having RasterDatasets with different types will probably cause problems when applying union or intersection operations.

Nope, this won't be a problem. All of our collation functions use torch operators that can handle various dtypes:

$ python
>>> import torch
>>> a = torch.zeros(3, 3, dtype=torch.int32)
>>> b = torch.zeros(3, 3, dtype=torch.float32)
>>> torch.stack((a, b))
tensor([[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]])
>>> torch.maximum(a, b)
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])
>>> torch.cat((a, b))
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])

@adamjstewart adamjstewart added the datasets Geospatial or benchmark datasets label Jan 31, 2022
@tritolol
Copy link
Contributor Author

tritolol commented Feb 1, 2022

Created PR #384

@adamjstewart adamjstewart added this to the 0.2.1 milestone Mar 19, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
datasets Geospatial or benchmark datasets
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants