Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add default task weights #821

Merged
merged 5 commits into from
Apr 22, 2024
Merged

Conversation

KnathanM
Copy link
Contributor

Description

In #779 we moved task weights to be an attribute of LossFunction. This means it is also an attribute of Metric even though it isn't used. We will address this more thoroughly in v2.1, but requiring task weights when creating a metric will be confusing for users. I thought about it some more and feel task weight should also not be required for when making a loss function if all task weights should be one.

Before I didn't put default task weights of ones in the init for LossFunction because in the LossFunction we don't know how many tasks there will be. Compare to predictors.py that uses `torch.ones(n_tasks) to make a default task weights. But then I realized if we want all ones, we don't have to have the length right because it can be broadcasted.

This PR allows for simple creation of metrics and loss functions if the task weights should be ones via MSEMetric()/MSELoss(). I chose to not include task weights in the repr string if it is a tensor of a single 1 as None is the real default and if task_weights == torch.ones(1).view(1,-1) that is just a way to get the calculations to work.

@KnathanM KnathanM mentioned this pull request Apr 21, 2024
@KnathanM KnathanM added this to the v2.0.0 milestone Apr 21, 2024
@@ -27,15 +27,19 @@


class LossFunction(nn.Module):
def __init__(self, task_weights: Tensor):
def __init__(self, task_weights: Tensor | None = None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this need to be Tensor | None? what about task_weights: ArrayLike = 1.0? This corresponds to a default of equal-weight tasks as scalars are fully broadcasted.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about it more and agree 1.0 is a better default than None.

@@ -44,18 +44,6 @@


class Metric(LossFunction):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So am I understanding correctly that a user can supply task_weights to a Metric but that they won't be used? If so, why are you deleting the note in the docstring?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I'll add back a note that task weights aren't used to the Metric docstring.

Comment on lines 88 to 90
if self.task_weights == torch.ones(1).view(1, -1):
return ""
return f"task_weights={self.task_weights.tolist()}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.task_weights == torch.ones(1).view(1, -1):
return ""
return f"task_weights={self.task_weights.tolist()}"
return f"task_weights={self.task_weights.tolist()}"

It's fine to display defaults. See pytorch defaults:

>>> nn.Dropout()
Dropout(p=0.5, inplace=False)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Originally I had None as the default which is why I wouldn’t print a single 1, but with the change to the default, I will also change this.

chemprop/nn/loss.py Outdated Show resolved Hide resolved
chemprop/nn/loss.py Outdated Show resolved Hide resolved
chemprop/nn/metrics.py Outdated Show resolved Hide resolved
chemprop/nn/loss.py Outdated Show resolved Hide resolved
chemprop/nn/metrics.py Outdated Show resolved Hide resolved
KnathanM and others added 2 commits April 22, 2024 12:07
Co-authored-by: david graff <60193893+davidegraff@users.noreply.github.com>
@KnathanM KnathanM enabled auto-merge (squash) April 22, 2024 16:19
@KnathanM KnathanM merged commit e746b2c into chemprop:v2/dev Apr 22, 2024
11 of 12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants