Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueError: make_sure failure: Only keep_num_dims=False supported for fully connected op #1939

Open
josephrocca opened this issue May 16, 2022 · 3 comments
Assignees
Labels
bug An unexpected problem or unintended behavior

Comments

@josephrocca
Copy link

josephrocca commented May 16, 2022

Describe the bug
The error in the title is produced when converting a tflite file to ONNX via tf2onnx.convert. The tflite file was was produced by converting a JAX function to tflite via tf.lite.TFLiteConverter.experimental_from_jax as shown in the Colab notebook linked below.

Urgency
No hard deadline.

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 18.04 (Google Colab)
  • Tensorflow Version: 2.8.0
  • Python version: 3.7.13

To Reproduce
I've uploaded the model here, but it's 900mb, so it's probably much faster to give you the notebook that generates the model.

Click Runtime > Run all in this Colab notebook to reproduce:

https://colab.research.google.com/drive/1DygMV-Nlae6BEJmZjN_laIdfYL33BbE0

Summary of what the Colab notebook does to generate the model (in case it's helpful at all):

  1. Load Flax CLIP
  2. Create a score function that uses the CLIP model inside it
  3. Get the gradient of that score function with jax.grad(score)
  4. Convert that jax.grad(score) function to tflite using tf.lite.TFLiteConverter.experimental_from_jax
  5. Check that the tflite output matches the jax output
  6. Convert that tflite file to ONNX with python -m tf2onnx.convert (which produces the error)

Note that another potential route from JAX to ONNX is via jax2tf, but I found that the tensorflow function that's produced by jax2tf can't be converted to ONNX due to tf2onnx's lack of support for PartitionedCall. I thought that perhaps the tflite to ONNX route via experimental_from_jax and tf2onnx might work, but ran into this error.

The error log starts off like this:

/usr/lib/python3.7/runpy.py:125: RuntimeWarning: 'tf2onnx.convert' found in sys.modules after import of package 'tf2onnx', but prior to execution of 'tf2onnx.convert'; this may result in unpredictable behaviour
  warn(RuntimeWarning(msg))
2022-05-16 18:18:30,858 - INFO - Using tensorflow=2.8.0, onnx=1.11.0, tf2onnx=1.10.1/a37f29
2022-05-16 18:18:30,858 - INFO - Using opset <onnx, 13>
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
2022-05-16 18:18:42,690 - ERROR - Failed to convert node 'xla_computation(score)/jit(main)/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None];147' (fct=<bound method TflFullyConnectedOp.to_tf of <class 'tf2onnx.tflite_handlers.tfl_math.TflFullyConnectedOp'>>)
'OP=TFL_FULLY_CONNECTED\nName=xla_computation(score)/jit(main)/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None];147\nInputs:\n\txla_computation(score)/jit(main)/add;199=Add, [1, 50, 768], 1\n\txla_computation(score)/jit(main)/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None];=Const, [768, 768], 1\nOutpus:\n\txla_computation(score)/jit(main)/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None];147=[1, 50, 768], 1'
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/tf2onnx/tfonnx.py", line 292, in tensorflow_onnx_mapping
    func(g, node, **kwargs, initialized_tables=initialized_tables, dequantize=dequantize)
  File "/usr/local/lib/python3.7/dist-packages/tf2onnx/tflite_handlers/tfl_math.py", line 205, in to_tf
    "Only keep_num_dims=False supported for fully connected op")
  File "/usr/local/lib/python3.7/dist-packages/tf2onnx/utils.py", line 264, in make_sure
    raise ValueError("make_sure failure: " + error_msg % args)
ValueError: make_sure failure: Only keep_num_dims=False supported for fully connected op

It then repeats that same ValueError many times and finishes with this:

2022-05-16 18:18:44,763 - ERROR - Failed to convert node 'xla_computation(score)/jit(main)/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None];139' (fct=<bound method TflFullyConnectedOp.to_tf of <class 'tf2onnx.tflite_handlers.tfl_math.TflFullyConnectedOp'>>)
'OP=TFL_FULLY_CONNECTED\nName=xla_computation(score)/jit(main)/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None];139\nInputs:\n\txla_computation(score)/jit(main)/add_any;114=Add, [1, 50, 3072], 1\n\txla_computation(score)/jit(main)/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None];67=Const, [768, 3072], 1\nOutpus:\n\txla_computation(score)/jit(main)/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None];139=[1, 50, 768], 1'
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/tf2onnx/tfonnx.py", line 292, in tensorflow_onnx_mapping
    func(g, node, **kwargs, initialized_tables=initialized_tables, dequantize=dequantize)
  File "/usr/local/lib/python3.7/dist-packages/tf2onnx/tflite_handlers/tfl_math.py", line 205, in to_tf
    "Only keep_num_dims=False supported for fully connected op")
  File "/usr/local/lib/python3.7/dist-packages/tf2onnx/utils.py", line 264, in make_sure
    raise ValueError("make_sure failure: " + error_msg % args)
ValueError: make_sure failure: Only keep_num_dims=False supported for fully connected op
Traceback (most recent call last):
  File "/usr/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.7/dist-packages/tf2onnx/convert.py", line 640, in <module>
    main()
  File "/usr/local/lib/python3.7/dist-packages/tf2onnx/convert.py", line 287, in main
    output_path=args.output)
  File "/usr/local/lib/python3.7/dist-packages/tf2onnx/convert.py", line 162, in _convert_common
    custom_op_handlers=custom_op_handlers, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/tf2onnx/tfonnx.py", line 439, in process_tf_graph
    initialized_tables, tensors_to_rename, is_tflite, dequantize)
  File "/usr/local/lib/python3.7/dist-packages/tf2onnx/tfonnx.py", line 492, in process_graphs
    dequantize)
  File "/usr/local/lib/python3.7/dist-packages/tf2onnx/tfonnx.py", line 512, in process_parsed_graph
    raise exceptions[0]
  File "/usr/local/lib/python3.7/dist-packages/tf2onnx/tfonnx.py", line 292, in tensorflow_onnx_mapping
    func(g, node, **kwargs, initialized_tables=initialized_tables, dequantize=dequantize)
  File "/usr/local/lib/python3.7/dist-packages/tf2onnx/tflite_handlers/tfl_math.py", line 205, in to_tf
    "Only keep_num_dims=False supported for fully connected op")
  File "/usr/local/lib/python3.7/dist-packages/tf2onnx/utils.py", line 264, in make_sure
    raise ValueError("make_sure failure: " + error_msg % args)
ValueError: make_sure failure: Only keep_num_dims=False supported for fully connected op
@pyl62112991
Copy link

Have you solved the problem now?

@buptlj
Copy link

buptlj commented Dec 8, 2023

same problem

@fatcat-z fatcat-z self-assigned this Dec 13, 2023
@fatcat-z fatcat-z added the bug An unexpected problem or unintended behavior label Dec 13, 2023
@Doomski99
Copy link

To anyone that finds this later, based on the docs of TFLite, keep_num_dims is probably turned on when the input shape of the Fully Connected layer is higher than 2 dimensions. Not sure if that will help, but that's that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug An unexpected problem or unintended behavior
Projects
None yet
Development

No branches or pull requests

5 participants