-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
More rigorous shape inference in to_tf_dataset #4763
Conversation
The documentation is not available anymore as the PR was closed or merged. |
src/datasets/arrow_dataset.py
Outdated
@@ -420,6 +420,26 @@ def to_tf_dataset( | |||
batch_size=batch_size if drop_remainder else None, | |||
) | |||
|
|||
shape_verification_signature, _ = dataset._get_output_signature( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do you need to call it a second time ? can't this logic be inside _get_output_signature ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would make sense, actually! I'll move it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lhoestq I cleaned things up a lot based on your feedback - _get_output_signature
is only called once, and it now immediately samples 200 batches of size 2 to infer the shape, but then overwrites the batch size element of the inferred shape with the actual batch size.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! :)
I also think 10 batches is good by default, going to 200 batches can take too much time for some datasets IMO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually specifically had problems with incorrect inferences when using 10! I think it's preferable for to_tf_dataset()
to be a little slow sometimes (it's only called once at dataset creation time) than to infer wrong shapes and create tricky bugs for users.
If you want, though, I can make num_test_batches
an argument to to_tf_dataset
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually specifically had problems with incorrect inferences when using 10!
Can you explain what problems ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In some cases, sampling 10 batches from the dataset makes it look like the dataset has a constant shape, but actually it doesn't. This is particularly common when datasets have been truncated. For example, if the average length in a dataset before truncation is >> 512, but we truncate at 512, then most batches will have length 512, but if some samples in the dataset have length < 512, then there will occasionally be batches with length < 512 too.
By reducing the batch size for shape inference and increasing the number of batches sampled, this problem is resolved in all the cases I know about!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about adding a way for users to specify if the shapes are fixed or not ? Could be via a new parameter, or by checking if the feature type is Sequence(..., length=512)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that's a good idea! We'll still need shape inference but it might be useful, and I can look into adding it when I get back!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lhoestq Reading the shape from Sequence features has been added!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM 👍
4a043fe
to
31a6d58
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks !
tf.data
needs to know the shape of tensors emitted from atf.data.Dataset
. AlthoughNone
dimensions are possible, overusing them can cause problems - Keras uses the dataset tensor spec at compile-time, and so saying that a dimension isNone
when it's actually constant can hurt performance, or even cause training to fail for dimensions that are needed to determine the shape of weight tensors!The compromise I used here was to sample several batches from the underlying dataset and apply the
collate_fn
to them, and then to see which dimensions were "empirically variable". There's an obvious problem here, though - if you sample 10 batches and they all have the same shape on a certain dimension, there's still a small chance that the 11th batch will be different, and Keras will throw an error if a dataset tries to emit a tensor whose shape doesn't match the spec.I encountered this bug in practice once or twice for datasets that were mostly-but-not-totally constant on a given dimension, and I still don't have a perfect solution, but this PR should greatly reduce the risk. It samples many more batches, and also samples very small batches (size 2) - this increases the variability, making it more likely that a few outlier samples will be detected.
Ideally, of course, we'd determine the full output shape analytically, but that's surprisingly tricky when the
collate_fn
can be any arbitrary Python code!