Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add predict_step to BYOLTask #819

Merged

Conversation

isaaccorley
Copy link
Collaborator

Related to #813. This PR adds a predict_step method to BYOLTask so that users can utilize PyTorch Lightning's trainer.predict() function which will automatically loop and predict over a PyTorch DataLoader e.g.:

preds = trainer.predict(model=task, dataloaders=datamodule.test_dataloader())

# Or if your datamodule has a `predict_dataloader` method defined
preds = trainer.predict(model=task, datamodule=datamodule)

Note the way the EncoderWrapper is setup is that calling self(x) on a batch of images will return the projector output embeddings. However, we only need this during training. For linear evaluation or other downstream tasks we want the output of the image encoder (the output of the avg pool layer of the ResNet). To work around this I had to add another hook to store the embeddings on a forward pass in the self._embeddings attribute.

@isaaccorley isaaccorley added this to the 0.3.2 milestone Oct 4, 2022
@isaaccorley isaaccorley self-assigned this Oct 4, 2022
@github-actions github-actions bot added testing Continuous integration testing trainers PyTorch Lightning trainers labels Oct 4, 2022
@adamjstewart adamjstewart mentioned this pull request Oct 4, 2022
6 tasks
@adamjstewart adamjstewart merged commit 0598975 into microsoft:main Oct 5, 2022
@isaaccorley isaaccorley deleted the trainers/byoltask-predict-step branch October 5, 2022 04:59
@adamjstewart adamjstewart modified the milestones: 0.3.2, 0.4.0 Jan 23, 2023
yichiac pushed a commit to yichiac/torchgeo that referenced this pull request Apr 29, 2023
* add predict_step to BYOLTask and tests

* fixes per suggestions
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
testing Continuous integration testing trainers PyTorch Lightning trainers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants