Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add predict_step to RegressionTask #818

Conversation

isaaccorley
Copy link
Collaborator

Related to #813. This PR adds a predict_step method to RegressionTask so that users can utilize PyTorch Lightning's trainer.predict() function which will automatically loop and predict over a PyTorch DataLoader e.g.:

preds = trainer.predict(model=task, dataloaders=datamodule.test_dataloader())

# Or if your datamodule has a `predict_dataloader` method defined
preds = trainer.predict(model=task, datamodule=datamodule)

@isaaccorley isaaccorley added this to the 0.3.2 milestone Oct 4, 2022
@isaaccorley isaaccorley self-assigned this Oct 4, 2022
@github-actions github-actions bot added testing Continuous integration testing trainers PyTorch Lightning trainers labels Oct 4, 2022
@isaaccorley
Copy link
Collaborator Author

isaaccorley commented Oct 4, 2022

I also found the reason why we need the Tensor type hint is due to RegressionTask.forward (and other tasks) returning Any and not Tensor so mypy assumes that calling self(x) returns Any.

@adamjstewart adamjstewart merged commit a63c73f into microsoft:main Oct 4, 2022
@adamjstewart adamjstewart mentioned this pull request Oct 4, 2022
6 tasks
@isaaccorley isaaccorley deleted the trainers/regressiontask-predict-step branch October 4, 2022 21:35
@isaaccorley isaaccorley changed the title add predict_step to RegressionTask and tests add predict_step to RegressionTask Oct 4, 2022
@adamjstewart adamjstewart modified the milestones: 0.3.2, 0.4.0 Jan 23, 2023
yichiac pushed a commit to yichiac/torchgeo that referenced this pull request Apr 29, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
testing Continuous integration testing trainers PyTorch Lightning trainers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants