-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Modular loading from pretrained #3305
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some requests for more documentation since the save/restore algorithm for component has some detail to it
self.init_emb() | ||
if init_emb: | ||
self.init_emb() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is made so that .restore() functions as expected. Otherwise the bias weights will be filled with 0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me barring a minor comment issue.
@gusty1g Can you do a quick review pass as well since this is based on your earlier prototype?
|
||
def load_from_pretrained( # type: ignore | ||
self, | ||
source_model: Optional["ModularTorchModel"] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are all three required - source_model, checkpoint, model_dir? I think only model_dir will be sufficient. Given a model_dir, the method loads the state_dict and if any of the current models layer or component matches the keys in state_dict, the method can update those components weights.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's relatively harmless to support multiple loading options here. Gives maybe a bit more flexibility to users
I am not sure why we need both |
restore is an optional argument in TorchModel.fit. We can move the functionality to load_from_pretrained then modify modulartorchmodel.fit but I think at this point it's just better to maintain the convention we have in torchmodel. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at discussion so far, I think we will be good to merge in once CI is fixed. Looks like we have some flake8 errors and yapf errors.
|
||
def load_from_pretrained( # type: ignore | ||
self, | ||
source_model: Optional["ModularTorchModel"] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's relatively harmless to support multiple loading options here. Gives maybe a bit more flexibility to users
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, feel free to merge in once CI is clear
Description
This PR allows ModularTorchModel to load components from disk. This is important for any pretraining training regime. The changes are:
Suggestion: we should remove the example usage for ModularTorchModel. It is an abstract class, users are not expected to call ModularTorchModel, so to have an example is more confusing than helpful. Subclasses will be the usage examples.
Type of change
Please check the option that is related to your PR.
Checklist
yapf -i <modified file>
and check no errors (yapf version must be 0.32.0)mypy -p deepchem
and check no errorsflake8 <modified file> --count
and check no errorspython -m doctest <modified file>
and check no errors