Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ArrayDataAdapter no longer converts to NumPy and supports sparse tens… #19298

Merged
merged 1 commit into from Mar 14, 2024

Conversation

hertschuh
Copy link
Contributor

…ors.

Instead, the passed arrays can be sliced or indexed in their native format.

  • This addresses ArrayDataAdapter will needlessly convert backend native tensors #18408 and improves performance, especially with Tensorflow and Torch. It improves TF -> TF and Torch -> Torch, but also TF -> Torch and Torch -> TF.
  • This allows the support of sparse tensors (tf.SparseTensor, jax.experimental.sparse.BCOO and scipy.sparse). These sparse tensors are sliced as sparse and the iterators yield sparse tensors in the requested format (either TF or JAX).
  • The validation_split argument of Model.fit() can now be used with anything supported by ArrayDataAdapter, in particular, sparse tensors are now supported.

In summary, ArrayDataAdapter now supports:

  • native Python arrays
  • NumPy arrays
  • Tensorflow tensors, ragged tensors, sparse tensors (new)
  • JAX arrays and BCOO sparse tensors (new)
  • pandas DataFrames
  • pandas Series
  • scipy sparse matrices (new)

Also:

  • Fixed bug where batch level shuffling would shuffle inconsistently the different arrays (in particular inputs and labels) when using a TF dataset or a NumPy iterator.
  • Fixed bug where tf.RaggedTensors would only work when using a TF dataset.
  • Fixed bug where tf.RaggedTensors would not work when doing batch level shuffling.
  • Added a workaround for a bug where tf.casting a tf.SparseTensor would lose the static shape.
  • Added test coverage for tf.RaggedTensors and pandas.Series.
  • Added verification in tests that inputs and labels are shuffled consistently.

…ors.

Instead, the passed arrays can be sliced or indexed in their native format.
- This addresses keras-team#18408 and improves performance, especially with Tensorflow and Torch. It improves TF -> TF and Torch -> Torch, but also TF -> Torch and Torch -> TF.
- This allows the support of sparse tensors (`tf.SparseTensor`, `jax.experimental.sparse.BCOO` and `scipy.sparse`). These sparse tensors are sliced as sparse and the iterators yield sparse tensors in the requested format (either TF or JAX).
- The `validation_split` argument of `Model.fit()` can now be used with anything supported by `ArrayDataAdapter`, in particular, sparse tensors are now supported.

In summary, `ArrayDataAdapter` now supports:
- native Python arrays
- NumPy arrays
- Tensorflow tensors, ragged tensors, sparse tensors (new)
- JAX arrays and BCOO sparse tensors (new)
- pandas DataFrames
- pandas Series
- scipy sparse matrices (new)

Also:
- Fixed bug where batch level shuffling would shuffle inconsistently the different arrays (in particular inputs and labels) when using a TF dataset or a NumPy iterator.
- Fixed bug where `tf.RaggedTensor`s would only work when using a TF dataset.
- Fixed bug where `tf.RaggedTensor`s would not work when doing batch level shuffling.
- Added a workaround for a bug where `tf.cast`ing a `tf.SparseTensor` would lose the static shape.
- Added test coverage for `tf.RaggedTensor`s and `pandas.Series`.
- Added verification in tests that inputs and labels are shuffled consistently.
@codecov-commenter
Copy link

codecov-commenter commented Mar 13, 2024

Codecov Report

Attention: Patch coverage is 89.53168% with 38 lines in your changes are missing coverage. Please review.

Project coverage is 75.70%. Comparing base (c8700f4) to head (769e282).
Report is 94 commits behind head on master.

Files Patch % Lines
keras/trainers/data_adapters/array_slicing.py 93.13% 8 Missing and 8 partials ⚠️
keras/trainers/data_adapters/data_adapter_utils.py 78.00% 4 Missing and 7 partials ⚠️
keras/trainers/data_adapters/array_data_adapter.py 87.27% 5 Missing and 2 partials ⚠️
...s/trainers/data_adapters/generator_data_adapter.py 50.00% 4 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #19298      +/-   ##
==========================================
- Coverage   80.14%   75.70%   -4.44%     
==========================================
  Files         341      366      +25     
  Lines       36163    40069    +3906     
  Branches     7116     7769     +653     
==========================================
+ Hits        28982    30334    +1352     
- Misses       5578     8052    +2474     
- Partials     1603     1683      +80     
Flag Coverage Δ
keras 75.55% <89.53%> (-4.44%) ⬇️
keras-jax 59.88% <89.25%> (-3.18%) ⬇️
keras-numpy 54.43% <83.19%> (-2.65%) ⬇️
keras-tensorflow 61.39% <88.42%> (-3.27%) ⬇️
keras-torch 60.50% <88.70%> (-3.37%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@gbaned gbaned added this to Assigned Reviewer in PR Queue via automation Mar 13, 2024
@gbaned gbaned requested a review from mattdangerw March 13, 2024 08:34
@fchollet
Copy link
Member

Great work! Are you seeing any performance impact?

@hertschuh
Copy link
Contributor Author

Great work! Are you seeing any performance impact?

Yes, here is the summary of the benchmark. I used a trivial model, clocked model.fit(), this is the time difference as a percentage, so a negative value is good (-50% means 2x faster).

Summary:

  • Torch to Torch is a lot faster
  • TF to TF is quite a bit faster
  • the rest is about the same
↓ Input ↓ \ Backend (output) → Tensorflow JAX Torch
Numpy arrays 0 to -21% -3% to -4% -1% to -2%
Tensorflow tensors -23% to -44% 14% to -1%‡ -18%
JAX arrays 0 to -20% -3% to -9% -2%
Torch tensors -14% to -32% 2% to -14% -53% to -55%

Notes:

  • it's a range not a single value because performance impact is different for shuffle=True / False / "batch"
  • Anything less than 4% is not significant
  • ‡ TF tensors to JAX is a bit slower only when shuffle=False

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you!

PR Queue automation moved this from Assigned Reviewer to Approved by Reviewer Mar 14, 2024
@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Mar 14, 2024
@fchollet fchollet merged commit 818c9fa into keras-team:master Mar 14, 2024
7 checks passed
PR Queue automation moved this from Approved by Reviewer to Merged Mar 14, 2024
@google-ml-butler google-ml-butler bot removed awaiting review kokoro:force-run ready to pull Ready to be merged into the codebase labels Mar 14, 2024
@hertschuh hertschuh deleted the sliceable_array_adapter branch March 14, 2024 18:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
PR Queue
Merged
Development

Successfully merging this pull request may close these issues.

None yet

4 participants