New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add pytorch backend trainer #247
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! It looks very clean and close to the TF trainer.
The top priority now will be implementing evaluate/predict and adding compilation support. Do you expect any major issue with compilation?
steps_per_execution
is less important and won't be a launch blocker. I'd rather focus on overall performance.
x=np.random.rand(100, 10), | ||
y=np.random.rand(100, 1), | ||
epochs=10, | ||
shuffle=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QQ: are you able to run integration_tests/numerical_test.py
to check the end-to-end numerics?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not yet, since it requires .evaluate()
to be implemented.
keras_core/backend/torch/trainer.py
Outdated
data = data[0] | ||
x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data) | ||
|
||
self.train() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you comment on what this does?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one should be removed since it was duplicated to the one in the .fit()
. I commented the one in the .fit()
funciton.
# TODO: respect compiled trainable state | ||
if validation_split and validation_data is None: | ||
# Create the validation data using the training data. Only supported | ||
# for TF/numpy/jax arrays. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should work for torch tensors as well -- we should start adding data adapter tests targeting torch tensors (in a new PR)
@@ -159,7 +158,7 @@ def update(self, current, values=None, finalize=None): | |||
for k in self._values_order: | |||
info += f" - {k}:" | |||
if isinstance(self._values[k], list): | |||
avg = np.mean( | |||
avg = ops.mean( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By the time we hit this I'd expect the logs to be Python/numpy, is it not the case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I turn the backend to TF, the values returned by the values here are still tf.Tensor
. It is caused by the Trainer.compute_metrics
not returning plain types like in the docs says but backend specific tensors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok we can leave it as is. Thanks!
For torch script, I do not see any major issue, but would expect a series of small issues. |
Have a basic working example of using Keras with PyTorch backend.
Some notable changes to other parts of the code:
backend.standardize_dtype()
when possible, since the dtype in PyTorch is different than other backends. It does not have the.name
attribute, and it is not astr
.ProgBar
, now useops
instead ofnp
for computing the mean. Other backends tensors work seamlessly withnp
, but nottorch.Tensor
.torch/numpy.py
.torch.zerors(*shape)
would not work whenshape
is an empty tuple. Does other ops support this corner case? @nkovela1 @chenmoneygithubMissing features to be added later:
steps_per_execution
..evaluate()
)