-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ModularTorchModel #3242
ModularTorchModel #3242
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is basically ready to merge once all tests look stable and my minor comment about the function name is fixed
self.model = self.build_model() | ||
self.model.to(self.device) | ||
|
||
def load_from_modular(self, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it may be better to just call this load_from_pretrained
and disable the mypy
complaint? Feels like introducing a new API is more confusing than a slight mismatch in the call structure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But we're still introducing a new API. We're just naming it the same as an old one.
It looks like some additional tests are failing as well. By comparison to #3241 which has 20 green, we only have 15 tests green suggesting there may be some breakages |
@rbharath I believe this is ready for a final review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Almost at the finish line I think. Just needs a unit test for unfreeze_components
raise NotImplementedError("Subclass must define the loss function") | ||
|
||
def freeze_components(self, components: List[str]): | ||
"""Freezes or unfreezes the parameters of the specified components. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Slight edit, "unfreezes" should no longer be in this docstring
self.model = self.build_model() | ||
self.model.to(self.device) | ||
|
||
def unfreeze_components(self, components: List[str]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add a unit test for unfreeze_components
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this all looks good. Only question we are at 22 passing and not 23. If that's transient I think we are good to merge
Pull Request Template
Description
This PR adds in a new abstract model, ModularTorchModel. The premise of ModularTorchModel is that we may want to take components from different models and combine them, or modify a model to work with a different task. This is useful for transfer learning where we may want to take an embedding from a pretrained model, or modify a model's prediction head to work with a pretraining task.
Type of change
Checklist
yapf -i <modified file>
and check no errors (yapf version must be 0.32.0)mypy -p deepchem
and check no errorsflake8 <modified file> --count
and check no errorspython -m doctest <modified file>
and check no errors