Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: [jax2tf] Strided pooling with polymorphic shape sometimes fails. #11804

Closed
2 of 3 tasks
sdenton4 opened this issue Aug 8, 2022 · 11 comments · Fixed by #11913
Closed
2 of 3 tasks

BUG: [jax2tf] Strided pooling with polymorphic shape sometimes fails. #11804

sdenton4 opened this issue Aug 8, 2022 · 11 comments · Fixed by #11913
Assignees
Labels
bug Something isn't working

Comments

@sdenton4
Copy link

sdenton4 commented Aug 8, 2022

Description

Pooling operations sometimes fail to convert. It looks like a None dimension is sometimes slipping through the cracks. The bug depends on the stride value... I'm using 'framed' inputs to ensure that striding evenly divides the input size.

Here's a minimal colab notebook replication:
https://colab.corp.google.com/drive/1FX99EPcaX-1mAVnpnpUQwkkR0WZ_6o3h#scrollTo=NzSMa2VWhzWq

TypeError: in user code:

    File "<ipython-input-2-e6d49ad967e7>", line 26, in None  *
        lambda inputs: converted_infer_fn(
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/experimental/jax2tf/jax2tf.py", line 302, in fun_no_kwargs  *
        return fun(*args, **kwargs)
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/experimental/jax2tf/jax2tf.py", line 534, in _interpret_fun  *
        out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] =               _call_wrapped_with_new_constant_cache(fun, in_vals,
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/experimental/jax2tf/jax2tf.py", line 688, in _call_wrapped_with_new_constant_cache  *
        out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] =         fun.call_wrapped(*in_vals)
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/linear_util.py", line 168, in call_wrapped  *
        ans = self.f(*args, **dict(self.params, **kwargs))
    File "<ipython-input-23-a74605bc8077>", line 11, in infer_fn  *
        pooled = nn.pooling.avg_pool(
    File "/tmp/lyra_notebook.par/google3/third_party/py/flax/linen/pooling.py", line 72, in avg_pool  *
        y = pool(inputs, 0., lax.add, window_shape, strides, padding)
    File "/tmp/lyra_notebook.par/google3/third_party/py/flax/linen/pooling.py", line 52, in pool  *
        y = lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding)
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/_src/lax/windowed_reductions.py", line 79, in reduce_window  *
        return monoid_reducer(operand, window_dimensions, window_strides, padding,
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/_src/lax/windowed_reductions.py", line 127, in _reduce_window_sum  *
        window_dilation=tuple(window_dilation))
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/core.py", line 324, in bind  *
        return self.bind_with_trace(find_top_trace(args), args, params)
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/core.py", line 327, in bind_with_trace  *
        out = trace.process_primitive(self, map(trace.full_raise, args), params)
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/experimental/jax2tf/jax2tf.py", line 977, in invoke_impl  *
        **params)
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/experimental/jax2tf/impl_no_xla.py", line 553, in _reduce_window  *
        tf_padding = pads_to_padtype(operand.shape, window_dimensions, window_strides, padding)
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/experimental/jax2tf/impl_no_xla.py", line 115, in pads_to_padtype  *
        pads = lax.padtype_to_pads(in_shape, window_shape, window_strides, pad_str)
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/_src/lax/lax.py", line 4527, in padtype_to_pads  *
        out_shape = _ceil_divide(in_shape, window_strides)
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/_src/lax/lax.py", line 4512, in _ceil_divide  *
        return -np.floor_divide(np.negative(x1), x2)

    TypeError: bad operand type for unary -: 'NoneType'

What jax/jaxlib version are you using?

v0.3.15

Which accelerator(s) are you using?

  • CPU
  • GPU
  • TPU

Additional System Info

No response

@sdenton4 sdenton4 added the bug Something isn't working label Aug 8, 2022
@gnecula
Copy link
Collaborator

gnecula commented Aug 11, 2022

It seems that this is only for the case enable_xla=False.

@marcvanzee PTAL

