Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SeCo/BYOL: add datamodule, RandomSeasonContrast #1168

Merged
merged 14 commits into from
Mar 17, 2023
Merged

Conversation

adamjstewart
Copy link
Collaborator

@adamjstewart adamjstewart commented Mar 9, 2023

This PR contains the following changes:

  • SeCo: faster initialization
  • SeCo: add datamodule
  • SeCo: add RandomSeasonContrast
  • BYOL: add RandomSeasonContrast

RandomSeasonContrast is the idea from SeCo and SSL4EO where images taken of the same location at different points in time are used as inputs to a contrastive SSL model instead of taking two random crops of the same image.

@adamjstewart adamjstewart added this to the 0.5.0 milestone Mar 9, 2023
@github-actions github-actions bot added datamodules PyTorch Lightning datamodules datasets Geospatial or benchmark datasets documentation Improvements or additions to documentation testing Continuous integration testing trainers PyTorch Lightning trainers labels Mar 9, 2023
self.transforms = transforms
self.download = download
self.checksum = checksum

self._verify()

# TODO: This is slow, I think this should be generated on download and then
Copy link
Collaborator Author

@adamjstewart adamjstewart Mar 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no reason we need to os.walk down 5M directories and 65M files, this is just silly. We know the total number of directories to expect (100k or 1m / 5) and can find subdirectories on the fly.

self.root, self.metadata[self.version]["directory"], f"{index:06}"
)
patch_dirs = glob.glob(os.path.join(directory, "*"))
patch_dirs = random.sample(patch_dirs, self.seasons)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RandomSeasonContrast. User defines how many seasons (patches) they want per location, and we randomly return them. Note that returning all 5 seasons will return them in a random order, but this was already the case in the previous implementation.

sample with an "image" in SCxHxW format where S is the number of seasons

.. versionchanged:: 0.5
Image shape changed from 5xCxHxW to SCxHxW
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kornia requires samples to be B x C x H x W, it doesn't support B x T x C x H x W.


in_channels = self.hyperparams["in_channels"]
assert x.size(1) == in_channels or x.size(1) == 2 * in_channels
if x.size(1) == in_channels:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This trainer now supports both datasets with and without RandomSeasonContrast

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice

@@ -401,33 +409,10 @@ def training_step(self, *args: Any, **kwargs: Any) -> Tensor:
return loss

def validation_step(self, *args: Any, **kwargs: Any) -> None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it makes sense to include validation/test/predict in an SSL trainer. For datasets without labels, there is no easy way to evaluate performance. For datasets with labels, there is no way to know which task we are trying to evaluate (classification, regression, semantic segmentation, etc.)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine. There could be an infinite number of downstream tasks

Copy link
Collaborator Author

@adamjstewart adamjstewart Mar 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://pytorch-lightning.readthedocs.io/en/stable/notebooks/course_UvA-DL/13-contrastive-learning.html defines a validation_step but that's because their trainer only supports validation of classification datasets.

"name,classname",
[
("chesapeake_cvpr_7", ChesapeakeCVPRDataModule),
("chesapeake_cvpr_prior", ChesapeakeCVPRDataModule),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could test any of our 50+ supervised datasets, but I think it makes more sense to test our unsupervised datasets. Right now, we only have SeCo and BYOL, but I'm working on adding SSL4EO and SimCLR/MoCo which will also be tested in a similar fashion.

Copy link
Collaborator

@isaaccorley isaaccorley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@adamjstewart adamjstewart merged commit c1a6fb1 into main Mar 17, 2023
@adamjstewart adamjstewart deleted the datamodules/seco branch March 17, 2023 19:20
yichiac pushed a commit to yichiac/torchgeo that referenced this pull request Apr 29, 2023
* SeCo/BYOL: add datamodule, RandomSeasonContrast

* black

* Fix length, mypy

* Fix tests

* Fix float length

* Simplify length logic

* Simpler plotting

* Fix axes indexing

* Increase coverage

* Increase coverage

* CVPR prior not compatible with segmentation, but is with BYOL

* Increase coverage

* isort fix

* mypy fix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
datamodules PyTorch Lightning datamodules datasets Geospatial or benchmark datasets documentation Improvements or additions to documentation testing Continuous integration testing trainers PyTorch Lightning trainers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants