Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ClassificationTask predict step #790

Merged
merged 2 commits into from
Oct 3, 2022

Conversation

isaaccorley
Copy link
Collaborator

Running trainer.predict() with the ClassificationTask fails because the default pytorch_lightning predict_step expects a batch tuple and not a dict. This PR overrides predict_step to work with our classification datasets.

E.g.

preds = trainer.predict(model=task, dataloaders=datamodule.test_dataloader())

# Or if your datamodule has a `predict_dataloader` method defined
preds = trainer.predict(model=task, datamodule=datamodule)

@isaaccorley isaaccorley self-assigned this Sep 23, 2022
@isaaccorley isaaccorley added the trainers PyTorch Lightning trainers label Sep 23, 2022
@github-actions github-actions bot added the testing Continuous integration testing label Sep 23, 2022
@isaaccorley isaaccorley added this to the 0.4.0 milestone Sep 23, 2022
@calebrob6
Copy link
Member

I don't understand the failing test here.

Copy link
Collaborator

@adamjstewart adamjstewart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add this to all trainers?

@adamjstewart
Copy link
Collaborator

No idea what's going on with the failing test, the latest passing commit on master has the same versions of mypy and numpy, so I don't know why it would be failing all of a sudden.

@adamjstewart adamjstewart modified the milestones: 0.4.0, 0.3.2 Sep 25, 2022
@adamjstewart
Copy link
Collaborator

Technically this could go in 0.3.2 since it doesn't actually add a new feature, it overrides an existing feature.

@isaaccorley
Copy link
Collaborator Author

Looks like it's a bug introduced with the latest mypy that breaks with Python 3.10.7. The fix will come out in the next mypy release. python/mypy#13627

@isaaccorley
Copy link
Collaborator Author

The plan is to add this to all the trainers but each one will have it's own specific needs, e.g you may not want to accumulate predicted segmentation masks into a list in memory and would rather save them in some other way.

@adamjstewart
Copy link
Collaborator

That is unfortunate, I guess dependabot can't control exactly which Python version we test. We could pin to 3.10.6 if we want to get the tests working.

@isaaccorley
Copy link
Collaborator Author

@adamjstewart mypy v0.981 is out which includes a fix for this failing test. Should I update requirements/tests.txt in this PR or a separate one?

@adamjstewart
Copy link
Collaborator

Dependabot will update it for you if you wait < 24 hrs (I think it checks once per day).

@isaaccorley isaaccorley reopened this Sep 28, 2022
@isaaccorley
Copy link
Collaborator Author

isaaccorley commented Sep 28, 2022

Not sure why codecov is complaining. When I go to the link it only shows coverage dips on scripts I didn't modify.

Closing and reopening did the trick

Comment on lines +252 to +253
y_hat: Tensor = self(x).softmax(dim=-1)
return y_hat
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the predict step so different from the train/val/test steps? Shouldn't it look something like:

Suggested change
y_hat: Tensor = self(x).softmax(dim=-1)
return y_hat
y_hat = self.forward(x)
y_hat_hard = y_hat.argmax(dim=1)
return y_hat_hard

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because you will frequently want the per-class probabilities instead of just the top class

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And that's not true for train/val/test?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The train/val/test steps aren't meant to return data to a user. The operations that happen are only meant to perform the forward pass and compute metrics (which is why the argmax op is needed). The predict step needs to return the raw softmax outputs because we don't want to assume what a user wants to do with the output.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need yhat: Tensor or can we just use yhat?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • The recommended way to run inference is with self(x) not self.forward(x) as self(x) calls hooks while self.forward(x) silently doesn't.
  • y_hat is a Tensor so I don't see the problem. Breaking something into multiple lines to get mypy to stop complaining is ridiculous overhead.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we change other methods to replace self.forward(x) with self(x)? Just want to make sure we're consistent.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we should change. We shouldn't be calling forward anywhere.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P.S. I feel like the yhat: Tensor thing is due to missing type hints for softmax but I couldn't figure out exactly where I need to add those. PyTorch's source code is a maze.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I was able to fix it by calling it as a function e.g. torch.softmax(tensor) instead of calling it as a method e.g. tensor.softmax()

@adamjstewart adamjstewart merged commit 396b4e3 into microsoft:main Oct 3, 2022
@isaaccorley isaaccorley deleted the classifier/predict-step branch October 3, 2022 18:37
@adamjstewart adamjstewart mentioned this pull request Oct 4, 2022
6 tasks
@adamjstewart adamjstewart modified the milestones: 0.3.2, 0.4.0 Jan 23, 2023
yichiac pushed a commit to yichiac/torchgeo that referenced this pull request Apr 29, 2023
* add predict_step method to ClassificationTask

* fix mypy error
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
testing Continuous integration testing trainers PyTorch Lightning trainers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants