Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added grover loss functions #3297

Merged
merged 3 commits into from
Mar 29, 2023
Merged

Added grover loss functions #3297

merged 3 commits into from
Mar 29, 2023

Conversation

arunppsg
Copy link
Contributor

Pull Request Template

Description

I am adding grover loss function in this pull request.

Type of change

Please check the option that is related to your PR.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
    • In this case, we recommend to discuss your modification on GitHub issues before creating the PR
  • Documentations (modification for documents)

Checklist

  • My code follows the style guidelines of this project
    • Run yapf -i <modified file> and check no errors (yapf version must be 0.32.0)
    • Run mypy -p deepchem and check no errors
    • Run flake8 <modified file> --count and check no errors
    • Run python -m doctest <modified file> and check no errors
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • New unit tests pass locally with my changes
  • I have checked my code and corrected any misspellings

@arunppsg arunppsg marked this pull request as draft March 22, 2023 14:28
Copy link
Member

@rbharath rbharath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some simplifications needed here. I'd suggest taking a pass based on the comments I've made, and we can discuss offline tomorrow. The core idea is we want to pull out the general concept into a new loss (the atom/bond/functional group terms), but we don't need to pull out things that are just standard mse loss. Let's talk offline since this may be easier to talk through on a review.

Can you add comments and usage examples as well?

deepchem/models/losses.py Show resolved Hide resolved
bv_atom_loss, bv_bond_loss, bv_dist_loss = 0.0, 0.0, 0.0
fg_bond_from_atom_loss, fg_bond_from_bond_loss, fg_bond_dist_loss = 0.0, 0.0, 0.0

if preds["av_task"][0] is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of assuming specific structure on preds, I'd suggest just accepting different tensors for the different sub-loss terms.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, you can have atom, bond, functional group terms

return loss


class GroverFinetuneLoss(Loss):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just MSE loss? (Assuming regression). If so, we may not need to factor this out into a custom loss.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick ping on this question. Is this loss something other models may want to use?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, other models will not want to use it. It takes two embeddings and make a prediction from it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's chat offline about this one since I want to make sure I understand the loss more precisely

Copy link
Member

@rbharath rbharath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did another pass. Looking better, but we need to clean this up a bit more to make it more useful I think

deepchem/models/losses.py Show resolved Hide resolved

sigmoid = nn.Sigmoid()

av_atom_loss = av_task_loss(preds['av_task'][0], targets["av_task"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the expected structure of preds? Why do we expect it to have a field 'av_task'? We can do one round of simplification here to directly accept atom losses, bond loses, etc as the arguments to this loss function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are 3 sets of predictions from GroverPretrain model and they are encoded in the dict preds. It is a dict with keys av_task, bv_task and fg_task.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's separate out these into 3 specific tensor arguments and document their shapes. In general it's better not to pass dictionaries around

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separated loss into three separate arguments.


# print(type(preds))
# TODO: Here, should we need to involve the model status? Using len(preds) is just a hack.
if type(preds) is not tuple:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the expected structure of preds here? We shouldn't make use of undocumented structure otherwise will be hard to maintain

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The finetune model in training mode produces a set of output and in evaluation model it produces a different set of outputs. So, it can be a tuple or just the predictions.

return loss


class GroverFinetuneLoss(Loss):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick ping on this question. Is this loss something other models may want to use?

deepchem/models/losses.py Show resolved Hide resolved
Copy link
Member

@rbharath rbharath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few simple comments below (need to add to docs, need unit test), and I think you have a few TODOs left as well. Let's try to get this merged in by tomorrow so we can get the main Grover PR in

deepchem/models/losses.py Show resolved Hide resolved
deepchem/models/torch_models/grover_layers.py Show resolved Hide resolved
deepchem/models/losses.py Show resolved Hide resolved
deepchem/models/losses.py Show resolved Hide resolved
Copy link
Member

@rbharath rbharath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, feel free to merge once CI is clear

@arunppsg arunppsg marked this pull request as ready for review March 29, 2023 17:23
@arunppsg arunppsg merged commit 011fda8 into deepchem:master Mar 29, 2023
@arunppsg arunppsg deleted the grover-loss branch April 11, 2023 09:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants