-
Notifications
You must be signed in to change notification settings - Fork 548
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add default task weights #821
Conversation
4d9f5a3
to
3467fa5
Compare
chemprop/nn/loss.py
Outdated
@@ -27,15 +27,19 @@ | |||
|
|||
|
|||
class LossFunction(nn.Module): | |||
def __init__(self, task_weights: Tensor): | |||
def __init__(self, task_weights: Tensor | None = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this need to be Tensor | None
? what about task_weights: ArrayLike = 1.0
? This corresponds to a default of equal-weight tasks as scalars are fully broadcasted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about it more and agree 1.0 is a better default than None.
@@ -44,18 +44,6 @@ | |||
|
|||
|
|||
class Metric(LossFunction): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So am I understanding correctly that a user can supply task_weights
to a Metric
but that they won't be used? If so, why are you deleting the note in the docstring?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, I'll add back a note that task weights aren't used to the Metric docstring.
chemprop/nn/loss.py
Outdated
if self.task_weights == torch.ones(1).view(1, -1): | ||
return "" | ||
return f"task_weights={self.task_weights.tolist()}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if self.task_weights == torch.ones(1).view(1, -1): | |
return "" | |
return f"task_weights={self.task_weights.tolist()}" | |
return f"task_weights={self.task_weights.tolist()}" |
It's fine to display defaults. See pytorch defaults:
>>> nn.Dropout()
Dropout(p=0.5, inplace=False)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Originally I had None as the default which is why I wouldn’t print a single 1, but with the change to the default, I will also change this.
Co-authored-by: david graff <60193893+davidegraff@users.noreply.github.com>
Description
In #779 we moved task weights to be an attribute of LossFunction. This means it is also an attribute of Metric even though it isn't used. We will address this more thoroughly in v2.1, but requiring task weights when creating a metric will be confusing for users. I thought about it some more and feel task weight should also not be required for when making a loss function if all task weights should be one.
Before I didn't put default task weights of ones in the init for LossFunction because in the LossFunction we don't know how many tasks there will be. Compare to predictors.py that uses `torch.ones(n_tasks) to make a default task weights. But then I realized if we want all ones, we don't have to have the length right because it can be broadcasted.
This PR allows for simple creation of metrics and loss functions if the task weights should be ones via
MSEMetric()
/MSELoss()
. I chose to not include task weights in the repr string if it is a tensor of a single 1 asNone
is the real default and iftask_weights == torch.ones(1).view(1,-1)
that is just a way to get the calculations to work.