Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversion of tranpose into reshape goes wrong after 3D NDHWC convolution #1464

Closed
gerbenvv opened this issue Apr 20, 2021 · 8 comments
Closed
Labels
pending on user response Waiting for more information or validation from user

Comments

@gerbenvv
Copy link

When doing a 3D NDHWC convolution with an optional bias and element-wise operation after it (for example sigmoid), the second transpose (after the convolution) is incorrect. It has 4 dimensions where it should have 5.

Futhermore the optimization crashes if biases = 1.0 in my example (it assumes a numpy array).

Reproduce:

import tensorflow as tf
import tf2onnx
import numpy as np


@tf.function
def f(x):
    weights = np.random.randn(1, 1, 1, 1, 1).astype(np.float32)
    biases = np.random.randn(1).astype(np.float32)

    x = tf.nn.conv3d(x, tf.constant(weights), strides=[1, 1, 1, 1, 1], padding="VALID", data_format="NDHWC")
    # x = x + biases  # Optional, does not matter, but crashes when biases=1.0
    x = tf.math.sigmoid(x)
    
    return x


tf2onnx.convert.from_function(
    function=f,
    input_signature=(
        tf.TensorSpec(shape=(None, 1, 1, 1, 1), name="x"),
    ),
)[0]

I traced it back to the _switch_transpose_and_node method in the TransposeOptimizer. It always does a NHWC_TO_NCHW permutation where-as it should check the rank as in _handle_node_having_branches. So:

            # only nhwc transpose can reach here
            new_shape = [shape[i] for i in NHWC_TO_NCHW]
            self._g.set_shape(node.output[0], new_shape)

should be:

            # only nhwc transpose can reach here
            trans_rank = get_transpose_rank(trans)
            perm = NHWC_TO_NCHW if trans_rank == 4 else NDHWC_TO_NCDHW
            new_shape = [shape[i] for i in perm]
            self._g.set_shape(node.output[0], new_shape)

It seems that in a lot of places it's still assumed that the graph is a 2D convolutional network, where-as of course convolutions can be of any dimension. I'd like to raise awareness because 3D convolutions are widely used in medical machine learning. Also, for audio 1D convolutions are often used.

Furthermore, it would be nice if optimizations could be disabled (perhaps on a per-optimization basis). The optimizations can be finicky and are not strictly necessary to get an ONNX conversion. I think ideally the optimizer is a separate tool altogether (so that it works on any ONNX).

@gerbenvv
Copy link
Author

Okay, I checked master and it looks like this may already be fixed in master.

@TomWildenhain-Microsoft
Copy link
Contributor

Thanks for the bug report. We just overhauled the transpose optimizer so as you said, it might already be fixed. You can install the latest tf2onnx from master with
pip uninstall tf2onnx
pip install git+https://github.com/onnx/tensorflow-onnx
Try it and let us know if you are still getting this bug. Thanks!

@TomWildenhain-Microsoft TomWildenhain-Microsoft added the pending on user response Waiting for more information or validation from user label Apr 21, 2021
@gerbenvv
Copy link
Author

I tried, but I am getting the following exception now:

Traceback (most recent call last):
-- snip --
    onnx_model, _ = tf2onnx.convert.from_function(
  File "/xx/site-packages/tf2onnx/convert.py", line 400, in from_function
    tensors_to_rename = tensor_names_from_structed(concrete_func, input_names, output_names)
  File "/xxsite-packages/tf2onnx/convert.py", line 276, in tensor_names_from_structed
    structured_inputs = [t.name for t in tf.nest.flatten(concrete_func.structured_input_signature)]
  File "/xx/site-packages/tf2onnx/convert.py", line 276, in <listcomp>
    structured_inputs = [t.name for t in tf.nest.flatten(concrete_func.structured_input_signature)]
AttributeError: 'UnknownArgument' object has no attribute 'name'

@gerbenvv
Copy link
Author

For me concrete_func.structured_input_signature is as follows:

(([(<tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc92e2ce0d0>, TensorSpec(shape=(None, None, None, None, 1), dtype=tf.float32, name='inputs/0/1'))],), {})

@TomWildenhain-Microsoft
Copy link
Contributor

Interesting. It works on my machine. What TF version are you using? I didn't realize structured_input_signature could contain non-TensorSpec objects.

@gerbenvv
Copy link
Author

It was my actual implementation that threw that error, not the bug reproduce.
But yes, structured_input_signature can pretty much contain anything (constants, captured variables outside the function, etc.).

Example:

import tensorflow as tf

class Foo:
    a = 42

@tf.function
def f(foo, a, b, x):
    if a:
        return x + foo.a
    else:
        return x + b

concrete_function = f.get_concrete_function(
    foo=Foo(),
    a=True,
    b=123,
    x=tf.TensorSpec(shape=(None, 1, 1, 1, 1), name="x"),
)

print(concrete_function.structured_input_signature)

yields:
((<tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f3c2c890460>, True, 123, TensorSpec(shape=(None, 1, 1, 1, 1), dtype=tf.float32, name='x')), {})

I think you want to check for isinstance(.., TensorSpec) before getting its name.

@TomWildenhain-Microsoft
Copy link
Contributor

Got it. #1490

@gerbenvv
Copy link
Author

That fixed it, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pending on user response Waiting for more information or validation from user
Projects
None yet
Development

No branches or pull requests

2 participants