@gnecula gnecula assigned marcvanzee and unassigned marcvanzee and gnecula Aug 11, 2022
@gnecula
Copy link
Collaborator

gnecula commented Aug 11, 2022

This may have been fixed incidentally by a very recent change #11816. Can you please try again at HEAD?

@sdenton4
Copy link
Author

Ah, that's great! Still getting an error, but the message has changed.

TypeError: in user code:

    File "<ipython-input-2-e6d49ad967e7>", line 26, in None  *
        lambda inputs: converted_infer_fn(
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/experimental/jax2tf/jax2tf.py", line 302, in fun_no_kwargs  *
        return fun(*args, **kwargs)
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/experimental/jax2tf/jax2tf.py", line 534, in _interpret_fun  *
        out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] =               _call_wrapped_with_new_constant_cache(fun, in_vals,
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/experimental/jax2tf/jax2tf.py", line 688, in _call_wrapped_with_new_constant_cache  *
        out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] =         fun.call_wrapped(*in_vals)
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/linear_util.py", line 168, in call_wrapped  *
        ans = self.f(*args, **dict(self.params, **kwargs))
    File "<ipython-input-3-587f6ee978a3>", line 13, in infer_fn  *
        pooled = nn.pooling.avg_pool(
    File "/tmp/lyra_notebook.par/google3/third_party/py/flax/linen/pooling.py", line 72, in avg_pool  *
        y = pool(inputs, 0., lax.add, window_shape, strides, padding)
    File "/tmp/lyra_notebook.par/google3/third_party/py/flax/linen/pooling.py", line 52, in pool  *
        y = lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding)
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/_src/lax/windowed_reductions.py", line 79, in reduce_window  *
        return monoid_reducer(operand, window_dimensions, window_strides, padding,
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/_src/lax/windowed_reductions.py", line 127, in _reduce_window_sum  *
        window_dilation=tuple(window_dilation))
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/core.py", line 324, in bind  *
        return self.bind_with_trace(find_top_trace(args), args, params)
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/core.py", line 327, in bind_with_trace  *
        out = trace.process_primitive(self, map(trace.full_raise, args), params)
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/experimental/jax2tf/jax2tf.py", line 977, in invoke_impl  *
        **params)
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/experimental/jax2tf/impl_no_xla.py", line 562, in tf_pool  *
        op = tf.reshape(op, (1,) + operand_shape + (1,))

    TypeError: Failed to convert elements of (1, 1, 16*t, 1, 1) to Tensor. Consider casting elements to a supported type. See https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes.

@sdenton4
Copy link
Author

Oh, and things seem to have gotten worse in one regard...
Previously I noticed that conversion succeeds if stride >= pool_size, but that no longer works with the new change applied.

@sdenton4
Copy link
Author

Looking briefly at the code+error, the reshape here is a bit strange, as the tensor already has batch and channel axes.

@marcvanzee
Copy link
Collaborator

Hi Tom,

Thanks for filing the bug! I removed the batch + channel axes logic because I thought this was not part of the op: looking at the operational semantics of XLA::ReduceWindow, the definition of operands is A sequence of N multi-dimensional arrays of types T_0,..., T_{N-1}, each representing the base area on which the window is placed.

However, it seems it actually is possible to add batch and feature dimensions. In fact, this seems to be what Flax does as well! (so all Flax Modules using pooling now fail)

I think this actually quite a good argument for adding some end-to-end Flax --> jax2tf (enable_xla=False) tests, which would catch this. I wrote some tooling in converter_eval a while ago, but I think we should rewrite this to some actual tests. I filed #11872 for this.

In any case, I will look at adding back support for batch and channel dimensions.

@marcvanzee
Copy link
Collaborator

