-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BUG: [jax2tf] Strided pooling with polymorphic shape sometimes fails. #11804
Comments
It seems that this is only for the case enable_xla=False. @marcvanzee PTAL |
This may have been fixed incidentally by a very recent change #11816. Can you please try again at HEAD? |
Ah, that's great! Still getting an error, but the message has changed.
|
Oh, and things seem to have gotten worse in one regard... |
Looking briefly at the code+error, the reshape here is a bit strange, as the tensor already has batch and channel axes. |
Hi Tom, Thanks for filing the bug! I removed the batch + channel axes logic because I thought this was not part of the op: looking at the operational semantics of XLA::ReduceWindow, the definition of However, it seems it actually is possible to add batch and feature dimensions. In fact, this seems to be what Flax does as well! (so all Flax Modules using pooling now fail) I think this actually quite a good argument for adding some end-to-end Flax --> jax2tf (enable_xla=False) tests, which would catch this. I wrote some tooling in converter_eval a while ago, but I think we should rewrite this to some actual tests. I filed #11872 for this. In any case, I will look at adding back support for batch and channel dimensions. |
I actually found a different bug in our current implementation when doing average pooling with "SAME" padding. It is related to the way TF computes average pooling, and will lead to different outputs than JAX (#11874). @sdenton4 PTAL at that issue since you are doing average pooling with "SAME" padding as well, so I think the bug affects your code. |
I just thought of a solution to that problem using manual padding whenever we encounter SAME padding. Will try implementing that fix together with this one. |
Oh, wow; interesting bug on the TF side. Probably worth filing something with them? It certainly affects me, but the TF behavior seems not too bad for my use-case (averaging embeddings over time). Treating the padding as signal effectively creates a downward bias on the mean. Dealing with the noisier average, without fake data, is probably preferable for me... If there's anything I can help with feel free to send a ping, of course. I've got a couple other things on the front burners right now, but as you know this is all important to me, so happy to pitch in where I can. |
No worries, i am taking a look at this issue and the other one, I'll let you know when I have a fix! |
* Fixes #11804: we only supported `lax.reduce_window` without batch and channel dimensions, which is wrong. This is supported, and in fact something that most users use (this case is actually not explained in the [operational semantics for XLA::ReduceWindow](https://www.tensorflow.org/xla/operation_semantics#reducewindow). I have fixed this and clarified a number of test cases with batch and channel dimensions. * Also, @sdenton4 gave a failing example in a Colab using polymorphic dimensions. I've added this as a test case to make sure it works now. * Adds support for explicit padding using the existing padding logic from convolutions. * Fixes #11874: we were not handling SAME padding for `lax.add` correctly, since we used `tf.nn.avg_pool`, which does not include non-padding tokens (see issue for more details). I resolved it by adding manual padding and added some additional tests for this. PiperOrigin-RevId: 467532166
* Fixes #11804: we only supported `lax.reduce_window` without batch and channel dimensions, which is wrong. This is supported, and in fact something that most users use (this case is actually not explained in the [operational semantics for XLA::ReduceWindow](https://www.tensorflow.org/xla/operation_semantics#reducewindow). I have fixed this and clarified a number of test cases with batch and channel dimensions. * Also, @sdenton4 gave a failing example in a Colab using polymorphic dimensions. I've added this as a test case to make sure it works now. * Adds support for explicit padding using the existing padding logic from convolutions. * Fixes #11874: we were not handling SAME padding for `lax.add` correctly, since we used `tf.nn.avg_pool`, which does not include non-padding tokens (see issue for more details). I resolved it by adding manual padding and added some additional tests for this. * Ensures we call eval_shape on a shape containing polynamials before calling a TF op. PiperOrigin-RevId: 467532166
* Fixes #11804: we only supported `lax.reduce_window` without batch and channel dimensions, which is wrong. This is supported, and in fact something that most users use (this case is actually not explained in the [operational semantics for XLA::ReduceWindow](https://www.tensorflow.org/xla/operation_semantics#reducewindow)). I have fixed this and clarified a number of test cases with batch and channel dimensions. * Also, @sdenton4 gave a failing example in a Colab using polymorphic dimensions. I've added this as a test case to make sure it works now. * Adds support for explicit padding using the existing padding logic from convolutions. * Fixes #11874: we were not handling SAME padding for `lax.add` correctly, since we used `tf.nn.avg_pool`, which does not include non-padding tokens (see issue for more details). I resolved it by adding manual padding and added some additional tests for this. * Ensures we call eval_shape on a shape containing polynamials before calling a TF op. PiperOrigin-RevId: 467532166
* Fixes #11804: we only supported `lax.reduce_window` without batch and channel dimensions, which is wrong. This is supported, and in fact something that most users use (this case is actually not explained in the [operational semantics for XLA::ReduceWindow](https://www.tensorflow.org/xla/operation_semantics#reducewindow)). I have fixed this and clarified a number of test cases with batch and channel dimensions. * Also, @sdenton4 gave a failing example in a Colab using polymorphic dimensions. I've added this as a test case to make sure it works now. * Adds support for explicit padding using the existing padding logic from convolutions. * Fixes #11874: we were not handling SAME padding for `lax.add` correctly, since we used `tf.nn.avg_pool`, which does not include non-padding tokens (see issue for more details). I resolved it by adding manual padding and added some additional tests for this. * Ensures we call eval_shape on a shape containing polynamials before calling a TF op. PiperOrigin-RevId: 467532166
* Fixes #11804: we only supported `lax.reduce_window` without batch and channel dimensions, which is wrong. This is supported, and in fact something that most users use (this case is actually not explained in the [operational semantics for XLA::ReduceWindow](https://www.tensorflow.org/xla/operation_semantics#reducewindow)). I have fixed this and clarified a number of test cases with batch and channel dimensions. * Also, @sdenton4 gave a failing example in a Colab using polymorphic dimensions. I've added this as a test case to make sure it works now. * Adds support for explicit padding using the existing padding logic from convolutions. * Fixes #11874: we were not handling SAME padding for `lax.add` correctly, since we used `tf.nn.avg_pool`, which does not include non-padding tokens (see issue for more details). I resolved it by adding manual padding and added some additional tests for this. * Ensures we call eval_shape on a shape containing polynomials before calling a TF op. * Fixes #11929 (comment): we weren't running any of the shape_poly_test.py tests for `enable_xla=False`. PiperOrigin-RevId: 467532166
* Fixes #11804: we only supported `lax.reduce_window` without batch and channel dimensions, which is wrong. This is supported, and in fact something that most users use (this case is actually not explained in the [operational semantics for XLA::ReduceWindow](https://www.tensorflow.org/xla/operation_semantics#reducewindow)). I have fixed this and clarified a number of test cases with batch and channel dimensions. * Also, @sdenton4 gave a failing example in a Colab using polymorphic dimensions. I've added this as a test case to make sure it works now. * Adds support for explicit padding using the existing padding logic from convolutions. * Fixes #11874: we were not handling SAME padding for `lax.add` correctly, since we used `tf.nn.avg_pool`, which does not include non-padding tokens (see issue for more details). I resolved it by adding manual padding and added some additional tests for this. * Ensures we call eval_shape on a shape containing polynomials before calling a TF op. * Fixes #11929 (comment): we weren't running any of the shape_poly_test.py tests for `enable_xla=False`. PiperOrigin-RevId: 467532166
* Fixes #11804: we only supported `lax.reduce_window` without batch and channel dimensions, which is wrong. This is supported, and in fact something that most users use (this case is actually not explained in the [operational semantics for XLA::ReduceWindow](https://www.tensorflow.org/xla/operation_semantics#reducewindow)). I have fixed this and clarified a number of test cases with batch and channel dimensions. * Also, @sdenton4 gave a failing example in a Colab using polymorphic dimensions. I've added this as a test case to make sure it works now. * Adds support for explicit padding using the existing padding logic from convolutions. * Fixes #11874: we were not handling SAME padding for `lax.add` correctly, since we used `tf.nn.avg_pool`, which does not include non-padding tokens (see issue for more details). I resolved it by adding manual padding and added some additional tests for this. * Ensures we call eval_shape on a shape containing polynomials before calling a TF op. * Fixes #11929 (comment): we weren't running any of the shape_poly_test.py tests for `enable_xla=False`. PiperOrigin-RevId: 467532166
Description
Pooling operations sometimes fail to convert. It looks like a None dimension is sometimes slipping through the cracks. The bug depends on the stride value... I'm using 'framed' inputs to ensure that striding evenly divides the input size.
Here's a minimal colab notebook replication:
https://colab.corp.google.com/drive/1FX99EPcaX-1mAVnpnpUQwkkR0WZ_6o3h#scrollTo=NzSMa2VWhzWq
What jax/jaxlib version are you using?
v0.3.15
Which accelerator(s) are you using?
Additional System Info
No response
The text was updated successfully, but these errors were encountered: