New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add *_on_batch functions to torch trainer #328
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! Looking good!
keras_core/backend/torch/trainer.py
Outdated
@@ -58,6 +58,64 @@ def train_step(self, data): | |||
|
|||
return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight) | |||
|
|||
def test_step(self, data): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should avoid adding new public Trainer APIs. Could these be floating functions in make_*_function
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
keras_core/backend/torch/trainer.py
Outdated
if self.train_function is not None and not force: | ||
return self.train_function | ||
|
||
def one_step_on_data(data): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
E.g. just inline test_step
here instead of define a wrapper function (which is not needed until we have step fusing or compilation)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
class_weight=None, | ||
return_dict=False, | ||
): | ||
"""Runs a single gradient update on a single batch of data. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something we should do soon is make sure the base Trainer
class has docstrings everywhere, that the Trainer subclasses don't have docstrings, then we programatically set the docstrings on the subclasses (using the ones from the base class)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool. Issue created: #329
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
No description provided.