I actually found a different bug in our current implementation when doing average pooling with "SAME" padding. It is related to the way TF computes average pooling, and will lead to different outputs than JAX (#11874).

@sdenton4 PTAL at that issue since you are doing average pooling with "SAME" padding as well, so I think the bug affects your code.

@marcvanzee
Copy link
Collaborator

I just thought of a solution to that problem using manual padding whenever we encounter SAME padding. Will try implementing that fix together with this one.

@sdenton4
Copy link
Author

Oh, wow; interesting bug on the TF side. Probably worth filing something with them?

It certainly affects me, but the TF behavior seems not too bad for my use-case (averaging embeddings over time). Treating the padding as signal effectively creates a downward bias on the mean. Dealing with the noisier average, without fake data, is probably preferable for me...

If there's anything I can help with feel free to send a ping, of course. I've got a couple other things on the front burners right now, but as you know this is all important to me, so happy to pitch in where I can.

@marcvanzee
Copy link
Collaborator

If there's anything I can help with feel free to send a ping, of course.

No worries, i am taking a look at this issue and the other one, I'll let you know when I have a fix!

copybara-service bot pushed a commit that referenced this issue Aug 15, 2022
* Fixes #11804: we only supported `lax.reduce_window` without batch and channel dimensions, which is wrong. This is supported, and in fact something that most users use (this case is actually not explained in the [operational semantics for XLA::ReduceWindow](https://www.tensorflow.org/xla/operation_semantics#reducewindow). I have fixed this and clarified a number of test cases with batch and channel dimensions.

* Also, @sdenton4 gave a failing example in a Colab using polymorphic dimensions. I've added this as a test case to make sure it works now.

* Adds support for explicit padding using the existing padding logic from convolutions.

* Fixes #11874: we were not handling SAME padding for `lax.add` correctly, since we used `tf.nn.avg_pool`, which does not include non-padding tokens (see issue for more details). I resolved it by adding manual padding and added some additional tests for this.

PiperOrigin-RevId: 467532166
@marcvanzee
Copy link
Collaborator

@sdenton4 could you please try #11913 on your code and see if it fixes things?

copybara-service bot pushed a commit that referenced this issue Aug 15, 2022
* Fixes #11804: we only supported `lax.reduce_window` without batch and channel dimensions, which is wrong. This is supported, and in fact something that most users use (this case is actually not explained in the [operational semantics for XLA::ReduceWindow](https://www.tensorflow.org/xla/operation_semantics#reducewindow). I have fixed this and clarified a number of test cases with batch and channel dimensions.

* Also, @sdenton4 gave a failing example in a Colab using polymorphic dimensions. I've added this as a test case to make sure it works now.

* Adds support for explicit padding using the existing padding logic from convolutions.

* Fixes #11874: we were not handling SAME padding for `lax.add` correctly, since we used `tf.nn.avg_pool`, which does not include non-padding tokens (see issue for more details). I resolved it by adding manual padding and added some additional tests for this.

* Ensures we call eval_shape on a shape containing polynamials before calling a TF op.

PiperOrigin-RevId: 467532166
copybara-service bot pushed a commit that referenced this issue Aug 15, 2022
* Fixes #11804: we only supported `lax.reduce_window` without batch and channel dimensions, which is wrong. This is supported, and in fact something that most users use (this case is actually not explained in the [operational semantics for XLA::ReduceWindow](https://www.tensorflow.org/xla/operation_semantics#reducewindow)). I have fixed this and clarified a number of test cases with batch and channel dimensions.

* Also, @sdenton4 gave a failing example in a Colab using polymorphic dimensions. I've added this as a test case to make sure it works now.

* Adds support for explicit padding using the existing padding logic from convolutions.

* Fixes #11874: we were not handling SAME padding for `lax.add` correctly, since we used `tf.nn.avg_pool`, which does not include non-padding tokens (see issue for more details). I resolved it by adding manual padding and added some additional tests for this.

* Ensures we call eval_shape on a shape containing polynamials before calling a TF op.

PiperOrigin-RevId: 467532166
copybara-service bot pushed a commit that referenced this issue Aug 15, 2022
* Fixes #11804: we only supported `lax.reduce_window` without batch and channel dimensions, which is wrong. This is supported, and in fact something that most users use (this case is actually not explained in the [operational semantics for XLA::ReduceWindow](https://www.tensorflow.org/xla/operation_semantics#reducewindow)). I have fixed this and clarified a number of test cases with batch and channel dimensions.

* Also, @sdenton4 gave a failing example in a Colab using polymorphic dimensions. I've added this as a test case to make sure it works now.

* Adds support for explicit padding using the existing padding logic from convolutions.

* Fixes #11874: we were not handling SAME padding for `lax.add` correctly, since we used `tf.nn.avg_pool`, which does not include non-padding tokens (see issue for more details). I resolved it by adding manual padding and added some additional tests for this.

* Ensures we call eval_shape on a shape containing polynamials before calling a TF op.

PiperOrigin-RevId: 467532166
copybara-service bot pushed a commit that referenced this issue Aug 16, 2022
* Fixes #11804: we only supported `lax.reduce_window` without batch and channel dimensions, which is wrong. This is supported, and in fact something that most users use (this case is actually not explained in the [operational semantics for XLA::ReduceWindow](https://www.tensorflow.org/xla/operation_semantics#reducewindow)). I have fixed this and clarified a number of test cases with batch and channel dimensions.

* Also, @sdenton4 gave a failing example in a Colab using polymorphic dimensions. I've added this as a test case to make sure it works now.

* Adds support for explicit padding using the existing padding logic from convolutions.

* Fixes #11874: we were not handling SAME padding for `lax.add` correctly, since we used `tf.nn.avg_pool`, which does not include non-padding tokens (see issue for more details). I resolved it by adding manual padding and added some additional tests for this.

* Ensures we call eval_shape on a shape containing polynomials before calling a TF op.

* Fixes #11929 (comment): we weren't running any of the shape_poly_test.py tests for `enable_xla=False`.

PiperOrigin-RevId: 467532166
copybara-service bot pushed a commit that referenced this issue Aug 16, 2022
* Fixes #11804: we only supported `lax.reduce_window` without batch and channel dimensions, which is wrong. This is supported, and in fact something that most users use (this case is actually not explained in the [operational semantics for XLA::ReduceWindow](https://www.tensorflow.org/xla/operation_semantics#reducewindow)). I have fixed this and clarified a number of test cases with batch and channel dimensions.

* Also, @sdenton4 gave a failing example in a Colab using polymorphic dimensions. I've added this as a test case to make sure it works now.

* Adds support for explicit padding using the existing padding logic from convolutions.

* Fixes #11874: we were not handling SAME padding for `lax.add` correctly, since we used `tf.nn.avg_pool`, which does not include non-padding tokens (see issue for more details). I resolved it by adding manual padding and added some additional tests for this.

* Ensures we call eval_shape on a shape containing polynomials before calling a TF op.

* Fixes #11929 (comment): we weren't running any of the shape_poly_test.py tests for `enable_xla=False`.

PiperOrigin-RevId: 467532166
copybara-service bot pushed a commit that referenced this issue Aug 16, 2022
* Fixes #11804: we only supported `lax.reduce_window` without batch and channel dimensions, which is wrong. This is supported, and in fact something that most users use (this case is actually not explained in the [operational semantics for XLA::ReduceWindow](https://www.tensorflow.org/xla/operation_semantics#reducewindow)). I have fixed this and clarified a number of test cases with batch and channel dimensions.

* Also, @sdenton4 gave a failing example in a Colab using polymorphic dimensions. I've added this as a test case to make sure it works now.

* Adds support for explicit padding using the existing padding logic from convolutions.

* Fixes #11874: we were not handling SAME padding for `lax.add` correctly, since we used `tf.nn.avg_pool`, which does not include non-padding tokens (see issue for more details). I resolved it by adding manual padding and added some additional tests for this.

* Ensures we call eval_shape on a shape containing polynomials before calling a TF op.

* Fixes #11929 (comment): we weren't running any of the shape_poly_test.py tests for `enable_xla=False`.

PiperOrigin-RevId: 467532166
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants