-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrating Huggingface models in Deepchem #3362
Conversation
f546899
to
bdb1564
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few comments here about more docs and some design questions about fit/predict
inputs, labels, weights = self._prepare_batch(batch) | ||
|
||
optimizer.zero_grad() | ||
outputs = self.model(**inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's consider whether we should generalize TorchModel instead of reimplementing fit here since a lot of this code is identical to TorchModel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a couple of ways we can generalize TorchModel here maintaining backward compatability.
- check type of inputs and if it is an instance of
Dict
, then invoke asmodel(**inputs)
else the old style -model(inputs)
- Delegate
model(inputs)
call to a function - sayforward()
which can be implemented in child classes as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Documenting offline discussion, it's ok to leave this code as-is for now, but we should come back and consider refactoring for maintainability
logger.info("TIMING: model fitting took %0.3f s" % (time2 - time1)) | ||
return last_avg_loss | ||
|
||
def _predict(self, generator: Iterable[Tuple[Any, Any, Any]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same point here; we may want to consider generalizing TorchModel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as above from offline discussion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, feel free to merge once CI is clear
Description
Integrating huggingface models in DeepChem.
Type of change
Please check the option that is related to your PR.
Checklist
yapf -i <modified file>
and check no errors (yapf version must be 0.32.0)mypy -p deepchem
and check no errorsflake8 <modified file> --count
and check no errorspython -m doctest <modified file>
and check no errors