Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add pytorch backend trainer #247

Merged
merged 2 commits into from Jun 2, 2023
Merged

add pytorch backend trainer #247

merged 2 commits into from Jun 2, 2023

Conversation

haifeng-jin
Copy link
Member

@haifeng-jin haifeng-jin commented Jun 2, 2023

Have a basic working example of using Keras with PyTorch backend.
Some notable changes to other parts of the code:

  • Use backend.standardize_dtype() when possible, since the dtype in PyTorch is different than other backends. It does not have the .name attribute, and it is not a str.
  • In ProgBar, now use ops instead of np for computing the mean. Other backends tensors work seamlessly with np, but not torch.Tensor.
  • A minor change in torch/numpy.py. torch.zerors(*shape) would not work when shape is an empty tuple. Does other ops support this corner case? @nkovela1 @chenmoneygithub

Missing features to be added later:

  • Use torch script for the training step.
  • Support steps_per_execution.
  • Support validation data. (implement .evaluate())

@haifeng-jin haifeng-jin requested a review from fchollet June 2, 2023 06:36
Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! It looks very clean and close to the TF trainer.

The top priority now will be implementing evaluate/predict and adding compilation support. Do you expect any major issue with compilation?

steps_per_execution is less important and won't be a launch blocker. I'd rather focus on overall performance.

x=np.random.rand(100, 10),
y=np.random.rand(100, 1),
epochs=10,
shuffle=False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: are you able to run integration_tests/numerical_test.py to check the end-to-end numerics?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not yet, since it requires .evaluate() to be implemented.

data = data[0]
x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data)

self.train()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you comment on what this does?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one should be removed since it was duplicated to the one in the .fit(). I commented the one in the .fit() funciton.

# TODO: respect compiled trainable state
if validation_split and validation_data is None:
# Create the validation data using the training data. Only supported
# for TF/numpy/jax arrays.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should work for torch tensors as well -- we should start adding data adapter tests targeting torch tensors (in a new PR)

keras_core/backend/torch/trainer.py Outdated Show resolved Hide resolved
@@ -159,7 +158,7 @@ def update(self, current, values=None, finalize=None):
for k in self._values_order:
info += f" - {k}:"
if isinstance(self._values[k], list):
avg = np.mean(
avg = ops.mean(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the time we hit this I'd expect the logs to be Python/numpy, is it not the case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I turn the backend to TF, the values returned by the values here are still tf.Tensor. It is caused by the Trainer.compute_metrics not returning plain types like in the docs says but backend specific tensors.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok we can leave it as is. Thanks!

@haifeng-jin
Copy link
Member Author

For torch script, I do not see any major issue, but would expect a series of small issues.

@haifeng-jin haifeng-jin requested a review from fchollet June 2, 2023 17:46
@fchollet fchollet merged commit 0724e27 into main Jun 2, 2023
4 checks passed
@haifeng-jin haifeng-jin deleted the haifeng-torch branch June 2, 2023 18:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants