-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow for reporting of loss curve from KerasModel.fit #2060
Conversation
CC @peastman |
It's a little ugly how this forces us to change the API of the base Model class, and also how we have to tell mypy to ignore errors because it can't infer the return type. As an alternative, what if we added an argument |
Good point! Let me try reworking this as you suggest |
This is much cleaner! No changes to API needed and solves our needs |
deepchem/models/keras_model.py
Outdated
@@ -302,16 +303,20 @@ def fit(self, | |||
callbacks: function or list of functions | |||
one or more functions of the form f(model, step) that will be invoked after | |||
every step. This can be used to perform validation, logging, etc. | |||
all_losses: list, optional (default False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be "default None". Also we can give a more specific type: Optional[List[float]]
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good suggestion, will fix!
deepchem/models/keras_model.py
Outdated
@@ -421,13 +433,19 @@ def fit_generator(self, | |||
avg_loss = float(avg_loss) / averaged_batches | |||
logger.info( | |||
'Ending global_step %d: Average loss %g' % (current_step, avg_loss)) | |||
avg_losses.append(avg_loss) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation is a bit more complicated than it really needs to be. The only required change to this method is a single addition here:
if all_losses is not None:
all_losses.append(avg_loss)
Everything else can stay exactly as it was before.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good suggestion! I've simplified it in this fashion, with a minor tweak to handle the edge case that was failing in #1944.
That looks much better! I made a few additional minor suggestions. Otherwise, I think it looks good. |
Tests are green so going to go ahead and merge this in now! |
This PR allows for the return of the full loss curve from
KerasModel.fit
by adding a new argumentreturn_loss_curve
as discussed in #2058 and adds a unit test.This PR also fixes the logging issues from #1944.