Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug in EuroSAT.plot() #1648

Closed
robmarkcole opened this issue Oct 11, 2023 · 5 comments · Fixed by #1650
Closed

Bug in EuroSAT.plot() #1648

robmarkcole opened this issue Oct 11, 2023 · 5 comments · Fixed by #1650
Labels
datasets Geospatial or benchmark datasets
Milestone

Comments

@robmarkcole
Copy link
Contributor

robmarkcole commented Oct 11, 2023

Description

I spent time trying to reproduce the EuroSAT.plot() when inferencing a single sample, and couldn't get the sample["label"] and the label shown in the plot to agree. To get them to agree it is necessary to perform sorted(EuroSAT.classes) which I do not see in the plot implementation (and doesn't make sense).

Examples below:

image image image

Steps to reproduce

# note my version without normalisation, although this shouldn't affect the plot
class EuroSATDataModule(NonGeoDataModule):
    """LightningDataModule implementation for the EuroSAT dataset.

    Uses the train/val/test splits from the dataset.

    .. versionadded:: 0.2
    """

    mean = torch.zeros(13)
    std = torch.ones(13)

    def __init__(
        self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
    ) -> None:
        """Initialize a new EuroSATDataModule instance.

        Args:
            batch_size: Size of each mini-batch.
            num_workers: Number of workers for parallel data loading.
            **kwargs: Additional keyword arguments passed to
                :class:`~torchgeo.datasets.EuroSAT`.
        """
        super().__init__(EuroSAT, batch_size, num_workers, **kwargs)


datamodule = EuroSATDataModule(
    batch_size=batch_size, 
    root="data", 
    num_workers=num_workers, 
    download=True,
)


sample = datamodule.test_dataset[0]
label = cast(int, sample["label"].item())
image = sample['image'].unsqueeze(0).to(device)
pred = task(image)
pred_index = int(torch.argmax(pred))

result_str = f"label: {sorted(EuroSAT.classes)[label]}, prediction: {EuroSAT.classes[pred_index]}" # sorted shouldn't be required!
fig = datamodule.test_dataset.plot(sample, suptitle=result_str)

Version

0.5.0

@adamjstewart adamjstewart added the datasets Geospatial or benchmark datasets label Oct 11, 2023
@adamjstewart
Copy link
Collaborator

EuroSAT subclasses NonGeoClassificationDataset which subclasses ImageFolder. ImageFolder sorts directory names to determine labels. Are the label names wrong? We can sort the class names if necessary.

@robmarkcole
Copy link
Contributor Author

robmarkcole commented Oct 12, 2023

If people try to use EuroSAT.classes directly as I have, they will experience the same issue. I will make a small PR to sort these, but suggest that it is worth doing the same for other ImageFolder datasets too. I note also the difference in some label names as picked up from the folders

image image

@robmarkcole
Copy link
Contributor Author

Check other classification datasets:

So only eurosat needs updating

@adamjstewart adamjstewart added this to the 0.5.1 milestone Oct 12, 2023
@adamjstewart
Copy link
Collaborator

The fact that instantiating the dataset overwrites self.classes makes me kinda not want to set it at all for these datasets. We should also ensure that the order documented in the docstring matches. Thoughts?

@robmarkcole
Copy link
Contributor Author

Agreed it doesn't need to be included, I assumed it was a convenience if you wanted to use them elsewhere. You could for example have 'pretty names' like Industrial buildings in the docstring in order, and then people could just copy from there if they need them

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
datasets Geospatial or benchmark datasets
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants