You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
The error in the title is produced when converting a tflite file to ONNX via tf2onnx.convert. The tflite file was was produced by converting a JAX function to tflite via tf.lite.TFLiteConverter.experimental_from_jax as shown in the Colab notebook linked below.
Urgency
No hard deadline.
System information
OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 18.04 (Google Colab)
Tensorflow Version: 2.8.0
Python version: 3.7.13
To Reproduce I've uploaded the model here, but it's 900mb, so it's probably much faster to give you the notebook that generates the model.
Click Runtime > Run all in this Colab notebook to reproduce:
Summary of what the Colab notebook does to generate the model (in case it's helpful at all):
Load Flax CLIP
Create a score function that uses the CLIP model inside it
Get the gradient of that score function with jax.grad(score)
Convert that jax.grad(score) function to tflite using tf.lite.TFLiteConverter.experimental_from_jax
Check that the tflite output matches the jax output
Convert that tflite file to ONNX with python -m tf2onnx.convert (which produces the error)
Note that another potential route from JAX to ONNX is via jax2tf, but I found that the tensorflow function that's produced by jax2tf can't be converted to ONNX due to tf2onnx's lack of support for PartitionedCall. I thought that perhaps the tflite to ONNX route via experimental_from_jax and tf2onnx might work, but ran into this error.
The error log starts off like this:
/usr/lib/python3.7/runpy.py:125: RuntimeWarning: 'tf2onnx.convert' found in sys.modules after import of package 'tf2onnx', but prior to execution of 'tf2onnx.convert'; this may result in unpredictable behaviour
warn(RuntimeWarning(msg))
2022-05-16 18:18:30,858 - INFO - Using tensorflow=2.8.0, onnx=1.11.0, tf2onnx=1.10.1/a37f29
2022-05-16 18:18:30,858 - INFO - Using opset <onnx, 13>
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
2022-05-16 18:18:42,690 - ERROR - Failed to convert node 'xla_computation(score)/jit(main)/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None];147' (fct=<bound method TflFullyConnectedOp.to_tf of <class 'tf2onnx.tflite_handlers.tfl_math.TflFullyConnectedOp'>>)
'OP=TFL_FULLY_CONNECTED\nName=xla_computation(score)/jit(main)/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None];147\nInputs:\n\txla_computation(score)/jit(main)/add;199=Add, [1, 50, 768], 1\n\txla_computation(score)/jit(main)/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None];=Const, [768, 768], 1\nOutpus:\n\txla_computation(score)/jit(main)/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None];147=[1, 50, 768], 1'
Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/tf2onnx/tfonnx.py", line 292, in tensorflow_onnx_mapping
func(g, node, **kwargs, initialized_tables=initialized_tables, dequantize=dequantize)
File "/usr/local/lib/python3.7/dist-packages/tf2onnx/tflite_handlers/tfl_math.py", line 205, in to_tf
"Only keep_num_dims=False supported for fully connected op")
File "/usr/local/lib/python3.7/dist-packages/tf2onnx/utils.py", line 264, in make_sure
raise ValueError("make_sure failure: " + error_msg % args)
ValueError: make_sure failure: Only keep_num_dims=False supported for fully connected op
It then repeats that same ValueError many times and finishes with this:
2022-05-16 18:18:44,763 - ERROR - Failed to convert node 'xla_computation(score)/jit(main)/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None];139' (fct=<bound method TflFullyConnectedOp.to_tf of <class 'tf2onnx.tflite_handlers.tfl_math.TflFullyConnectedOp'>>)
'OP=TFL_FULLY_CONNECTED\nName=xla_computation(score)/jit(main)/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None];139\nInputs:\n\txla_computation(score)/jit(main)/add_any;114=Add, [1, 50, 3072], 1\n\txla_computation(score)/jit(main)/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None];67=Const, [768, 3072], 1\nOutpus:\n\txla_computation(score)/jit(main)/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None];139=[1, 50, 768], 1'
Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/tf2onnx/tfonnx.py", line 292, in tensorflow_onnx_mapping
func(g, node, **kwargs, initialized_tables=initialized_tables, dequantize=dequantize)
File "/usr/local/lib/python3.7/dist-packages/tf2onnx/tflite_handlers/tfl_math.py", line 205, in to_tf
"Only keep_num_dims=False supported for fully connected op")
File "/usr/local/lib/python3.7/dist-packages/tf2onnx/utils.py", line 264, in make_sure
raise ValueError("make_sure failure: " + error_msg % args)
ValueError: make_sure failure: Only keep_num_dims=False supported for fully connected op
Traceback (most recent call last):
File "/usr/lib/python3.7/runpy.py", line 193, in _run_module_as_main
"__main__", mod_spec)
File "/usr/lib/python3.7/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/usr/local/lib/python3.7/dist-packages/tf2onnx/convert.py", line 640, in <module>
main()
File "/usr/local/lib/python3.7/dist-packages/tf2onnx/convert.py", line 287, in main
output_path=args.output)
File "/usr/local/lib/python3.7/dist-packages/tf2onnx/convert.py", line 162, in _convert_common
custom_op_handlers=custom_op_handlers, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/tf2onnx/tfonnx.py", line 439, in process_tf_graph
initialized_tables, tensors_to_rename, is_tflite, dequantize)
File "/usr/local/lib/python3.7/dist-packages/tf2onnx/tfonnx.py", line 492, in process_graphs
dequantize)
File "/usr/local/lib/python3.7/dist-packages/tf2onnx/tfonnx.py", line 512, in process_parsed_graph
raise exceptions[0]
File "/usr/local/lib/python3.7/dist-packages/tf2onnx/tfonnx.py", line 292, in tensorflow_onnx_mapping
func(g, node, **kwargs, initialized_tables=initialized_tables, dequantize=dequantize)
File "/usr/local/lib/python3.7/dist-packages/tf2onnx/tflite_handlers/tfl_math.py", line 205, in to_tf
"Only keep_num_dims=False supported for fully connected op")
File "/usr/local/lib/python3.7/dist-packages/tf2onnx/utils.py", line 264, in make_sure
raise ValueError("make_sure failure: " + error_msg % args)
ValueError: make_sure failure: Only keep_num_dims=False supported for fully connected op
The text was updated successfully, but these errors were encountered:
To anyone that finds this later, based on the docs of TFLite, keep_num_dims is probably turned on when the input shape of the Fully Connected layer is higher than 2 dimensions. Not sure if that will help, but that's that.
Describe the bug
The error in the title is produced when converting a tflite file to ONNX via
tf2onnx.convert
. The tflite file was was produced by converting a JAX function to tflite viatf.lite.TFLiteConverter.experimental_from_jax
as shown in the Colab notebook linked below.Urgency
No hard deadline.
System information
To Reproduce
I've uploaded the model here, but it's 900mb, so it's probably much faster to give you the notebook that generates the model.
Click
Runtime > Run all
in this Colab notebook to reproduce:https://colab.research.google.com/drive/1DygMV-Nlae6BEJmZjN_laIdfYL33BbE0
Summary of what the Colab notebook does to generate the model (in case it's helpful at all):
score
function that uses the CLIP model inside itscore
function withjax.grad(score)
jax.grad(score)
function to tflite usingtf.lite.TFLiteConverter.experimental_from_jax
python -m tf2onnx.convert
(which produces the error)Note that another potential route from JAX to ONNX is via jax2tf, but I found that the tensorflow function that's produced by jax2tf can't be converted to ONNX due to tf2onnx's lack of support for PartitionedCall. I thought that perhaps the tflite to ONNX route via experimental_from_jax and tf2onnx might work, but ran into this error.
The error log starts off like this:
It then repeats that same ValueError many times and finishes with this:
The text was updated successfully, but these errors were encountered: