Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EuroSATDataModule initialised with 3 bands gives error: mean length and number of channels do not match #1634

Closed
robmarkcole opened this issue Oct 9, 2023 · 10 comments · Fixed by #1681
Labels
datamodules PyTorch Lightning datamodules
Milestone

Comments

@robmarkcole
Copy link
Contributor

robmarkcole commented Oct 9, 2023

Description

I have previously:

  • Initialised with download=True the EuroSATDataModule with 13 bands
  • Trained a model etc

I now:

  • Create the data module as follows:
batch_size = 16
num_workers = 0
rgb_bands = ("B04", "B03", "B02") # or experiment with all

datamodule = EuroSATDataModule(batch_size=batch_size, root="data", num_workers=num_workers, bands=rgb_bands)

# Do some validation on the train set
datamodule.prepare_data()
datamodule.setup('fit')
datamodule.train_dataset[0]['image'].shape # prints torch.Size([3, 64, 64]), confirms 3 bands

Create a task:

task = ClassificationTask(
    model="resnet18",
    weights=True, # standard Imagenet
    # weights=ResNet18_Weights.SENTINEL2_ALL_MOCO, # or try sentinel 2
    num_classes=10,
    in_channels=3, # make sure to update
    loss="ce", 
    patience=10
)

Train:

trainer.fit(model=task, datamodule=datamodule)

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[10], line 1
----> 1 trainer.fit(model=task, datamodule=datamodule)

File [/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:532](https://vscode-remote+lightning-002eai.vscode-resource.vscode-cdn.net/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:532), in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    530 self.strategy._lightning_module = model
    531 _verify_strategy_supports_compile(model, self.strategy)
--> 532 call._call_and_handle_interrupt(
    533     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    534 )

File [/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:43](https://vscode-remote+lightning-002eai.vscode-resource.vscode-cdn.net/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:43), in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     41     if trainer.strategy.launcher is not None:
     42         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 43     return trainer_fn(*args, **kwargs)
     45 except _TunerExitException:
     46     _call_teardown_hook(trainer)

File [/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:571](https://vscode-remote+lightning-002eai.vscode-resource.vscode-cdn.net/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:571), in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    561 self._data_connector.attach_data(
    562     model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
    563 )
    565 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    566     self.state.fn,
...
--> 109         raise ValueError(f"mean length and number of channels do not match. Got {mean.shape} and {data.shape}.")
    111 # Allow broadcast on channel dimension
    112 if std.shape and std.shape[0] != 1:

ValueError: mean length and number of channels do not match. Got torch.Size([13]) and torch.Size([16, 3, 64, 64]).

This presumably occurs on the attempted normalisation, which is expecting 13 bands

Steps to reproduce

As above

Version

0.5.0

@robmarkcole
Copy link
Contributor Author

Tried to hack around this with

MEAN = torch.tensor(
    [
        947.62620298,
        1118.24399958,
        1042.92983953,
    ]
)

STD = torch.tensor(
    [
        593.75055589,
        333.00778264,
        395.09249139,
    ]
)
...
datamodule.mean = MEAN
datamodule.std = STD

But error still arises

@isaaccorley
Copy link
Collaborator

If I understand correctly, what you want to use is SENTINEL2_RGB_MOCO for the weights and not SENTINEL2_ALL_MOCO which expects all the multispectral bands.

@isaaccorley
Copy link
Collaborator

Also, your hack wouldn't actually work. You would need to override the actual Normalize transform itself since it's already been instantiated. We prefer that users simply subclass the datamodule into their own custom datamodule if they want to override attributes.

@robmarkcole
Copy link
Contributor Author

With weights=ResNet18_Weights.SENTINEL2_RGB_MOCO I get the error - I'd assumed the normalisation would be band aware. So to summarise intended use is:

rgb_bands = ("B04", "B03", "B02") # or experiment with all

datamodule = EuroSATDataModule(
    batch_size=batch_size, 
    root="data", 
    num_workers=num_workers, 
    bands=rgb_bands,
    download=True
)

But I should replace EuroSATDataModule with a subclass that re-implements the normalisation?

My objective is a simple tutorial for a hackathon where people can experiment with number of bands, pretrained weights etc

@robmarkcole
Copy link
Contributor Author

robmarkcole commented Oct 9, 2023

I've attempted the custom module below:

# extract for RGB later

mins = torch.tensor(
    [
        1013.0,
        676.0,
        448.0,
        247.0,
        269.0,
        253.0,
        243.0,
        189.0,
        61.0,
        4.0,
        33.0,
        11.0,
        186.0,
    ]
)
maxs = torch.tensor(
    [
        2309.0,
        4543.05,
        4720.2,
        5293.05,
        3902.05,
        4473.0,
        5447.0,
        5948.05,
        1829.0,
        23.0,
        4894.05,
        4076.05,
        5846.0,
    ]
)

# use vaules from https://github.com/microsoft/torchgeo/blob/main/torchgeo/datasets/eurosat.py
bands = {
    "B01": "Coastal Aerosol",
    "B02": "Blue",
    "B03": "Green",
    "B04": "Red",
    "B05": "Vegetation Red Edge 1",
    "B06": "Vegetation Red Edge 2",
    "B07": "Vegetation Red Edge 3",
    "B08": "NIR 1",
    "B08A": "NIR 2",
    "B09": "Water Vapour",
    "B10": "SWIR 1",
    "B11": "SWIR 2",
    "B12": "SWIR 3",
}

rgb_bands = ("B04", "B03", "B02") # or experiment with all

# Get the indices of the keys represented in rgb_bands
rgb_indices = [list(bands.keys()).index(band) for band in rgb_bands]
mins = mins[rgb_indices]
maxs = maxs[rgb_indices]

class MinMaxNormalize(K.IntensityAugmentationBase2D):
    """Normalize channels to the range [0, 1] using min/max values."""

    def __init__(self, mins: Tensor, maxs: Tensor) -> None:
        super().__init__(p=1)
        self.flags = {"mins": mins.view(1, -1, 1, 1), "maxs": maxs.view(1, -1, 1, 1)}

    def apply_transform(
        self,
        input: Tensor,
        params: Dict[str, Tensor],
        flags: Dict[str, int],
        transform: Optional[Tensor] = None,
    ) -> Tensor:
        return (input - flags["mins"]) / (flags["maxs"] - flags["mins"] + 1e-10)

class RGBEuroSATDataModule(NonGeoDataModule):
    def __init__(self, data_dir: str, batch_size: int = 64):
        super().__init__(batch_size=batch_size, dataset_class=NonGeoDataset)
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.rgb_bands = ("B04", "B03", "B02") # or experiment with all

        # Define transforms
        self.train_transforms = AugmentationSequential(
            MinMaxNormalize(mins, maxs),
            K.RandomHorizontalFlip(p=0.5),
            K.RandomVerticalFlip(p=0.5),
            K.RandomAffine(degrees=(0, 90), p=0.25),
            K.RandomGaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0), p=0.25),
            K.RandomResizedCrop(size=(512, 512), scale=(0.8, 1.0), p=0.25),
            data_keys=["image"],
        )

        self.test_transforms = nn.Sequential(
            MinMaxNormalize(mins, maxs),
        )

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            self.train_dataset = EuroSAT(root=self.data_dir, split='train', transforms=self.train_transforms, download=True, bands=self.rgb_bands)
            self.val_dataset = EuroSAT(root=self.data_dir, split='val', transforms=self.test_transforms, download=True, bands=self.rgb_bands)
        if stage == 'test' or stage is None:
            self.test_dataset = EuroSAT(root=self.data_dir, split='test', transforms=self.test_transforms, download=True, bands=self.rgb_bands)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)


datamodule = RGBEuroSATDataModule(
    data_dir="data",
    batch_size=batch_size,
)

task = ClassificationTask(
    model="resnet18",
    # weights=True, # standard Imagenet
    # weights=ResNet18_Weights.SENTINEL2_ALL_MOCO, # or try sentinel 2 all bands
    weights=ResNet18_Weights.SENTINEL2_RGB_MOCO, # or try sentinel 2 rgb bands
    num_classes=10,
    in_channels=len(datamodule.train_dataset.bands), # make sure to validate
    loss="ce", 
    patience=10
)

trainer.fit(model=task, datamodule=datamodule)

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[19], line 1
----> 1 trainer.fit(model=task, datamodule=datamodule)

File [/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:532](https://vscode-remote+lightning-002eai.vscode-resource.vscode-cdn.net/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:532), in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    530 self.strategy._lightning_module = model
    531 _verify_strategy_supports_compile(model, self.strategy)
--> 532 call._call_and_handle_interrupt(
    533     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    534 )

File [/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:43](https://vscode-remote+lightning-002eai.vscode-resource.vscode-cdn.net/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:43), in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     41     if trainer.strategy.launcher is not None:
     42         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 43     return trainer_fn(*args, **kwargs)
     45 except _TunerExitException:
     46     _call_teardown_hook(trainer)

File [/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:571](https://vscode-remote+lightning-002eai.vscode-resource.vscode-cdn.net/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:571), in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    561 self._data_connector.attach_data(
    562     model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
    563 )
    565 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    566     self.state.fn,
...
--> 199 input_shape = in_tensor.shape
    200 in_tensor = self.transform_tensor(in_tensor)
    201 batch_shape = in_tensor.shape

AttributeError: 'dict' object has no attribute 'shape'

@robmarkcole
Copy link
Contributor Author

OK appears I just overcomplicated it - get the correct MEAN and STD then just

class RGBEuroSATDataModule(NonGeoDataModule):
    mean = MEAN
    std = STD
    rgb_bands = ("B04", "B03", "B02") # or experiment with all

    def __init__(
        self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
    ) -> None:
        """Initialize a new EuroSATDataModule instance.

        Args:
            batch_size: Size of each mini-batch.
            num_workers: Number of workers for parallel data loading.
            **kwargs: Additional keyword arguments passed to
                :class:`~torchgeo.datasets.EuroSAT`.
        """
        super().__init__(EuroSAT, batch_size, num_workers, bands=self.rgb_bands, **kwargs)

@adamjstewart adamjstewart added this to the 0.5.1 milestone Oct 9, 2023
@adamjstewart
Copy link
Collaborator

I personally consider it a bug that EuroSATDataModule(bands=rgb_bands) doesn't work automatically. OSCD and So2Sat data modules correctly modify mean/std for a change in bands. Want to submit a PR to do something similar?

@adamjstewart adamjstewart added the datamodules PyTorch Lightning datamodules label Oct 9, 2023
@robmarkcole
Copy link
Contributor Author

A quick check of OSCD shows it supports 'all' or 'rbg' normalisation. Likewise so2sat supports via a 'band_set' approach. However I prefer the more flexible approach I demonstrated above, which allows any combo of bands. Should I implement that, or the 'band_set' approach?

@adamjstewart
Copy link
Collaborator

I'm fine with the more flexible approach. You could even modify OSCD/So2Sat/other datasets to match, but that's too much work to ask you to do.

@robmarkcole
Copy link
Contributor Author

@adamjstewart sure will do after #1646

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
datamodules PyTorch Lightning datamodules
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants