Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dtype field to RasterDataset #1149

Merged
merged 12 commits into from Apr 24, 2023
Merged

Add dtype field to RasterDataset #1149

merged 12 commits into from Apr 24, 2023

Conversation

calebrob6
Copy link
Member

In the great DataModule overhaul of winter 2022/2023 (#992) we made sure that all the dtypes returned by our datasets played nicely with Kornia. As part of this, we made the assumption in RasterDataset that all "mask" layers should be torch.long and all images should be torch.float. As a result, our RasterDatasets can essentially not be used for real valued regression tasks.

This PR aims to fix this by introducing a new field to RasterDataset called dtype that, if set, forces a cast to that type right before the dataset level transform is called. With this we can fix the dtype expectations on a per dataset level.

@github-actions github-actions bot added the datasets Geospatial or benchmark datasets label Feb 28, 2023
Copy link
Collaborator

@adamjstewart adamjstewart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically this is an API change (we introduce a new class attribute) so it should wait until 0.5.0, but it also fixes an unintended consequence of #992, so it could also be argued that it should go in 0.4.1. Not sure which to choose.

We'll need to take this PR into account when thinking about #985.

torchgeo/datasets/geo.py Outdated Show resolved Hide resolved
@adamjstewart
Copy link
Collaborator

Just to complicate things, different bands of the same image may have different dtypes: #1182

Not sure how to help with this one. We can't stack all bands of the image into one tensor unless the tensor has a single shared dtype.

@calebrob6
Copy link
Member Author

Just to complicate things, different bands of the same image may have different dtypes: #1182

That needs to be handled by casting in a pre-processing step. It doesn't make sense for RasterDataset to have different dtypes for different bands because the input to a neural network needs to be a float32, 16, 8, .. not a int/float mashup.

@github-actions github-actions bot added the testing Continuous integration testing label Mar 22, 2023
@adamjstewart
Copy link
Collaborator

In the case of #1182 it may make more sense for the PIXEL_QA band to be read as a mask instead of an image.

@github-actions github-actions bot removed the testing Continuous integration testing label Mar 22, 2023
@adamjstewart
Copy link
Collaborator

Related to this, it may be useful to add a similar setting for resampling algorithm. By default, all resampling is done by nearest neighbors, which is perfect for classification masks, but less useful (albeit fast) for images.

@calebrob6
Copy link
Member Author

calebrob6 commented Mar 29, 2023

In the case of #1182 it may make more sense for the PIXEL_QA band to be read as a mask instead of an image.

This is supported -- you would make one dataset for the imagery, one dataset for the PIXEL_QA mask then intersection.

Related to this, it may be useful to add a similar setting for resampling algorithm. By default, all resampling is done by nearest neighbors, which is perfect for classification masks, but less useful (albeit fast) for images.

Agree, I'll open an issue

Is there anything that is holding this PR back?

@adamjstewart
Copy link
Collaborator

The only thing holding this PR back is that dtype affects masks but not images. This is very counterintuitive, isn't documented, and there's no warning letting you know this. Is there any reason dtype shouldn't affect both images and masks?

@calebrob6
Copy link
Member Author

Conv2d expects inputs to be float32s, therefore it makes a lot of sense to always make inputs float32s. If the naming is confusing then we can rename it mask_dtype or something.

Is there any reason dtype shouldn't affect both images and masks?

Often the masks will be longs while the inputs will be float32s.

You can try running this:

import torch
import torch.nn as nn

conv = nn.Conv2d(3, 3, 3, bias=False)

for dtype in [
    torch.float32,
    torch.float64,
    torch.complex64,
    torch.complex128,
    torch.float16,
    torch.bfloat16,
    torch.uint8,
    torch.int8,
    torch.int16,
    torch.int32,
    torch.int64,
    torch.bool,
]:
    x = torch.ones(3, 9, 9, dtype=dtype)
    try:
        y = conv(x)
        print(f"{dtype} works")
    except:
        print(f"{dtype} doesn't work")

You should get:

torch.float32 works
torch.float64 doesn't work
torch.complex64 doesn't work
torch.complex128 doesn't work
torch.float16 doesn't work
torch.bfloat16 doesn't work
torch.uint8 doesn't work
torch.int8 doesn't work
torch.int16 doesn't work
torch.int32 doesn't work
torch.int64 doesn't work
torch.bool doesn't work

@adamjstewart
Copy link
Collaborator

Kinda surprised float64 doesn't work.

I would be fine with calling it mask_dtype. We can always rename it later if someone provides a use case for changing image dtype.

@calebrob6
Copy link
Member Author

Actually, calling it mask_dtype doesn't make sense -- RasterDatasets are either images or masks, they only have one dtype.

@adamjstewart
Copy link
Collaborator

Let's call it dtype then. I would either let it apply to both images and masks, or raise an error when someone uses not-float32 for an image.

@calebrob6
Copy link
Member Author

How about a UserWarning? Raising an all-out error seems extreme to me when everything should otherwise work.

@calebrob6
Copy link
Member Author

This should help with #849 btw

@adamjstewart
Copy link
Collaborator

@adamjstewart adamjstewart added this to the 0.5.0 milestone Apr 12, 2023
@github-actions github-actions bot added the testing Continuous integration testing label Apr 23, 2023
Copy link
Collaborator

@adamjstewart adamjstewart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with this approach. We could also use something like:

@property
def dtype(self) -> torch.dtype:
    if self.is_image:
        return torch.float32
    else:
        return torch.long

Then we would only need to override this for the small number of datasets that have a different dtype.

torchgeo/datasets/chesapeake.py Outdated Show resolved Hide resolved
torchgeo/datasets/geo.py Outdated Show resolved Hide resolved
@github-actions github-actions bot removed the testing Continuous integration testing label Apr 23, 2023
@adamjstewart adamjstewart merged commit e13de26 into main Apr 24, 2023
18 checks passed
@adamjstewart adamjstewart deleted the datasets/raster branch April 24, 2023 16:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
datasets Geospatial or benchmark datasets
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants