-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adding grover pretrain model as ModularTorchModel #3272
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did a first review pass. This is a good start! Some room for cleanup here. Let's discuss questions offline since there are some design questions open here
636a16a
to
71b11f2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this PR is still in draft stage and not ready for full review so I just did a quick pass and pointed out a couple things that jumped out at me. Let me know once this is ready for full review
71b11f2
to
61d549b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Capturing the offline discussion, we will want to split this into multiple PRs:
- Fixing
GraphData
,BatchGraphData
to support Grover (and DMPNN) - Adding Grover's loss to DeepChem's losses
- Adding Grover as a modular torch model.
We should also discuss the proposed save/restore fix offline and see if we want to make that a separate PR
73234b3
to
694bf23
Compare
What is the issue with GraphData not supporting DMPNN and Grover? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this PR is waiting on the earlier graph data handling PR so not doing a thorugh review, but a comment that we still need save/reload/overfit tests)
deepchem/utils/test/test_grover.py
Outdated
@@ -0,0 +1,34 @@ | |||
import pytest | |||
import deepchem as dc | |||
from deepchem.feat.graph_data import BatchGraphData |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a reminder, we need to add the save/reload/overfit tests before we can merge this PR in
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a lot of code here that I think will go away once we rebase. I'll hold off on doing a full review till we merge in the previous PR and rebase since it'll get messy otherwise
It's not that it does not support - I didn't know how to make DMPNN and Grover graph batching work with GraphData. To add a bit of context, DMPNN and Grover models use same input representations. |
9f6b27d
to
4ce9c43
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The core implementation is looking solid, but I have some requests for improved documentation and additional unit tests below.
return self._prepare_batch_for_finetuning(data) | ||
|
||
def _prepare_batch_for_pretraining(self, batch: Tuple[Any, Any, Any]): | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a description to the docstring of the preparation required?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reminder here as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added description
|
||
This layer is a simple wrapper over GroverTransEncoder layer for retrieving the embeddings from the GroverTransEncoder corresponding to the `embedding_output_type` chosen by the user. | ||
class Grover(ModularTorchModel): | ||
"""Grove model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add this model to the docs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quick reminder to add to the docs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added to docs
4ce9c43
to
4c876d2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will re-review this PR later after we merge the earlier PRs and rebase
5e407b7
to
d72e192
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is nearly ready but there are still a number of missing docs. Can you do a pass and make sure all comments are addressed?
|
||
This layer is a simple wrapper over GroverTransEncoder layer for retrieving the embeddings from the GroverTransEncoder corresponding to the `embedding_output_type` chosen by the user. | ||
class Grover(ModularTorchModel): | ||
"""Grove model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quick reminder to add to the docs
return self._prepare_batch_for_finetuning(data) | ||
|
||
def _prepare_batch_for_pretraining(self, batch: Tuple[Any, Any, Any]): | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reminder here as well
d72e192
to
64be63b
Compare
64be63b
to
9994cca
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, please go ahead and merge in once CI is clear.
Nice working getting this to the finish line! It's a big one
The ci failures are not related to these changes. Going ahead and merging this in. |
Pull Request Template
Description
The PR adds GroverPretrainer to train embeddings. The main contribution is the self-supervised pretraining task style pretraining for training embeddings.
Type of change
Please check the option that is related to your PR.
Checklist
yapf -i <modified file>
and check no errors (yapf version must be 0.32.0)mypy -p deepchem
and check no errorsflake8 <modified file> --count
and check no errorspython -m doctest <modified file>
and check no errors