Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow tensors in tf.Datasets to have different dimensions. #19318

Merged
merged 1 commit into from
Mar 16, 2024

Conversation

hertschuh
Copy link
Collaborator

The shape for the tf.TensorSpec for the tf.Dataset is determined by inspecting several batches and keeping dimensions that are common.

Fixes #19124

@codecov-commenter
Copy link

codecov-commenter commented Mar 15, 2024

Codecov Report

Attention: Patch coverage is 78.37838% with 8 lines in your changes missing coverage. Please review.

Project coverage is 75.73%. Comparing base (c8700f4) to head (228a748).
Report is 632 commits behind head on master.

Files with missing lines Patch % Lines
keras/trainers/data_adapters/data_adapter_utils.py 73.91% 3 Missing and 3 partials ⚠️
...s/trainers/data_adapters/generator_data_adapter.py 83.33% 0 Missing and 1 partial ⚠️
...rainers/data_adapters/torch_data_loader_adapter.py 80.00% 0 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #19318      +/-   ##
==========================================
- Coverage   80.14%   75.73%   -4.41%     
==========================================
  Files         341      366      +25     
  Lines       36163    40193    +4030     
  Branches     7116     7814     +698     
==========================================
+ Hits        28982    30442    +1460     
- Misses       5578     8064    +2486     
- Partials     1603     1687      +84     
Flag Coverage Δ
keras 75.59% <78.37%> (-4.40%) ⬇️
keras-jax 59.95% <78.37%> (-3.11%) ⬇️
keras-numpy 54.54% <78.37%> (-2.54%) ⬇️
keras-tensorflow 61.46% <78.37%> (-3.19%) ⬇️
keras-torch 60.56% <78.37%> (-3.31%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

@@ -4,6 +4,8 @@
from keras.api_export import keras_export
from keras.utils import tree

NUM_SAMPLES_FOR_TENSOR_SPEC = 4
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just 2?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, these are batches, not samples

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know, I just figured the odds of detecting different shapes would be better with more examples (but it will never be 100% correct).

I will rename.

Also, I need to add unit tests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 2 main use cases are for images and sequences. In both cases you have a very high likelihood that 2 consecutive batches will have different image sizes or sequence lengths if they are dynamic. So IMO 2 is good enough and minimizes overhead. Of course, it's not 100% accurate, as you say.

The shape for the `tf.TensorSpec` for the `tf.Dataset` is determined by inspecting several batches and keeping dimensions that are common.

Fixes keras-team#19124
Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you!

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Mar 16, 2024
@fchollet fchollet merged commit 65c6462 into keras-team:master Mar 16, 2024
6 checks passed
@google-ml-butler google-ml-butler bot removed ready to pull Ready to be merged into the codebase kokoro:force-run labels Mar 16, 2024
@hertschuh hertschuh deleted the dataset_dims branch March 18, 2024 18:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: Assigned Reviewer
Development

Successfully merging this pull request may close these issues.

[Bug] With TensorFlow backend, using PyTorch DataLoader with different per-batch size does not work.
4 participants