-
Notifications
You must be signed in to change notification settings - Fork 347
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SeCo/BYOL: add datamodule, RandomSeasonContrast #1168
Conversation
self.transforms = transforms | ||
self.download = download | ||
self.checksum = checksum | ||
|
||
self._verify() | ||
|
||
# TODO: This is slow, I think this should be generated on download and then |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's no reason we need to os.walk down 5M directories and 65M files, this is just silly. We know the total number of directories to expect (100k or 1m / 5) and can find subdirectories on the fly.
torchgeo/datasets/seco.py
Outdated
self.root, self.metadata[self.version]["directory"], f"{index:06}" | ||
) | ||
patch_dirs = glob.glob(os.path.join(directory, "*")) | ||
patch_dirs = random.sample(patch_dirs, self.seasons) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RandomSeasonContrast. User defines how many seasons (patches) they want per location, and we randomly return them. Note that returning all 5 seasons will return them in a random order, but this was already the case in the previous implementation.
sample with an "image" in SCxHxW format where S is the number of seasons | ||
|
||
.. versionchanged:: 0.5 | ||
Image shape changed from 5xCxHxW to SCxHxW |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Kornia requires samples to be B x C x H x W, it doesn't support B x T x C x H x W.
|
||
in_channels = self.hyperparams["in_channels"] | ||
assert x.size(1) == in_channels or x.size(1) == 2 * in_channels | ||
if x.size(1) == in_channels: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This trainer now supports both datasets with and without RandomSeasonContrast
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice
@@ -401,33 +409,10 @@ def training_step(self, *args: Any, **kwargs: Any) -> Tensor: | |||
return loss | |||
|
|||
def validation_step(self, *args: Any, **kwargs: Any) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it makes sense to include validation/test/predict in an SSL trainer. For datasets without labels, there is no easy way to evaluate performance. For datasets with labels, there is no way to know which task we are trying to evaluate (classification, regression, semantic segmentation, etc.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is fine. There could be an infinite number of downstream tasks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://pytorch-lightning.readthedocs.io/en/stable/notebooks/course_UvA-DL/13-contrastive-learning.html defines a validation_step but that's because their trainer only supports validation of classification datasets.
"name,classname", | ||
[ | ||
("chesapeake_cvpr_7", ChesapeakeCVPRDataModule), | ||
("chesapeake_cvpr_prior", ChesapeakeCVPRDataModule), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could test any of our 50+ supervised datasets, but I think it makes more sense to test our unsupervised datasets. Right now, we only have SeCo and BYOL, but I'm working on adding SSL4EO and SimCLR/MoCo which will also be tested in a similar fashion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* SeCo/BYOL: add datamodule, RandomSeasonContrast * black * Fix length, mypy * Fix tests * Fix float length * Simplify length logic * Simpler plotting * Fix axes indexing * Increase coverage * Increase coverage * CVPR prior not compatible with segmentation, but is with BYOL * Increase coverage * isort fix * mypy fix
This PR contains the following changes:
RandomSeasonContrast is the idea from SeCo and SSL4EO where images taken of the same location at different points in time are used as inputs to a contrastive SSL model instead of taking two random crops of the same image.