Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modular loading from pretrained #3305

Merged
merged 20 commits into from
Mar 29, 2023

Conversation

tonydavis629
Copy link
Collaborator

Description

This PR allows ModularTorchModel to load components from disk. This is important for any pretraining training regime. The changes are:

  • load_pretrained_components is removed
  • load_from_pretrained from TorchModel is modified to accept components, and is modified to load the components and full model. Mypy incompatible signature errors are ignored.
  • save_checkpoint is modified to save components as well as the full model to a single checkpoint.
  • restore is modified to accept components. Mypy incompatible signature errors are ignored.

Suggestion: we should remove the example usage for ModularTorchModel. It is an abstract class, users are not expected to call ModularTorchModel, so to have an example is more confusing than helpful. Subclasses will be the usage examples.

Type of change

Please check the option that is related to your PR.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
    • In this case, we recommend to discuss your modification on GitHub issues before creating the PR
  • Documentations (modification for documents)

Checklist

  • My code follows the style guidelines of this project
  • Run yapf -i <modified file> and check no errors (yapf version must be 0.32.0)
  • Run mypy -p deepchem and check no errors
  • Run flake8 <modified file> --count and check no errors
  • Run python -m doctest <modified file> and check no errors
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • New unit tests pass locally with my changes
  • I have checked my code and corrected any misspellings

Copy link
Member

@rbharath rbharath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some requests for more documentation since the save/restore algorithm for component has some detail to it

deepchem/models/torch_models/modular.py Show resolved Hide resolved
deepchem/models/torch_models/modular.py Show resolved Hide resolved
deepchem/models/torch_models/modular.py Show resolved Hide resolved
Comment on lines 226 to 235
self.init_emb()
if init_emb:
self.init_emb()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is made so that .restore() functions as expected. Otherwise the bias weights will be filled with 0.

Copy link
Member

@rbharath rbharath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me barring a minor comment issue.

@gusty1g Can you do a quick review pass as well since this is based on your earlier prototype?

deepchem/models/torch_models/modular.py Show resolved Hide resolved

def load_from_pretrained( # type: ignore
self,
source_model: Optional["ModularTorchModel"] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are all three required - source_model, checkpoint, model_dir? I think only model_dir will be sufficient. Given a model_dir, the method loads the state_dict and if any of the current models layer or component matches the keys in state_dict, the method can update those components weights.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's relatively harmless to support multiple loading options here. Gives maybe a bit more flexibility to users

@arunppsg
Copy link
Contributor

I am not sure why we need both load_from_pretrained and restore. Will it be possible to add the functionality of load_from_pretrained to restore?

@tonydavis629
Copy link
Collaborator Author

I am not sure why we need both load_from_pretrained and restore. Will it be possible to add the functionality of load_from_pretrained to restore?

restore is an optional argument in TorchModel.fit. We can move the functionality to load_from_pretrained then modify modulartorchmodel.fit but I think at this point it's just better to maintain the convention we have in torchmodel.

Copy link
Member

@rbharath rbharath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at discussion so far, I think we will be good to merge in once CI is fixed. Looks like we have some flake8 errors and yapf errors.


def load_from_pretrained( # type: ignore
self,
source_model: Optional["ModularTorchModel"] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's relatively harmless to support multiple loading options here. Gives maybe a bit more flexibility to users

Copy link
Member

@rbharath rbharath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, feel free to merge in once CI is clear

@tonydavis629 tonydavis629 merged commit 3dc137a into deepchem:master Mar 29, 2023
@tonydavis629 tonydavis629 deleted the modular_loading branch March 30, 2023 13:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants