diff --git a/tf2onnx/tf_utils.py b/tf2onnx/tf_utils.py index 064506748..4c5ed9505 100644 --- a/tf2onnx/tf_utils.py +++ b/tf2onnx/tf_utils.py @@ -203,7 +203,7 @@ def tflist_to_onnx(g, shape_override, const_node_values=None): elif a == "output_shapes": # we should not need it since we pull the shapes above already pass - elif a in {"body", "cond", "then_branch", "else_branch"}: + elif a in {"body", "cond", "then_branch", "else_branch", "f"}: input_shapes = [inp.get_shape() for inp in node.inputs] nattr = get_tf_node_attr(node, a) attr[a] = nattr.name