-
Notifications
You must be signed in to change notification settings - Fork 298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ClassificationTask predict step #790
ClassificationTask predict step #790
Conversation
I don't understand the failing test here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add this to all trainers?
No idea what's going on with the failing test, the latest passing commit on master has the same versions of mypy and numpy, so I don't know why it would be failing all of a sudden. |
Technically this could go in 0.3.2 since it doesn't actually add a new feature, it overrides an existing feature. |
Looks like it's a bug introduced with the latest mypy that breaks with Python 3.10.7. The fix will come out in the next mypy release. python/mypy#13627 |
The plan is to add this to all the trainers but each one will have it's own specific needs, e.g you may not want to accumulate predicted segmentation masks into a list in memory and would rather save them in some other way. |
That is unfortunate, I guess dependabot can't control exactly which Python version we test. We could pin to 3.10.6 if we want to get the tests working. |
@adamjstewart mypy v0.981 is out which includes a fix for this failing test. Should I update |
Dependabot will update it for you if you wait < 24 hrs (I think it checks once per day). |
a79fbc1
to
0dc7894
Compare
0dc7894
to
8a02866
Compare
Closing and reopening did the trick |
y_hat: Tensor = self(x).softmax(dim=-1) | ||
return y_hat |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is the predict step so different from the train/val/test steps? Shouldn't it look something like:
y_hat: Tensor = self(x).softmax(dim=-1) | |
return y_hat | |
y_hat = self.forward(x) | |
y_hat_hard = y_hat.argmax(dim=1) | |
return y_hat_hard |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because you will frequently want the per-class probabilities instead of just the top class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And that's not true for train/val/test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The train/val/test steps aren't meant to return data to a user. The operations that happen are only meant to perform the forward pass and compute metrics (which is why the argmax op is needed). The predict step needs to return the raw softmax outputs because we don't want to assume what a user wants to do with the output.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need yhat: Tensor
or can we just use yhat
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- The recommended way to run inference is with
self(x)
notself.forward(x)
asself(x)
calls hooks whileself.forward(x)
silently doesn't. y_hat
is a Tensor so I don't see the problem. Breaking something into multiple lines to getmypy
to stop complaining is ridiculous overhead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we change other methods to replace self.forward(x)
with self(x)
? Just want to make sure we're consistent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes we should change. We shouldn't be calling forward anywhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
P.S. I feel like the yhat: Tensor
thing is due to missing type hints for softmax but I couldn't figure out exactly where I need to add those. PyTorch's source code is a maze.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I was able to fix it by calling it as a function e.g. torch.softmax(tensor) instead of calling it as a method e.g. tensor.softmax()
* add predict_step method to ClassificationTask * fix mypy error
Running
trainer.predict()
with theClassificationTask
fails because the default pytorch_lightningpredict_step
expects a batch tuple and not a dict. This PR overridespredict_step
to work with our classification datasets.E.g.