Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Front page example for VHR10 dataset does not work #1686

Closed
grantcurell opened this issue Oct 20, 2023 · 11 comments · Fixed by #1920
Closed

Front page example for VHR10 dataset does not work #1686

grantcurell opened this issue Oct 20, 2023 · 11 comments · Fixed by #1920
Labels
documentation Improvements or additions to documentation
Milestone

Comments

@grantcurell
Copy link

grantcurell commented Oct 20, 2023

Description

I started by copying and pasting the example as is from the frontpage:

from torch.utils.data import DataLoader

from torchgeo.datasets import VHR10

dataset = VHR10(root="./raw_data", download=True, checksum=True)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)

for batch in dataloader:
    image = batch["image"]
    label = batch["label"]

    # train a model, or make predictions using a pre-trained model

This produces:

/usr/bin/python3.10 /home/grant/Documents/code/geo_testing/NAIP_test/3_test.py 
Files already downloaded and verified
loading annotations into memory...
Done (t=0.01s)
creating index...
index created!
Traceback (most recent call last):
  File "/home/grant/Documents/code/geo_testing/NAIP_test/3_test.py", line 8, in <module>
    for batch in dataloader:
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 630, in __next__
    data = self._next_data()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1345, in _next_data
    return self._process_data(data)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1371, in _process_data
    data.reraise()
  File "/usr/local/lib/python3.10/dist-packages/torch/_utils.py", line 694, in reraise
    raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    return self.collate_fn(data)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/collate.py", line 265, in default_collate
    return collate(batch, collate_fn_map=default_collate_fn_map)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/collate.py", line 127, in collate
    return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem})
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/collate.py", line 127, in <dictcomp>
    return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem})
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/collate.py", line 119, in collate
    return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/collate.py", line 162, in collate_tensor_fn
    return torch.stack(batch, 0, out=out)
RuntimeError: stack expects each tensor to be equal size, but got [3, 663, 794] at entry 0 and [3, 551, 808] at entry 1

Steps to reproduce

  1. Copy and paste the code from here
  2. Update the root directory to something valid
  3. Run with no additional options

Version

0.5.0

@adamjstewart adamjstewart added this to the 0.5.1 milestone Oct 21, 2023
@adamjstewart
Copy link
Collaborator

We could probably solve this with a Resize augmentation but let's actually choose a simpler dataset, VHR-10 is kind of complicated. @ashnair1 is working on a data module for VHR-10 which will make it easier to use, and will include augmentations like this: #1082

@adamjstewart adamjstewart added the documentation Improvements or additions to documentation label Oct 21, 2023
@grantcurell
Copy link
Author

I'm not sure if you want it, but after staring at this for awhile to figure out what it is I'm looking at (never used any of this stuff before) this is what I came up with:

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchgeo.datasets import VHR10

# Define the resize transform
resize_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((1024, 1024)),
    transforms.ToTensor()
])

# Custom collate function
def custom_collate(batch):
    images = [item["image"] for item in batch]
    labels = [item["labels"] for item in batch]

    resized_images = [resize_transform(img) for img in images]
    resized_images = torch.stack(resized_images)

    # Since labels can have different lengths, we keep them as a list instead of stacking
    return {"image": resized_images, "labels": labels}

# Initialize the dataset
dataset = VHR10(root="./raw_data", download=True, checksum=True)

# Initialize the dataloader with the custom collate function
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4, collate_fn=custom_collate)

# Training loop
for batch in dataloader:
    image = batch["image"]
    labels = batch["labels"]

I can PR it with an explanation for noobies if you want but like you said I'm not sure if it's what you want.

@adamjstewart
Copy link
Collaborator

Our datasets aren't really compatible with torchvision transforms, you'll have much better luck with kornia transforms. Something like:

from kornia.augmentation import Resize
from torchgeo.transforms import AugmentationSequential

transforms = AugmentationSequential(
    Resize(..., ...), data_keys=["image"]
)

See https://torchgeo.readthedocs.io/en/stable/tutorials/transforms.html for more examples.

I tried this and I don't think it's compatible with our current design of VHR10. This should be reworked in #1082. In the meantime, it's probably easier to give an example using a different dataset where images don't require resizing.

@adamjstewart
Copy link
Collaborator

@grantcurell do you want to submit a PR to change the example dataset from VHR10 to EuroSAT while we wait for #1082 to be merged? EuroSAT should be a much simpler example.

@adamjstewart adamjstewart modified the milestones: 0.5.1, 0.5.2 Nov 6, 2023
@grantcurell
Copy link
Author

@grantcurell do you want to submit a PR to change the example dataset from VHR10 to EuroSAT while we wait for #1082 to be merged? EuroSAT should be a much simpler example.

Apologies for my delayed response. I've already done all my other modeling with the VHR10 dataset so for me that's what I'll probably stick with.

If I get the chance though, I'll write something up for EuroSAT.

@connorlee77
Copy link

connorlee77 commented Nov 22, 2023

We could probably solve this with a Resize augmentation but let's actually choose a simpler dataset, VHR-10 is kind of complicated. @ashnair1 is working on a data module for VHR-10 which will make it easier to use, and will include augmentations like this: #1082

@adamjstewart Can you explain why its complicated? I'm facing a similar issue with the chesapeake dataset. I read the corresponding datamodule code, but its unclear why resizing the image before applying a crop is the ideal solution, especially for applications where the pixel resolution matters. Furthermore, each tile in the dataset is quite large, so I'm also not sure why this resizing is even necessary.

@adamjstewart
Copy link
Collaborator

VHR-10 is complicated because it has images, masks, and bounding boxes. Chesapeake only has images and masks, so it's much easier. I think the problem with Chesapeake is slightly different since it's a GeoDataset. In theory, resize/crop shouldn't be needed, but it's needed right now because it's not using the RasterDataset base class. If you open a separate issue maybe @calebrob6 can take a look at fixing this.

@adamjstewart
Copy link
Collaborator

@ashnair1 has this gotten any better now that #1082 has been merged? Or should we switch to a simpler dataset that does not require transforms to use?

@ashnair1
Copy link
Collaborator

The following example will work. Though a simpler dataset might be better suited for the README.

import kornia.augmentation as K
import torch
from torch.utils.data import DataLoader

from torchgeo.datamodules.utils import AugPipe, collate_fn_detection
from torchgeo.datasets import VHR10
from torchgeo.transforms import AugmentationSequential

batch_size = 2

# Initialize the dataset
dataset = VHR10(root="./raw_data/", download=True, checksum=True)

# Initialize the dataloader with the custom collate function
dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    collate_fn=collate_fn_detection,
)

# Initialize augs to normalize and resize images to size (512, 512)
aug = AugPipe(
    augs=AugmentationSequential(
        K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)),
        K.Resize((512, 512)),
        data_keys=["image", "boxes", "masks"],
    ),
    batch_size=batch_size,
)

# Training loop
for batch in dataloader:
    batch = aug(batch)
    images = batch["image"] # List of images
    boxes = batch["boxes"] # List of boxes
    labels = batch["labels"] # List of labels
    masks = batch["masks"] # List of masks

@adamjstewart
Copy link
Collaborator

I do really like the VHR-10 pic we use in the README though... Want to submit a PR to use that code to fix the README example? I would also except a PR that uses a different dataset like EuroSAT instead. We're trying to release 0.5.2 tomorrow or Saturday so it kinda needs to happen fast if we want to get this fixed before the next release.

@ashnair1
Copy link
Collaborator

Ok, let's go with VHR-10 (#1920) for now. We can always switch later.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants