-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added grover loss functions #3297
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some simplifications needed here. I'd suggest taking a pass based on the comments I've made, and we can discuss offline tomorrow. The core idea is we want to pull out the general concept into a new loss (the atom/bond/functional group terms), but we don't need to pull out things that are just standard mse loss. Let's talk offline since this may be easier to talk through on a review.
Can you add comments and usage examples as well?
deepchem/models/losses.py
Outdated
bv_atom_loss, bv_bond_loss, bv_dist_loss = 0.0, 0.0, 0.0 | ||
fg_bond_from_atom_loss, fg_bond_from_bond_loss, fg_bond_dist_loss = 0.0, 0.0, 0.0 | ||
|
||
if preds["av_task"][0] is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of assuming specific structure on preds
, I'd suggest just accepting different tensors for the different sub-loss terms.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example, you can have atom, bond, functional group terms
deepchem/models/losses.py
Outdated
return loss | ||
|
||
|
||
class GroverFinetuneLoss(Loss): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this just MSE loss? (Assuming regression). If so, we may not need to factor this out into a custom loss.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quick ping on this question. Is this loss something other models may want to use?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, other models will not want to use it. It takes two embeddings and make a prediction from it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's chat offline about this one since I want to make sure I understand the loss more precisely
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did another pass. Looking better, but we need to clean this up a bit more to make it more useful I think
deepchem/models/losses.py
Outdated
|
||
sigmoid = nn.Sigmoid() | ||
|
||
av_atom_loss = av_task_loss(preds['av_task'][0], targets["av_task"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the expected structure of preds
? Why do we expect it to have a field 'av_task'? We can do one round of simplification here to directly accept atom losses, bond loses, etc as the arguments to this loss function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are 3 sets of predictions from GroverPretrain model and they are encoded in the dict preds
. It is a dict with keys av_task
, bv_task
and fg_task
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's separate out these into 3 specific tensor arguments and document their shapes. In general it's better not to pass dictionaries around
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Separated loss into three separate arguments.
deepchem/models/losses.py
Outdated
|
||
# print(type(preds)) | ||
# TODO: Here, should we need to involve the model status? Using len(preds) is just a hack. | ||
if type(preds) is not tuple: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the expected structure of preds
here? We shouldn't make use of undocumented structure otherwise will be hard to maintain
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The finetune model in training mode produces a set of output and in evaluation model it produces a different set of outputs. So, it can be a tuple or just the predictions.
deepchem/models/losses.py
Outdated
return loss | ||
|
||
|
||
class GroverFinetuneLoss(Loss): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quick ping on this question. Is this loss something other models may want to use?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few simple comments below (need to add to docs, need unit test), and I think you have a few TODOs left as well. Let's try to get this merged in by tomorrow so we can get the main Grover PR in
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, feel free to merge once CI is clear
Pull Request Template
Description
I am adding grover loss function in this pull request.
Type of change
Please check the option that is related to your PR.
Checklist
yapf -i <modified file>
and check no errors (yapf version must be 0.32.0)mypy -p deepchem
and check no errorsflake8 <modified file> --count
and check no errorspython -m doctest <modified file>
and check no errors